윤영준 윤영준 2023-07-03
Its running!
still not sure if its really learning
@2fbce641ac973499763f83fc8be17e51f803bbae
model/Discriminator.py
--- model/Discriminator.py
+++ model/Discriminator.py
@@ -1,3 +1,5 @@
+import torch
+from numpy import ceil
 from torch import nn, clamp
 from torch.functional import F
 
@@ -50,9 +52,9 @@
         :return:
         """
 
-        batch_size, image_h, image_w, _ = real_clean.size()
-
-        zeros_mask = torch.zeros([batch_size, image_h, image_w, 1], dtype=torch.float32)
+        batch_size, _, image_h, image_w = real_clean.size()
+        attention_map = F.interpolate(attention_map[-1], size=(int(ceil(image_h/16)), int(ceil(image_w/16))))
+        zeros_mask = torch.zeros([batch_size, 1, int(ceil(image_h/16)), int(ceil(image_w/16))], dtype=torch.float32).to(attention_map.device)
 
         # Inference function
         ret = self.forward(real_clean)
@@ -68,7 +70,7 @@
 
         loss = entropy_loss + 0.05 * l_map
 
-        return fc_out_o, loss
+        return loss
 
 if __name__ == "__main__":
     import torch
tools/logger.py
--- tools/logger.py
+++ tools/logger.py
@@ -50,7 +50,7 @@
 
         sys.stdout.write(
             f"epoch : {current_epoch}/{total_epoch}\n"
-            f"batch : {current_batch}/{total_batch}"
+            f"batch : {current_batch}/{total_batch}\n"
             f"estimated time remaining : {remaining_time}"
             f"{terminal_logging_string}"
         )
train.py
--- train.py
+++ train.py
@@ -105,7 +105,7 @@
 
         optimizer_D.zero_grad()
         real_clean_prediction = discriminator(clean_img)
-        discriminator_loss = discriminator.loss(real_clean_prediction, generator_result, generator_attention_map)
+        discriminator_loss = discriminator.loss(clean_img, generator_result, generator_attention_map)
 
         discriminator_loss.backward()
         optimizer_D.step()
@@ -123,11 +123,6 @@
         torch.save(generator.attentionRNN.state_dict(), f"weight/Attention_RNN_{day}.pt")
         torch.save(generator.state_dict(), f"weight/Generator_{day}.pt")
         torch.save(discriminator.state_dict(), f"weight/Discriminator_{day}.pt")
-
-
-
-
-
 
 
 
Add a comment
List