윤영준 윤영준 2023-06-21
Creation of Attentive RNN GAN, change of naming scheme for the shake of conciseness.
@cecfc8dc40869fd782c33519b6f596a5f536c800
 
model/AttentiveRNN GAN.py (added)
+++ model/AttentiveRNN GAN.py
@@ -0,0 +1,38 @@
+from attentivernn import AttentiveRNN
+from autoencoder import AutoEncoder
+from torch import nn
+
+
+class Generator(nn.Module):
+    def __init__(self, repetition, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1,
+                 dilation=1):
+        super(Generator, self).__init__()
+        if kernel_size is None:
+            kernel_size = [3, 3]
+        self.attentiveRNN = AttentiveRNN( repetition,
+            blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1, dilation=1
+        )
+        self.autoencoder = AutoEncoder()
+        self.blocks = blocks
+        self.layers = layers
+        self.input_ch = input_ch
+        self.out_ch = out_ch
+        self.kernel_size = kernel_size
+        self.stride = stride
+        self.padding = padding
+        self.groups = groups
+        self.dilation = dilation
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, x):
+        x, attention_map = self.attentiveRNN(x)
+        x = self.autoencoder(x * attention_map)
+
+if __name__ == "__main__":
+    import torch
+    from torchinfo import summary
+
+    torch.set_default_tensor_type(torch.FloatTensor)
+    generator = Generator(3, blocks=2)
+    batch_size = 2
+    summary(generator, input_size=(batch_size, 3, 960,540))
model/attentivernn.py (Renamed from model/generator.py)
--- model/generator.py
+++ model/attentivernn.py
@@ -111,13 +111,13 @@
         return attention_map, cell_state, lstm_feats
 
 
-class GeneratorBlock(nn.Module):
+class AttentiveRNNBLCK(nn.Module):
     def __init__(self, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1,
                  dilation=1):
         """
         :type kernel_size: iterator or int
         """
-        super(GeneratorBlock, self).__init__()
+        super(AttentiveRNNBLCK, self).__init__()
         if kernel_size is None:
             kernel_size = [3, 3]
         self.blocks = blocks
@@ -156,13 +156,13 @@
         return x, attention_map, cell_state, lstm_feats
 
 
-class Generator(nn.Module):
+class AttentiveRNN(nn.Module):
     def __init__(self, repetition, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1,
                  groups=1, dilation=1):
         """
         :type kernel_size: iterator or int
         """
-        super(Generator, self).__init__()
+        super(AttentiveRNN, self).__init__()
         if kernel_size is None:
             kernel_size = [3, 3]
         self.blocks = blocks
@@ -176,15 +176,15 @@
         self.dilation = dilation
         self.repetition = repetition
         self.generator_block = mySequential(
-            GeneratorBlock(blocks=blocks,
-                           layers=layers,
-                           input_ch=input_ch,
-                           out_ch=out_ch,
-                           kernel_size=kernel_size,
-                           stride=stride,
-                           padding=padding,
-                           groups=groups,
-                           dilation=dilation)
+            AttentiveRNNBLCK(blocks=blocks,
+                             layers=layers,
+                             input_ch=input_ch,
+                             out_ch=out_ch,
+                             kernel_size=kernel_size,
+                             stride=stride,
+                             padding=padding,
+                             groups=groups,
+                             dilation=dilation)
         )
         self.generator_blocks = nn.ModuleList()
         for repetition in range(repetition):
@@ -206,7 +206,7 @@
 
     def forward(self, input_tensor, label_tensor):
         # Initialize attentive rnn model
-        attentive_rnn = Generator
+        attentive_rnn = AttentiveRNN
         inference_ret = attentive_rnn(input_tensor)
 
         loss = 0.0
@@ -258,6 +258,9 @@
 
 if __name__ == "__main__":
     from torchinfo import summary
-    generator = Generator(3)
-    batch_size = 1
+
+    torch.set_default_tensor_type(torch.FloatTensor)
+
+    generator = AttentiveRNN(3, blocks=2)
+    batch_size = 5
     summary(generator, input_size=(batch_size, 3, 960,540))
model/autoencoder.py
--- model/autoencoder.py
+++ model/autoencoder.py
@@ -40,18 +40,21 @@
     # maybe change it into concat Networks? this seems way to cumbersome.
     def forward(self, input_tensor):
         # Feed the input through each layer
-        relu1 = torch.relu(self.conv1(input_tensor))
-        relu2 = torch.relu(self.conv2(relu1))
-        relu3 = torch.relu(self.conv3(relu2))
-        relu4 = torch.relu(self.conv4(relu3))
-        relu5 = torch.relu(self.conv5(relu4))
-        relu6 = torch.relu(self.conv6(relu5))
-        relu7 = torch.relu(self.dilated_conv1(relu6))
-        relu8 = torch.relu(self.dilated_conv2(relu7))
-        relu9 = torch.relu(self.dilated_conv3(relu8))
-        relu10 = torch.relu(self.dilated_conv4(relu9))
-        relu11 = torch.relu(self.conv7(relu10))
-        relu12 = torch.relu(self.conv8(relu11))
+        x = torch.relu(self.conv1(input_tensor))
+        relu1 = x
+        x = torch.relu(self.conv2(x))
+        x = torch.relu(self.conv3(x))
+        relu3 = x
+        x = torch.relu(self.conv4(x))
+        x = torch.relu(self.conv5(x))
+        x = torch.relu(self.conv6(x))
+        x = torch.relu(self.dilated_conv1(x))
+        x = torch.relu(self.dilated_conv2(x))
+        x = torch.relu(self.dilated_conv3(x))
+        x = torch.relu(self.dilated_conv4(x))
+        x = torch.relu(self.conv7(x))
+        x = torch.relu(self.conv8(x))
+        relu12 = x
 
         deconv1 = self.deconv1(relu12)
         avg_pool1 = self.avg_pool1(deconv1)
@@ -131,4 +134,12 @@
             x = layer(x)
             if layer_num in {3, 8, 15, 22, 29}:
                 feats.append(x)
-        return feats
(No newline at end of file)
+        return feats
+
+
+if __name__ == "__main__":
+    from torchinfo import summary
+    torch.set_default_tensor_type(torch.FloatTensor)
+    generator = AutoEncoder()
+    batch_size = 2
+    summary(generator, input_size=(batch_size, 3, 960,540))
Add a comment
List