--- model/Autoencoder.py
+++ model/Autoencoder.py
... | ... | @@ -83,12 +83,14 @@ |
83 | 83 |
|
84 | 84 |
skip_output_1 = self.skip_output1(relu12) |
85 | 85 |
skip_output_2 = self.skip_output2(relu14) |
86 |
- skip_output_3 = torch.tanh(self.skip_output3(relu16)) |
|
86 |
+ output = self.skip_output3(relu16) |
|
87 |
+ skip_output_3 = torch.tanh(output) |
|
87 | 88 |
|
88 | 89 |
ret = { |
89 | 90 |
'skip_1': skip_output_1, |
90 | 91 |
'skip_2': skip_output_2, |
91 | 92 |
'skip_3': skip_output_3, |
93 |
+ 'output': output |
|
92 | 94 |
} |
93 | 95 |
|
94 | 96 |
return ret |
--- train.py
+++ train.py
... | ... | @@ -154,7 +154,7 @@ |
154 | 154 |
vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch * epoch_num + i]), win=Discriminator_loss_window, |
155 | 155 |
update='append') |
156 | 156 |
vis.image(generator_attention_map[-1][0,0,:,:], win=Attention_map_visualizer, opts=dict(title="Attention Map")) |
157 |
- vis.image(generator_result[-1], win=Generator_output_visualizer, opts=dict(title="Generator Output")) |
|
157 |
+ vis.image(generator_result[-1]*255, win=Generator_output_visualizer, opts=dict(title="Generator Output")) |
|
158 | 158 |
day = strftime("%Y-%m-%d %H:%M:%S", gmtime()) |
159 | 159 |
if epoch % save_interval == 0 and epoch != 0: |
160 | 160 |
torch.save(generator.attentiveRNN.state_dict(), f"weight/Attention_RNN_{epoch}_{day}.pt") |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?