File name
Commit message
Commit date
File name
Commit message
Commit date
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models import vgg16
class AutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=2, bias=False)
self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False)
self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
self.conv4 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False)
self.conv5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
self.conv6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
self.dilated_conv1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=2, padding=2, bias=False)
self.dilated_conv2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=4, padding=4, bias=False)
self.dilated_conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=8, padding=8, bias=False)
self.dilated_conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=16, padding=16, bias=False)
self.conv7 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
self.conv8 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
self.deconv1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False)
self.avg_pool1 = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
self.deconv2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False)
self.avg_pool2 = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
self.conv9 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
self.conv10 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False)
self.skip_output1 = nn.Conv2d(in_channels=256, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
self.skip_output2 = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
self.skip_output3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
# maybe change it into concat Networks? this seems way to cumbersome.
def forward(self, input_tensor):
# Feed the input through each layer
x = torch.relu(self.conv1(input_tensor))
relu1 = x
x = torch.relu(self.conv2(x))
x = torch.relu(self.conv3(x))
relu3 = x
x = torch.relu(self.conv4(x))
x = torch.relu(self.conv5(x))
x = torch.relu(self.conv6(x))
x = torch.relu(self.dilated_conv1(x))
x = torch.relu(self.dilated_conv2(x))
x = torch.relu(self.dilated_conv3(x))
x = torch.relu(self.dilated_conv4(x))
x = torch.relu(self.conv7(x))
x = torch.relu(self.conv8(x))
relu12 = x
deconv1 = self.deconv1(relu12)
avg_pool1 = self.avg_pool1(deconv1)
relu13 = torch.relu(avg_pool1)
relu14 = torch.relu(self.conv9(relu13 + relu3))
deconv2 = self.deconv2(relu14)
avg_pool2 = self.avg_pool2(deconv2)
relu15 = torch.relu(avg_pool2)
relu16 = torch.relu(self.conv10(relu15 + relu1))
skip_output_1 = self.skip_output1(relu12)
skip_output_2 = self.skip_output2(relu14)
skip_output_3 = torch.tanh(self.skip_output3(relu16))
ret = {
'skip_1': skip_output_1,
'skip_2': skip_output_2,
'skip_3': skip_output_3,
}
return ret
class LossFunction(nn.Module):
def __init__(self):
super(LossFunction, self).__init__()
# Load pre-trained VGG model for feature extraction
self.vgg = vgg16(pretrained=True).features
self.vgg.eval()
for param in self.vgg.parameters():
param.requires_grad = False
self.lambda_i = [0.6, 0.8, 1.0]
def forward(self, input_tensor, label_tensor):
ori_height, ori_width = label_tensor.shape[2:]
# Rescale labels to match the scales of the outputs
label_tensor_resize_2 = F.interpolate(label_tensor, size=(ori_height // 2, ori_width // 2))
label_tensor_resize_4 = F.interpolate(label_tensor, size=(ori_height // 4, ori_width // 4))
label_list = [label_tensor_resize_4, label_tensor_resize_2, label_tensor]
# Initialize autoencoder model
autoencoder = AutoEncoder()
inference_ret = autoencoder(input_tensor)
output_list = [inference_ret['skip_1'], inference_ret['skip_2'], inference_ret['skip_3']]
# Compute lm_loss
lm_loss = 0.0
for index, output in enumerate(output_list):
mse_loss = nn.MSELoss()(output, label_list[index]) * self.lambda_i[index]
lm_loss += mse_loss
# Compute lp_loss
src_vgg_feats = self.extract_vgg_feats(label_tensor)
pred_vgg_feats = self.extract_vgg_feats(output_list[-1])
lp_losses = []
for index in range(len(src_vgg_feats)):
lp_losses.append(nn.MSELoss()(src_vgg_feats[index], pred_vgg_feats[index]))
lp_loss = torch.mean(torch.stack(lp_losses))
loss = lm_loss + lp_loss
return loss, inference_ret['skip_3']
def extract_vgg_feats(self, input_tensor):
# Extract features from the input tensor using the VGG network
feats = []
x = input_tensor
for layer_num, layer in enumerate(self.vgg):
x = layer(x)
if layer_num in {3, 8, 15, 22, 29}:
feats.append(x)
return feats
if __name__ == "__main__":
from torchinfo import summary
torch.set_default_tensor_type(torch.FloatTensor)
generator = AutoEncoder()
batch_size = 2
summary(generator, input_size=(batch_size, 3, 960,540))