
File name
Commit message
Commit date

2023-07-04
File name
Commit message
Commit date

2023-07-04
from model.AttentiveRNN import AttentiveRNN
from model.Autoencoder import AutoEncoder
from torch import nn
from torch import sum, pow, abs
class Generator(nn.Module):
def __init__(self, repetition, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1,
dilation=1):
super(Generator, self).__init__()
if kernel_size is None:
kernel_size = [3, 3]
self.attentiveRNN = AttentiveRNN(repetition,
blocks=blocks, layers=layers, input_ch=input_ch, out_ch=out_ch,
kernel_size=None, stride=stride, padding=padding, groups=groups, dilation=dilation
)
self.autoencoder = AutoEncoder()
self.blocks = blocks
self.layers = layers
self.input_ch = input_ch
self.out_ch = out_ch
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
self.dilation = dilation
self.sigmoid = nn.Sigmoid()
def forward(self, x):
attentiveRNNresults = self.attentiveRNN(x)
x = self.autoencoder(attentiveRNNresults['x'] * attentiveRNNresults['attention_map_list'][-1])
ret = {
'x' : x,
'attention_maps' : attentiveRNNresults['attention_map_list']
}
return ret
def binary_diff_mask(self, clean, dirty, thresold=0.1):
# this parts corrects gamma, and always remember, sRGB values are not in linear scale with lights intensity,
clean = pow(clean, 0.45)
dirty = pow(dirty, 0.45)
diff = abs(clean - dirty)
diff = sum(diff, dim=1)
bin_diff = (diff > thresold).to(clean.dtype)
return bin_diff
def loss(self, clean, dirty, thresold=0.1):
# check diff if they are working as intended
diff_mask = self.binary_diff_mask(clean, dirty, thresold)
attentive_rnn_loss = self.attentiveRNN.loss(clean, diff_mask)
autoencoder_loss = self.autoencoder.loss(clean, dirty)
ret = {
"attentive_rnn_loss" : attentive_rnn_loss,
"autoencoder_loss" : autoencoder_loss,
}
return ret
if __name__ == "__main__":
import torch
from torchinfo import summary
torch.set_default_tensor_type(torch.FloatTensor)
generator = Generator(3, blocks=2)
batch_size = 2
summary(generator, input_size=(batch_size, 3, 720,720))