윤영준 윤영준 2023-06-30
finished training code for generator, now the rest is discriminator.
@a4559baa74c7d70ea5cb3c2cfc1e1a1b1163219e
 
.gitignore (added)
+++ .gitignore
@@ -0,0 +1,2 @@
+# this is image dataset that is about 9 GB
+data/source/Oxford_raindrop_dataset(파일 끝에 줄바꿈 문자 없음)
model/AttentiveRNN.py
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
@@ -202,7 +202,12 @@
         attention_map = []
         lstm_feats = []
         for generator_block in self.generator_blocks:
-            x, attention_map_i, cell_state, lstm_feats_i = generator_block(x, cell_state)
+            generator_block_return = generator_block(x, cell_state)
+            attention_map_i = generator_block_return['attention_map']
+            lstm_feats_i = generator_block_return['lstm_feats']
+            cell_state = generator_block_return['cell_state']
+            x = generator_block_return['x']
+
             attention_map.append(attention_map_i)
             lstm_feats.append(lstm_feats_i)
         ret = {
model/Autoencoder.py
--- model/Autoencoder.py
+++ model/Autoencoder.py
@@ -45,7 +45,7 @@
         # pretrained network for feature extractor for the loss function or even not using feature extractor.
         # I honestly do not think using neural network for VGG is strictly necessary, and may have to be replaced with
         # other image preprocessing like MSCN which was implemented before.
-        self.vgg = vgg16(pretrained=True).features
+        self.vgg = vgg16(weights='VGG16_Weights.IMAGENET1K_V1').features
         self.vgg.eval()
         for param in self.vgg.parameters():
             param.requires_grad = False
@@ -93,15 +93,15 @@
 
         return ret
 
-    def loss(self, input_image_tensor, input_clean_image_tensor):
-        ori_height, ori_width = input_clean_image_tensor.shape[2:]
+    def loss(self, input_clean_img, input_rainy_img):
+        ori_height, ori_width = input_clean_img.shape[2:]
 
         # Rescale labels to match the scales of the outputs
-        label_tensor_resize_2 = F.interpolate(input_clean_image_tensor, size=(ori_height // 2, ori_width // 2))
-        label_tensor_resize_4 = F.interpolate(input_clean_image_tensor, size=(ori_height // 4, ori_width // 4))
-        label_list = [label_tensor_resize_4, label_tensor_resize_2, input_clean_image_tensor]
+        label_tensor_resize_2 = F.interpolate(input_clean_img, size=(ori_height // 2, ori_width // 2))
+        label_tensor_resize_4 = F.interpolate(input_clean_img, size=(ori_height // 4, ori_width // 4))
+        label_list = [label_tensor_resize_4, label_tensor_resize_2, input_clean_img]
 
-        inference_ret = self.forward(input_image_tensor)
+        inference_ret = self.forward(input_rainy_img)
 
         output_list = [inference_ret['skip_1'], inference_ret['skip_2'], inference_ret['skip_3']]
 
@@ -112,7 +112,7 @@
             lm_loss += mse_loss
 
         # Compute lp_loss
-        src_vgg_feats = self.vgg(input_clean_image_tensor)
+        src_vgg_feats = self.vgg(input_clean_img)
         pred_vgg_feats = self.vgg(output_list[-1])
 
         lp_losses = []
model/Discriminator.py
--- model/Discriminator.py
+++ model/Discriminator.py
@@ -2,7 +2,7 @@
 from torch.functional import F
 
 class DiscriminativeNet(nn.Module):
-    def __init__(self, W, H):
+    def __init__(self):
         super(DiscriminativeNet, self).__init__()
         self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=5, stride=2, padding=1)
         self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=2, padding=2)
@@ -49,24 +49,24 @@
         :param attention_map: This is the final attention map from the generator.
         :return:
         """
-        with torch.no_grad():
-            batch_size, image_h, image_w, _ = real_clean.size()
 
-            zeros_mask = torch.zeros([batch_size, image_h, image_w, 1], dtype=torch.float32)
+        batch_size, image_h, image_w, _ = real_clean.size()
 
-            # Inference function
-            ret = self.forward(real_clean)
-            fc_out_o, attention_mask_o, fc2_o = ret["fc_out"], ret["attention_map"], ret["fc_raw"]
-            ret = self.forward(generated_clean)
-            fc_out_r, attention_mask_r, fc2_r = ret["fc_out"], ret["attention_map"], ret["fc_raw"]
+        zeros_mask = torch.zeros([batch_size, image_h, image_w, 1], dtype=torch.float32)
 
-            l_map = F.mse_loss(attention_map, attention_mask_o) + \
-                    F.mse_loss(attention_mask_r, zeros_mask)
+        # Inference function
+        ret = self.forward(real_clean)
+        fc_out_o, attention_mask_o, fc2_o = ret["fc_out"], ret["attention_map"], ret["fc_raw"]
+        ret = self.forward(generated_clean)
+        fc_out_r, attention_mask_r, fc2_r = ret["fc_out"], ret["attention_map"], ret["fc_raw"]
 
-            entropy_loss = -torch.log(fc_out_r) - torch.log(-torch.sub(fc_out_o, 1.0))
-            entropy_loss = torch.mean(entropy_loss)
+        l_map = F.mse_loss(attention_map, attention_mask_o) + \
+                F.mse_loss(attention_mask_r, zeros_mask)
 
-            loss = entropy_loss + 0.05 * l_map
+        entropy_loss = -torch.log(fc_out_r) - torch.log(-torch.sub(fc_out_o, 1.0))
+        entropy_loss = torch.mean(entropy_loss)
+
+        loss = entropy_loss + 0.05 * l_map
 
         return fc_out_o, loss
 
model/Generator.py
--- model/Generator.py
+++ model/Generator.py
@@ -1,6 +1,7 @@
-from AttentiveRNN import AttentiveRNN
-from Autoencoder import AutoEncoder
+from model.AttentiveRNN import AttentiveRNN
+from model.Autoencoder import AutoEncoder
 from torch import nn
+from torch import sum, pow, abs
 
 
 class Generator(nn.Module):
@@ -36,10 +37,10 @@
 
     def binary_diff_mask(self, clean, dirty, thresold=0.1):
         # this parts corrects gamma, and always remember, sRGB values are not in linear scale with lights intensity,
-        clean = torch.pow(clean, 0.45)
-        dirty = torch.pow(dirty, 0.45)
-        diff = torch.abs(clean - dirty)
-        diff = torch.sum(diff, dim=1)
+        clean = pow(clean, 0.45)
+        dirty = pow(dirty, 0.45)
+        diff = abs(clean - dirty)
+        diff = sum(diff, dim=1)
 
         bin_diff = (diff > thresold).to(clean.dtype)
 
tools/argparser.py
--- tools/argparser.py
+++ tools/argparser.py
@@ -8,13 +8,22 @@
     parser.add_argument("--batch_size", "-b", type=int, required=True, help="Size of single batch")
     parser.add_argument("--save_interval", "-s", type=int, required=True, help="Interval for saving weights")
     parser.add_argument("--sample_interval", type=int, required=True, help="Interval for saving inference result")
-    parser.add_argument("--device", "-d", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to use for computation")
-    parser.add_argument("--load", "-l", type=str, help="Path to previous weights for continuing training")
-    parser.add_argument("--generator_learning_rate", "-g_lr", type=float, required=True, help="Learning rate of generator")
-    parser.add_argument("--generator_learning_miniepoch", "-g_epoch", type=int, default=1, help="Number of times generator trains in a single epoch")
-    parser.add_argument("--generator_attentivernn_blocks", "-g_arnn_b", type=int, default=1, help="Number of blocks of RNN in attention network")
-    parser.add_argument("--generator_resnet_depth", "-g_depth", type=int, default=1, help="Depth of ResNet in each attention RNN blocks")
-    parser.add_argument("--discriminator_learning_rate", "-d_lr", type=float, help="Learning rate of discriminator. If not given, it is assumed to be the same as the generator")
+    parser.add_argument("--device", "-d", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to use for "
+                                                                                                 "computation")
+    parser.add_argument("--load", "-l", type=str, default=None, help="Path to previous weights for continuing training")
+    parser.add_argument("--generator_learning_rate", "-g_lr", type=float, required=True, help="Learning rate of "
+                                                                                              "generator")
+    parser.add_argument("--generator_learning_miniepoch", "-g_epoch", type=int, default=1, help="Number of times "
+                                                                                                "generator trains in "
+                                                                                                "a single epoch")
+    parser.add_argument("--generator_attentivernn_blocks", "-g_arnn_b", type=int, default=1, help="Number of blocks "
+                                                                                                  "of RNN in "
+                                                                                                  "attention network")
+    parser.add_argument("--generator_resnet_depth", "-g_depth", type=int, default=1, help="Depth of ResNet in each "
+                                                                                          "attention RNN blocks")
+    parser.add_argument("--discriminator_learning_rate", "-d_lr", type=float, help="Learning rate of discriminator. "
+                                                                                   "If not given, it is assumed to be"
+                                                                                   " the same as the generator")
 
     args = parser.parse_args()
     return args
tools/dataloader.py
--- tools/dataloader.py
+++ tools/dataloader.py
@@ -3,7 +3,7 @@
 import numpy as np
 
 
-class ImageDataSet(Dataset):
+class ImagePairDataset(Dataset):
     def __init__(self, clean_img_dir, rainy_img_dirs, transform=None):
         self.clean_img = clean_img_dir
         self.rainy_img = rainy_img_dirs
@@ -16,15 +16,17 @@
         clean_img_path = self.clean_img[idx]
 
         i = 0
-        if len(self.rainy_img) != 1:
+        if len(self.rainy_img[idx]) is list:
             rng = np.random.default_rng()
             i = rng.integers(low=0, high=len(self.rainy_img)-1)
-        rainy_img_path = self.rainy_img[idx][i]
+            rainy_img_path = self.rainy_img[idx][i]
+        else:
+            rainy_img_path = self.rainy_img[idx]
         clean_image = read_image(clean_img_path)
         rainy_image = read_image(rainy_img_path)
         if self.transform:
             clean_image = self.transform(clean_image)
-            rainy_image = self.rainy_img(rainy_image)
+            rainy_image = self.transform(rainy_image)
 
         ret = {
             "clean_image" : clean_image,
@@ -33,7 +35,7 @@
         return ret
 
     def __add__(self, other):
-        return ImageDataSet(
+        return ImagePairDataset(
             clean_img_dir=self.clean_img+other.clean_img,
             rainy_img_dirs=self.rainy_img+other.rainy_img,
             transform=self.transform
tools/logger.py
--- tools/logger.py
+++ tools/logger.py
@@ -2,8 +2,8 @@
 import time
 import plotly.express as px
 import dash
-import dash_core_components as dcc
-import dash_html_components as html
+from dash import dcc
+from dash import html
 from dash.dependencies import Input, Output
 
 
train.py
--- train.py
+++ train.py
@@ -1,20 +1,25 @@
 import sys
 import os
 import torch
+import glob
 import numpy as np
 import pandas as pd
 import plotly.express as px
+import torchvision.transforms
 from torchvision.utils import save_image
+from torch.utils.data import DataLoader
+from time import gmtime, strftime
 
 from model import Autoencoder
-from model import Generator
-from model import Discriminator
-from model import AttentiveRNN
+from model.Generator import Generator
+from model.Discriminator import DiscriminativeNet as Discriminator
+from model.AttentiveRNN import AttentiveRNN
 from tools.argparser import get_param
 from tools.logger import Logger
-from tools.dataloader import Dataset
+from tools.dataloader import ImagePairDataset
 
 # this function is from https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dualgan/models.py
+# MIT license
 def weights_init_normal(m):
     classname = m.__class__.__name__
     if classname.find("Conv") != -1:
@@ -24,7 +29,9 @@
         torch.nn.init.constant_(m.bias.data, 0.0)
 
 args = get_param()
-
+# I am doing this for easier debugging
+# when you have error on those variables without doing this,
+# you will be in trouble because error message will not say anything.
 epochs = args.epochs
 batch_size = args.batch_size
 save_interval = args.save_interval
@@ -40,44 +47,61 @@
 logger = Logger()
 
 cuda = True if torch.cuda.is_available() else False
-
-generator = Generator() # get network values and stuff
-discriminator = Discriminator()
-
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
-generator = Generator().to(device)
-discriminator = Discriminator().to(device)
+generator = Generator(generator_attentivernn_blocks, generator_resnet_depth).to(device=device) # get network values and stuff
+discriminator = Discriminator().to(device=device)
 
-if load is not False:
-    generator.load_state_dict(torch.load("example_path"))
-    discriminator.load_state_dict(torch.load("example_path"))
+if load is not None:
+    generator.load_state_dict(torch.load(load))
+    discriminator.load_state_dict(torch.load(load))
 else:
-    generator.apply(weights_init_normal)
-    discriminator.apply(weights_init_normal)
+    pass
 
-dataloader = Dataloader()
 
+
+# 이건 땜빵이고 차후에 데이터 관리 모듈 만들꺼임
+rainy_data_path = glob.glob("data/source/Oxford_raindrop_dataset/dirty/*.png")
+rainy_data_path = sorted(rainy_data_path)
+clean_data_path = glob.glob("data/source/Oxford_raindrop_dataset/clean/*.png")
+clean_data_path = sorted(clean_data_path)
+
+resize = torchvision.transforms.Resize((692, 776))
+dataset = ImagePairDataset(clean_data_path, rainy_data_path, transform=resize)
+dataloader = DataLoader(dataset, batch_size=batch_size)
 # declare generator loss
 
-optimizer_G = torch.optim.Adam(generator.parameters(), lr=generator_learning_rate)
-optimizer_D = torch.optim.Adam(generator.parameters(), lr=discriminator_learning_rate)
+optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_learning_rate)
+optimizer_G_AE = torch.optim.Adam(generator.autoencoder.parameters(), lr=generator_learning_rate)
+optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=discriminator_learning_rate)
 
 for epoch_num, epoch in enumerate(range(epochs)):
-    for i, (imgs, _) in enumerate(dataloader):
-        logger.print_training_log(epoch_num, epochs, i, len(enumerate(dataloader)))
+    for i, imgs in enumerate(dataloader):
 
-        img_batch = data[0].to(device)
-        clean_img = img_batch["clean_image"]
-        rainy_img = img_batch["rainy_image"]
+        img_batch = imgs
+        clean_img = img_batch["clean_image"] / 255
+        rainy_img = img_batch["rainy_image"] / 255
 
-        optimizer_G.zero_grad()
-        generator_outputs = generator(clean_img)
-        generator_result = generator_outputs["x"]
-        generator_attention_map = generator_outputs["attention_map_list"]
-        generator_loss = generator.loss(clean_img, rainy_img)
-        generator_loss.backward()
-        optimizer_G.step()
+        clean_img = clean_img.to(device=device)
+        rainy_img = rainy_img.to(device=device)
+
+        optimizer_G_ARNN.zero_grad()
+        optimizer_G_AE.zero_grad()
+
+        attentiveRNNresults = generator.attentiveRNN(clean_img)
+        generator_attention_map = attentiveRNNresults['attention_map_list']
+        binary_difference_mask = generator.binary_diff_mask(clean_img, rainy_img)
+        generator_loss_ARNN = generator.attentiveRNN.loss(clean_img, binary_difference_mask)
+        generator_loss_ARNN.backward()
+        optimizer_G_ARNN.step()
+
+        generator_outputs = generator.autoencoder(attentiveRNNresults['x'] * attentiveRNNresults['attention_map_list'][-1])
+        generator_result = generator_outputs['skip_3']
+
+        generator_loss_AE = generator.autoencoder.loss(clean_img, rainy_img)
+        generator_loss_AE.backward()
+        optimizer_G_AE.step()
+
 
         optimizer_D.zero_grad()
         real_clean_prediction = discriminator(clean_img)
@@ -86,6 +110,19 @@
         discriminator_loss.backward()
         optimizer_D.step()
 
+        losses = {
+            "generator_loss_ARNN": generator_loss_ARNN,
+            "generator_loss_AE" : generator_loss_AE,
+            "discriminator loss" : discriminator_loss
+        }
+
+        logger.print_training_log(epoch_num, epochs, i, len(dataloader), losses)
+
+    day = strftime("%Y-%m-%d %H:%M:%S", gmtime())
+    if epoch % save_interval == 0 and epoch != 0:
+        torch.save(generator.attentionRNN.state_dict(), f"weight/Attention_RNN_{day}.pt")
+        torch.save(generator.state_dict(), f"weight/Generator_{day}.pt")
+        torch.save(discriminator.state_dict(), f"weight/Discriminator_{day}.pt")
 
 
 
@@ -93,7 +130,8 @@
 
 
 
-torch.save(generator.attentionRNN.state_dict(), "attentionRNN_model_path")
+
+
 
 
 
Add a comment
List