File name
Commit message
Commit date
File name
Commit message
Commit date
from torch.utils.data import Dataset
from torchvision.io import read_image
import numpy as np
class ImagePairDataset(Dataset):
def __init__(self, clean_img_dir, rainy_img_dirs, transform=None):
self.clean_img = clean_img_dir
self.rainy_img = rainy_img_dirs
self.transform = transform
def __len__(self):
return len(self.clean_img)
def __getitem__(self, idx):
clean_img_path = self.clean_img[idx]
i = 0
if len(self.rainy_img[idx]) is list:
rng = np.random.default_rng()
i = rng.integers(low=0, high=len(self.rainy_img)-1)
rainy_img_path = self.rainy_img[idx][i]
else:
rainy_img_path = self.rainy_img[idx]
clean_image = read_image(clean_img_path)
rainy_image = read_image(rainy_img_path)
if self.transform:
clean_image = self.transform(clean_image)
rainy_image = self.transform(rainy_image)
ret = {
"clean_image" : clean_image,
"rainy_image" : rainy_image
}
return ret
def __add__(self, other):
return ImagePairDataset(
clean_img_dir=self.clean_img+other.clean_img,
rainy_img_dirs=self.rainy_img+other.rainy_img,
transform=self.transform
)