윤영준 윤영준 2023-07-24
typo
@97e1e42f0ac512996e82807a606d568732ee375c
model/AttentiveRNN.py
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
@@ -101,7 +101,7 @@
         attention_map = torch.sigmoid(attention_map)
 
         ret = {
-            "attention_amp" : attention_map,
+            "attention_map" : attention_map,
             "cell_state" : cell_state,
             "lstm_feats" : lstm_feats
         }
train.py
--- train.py
+++ train.py
@@ -8,8 +8,12 @@
 import subprocess
 import atexit
 import torchvision.transforms
+import cv2
+
 from visdom import Visdom
 from torchvision.utils import save_image
+from torchvision import transforms
+from torchvision.transforms import RandomCrop, RandomPerspective, Compose
 from torch.utils.data import DataLoader
 from time import gmtime, strftime
 
@@ -20,17 +24,6 @@
 from tools.argparser import get_param
 from tools.logger import Logger
 from tools.dataloader import ImagePairDataset
-
-
-# this function is from https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dualgan/models.py
-# MIT license
-def weights_init_normal(m):
-    classname = m.__class__.__name__
-    if classname.find("Conv") != -1:
-        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
-    elif classname.find("BatchNorm2d") != -1:
-        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
-        torch.nn.init.constant_(m.bias.data, 0.0)
 
 args = get_param()
 # I am doing this for easier debugging
@@ -71,14 +64,22 @@
 clean_data_path = glob.glob("data/source/Oxford_raindrop_dataset/clean/*.png")
 clean_data_path = sorted(clean_data_path)
 
-resize = torchvision.transforms.Resize((480, 720), antialias=True)
-dataset = ImagePairDataset(clean_data_path, rainy_data_path, transform=resize)
+height = 480
+width  = 720
+
+transform = Compose([
+    RandomPerspective(),
+    RandomCrop((height, width))
+])
+
+resize = torchvision.transforms.Resize((height, width), antialias=True)
+dataset = ImagePairDataset(clean_data_path, rainy_data_path, transform=transform)
 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
 # declare generator loss
 
 optimizer_G = torch.optim.Adam(generator.parameters(), lr=generator_learning_rate)
-optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_ARNN_learning_rate)
-optimizer_G_AE = torch.optim.Adam(generator.autoencoder.parameters(), lr=generator_learning_rate)
+optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_ARNN_learning_rate, betas=(0.5, 0.999))
+optimizer_G_AE = torch.optim.Adam(generator.autoencoder.parameters(), lr=generator_learning_rate, betas=(0.5, 0.999))
 optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=discriminator_learning_rate)
 
 # ------visdom visualizer ----------
@@ -93,10 +94,10 @@
 ARNN_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AttentionRNN Loss'))
 AE_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AutoEncoder Loss'))
 Discriminator_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Discriminator Loss'))
-Attention_map_visualizer = vis.image(np.zeros((692, 776)), opts=dict(title='Attention Map'))
-Difference_mask_map_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='Mask Map'))
-Generator_output_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='Generated Derain Output'))
-Input_image_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='input clean image'))
+Attention_map_visualizer = vis.image(np.zeros((height, width)), opts=dict(title='Attention Map'))
+Difference_mask_map_visualizer = vis.image(np.zeros((height, width)), opts=dict(title='Mask Map'))
+Generator_output_visualizer = vis.image(np.zeros((height, width)), opts=dict(title='Generated Derain Output'))
+Input_image_visualizer = vis.image(np.zeros((height, width)), opts=dict(title='input clean image'))
 
 for epoch_num, epoch in enumerate(range(epochs)):
     for i, imgs in enumerate(dataloader):
@@ -153,11 +154,11 @@
 
         logger.print_training_log(epoch_num, epochs, i, len(dataloader), losses)
         # visdom logger
-        vis.line(Y=np.array([generator_loss_ARNN.item()]), X=np.array([epoch * epoch_num + i]), win=ARNN_loss_window,
+        vis.line(Y=np.array([generator_loss_ARNN.item()]), X=np.array([epoch_num * epochs + i]), win=ARNN_loss_window,
                  update='append')
-        vis.line(Y=np.array([generator_loss_AE.item()]), X=np.array([epoch * epoch_num + i]), win=AE_loss_window,
+        vis.line(Y=np.array([generator_loss_AE.item()]), X=np.array([epoch_num * epochs + i]), win=AE_loss_window,
                  update='append')
-        vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch * epoch_num + i]), win=Discriminator_loss_window,
+        vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch_num * epochs + i]), win=Discriminator_loss_window,
                  update='append')
         vis.image(generator_attention_map[-1][0, 0, :, :], win=Attention_map_visualizer,
                   opts=dict(title="Attention Map"))
Add a comment
List