윤영준 윤영준 2023-07-10
train code update
@a7290ec674c4f46888a63c84e9c7b9af6bad60a7
train.py
--- train.py
+++ train.py
@@ -93,7 +93,9 @@
 AE_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AutoEncoder Loss'))
 Discriminator_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Discriminator Loss'))
 Attention_map_visualizer = vis.image(np.zeros((692, 776)), opts=dict(title='Attention Map'))
+Difference_mask_map_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='Mask Map'))
 Generator_output_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='Generated Derain Output'))
+Input_image_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='input clean image'))
 
 for epoch_num, epoch in enumerate(range(epochs)):
     for i, imgs in enumerate(dataloader):
@@ -111,14 +113,14 @@
         optimizer_G_ARNN.zero_grad()
         optimizer_G_AE.zero_grad()
 
-        attentiveRNNresults = generator.attentiveRNN(clean_img)
+        attentiveRNNresults = generator.attentiveRNN(rainy_img)
         generator_attention_map = attentiveRNNresults['attention_map_list']
-        binary_difference_mask = generator.binary_diff_mask(clean_img, rainy_img)
-        generator_loss_ARNN = generator.attentiveRNN.loss(clean_img, binary_difference_mask)
+        binary_difference_mask = generator.binary_diff_mask(clean_img, rainy_img, thresold=0.2)
+        generator_loss_ARNN = generator.attentiveRNN.loss(rainy_img, binary_difference_mask)
         generator_loss_ARNN.backward()
         optimizer_G_ARNN.step()
 
-        generator_outputs = generator.autoencoder(attentiveRNNresults['x'] * attentiveRNNresults['attention_map_list'][-1])
+        generator_outputs = generator.autoencoder(rainy_img * attentiveRNNresults['attention_map_list'][-1])
         generator_result = generator_outputs['skip_3']
         generator_output = generator_outputs['output']
 
@@ -156,8 +158,12 @@
                  update='append')
         vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch * epoch_num + i]), win=Discriminator_loss_window,
                  update='append')
-        vis.image(generator_attention_map[-1][0,0,:,:], win=Attention_map_visualizer, opts=dict(title="Attention Map"))
+        vis.image(generator_attention_map[-1][0, 0, :, :], win=Attention_map_visualizer,
+                  opts=dict(title="Attention Map"))
+        vis.image(binary_difference_mask[-1], win=Difference_mask_map_visualizer,
+                  opts=dict(title="Binary Mask Map"))
         vis.image(generator_result[-1], win=Generator_output_visualizer, opts=dict(title="Generator Output"))
+        vis.image(clean_img[-1], win=Input_image_visualizer, opts=dict(title="input clean image"))
     day = strftime("%Y-%m-%d %H:%M:%S", gmtime())
     if epoch % save_interval == 0 and epoch != 0:
         torch.save(generator.attentiveRNN.state_dict(), f"weight/Attention_RNN_{epoch}_{day}.pt")
Add a comment
List