윤영준 윤영준 2023-06-27
Discriminator Loss
@415bdf44a549e535f776f2c23546a6f76ff563e5
model/Discriminator.py
--- model/Discriminator.py
+++ model/Discriminator.py
@@ -17,6 +17,25 @@
         self.fc1 = nn.Linear(32, 1)  # You need to adjust the input dimension here depending on your input size
         self.fc2 = nn.Linear(1, 1)
 
+    def loss(self, input_tensor, label_tensor, attention_map, name):
+        batch_size, image_h, image_w, _ = input_tensor.size()
+
+        zeros_mask = torch.zeros([batch_size, image_h, image_w, 1], dtype=torch.float32)
+
+        # Inference function
+        fc_out_o, attention_mask_o, fc2_o = self.forward(input_tensor)
+        fc_out_r, attention_mask_r, fc2_r = self.forward(label_tensor)
+
+        l_map = F.mse_loss(attention_map, attention_mask_o) + \
+                F.mse_loss(attention_mask_r, zeros_mask)
+
+        entropy_loss = -torch.log(fc_out_r) - torch.log(-torch.sub(fc_out_o, 1.0))
+        entropy_loss = torch.mean(entropy_loss)
+
+        loss = entropy_loss + 0.05 * l_map
+
+        return fc_out_o, loss
+
     def forward(self, x):
         x1 = F.leaky_relu(self.conv1(x))
         x2 = F.leaky_relu(self.conv2(x1))
tools/dataloader.py
--- tools/dataloader.py
+++ tools/dataloader.py
@@ -8,8 +8,6 @@
         self.clean_img = clean_img_dir
         self.rainy_img = rainy_img_dirs
         self.transform = transform
-        # self.target_transform = target_transform
-        # print(self.transform)
 
     def __len__(self):
         return len(self.clean_img)
train.py
--- train.py
+++ train.py
@@ -44,9 +44,18 @@
     generator.apply(weights_init_normal)
     discriminator.apply(weights_init_normal)
 
-dataloader = Dataloader
+dataloader = Dataloader()
 
 
+# declare generator loss
+
+
+for epoch in range(epoch):
+    for i, (imgs, _) in enumerate(dataloader):
+
+
+torch.save(generator.attentionRNN.state_dict(), "attentionRNN_model_path")
+
 
 
 ## RNN 따로 돌리고 CPU로 메모리 옳기고
Add a comment
List