윤영준 윤영준 2023-06-29
Still working on loss functions
@84f45a018b9373cc36f33b3197f3b2200af5c983
model/Discriminator.py
--- model/Discriminator.py
+++ model/Discriminator.py
@@ -16,6 +16,31 @@
         self.conv9 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, stride=4, padding=2)
         self.fc1 = nn.Linear(32, 1)  # You need to adjust the input dimension here depending on your input size
         self.fc2 = nn.Linear(1, 1)
+    def forward(self, x):
+        x1 = F.leaky_relu(self.conv1(x))
+        x2 = F.leaky_relu(self.conv2(x1))
+        x3 = F.leaky_relu(self.conv3(x2))
+        x4 = F.leaky_relu(self.conv4(x3))
+        x5 = F.leaky_relu(self.conv5(x4))
+        x6 = F.leaky_relu(self.conv6(x5))
+        attention_map = self.conv_attention(x6)
+        x7 = F.leaky_relu(self.conv7(attention_map * x6))
+        x8 = F.leaky_relu(self.conv8(x7))
+        x9 = F.leaky_relu(self.conv9(x8))
+        x9 = x9.view(x9.size(0), -1)  # flatten the tensor
+        fc1 = self.fc1(x9)
+        fc_raw = self.fc2(fc1)
+        fc_out = F.sigmoid(fc_raw)
+
+        # Ensure fc_out is not exactly 0 or 1 for stability of log operation in loss
+        fc_out = clamp(fc_out, min=1e-7, max=1 - 1e-7)
+
+        ret = {
+            "fc_out" : fc_out,
+            "attention_map": attention_map,
+            "fc_raw" : fc_raw
+        }
+        return ret
 
     def loss(self, real_clean, label_tensor, attention_map):
         """
@@ -44,32 +69,6 @@
             loss = entropy_loss + 0.05 * l_map
 
         return fc_out_o, loss
-
-    def forward(self, x):
-        x1 = F.leaky_relu(self.conv1(x))
-        x2 = F.leaky_relu(self.conv2(x1))
-        x3 = F.leaky_relu(self.conv3(x2))
-        x4 = F.leaky_relu(self.conv4(x3))
-        x5 = F.leaky_relu(self.conv5(x4))
-        x6 = F.leaky_relu(self.conv6(x5))
-        attention_map = self.conv_attention(x6)
-        x7 = F.leaky_relu(self.conv7(attention_map * x6))
-        x8 = F.leaky_relu(self.conv8(x7))
-        x9 = F.leaky_relu(self.conv9(x8))
-        x9 = x9.view(x9.size(0), -1)  # flatten the tensor
-        fc1 = self.fc1(x9)
-        fc_raw = self.fc2(fc1)
-        fc_out = F.sigmoid(fc_raw)
-
-        # Ensure fc_out is not exactly 0 or 1 for stability of log operation in loss
-        fc_out = clamp(fc_out, min=1e-7, max=1 - 1e-7)
-
-        ret = {
-            "fc_out" : fc_out,
-            "attention_map": attention_map,
-            "fc_raw" : fc_raw
-        }
-        return ret
 
 if __name__ == "__main__":
     import torch
model/Generator.py
--- model/Generator.py
+++ model/Generator.py
@@ -34,9 +34,22 @@
         }
         return ret
 
-    def loss(self, x, diff):
-        self.attentiveRNN.loss(x, diff)
-        self.autoencoder.loss(x)
+    def binary_diff_mask(self, clean, dirty, thresold=0.1):
+        clean = torch.pow(clean, 1/2.2)
+        dirty = torch.pow(dirty, 1/2.2)
+        diff = torch.abs(clean - dirty)
+        diff = torch.sum(diff, dim=1)
+        # this line is certainly cause problem for quantization
+        # like, hardcoding it, what could go wrong?
+        bin_diff = (diff > thresold).to(clean.dtype)
+
+        return bin_diff
+    def loss(self, clean, dirty, thresold=0.1):
+        # check diff if they are working as intended
+        diff = self.binary_diff_mask(clean, dirty, thresold)
+
+        self.attentiveRNN.loss(clean, diff)
+        self.autoencoder.loss(clean, dirty)
 
 if __name__ == "__main__":
     import torch
Add a comment
List