File name
Commit message
Commit date
import sys
import os
import torch
import glob
import numpy as np
import time
import pandas as pd
import subprocess
import atexit
import torchvision.transforms
from visdom import Visdom
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from time import gmtime, strftime
from model.Autoencoder import AutoEncoder
from model.Generator import Generator
from model.Discriminator import DiscriminativeNet as Discriminator
from model.AttentiveRNN import AttentiveRNN
from tools.argparser import get_param
from tools.logger import Logger
from tools.dataloader import ImagePairDataset
# this function is from https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dualgan/models.py
# MIT license
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
args = get_param()
# I am doing this for easier debugging
# when you have error on those variables without doing this,
# you will be in trouble because error message will not say anything.
epochs = args.epochs
batch_size = args.batch_size
save_interval = args.save_interval
sample_interval = args.sample_interval
num_worker = args.num_worker
device = args.device
load = args.load
generator_learning_rate = args.generator_learning_rate
generator_ARNN_learning_rate = args.generator_ARNN_learning_rate if args.discriminator_learning_rate is not None else args.generator_learning_rate
generator_learning_miniepoch = args.generator_learning_miniepoch
generator_attentivernn_blocks = args.generator_attentivernn_blocks
generator_resnet_depth = args.generator_resnet_depth
discriminator_learning_rate = args.discriminator_learning_rate if args.discriminator_learning_rate is not None else args.generator_learning_rate
logger = Logger()
cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = Generator(generator_attentivernn_blocks, generator_resnet_depth).to(device=device)
discriminator = Discriminator().to(device=device)
if load is not None:
generator.attentiveRNN.load_state_dict(torch.load(load))
generator.autoencoder.load_state_dict(torch.load(load))
discriminator.load_state_dict(torch.load(load))
else:
pass
# 이건 땜빵이고 차후에 데이터 관리 모듈 만들꺼임
rainy_data_path = glob.glob("data/source/Oxford_raindrop_dataset/dirty/*.png")
rainy_data_path = sorted(rainy_data_path)
clean_data_path = glob.glob("data/source/Oxford_raindrop_dataset/clean/*.png")
clean_data_path = sorted(clean_data_path)
resize = torchvision.transforms.Resize((692, 776), antialias=True)
dataset = ImagePairDataset(clean_data_path, rainy_data_path, transform=resize)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
# declare generator loss
optimizer_G = torch.optim.Adam(generator.parameters(), lr=generator_learning_rate)
optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_ARNN_learning_rate)
optimizer_G_AE = torch.optim.Adam(generator.autoencoder.parameters(), lr=generator_learning_rate)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=discriminator_learning_rate)
# ------visdom visualizer ----------
server_process = subprocess.Popen("python -m visdom.server", shell=True)
# to ensure the visdom server process must stop whenever the script is terminated.
def cleanup():
server_process.terminate()
atexit.register(cleanup)
time.sleep(10)
vis = Visdom(server="http://localhost", port=8097)
ARNN_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AttentionRNN Loss'))
AE_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AutoEncoder Loss'))
Discriminator_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Discriminator Loss'))
Attention_map_visualizer = vis.image(np.zeros((692, 776)), opts=dict(title='Attention Map'))
Generator_output_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='Generated Derain Output'))
for epoch_num, epoch in enumerate(range(epochs)):
for i, imgs in enumerate(dataloader):
img_batch = imgs
clean_img = img_batch["clean_image"] / 255
rainy_img = img_batch["rainy_image"] / 255
# clean_img = clean_img.to()
# rainy_img = rainy_img.to()
clean_img = clean_img.to(device=device)
rainy_img = rainy_img.to(device=device)
optimizer_G_ARNN.zero_grad()
optimizer_G_AE.zero_grad()
attentiveRNNresults = generator.attentiveRNN(clean_img)
generator_attention_map = attentiveRNNresults['attention_map_list']
binary_difference_mask = generator.binary_diff_mask(clean_img, rainy_img)
generator_loss_ARNN = generator.attentiveRNN.loss(clean_img, binary_difference_mask)
generator_loss_ARNN.backward()
optimizer_G_ARNN.step()
generator_outputs = generator.autoencoder(attentiveRNNresults['x'] * attentiveRNNresults['attention_map_list'][-1])
generator_result = generator_outputs['skip_3']
generator_output = generator_outputs['output']
generator_loss_AE = generator.autoencoder.loss(clean_img, rainy_img)
generator_loss_AE.backward()
optimizer_G_AE.step()
optimizer_D.zero_grad()
real_clean_prediction = discriminator(clean_img)
fake_clean_prediction = discriminator(generator_result)["fc_out"]
discriminator_loss = discriminator.loss(clean_img, generator_result, generator_attention_map)
discriminator_loss.backward()
optimizer_D.step()
# Total loss
optimizer_G.zero_grad()
generator_loss_whole = generator_loss_AE + generator_loss_ARNN + torch.mean(
torch.log(torch.subtract(1, fake_clean_prediction))
)
optimizer_G.step()
losses = {
"generator_loss_ARNN": generator_loss_ARNN,
"generator_loss_AE" : generator_loss_AE,
"discriminator loss" : discriminator_loss
}
logger.print_training_log(epoch_num, epochs, i, len(dataloader), losses)
# visdom logger
vis.line(Y=np.array([generator_loss_ARNN.item()]), X=np.array([epoch * epoch_num + i]), win=ARNN_loss_window,
update='append')
vis.line(Y=np.array([generator_loss_AE.item()]), X=np.array([epoch * epoch_num + i]), win=AE_loss_window,
update='append')
vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch * epoch_num + i]), win=Discriminator_loss_window,
update='append')
vis.image(generator_attention_map[-1][0,0,:,:], win=Attention_map_visualizer, opts=dict(title="Attention Map"))
vis.image(generator_result[-1], win=Generator_output_visualizer, opts=dict(title="Generator Output"))
day = strftime("%Y-%m-%d %H:%M:%S", gmtime())
if epoch % save_interval == 0 and epoch != 0:
torch.save(generator.attentiveRNN.state_dict(), f"weight/Attention_RNN_{epoch}_{day}.pt")
torch.save(generator.state_dict(), f"weight/Generator_{epoch}_{day}.pt")
torch.save(discriminator.state_dict(), f"weight/Discriminator_{epoch}_{day}.pt")
server_process.terminate()
## RNN 따로 돌리고 CPU로 메모리 옳기고
## Autoencoder 따로 돌리고 메모리 옳기고
## 안되는가
## 대충 열심히 GAN 구성하는 코드
## 대충 그래서 weight export해서 inference용과 training용으로 나누는 코드
## 대충 그래서 inference용은 attention map까지 하는 녀석과 deraining까지 하는 녀석 두개가 나오는 코드
## 학습용은 그래서 풀 weight 나옴
## GAN은 학습 시키면 Nash equilibrium ... 나오게 할 수 있으려나?
## 대충 학습은 어떻게 돌려야 되지 하는 코드
## generator에서 튀어 나온 애들을 따로 저장해야 하는건가