File name
Commit message
Commit date
File name
Commit message
Commit date
from torch.utils.data import Dataset
from torchvision.io import read_image
from typing import Union, List
import numpy as np
class ImagePairDataset(Dataset):
def __init__(self, clean_img_dir: List[str], rainy_img_dirs: Union[List[str], List[List[str]]], transform=None):
self.clean_img = clean_img_dir
self.dirty_img = rainy_img_dirs
self.transform = transform
self.rng = np.random.default_rng()
def __len__(self):
return len(self.clean_img)
def __getitem__(self, idx):
clean_img_path = self.clean_img[idx]
rainy_img_paths = self.dirty_img[idx]
if isinstance(rainy_img_paths, list):
i = self.rng.integers(low=0, high=len(rainy_img_paths) - 1)
rainy_img_path = rainy_img_paths[i]
else:
rainy_img_path = rainy_img_paths
clean_image = read_image(clean_img_path)
rainy_image = read_image(rainy_img_path)
if self.transform:
clean_image = self.transform(clean_image)
rainy_image = self.transform(rainy_image)
return {
"clean_image": clean_image,
"rainy_image": rainy_image
}
def __add__(self, other):
assert isinstance(other, ImagePairDataset), "other must be an instance of ImagePairDataset"
return ImagePairDataset(
clean_img_dir=self.clean_img + other.clean_img,
rainy_img_dirs=self.dirty_img + other.dirty_img,
transform=self.transform or other.transform
)