from torch import nn class Conv3by3(nn.Module): def __init__(self, in_ch, out_ch): super(Conv3by3, self).__init__() self.conv3by3 = nn.Sequential( nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1), nn.ReLU() ) def forward(self, x): return self.conv3by3(x) class Resnet(nn.Module): def __init__(self, classes=2, in_ch=3): super(Resnet, self).__init__() self.firstconv = nn.Sequential( nn.Conv2d(in_channels=in_ch, out_channels=64, kernel_size=7), nn.ReLU(), nn.AvgPool2d(kernel_size=2, stride=2) ) self.block1_1 = nn.Sequential( Conv3by3(64, 64), Conv3by3(64, 64), ) self.block1_2 = nn.Sequential( Conv3by3(64, 64), Conv3by3(64, 64), ) self.block1_3 = nn.Sequential( Conv3by3(64, 64), Conv3by3(64, 64), ) self.blockshort_1to2 = nn.Sequential( nn.AvgPool2d(kernel_size=2,stride=2) ) self.block2_1 = nn.Sequential( Conv3by3(64, 128), Conv3by3(128, 128), ) self.block2_2 = nn.Sequential( Conv3by3(128, 128), Conv3by3(128, 128), ) self.block2_3 = nn.Sequential( Conv3by3(128, 128), Conv3by3(128, 128), ) self.blockshort_2to3 = nn.Sequential( nn.AvgPool2d(kernel_size=2, stride=2), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1) ) self.block3_1 = nn.Sequential( Conv3by3(128, 256), Conv3by3(256, 256), ) self.block3_2 = nn.Sequential( Conv3by3(256, 256), Conv3by3(256, 256), ) self.block3_3 = nn.Sequential( Conv3by3(256, 256), Conv3by3(256, 256), ) self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(256, classes) # assuming 10 classes for the classification def forward(self, x): x = self.firstconv(x) identity = x out = self.block1_1(x) out = out + identity out = self.block1_2(out) out = out + identity out = self.block1_3(out) out = out + identity out = self.block2_1(out) out = self.blockshort_1to2(out) identity = out out = self.block2_2(out) out = out + identity out = self.block2_3(out) out = out + identity out = self.block3_1(out) out = self.blockshort_2to3(out) identity = out out = self.block3_2(out) out = out + identity out = self.block3_3(out) out = out + identity out = self.global_pool(out) out = out.view(out.size(0), -1) out = self.fc(out) return out