File name
Commit message
Commit date
File name
Commit message
Commit date
Tried to make logger with dash, however it did not work as I have planned, so instead, visdom is used. Also, now the dataloader shuffles the dataset by default
2023-07-04
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models import vgg16
class AutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder, self).__init__()
#layers
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=2, bias=False)
self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False)
self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
self.conv4 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False)
self.conv5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
self.conv6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
self.dilated_conv1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=2, padding=2, bias=False)
self.dilated_conv2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=4, padding=4, bias=False)
self.dilated_conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=8, padding=8, bias=False)
self.dilated_conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=16, padding=16, bias=False)
self.conv7 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
self.conv8 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
self.deconv1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False)
self.avg_pool1 = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
self.deconv2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False)
self.avg_pool2 = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
self.conv9 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
self.conv10 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False)
self.skip_output1 = nn.Conv2d(in_channels=256, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
self.skip_output2 = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
self.skip_output3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
# Loss specific definitions
# TODO
# the paper uses vgg16 for features extraction, however, since vgg16 is not a light
# model, we may consider it to be replaced Loss function for Autoencoder, also we may have to use other
# pretrained network for feature extractor for the loss function or even not using feature extractor.
# I honestly do not think using neural network for VGG is strictly necessary, and may have to be replaced with
# other image preprocessing like MSCN which was implemented before.
self.vgg = vgg16(weights='VGG16_Weights.IMAGENET1K_V1').features
self.vgg.eval()
for param in self.vgg.parameters():
param.requires_grad = False
self.lambda_i = [0.6, 0.8, 1.0]
def forward(self, input_tensor):
# Feed the input through each layer
x = torch.relu(self.conv1(input_tensor))
relu1 = x
x = torch.relu(self.conv2(x))
x = torch.relu(self.conv3(x))
relu3 = x
x = torch.relu(self.conv4(x))
x = torch.relu(self.conv5(x))
x = torch.relu(self.conv6(x))
x = torch.relu(self.dilated_conv1(x))
x = torch.relu(self.dilated_conv2(x))
x = torch.relu(self.dilated_conv3(x))
x = torch.relu(self.dilated_conv4(x))
x = torch.relu(self.conv7(x))
x = torch.relu(self.conv8(x))
relu12 = x
deconv1 = self.deconv1(relu12)
avg_pool1 = self.avg_pool1(deconv1)
relu13 = torch.relu(avg_pool1)
relu14 = torch.relu(self.conv9(relu13 + relu3))
deconv2 = self.deconv2(relu14)
avg_pool2 = self.avg_pool2(deconv2)
relu15 = torch.relu(avg_pool2)
relu16 = torch.relu(self.conv10(relu15 + relu1))
skip_output_1 = self.skip_output1(relu12)
skip_output_2 = self.skip_output2(relu14)
output = self.skip_output3(relu16)
skip_output_3 = torch.tanh(output)
ret = {
'skip_1': skip_output_1,
'skip_2': skip_output_2,
'skip_3': skip_output_3,
'output': output
}
return ret
def loss(self, input_clean_img, input_rainy_img):
ori_height, ori_width = input_clean_img.shape[2:]
# Rescale labels to match the scales of the outputs
label_tensor_resize_2 = F.interpolate(input_clean_img, size=(ori_height // 2, ori_width // 2))
label_tensor_resize_4 = F.interpolate(input_clean_img, size=(ori_height // 4, ori_width // 4))
label_list = [label_tensor_resize_4, label_tensor_resize_2, input_clean_img]
inference_ret = self.forward(input_rainy_img)
output_list = [inference_ret['skip_1'], inference_ret['skip_2'], inference_ret['skip_3']]
# Compute lm_loss
lm_loss = 0.0
for index, output in enumerate(output_list):
mse_loss = nn.MSELoss()(output, label_list[index]) * self.lambda_i[index]
lm_loss += mse_loss
# Compute lp_loss
src_vgg_feats = self.vgg(input_clean_img)
pred_vgg_feats = self.vgg(output_list[-1])
lp_losses = []
for index in range(len(src_vgg_feats)):
lp_losses.append(nn.MSELoss()(src_vgg_feats[index], pred_vgg_feats[index]))
lp_loss = torch.mean(torch.stack(lp_losses))
loss = lm_loss + lp_loss
return loss