
--- model/Discriminator.py
+++ model/Discriminator.py
... | ... | @@ -17,6 +17,25 @@ |
17 | 17 |
self.fc1 = nn.Linear(32, 1) # You need to adjust the input dimension here depending on your input size |
18 | 18 |
self.fc2 = nn.Linear(1, 1) |
19 | 19 |
|
20 |
+ def loss(self, input_tensor, label_tensor, attention_map, name): |
|
21 |
+ batch_size, image_h, image_w, _ = input_tensor.size() |
|
22 |
+ |
|
23 |
+ zeros_mask = torch.zeros([batch_size, image_h, image_w, 1], dtype=torch.float32) |
|
24 |
+ |
|
25 |
+ # Inference function |
|
26 |
+ fc_out_o, attention_mask_o, fc2_o = self.forward(input_tensor) |
|
27 |
+ fc_out_r, attention_mask_r, fc2_r = self.forward(label_tensor) |
|
28 |
+ |
|
29 |
+ l_map = F.mse_loss(attention_map, attention_mask_o) + \ |
|
30 |
+ F.mse_loss(attention_mask_r, zeros_mask) |
|
31 |
+ |
|
32 |
+ entropy_loss = -torch.log(fc_out_r) - torch.log(-torch.sub(fc_out_o, 1.0)) |
|
33 |
+ entropy_loss = torch.mean(entropy_loss) |
|
34 |
+ |
|
35 |
+ loss = entropy_loss + 0.05 * l_map |
|
36 |
+ |
|
37 |
+ return fc_out_o, loss |
|
38 |
+ |
|
20 | 39 |
def forward(self, x): |
21 | 40 |
x1 = F.leaky_relu(self.conv1(x)) |
22 | 41 |
x2 = F.leaky_relu(self.conv2(x1)) |
--- tools/dataloader.py
+++ tools/dataloader.py
... | ... | @@ -8,8 +8,6 @@ |
8 | 8 |
self.clean_img = clean_img_dir |
9 | 9 |
self.rainy_img = rainy_img_dirs |
10 | 10 |
self.transform = transform |
11 |
- # self.target_transform = target_transform |
|
12 |
- # print(self.transform) |
|
13 | 11 |
|
14 | 12 |
def __len__(self): |
15 | 13 |
return len(self.clean_img) |
--- train.py
+++ train.py
... | ... | @@ -44,9 +44,18 @@ |
44 | 44 |
generator.apply(weights_init_normal) |
45 | 45 |
discriminator.apply(weights_init_normal) |
46 | 46 |
|
47 |
-dataloader = Dataloader |
|
47 |
+dataloader = Dataloader() |
|
48 | 48 |
|
49 | 49 |
|
50 |
+# declare generator loss |
|
51 |
+ |
|
52 |
+ |
|
53 |
+for epoch in range(epoch): |
|
54 |
+ for i, (imgs, _) in enumerate(dataloader): |
|
55 |
+ |
|
56 |
+ |
|
57 |
+torch.save(generator.attentionRNN.state_dict(), "attentionRNN_model_path") |
|
58 |
+ |
|
50 | 59 |
|
51 | 60 |
|
52 | 61 |
## RNN 따로 돌리고 CPU로 메모리 옳기고 |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?