Its running!
still not sure if its really learning
@2fbce641ac973499763f83fc8be17e51f803bbae
--- model/Discriminator.py
+++ model/Discriminator.py
... | ... | @@ -1,3 +1,5 @@ |
1 |
+import torch |
|
2 |
+from numpy import ceil |
|
1 | 3 |
from torch import nn, clamp |
2 | 4 |
from torch.functional import F |
3 | 5 |
|
... | ... | @@ -50,9 +52,9 @@ |
50 | 52 |
:return: |
51 | 53 |
""" |
52 | 54 |
|
53 |
- batch_size, image_h, image_w, _ = real_clean.size() |
|
54 |
- |
|
55 |
- zeros_mask = torch.zeros([batch_size, image_h, image_w, 1], dtype=torch.float32) |
|
55 |
+ batch_size, _, image_h, image_w = real_clean.size() |
|
56 |
+ attention_map = F.interpolate(attention_map[-1], size=(int(ceil(image_h/16)), int(ceil(image_w/16)))) |
|
57 |
+ zeros_mask = torch.zeros([batch_size, 1, int(ceil(image_h/16)), int(ceil(image_w/16))], dtype=torch.float32).to(attention_map.device) |
|
56 | 58 |
|
57 | 59 |
# Inference function |
58 | 60 |
ret = self.forward(real_clean) |
... | ... | @@ -68,7 +70,7 @@ |
68 | 70 |
|
69 | 71 |
loss = entropy_loss + 0.05 * l_map |
70 | 72 |
|
71 |
- return fc_out_o, loss |
|
73 |
+ return loss |
|
72 | 74 |
|
73 | 75 |
if __name__ == "__main__": |
74 | 76 |
import torch |
--- tools/logger.py
+++ tools/logger.py
... | ... | @@ -50,7 +50,7 @@ |
50 | 50 |
|
51 | 51 |
sys.stdout.write( |
52 | 52 |
f"epoch : {current_epoch}/{total_epoch}\n" |
53 |
- f"batch : {current_batch}/{total_batch}" |
|
53 |
+ f"batch : {current_batch}/{total_batch}\n" |
|
54 | 54 |
f"estimated time remaining : {remaining_time}" |
55 | 55 |
f"{terminal_logging_string}" |
56 | 56 |
) |
--- train.py
+++ train.py
... | ... | @@ -105,7 +105,7 @@ |
105 | 105 |
|
106 | 106 |
optimizer_D.zero_grad() |
107 | 107 |
real_clean_prediction = discriminator(clean_img) |
108 |
- discriminator_loss = discriminator.loss(real_clean_prediction, generator_result, generator_attention_map) |
|
108 |
+ discriminator_loss = discriminator.loss(clean_img, generator_result, generator_attention_map) |
|
109 | 109 |
|
110 | 110 |
discriminator_loss.backward() |
111 | 111 |
optimizer_D.step() |
... | ... | @@ -123,11 +123,6 @@ |
123 | 123 |
torch.save(generator.attentionRNN.state_dict(), f"weight/Attention_RNN_{day}.pt") |
124 | 124 |
torch.save(generator.state_dict(), f"weight/Generator_{day}.pt") |
125 | 125 |
torch.save(discriminator.state_dict(), f"weight/Discriminator_{day}.pt") |
126 |
- |
|
127 |
- |
|
128 |
- |
|
129 |
- |
|
130 |
- |
|
131 | 126 |
|
132 | 127 |
|
133 | 128 |
|
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?