
Creation of Attentive RNN GAN, change of naming scheme for the shake of conciseness.
@cecfc8dc40869fd782c33519b6f596a5f536c800
+++ model/AttentiveRNN GAN.py
... | ... | @@ -0,0 +1,38 @@ |
1 | +from attentivernn import AttentiveRNN | |
2 | +from autoencoder import AutoEncoder | |
3 | +from torch import nn | |
4 | + | |
5 | + | |
6 | +class Generator(nn.Module): | |
7 | + def __init__(self, repetition, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1, | |
8 | + dilation=1): | |
9 | + super(Generator, self).__init__() | |
10 | + if kernel_size is None: | |
11 | + kernel_size = [3, 3] | |
12 | + self.attentiveRNN = AttentiveRNN( repetition, | |
13 | + blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1, dilation=1 | |
14 | + ) | |
15 | + self.autoencoder = AutoEncoder() | |
16 | + self.blocks = blocks | |
17 | + self.layers = layers | |
18 | + self.input_ch = input_ch | |
19 | + self.out_ch = out_ch | |
20 | + self.kernel_size = kernel_size | |
21 | + self.stride = stride | |
22 | + self.padding = padding | |
23 | + self.groups = groups | |
24 | + self.dilation = dilation | |
25 | + self.sigmoid = nn.Sigmoid() | |
26 | + | |
27 | + def forward(self, x): | |
28 | + x, attention_map = self.attentiveRNN(x) | |
29 | + x = self.autoencoder(x * attention_map) | |
30 | + | |
31 | +if __name__ == "__main__": | |
32 | + import torch | |
33 | + from torchinfo import summary | |
34 | + | |
35 | + torch.set_default_tensor_type(torch.FloatTensor) | |
36 | + generator = Generator(3, blocks=2) | |
37 | + batch_size = 2 | |
38 | + summary(generator, input_size=(batch_size, 3, 960,540)) |
--- model/generator.py
+++ model/attentivernn.py
... | ... | @@ -111,13 +111,13 @@ |
111 | 111 |
return attention_map, cell_state, lstm_feats |
112 | 112 |
|
113 | 113 |
|
114 |
-class GeneratorBlock(nn.Module): |
|
114 |
+class AttentiveRNNBLCK(nn.Module): |
|
115 | 115 |
def __init__(self, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1, |
116 | 116 |
dilation=1): |
117 | 117 |
""" |
118 | 118 |
:type kernel_size: iterator or int |
119 | 119 |
""" |
120 |
- super(GeneratorBlock, self).__init__() |
|
120 |
+ super(AttentiveRNNBLCK, self).__init__() |
|
121 | 121 |
if kernel_size is None: |
122 | 122 |
kernel_size = [3, 3] |
123 | 123 |
self.blocks = blocks |
... | ... | @@ -156,13 +156,13 @@ |
156 | 156 |
return x, attention_map, cell_state, lstm_feats |
157 | 157 |
|
158 | 158 |
|
159 |
-class Generator(nn.Module): |
|
159 |
+class AttentiveRNN(nn.Module): |
|
160 | 160 |
def __init__(self, repetition, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, |
161 | 161 |
groups=1, dilation=1): |
162 | 162 |
""" |
163 | 163 |
:type kernel_size: iterator or int |
164 | 164 |
""" |
165 |
- super(Generator, self).__init__() |
|
165 |
+ super(AttentiveRNN, self).__init__() |
|
166 | 166 |
if kernel_size is None: |
167 | 167 |
kernel_size = [3, 3] |
168 | 168 |
self.blocks = blocks |
... | ... | @@ -176,15 +176,15 @@ |
176 | 176 |
self.dilation = dilation |
177 | 177 |
self.repetition = repetition |
178 | 178 |
self.generator_block = mySequential( |
179 |
- GeneratorBlock(blocks=blocks, |
|
180 |
- layers=layers, |
|
181 |
- input_ch=input_ch, |
|
182 |
- out_ch=out_ch, |
|
183 |
- kernel_size=kernel_size, |
|
184 |
- stride=stride, |
|
185 |
- padding=padding, |
|
186 |
- groups=groups, |
|
187 |
- dilation=dilation) |
|
179 |
+ AttentiveRNNBLCK(blocks=blocks, |
|
180 |
+ layers=layers, |
|
181 |
+ input_ch=input_ch, |
|
182 |
+ out_ch=out_ch, |
|
183 |
+ kernel_size=kernel_size, |
|
184 |
+ stride=stride, |
|
185 |
+ padding=padding, |
|
186 |
+ groups=groups, |
|
187 |
+ dilation=dilation) |
|
188 | 188 |
) |
189 | 189 |
self.generator_blocks = nn.ModuleList() |
190 | 190 |
for repetition in range(repetition): |
... | ... | @@ -206,7 +206,7 @@ |
206 | 206 |
|
207 | 207 |
def forward(self, input_tensor, label_tensor): |
208 | 208 |
# Initialize attentive rnn model |
209 |
- attentive_rnn = Generator |
|
209 |
+ attentive_rnn = AttentiveRNN |
|
210 | 210 |
inference_ret = attentive_rnn(input_tensor) |
211 | 211 |
|
212 | 212 |
loss = 0.0 |
... | ... | @@ -258,6 +258,9 @@ |
258 | 258 |
|
259 | 259 |
if __name__ == "__main__": |
260 | 260 |
from torchinfo import summary |
261 |
- generator = Generator(3) |
|
262 |
- batch_size = 1 |
|
261 |
+ |
|
262 |
+ torch.set_default_tensor_type(torch.FloatTensor) |
|
263 |
+ |
|
264 |
+ generator = AttentiveRNN(3, blocks=2) |
|
265 |
+ batch_size = 5 |
|
263 | 266 |
summary(generator, input_size=(batch_size, 3, 960,540)) |
--- model/autoencoder.py
+++ model/autoencoder.py
... | ... | @@ -40,18 +40,21 @@ |
40 | 40 |
# maybe change it into concat Networks? this seems way to cumbersome. |
41 | 41 |
def forward(self, input_tensor): |
42 | 42 |
# Feed the input through each layer |
43 |
- relu1 = torch.relu(self.conv1(input_tensor)) |
|
44 |
- relu2 = torch.relu(self.conv2(relu1)) |
|
45 |
- relu3 = torch.relu(self.conv3(relu2)) |
|
46 |
- relu4 = torch.relu(self.conv4(relu3)) |
|
47 |
- relu5 = torch.relu(self.conv5(relu4)) |
|
48 |
- relu6 = torch.relu(self.conv6(relu5)) |
|
49 |
- relu7 = torch.relu(self.dilated_conv1(relu6)) |
|
50 |
- relu8 = torch.relu(self.dilated_conv2(relu7)) |
|
51 |
- relu9 = torch.relu(self.dilated_conv3(relu8)) |
|
52 |
- relu10 = torch.relu(self.dilated_conv4(relu9)) |
|
53 |
- relu11 = torch.relu(self.conv7(relu10)) |
|
54 |
- relu12 = torch.relu(self.conv8(relu11)) |
|
43 |
+ x = torch.relu(self.conv1(input_tensor)) |
|
44 |
+ relu1 = x |
|
45 |
+ x = torch.relu(self.conv2(x)) |
|
46 |
+ x = torch.relu(self.conv3(x)) |
|
47 |
+ relu3 = x |
|
48 |
+ x = torch.relu(self.conv4(x)) |
|
49 |
+ x = torch.relu(self.conv5(x)) |
|
50 |
+ x = torch.relu(self.conv6(x)) |
|
51 |
+ x = torch.relu(self.dilated_conv1(x)) |
|
52 |
+ x = torch.relu(self.dilated_conv2(x)) |
|
53 |
+ x = torch.relu(self.dilated_conv3(x)) |
|
54 |
+ x = torch.relu(self.dilated_conv4(x)) |
|
55 |
+ x = torch.relu(self.conv7(x)) |
|
56 |
+ x = torch.relu(self.conv8(x)) |
|
57 |
+ relu12 = x |
|
55 | 58 |
|
56 | 59 |
deconv1 = self.deconv1(relu12) |
57 | 60 |
avg_pool1 = self.avg_pool1(deconv1) |
... | ... | @@ -131,4 +134,12 @@ |
131 | 134 |
x = layer(x) |
132 | 135 |
if layer_num in {3, 8, 15, 22, 29}: |
133 | 136 |
feats.append(x) |
134 |
- return feats(No newline at end of file) |
|
137 |
+ return feats |
|
138 |
+ |
|
139 |
+ |
|
140 |
+if __name__ == "__main__": |
|
141 |
+ from torchinfo import summary |
|
142 |
+ torch.set_default_tensor_type(torch.FloatTensor) |
|
143 |
+ generator = AutoEncoder() |
|
144 |
+ batch_size = 2 |
|
145 |
+ summary(generator, input_size=(batch_size, 3, 960,540)) |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?