윤영준 윤영준 2023-06-23
changing return type into dict to not make additional cost for chaning return types of functions just in case.
@2bbc71667c01b3eaf51dd4be8597f67421a644a2
model/AttentiveRNN.py
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
@@ -152,7 +152,13 @@
         x = self.resnet(original_image)
         attention_map, cell_state, lstm_feats = self.LSTM(x, prev_cell_state)
         x = attention_map * original_image
-        return x, attention_map, cell_state, lstm_feats
+        ret = {
+            'x' : x,
+            'attention_map' : attention_map,
+            'cell_state' : cell_state,
+            'lstm_feats' : lstm_feats
+        }
+        return ret
 
 
 class AttentiveRNN(nn.Module):
@@ -204,7 +210,7 @@
             'attention_map_list' : attention_map,
             'lstm_feats' : lstm_feats
         }
-        return x, attention_map, lstm_feats
+        return ret
 
 # need fixing
 class AttentiveRNNLoss(nn.Module):
model/Discriminator.py
--- model/Discriminator.py
+++ model/Discriminator.py
@@ -4,19 +4,18 @@
 class DiscriminativeNet(nn.Module):
     def __init__(self, W, H):
         super(DiscriminativeNet, self).__init__()
-        self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=5, stride=1, padding=2)
-        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=2)
-        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2)
-        self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2)
+        self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=5, stride=2, padding=1)
+        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=2, padding=2)
+        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=2, padding=2)
+        self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=2, padding=2)
         self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2)
         self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=2)
-        self.conv_map = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=5, stride=1, padding=2, bias=False)
+        self.conv_attention = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=5, stride=1, padding=2, bias=False)
         self.conv7 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=5, stride=4, padding=2)
         self.conv8 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, stride=4, padding=2)
         self.conv9 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, stride=4, padding=2)
-        self.fc1 = nn.Linear(32 * W * H,
-                             1024)  # You need to adjust the input dimension here depending on your input size
-        self.fc2 = nn.Linear(1024, 1)
+        self.fc1 = nn.Linear(32, 1)  # You need to adjust the input dimension here depending on your input size
+        self.fc2 = nn.Linear(1, 1)
 
     def forward(self, x):
         x1 = F.leaky_relu(self.conv1(x))
@@ -25,19 +24,24 @@
         x4 = F.leaky_relu(self.conv4(x3))
         x5 = F.leaky_relu(self.conv5(x4))
         x6 = F.leaky_relu(self.conv6(x5))
-        attention_map = self.conv_map(x6)
+        attention_map = self.conv_attention(x6)
         x7 = F.leaky_relu(self.conv7(attention_map * x6))
         x8 = F.leaky_relu(self.conv8(x7))
         x9 = F.leaky_relu(self.conv9(x8))
         x9 = x9.view(x9.size(0), -1)  # flatten the tensor
         fc1 = self.fc1(x9)
-        fc2 = self.fc2(fc1)
-        fc_out = F.sigmoid(fc2)
+        fc_raw = self.fc2(fc1)
+        fc_out = F.sigmoid(fc_raw)
 
         # Ensure fc_out is not exactly 0 or 1 for stability of log operation in loss
         fc_out = clamp(fc_out, min=1e-7, max=1 - 1e-7)
 
-        return fc_out, attention_map, fc2
+        ret = {
+            "fc_out" : fc_out,
+            "attention_map": attention_map,
+            "fc_raw" : fc_raw
+        }
+        return fc_out, attention_map, fc_raw
 
 if __name__ == "__main__":
     import torch
@@ -45,5 +49,5 @@
 
     torch.set_default_tensor_type(torch.FloatTensor)
     generator = DiscriminativeNet(960,540)
-    batch_size = 2
+    batch_size = 1
     summary(generator, input_size=(batch_size, 3, 960,540))
model/Generator.py
--- model/Generator.py
+++ model/Generator.py
@@ -9,7 +9,7 @@
         super(Generator, self).__init__()
         if kernel_size is None:
             kernel_size = [3, 3]
-        self.attentiveRNN = AttentiveRNN( repetition,
+        self.attentiveRNN = AttentiveRNN(repetition,
             blocks=blocks, layers=layers, input_ch=input_ch, out_ch=out_ch,
             kernel_size=None, stride=stride, padding=padding, groups=groups, dilation=dilation
         )
@@ -26,8 +26,8 @@
         self.sigmoid = nn.Sigmoid()
 
     def forward(self, x):
-        x, attention_map = self.attentiveRNN(x)
-        x = self.autoencoder(x * attention_map)
+        ret = self.attentiveRNN(x)
+        x = self.autoencoder(ret['x'] * ret['attention_map_list'][-1])
         return x
 
 if __name__ == "__main__":
@@ -37,4 +37,4 @@
     torch.set_default_tensor_type(torch.FloatTensor)
     generator = Generator(3, blocks=2)
     batch_size = 2
-    summary(generator, input_size=(batch_size, 3, 960,540))
+    summary(generator, input_size=(batch_size, 3, 720,720))
Add a comment
List