--- train.py
+++ train.py
... | ... | @@ -93,7 +93,9 @@ |
93 | 93 |
AE_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AutoEncoder Loss')) |
94 | 94 |
Discriminator_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Discriminator Loss')) |
95 | 95 |
Attention_map_visualizer = vis.image(np.zeros((692, 776)), opts=dict(title='Attention Map')) |
96 |
+Difference_mask_map_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='Mask Map')) |
|
96 | 97 |
Generator_output_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='Generated Derain Output')) |
98 |
+Input_image_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='input clean image')) |
|
97 | 99 |
|
98 | 100 |
for epoch_num, epoch in enumerate(range(epochs)): |
99 | 101 |
for i, imgs in enumerate(dataloader): |
... | ... | @@ -111,14 +113,14 @@ |
111 | 113 |
optimizer_G_ARNN.zero_grad() |
112 | 114 |
optimizer_G_AE.zero_grad() |
113 | 115 |
|
114 |
- attentiveRNNresults = generator.attentiveRNN(clean_img) |
|
116 |
+ attentiveRNNresults = generator.attentiveRNN(rainy_img) |
|
115 | 117 |
generator_attention_map = attentiveRNNresults['attention_map_list'] |
116 |
- binary_difference_mask = generator.binary_diff_mask(clean_img, rainy_img) |
|
117 |
- generator_loss_ARNN = generator.attentiveRNN.loss(clean_img, binary_difference_mask) |
|
118 |
+ binary_difference_mask = generator.binary_diff_mask(clean_img, rainy_img, thresold=0.2) |
|
119 |
+ generator_loss_ARNN = generator.attentiveRNN.loss(rainy_img, binary_difference_mask) |
|
118 | 120 |
generator_loss_ARNN.backward() |
119 | 121 |
optimizer_G_ARNN.step() |
120 | 122 |
|
121 |
- generator_outputs = generator.autoencoder(attentiveRNNresults['x'] * attentiveRNNresults['attention_map_list'][-1]) |
|
123 |
+ generator_outputs = generator.autoencoder(rainy_img * attentiveRNNresults['attention_map_list'][-1]) |
|
122 | 124 |
generator_result = generator_outputs['skip_3'] |
123 | 125 |
generator_output = generator_outputs['output'] |
124 | 126 |
|
... | ... | @@ -156,8 +158,12 @@ |
156 | 158 |
update='append') |
157 | 159 |
vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch * epoch_num + i]), win=Discriminator_loss_window, |
158 | 160 |
update='append') |
159 |
- vis.image(generator_attention_map[-1][0,0,:,:], win=Attention_map_visualizer, opts=dict(title="Attention Map")) |
|
161 |
+ vis.image(generator_attention_map[-1][0, 0, :, :], win=Attention_map_visualizer, |
|
162 |
+ opts=dict(title="Attention Map")) |
|
163 |
+ vis.image(binary_difference_mask[-1], win=Difference_mask_map_visualizer, |
|
164 |
+ opts=dict(title="Binary Mask Map")) |
|
160 | 165 |
vis.image(generator_result[-1], win=Generator_output_visualizer, opts=dict(title="Generator Output")) |
166 |
+ vis.image(clean_img[-1], win=Input_image_visualizer, opts=dict(title="input clean image")) |
|
161 | 167 |
day = strftime("%Y-%m-%d %H:%M:%S", gmtime()) |
162 | 168 |
if epoch % save_interval == 0 and epoch != 0: |
163 | 169 |
torch.save(generator.attentiveRNN.state_dict(), f"weight/Attention_RNN_{epoch}_{day}.pt") |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?