윤영준 윤영준 2023-06-28
integrated loss functions into neural network class
@e5885c1ea38a8eb0e7d1f3194bf339b197c0ce2b
model/AttentiveRNN.py
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
@@ -212,23 +212,17 @@
         }
         return ret
 
-# need fixing
-class AttentiveRNNLoss(nn.Module):
-    def __init__(self, theta=0.8):
-        super(AttentiveRNNLoss, self).__init__()
+    #
+    def loss(self, input_image_tensor, difference_maskmap, theta=0.8):
         self.theta = theta
-    def forward(self, input_tensor, label_tensor):
         # Initialize attentive rnn model
-        attentive_rnn = AttentiveRNN
-        inference_ret = attentive_rnn(input_tensor)
-
+        inference_ret = self.forward(input_image_tensor)
         loss = 0.0
         n = len(inference_ret['attention_map_list'])
         for index, attention_map in enumerate(inference_ret['attention_map_list']):
-            mse_loss = (self.theta ** (n - index + 1)) * nn.MSELoss()(attention_map, label_tensor)
+            mse_loss = (self.theta ** (n - index + 1)) * nn.MSELoss()(attention_map, difference_maskmap)
             loss += mse_loss
-
-        return loss, inference_ret['final_attention_map']
+        return loss
 
 # Need work
 
model/Autoencoder.py
--- model/Autoencoder.py
+++ model/Autoencoder.py
@@ -93,15 +93,15 @@
 
         return ret
 
-    def loss(self, input_tensor, label_tensor):
-        ori_height, ori_width = label_tensor.shape[2:]
+    def loss(self, input_image_tensor, input_clean_image_tensor):
+        ori_height, ori_width = input_clean_image_tensor.shape[2:]
 
         # Rescale labels to match the scales of the outputs
-        label_tensor_resize_2 = F.interpolate(label_tensor, size=(ori_height // 2, ori_width // 2))
-        label_tensor_resize_4 = F.interpolate(label_tensor, size=(ori_height // 4, ori_width // 4))
-        label_list = [label_tensor_resize_4, label_tensor_resize_2, label_tensor]
+        label_tensor_resize_2 = F.interpolate(input_clean_image_tensor, size=(ori_height // 2, ori_width // 2))
+        label_tensor_resize_4 = F.interpolate(input_clean_image_tensor, size=(ori_height // 4, ori_width // 4))
+        label_list = [label_tensor_resize_4, label_tensor_resize_2, input_clean_image_tensor]
 
-        inference_ret = self.forward(input_tensor)
+        inference_ret = self.forward(input_image_tensor)
 
         output_list = [inference_ret['skip_1'], inference_ret['skip_2'], inference_ret['skip_3']]
 
@@ -112,7 +112,7 @@
             lm_loss += mse_loss
 
         # Compute lp_loss
-        src_vgg_feats = self.vgg(label_tensor)
+        src_vgg_feats = self.vgg(input_clean_image_tensor)
         pred_vgg_feats = self.vgg(output_list[-1])
 
         lp_losses = []
@@ -122,4 +122,4 @@
 
         loss = lm_loss + lp_loss
 
-        return loss, inference_ret['skip_3']
(No newline at end of file)
+        return loss
model/Generator.py
--- model/Generator.py
+++ model/Generator.py
@@ -34,6 +34,10 @@
         }
         return ret
 
+    def loss(self, x, diff):
+        self.attentiveRNN.loss(x, diff)
+        self.autoencoder.loss(x)
+
 if __name__ == "__main__":
     import torch
     from torchinfo import summary
Add a comment
List