
File name
Commit message
Commit date
File name
Commit message
Commit date
2023-06-21
from torch import nn, clamp
from torch.functional import F
class DiscriminativeNet(nn.Module):
def __init__(self, W, H):
super(DiscriminativeNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=5, stride=1, padding=2)
self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=2)
self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2)
self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2)
self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2)
self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=2)
self.conv_map = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=5, stride=1, padding=2, bias=False)
self.conv7 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=5, stride=4, padding=2)
self.conv8 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, stride=4, padding=2)
self.conv9 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, stride=4, padding=2)
self.fc1 = nn.Linear(32 * W * H,
1024) # You need to adjust the input dimension here depending on your input size
self.fc2 = nn.Linear(1024, 1)
def forward(self, x):
x1 = F.leaky_relu(self.conv1(x))
x2 = F.leaky_relu(self.conv2(x1))
x3 = F.leaky_relu(self.conv3(x2))
x4 = F.leaky_relu(self.conv4(x3))
x5 = F.leaky_relu(self.conv5(x4))
x6 = F.leaky_relu(self.conv6(x5))
attention_map = self.conv_map(x6)
x7 = F.leaky_relu(self.conv7(attention_map * x6))
x8 = F.leaky_relu(self.conv8(x7))
x9 = F.leaky_relu(self.conv9(x8))
x9 = x9.view(x9.size(0), -1) # flatten the tensor
fc1 = self.fc1(x9)
fc2 = self.fc2(fc1)
fc_out = F.sigmoid(fc2)
# Ensure fc_out is not exactly 0 or 1 for stability of log operation in loss
fc_out = clamp(fc_out, min=1e-7, max=1 - 1e-7)
return fc_out, attention_map, fc2