from torch.utils.data import Dataset from torchvision.io import read_image import numpy as np class ImageDataSet(Dataset): def __init__(self, clean_img_dir, rainy_img_dirs, transform=None): self.clean_img = clean_img_dir self.rainy_img = rainy_img_dirs self.transform = transform def __len__(self): return len(self.clean_img) def __getitem__(self, idx): clean_img_path = self.clean_img[idx] i = 0 if len(self.rainy_img) != 1: rng = np.random.default_rng() i = rng.integers(low=0, high=len(self.rainy_img)-1) rainy_img_path = self.rainy_img[idx][i] clean_image = read_image(clean_img_path) rainy_image = read_image(rainy_img_path) if self.transform: clean_image = self.transform(clean_image) rainy_image = self.rainy_img(rainy_image) ret = { "clean_image" : clean_image, "rainy_image" : rainy_image } return ret def __add__(self, other): return ImageDataSet( clean_img_dir=self.clean_img+other.clean_img, rainy_img_dirs=self.rainy_img+other.rainy_img, transform=self.transform )