File name
Commit message
Commit date
2023-10-24
2023-10-24
2023-10-24
2023-10-24
2023-10-24
2023-10-24
File name
Commit message
Commit date
2023-10-24
import torch
from torch import nn
from torch.nn import functional as F
# nn.Sequential does not handle multiple input by design, and this is a workaround
# https://github.com/pytorch/pytorch/issues/19808#
class mySequential(nn.Sequential):
def forward(self, *input):
for module in self._modules.values():
input = module(*input)
return input
class ResNetBlock(nn.Module):
def __init__(self, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1,
dilation=1):
"""
:type kernel_size: iterator or int
"""
super(ResNetBlock, self).__init__()
if kernel_size is None:
kernel_size = [3, 3]
self.conv1 = nn.Conv2d(
input_ch, out_ch,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
dilation=dilation
)
self.conv2 = nn.Sequential(
nn.Conv2d(
out_ch, out_ch,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
dilation=dilation
),
nn.LeakyReLU()
)
self.conv_hidden = nn.ModuleList()
for block in range(blocks):
for layer in range(layers):
self.conv_hidden.append(
self.conv2
)
self.blocks = blocks
self.layers = layers
def forward(self, x):
x = self.conv1(x)
x = F.leaky_relu(x)
shortcut = x
for i, hidden_layer in enumerate(self.conv_hidden):
x = hidden_layer(x)
if (i % self.layers == 0) & (i != 0):
x = x + shortcut
return x
class ConvLSTM(nn.Module):
def __init__(self, ch, kernel_size=3):
super(ConvLSTM, self).__init__()
self.padding = (len(kernel_size)-1)/2
self.conv_i = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=kernel_size, stride=1, padding=1,
bias=False)
self.conv_f = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=kernel_size, stride=1, padding=1,
bias=False)
self.conv_c = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=kernel_size, stride=1, padding=1,
bias=False)
self.conv_o = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=kernel_size, stride=1, padding=1,
bias=False)
self.conv_attention_map = nn.Conv2d(in_channels=ch, out_channels=1, kernel_size=kernel_size, stride=1,
padding=1, bias=False)
self.ch = ch
def init_hidden(self, batch_size, image_size, init=0.0):
height, width = image_size
return torch.ones(batch_size, self.ch, height, width).to(dtype=self.conv_i.weight.dtype , device=self.conv_i.weight.device) * init
def forward(self, input_tensor, input_cell_state=None):
if input_cell_state is None:
batch_size, _, height, width = input_tensor.size()
input_cell_state = self.init_hidden(batch_size, (height, width))
conv_i = self.conv_i(input_tensor)
sigmoid_i = torch.sigmoid(conv_i)
conv_f = self.conv_f(input_tensor)
sigmoid_f = torch.sigmoid(conv_f)
cell_state = sigmoid_f * input_cell_state + sigmoid_i * torch.tanh(self.conv_c(input_tensor))
conv_o = self.conv_o(input_tensor)
sigmoid_o = torch.sigmoid(conv_o)
lstm_feats = sigmoid_o * torch.tanh(cell_state)
attention_map = self.conv_attention_map(lstm_feats)
attention_map = torch.sigmoid(attention_map)
ret = {
"attention_map" : attention_map,
"cell_state" : cell_state,
"lstm_feats" : lstm_feats
}
return ret
class AttentiveRNNBLCK(nn.Module):
def __init__(self, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1,
dilation=1):
"""
:type kernel_size: iterator or int
"""
super(AttentiveRNNBLCK, self).__init__()
if kernel_size is None:
kernel_size = [3, 3]
self.blocks = blocks
self.layers = layers
self.input_ch = input_ch
self.out_ch = out_ch
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
self.dilation = dilation
self.sigmoid = nn.Sigmoid()
self.resnet = nn.Sequential(
ResNetBlock(
blocks=self.blocks,
layers=self.layers,
input_ch=self.input_ch,
out_ch=self.out_ch,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
groups=self.groups,
dilation=self.dilation
)
)
self.LSTM = mySequential(
ConvLSTM(
ch=out_ch, kernel_size=kernel_size,
)
)
def forward(self, original_image, prev_cell_state=None):
x = self.resnet(original_image)
lstm_ret = self.LSTM(x, prev_cell_state)
attention_map = lstm_ret["attention_map"]
cell_state = lstm_ret['cell_state']
lstm_feats = lstm_ret["lstm_feats"]
x = attention_map * original_image
ret = {
'x' : x,
'attention_map' : attention_map,
'cell_state' : cell_state,
'lstm_feats' : lstm_feats
}
return ret
class AttentiveRNN(nn.Module):
def __init__(self, repetition, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1,
groups=1, dilation=1):
"""
:type kernel_size: iterator or int
"""
super(AttentiveRNN, self).__init__()
if kernel_size is None:
kernel_size = [3, 3]
self.blocks = blocks
self.layers = layers
self.input_ch = input_ch
self.out_ch = out_ch
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
self.dilation = dilation
self.repetition = repetition
self.arnn_block = mySequential(
AttentiveRNNBLCK(
blocks=blocks,
layers=layers,
input_ch=input_ch,
out_ch=out_ch,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
dilation=dilation
)
)
self.arnn_blocks = nn.ModuleList()
for repetition in range(repetition):
self.arnn_blocks.append(
self.arnn_block
)
self.name = "AttentiveRNN"
def forward(self, input_img):
cell_state = None
attention_map = []
lstm_feats = []
x = input_img
for arnn_block in self.arnn_blocks:
arnn_block_return = arnn_block(x, cell_state)
attention_map_i = arnn_block_return['attention_map']
lstm_feats_i = arnn_block_return['lstm_feats']
cell_state = arnn_block_return['cell_state']
x = arnn_block_return['x']
attention_map.append(attention_map_i)
lstm_feats.append(lstm_feats_i)
ret = {
'x' : x,
# 'attention_map_list' : attention_map,
# 'lstm_feats' : lstm_feats
}
return ret
#
def loss(self, input_image_tensor, difference_maskmap, theta=0.8):
self.theta = theta
# Initialize attentive rnn model
inference_ret = self.forward(input_image_tensor)
loss = 0.0
n = len(inference_ret['attention_map_list'])
for index, attention_map in enumerate(inference_ret['attention_map_list']):
mse_loss = (self.theta ** (n - index + 1)) * nn.MSELoss()(attention_map, difference_maskmap)
loss += mse_loss
return loss
# Need work
if __name__ == "__main__":
from torchinfo import summary
torch.set_default_tensor_type(torch.FloatTensor)
generator = AttentiveRNN(3, blocks=2)
batch_size = 5
summary(generator, input_size=(batch_size, 3, 960,540))