윤영준 윤영준 2023-06-23
updating loss function
@3690a9e3555e148f80134cf4afb7174c58a4a9a6
data/example.yaml
--- data/example.yaml
+++ data/example.yaml
@@ -3,6 +3,8 @@
     num_link: 1
     path_source: /path/to/source/files/file1.jpg
     path_raindrop1: /path/to/source/files/file1.jpg
+    data_source: #url or origin
+    license:
   def456: # hash key
     path: /path/to/source/files/file2.jpg
   ghi789: # hash key
model/AttentiveRNN.py
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
@@ -2,7 +2,7 @@
 from torch import nn
 from torch.nn import functional as F
 
-# nn.Sequential does not handel multiple input by design
+# nn.Sequential does not handle multiple input by design, and this is a workaround
 # https://github.com/pytorch/pytorch/issues/19808#
 class mySequential(nn.Sequential):
     def forward(self, *input):
@@ -193,16 +193,24 @@
 
     def forward(self, x):
         cell_state = None
-        attention_map = None
+        attention_map = []
+        lstm_feats = []
         for generator_block in self.generator_blocks:
-            x, attention_map, cell_state, lstm_feats = generator_block(x, cell_state)
-        return x, attention_map
+            x, attention_map_i, cell_state, lstm_feats_i = generator_block(x, cell_state)
+            attention_map.append(attention_map_i)
+            lstm_feats.append(lstm_feats_i)
+        ret = {
+            'x' : x,
+            'attention_map_list' : attention_map,
+            'lstm_feats' : lstm_feats
+        }
+        return x, attention_map, lstm_feats
 
 # need fixing
 class AttentiveRNNLoss(nn.Module):
-    def __init__(self):
+    def __init__(self, theta=0.8):
         super(AttentiveRNNLoss, self).__init__()
-
+        self.theta = theta
     def forward(self, input_tensor, label_tensor):
         # Initialize attentive rnn model
         attentive_rnn = AttentiveRNN
@@ -211,7 +219,7 @@
         loss = 0.0
         n = len(inference_ret['attention_map_list'])
         for index, attention_map in enumerate(inference_ret['attention_map_list']):
-            mse_loss = (0.8 ** (n - index + 1)) * nn.MSELoss()(attention_map, label_tensor)
+            mse_loss = (self.theta ** (n - index + 1)) * nn.MSELoss()(attention_map, label_tensor)
             loss += mse_loss
 
         return loss, inference_ret['final_attention_map']
Add a comment
List