윤영준 윤영준 2023-06-21
discrminator is separated
@bfdeaf185e37210e210842e96a8b72bfb902ec99
model/AttentiveRNN GAN.py
--- model/AttentiveRNN GAN.py
+++ model/AttentiveRNN GAN.py
@@ -10,7 +10,8 @@
         if kernel_size is None:
             kernel_size = [3, 3]
         self.attentiveRNN = AttentiveRNN( repetition,
-            blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1, dilation=1
+            blocks=blocks, layers=layers, input_ch=input_ch, out_ch=out_ch,
+            kernel_size=None, stride=stride, padding=padding, groups=groups, dilation=dilation
         )
         self.autoencoder = AutoEncoder()
         self.blocks = blocks
@@ -27,6 +28,7 @@
     def forward(self, x):
         x, attention_map = self.attentiveRNN(x)
         x = self.autoencoder(x * attention_map)
+        return x
 
 if __name__ == "__main__":
     import torch
model/attentivernn.py
--- model/attentivernn.py
+++ model/attentivernn.py
@@ -218,43 +218,6 @@
         return loss, inference_ret['final_attention_map']
 
 # Need work
-class DiscriminativeNet(nn.Module):
-    def __init__(self, W, H):
-        super(DiscriminativeNet, self).__init__()
-        self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=5, stride=1, padding=2)
-        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=2)
-        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2)
-        self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2)
-        self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2)
-        self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=2)
-        self.conv_map = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=5, stride=1, padding=2, bias=False)
-        self.conv7 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=5, stride=4, padding=2)
-        self.conv8 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, stride=4, padding=2)
-        self.conv9 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, stride=4, padding=2)
-        self.fc1 = nn.Linear(32 * W * H,
-                             1024)  # You need to adjust the input dimension here depending on your input size
-        self.fc2 = nn.Linear(1024, 1)
-
-    def forward(self, x):
-        x1 = F.leaky_relu(self.conv1(x))
-        x2 = F.leaky_relu(self.conv2(x1))
-        x3 = F.leaky_relu(self.conv3(x2))
-        x4 = F.leaky_relu(self.conv4(x3))
-        x5 = F.leaky_relu(self.conv5(x4))
-        x6 = F.leaky_relu(self.conv6(x5))
-        attention_map = self.conv_map(x6)
-        x7 = F.leaky_relu(self.conv7(attention_map * x6))
-        x8 = F.leaky_relu(self.conv8(x7))
-        x9 = F.leaky_relu(self.conv9(x8))
-        x9 = x9.view(x9.size(0), -1)  # flatten the tensor
-        fc1 = self.fc1(x9)
-        fc2 = self.fc2(fc1)
-        fc_out = torch.sigmoid(fc2)
-
-        # Ensure fc_out is not exactly 0 or 1 for stability of log operation in loss
-        fc_out = torch.clamp(fc_out, min=1e-7, max=1 - 1e-7)
-
-        return fc_out, attention_map, fc2
 
 if __name__ == "__main__":
     from torchinfo import summary
 
model/discriminator.py (added)
+++ model/discriminator.py
@@ -0,0 +1,40 @@
+from torch import nn, clamp
+from torch.functional import F
+
+class DiscriminativeNet(nn.Module):
+    def __init__(self, W, H):
+        super(DiscriminativeNet, self).__init__()
+        self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=5, stride=1, padding=2)
+        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=2)
+        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2)
+        self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2)
+        self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2)
+        self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=2)
+        self.conv_map = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=5, stride=1, padding=2, bias=False)
+        self.conv7 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=5, stride=4, padding=2)
+        self.conv8 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, stride=4, padding=2)
+        self.conv9 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, stride=4, padding=2)
+        self.fc1 = nn.Linear(32 * W * H,
+                             1024)  # You need to adjust the input dimension here depending on your input size
+        self.fc2 = nn.Linear(1024, 1)
+
+    def forward(self, x):
+        x1 = F.leaky_relu(self.conv1(x))
+        x2 = F.leaky_relu(self.conv2(x1))
+        x3 = F.leaky_relu(self.conv3(x2))
+        x4 = F.leaky_relu(self.conv4(x3))
+        x5 = F.leaky_relu(self.conv5(x4))
+        x6 = F.leaky_relu(self.conv6(x5))
+        attention_map = self.conv_map(x6)
+        x7 = F.leaky_relu(self.conv7(attention_map * x6))
+        x8 = F.leaky_relu(self.conv8(x7))
+        x9 = F.leaky_relu(self.conv9(x8))
+        x9 = x9.view(x9.size(0), -1)  # flatten the tensor
+        fc1 = self.fc1(x9)
+        fc2 = self.fc2(fc1)
+        fc_out = F.sigmoid(fc2)
+
+        # Ensure fc_out is not exactly 0 or 1 for stability of log operation in loss
+        fc_out = clamp(fc_out, min=1e-7, max=1 - 1e-7)
+
+        return fc_out, attention_map, fc2
Add a comment
List