import torch from torch import nn from torch.nn import functional as F # nn.Sequential does not handle multiple input by design, and this is a workaround # https://github.com/pytorch/pytorch/issues/19808# class mySequential(nn.Sequential): def forward(self, *input): for module in self._modules.values(): input = module(*input) return input def conv3x3(in_ch, out_ch, stride=1, padding=1, groups=1, dilation=1): return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=padding, groups=groups, dilation=dilation) def conv1x1(in_ch, out_ch, stride=1): return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride) class ResNetBlock(nn.Module): def __init__(self, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1, dilation=1): """ :type kernel_size: iterator or int """ super(ResNetBlock, self).__init__() if kernel_size is None: kernel_size = [3, 3] self.conv1 = nn.Conv2d( input_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, dilation=dilation ) self.conv2 = nn.Sequential( nn.Conv2d( out_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, dilation=dilation ), nn.LeakyReLU() ) self.conv_hidden = nn.ModuleList() for block in range(blocks): for layer in range(layers): self.conv_hidden.append( self.conv2 ) self.blocks = blocks self.layers = layers def forward(self, x): x = self.conv1(x) shortcut = x for i, hidden_layer in enumerate(self.conv_hidden): x = hidden_layer(x) if (i % self.layers == 0) & (i != 0): x = F.leaky_relu(x) x = x + shortcut return x class ConvLSTM(nn.Module): def __init__(self, ch, kernel_size=3): super(ConvLSTM, self).__init__() self.padding = (len(kernel_size)-1)/2 self.conv_i = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=kernel_size, stride=1, padding=1, bias=False) self.conv_f = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=kernel_size, stride=1, padding=1, bias=False) self.conv_c = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=kernel_size, stride=1, padding=1, bias=False) self.conv_o = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=kernel_size, stride=1, padding=1, bias=False) self.conv_attention_map = nn.Conv2d(in_channels=ch, out_channels=1, kernel_size=kernel_size, stride=1, padding=1, bias=False) self.ch = ch def init_hidden(self, batch_size, image_size, init=0.5): height, width = image_size return torch.ones(batch_size, self.ch, height, width).to(dtype=self.conv_i.weight.dtype , device=self.conv_i.weight.device) * init def forward(self, input_tensor, input_cell_state=None): if input_cell_state is None: batch_size, _, height, width = input_tensor.size() input_cell_state = self.init_hidden(batch_size, (height, width)) conv_i = self.conv_i(input_tensor) sigmoid_i = torch.sigmoid(conv_i) conv_f = self.conv_f(input_tensor) sigmoid_f = torch.sigmoid(conv_f) cell_state = sigmoid_f * input_cell_state + sigmoid_i * torch.tanh(self.conv_c(input_tensor)) conv_o = self.conv_o(input_tensor) sigmoid_o = torch.sigmoid(conv_o) lstm_feats = sigmoid_o * torch.tanh(cell_state) attention_map = self.conv_attention_map(lstm_feats) attention_map = torch.sigmoid(attention_map) return attention_map, cell_state, lstm_feats class AttentiveRNNBLCK(nn.Module): def __init__(self, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1, dilation=1): """ :type kernel_size: iterator or int """ super(AttentiveRNNBLCK, self).__init__() if kernel_size is None: kernel_size = [3, 3] self.blocks = blocks self.layers = layers self.input_ch = input_ch self.out_ch = out_ch self.kernel_size = kernel_size self.stride = stride self.padding = padding self.groups = groups self.dilation = dilation self.sigmoid = nn.Sigmoid() self.resnet = nn.Sequential( ResNetBlock( blocks=self.blocks, layers=self.layers, input_ch=self.input_ch, out_ch=self.out_ch, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, groups=self.groups, dilation=self.dilation ) ) self.LSTM = mySequential( ConvLSTM( ch=out_ch, kernel_size=kernel_size, ) ) def forward(self, original_image, prev_cell_state=None): x = self.resnet(original_image) attention_map, cell_state, lstm_feats = self.LSTM(x, prev_cell_state) x = attention_map * original_image ret = { 'x' : x, 'attention_map' : attention_map, 'cell_state' : cell_state, 'lstm_feats' : lstm_feats } return ret class AttentiveRNN(nn.Module): def __init__(self, repetition, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1, dilation=1): """ :type kernel_size: iterator or int """ super(AttentiveRNN, self).__init__() if kernel_size is None: kernel_size = [3, 3] self.blocks = blocks self.layers = layers self.input_ch = input_ch self.out_ch = out_ch self.kernel_size = kernel_size self.stride = stride self.padding = padding self.groups = groups self.dilation = dilation self.repetition = repetition self.generator_block = mySequential( AttentiveRNNBLCK(blocks=blocks, layers=layers, input_ch=input_ch, out_ch=out_ch, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, dilation=dilation) ) self.generator_blocks = nn.ModuleList() for repetition in range(repetition): self.generator_blocks.append( self.generator_block ) def forward(self, x): cell_state = None attention_map = [] lstm_feats = [] for generator_block in self.generator_blocks: generator_block_return = generator_block(x, cell_state) attention_map_i = generator_block_return['attention_map'] lstm_feats_i = generator_block_return['lstm_feats'] cell_state = generator_block_return['cell_state'] x = generator_block_return['x'] attention_map.append(attention_map_i) lstm_feats.append(lstm_feats_i) ret = { 'x' : x, 'attention_map_list' : attention_map, 'lstm_feats' : lstm_feats } return ret # def loss(self, input_image_tensor, difference_maskmap, theta=0.8): self.theta = theta # Initialize attentive rnn model inference_ret = self.forward(input_image_tensor) loss = 0.0 n = len(inference_ret['attention_map_list']) for index, attention_map in enumerate(inference_ret['attention_map_list']): mse_loss = (self.theta ** (n - index + 1)) * nn.MSELoss()(attention_map, difference_maskmap) loss += mse_loss return loss # Need work if __name__ == "__main__": from torchinfo import summary torch.set_default_tensor_type(torch.FloatTensor) generator = AttentiveRNN(3, blocks=2) batch_size = 5 summary(generator, input_size=(batch_size, 3, 960,540))