data:image/s3,"s3://crabby-images/77fc1/77fc1ecd598263bdfa1d6248fbe60b3bfc41f6f8" alt=""
finished training code for generator, now the rest is discriminator.
@a4559baa74c7d70ea5cb3c2cfc1e1a1b1163219e
+++ .gitignore
... | ... | @@ -0,0 +1,2 @@ |
1 | +# this is image dataset that is about 9 GB | |
2 | +data/source/Oxford_raindrop_dataset(파일 끝에 줄바꿈 문자 없음) |
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
... | ... | @@ -202,7 +202,12 @@ |
202 | 202 |
attention_map = [] |
203 | 203 |
lstm_feats = [] |
204 | 204 |
for generator_block in self.generator_blocks: |
205 |
- x, attention_map_i, cell_state, lstm_feats_i = generator_block(x, cell_state) |
|
205 |
+ generator_block_return = generator_block(x, cell_state) |
|
206 |
+ attention_map_i = generator_block_return['attention_map'] |
|
207 |
+ lstm_feats_i = generator_block_return['lstm_feats'] |
|
208 |
+ cell_state = generator_block_return['cell_state'] |
|
209 |
+ x = generator_block_return['x'] |
|
210 |
+ |
|
206 | 211 |
attention_map.append(attention_map_i) |
207 | 212 |
lstm_feats.append(lstm_feats_i) |
208 | 213 |
ret = { |
--- model/Autoencoder.py
+++ model/Autoencoder.py
... | ... | @@ -45,7 +45,7 @@ |
45 | 45 |
# pretrained network for feature extractor for the loss function or even not using feature extractor. |
46 | 46 |
# I honestly do not think using neural network for VGG is strictly necessary, and may have to be replaced with |
47 | 47 |
# other image preprocessing like MSCN which was implemented before. |
48 |
- self.vgg = vgg16(pretrained=True).features |
|
48 |
+ self.vgg = vgg16(weights='VGG16_Weights.IMAGENET1K_V1').features |
|
49 | 49 |
self.vgg.eval() |
50 | 50 |
for param in self.vgg.parameters(): |
51 | 51 |
param.requires_grad = False |
... | ... | @@ -93,15 +93,15 @@ |
93 | 93 |
|
94 | 94 |
return ret |
95 | 95 |
|
96 |
- def loss(self, input_image_tensor, input_clean_image_tensor): |
|
97 |
- ori_height, ori_width = input_clean_image_tensor.shape[2:] |
|
96 |
+ def loss(self, input_clean_img, input_rainy_img): |
|
97 |
+ ori_height, ori_width = input_clean_img.shape[2:] |
|
98 | 98 |
|
99 | 99 |
# Rescale labels to match the scales of the outputs |
100 |
- label_tensor_resize_2 = F.interpolate(input_clean_image_tensor, size=(ori_height // 2, ori_width // 2)) |
|
101 |
- label_tensor_resize_4 = F.interpolate(input_clean_image_tensor, size=(ori_height // 4, ori_width // 4)) |
|
102 |
- label_list = [label_tensor_resize_4, label_tensor_resize_2, input_clean_image_tensor] |
|
100 |
+ label_tensor_resize_2 = F.interpolate(input_clean_img, size=(ori_height // 2, ori_width // 2)) |
|
101 |
+ label_tensor_resize_4 = F.interpolate(input_clean_img, size=(ori_height // 4, ori_width // 4)) |
|
102 |
+ label_list = [label_tensor_resize_4, label_tensor_resize_2, input_clean_img] |
|
103 | 103 |
|
104 |
- inference_ret = self.forward(input_image_tensor) |
|
104 |
+ inference_ret = self.forward(input_rainy_img) |
|
105 | 105 |
|
106 | 106 |
output_list = [inference_ret['skip_1'], inference_ret['skip_2'], inference_ret['skip_3']] |
107 | 107 |
|
... | ... | @@ -112,7 +112,7 @@ |
112 | 112 |
lm_loss += mse_loss |
113 | 113 |
|
114 | 114 |
# Compute lp_loss |
115 |
- src_vgg_feats = self.vgg(input_clean_image_tensor) |
|
115 |
+ src_vgg_feats = self.vgg(input_clean_img) |
|
116 | 116 |
pred_vgg_feats = self.vgg(output_list[-1]) |
117 | 117 |
|
118 | 118 |
lp_losses = [] |
--- model/Discriminator.py
+++ model/Discriminator.py
... | ... | @@ -2,7 +2,7 @@ |
2 | 2 |
from torch.functional import F |
3 | 3 |
|
4 | 4 |
class DiscriminativeNet(nn.Module): |
5 |
- def __init__(self, W, H): |
|
5 |
+ def __init__(self): |
|
6 | 6 |
super(DiscriminativeNet, self).__init__() |
7 | 7 |
self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=5, stride=2, padding=1) |
8 | 8 |
self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=2, padding=2) |
... | ... | @@ -49,24 +49,24 @@ |
49 | 49 |
:param attention_map: This is the final attention map from the generator. |
50 | 50 |
:return: |
51 | 51 |
""" |
52 |
- with torch.no_grad(): |
|
53 |
- batch_size, image_h, image_w, _ = real_clean.size() |
|
54 | 52 |
|
55 |
- zeros_mask = torch.zeros([batch_size, image_h, image_w, 1], dtype=torch.float32) |
|
53 |
+ batch_size, image_h, image_w, _ = real_clean.size() |
|
56 | 54 |
|
57 |
- # Inference function |
|
58 |
- ret = self.forward(real_clean) |
|
59 |
- fc_out_o, attention_mask_o, fc2_o = ret["fc_out"], ret["attention_map"], ret["fc_raw"] |
|
60 |
- ret = self.forward(generated_clean) |
|
61 |
- fc_out_r, attention_mask_r, fc2_r = ret["fc_out"], ret["attention_map"], ret["fc_raw"] |
|
55 |
+ zeros_mask = torch.zeros([batch_size, image_h, image_w, 1], dtype=torch.float32) |
|
62 | 56 |
|
63 |
- l_map = F.mse_loss(attention_map, attention_mask_o) + \ |
|
64 |
- F.mse_loss(attention_mask_r, zeros_mask) |
|
57 |
+ # Inference function |
|
58 |
+ ret = self.forward(real_clean) |
|
59 |
+ fc_out_o, attention_mask_o, fc2_o = ret["fc_out"], ret["attention_map"], ret["fc_raw"] |
|
60 |
+ ret = self.forward(generated_clean) |
|
61 |
+ fc_out_r, attention_mask_r, fc2_r = ret["fc_out"], ret["attention_map"], ret["fc_raw"] |
|
65 | 62 |
|
66 |
- entropy_loss = -torch.log(fc_out_r) - torch.log(-torch.sub(fc_out_o, 1.0)) |
|
67 |
- entropy_loss = torch.mean(entropy_loss) |
|
63 |
+ l_map = F.mse_loss(attention_map, attention_mask_o) + \ |
|
64 |
+ F.mse_loss(attention_mask_r, zeros_mask) |
|
68 | 65 |
|
69 |
- loss = entropy_loss + 0.05 * l_map |
|
66 |
+ entropy_loss = -torch.log(fc_out_r) - torch.log(-torch.sub(fc_out_o, 1.0)) |
|
67 |
+ entropy_loss = torch.mean(entropy_loss) |
|
68 |
+ |
|
69 |
+ loss = entropy_loss + 0.05 * l_map |
|
70 | 70 |
|
71 | 71 |
return fc_out_o, loss |
72 | 72 |
|
--- model/Generator.py
+++ model/Generator.py
... | ... | @@ -1,6 +1,7 @@ |
1 |
-from AttentiveRNN import AttentiveRNN |
|
2 |
-from Autoencoder import AutoEncoder |
|
1 |
+from model.AttentiveRNN import AttentiveRNN |
|
2 |
+from model.Autoencoder import AutoEncoder |
|
3 | 3 |
from torch import nn |
4 |
+from torch import sum, pow, abs |
|
4 | 5 |
|
5 | 6 |
|
6 | 7 |
class Generator(nn.Module): |
... | ... | @@ -36,10 +37,10 @@ |
36 | 37 |
|
37 | 38 |
def binary_diff_mask(self, clean, dirty, thresold=0.1): |
38 | 39 |
# this parts corrects gamma, and always remember, sRGB values are not in linear scale with lights intensity, |
39 |
- clean = torch.pow(clean, 0.45) |
|
40 |
- dirty = torch.pow(dirty, 0.45) |
|
41 |
- diff = torch.abs(clean - dirty) |
|
42 |
- diff = torch.sum(diff, dim=1) |
|
40 |
+ clean = pow(clean, 0.45) |
|
41 |
+ dirty = pow(dirty, 0.45) |
|
42 |
+ diff = abs(clean - dirty) |
|
43 |
+ diff = sum(diff, dim=1) |
|
43 | 44 |
|
44 | 45 |
bin_diff = (diff > thresold).to(clean.dtype) |
45 | 46 |
|
--- tools/argparser.py
+++ tools/argparser.py
... | ... | @@ -8,13 +8,22 @@ |
8 | 8 |
parser.add_argument("--batch_size", "-b", type=int, required=True, help="Size of single batch") |
9 | 9 |
parser.add_argument("--save_interval", "-s", type=int, required=True, help="Interval for saving weights") |
10 | 10 |
parser.add_argument("--sample_interval", type=int, required=True, help="Interval for saving inference result") |
11 |
- parser.add_argument("--device", "-d", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to use for computation") |
|
12 |
- parser.add_argument("--load", "-l", type=str, help="Path to previous weights for continuing training") |
|
13 |
- parser.add_argument("--generator_learning_rate", "-g_lr", type=float, required=True, help="Learning rate of generator") |
|
14 |
- parser.add_argument("--generator_learning_miniepoch", "-g_epoch", type=int, default=1, help="Number of times generator trains in a single epoch") |
|
15 |
- parser.add_argument("--generator_attentivernn_blocks", "-g_arnn_b", type=int, default=1, help="Number of blocks of RNN in attention network") |
|
16 |
- parser.add_argument("--generator_resnet_depth", "-g_depth", type=int, default=1, help="Depth of ResNet in each attention RNN blocks") |
|
17 |
- parser.add_argument("--discriminator_learning_rate", "-d_lr", type=float, help="Learning rate of discriminator. If not given, it is assumed to be the same as the generator") |
|
11 |
+ parser.add_argument("--device", "-d", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to use for " |
|
12 |
+ "computation") |
|
13 |
+ parser.add_argument("--load", "-l", type=str, default=None, help="Path to previous weights for continuing training") |
|
14 |
+ parser.add_argument("--generator_learning_rate", "-g_lr", type=float, required=True, help="Learning rate of " |
|
15 |
+ "generator") |
|
16 |
+ parser.add_argument("--generator_learning_miniepoch", "-g_epoch", type=int, default=1, help="Number of times " |
|
17 |
+ "generator trains in " |
|
18 |
+ "a single epoch") |
|
19 |
+ parser.add_argument("--generator_attentivernn_blocks", "-g_arnn_b", type=int, default=1, help="Number of blocks " |
|
20 |
+ "of RNN in " |
|
21 |
+ "attention network") |
|
22 |
+ parser.add_argument("--generator_resnet_depth", "-g_depth", type=int, default=1, help="Depth of ResNet in each " |
|
23 |
+ "attention RNN blocks") |
|
24 |
+ parser.add_argument("--discriminator_learning_rate", "-d_lr", type=float, help="Learning rate of discriminator. " |
|
25 |
+ "If not given, it is assumed to be" |
|
26 |
+ " the same as the generator") |
|
18 | 27 |
|
19 | 28 |
args = parser.parse_args() |
20 | 29 |
return args |
--- tools/dataloader.py
+++ tools/dataloader.py
... | ... | @@ -3,7 +3,7 @@ |
3 | 3 |
import numpy as np |
4 | 4 |
|
5 | 5 |
|
6 |
-class ImageDataSet(Dataset): |
|
6 |
+class ImagePairDataset(Dataset): |
|
7 | 7 |
def __init__(self, clean_img_dir, rainy_img_dirs, transform=None): |
8 | 8 |
self.clean_img = clean_img_dir |
9 | 9 |
self.rainy_img = rainy_img_dirs |
... | ... | @@ -16,15 +16,17 @@ |
16 | 16 |
clean_img_path = self.clean_img[idx] |
17 | 17 |
|
18 | 18 |
i = 0 |
19 |
- if len(self.rainy_img) != 1: |
|
19 |
+ if len(self.rainy_img[idx]) is list: |
|
20 | 20 |
rng = np.random.default_rng() |
21 | 21 |
i = rng.integers(low=0, high=len(self.rainy_img)-1) |
22 |
- rainy_img_path = self.rainy_img[idx][i] |
|
22 |
+ rainy_img_path = self.rainy_img[idx][i] |
|
23 |
+ else: |
|
24 |
+ rainy_img_path = self.rainy_img[idx] |
|
23 | 25 |
clean_image = read_image(clean_img_path) |
24 | 26 |
rainy_image = read_image(rainy_img_path) |
25 | 27 |
if self.transform: |
26 | 28 |
clean_image = self.transform(clean_image) |
27 |
- rainy_image = self.rainy_img(rainy_image) |
|
29 |
+ rainy_image = self.transform(rainy_image) |
|
28 | 30 |
|
29 | 31 |
ret = { |
30 | 32 |
"clean_image" : clean_image, |
... | ... | @@ -33,7 +35,7 @@ |
33 | 35 |
return ret |
34 | 36 |
|
35 | 37 |
def __add__(self, other): |
36 |
- return ImageDataSet( |
|
38 |
+ return ImagePairDataset( |
|
37 | 39 |
clean_img_dir=self.clean_img+other.clean_img, |
38 | 40 |
rainy_img_dirs=self.rainy_img+other.rainy_img, |
39 | 41 |
transform=self.transform |
--- tools/logger.py
+++ tools/logger.py
... | ... | @@ -2,8 +2,8 @@ |
2 | 2 |
import time |
3 | 3 |
import plotly.express as px |
4 | 4 |
import dash |
5 |
-import dash_core_components as dcc |
|
6 |
-import dash_html_components as html |
|
5 |
+from dash import dcc |
|
6 |
+from dash import html |
|
7 | 7 |
from dash.dependencies import Input, Output |
8 | 8 |
|
9 | 9 |
|
--- train.py
+++ train.py
... | ... | @@ -1,20 +1,25 @@ |
1 | 1 |
import sys |
2 | 2 |
import os |
3 | 3 |
import torch |
4 |
+import glob |
|
4 | 5 |
import numpy as np |
5 | 6 |
import pandas as pd |
6 | 7 |
import plotly.express as px |
8 |
+import torchvision.transforms |
|
7 | 9 |
from torchvision.utils import save_image |
10 |
+from torch.utils.data import DataLoader |
|
11 |
+from time import gmtime, strftime |
|
8 | 12 |
|
9 | 13 |
from model import Autoencoder |
10 |
-from model import Generator |
|
11 |
-from model import Discriminator |
|
12 |
-from model import AttentiveRNN |
|
14 |
+from model.Generator import Generator |
|
15 |
+from model.Discriminator import DiscriminativeNet as Discriminator |
|
16 |
+from model.AttentiveRNN import AttentiveRNN |
|
13 | 17 |
from tools.argparser import get_param |
14 | 18 |
from tools.logger import Logger |
15 |
-from tools.dataloader import Dataset |
|
19 |
+from tools.dataloader import ImagePairDataset |
|
16 | 20 |
|
17 | 21 |
# this function is from https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dualgan/models.py |
22 |
+# MIT license |
|
18 | 23 |
def weights_init_normal(m): |
19 | 24 |
classname = m.__class__.__name__ |
20 | 25 |
if classname.find("Conv") != -1: |
... | ... | @@ -24,7 +29,9 @@ |
24 | 29 |
torch.nn.init.constant_(m.bias.data, 0.0) |
25 | 30 |
|
26 | 31 |
args = get_param() |
27 |
- |
|
32 |
+# I am doing this for easier debugging |
|
33 |
+# when you have error on those variables without doing this, |
|
34 |
+# you will be in trouble because error message will not say anything. |
|
28 | 35 |
epochs = args.epochs |
29 | 36 |
batch_size = args.batch_size |
30 | 37 |
save_interval = args.save_interval |
... | ... | @@ -40,44 +47,61 @@ |
40 | 47 |
logger = Logger() |
41 | 48 |
|
42 | 49 |
cuda = True if torch.cuda.is_available() else False |
43 |
- |
|
44 |
-generator = Generator() # get network values and stuff |
|
45 |
-discriminator = Discriminator() |
|
46 |
- |
|
47 | 50 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
48 | 51 |
|
49 |
-generator = Generator().to(device) |
|
50 |
-discriminator = Discriminator().to(device) |
|
52 |
+generator = Generator(generator_attentivernn_blocks, generator_resnet_depth).to(device=device) # get network values and stuff |
|
53 |
+discriminator = Discriminator().to(device=device) |
|
51 | 54 |
|
52 |
-if load is not False: |
|
53 |
- generator.load_state_dict(torch.load("example_path")) |
|
54 |
- discriminator.load_state_dict(torch.load("example_path")) |
|
55 |
+if load is not None: |
|
56 |
+ generator.load_state_dict(torch.load(load)) |
|
57 |
+ discriminator.load_state_dict(torch.load(load)) |
|
55 | 58 |
else: |
56 |
- generator.apply(weights_init_normal) |
|
57 |
- discriminator.apply(weights_init_normal) |
|
59 |
+ pass |
|
58 | 60 |
|
59 |
-dataloader = Dataloader() |
|
60 | 61 |
|
62 |
+ |
|
63 |
+# 이건 땜빵이고 차후에 데이터 관리 모듈 만들꺼임 |
|
64 |
+rainy_data_path = glob.glob("data/source/Oxford_raindrop_dataset/dirty/*.png") |
|
65 |
+rainy_data_path = sorted(rainy_data_path) |
|
66 |
+clean_data_path = glob.glob("data/source/Oxford_raindrop_dataset/clean/*.png") |
|
67 |
+clean_data_path = sorted(clean_data_path) |
|
68 |
+ |
|
69 |
+resize = torchvision.transforms.Resize((692, 776)) |
|
70 |
+dataset = ImagePairDataset(clean_data_path, rainy_data_path, transform=resize) |
|
71 |
+dataloader = DataLoader(dataset, batch_size=batch_size) |
|
61 | 72 |
# declare generator loss |
62 | 73 |
|
63 |
-optimizer_G = torch.optim.Adam(generator.parameters(), lr=generator_learning_rate) |
|
64 |
-optimizer_D = torch.optim.Adam(generator.parameters(), lr=discriminator_learning_rate) |
|
74 |
+optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_learning_rate) |
|
75 |
+optimizer_G_AE = torch.optim.Adam(generator.autoencoder.parameters(), lr=generator_learning_rate) |
|
76 |
+optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=discriminator_learning_rate) |
|
65 | 77 |
|
66 | 78 |
for epoch_num, epoch in enumerate(range(epochs)): |
67 |
- for i, (imgs, _) in enumerate(dataloader): |
|
68 |
- logger.print_training_log(epoch_num, epochs, i, len(enumerate(dataloader))) |
|
79 |
+ for i, imgs in enumerate(dataloader): |
|
69 | 80 |
|
70 |
- img_batch = data[0].to(device) |
|
71 |
- clean_img = img_batch["clean_image"] |
|
72 |
- rainy_img = img_batch["rainy_image"] |
|
81 |
+ img_batch = imgs |
|
82 |
+ clean_img = img_batch["clean_image"] / 255 |
|
83 |
+ rainy_img = img_batch["rainy_image"] / 255 |
|
73 | 84 |
|
74 |
- optimizer_G.zero_grad() |
|
75 |
- generator_outputs = generator(clean_img) |
|
76 |
- generator_result = generator_outputs["x"] |
|
77 |
- generator_attention_map = generator_outputs["attention_map_list"] |
|
78 |
- generator_loss = generator.loss(clean_img, rainy_img) |
|
79 |
- generator_loss.backward() |
|
80 |
- optimizer_G.step() |
|
85 |
+ clean_img = clean_img.to(device=device) |
|
86 |
+ rainy_img = rainy_img.to(device=device) |
|
87 |
+ |
|
88 |
+ optimizer_G_ARNN.zero_grad() |
|
89 |
+ optimizer_G_AE.zero_grad() |
|
90 |
+ |
|
91 |
+ attentiveRNNresults = generator.attentiveRNN(clean_img) |
|
92 |
+ generator_attention_map = attentiveRNNresults['attention_map_list'] |
|
93 |
+ binary_difference_mask = generator.binary_diff_mask(clean_img, rainy_img) |
|
94 |
+ generator_loss_ARNN = generator.attentiveRNN.loss(clean_img, binary_difference_mask) |
|
95 |
+ generator_loss_ARNN.backward() |
|
96 |
+ optimizer_G_ARNN.step() |
|
97 |
+ |
|
98 |
+ generator_outputs = generator.autoencoder(attentiveRNNresults['x'] * attentiveRNNresults['attention_map_list'][-1]) |
|
99 |
+ generator_result = generator_outputs['skip_3'] |
|
100 |
+ |
|
101 |
+ generator_loss_AE = generator.autoencoder.loss(clean_img, rainy_img) |
|
102 |
+ generator_loss_AE.backward() |
|
103 |
+ optimizer_G_AE.step() |
|
104 |
+ |
|
81 | 105 |
|
82 | 106 |
optimizer_D.zero_grad() |
83 | 107 |
real_clean_prediction = discriminator(clean_img) |
... | ... | @@ -86,6 +110,19 @@ |
86 | 110 |
discriminator_loss.backward() |
87 | 111 |
optimizer_D.step() |
88 | 112 |
|
113 |
+ losses = { |
|
114 |
+ "generator_loss_ARNN": generator_loss_ARNN, |
|
115 |
+ "generator_loss_AE" : generator_loss_AE, |
|
116 |
+ "discriminator loss" : discriminator_loss |
|
117 |
+ } |
|
118 |
+ |
|
119 |
+ logger.print_training_log(epoch_num, epochs, i, len(dataloader), losses) |
|
120 |
+ |
|
121 |
+ day = strftime("%Y-%m-%d %H:%M:%S", gmtime()) |
|
122 |
+ if epoch % save_interval == 0 and epoch != 0: |
|
123 |
+ torch.save(generator.attentionRNN.state_dict(), f"weight/Attention_RNN_{day}.pt") |
|
124 |
+ torch.save(generator.state_dict(), f"weight/Generator_{day}.pt") |
|
125 |
+ torch.save(discriminator.state_dict(), f"weight/Discriminator_{day}.pt") |
|
89 | 126 |
|
90 | 127 |
|
91 | 128 |
|
... | ... | @@ -93,7 +130,8 @@ |
93 | 130 |
|
94 | 131 |
|
95 | 132 |
|
96 |
-torch.save(generator.attentionRNN.state_dict(), "attentionRNN_model_path") |
|
133 |
+ |
|
134 |
+ |
|
97 | 135 |
|
98 | 136 |
|
99 | 137 |
|
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?