data:image/s3,"s3://crabby-images/77fc1/77fc1ecd598263bdfa1d6248fbe60b3bfc41f6f8" alt=""
changing return type into dict to not make additional cost for chaning return types of functions just in case.
@2bbc71667c01b3eaf51dd4be8597f67421a644a2
--- model/AttentiveRNN.py
+++ model/AttentiveRNN.py
... | ... | @@ -152,7 +152,13 @@ |
152 | 152 |
x = self.resnet(original_image) |
153 | 153 |
attention_map, cell_state, lstm_feats = self.LSTM(x, prev_cell_state) |
154 | 154 |
x = attention_map * original_image |
155 |
- return x, attention_map, cell_state, lstm_feats |
|
155 |
+ ret = { |
|
156 |
+ 'x' : x, |
|
157 |
+ 'attention_map' : attention_map, |
|
158 |
+ 'cell_state' : cell_state, |
|
159 |
+ 'lstm_feats' : lstm_feats |
|
160 |
+ } |
|
161 |
+ return ret |
|
156 | 162 |
|
157 | 163 |
|
158 | 164 |
class AttentiveRNN(nn.Module): |
... | ... | @@ -204,7 +210,7 @@ |
204 | 210 |
'attention_map_list' : attention_map, |
205 | 211 |
'lstm_feats' : lstm_feats |
206 | 212 |
} |
207 |
- return x, attention_map, lstm_feats |
|
213 |
+ return ret |
|
208 | 214 |
|
209 | 215 |
# need fixing |
210 | 216 |
class AttentiveRNNLoss(nn.Module): |
--- model/Discriminator.py
+++ model/Discriminator.py
... | ... | @@ -4,19 +4,18 @@ |
4 | 4 |
class DiscriminativeNet(nn.Module): |
5 | 5 |
def __init__(self, W, H): |
6 | 6 |
super(DiscriminativeNet, self).__init__() |
7 |
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=5, stride=1, padding=2) |
|
8 |
- self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=2) |
|
9 |
- self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2) |
|
10 |
- self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2) |
|
7 |
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=5, stride=2, padding=1) |
|
8 |
+ self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=2, padding=2) |
|
9 |
+ self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=2, padding=2) |
|
10 |
+ self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=2, padding=2) |
|
11 | 11 |
self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2) |
12 | 12 |
self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=2) |
13 |
- self.conv_map = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=5, stride=1, padding=2, bias=False) |
|
13 |
+ self.conv_attention = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=5, stride=1, padding=2, bias=False) |
|
14 | 14 |
self.conv7 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=5, stride=4, padding=2) |
15 | 15 |
self.conv8 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, stride=4, padding=2) |
16 | 16 |
self.conv9 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, stride=4, padding=2) |
17 |
- self.fc1 = nn.Linear(32 * W * H, |
|
18 |
- 1024) # You need to adjust the input dimension here depending on your input size |
|
19 |
- self.fc2 = nn.Linear(1024, 1) |
|
17 |
+ self.fc1 = nn.Linear(32, 1) # You need to adjust the input dimension here depending on your input size |
|
18 |
+ self.fc2 = nn.Linear(1, 1) |
|
20 | 19 |
|
21 | 20 |
def forward(self, x): |
22 | 21 |
x1 = F.leaky_relu(self.conv1(x)) |
... | ... | @@ -25,19 +24,24 @@ |
25 | 24 |
x4 = F.leaky_relu(self.conv4(x3)) |
26 | 25 |
x5 = F.leaky_relu(self.conv5(x4)) |
27 | 26 |
x6 = F.leaky_relu(self.conv6(x5)) |
28 |
- attention_map = self.conv_map(x6) |
|
27 |
+ attention_map = self.conv_attention(x6) |
|
29 | 28 |
x7 = F.leaky_relu(self.conv7(attention_map * x6)) |
30 | 29 |
x8 = F.leaky_relu(self.conv8(x7)) |
31 | 30 |
x9 = F.leaky_relu(self.conv9(x8)) |
32 | 31 |
x9 = x9.view(x9.size(0), -1) # flatten the tensor |
33 | 32 |
fc1 = self.fc1(x9) |
34 |
- fc2 = self.fc2(fc1) |
|
35 |
- fc_out = F.sigmoid(fc2) |
|
33 |
+ fc_raw = self.fc2(fc1) |
|
34 |
+ fc_out = F.sigmoid(fc_raw) |
|
36 | 35 |
|
37 | 36 |
# Ensure fc_out is not exactly 0 or 1 for stability of log operation in loss |
38 | 37 |
fc_out = clamp(fc_out, min=1e-7, max=1 - 1e-7) |
39 | 38 |
|
40 |
- return fc_out, attention_map, fc2 |
|
39 |
+ ret = { |
|
40 |
+ "fc_out" : fc_out, |
|
41 |
+ "attention_map": attention_map, |
|
42 |
+ "fc_raw" : fc_raw |
|
43 |
+ } |
|
44 |
+ return fc_out, attention_map, fc_raw |
|
41 | 45 |
|
42 | 46 |
if __name__ == "__main__": |
43 | 47 |
import torch |
... | ... | @@ -45,5 +49,5 @@ |
45 | 49 |
|
46 | 50 |
torch.set_default_tensor_type(torch.FloatTensor) |
47 | 51 |
generator = DiscriminativeNet(960,540) |
48 |
- batch_size = 2 |
|
52 |
+ batch_size = 1 |
|
49 | 53 |
summary(generator, input_size=(batch_size, 3, 960,540)) |
--- model/Generator.py
+++ model/Generator.py
... | ... | @@ -9,7 +9,7 @@ |
9 | 9 |
super(Generator, self).__init__() |
10 | 10 |
if kernel_size is None: |
11 | 11 |
kernel_size = [3, 3] |
12 |
- self.attentiveRNN = AttentiveRNN( repetition, |
|
12 |
+ self.attentiveRNN = AttentiveRNN(repetition, |
|
13 | 13 |
blocks=blocks, layers=layers, input_ch=input_ch, out_ch=out_ch, |
14 | 14 |
kernel_size=None, stride=stride, padding=padding, groups=groups, dilation=dilation |
15 | 15 |
) |
... | ... | @@ -26,8 +26,8 @@ |
26 | 26 |
self.sigmoid = nn.Sigmoid() |
27 | 27 |
|
28 | 28 |
def forward(self, x): |
29 |
- x, attention_map = self.attentiveRNN(x) |
|
30 |
- x = self.autoencoder(x * attention_map) |
|
29 |
+ ret = self.attentiveRNN(x) |
|
30 |
+ x = self.autoencoder(ret['x'] * ret['attention_map_list'][-1]) |
|
31 | 31 |
return x |
32 | 32 |
|
33 | 33 |
if __name__ == "__main__": |
... | ... | @@ -37,4 +37,4 @@ |
37 | 37 |
torch.set_default_tensor_type(torch.FloatTensor) |
38 | 38 |
generator = Generator(3, blocks=2) |
39 | 39 |
batch_size = 2 |
40 |
- summary(generator, input_size=(batch_size, 3, 960,540)) |
|
40 |
+ summary(generator, input_size=(batch_size, 3, 720,720)) |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?