Tried to make logger with dash, however it did not work as I have planned, so instead, visdom is used. Also, now the dataloader shuffles the dataset by default
Also, now the dataloader shuffles the dataset by default
@9773fd657775472d18b9139bcb67c52a74cd9623
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
... | ... | @@ -84,7 +84,7 @@ |
84 | 84 |
|
85 | 85 |
def init_hidden(self, batch_size, image_size, init=0.5): |
86 | 86 |
height, width = image_size |
87 |
- return torch.ones(batch_size, self.ch, height, width).to(self.conv_i.weight.device) * init |
|
87 |
+ return torch.ones(batch_size, self.ch, height, width).to(dtype=self.conv_i.weight.dtype , device=self.conv_i.weight.device) * init |
|
88 | 88 |
|
89 | 89 |
def forward(self, input_tensor, input_cell_state=None): |
90 | 90 |
if input_cell_state is None: |
--- model/Discriminator.py
+++ model/Discriminator.py
... | ... | @@ -54,7 +54,7 @@ |
54 | 54 |
|
55 | 55 |
batch_size, _, image_h, image_w = real_clean.size() |
56 | 56 |
attention_map = F.interpolate(attention_map[-1], size=(int(ceil(image_h/16)), int(ceil(image_w/16)))) |
57 |
- zeros_mask = torch.zeros([batch_size, 1, int(ceil(image_h/16)), int(ceil(image_w/16))], dtype=torch.float32).to(attention_map.device) |
|
57 |
+ zeros_mask = torch.zeros([batch_size, 1, int(ceil(image_h/16)), int(ceil(image_w/16))], dtype=attention_map.dtype).to(attention_map.device) |
|
58 | 58 |
|
59 | 59 |
# Inference function |
60 | 60 |
ret = self.forward(real_clean) |
... | ... | @@ -70,7 +70,7 @@ |
70 | 70 |
|
71 | 71 |
loss = entropy_loss + 0.05 * l_map |
72 | 72 |
|
73 |
- return loss |
|
73 |
+ return loss.to(dtype=attention_map.dtype) |
|
74 | 74 |
|
75 | 75 |
if __name__ == "__main__": |
76 | 76 |
import torch |
--- tools/logger.py
+++ tools/logger.py
... | ... | @@ -13,24 +13,6 @@ |
13 | 13 |
self.epoch_start_time = self.start_time |
14 | 14 |
self.losses = [] |
15 | 15 |
|
16 |
- self.app = dash.Dash(__name__) |
|
17 |
- |
|
18 |
- self.app.layout = html.Div([ |
|
19 |
- dcc.Graph(id='live-update-graph'), |
|
20 |
- dcc.Interval( |
|
21 |
- id='interval-component', |
|
22 |
- interval=1 * 1000, # in milliseconds |
|
23 |
- n_intervals=0 |
|
24 |
- ) |
|
25 |
- ]) |
|
26 |
- |
|
27 |
- @self.app.callback(Output('live-update-graph', 'figure'), |
|
28 |
- [Input('interval-component', 'n_intervals')]) |
|
29 |
- def update_graph_live(n): |
|
30 |
- # Create the graph with subplots |
|
31 |
- fig = px.line(x=list(range(len(self.losses))), y=self.losses, labels={'x': 'Epoch', 'y': 'Loss'}) |
|
32 |
- return fig |
|
33 |
- |
|
34 | 16 |
def print_training_log(self, current_epoch, total_epoch, current_batch, total_batch, losses): |
35 | 17 |
assert type(losses) == dict |
36 | 18 |
current_time = time.time() |
... | ... | @@ -54,6 +36,3 @@ |
54 | 36 |
f"estimated time remaining : {remaining_time}\n" |
55 | 37 |
f"{terminal_logging_string}\n" |
56 | 38 |
) |
57 |
- |
|
58 |
- def print_training_history(self): |
|
59 |
- self.app.run_server(debug=True) |
--- train.py
+++ train.py
... | ... | @@ -3,20 +3,27 @@ |
3 | 3 |
import torch |
4 | 4 |
import glob |
5 | 5 |
import numpy as np |
6 |
+import time |
|
6 | 7 |
import pandas as pd |
7 |
-import plotly.express as px |
|
8 |
+import subprocess |
|
9 |
+import atexit |
|
8 | 10 |
import torchvision.transforms |
11 |
+from visdom import Visdom |
|
9 | 12 |
from torchvision.utils import save_image |
10 | 13 |
from torch.utils.data import DataLoader |
11 | 14 |
from time import gmtime, strftime |
12 | 15 |
|
13 |
-from model import Autoencoder |
|
16 |
+from model.Autoencoder import AutoEncoder |
|
14 | 17 |
from model.Generator import Generator |
15 | 18 |
from model.Discriminator import DiscriminativeNet as Discriminator |
16 | 19 |
from model.AttentiveRNN import AttentiveRNN |
17 | 20 |
from tools.argparser import get_param |
18 | 21 |
from tools.logger import Logger |
19 | 22 |
from tools.dataloader import ImagePairDataset |
23 |
+ |
|
24 |
+ |
|
25 |
+ |
|
26 |
+ |
|
20 | 27 |
|
21 | 28 |
# this function is from https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dualgan/models.py |
22 | 29 |
# MIT license |
... | ... | @@ -49,7 +56,7 @@ |
49 | 56 |
cuda = True if torch.cuda.is_available() else False |
50 | 57 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
51 | 58 |
|
52 |
-generator = Generator(generator_attentivernn_blocks, generator_resnet_depth).to(device=device) # get network values and stuff |
|
59 |
+generator = Generator(generator_attentivernn_blocks, generator_resnet_depth).to(device=device) |
|
53 | 60 |
discriminator = Discriminator().to(device=device) |
54 | 61 |
|
55 | 62 |
if load is not None: |
... | ... | @@ -58,22 +65,33 @@ |
58 | 65 |
else: |
59 | 66 |
pass |
60 | 67 |
|
61 |
- |
|
62 |
- |
|
63 | 68 |
# 이건 땜빵이고 차후에 데이터 관리 모듈 만들꺼임 |
64 | 69 |
rainy_data_path = glob.glob("data/source/Oxford_raindrop_dataset/dirty/*.png") |
65 | 70 |
rainy_data_path = sorted(rainy_data_path) |
66 | 71 |
clean_data_path = glob.glob("data/source/Oxford_raindrop_dataset/clean/*.png") |
67 | 72 |
clean_data_path = sorted(clean_data_path) |
68 | 73 |
|
69 |
-resize = torchvision.transforms.Resize((692, 776)) |
|
74 |
+resize = torchvision.transforms.Resize((692, 776), antialias=True) |
|
70 | 75 |
dataset = ImagePairDataset(clean_data_path, rainy_data_path, transform=resize) |
71 |
-dataloader = DataLoader(dataset, batch_size=batch_size) |
|
76 |
+dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
72 | 77 |
# declare generator loss |
78 |
+ |
|
73 | 79 |
|
74 | 80 |
optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_learning_rate) |
75 | 81 |
optimizer_G_AE = torch.optim.Adam(generator.autoencoder.parameters(), lr=generator_learning_rate) |
76 | 82 |
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=discriminator_learning_rate) |
83 |
+ |
|
84 |
+# ------visdom visualizer ---------- |
|
85 |
+server_process = subprocess.Popen("python -m visdom.server", shell=True) |
|
86 |
+def cleanup(): |
|
87 |
+ server_process.terminate() |
|
88 |
+atexit.register(cleanup) |
|
89 |
+ |
|
90 |
+time.sleep(10) |
|
91 |
+vis = Visdom(server="http://localhost", port=8097) |
|
92 |
+ARNN_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AttentionRNN Loss')) |
|
93 |
+AE_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AutoEncoder Loss')) |
|
94 |
+Discriminator_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Discriminator Loss')) |
|
77 | 95 |
|
78 | 96 |
for epoch_num, epoch in enumerate(range(epochs)): |
79 | 97 |
for i, imgs in enumerate(dataloader): |
... | ... | @@ -81,6 +99,9 @@ |
81 | 99 |
img_batch = imgs |
82 | 100 |
clean_img = img_batch["clean_image"] / 255 |
83 | 101 |
rainy_img = img_batch["rainy_image"] / 255 |
102 |
+ |
|
103 |
+ # clean_img = clean_img.to() |
|
104 |
+ # rainy_img = rainy_img.to() |
|
84 | 105 |
|
85 | 106 |
clean_img = clean_img.to(device=device) |
86 | 107 |
rainy_img = rainy_img.to(device=device) |
... | ... | @@ -118,12 +139,21 @@ |
118 | 139 |
|
119 | 140 |
logger.print_training_log(epoch_num, epochs, i, len(dataloader), losses) |
120 | 141 |
|
142 |
+ # visdom logger |
|
143 |
+ vis.line(Y=np.array([generator_loss_ARNN.item()]), X=np.array([epoch * epoch_num + i]), win=ARNN_loss_window, |
|
144 |
+ update='append') |
|
145 |
+ vis.line(Y=np.array([generator_loss_AE.item()]), X=np.array([epoch * epoch_num + i]), win=AE_loss_window, |
|
146 |
+ update='append') |
|
147 |
+ vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch * epoch_num + i]), win=Discriminator_loss_window, |
|
148 |
+ update='append') |
|
149 |
+ |
|
121 | 150 |
day = strftime("%Y-%m-%d %H:%M:%S", gmtime()) |
122 | 151 |
if epoch % save_interval == 0 and epoch != 0: |
123 | 152 |
torch.save(generator.attentionRNN.state_dict(), f"weight/Attention_RNN_{day}.pt") |
124 | 153 |
torch.save(generator.state_dict(), f"weight/Generator_{day}.pt") |
125 | 154 |
torch.save(discriminator.state_dict(), f"weight/Discriminator_{day}.pt") |
126 | 155 |
|
156 |
+server_process.terminate() |
|
127 | 157 |
|
128 | 158 |
|
129 | 159 |
|
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?