
integrated loss functions into neural network class
@e5885c1ea38a8eb0e7d1f3194bf339b197c0ce2b
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
... | ... | @@ -212,23 +212,17 @@ |
212 | 212 |
} |
213 | 213 |
return ret |
214 | 214 |
|
215 |
-# need fixing |
|
216 |
-class AttentiveRNNLoss(nn.Module): |
|
217 |
- def __init__(self, theta=0.8): |
|
218 |
- super(AttentiveRNNLoss, self).__init__() |
|
215 |
+ # |
|
216 |
+ def loss(self, input_image_tensor, difference_maskmap, theta=0.8): |
|
219 | 217 |
self.theta = theta |
220 |
- def forward(self, input_tensor, label_tensor): |
|
221 | 218 |
# Initialize attentive rnn model |
222 |
- attentive_rnn = AttentiveRNN |
|
223 |
- inference_ret = attentive_rnn(input_tensor) |
|
224 |
- |
|
219 |
+ inference_ret = self.forward(input_image_tensor) |
|
225 | 220 |
loss = 0.0 |
226 | 221 |
n = len(inference_ret['attention_map_list']) |
227 | 222 |
for index, attention_map in enumerate(inference_ret['attention_map_list']): |
228 |
- mse_loss = (self.theta ** (n - index + 1)) * nn.MSELoss()(attention_map, label_tensor) |
|
223 |
+ mse_loss = (self.theta ** (n - index + 1)) * nn.MSELoss()(attention_map, difference_maskmap) |
|
229 | 224 |
loss += mse_loss |
230 |
- |
|
231 |
- return loss, inference_ret['final_attention_map'] |
|
225 |
+ return loss |
|
232 | 226 |
|
233 | 227 |
# Need work |
234 | 228 |
|
--- model/Autoencoder.py
+++ model/Autoencoder.py
... | ... | @@ -93,15 +93,15 @@ |
93 | 93 |
|
94 | 94 |
return ret |
95 | 95 |
|
96 |
- def loss(self, input_tensor, label_tensor): |
|
97 |
- ori_height, ori_width = label_tensor.shape[2:] |
|
96 |
+ def loss(self, input_image_tensor, input_clean_image_tensor): |
|
97 |
+ ori_height, ori_width = input_clean_image_tensor.shape[2:] |
|
98 | 98 |
|
99 | 99 |
# Rescale labels to match the scales of the outputs |
100 |
- label_tensor_resize_2 = F.interpolate(label_tensor, size=(ori_height // 2, ori_width // 2)) |
|
101 |
- label_tensor_resize_4 = F.interpolate(label_tensor, size=(ori_height // 4, ori_width // 4)) |
|
102 |
- label_list = [label_tensor_resize_4, label_tensor_resize_2, label_tensor] |
|
100 |
+ label_tensor_resize_2 = F.interpolate(input_clean_image_tensor, size=(ori_height // 2, ori_width // 2)) |
|
101 |
+ label_tensor_resize_4 = F.interpolate(input_clean_image_tensor, size=(ori_height // 4, ori_width // 4)) |
|
102 |
+ label_list = [label_tensor_resize_4, label_tensor_resize_2, input_clean_image_tensor] |
|
103 | 103 |
|
104 |
- inference_ret = self.forward(input_tensor) |
|
104 |
+ inference_ret = self.forward(input_image_tensor) |
|
105 | 105 |
|
106 | 106 |
output_list = [inference_ret['skip_1'], inference_ret['skip_2'], inference_ret['skip_3']] |
107 | 107 |
|
... | ... | @@ -112,7 +112,7 @@ |
112 | 112 |
lm_loss += mse_loss |
113 | 113 |
|
114 | 114 |
# Compute lp_loss |
115 |
- src_vgg_feats = self.vgg(label_tensor) |
|
115 |
+ src_vgg_feats = self.vgg(input_clean_image_tensor) |
|
116 | 116 |
pred_vgg_feats = self.vgg(output_list[-1]) |
117 | 117 |
|
118 | 118 |
lp_losses = [] |
... | ... | @@ -122,4 +122,4 @@ |
122 | 122 |
|
123 | 123 |
loss = lm_loss + lp_loss |
124 | 124 |
|
125 |
- return loss, inference_ret['skip_3'](No newline at end of file) |
|
125 |
+ return loss |
--- model/Generator.py
+++ model/Generator.py
... | ... | @@ -34,6 +34,10 @@ |
34 | 34 |
} |
35 | 35 |
return ret |
36 | 36 |
|
37 |
+ def loss(self, x, diff): |
|
38 |
+ self.attentiveRNN.loss(x, diff) |
|
39 |
+ self.autoencoder.loss(x) |
|
40 |
+ |
|
37 | 41 |
if __name__ == "__main__": |
38 | 42 |
import torch |
39 | 43 |
from torchinfo import summary |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?