from torch import nn, clamp from torch.functional import F class DiscriminativeNet(nn.Module): def __init__(self, W, H): super(DiscriminativeNet, self).__init__() self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=5, stride=1, padding=2) self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=2) self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2) self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2) self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2) self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=2) self.conv_map = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=5, stride=1, padding=2, bias=False) self.conv7 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=5, stride=4, padding=2) self.conv8 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, stride=4, padding=2) self.conv9 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, stride=4, padding=2) self.fc1 = nn.Linear(32 * W * H, 1024) # You need to adjust the input dimension here depending on your input size self.fc2 = nn.Linear(1024, 1) def forward(self, x): x1 = F.leaky_relu(self.conv1(x)) x2 = F.leaky_relu(self.conv2(x1)) x3 = F.leaky_relu(self.conv3(x2)) x4 = F.leaky_relu(self.conv4(x3)) x5 = F.leaky_relu(self.conv5(x4)) x6 = F.leaky_relu(self.conv6(x5)) attention_map = self.conv_map(x6) x7 = F.leaky_relu(self.conv7(attention_map * x6)) x8 = F.leaky_relu(self.conv8(x7)) x9 = F.leaky_relu(self.conv9(x8)) x9 = x9.view(x9.size(0), -1) # flatten the tensor fc1 = self.fc1(x9) fc2 = self.fc2(fc1) fc_out = F.sigmoid(fc2) # Ensure fc_out is not exactly 0 or 1 for stability of log operation in loss fc_out = clamp(fc_out, min=1e-7, max=1 - 1e-7) return fc_out, attention_map, fc2