
--- model/Discriminator.py
+++ model/Discriminator.py
... | ... | @@ -42,10 +42,10 @@ |
42 | 42 |
} |
43 | 43 |
return ret |
44 | 44 |
|
45 |
- def loss(self, real_clean, label_tensor, attention_map): |
|
45 |
+ def loss(self, real_clean, generated_clean, attention_map): |
|
46 | 46 |
""" |
47 | 47 |
:param real_clean: |
48 |
- :param label_tensor: |
|
48 |
+ :param generated_clean: |
|
49 | 49 |
:param attention_map: This is the final attention map from the generator. |
50 | 50 |
:return: |
51 | 51 |
""" |
... | ... | @@ -57,7 +57,7 @@ |
57 | 57 |
# Inference function |
58 | 58 |
ret = self.forward(real_clean) |
59 | 59 |
fc_out_o, attention_mask_o, fc2_o = ret["fc_out"], ret["attention_map"], ret["fc_raw"] |
60 |
- ret = self.forward(label_tensor) |
|
60 |
+ ret = self.forward(generated_clean) |
|
61 | 61 |
fc_out_r, attention_mask_r, fc2_r = ret["fc_out"], ret["attention_map"], ret["fc_raw"] |
62 | 62 |
|
63 | 63 |
l_map = F.mse_loss(attention_map, attention_mask_o) + \ |
--- model/Generator.py
+++ model/Generator.py
... | ... | @@ -35,7 +35,7 @@ |
35 | 35 |
return ret |
36 | 36 |
|
37 | 37 |
def binary_diff_mask(self, clean, dirty, thresold=0.1): |
38 |
- # this parts corrects gamma, and always remember, sRGB values are not in linear scale with lights intensity |
|
38 |
+ # this parts corrects gamma, and always remember, sRGB values are not in linear scale with lights intensity, |
|
39 | 39 |
clean = torch.pow(clean, 0.45) |
40 | 40 |
dirty = torch.pow(dirty, 0.45) |
41 | 41 |
diff = torch.abs(clean - dirty) |
... | ... | @@ -44,12 +44,18 @@ |
44 | 44 |
bin_diff = (diff > thresold).to(clean.dtype) |
45 | 45 |
|
46 | 46 |
return bin_diff |
47 |
+ |
|
47 | 48 |
def loss(self, clean, dirty, thresold=0.1): |
48 | 49 |
# check diff if they are working as intended |
49 |
- diff = self.binary_diff_mask(clean, dirty, thresold) |
|
50 |
+ diff_mask = self.binary_diff_mask(clean, dirty, thresold) |
|
50 | 51 |
|
51 |
- self.attentiveRNN.loss(clean, diff) |
|
52 |
- self.autoencoder.loss(clean, dirty) |
|
52 |
+ attentive_rnn_loss = self.attentiveRNN.loss(clean, diff_mask) |
|
53 |
+ autoencoder_loss = self.autoencoder.loss(clean, dirty) |
|
54 |
+ ret = { |
|
55 |
+ "attentive_rnn_loss" : attentive_rnn_loss, |
|
56 |
+ "autoencoder_loss" : autoencoder_loss, |
|
57 |
+ } |
|
58 |
+ return ret |
|
53 | 59 |
|
54 | 60 |
if __name__ == "__main__": |
55 | 61 |
import torch |
--- tools/logger.py
+++ tools/logger.py
... | ... | @@ -31,7 +31,7 @@ |
31 | 31 |
fig = px.line(x=list(range(len(self.losses))), y=self.losses, labels={'x': 'Epoch', 'y': 'Loss'}) |
32 | 32 |
return fig |
33 | 33 |
|
34 |
- def print_training_log(self, current_epoch, total_epoch, losses): |
|
34 |
+ def print_training_log(self, current_epoch, total_epoch, current_batch, total_batch, losses): |
|
35 | 35 |
assert type(losses) == dict |
36 | 36 |
current_time = time.time() |
37 | 37 |
epoch_time = current_time - self.epoch_start_time |
... | ... | @@ -50,6 +50,7 @@ |
50 | 50 |
|
51 | 51 |
sys.stdout.write( |
52 | 52 |
f"epoch : {current_epoch}/{total_epoch}\n" |
53 |
+ f"batch : {current_batch}/{total_batch}" |
|
53 | 54 |
f"estimated time remaining : {remaining_time}" |
54 | 55 |
f"{terminal_logging_string}" |
55 | 56 |
) |
--- train.py
+++ train.py
... | ... | @@ -46,12 +46,39 @@ |
46 | 46 |
|
47 | 47 |
dataloader = Dataloader() |
48 | 48 |
|
49 |
- |
|
50 | 49 |
# declare generator loss |
51 | 50 |
|
51 |
+optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr) |
|
52 |
+optimizer_D = torch.optim.Adam(generator.parameters(), lr=lr) |
|
52 | 53 |
|
53 |
-for epoch in range(epoch): |
|
54 |
+for epoch_num, epoch in enumerate(range(epochs)): |
|
54 | 55 |
for i, (imgs, _) in enumerate(dataloader): |
56 |
+ logger.print_training_log(epoch_num, epochs, i, len(enumerate(dataloader))) |
|
57 |
+ |
|
58 |
+ img_batch = data[0].to(device) |
|
59 |
+ clean_img = img_batch["clean_image"] |
|
60 |
+ rainy_img = img_batch["rainy_image"] |
|
61 |
+ |
|
62 |
+ optimizer_G.zero_grad() |
|
63 |
+ generator_outputs = generator(clean_img) |
|
64 |
+ generator_result = generator_outputs["x"] |
|
65 |
+ generator_attention_map = generator_outputs["attention_map_list"] |
|
66 |
+ generator_loss = generator.loss(clean_img, rainy_img) |
|
67 |
+ generator_loss.backward() |
|
68 |
+ optimizer_G.step() |
|
69 |
+ |
|
70 |
+ optimizer_D.zero_grad() |
|
71 |
+ real_clean_prediction = discriminator(clean_img) |
|
72 |
+ discriminator_loss = discriminator.loss(real_clean_prediction, generator_result, generator_attention_map) |
|
73 |
+ |
|
74 |
+ discriminator_loss.backward() |
|
75 |
+ optimizer_D.step() |
|
76 |
+ |
|
77 |
+ |
|
78 |
+ |
|
79 |
+ |
|
80 |
+ |
|
81 |
+ |
|
55 | 82 |
|
56 | 83 |
|
57 | 84 |
torch.save(generator.attentionRNN.state_dict(), "attentionRNN_model_path") |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?