윤영준 윤영준 2023-07-05
Added a vizdom window for attention map to view training
@472af24ca65bbc328559b5050b4548e225b0b586
train.py
--- train.py
+++ train.py
@@ -74,7 +74,7 @@
 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
 # declare generator loss
 
-
+optimizer_G = torch.optim.Adam(generator.parameters(), lr=generator_learning_rate)
 optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_learning_rate)
 optimizer_G_AE = torch.optim.Adam(generator.autoencoder.parameters(), lr=generator_learning_rate)
 optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=discriminator_learning_rate)
@@ -91,6 +91,8 @@
 ARNN_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AttentionRNN Loss'))
 AE_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AutoEncoder Loss'))
 Discriminator_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Discriminator Loss'))
+attenton_map_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='Attention Map'))
+
 
 for epoch_num, epoch in enumerate(range(epochs)):
     for i, imgs in enumerate(dataloader):
@@ -125,10 +127,17 @@
 
         optimizer_D.zero_grad()
         real_clean_prediction = discriminator(clean_img)
+        fake_clean_prediction = discriminator(generator_result)["fc_out"]
         discriminator_loss = discriminator.loss(clean_img, generator_result, generator_attention_map)
-
         discriminator_loss.backward()
+
         optimizer_D.step()
+
+
+        optimizer_G.zero_grad()
+        generator_loss_whole = generator_loss_AE + generator_loss_ARNN + torch.mean(
+            torch.log(torch.subtract(1, fake_clean_prediction)))
+        optimizer_G.step()
 
         losses = {
             "generator_loss_ARNN": generator_loss_ARNN,
@@ -145,6 +154,7 @@
                  update='append')
         vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch * epoch_num + i]), win=Discriminator_loss_window,
                  update='append')
+        vis.image(generator_attention_map[-1], win=attenton_map_visualizer, opts=dict(title="Attention Map"))
 
     day = strftime("%Y-%m-%d %H:%M:%S", gmtime())
     if epoch % save_interval == 0 and epoch != 0:
Add a comment
List