윤영준 윤영준 2023-07-24
small fix for return type of networks
@fdc2bde10f68271025a10130d52e87f76deb970b
model/AttentiveRNN.py
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
@@ -10,13 +10,6 @@
             input = module(*input)
         return input
 
-def conv3x3(in_ch, out_ch, stride=1, padding=1, groups=1, dilation=1):
-    return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=padding, groups=groups, dilation=dilation)
-
-
-def conv1x1(in_ch, out_ch, stride=1):
-    return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
-
 
 class ResNetBlock(nn.Module):
     def __init__(self, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1,
@@ -107,7 +100,12 @@
         attention_map = self.conv_attention_map(lstm_feats)
         attention_map = torch.sigmoid(attention_map)
 
-        return attention_map, cell_state, lstm_feats
+        ret = {
+            "attention_amp" : attention_map,
+            "cell_state" : cell_state,
+            "lstm_feats" : lstm_feats
+        }
+        return ret
 
 
 class AttentiveRNNBLCK(nn.Module):
@@ -150,7 +148,10 @@
 
     def forward(self, original_image, prev_cell_state=None):
         x = self.resnet(original_image)
-        attention_map, cell_state, lstm_feats = self.LSTM(x, prev_cell_state)
+        lstm_ret = self.LSTM(x, prev_cell_state)
+        attention_map = lstm_ret["attention_map"]
+        cell_state = lstm_ret['cell_state']
+        lstm_feats = lstm_ret["lstm_feats"]
         x = attention_map * original_image
         ret = {
             'x' : x,
@@ -180,21 +181,23 @@
         self.groups = groups
         self.dilation = dilation
         self.repetition = repetition
-        self.generator_block = mySequential(
-            AttentiveRNNBLCK(blocks=blocks,
-                             layers=layers,
-                             input_ch=input_ch,
-                             out_ch=out_ch,
-                             kernel_size=kernel_size,
-                             stride=stride,
-                             padding=padding,
-                             groups=groups,
-                             dilation=dilation)
+        self.arnn_block = mySequential(
+            AttentiveRNNBLCK(
+             blocks=blocks,
+             layers=layers,
+             input_ch=input_ch,
+             out_ch=out_ch,
+             kernel_size=kernel_size,
+             stride=stride,
+             padding=padding,
+             groups=groups,
+             dilation=dilation
+            )
         )
-        self.generator_blocks = nn.ModuleList()
+        self.arnn_blocks = nn.ModuleList()
         for repetition in range(repetition):
-            self.generator_blocks.append(
-                self.generator_block
+            self.arnn_blocks.append(
+                self.arnn_block
             )
         self.name = "AttentiveRNN"
 
@@ -202,12 +205,12 @@
         cell_state = None
         attention_map = []
         lstm_feats = []
-        for generator_block in self.generator_blocks:
-            generator_block_return = generator_block(x, cell_state)
-            attention_map_i = generator_block_return['attention_map']
-            lstm_feats_i = generator_block_return['lstm_feats']
-            cell_state = generator_block_return['cell_state']
-            x = generator_block_return['x']
+        for arnn_block in self.arnn_blocks:
+            arnn_block_return = arnn_block(x, cell_state)
+            attention_map_i = arnn_block_return['attention_map']
+            lstm_feats_i = arnn_block_return['lstm_feats']
+            cell_state = arnn_block_return['cell_state']
+            x = arnn_block_return['x']
 
             attention_map.append(attention_map_i)
             lstm_feats.append(lstm_feats_i)
Add a comment
List