--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
... | ... | @@ -101,7 +101,7 @@ |
101 | 101 |
attention_map = torch.sigmoid(attention_map) |
102 | 102 |
|
103 | 103 |
ret = { |
104 |
- "attention_amp" : attention_map, |
|
104 |
+ "attention_map" : attention_map, |
|
105 | 105 |
"cell_state" : cell_state, |
106 | 106 |
"lstm_feats" : lstm_feats |
107 | 107 |
} |
--- train.py
+++ train.py
... | ... | @@ -8,8 +8,12 @@ |
8 | 8 |
import subprocess |
9 | 9 |
import atexit |
10 | 10 |
import torchvision.transforms |
11 |
+import cv2 |
|
12 |
+ |
|
11 | 13 |
from visdom import Visdom |
12 | 14 |
from torchvision.utils import save_image |
15 |
+from torchvision import transforms |
|
16 |
+from torchvision.transforms import RandomCrop, RandomPerspective, Compose |
|
13 | 17 |
from torch.utils.data import DataLoader |
14 | 18 |
from time import gmtime, strftime |
15 | 19 |
|
... | ... | @@ -20,17 +24,6 @@ |
20 | 24 |
from tools.argparser import get_param |
21 | 25 |
from tools.logger import Logger |
22 | 26 |
from tools.dataloader import ImagePairDataset |
23 |
- |
|
24 |
- |
|
25 |
-# this function is from https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dualgan/models.py |
|
26 |
-# MIT license |
|
27 |
-def weights_init_normal(m): |
|
28 |
- classname = m.__class__.__name__ |
|
29 |
- if classname.find("Conv") != -1: |
|
30 |
- torch.nn.init.normal_(m.weight.data, 0.0, 0.02) |
|
31 |
- elif classname.find("BatchNorm2d") != -1: |
|
32 |
- torch.nn.init.normal_(m.weight.data, 1.0, 0.02) |
|
33 |
- torch.nn.init.constant_(m.bias.data, 0.0) |
|
34 | 27 |
|
35 | 28 |
args = get_param() |
36 | 29 |
# I am doing this for easier debugging |
... | ... | @@ -71,14 +64,22 @@ |
71 | 64 |
clean_data_path = glob.glob("data/source/Oxford_raindrop_dataset/clean/*.png") |
72 | 65 |
clean_data_path = sorted(clean_data_path) |
73 | 66 |
|
74 |
-resize = torchvision.transforms.Resize((480, 720), antialias=True) |
|
75 |
-dataset = ImagePairDataset(clean_data_path, rainy_data_path, transform=resize) |
|
67 |
+height = 480 |
|
68 |
+width = 720 |
|
69 |
+ |
|
70 |
+transform = Compose([ |
|
71 |
+ RandomPerspective(), |
|
72 |
+ RandomCrop((height, width)) |
|
73 |
+]) |
|
74 |
+ |
|
75 |
+resize = torchvision.transforms.Resize((height, width), antialias=True) |
|
76 |
+dataset = ImagePairDataset(clean_data_path, rainy_data_path, transform=transform) |
|
76 | 77 |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) |
77 | 78 |
# declare generator loss |
78 | 79 |
|
79 | 80 |
optimizer_G = torch.optim.Adam(generator.parameters(), lr=generator_learning_rate) |
80 |
-optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_ARNN_learning_rate) |
|
81 |
-optimizer_G_AE = torch.optim.Adam(generator.autoencoder.parameters(), lr=generator_learning_rate) |
|
81 |
+optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_ARNN_learning_rate, betas=(0.5, 0.999)) |
|
82 |
+optimizer_G_AE = torch.optim.Adam(generator.autoencoder.parameters(), lr=generator_learning_rate, betas=(0.5, 0.999)) |
|
82 | 83 |
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=discriminator_learning_rate) |
83 | 84 |
|
84 | 85 |
# ------visdom visualizer ---------- |
... | ... | @@ -93,10 +94,10 @@ |
93 | 94 |
ARNN_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AttentionRNN Loss')) |
94 | 95 |
AE_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AutoEncoder Loss')) |
95 | 96 |
Discriminator_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Discriminator Loss')) |
96 |
-Attention_map_visualizer = vis.image(np.zeros((692, 776)), opts=dict(title='Attention Map')) |
|
97 |
-Difference_mask_map_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='Mask Map')) |
|
98 |
-Generator_output_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='Generated Derain Output')) |
|
99 |
-Input_image_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='input clean image')) |
|
97 |
+Attention_map_visualizer = vis.image(np.zeros((height, width)), opts=dict(title='Attention Map')) |
|
98 |
+Difference_mask_map_visualizer = vis.image(np.zeros((height, width)), opts=dict(title='Mask Map')) |
|
99 |
+Generator_output_visualizer = vis.image(np.zeros((height, width)), opts=dict(title='Generated Derain Output')) |
|
100 |
+Input_image_visualizer = vis.image(np.zeros((height, width)), opts=dict(title='input clean image')) |
|
100 | 101 |
|
101 | 102 |
for epoch_num, epoch in enumerate(range(epochs)): |
102 | 103 |
for i, imgs in enumerate(dataloader): |
... | ... | @@ -153,11 +154,11 @@ |
153 | 154 |
|
154 | 155 |
logger.print_training_log(epoch_num, epochs, i, len(dataloader), losses) |
155 | 156 |
# visdom logger |
156 |
- vis.line(Y=np.array([generator_loss_ARNN.item()]), X=np.array([epoch * epoch_num + i]), win=ARNN_loss_window, |
|
157 |
+ vis.line(Y=np.array([generator_loss_ARNN.item()]), X=np.array([epoch_num * epochs + i]), win=ARNN_loss_window, |
|
157 | 158 |
update='append') |
158 |
- vis.line(Y=np.array([generator_loss_AE.item()]), X=np.array([epoch * epoch_num + i]), win=AE_loss_window, |
|
159 |
+ vis.line(Y=np.array([generator_loss_AE.item()]), X=np.array([epoch_num * epochs + i]), win=AE_loss_window, |
|
159 | 160 |
update='append') |
160 |
- vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch * epoch_num + i]), win=Discriminator_loss_window, |
|
161 |
+ vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch_num * epochs + i]), win=Discriminator_loss_window, |
|
161 | 162 |
update='append') |
162 | 163 |
vis.image(generator_attention_map[-1][0, 0, :, :], win=Attention_map_visualizer, |
163 | 164 |
opts=dict(title="Attention Map")) |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?