윤영준 윤영준 2023-07-05
Generator visualizer
@7cbe1b2dee697674c26dd8b7f95db7c44c354532
train.py
--- train.py
+++ train.py
@@ -91,8 +91,8 @@
 ARNN_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AttentionRNN Loss'))
 AE_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AutoEncoder Loss'))
 Discriminator_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Discriminator Loss'))
-attenton_map_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='Attention Map'))
-
+Attention_map_visualizer = vis.image(np.zeros((692, 776)), opts=dict(title='Attention Map'))
+Generator_output_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='Generated Derain Output'))
 
 for epoch_num, epoch in enumerate(range(epochs)):
     for i, imgs in enumerate(dataloader):
@@ -146,7 +146,6 @@
         }
 
         logger.print_training_log(epoch_num, epochs, i, len(dataloader), losses)
-
         # visdom logger
         vis.line(Y=np.array([generator_loss_ARNN.item()]), X=np.array([epoch * epoch_num + i]), win=ARNN_loss_window,
                  update='append')
@@ -154,8 +153,8 @@
                  update='append')
         vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch * epoch_num + i]), win=Discriminator_loss_window,
                  update='append')
-        vis.image(generator_attention_map[-1], win=attenton_map_visualizer, opts=dict(title="Attention Map"))
-
+        vis.image(generator_attention_map[-1][0,0,:,:], win=Attention_map_visualizer, opts=dict(title="Attention Map"))
+        vis.image(generator_result['skip_3'][-1][0,0,:,:], win=Generator_output_visualizer, opts=dict(title="Generator Output"))
     day = strftime("%Y-%m-%d %H:%M:%S", gmtime())
     if epoch % save_interval == 0 and epoch != 0:
         torch.save(generator.attentiveRNN.state_dict(), f"weight/Attention_RNN_{epoch}_{day}.pt")
Add a comment
List