from torch.utils.data import Dataset from torchvision.io import read_image import numpy as np # the dataset for this model needs to be in at least, a pair of clean and dirty image. # in other words, 'annotations' is not really appropriate word # however, the problem here is that we need should be able to pair more than one dirty image per clean image, # I am lost now class ImageDataSet(Dataset): def __init__(self, img_dir, annotations, transform=None): self.annotations = annotations self.img_dir = img_dir self.transform = transform # self.target_transform = target_transform # print(self.transform) def __len__(self): return len(self.annotations) def __getitem__(self, idx): img_path = self.img_dir[idx] image = read_image(img_path) label = self.annotations[idx] if self.transform: image = self.transform(image) # if self.target_transform(label): # label = self.target_transform(label) return image, label def __add__(self, other): return ImageDataSet( img_dir=self.img_dir+other.img_dir, annotations=np.array(list(self.annotations)+list(other.annotations)), transform=self.transform )