from torch.utils.data import Dataset from torchvision.io import read_image from typing import Union, List import numpy as np class ImagePairDataset(Dataset): def __init__(self, clean_img_dir: List[str], rainy_img_dirs: Union[List[str], List[List[str]]], transform=None): self.clean_img = clean_img_dir self.dirty_img = rainy_img_dirs self.transform = transform self.rng = np.random.default_rng() def __len__(self): return len(self.clean_img) def __getitem__(self, idx): clean_img_path = self.clean_img[idx] rainy_img_paths = self.dirty_img[idx] if isinstance(rainy_img_paths, list): i = self.rng.integers(low=0, high=len(rainy_img_paths) - 1) rainy_img_path = rainy_img_paths[i] else: rainy_img_path = rainy_img_paths clean_image = read_image(clean_img_path) rainy_image = read_image(rainy_img_path) if self.transform: clean_image = self.transform(clean_image) rainy_image = self.transform(rainy_image) return { "clean_image": clean_image, "rainy_image": rainy_image } def __add__(self, other): assert isinstance(other, ImagePairDataset), "other must be an instance of ImagePairDataset" return ImagePairDataset( clean_img_dir=self.clean_img + other.clean_img, rainy_img_dirs=self.dirty_img + other.dirty_img, transform=self.transform or other.transform )