from AttentiveRNN import AttentiveRNN from Autoencoder import AutoEncoder from torch import nn class Generator(nn.Module): def __init__(self, repetition, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1, dilation=1): super(Generator, self).__init__() if kernel_size is None: kernel_size = [3, 3] self.attentiveRNN = AttentiveRNN( repetition, blocks=blocks, layers=layers, input_ch=input_ch, out_ch=out_ch, kernel_size=None, stride=stride, padding=padding, groups=groups, dilation=dilation ) self.autoencoder = AutoEncoder() self.blocks = blocks self.layers = layers self.input_ch = input_ch self.out_ch = out_ch self.kernel_size = kernel_size self.stride = stride self.padding = padding self.groups = groups self.dilation = dilation self.sigmoid = nn.Sigmoid() def forward(self, x): x, attention_map = self.attentiveRNN(x) x = self.autoencoder(x * attention_map) return x if __name__ == "__main__": import torch from torchinfo import summary torch.set_default_tensor_type(torch.FloatTensor) generator = Generator(3, blocks=2) batch_size = 2 summary(generator, input_size=(batch_size, 3, 960,540))