import torch from torch import nn from torch.nn import functional as F from torchvision.models import vgg16 class AutoEncoder(nn.Module): def __init__(self): super(AutoEncoder, self).__init__() #layers self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=2, bias=False) self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False) self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False) self.conv4 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False) self.conv5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) self.conv6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) self.dilated_conv1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=2, padding=2, bias=False) self.dilated_conv2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=4, padding=4, bias=False) self.dilated_conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=8, padding=8, bias=False) self.dilated_conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=16, padding=16, bias=False) self.conv7 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) self.conv8 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) self.deconv1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False) self.avg_pool1 = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) self.deconv2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False) self.avg_pool2 = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) self.conv9 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False) self.conv10 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False) self.skip_output1 = nn.Conv2d(in_channels=256, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False) self.skip_output2 = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False) self.skip_output3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False) # Loss specific definitions # TODO # the paper uses vgg16 for features extraction, however, since vgg16 is not a light # model, we may consider it to be replaced Loss function for Autoencoder, also we may have to use other # pretrained network for feature extractor for the loss function or even not using feature extractor. # I honestly do not think using neural network for VGG is strictly necessary, and may have to be replaced with # other image preprocessing like MSCN which was implemented before. self.vgg = vgg16(weights='VGG16_Weights.IMAGENET1K_V1').features self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False self.lambda_i = [0.6, 0.8, 1.0] def forward(self, input_tensor): # Feed the input through each layer x = torch.relu(self.conv1(input_tensor)) relu1 = x x = torch.relu(self.conv2(x)) x = torch.relu(self.conv3(x)) relu3 = x x = torch.relu(self.conv4(x)) x = torch.relu(self.conv5(x)) x = torch.relu(self.conv6(x)) x = torch.relu(self.dilated_conv1(x)) x = torch.relu(self.dilated_conv2(x)) x = torch.relu(self.dilated_conv3(x)) x = torch.relu(self.dilated_conv4(x)) x = torch.relu(self.conv7(x)) x = torch.relu(self.conv8(x)) relu12 = x deconv1 = self.deconv1(relu12) avg_pool1 = self.avg_pool1(deconv1) relu13 = torch.relu(avg_pool1) relu14 = torch.relu(self.conv9(relu13 + relu3)) deconv2 = self.deconv2(relu14) avg_pool2 = self.avg_pool2(deconv2) relu15 = torch.relu(avg_pool2) relu16 = torch.relu(self.conv10(relu15 + relu1)) skip_output_1 = self.skip_output1(relu12) skip_output_2 = self.skip_output2(relu14) output = self.skip_output3(relu16) skip_output_3 = torch.tanh(output) ret = { 'skip_1': skip_output_1, 'skip_2': skip_output_2, 'skip_3': skip_output_3, 'output': output } return ret def loss(self, input_clean_img, input_rainy_img): ori_height, ori_width = input_clean_img.shape[2:] # Rescale labels to match the scales of the outputs label_tensor_resize_2 = F.interpolate(input_clean_img, size=(ori_height // 2, ori_width // 2)) label_tensor_resize_4 = F.interpolate(input_clean_img, size=(ori_height // 4, ori_width // 4)) label_list = [label_tensor_resize_4, label_tensor_resize_2, input_clean_img] inference_ret = self.forward(input_rainy_img) output_list = [inference_ret['skip_1'], inference_ret['skip_2'], inference_ret['skip_3']] # Compute lm_loss lm_loss = 0.0 for index, output in enumerate(output_list): mse_loss = nn.MSELoss()(output, label_list[index]) * self.lambda_i[index] lm_loss += mse_loss # Compute lp_loss src_vgg_feats = self.vgg(input_clean_img) pred_vgg_feats = self.vgg(output_list[-1]) lp_losses = [] for index in range(len(src_vgg_feats)): lp_losses.append(nn.MSELoss()(src_vgg_feats[index], pred_vgg_feats[index])) lp_loss = torch.mean(torch.stack(lp_losses)) loss = lm_loss + lp_loss return loss