File name
Commit message
Commit date
2023-10-24
2023-10-24
2023-10-24
File name
Commit message
Commit date
from torch import nn
class Conv3by3(nn.Module):
def __init__(self, in_ch, out_ch):
super(Conv3by3, self).__init__()
self.conv3by3 = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1),
nn.ReLU()
)
def forward(self, x):
return self.conv3by3(x)
class Resnet(nn.Module):
def __init__(self, classes=2, in_ch=3):
super(Resnet, self).__init__()
self.firstconv = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=64, kernel_size=7),
nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2)
)
self.block1_1 = nn.Sequential(
Conv3by3(64, 64),
Conv3by3(64, 64),
)
self.block1_2 = nn.Sequential(
Conv3by3(64, 64),
Conv3by3(64, 64),
)
self.block1_3 = nn.Sequential(
Conv3by3(64, 64),
Conv3by3(64, 64),
)
self.blockshort_1to2 = nn.Sequential(
nn.AvgPool2d(kernel_size=2,stride=2)
)
self.block2_1 = nn.Sequential(
Conv3by3(64, 128),
Conv3by3(128, 128),
)
self.block2_2 = nn.Sequential(
Conv3by3(128, 128),
Conv3by3(128, 128),
)
self.block2_3 = nn.Sequential(
Conv3by3(128, 128),
Conv3by3(128, 128),
)
self.blockshort_2to3 = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1)
)
self.block3_1 = nn.Sequential(
Conv3by3(128, 256),
Conv3by3(256, 256),
)
self.block3_2 = nn.Sequential(
Conv3by3(256, 256),
Conv3by3(256, 256),
)
self.block3_3 = nn.Sequential(
Conv3by3(256, 256),
Conv3by3(256, 256),
)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(256, classes)
def forward(self, x):
x = self.firstconv(x)
identity = x
out = self.block1_1(x)
out = out + identity
out = self.block1_2(out)
out = out + identity
out = self.block1_3(out)
out = out + identity
out = self.block2_1(out)
out = self.blockshort_1to2(out)
identity = out
out = self.block2_2(out)
out = out + identity
out = self.block2_3(out)
out = out + identity
out = self.block3_1(out)
out = self.blockshort_2to3(out)
identity = out
out = self.block3_2(out)
out = out + identity
out = self.block3_3(out)
out = out + identity
out = self.global_pool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out