윤영준 윤영준 2023-06-26
wrote draft of logger and writing train.py
@722005dde652e4d4c08c2e2f78555c53031069aa
model/Generator.py
--- model/Generator.py
+++ model/Generator.py
@@ -26,9 +26,13 @@
         self.sigmoid = nn.Sigmoid()
 
     def forward(self, x):
-        ret = self.attentiveRNN(x)
-        x = self.autoencoder(ret['x'] * ret['attention_map_list'][-1])
-        return x
+        attentiveRNNresults = self.attentiveRNN(x)
+        x = self.autoencoder(attentiveRNNresults['x'] * attentiveRNNresults['attention_map_list'][-1])
+        ret = {
+            'x' : x,
+            'attention_maps' : attentiveRNNresults['attention_map_list']
+        }
+        return ret
 
 if __name__ == "__main__":
     import torch
tools/logger.py
--- tools/logger.py
+++ tools/logger.py
@@ -1,10 +1,38 @@
 import sys
 import time
+import plotly.express as px
+import dash
+import dash_core_components as dcc
+import dash_html_components as html
+from dash.dependencies import Input, Output
+
+
 class Logger():
     def __init__(self):
         self.start_time = time.time()
         self.epoch_start_time = self.start_time
-    def print_training_log(self, current_epoch, total_epoch, **kargs):
+        self.losses = []
+
+        self.app = dash.Dash(__name__)
+
+        self.app.layout = html.Div([
+            dcc.Graph(id='live-update-graph'),
+            dcc.Interval(
+                id='interval-component',
+                interval=1 * 1000,  # in milliseconds
+                n_intervals=0
+            )
+        ])
+
+        @self.app.callback(Output('live-update-graph', 'figure'),
+                           [Input('interval-component', 'n_intervals')])
+        def update_graph_live(n):
+            # Create the graph with subplots
+            fig = px.line(x=list(range(len(self.losses))), y=self.losses, labels={'x': 'Epoch', 'y': 'Loss'})
+            return fig
+
+    def print_training_log(self, current_epoch, total_epoch, losses):
+        assert type(losses) == dict
         current_time = time.time()
         epoch_time = current_time - self.epoch_start_time
         total_time = current_time - self.start_time
@@ -14,6 +42,17 @@
 
         self.epoch_start_time = current_time
 
+        terminal_logging_string = ""
+        for loss_name, loss_value in losses.items():
+            terminal_logging_string += f"{loss_name} : {loss_value}\n"
+            if loss_name == 'loss':
+                self.losses.append(loss_value)
+
         sys.stdout.write(
             f"epoch : {current_epoch}/{total_epoch}\n"
-        )
(No newline at end of file)
+            f"estimated time remaining : {remaining_time}"
+            f"{terminal_logging_string}"
+        )
+
+    def print_training_history(self):
+        self.app.run_server(debug=True)
train.py
--- train.py
+++ train.py
@@ -12,11 +12,42 @@
 from model import AttentiveRNN
 from tools.argparser import get_param
 from tools.logger import Logger
+from tools.dataloader import Dataset
+
+# this function is from https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dualgan/models.py
+def weights_init_normal(m):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
+    elif classname.find("BatchNorm2d") != -1:
+        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
+        torch.nn.init.constant_(m.bias.data, 0.0)
 
 param = get_param()
 
 logger = Logger()
 
+cuda = True if torch.cuda.is_available() else False
+
+generator = Generator() # get network values and stuff
+discriminator = Discriminator()
+
+if cuda:
+    generator.cuda()
+    discriminator.cuda()
+
+if load is not False:
+    generator.load_state_dict(torch.load("example_path"))
+    discriminator.load_state_dict(torch.load("example_path"))
+else:
+    generator.apply(weights_init_normal)
+    discriminator.apply(weights_init_normal)
+
+dataloader = Dataloader
+
+
+
+
 ## RNN 따로 돌리고 CPU로 메모리 옳기고
 ## Autoencoder 따로 돌리고 메모리 옳기고
 ## 안되는가
Add a comment
List