
Added a vizdom window for attention map to view training
@472af24ca65bbc328559b5050b4548e225b0b586
--- train.py
+++ train.py
... | ... | @@ -74,7 +74,7 @@ |
74 | 74 |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
75 | 75 |
# declare generator loss |
76 | 76 |
|
77 |
- |
|
77 |
+optimizer_G = torch.optim.Adam(generator.parameters(), lr=generator_learning_rate) |
|
78 | 78 |
optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_learning_rate) |
79 | 79 |
optimizer_G_AE = torch.optim.Adam(generator.autoencoder.parameters(), lr=generator_learning_rate) |
80 | 80 |
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=discriminator_learning_rate) |
... | ... | @@ -91,6 +91,8 @@ |
91 | 91 |
ARNN_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AttentionRNN Loss')) |
92 | 92 |
AE_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AutoEncoder Loss')) |
93 | 93 |
Discriminator_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Discriminator Loss')) |
94 |
+attenton_map_visualizer = vis.image(np.zeros((692,776)), opts=dict(title='Attention Map')) |
|
95 |
+ |
|
94 | 96 |
|
95 | 97 |
for epoch_num, epoch in enumerate(range(epochs)): |
96 | 98 |
for i, imgs in enumerate(dataloader): |
... | ... | @@ -125,10 +127,17 @@ |
125 | 127 |
|
126 | 128 |
optimizer_D.zero_grad() |
127 | 129 |
real_clean_prediction = discriminator(clean_img) |
130 |
+ fake_clean_prediction = discriminator(generator_result)["fc_out"] |
|
128 | 131 |
discriminator_loss = discriminator.loss(clean_img, generator_result, generator_attention_map) |
129 |
- |
|
130 | 132 |
discriminator_loss.backward() |
133 |
+ |
|
131 | 134 |
optimizer_D.step() |
135 |
+ |
|
136 |
+ |
|
137 |
+ optimizer_G.zero_grad() |
|
138 |
+ generator_loss_whole = generator_loss_AE + generator_loss_ARNN + torch.mean( |
|
139 |
+ torch.log(torch.subtract(1, fake_clean_prediction))) |
|
140 |
+ optimizer_G.step() |
|
132 | 141 |
|
133 | 142 |
losses = { |
134 | 143 |
"generator_loss_ARNN": generator_loss_ARNN, |
... | ... | @@ -145,6 +154,7 @@ |
145 | 154 |
update='append') |
146 | 155 |
vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch * epoch_num + i]), win=Discriminator_loss_window, |
147 | 156 |
update='append') |
157 |
+ vis.image(generator_attention_map[-1], win=attenton_map_visualizer, opts=dict(title="Attention Map")) |
|
148 | 158 |
|
149 | 159 |
day = strftime("%Y-%m-%d %H:%M:%S", gmtime()) |
150 | 160 |
if epoch % save_interval == 0 and epoch != 0: |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?