윤영준 윤영준 2023-07-04
Tried to make logger with dash, however it did not work as I have planned, so instead, visdom is used. Also, now the dataloader shuffles the dataset by default
Also, now the dataloader shuffles the dataset by default
@9773fd657775472d18b9139bcb67c52a74cd9623
model/AttentiveRNN.py
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
@@ -84,7 +84,7 @@
 
     def init_hidden(self, batch_size, image_size, init=0.5):
         height, width = image_size
-        return torch.ones(batch_size, self.ch, height, width).to(self.conv_i.weight.device) * init
+        return torch.ones(batch_size, self.ch, height, width).to(dtype=self.conv_i.weight.dtype , device=self.conv_i.weight.device) * init
 
     def forward(self, input_tensor, input_cell_state=None):
         if input_cell_state is None:
model/Discriminator.py
--- model/Discriminator.py
+++ model/Discriminator.py
@@ -54,7 +54,7 @@
 
         batch_size, _, image_h, image_w = real_clean.size()
         attention_map = F.interpolate(attention_map[-1], size=(int(ceil(image_h/16)), int(ceil(image_w/16))))
-        zeros_mask = torch.zeros([batch_size, 1, int(ceil(image_h/16)), int(ceil(image_w/16))], dtype=torch.float32).to(attention_map.device)
+        zeros_mask = torch.zeros([batch_size, 1, int(ceil(image_h/16)), int(ceil(image_w/16))], dtype=attention_map.dtype).to(attention_map.device)
 
         # Inference function
         ret = self.forward(real_clean)
@@ -70,7 +70,7 @@
 
         loss = entropy_loss + 0.05 * l_map
 
-        return loss
+        return loss.to(dtype=attention_map.dtype)
 
 if __name__ == "__main__":
     import torch
tools/logger.py
--- tools/logger.py
+++ tools/logger.py
@@ -13,24 +13,6 @@
         self.epoch_start_time = self.start_time
         self.losses = []
 
-        self.app = dash.Dash(__name__)
-
-        self.app.layout = html.Div([
-            dcc.Graph(id='live-update-graph'),
-            dcc.Interval(
-                id='interval-component',
-                interval=1 * 1000,  # in milliseconds
-                n_intervals=0
-            )
-        ])
-
-        @self.app.callback(Output('live-update-graph', 'figure'),
-                           [Input('interval-component', 'n_intervals')])
-        def update_graph_live(n):
-            # Create the graph with subplots
-            fig = px.line(x=list(range(len(self.losses))), y=self.losses, labels={'x': 'Epoch', 'y': 'Loss'})
-            return fig
-
     def print_training_log(self, current_epoch, total_epoch, current_batch, total_batch, losses):
         assert type(losses) == dict
         current_time = time.time()
@@ -54,6 +36,3 @@
             f"estimated time remaining : {remaining_time}\n"
             f"{terminal_logging_string}\n"
         )
-
-    def print_training_history(self):
-        self.app.run_server(debug=True)
train.py
--- train.py
+++ train.py
@@ -3,20 +3,27 @@
 import torch
 import glob
 import numpy as np
+import time
 import pandas as pd
-import plotly.express as px
+import subprocess
+import atexit
 import torchvision.transforms
+from visdom import Visdom
 from torchvision.utils import save_image
 from torch.utils.data import DataLoader
 from time import gmtime, strftime
 
-from model import Autoencoder
+from model.Autoencoder import AutoEncoder
 from model.Generator import Generator
 from model.Discriminator import DiscriminativeNet as Discriminator
 from model.AttentiveRNN import AttentiveRNN
 from tools.argparser import get_param
 from tools.logger import Logger
 from tools.dataloader import ImagePairDataset
+
+
+
+
 
 # this function is from https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dualgan/models.py
 # MIT license
@@ -49,7 +56,7 @@
 cuda = True if torch.cuda.is_available() else False
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
-generator = Generator(generator_attentivernn_blocks, generator_resnet_depth).to(device=device) # get network values and stuff
+generator = Generator(generator_attentivernn_blocks, generator_resnet_depth).to(device=device)
 discriminator = Discriminator().to(device=device)
 
 if load is not None:
@@ -58,22 +65,33 @@
 else:
     pass
 
-
-
 # 이건 땜빵이고 차후에 데이터 관리 모듈 만들꺼임
 rainy_data_path = glob.glob("data/source/Oxford_raindrop_dataset/dirty/*.png")
 rainy_data_path = sorted(rainy_data_path)
 clean_data_path = glob.glob("data/source/Oxford_raindrop_dataset/clean/*.png")
 clean_data_path = sorted(clean_data_path)
 
-resize = torchvision.transforms.Resize((692, 776))
+resize = torchvision.transforms.Resize((692, 776), antialias=True)
 dataset = ImagePairDataset(clean_data_path, rainy_data_path, transform=resize)
-dataloader = DataLoader(dataset, batch_size=batch_size)
+dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
 # declare generator loss
+
 
 optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_learning_rate)
 optimizer_G_AE = torch.optim.Adam(generator.autoencoder.parameters(), lr=generator_learning_rate)
 optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=discriminator_learning_rate)
+
+# ------visdom visualizer ----------
+server_process = subprocess.Popen("python -m visdom.server", shell=True)
+def cleanup():
+    server_process.terminate()
+atexit.register(cleanup)
+
+time.sleep(10)
+vis = Visdom(server="http://localhost", port=8097)
+ARNN_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AttentionRNN Loss'))
+AE_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AutoEncoder Loss'))
+Discriminator_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Discriminator Loss'))
 
 for epoch_num, epoch in enumerate(range(epochs)):
     for i, imgs in enumerate(dataloader):
@@ -81,6 +99,9 @@
         img_batch = imgs
         clean_img = img_batch["clean_image"] / 255
         rainy_img = img_batch["rainy_image"] / 255
+
+        # clean_img = clean_img.to()
+        # rainy_img = rainy_img.to()
 
         clean_img = clean_img.to(device=device)
         rainy_img = rainy_img.to(device=device)
@@ -118,12 +139,21 @@
 
         logger.print_training_log(epoch_num, epochs, i, len(dataloader), losses)
 
+        # visdom logger
+        vis.line(Y=np.array([generator_loss_ARNN.item()]), X=np.array([epoch * epoch_num + i]), win=ARNN_loss_window,
+                 update='append')
+        vis.line(Y=np.array([generator_loss_AE.item()]), X=np.array([epoch * epoch_num + i]), win=AE_loss_window,
+                 update='append')
+        vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch * epoch_num + i]), win=Discriminator_loss_window,
+                 update='append')
+
     day = strftime("%Y-%m-%d %H:%M:%S", gmtime())
     if epoch % save_interval == 0 and epoch != 0:
         torch.save(generator.attentionRNN.state_dict(), f"weight/Attention_RNN_{day}.pt")
         torch.save(generator.state_dict(), f"weight/Generator_{day}.pt")
         torch.save(discriminator.state_dict(), f"weight/Discriminator_{day}.pt")
 
+server_process.terminate()
 
 
 
Add a comment
List