data:image/s3,"s3://crabby-images/77fc1/77fc1ecd598263bdfa1d6248fbe60b3bfc41f6f8" alt=""
--- model/Generator.py
+++ model/Generator.py
... | ... | @@ -26,9 +26,13 @@ |
26 | 26 |
self.sigmoid = nn.Sigmoid() |
27 | 27 |
|
28 | 28 |
def forward(self, x): |
29 |
- ret = self.attentiveRNN(x) |
|
30 |
- x = self.autoencoder(ret['x'] * ret['attention_map_list'][-1]) |
|
31 |
- return x |
|
29 |
+ attentiveRNNresults = self.attentiveRNN(x) |
|
30 |
+ x = self.autoencoder(attentiveRNNresults['x'] * attentiveRNNresults['attention_map_list'][-1]) |
|
31 |
+ ret = { |
|
32 |
+ 'x' : x, |
|
33 |
+ 'attention_maps' : attentiveRNNresults['attention_map_list'] |
|
34 |
+ } |
|
35 |
+ return ret |
|
32 | 36 |
|
33 | 37 |
if __name__ == "__main__": |
34 | 38 |
import torch |
--- tools/logger.py
+++ tools/logger.py
... | ... | @@ -1,10 +1,38 @@ |
1 | 1 |
import sys |
2 | 2 |
import time |
3 |
+import plotly.express as px |
|
4 |
+import dash |
|
5 |
+import dash_core_components as dcc |
|
6 |
+import dash_html_components as html |
|
7 |
+from dash.dependencies import Input, Output |
|
8 |
+ |
|
9 |
+ |
|
3 | 10 |
class Logger(): |
4 | 11 |
def __init__(self): |
5 | 12 |
self.start_time = time.time() |
6 | 13 |
self.epoch_start_time = self.start_time |
7 |
- def print_training_log(self, current_epoch, total_epoch, **kargs): |
|
14 |
+ self.losses = [] |
|
15 |
+ |
|
16 |
+ self.app = dash.Dash(__name__) |
|
17 |
+ |
|
18 |
+ self.app.layout = html.Div([ |
|
19 |
+ dcc.Graph(id='live-update-graph'), |
|
20 |
+ dcc.Interval( |
|
21 |
+ id='interval-component', |
|
22 |
+ interval=1 * 1000, # in milliseconds |
|
23 |
+ n_intervals=0 |
|
24 |
+ ) |
|
25 |
+ ]) |
|
26 |
+ |
|
27 |
+ @self.app.callback(Output('live-update-graph', 'figure'), |
|
28 |
+ [Input('interval-component', 'n_intervals')]) |
|
29 |
+ def update_graph_live(n): |
|
30 |
+ # Create the graph with subplots |
|
31 |
+ fig = px.line(x=list(range(len(self.losses))), y=self.losses, labels={'x': 'Epoch', 'y': 'Loss'}) |
|
32 |
+ return fig |
|
33 |
+ |
|
34 |
+ def print_training_log(self, current_epoch, total_epoch, losses): |
|
35 |
+ assert type(losses) == dict |
|
8 | 36 |
current_time = time.time() |
9 | 37 |
epoch_time = current_time - self.epoch_start_time |
10 | 38 |
total_time = current_time - self.start_time |
... | ... | @@ -14,6 +42,17 @@ |
14 | 42 |
|
15 | 43 |
self.epoch_start_time = current_time |
16 | 44 |
|
45 |
+ terminal_logging_string = "" |
|
46 |
+ for loss_name, loss_value in losses.items(): |
|
47 |
+ terminal_logging_string += f"{loss_name} : {loss_value}\n" |
|
48 |
+ if loss_name == 'loss': |
|
49 |
+ self.losses.append(loss_value) |
|
50 |
+ |
|
17 | 51 |
sys.stdout.write( |
18 | 52 |
f"epoch : {current_epoch}/{total_epoch}\n" |
19 |
- )(No newline at end of file) |
|
53 |
+ f"estimated time remaining : {remaining_time}" |
|
54 |
+ f"{terminal_logging_string}" |
|
55 |
+ ) |
|
56 |
+ |
|
57 |
+ def print_training_history(self): |
|
58 |
+ self.app.run_server(debug=True) |
--- train.py
+++ train.py
... | ... | @@ -12,11 +12,42 @@ |
12 | 12 |
from model import AttentiveRNN |
13 | 13 |
from tools.argparser import get_param |
14 | 14 |
from tools.logger import Logger |
15 |
+from tools.dataloader import Dataset |
|
16 |
+ |
|
17 |
+# this function is from https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dualgan/models.py |
|
18 |
+def weights_init_normal(m): |
|
19 |
+ classname = m.__class__.__name__ |
|
20 |
+ if classname.find("Conv") != -1: |
|
21 |
+ torch.nn.init.normal_(m.weight.data, 0.0, 0.02) |
|
22 |
+ elif classname.find("BatchNorm2d") != -1: |
|
23 |
+ torch.nn.init.normal_(m.weight.data, 1.0, 0.02) |
|
24 |
+ torch.nn.init.constant_(m.bias.data, 0.0) |
|
15 | 25 |
|
16 | 26 |
param = get_param() |
17 | 27 |
|
18 | 28 |
logger = Logger() |
19 | 29 |
|
30 |
+cuda = True if torch.cuda.is_available() else False |
|
31 |
+ |
|
32 |
+generator = Generator() # get network values and stuff |
|
33 |
+discriminator = Discriminator() |
|
34 |
+ |
|
35 |
+if cuda: |
|
36 |
+ generator.cuda() |
|
37 |
+ discriminator.cuda() |
|
38 |
+ |
|
39 |
+if load is not False: |
|
40 |
+ generator.load_state_dict(torch.load("example_path")) |
|
41 |
+ discriminator.load_state_dict(torch.load("example_path")) |
|
42 |
+else: |
|
43 |
+ generator.apply(weights_init_normal) |
|
44 |
+ discriminator.apply(weights_init_normal) |
|
45 |
+ |
|
46 |
+dataloader = Dataloader |
|
47 |
+ |
|
48 |
+ |
|
49 |
+ |
|
50 |
+ |
|
20 | 51 |
## RNN 따로 돌리고 CPU로 메모리 옳기고 |
21 | 52 |
## Autoencoder 따로 돌리고 메모리 옳기고 |
22 | 53 |
## 안되는가 |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?