corrected wrong ... binary mask. It was inverted
@5225b78301d1af4858c87f98e44204efa59c2a9b
--- model/Generator.py
+++ model/Generator.py
... | ... | @@ -42,7 +42,7 @@ |
42 | 42 |
diff = abs(clean - dirty) |
43 | 43 |
diff = sum(diff, dim=1) |
44 | 44 |
|
45 |
- bin_diff = (diff > thresold).to(clean.dtype) |
|
45 |
+ bin_diff = (diff < thresold).to(clean.dtype) |
|
46 | 46 |
|
47 | 47 |
return bin_diff |
48 | 48 |
|
--- tools/argparser.py
+++ tools/argparser.py
... | ... | @@ -15,6 +15,10 @@ |
15 | 15 |
parser.add_argument("--load", "-l", type=str, default=None, help="Path to previous weights for continuing training") |
16 | 16 |
parser.add_argument("--generator_learning_rate", "-g_lr", type=float, required=True, help="Learning rate of " |
17 | 17 |
"generator") |
18 |
+ parser.add_argument("--generator_ARNN_learning_rate", "-g_arnn_lr", type=float, help="learning rate of Attention " |
|
19 |
+ "RNN network, default is " |
|
20 |
+ "same as the whole generator " |
|
21 |
+ "(autoencoder)") |
|
18 | 22 |
parser.add_argument("--generator_learning_miniepoch", "-g_epoch", type=int, default=1, help="Number of times " |
19 | 23 |
"generator trains in " |
20 | 24 |
"a single epoch") |
... | ... | @@ -23,9 +27,12 @@ |
23 | 27 |
"attention network") |
24 | 28 |
parser.add_argument("--generator_resnet_depth", "-g_depth", type=int, default=1, help="Depth of ResNet in each " |
25 | 29 |
"attention RNN blocks") |
26 |
- parser.add_argument("--discriminator_learning_rate", "-d_lr", type=float, help="Learning rate of discriminator. " |
|
27 |
- "If not given, it is assumed to be" |
|
28 |
- " the same as the generator") |
|
30 |
+ parser.add_argument("--discriminator_learning_rate", "-d_lr", default=None, type=float, help="Learning rate of " |
|
31 |
+ "discriminator." |
|
32 |
+ "If not given, it is " |
|
33 |
+ "assumed to be" |
|
34 |
+ "the same as the " |
|
35 |
+ "generator") |
|
29 | 36 |
|
30 | 37 |
args = parser.parse_args() |
31 | 38 |
return args |
--- train.py
+++ train.py
... | ... | @@ -44,6 +44,7 @@ |
44 | 44 |
device = args.device |
45 | 45 |
load = args.load |
46 | 46 |
generator_learning_rate = args.generator_learning_rate |
47 |
+generator_ARNN_learning_rate = args.generator_ARNN_learning_rate if args.discriminator_learning_rate is not None else args.generator_learning_rate |
|
47 | 48 |
generator_learning_miniepoch = args.generator_learning_miniepoch |
48 | 49 |
generator_attentivernn_blocks = args.generator_attentivernn_blocks |
49 | 50 |
generator_resnet_depth = args.generator_resnet_depth |
... | ... | @@ -75,7 +76,7 @@ |
75 | 76 |
# declare generator loss |
76 | 77 |
|
77 | 78 |
optimizer_G = torch.optim.Adam(generator.parameters(), lr=generator_learning_rate) |
78 |
-optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_learning_rate) |
|
79 |
+optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_ARNN_learning_rate) |
|
79 | 80 |
optimizer_G_AE = torch.optim.Adam(generator.autoencoder.parameters(), lr=generator_learning_rate) |
80 | 81 |
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=discriminator_learning_rate) |
81 | 82 |
|
... | ... | @@ -119,6 +120,7 @@ |
119 | 120 |
|
120 | 121 |
generator_outputs = generator.autoencoder(attentiveRNNresults['x'] * attentiveRNNresults['attention_map_list'][-1]) |
121 | 122 |
generator_result = generator_outputs['skip_3'] |
123 |
+ generator_output = generator_outputs['output'] |
|
122 | 124 |
|
123 | 125 |
generator_loss_AE = generator.autoencoder.loss(clean_img, rainy_img) |
124 | 126 |
generator_loss_AE.backward() |
... | ... | @@ -136,7 +138,8 @@ |
136 | 138 |
|
137 | 139 |
optimizer_G.zero_grad() |
138 | 140 |
generator_loss_whole = generator_loss_AE + generator_loss_ARNN + torch.mean( |
139 |
- torch.log(torch.subtract(1, fake_clean_prediction))) |
|
141 |
+ torch.log(torch.subtract(1, fake_clean_prediction)) |
|
142 |
+ ) |
|
140 | 143 |
optimizer_G.step() |
141 | 144 |
|
142 | 145 |
losses = { |
... | ... | @@ -154,7 +157,7 @@ |
154 | 157 |
vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch * epoch_num + i]), win=Discriminator_loss_window, |
155 | 158 |
update='append') |
156 | 159 |
vis.image(generator_attention_map[-1][0,0,:,:], win=Attention_map_visualizer, opts=dict(title="Attention Map")) |
157 |
- vis.image(generator_result[-1]*255, win=Generator_output_visualizer, opts=dict(title="Generator Output")) |
|
160 |
+ vis.image(generator_result[-1], win=Generator_output_visualizer, opts=dict(title="Generator Output")) |
|
158 | 161 |
day = strftime("%Y-%m-%d %H:%M:%S", gmtime()) |
159 | 162 |
if epoch % save_interval == 0 and epoch != 0: |
160 | 163 |
torch.save(generator.attentiveRNN.state_dict(), f"weight/Attention_RNN_{epoch}_{day}.pt") |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?