--- data/example.yaml
+++ data/example.yaml
... | ... | @@ -3,6 +3,8 @@ |
3 | 3 |
num_link: 1 |
4 | 4 |
path_source: /path/to/source/files/file1.jpg |
5 | 5 |
path_raindrop1: /path/to/source/files/file1.jpg |
6 |
+ data_source: #url or origin |
|
7 |
+ license: |
|
6 | 8 |
def456: # hash key |
7 | 9 |
path: /path/to/source/files/file2.jpg |
8 | 10 |
ghi789: # hash key |
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
... | ... | @@ -2,7 +2,7 @@ |
2 | 2 |
from torch import nn |
3 | 3 |
from torch.nn import functional as F |
4 | 4 |
|
5 |
-# nn.Sequential does not handel multiple input by design |
|
5 |
+# nn.Sequential does not handle multiple input by design, and this is a workaround |
|
6 | 6 |
# https://github.com/pytorch/pytorch/issues/19808# |
7 | 7 |
class mySequential(nn.Sequential): |
8 | 8 |
def forward(self, *input): |
... | ... | @@ -193,16 +193,24 @@ |
193 | 193 |
|
194 | 194 |
def forward(self, x): |
195 | 195 |
cell_state = None |
196 |
- attention_map = None |
|
196 |
+ attention_map = [] |
|
197 |
+ lstm_feats = [] |
|
197 | 198 |
for generator_block in self.generator_blocks: |
198 |
- x, attention_map, cell_state, lstm_feats = generator_block(x, cell_state) |
|
199 |
- return x, attention_map |
|
199 |
+ x, attention_map_i, cell_state, lstm_feats_i = generator_block(x, cell_state) |
|
200 |
+ attention_map.append(attention_map_i) |
|
201 |
+ lstm_feats.append(lstm_feats_i) |
|
202 |
+ ret = { |
|
203 |
+ 'x' : x, |
|
204 |
+ 'attention_map_list' : attention_map, |
|
205 |
+ 'lstm_feats' : lstm_feats |
|
206 |
+ } |
|
207 |
+ return x, attention_map, lstm_feats |
|
200 | 208 |
|
201 | 209 |
# need fixing |
202 | 210 |
class AttentiveRNNLoss(nn.Module): |
203 |
- def __init__(self): |
|
211 |
+ def __init__(self, theta=0.8): |
|
204 | 212 |
super(AttentiveRNNLoss, self).__init__() |
205 |
- |
|
213 |
+ self.theta = theta |
|
206 | 214 |
def forward(self, input_tensor, label_tensor): |
207 | 215 |
# Initialize attentive rnn model |
208 | 216 |
attentive_rnn = AttentiveRNN |
... | ... | @@ -211,7 +219,7 @@ |
211 | 219 |
loss = 0.0 |
212 | 220 |
n = len(inference_ret['attention_map_list']) |
213 | 221 |
for index, attention_map in enumerate(inference_ret['attention_map_list']): |
214 |
- mse_loss = (0.8 ** (n - index + 1)) * nn.MSELoss()(attention_map, label_tensor) |
|
222 |
+ mse_loss = (self.theta ** (n - index + 1)) * nn.MSELoss()(attention_map, label_tensor) |
|
215 | 223 |
loss += mse_loss |
216 | 224 |
|
217 | 225 |
return loss, inference_ret['final_attention_map'] |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?