configuring saving method and inference method for model
@f6d5f07f1fff63a462dab0488224b039b675e9dd
+++ inference.py
... | ... | @@ -0,0 +1,24 @@ |
1 | +import numpy as np | |
2 | +import pandas as pd | |
3 | +import json | |
4 | +import torch | |
5 | + | |
6 | +from model.Generator import Generator | |
7 | +from model.AttentiveRNN import AttentiveRNN | |
8 | +from model.Autoencoder import AutoEncoder | |
9 | +from model.Discriminator import DiscriminativeNet as Discriminator | |
10 | + | |
11 | +def load_config_from_json(filename): | |
12 | + with open(filename, 'r') as f: | |
13 | + config = json.load(f) | |
14 | + return config | |
15 | + | |
16 | +config = load_config_from_json('training_config.json') | |
17 | +print(config) | |
18 | + | |
19 | + | |
20 | +with torch.no_grad(): | |
21 | + settings = | |
22 | + generator = Generator | |
23 | + generator.attentiveRNN.load_state_dict(torch.load(load)) | |
24 | + generator.autoencoder.load_state_dict(torch.load(load))(파일 끝에 줄바꿈 문자 없음) |
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
... | ... | @@ -196,6 +196,7 @@ |
196 | 196 |
self.generator_blocks.append( |
197 | 197 |
self.generator_block |
198 | 198 |
) |
199 |
+ self.name = "AttentiveRNN" |
|
199 | 200 |
|
200 | 201 |
def forward(self, x): |
201 | 202 |
cell_state = None |
--- model/Autoencoder.py
+++ model/Autoencoder.py
... | ... | @@ -50,6 +50,7 @@ |
50 | 50 |
for param in self.vgg.parameters(): |
51 | 51 |
param.requires_grad = False |
52 | 52 |
self.lambda_i = [0.6, 0.8, 1.0] |
53 |
+ self.name = "AutoEncoder" |
|
53 | 54 |
|
54 | 55 |
def forward(self, input_tensor): |
55 | 56 |
# Feed the input through each layer |
--- model/Discriminator.py
+++ model/Discriminator.py
... | ... | @@ -18,6 +18,7 @@ |
18 | 18 |
self.conv9 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, stride=4, padding=2) |
19 | 19 |
self.fc1 = nn.Linear(32, 1) # You need to adjust the input dimension here depending on your input size |
20 | 20 |
self.fc2 = nn.Linear(1, 1) |
21 |
+ self.name = "Discriminator of Attentive GAN" |
|
21 | 22 |
def forward(self, x): |
22 | 23 |
x1 = F.leaky_relu(self.conv1(x)) |
23 | 24 |
x2 = F.leaky_relu(self.conv2(x1)) |
--- model/Generator.py
+++ model/Generator.py
... | ... | @@ -25,6 +25,7 @@ |
25 | 25 |
self.groups = groups |
26 | 26 |
self.dilation = dilation |
27 | 27 |
self.sigmoid = nn.Sigmoid() |
28 |
+ self.name = "Generator of Attentive GAN" |
|
28 | 29 |
|
29 | 30 |
def forward(self, x): |
30 | 31 |
attentiveRNNresults = self.attentiveRNN(x) |
--- train.py
+++ train.py
... | ... | @@ -59,7 +59,8 @@ |
59 | 59 |
discriminator = Discriminator().to(device=device) |
60 | 60 |
|
61 | 61 |
if load is not None: |
62 |
- generator.load_state_dict(torch.load(load)) |
|
62 |
+ generator.attentiveRNN.load_state_dict(torch.load(load)) |
|
63 |
+ generator.autoencoder.load_state_dict(torch.load(load)) |
|
63 | 64 |
discriminator.load_state_dict(torch.load(load)) |
64 | 65 |
else: |
65 | 66 |
pass |
... | ... | @@ -135,7 +136,7 @@ |
135 | 136 |
|
136 | 137 |
optimizer_D.step() |
137 | 138 |
|
138 |
- |
|
139 |
+ # Total loss |
|
139 | 140 |
optimizer_G.zero_grad() |
140 | 141 |
generator_loss_whole = generator_loss_AE + generator_loss_ARNN + torch.mean( |
141 | 142 |
torch.log(torch.subtract(1, fake_clean_prediction)) |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?