윤영준 윤영준 2023-06-30
Added comment
@fa5311286c683280f564362bb987c22045d6e9ee
model/Discriminator.py
--- model/Discriminator.py
+++ model/Discriminator.py
@@ -42,10 +42,10 @@
         }
         return ret
 
-    def loss(self, real_clean, label_tensor, attention_map):
+    def loss(self, real_clean, generated_clean, attention_map):
         """
         :param real_clean:
-        :param label_tensor:
+        :param generated_clean:
         :param attention_map: This is the final attention map from the generator.
         :return:
         """
@@ -57,7 +57,7 @@
             # Inference function
             ret = self.forward(real_clean)
             fc_out_o, attention_mask_o, fc2_o = ret["fc_out"], ret["attention_map"], ret["fc_raw"]
-            ret = self.forward(label_tensor)
+            ret = self.forward(generated_clean)
             fc_out_r, attention_mask_r, fc2_r = ret["fc_out"], ret["attention_map"], ret["fc_raw"]
 
             l_map = F.mse_loss(attention_map, attention_mask_o) + \
model/Generator.py
--- model/Generator.py
+++ model/Generator.py
@@ -35,7 +35,7 @@
         return ret
 
     def binary_diff_mask(self, clean, dirty, thresold=0.1):
-        # this parts corrects gamma, and always remember, sRGB values are not in linear scale with lights intensity
+        # this parts corrects gamma, and always remember, sRGB values are not in linear scale with lights intensity,
         clean = torch.pow(clean, 0.45)
         dirty = torch.pow(dirty, 0.45)
         diff = torch.abs(clean - dirty)
@@ -44,12 +44,18 @@
         bin_diff = (diff > thresold).to(clean.dtype)
 
         return bin_diff
+
     def loss(self, clean, dirty, thresold=0.1):
         # check diff if they are working as intended
-        diff = self.binary_diff_mask(clean, dirty, thresold)
+        diff_mask = self.binary_diff_mask(clean, dirty, thresold)
 
-        self.attentiveRNN.loss(clean, diff)
-        self.autoencoder.loss(clean, dirty)
+        attentive_rnn_loss = self.attentiveRNN.loss(clean, diff_mask)
+        autoencoder_loss = self.autoencoder.loss(clean, dirty)
+        ret = {
+            "attentive_rnn_loss" : attentive_rnn_loss,
+            "autoencoder_loss" : autoencoder_loss,
+        }
+        return ret
 
 if __name__ == "__main__":
     import torch
tools/logger.py
--- tools/logger.py
+++ tools/logger.py
@@ -31,7 +31,7 @@
             fig = px.line(x=list(range(len(self.losses))), y=self.losses, labels={'x': 'Epoch', 'y': 'Loss'})
             return fig
 
-    def print_training_log(self, current_epoch, total_epoch, losses):
+    def print_training_log(self, current_epoch, total_epoch, current_batch, total_batch, losses):
         assert type(losses) == dict
         current_time = time.time()
         epoch_time = current_time - self.epoch_start_time
@@ -50,6 +50,7 @@
 
         sys.stdout.write(
             f"epoch : {current_epoch}/{total_epoch}\n"
+            f"batch : {current_batch}/{total_batch}"
             f"estimated time remaining : {remaining_time}"
             f"{terminal_logging_string}"
         )
train.py
--- train.py
+++ train.py
@@ -46,12 +46,39 @@
 
 dataloader = Dataloader()
 
-
 # declare generator loss
 
+optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr)
+optimizer_D = torch.optim.Adam(generator.parameters(), lr=lr)
 
-for epoch in range(epoch):
+for epoch_num, epoch in enumerate(range(epochs)):
     for i, (imgs, _) in enumerate(dataloader):
+        logger.print_training_log(epoch_num, epochs, i, len(enumerate(dataloader)))
+
+        img_batch = data[0].to(device)
+        clean_img = img_batch["clean_image"]
+        rainy_img = img_batch["rainy_image"]
+
+        optimizer_G.zero_grad()
+        generator_outputs = generator(clean_img)
+        generator_result = generator_outputs["x"]
+        generator_attention_map = generator_outputs["attention_map_list"]
+        generator_loss = generator.loss(clean_img, rainy_img)
+        generator_loss.backward()
+        optimizer_G.step()
+
+        optimizer_D.zero_grad()
+        real_clean_prediction = discriminator(clean_img)
+        discriminator_loss = discriminator.loss(real_clean_prediction, generator_result, generator_attention_map)
+
+        discriminator_loss.backward()
+        optimizer_D.step()
+
+
+
+
+
+
 
 
 torch.save(generator.attentionRNN.state_dict(), "attentionRNN_model_path")
Add a comment
List