
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
... | ... | @@ -10,13 +10,6 @@ |
10 | 10 |
input = module(*input) |
11 | 11 |
return input |
12 | 12 |
|
13 |
-def conv3x3(in_ch, out_ch, stride=1, padding=1, groups=1, dilation=1): |
|
14 |
- return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=padding, groups=groups, dilation=dilation) |
|
15 |
- |
|
16 |
- |
|
17 |
-def conv1x1(in_ch, out_ch, stride=1): |
|
18 |
- return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride) |
|
19 |
- |
|
20 | 13 |
|
21 | 14 |
class ResNetBlock(nn.Module): |
22 | 15 |
def __init__(self, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1, |
... | ... | @@ -107,7 +100,12 @@ |
107 | 100 |
attention_map = self.conv_attention_map(lstm_feats) |
108 | 101 |
attention_map = torch.sigmoid(attention_map) |
109 | 102 |
|
110 |
- return attention_map, cell_state, lstm_feats |
|
103 |
+ ret = { |
|
104 |
+ "attention_amp" : attention_map, |
|
105 |
+ "cell_state" : cell_state, |
|
106 |
+ "lstm_feats" : lstm_feats |
|
107 |
+ } |
|
108 |
+ return ret |
|
111 | 109 |
|
112 | 110 |
|
113 | 111 |
class AttentiveRNNBLCK(nn.Module): |
... | ... | @@ -150,7 +148,10 @@ |
150 | 148 |
|
151 | 149 |
def forward(self, original_image, prev_cell_state=None): |
152 | 150 |
x = self.resnet(original_image) |
153 |
- attention_map, cell_state, lstm_feats = self.LSTM(x, prev_cell_state) |
|
151 |
+ lstm_ret = self.LSTM(x, prev_cell_state) |
|
152 |
+ attention_map = lstm_ret["attention_map"] |
|
153 |
+ cell_state = lstm_ret['cell_state'] |
|
154 |
+ lstm_feats = lstm_ret["lstm_feats"] |
|
154 | 155 |
x = attention_map * original_image |
155 | 156 |
ret = { |
156 | 157 |
'x' : x, |
... | ... | @@ -180,21 +181,23 @@ |
180 | 181 |
self.groups = groups |
181 | 182 |
self.dilation = dilation |
182 | 183 |
self.repetition = repetition |
183 |
- self.generator_block = mySequential( |
|
184 |
- AttentiveRNNBLCK(blocks=blocks, |
|
185 |
- layers=layers, |
|
186 |
- input_ch=input_ch, |
|
187 |
- out_ch=out_ch, |
|
188 |
- kernel_size=kernel_size, |
|
189 |
- stride=stride, |
|
190 |
- padding=padding, |
|
191 |
- groups=groups, |
|
192 |
- dilation=dilation) |
|
184 |
+ self.arnn_block = mySequential( |
|
185 |
+ AttentiveRNNBLCK( |
|
186 |
+ blocks=blocks, |
|
187 |
+ layers=layers, |
|
188 |
+ input_ch=input_ch, |
|
189 |
+ out_ch=out_ch, |
|
190 |
+ kernel_size=kernel_size, |
|
191 |
+ stride=stride, |
|
192 |
+ padding=padding, |
|
193 |
+ groups=groups, |
|
194 |
+ dilation=dilation |
|
195 |
+ ) |
|
193 | 196 |
) |
194 |
- self.generator_blocks = nn.ModuleList() |
|
197 |
+ self.arnn_blocks = nn.ModuleList() |
|
195 | 198 |
for repetition in range(repetition): |
196 |
- self.generator_blocks.append( |
|
197 |
- self.generator_block |
|
199 |
+ self.arnn_blocks.append( |
|
200 |
+ self.arnn_block |
|
198 | 201 |
) |
199 | 202 |
self.name = "AttentiveRNN" |
200 | 203 |
|
... | ... | @@ -202,12 +205,12 @@ |
202 | 205 |
cell_state = None |
203 | 206 |
attention_map = [] |
204 | 207 |
lstm_feats = [] |
205 |
- for generator_block in self.generator_blocks: |
|
206 |
- generator_block_return = generator_block(x, cell_state) |
|
207 |
- attention_map_i = generator_block_return['attention_map'] |
|
208 |
- lstm_feats_i = generator_block_return['lstm_feats'] |
|
209 |
- cell_state = generator_block_return['cell_state'] |
|
210 |
- x = generator_block_return['x'] |
|
208 |
+ for arnn_block in self.arnn_blocks: |
|
209 |
+ arnn_block_return = arnn_block(x, cell_state) |
|
210 |
+ attention_map_i = arnn_block_return['attention_map'] |
|
211 |
+ lstm_feats_i = arnn_block_return['lstm_feats'] |
|
212 |
+ cell_state = arnn_block_return['cell_state'] |
|
213 |
+ x = arnn_block_return['x'] |
|
211 | 214 |
|
212 | 215 |
attention_map.append(attention_map_i) |
213 | 216 |
lstm_feats.append(lstm_feats_i) |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?