File name
Commit message
Commit date
File name
Commit message
Commit date
from model.AttentiveRNN import AttentiveRNN
from model.Autoencoder import AutoEncoder
from torch import nn
from torch import sum, pow, abs
class Generator(nn.Module):
def __init__(self, repetition, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1,
dilation=1):
super(Generator, self).__init__()
if kernel_size is None:
kernel_size = [3, 3]
self.attentiveRNN = AttentiveRNN(repetition,
blocks=blocks, layers=layers, input_ch=input_ch, out_ch=out_ch,
kernel_size=None, stride=stride, padding=padding, groups=groups, dilation=dilation
)
self.autoencoder = AutoEncoder()
self.blocks = blocks
self.layers = layers
self.input_ch = input_ch
self.out_ch = out_ch
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
self.dilation = dilation
self.sigmoid = nn.Sigmoid()
self.name = "Generator of Attentive GAN"
def forward(self, x):
attentiveRNNresults = self.attentiveRNN(x)
x = self.autoencoder(attentiveRNNresults['x'] * attentiveRNNresults['attention_map_list'][-1])
ret = {
'x' : x,
'attention_maps' : attentiveRNNresults['attention_map_list']
}
return ret
def binary_diff_mask(self, clean, dirty, thresold=0.1):
# this parts corrects gamma, and always remember, sRGB values are not in linear scale with lights intensity,
clean = pow(clean, 0.45)
dirty = pow(dirty, 0.45)
diff = abs(clean - dirty)
diff = sum(diff, dim=1)
bin_diff = (diff < thresold).to(clean.dtype)
return bin_diff
def loss(self, clean, dirty, thresold=0.1):
# check diff if they are working as intended
diff_mask = self.binary_diff_mask(clean, dirty, thresold)
attentive_rnn_loss = self.attentiveRNN.loss(clean, diff_mask)
autoencoder_loss = self.autoencoder.loss(clean, dirty)
ret = {
"attentive_rnn_loss" : attentive_rnn_loss,
"autoencoder_loss" : autoencoder_loss,
}
return ret
if __name__ == "__main__":
import torch
from torchinfo import summary
torch.set_default_tensor_type(torch.FloatTensor)
generator = Generator(3, blocks=2)
batch_size = 2
summary(generator, input_size=(batch_size, 3, 720,720))