윤영준 윤영준 2023-07-05
corrected the inference output
@c3d88f8b15bd0c0d9de156d60fe1e421b403541f
model/Autoencoder.py
--- model/Autoencoder.py
+++ model/Autoencoder.py
@@ -83,12 +83,14 @@
 
         skip_output_1 = self.skip_output1(relu12)
         skip_output_2 = self.skip_output2(relu14)
-        skip_output_3 = torch.tanh(self.skip_output3(relu16))
+        output = self.skip_output3(relu16)
+        skip_output_3 = torch.tanh(output)
 
         ret = {
             'skip_1': skip_output_1,
             'skip_2': skip_output_2,
             'skip_3': skip_output_3,
+            'output': output
         }
 
         return ret
train.py
--- train.py
+++ train.py
@@ -154,7 +154,7 @@
         vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch * epoch_num + i]), win=Discriminator_loss_window,
                  update='append')
         vis.image(generator_attention_map[-1][0,0,:,:], win=Attention_map_visualizer, opts=dict(title="Attention Map"))
-        vis.image(generator_result[-1], win=Generator_output_visualizer, opts=dict(title="Generator Output"))
+        vis.image(generator_result[-1]*255, win=Generator_output_visualizer, opts=dict(title="Generator Output"))
     day = strftime("%Y-%m-%d %H:%M:%S", gmtime())
     if epoch % save_interval == 0 and epoch != 0:
         torch.save(generator.attentiveRNN.state_dict(), f"weight/Attention_RNN_{epoch}_{day}.pt")
Add a comment
List