
--- model/Discriminator.py
+++ model/Discriminator.py
... | ... | @@ -16,6 +16,31 @@ |
16 | 16 |
self.conv9 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, stride=4, padding=2) |
17 | 17 |
self.fc1 = nn.Linear(32, 1) # You need to adjust the input dimension here depending on your input size |
18 | 18 |
self.fc2 = nn.Linear(1, 1) |
19 |
+ def forward(self, x): |
|
20 |
+ x1 = F.leaky_relu(self.conv1(x)) |
|
21 |
+ x2 = F.leaky_relu(self.conv2(x1)) |
|
22 |
+ x3 = F.leaky_relu(self.conv3(x2)) |
|
23 |
+ x4 = F.leaky_relu(self.conv4(x3)) |
|
24 |
+ x5 = F.leaky_relu(self.conv5(x4)) |
|
25 |
+ x6 = F.leaky_relu(self.conv6(x5)) |
|
26 |
+ attention_map = self.conv_attention(x6) |
|
27 |
+ x7 = F.leaky_relu(self.conv7(attention_map * x6)) |
|
28 |
+ x8 = F.leaky_relu(self.conv8(x7)) |
|
29 |
+ x9 = F.leaky_relu(self.conv9(x8)) |
|
30 |
+ x9 = x9.view(x9.size(0), -1) # flatten the tensor |
|
31 |
+ fc1 = self.fc1(x9) |
|
32 |
+ fc_raw = self.fc2(fc1) |
|
33 |
+ fc_out = F.sigmoid(fc_raw) |
|
34 |
+ |
|
35 |
+ # Ensure fc_out is not exactly 0 or 1 for stability of log operation in loss |
|
36 |
+ fc_out = clamp(fc_out, min=1e-7, max=1 - 1e-7) |
|
37 |
+ |
|
38 |
+ ret = { |
|
39 |
+ "fc_out" : fc_out, |
|
40 |
+ "attention_map": attention_map, |
|
41 |
+ "fc_raw" : fc_raw |
|
42 |
+ } |
|
43 |
+ return ret |
|
19 | 44 |
|
20 | 45 |
def loss(self, real_clean, label_tensor, attention_map): |
21 | 46 |
""" |
... | ... | @@ -44,32 +69,6 @@ |
44 | 69 |
loss = entropy_loss + 0.05 * l_map |
45 | 70 |
|
46 | 71 |
return fc_out_o, loss |
47 |
- |
|
48 |
- def forward(self, x): |
|
49 |
- x1 = F.leaky_relu(self.conv1(x)) |
|
50 |
- x2 = F.leaky_relu(self.conv2(x1)) |
|
51 |
- x3 = F.leaky_relu(self.conv3(x2)) |
|
52 |
- x4 = F.leaky_relu(self.conv4(x3)) |
|
53 |
- x5 = F.leaky_relu(self.conv5(x4)) |
|
54 |
- x6 = F.leaky_relu(self.conv6(x5)) |
|
55 |
- attention_map = self.conv_attention(x6) |
|
56 |
- x7 = F.leaky_relu(self.conv7(attention_map * x6)) |
|
57 |
- x8 = F.leaky_relu(self.conv8(x7)) |
|
58 |
- x9 = F.leaky_relu(self.conv9(x8)) |
|
59 |
- x9 = x9.view(x9.size(0), -1) # flatten the tensor |
|
60 |
- fc1 = self.fc1(x9) |
|
61 |
- fc_raw = self.fc2(fc1) |
|
62 |
- fc_out = F.sigmoid(fc_raw) |
|
63 |
- |
|
64 |
- # Ensure fc_out is not exactly 0 or 1 for stability of log operation in loss |
|
65 |
- fc_out = clamp(fc_out, min=1e-7, max=1 - 1e-7) |
|
66 |
- |
|
67 |
- ret = { |
|
68 |
- "fc_out" : fc_out, |
|
69 |
- "attention_map": attention_map, |
|
70 |
- "fc_raw" : fc_raw |
|
71 |
- } |
|
72 |
- return ret |
|
73 | 72 |
|
74 | 73 |
if __name__ == "__main__": |
75 | 74 |
import torch |
--- model/Generator.py
+++ model/Generator.py
... | ... | @@ -34,9 +34,22 @@ |
34 | 34 |
} |
35 | 35 |
return ret |
36 | 36 |
|
37 |
- def loss(self, x, diff): |
|
38 |
- self.attentiveRNN.loss(x, diff) |
|
39 |
- self.autoencoder.loss(x) |
|
37 |
+ def binary_diff_mask(self, clean, dirty, thresold=0.1): |
|
38 |
+ clean = torch.pow(clean, 1/2.2) |
|
39 |
+ dirty = torch.pow(dirty, 1/2.2) |
|
40 |
+ diff = torch.abs(clean - dirty) |
|
41 |
+ diff = torch.sum(diff, dim=1) |
|
42 |
+ # this line is certainly cause problem for quantization |
|
43 |
+ # like, hardcoding it, what could go wrong? |
|
44 |
+ bin_diff = (diff > thresold).to(clean.dtype) |
|
45 |
+ |
|
46 |
+ return bin_diff |
|
47 |
+ def loss(self, clean, dirty, thresold=0.1): |
|
48 |
+ # check diff if they are working as intended |
|
49 |
+ diff = self.binary_diff_mask(clean, dirty, thresold) |
|
50 |
+ |
|
51 |
+ self.attentiveRNN.loss(clean, diff) |
|
52 |
+ self.autoencoder.loss(clean, dirty) |
|
40 | 53 |
|
41 | 54 |
if __name__ == "__main__": |
42 | 55 |
import torch |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?