윤영준 윤영준 2023-07-05
corrected wrong ... binary mask. It was inverted
@5225b78301d1af4858c87f98e44204efa59c2a9b
model/Generator.py
--- model/Generator.py
+++ model/Generator.py
@@ -42,7 +42,7 @@
         diff = abs(clean - dirty)
         diff = sum(diff, dim=1)
 
-        bin_diff = (diff > thresold).to(clean.dtype)
+        bin_diff = (diff < thresold).to(clean.dtype)
 
         return bin_diff
 
tools/argparser.py
--- tools/argparser.py
+++ tools/argparser.py
@@ -15,6 +15,10 @@
     parser.add_argument("--load", "-l", type=str, default=None, help="Path to previous weights for continuing training")
     parser.add_argument("--generator_learning_rate", "-g_lr", type=float, required=True, help="Learning rate of "
                                                                                               "generator")
+    parser.add_argument("--generator_ARNN_learning_rate", "-g_arnn_lr", type=float, help="learning rate of Attention "
+                                                                                         "RNN network, default is "
+                                                                                         "same as the whole generator "
+                                                                                         "(autoencoder)")
     parser.add_argument("--generator_learning_miniepoch", "-g_epoch", type=int, default=1, help="Number of times "
                                                                                                 "generator trains in "
                                                                                                 "a single epoch")
@@ -23,9 +27,12 @@
                                                                                                   "attention network")
     parser.add_argument("--generator_resnet_depth", "-g_depth", type=int, default=1, help="Depth of ResNet in each "
                                                                                           "attention RNN blocks")
-    parser.add_argument("--discriminator_learning_rate", "-d_lr", type=float, help="Learning rate of discriminator. "
-                                                                                   "If not given, it is assumed to be"
-                                                                                   " the same as the generator")
+    parser.add_argument("--discriminator_learning_rate", "-d_lr", default=None, type=float, help="Learning rate of "
+                                                                                                 "discriminator."
+                                                                                                 "If not given, it is "
+                                                                                                 "assumed to be"
+                                                                                                 "the same as the "
+                                                                                                 "generator")
 
     args = parser.parse_args()
     return args
train.py
--- train.py
+++ train.py
@@ -44,6 +44,7 @@
 device = args.device
 load = args.load
 generator_learning_rate = args.generator_learning_rate
+generator_ARNN_learning_rate = args.generator_ARNN_learning_rate if args.discriminator_learning_rate is not None else args.generator_learning_rate
 generator_learning_miniepoch = args.generator_learning_miniepoch
 generator_attentivernn_blocks = args.generator_attentivernn_blocks
 generator_resnet_depth = args.generator_resnet_depth
@@ -75,7 +76,7 @@
 # declare generator loss
 
 optimizer_G = torch.optim.Adam(generator.parameters(), lr=generator_learning_rate)
-optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_learning_rate)
+optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_ARNN_learning_rate)
 optimizer_G_AE = torch.optim.Adam(generator.autoencoder.parameters(), lr=generator_learning_rate)
 optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=discriminator_learning_rate)
 
@@ -119,6 +120,7 @@
 
         generator_outputs = generator.autoencoder(attentiveRNNresults['x'] * attentiveRNNresults['attention_map_list'][-1])
         generator_result = generator_outputs['skip_3']
+        generator_output = generator_outputs['output']
 
         generator_loss_AE = generator.autoencoder.loss(clean_img, rainy_img)
         generator_loss_AE.backward()
@@ -136,7 +138,8 @@
 
         optimizer_G.zero_grad()
         generator_loss_whole = generator_loss_AE + generator_loss_ARNN + torch.mean(
-            torch.log(torch.subtract(1, fake_clean_prediction)))
+            torch.log(torch.subtract(1, fake_clean_prediction))
+        )
         optimizer_G.step()
 
         losses = {
@@ -154,7 +157,7 @@
         vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch * epoch_num + i]), win=Discriminator_loss_window,
                  update='append')
         vis.image(generator_attention_map[-1][0,0,:,:], win=Attention_map_visualizer, opts=dict(title="Attention Map"))
-        vis.image(generator_result[-1]*255, win=Generator_output_visualizer, opts=dict(title="Generator Output"))
+        vis.image(generator_result[-1], win=Generator_output_visualizer, opts=dict(title="Generator Output"))
     day = strftime("%Y-%m-%d %H:%M:%S", gmtime())
     if epoch % save_interval == 0 and epoch != 0:
         torch.save(generator.attentiveRNN.state_dict(), f"weight/Attention_RNN_{epoch}_{day}.pt")
Add a comment
List