윤영준 윤영준 2023-07-11
configuring saving method and inference method for model
@f6d5f07f1fff63a462dab0488224b039b675e9dd
 
inference.py (added)
+++ inference.py
@@ -0,0 +1,24 @@
+import numpy as np
+import pandas as pd
+import json
+import torch
+
+from model.Generator import Generator
+from model.AttentiveRNN import AttentiveRNN
+from model.Autoencoder import AutoEncoder
+from model.Discriminator import DiscriminativeNet as Discriminator
+
+def load_config_from_json(filename):
+    with open(filename, 'r') as f:
+        config = json.load(f)
+    return config
+
+config = load_config_from_json('training_config.json')
+print(config)
+
+
+with torch.no_grad():
+    settings =
+    generator = Generator
+    generator.attentiveRNN.load_state_dict(torch.load(load))
+    generator.autoencoder.load_state_dict(torch.load(load))(파일 끝에 줄바꿈 문자 없음)
model/AttentiveRNN.py
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
@@ -196,6 +196,7 @@
             self.generator_blocks.append(
                 self.generator_block
             )
+        self.name = "AttentiveRNN"
 
     def forward(self, x):
         cell_state = None
model/Autoencoder.py
--- model/Autoencoder.py
+++ model/Autoencoder.py
@@ -50,6 +50,7 @@
         for param in self.vgg.parameters():
             param.requires_grad = False
         self.lambda_i = [0.6, 0.8, 1.0]
+        self.name = "AutoEncoder"
 
     def forward(self, input_tensor):
         # Feed the input through each layer
model/Discriminator.py
--- model/Discriminator.py
+++ model/Discriminator.py
@@ -18,6 +18,7 @@
         self.conv9 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, stride=4, padding=2)
         self.fc1 = nn.Linear(32, 1)  # You need to adjust the input dimension here depending on your input size
         self.fc2 = nn.Linear(1, 1)
+        self.name = "Discriminator of Attentive GAN"
     def forward(self, x):
         x1 = F.leaky_relu(self.conv1(x))
         x2 = F.leaky_relu(self.conv2(x1))
model/Generator.py
--- model/Generator.py
+++ model/Generator.py
@@ -25,6 +25,7 @@
         self.groups = groups
         self.dilation = dilation
         self.sigmoid = nn.Sigmoid()
+        self.name = "Generator of Attentive GAN"
 
     def forward(self, x):
         attentiveRNNresults = self.attentiveRNN(x)
train.py
--- train.py
+++ train.py
@@ -59,7 +59,8 @@
 discriminator = Discriminator().to(device=device)
 
 if load is not None:
-    generator.load_state_dict(torch.load(load))
+    generator.attentiveRNN.load_state_dict(torch.load(load))
+    generator.autoencoder.load_state_dict(torch.load(load))
     discriminator.load_state_dict(torch.load(load))
 else:
     pass
@@ -135,7 +136,7 @@
 
         optimizer_D.step()
 
-
+        # Total loss
         optimizer_G.zero_grad()
         generator_loss_whole = generator_loss_AE + generator_loss_ARNN + torch.mean(
             torch.log(torch.subtract(1, fake_clean_prediction))
Add a comment
List