윤영준 윤영준 2023-07-13
dataloader refactoring
@7156acc18073f5d9abedcdd1bc2629c283d0dafc
tools/dataloader.py
--- tools/dataloader.py
+++ tools/dataloader.py
@@ -1,42 +1,46 @@
 from torch.utils.data import Dataset
 from torchvision.io import read_image
+
+from typing import Union, List
 import numpy as np
 
 
 class ImagePairDataset(Dataset):
-    def __init__(self, clean_img_dir, rainy_img_dirs, transform=None):
+    def __init__(self, clean_img_dir: List[str], rainy_img_dirs: Union[List[str], List[List[str]]], transform=None):
         self.clean_img = clean_img_dir
-        self.rainy_img = rainy_img_dirs
+        self.dirty_img = rainy_img_dirs
         self.transform = transform
+        self.rng = np.random.default_rng()
 
     def __len__(self):
         return len(self.clean_img)
 
     def __getitem__(self, idx):
         clean_img_path = self.clean_img[idx]
+        rainy_img_paths = self.dirty_img[idx]
 
-        i = 0
-        if len(self.rainy_img[idx]) is list:
-            rng = np.random.default_rng()
-            i = rng.integers(low=0, high=len(self.rainy_img)-1)
-            rainy_img_path = self.rainy_img[idx][i]
+        if isinstance(rainy_img_paths, list):
+            i = self.rng.integers(low=0, high=len(rainy_img_paths) - 1)
+            rainy_img_path = rainy_img_paths[i]
         else:
-            rainy_img_path = self.rainy_img[idx]
+            rainy_img_path = rainy_img_paths
+
         clean_image = read_image(clean_img_path)
         rainy_image = read_image(rainy_img_path)
+
         if self.transform:
             clean_image = self.transform(clean_image)
             rainy_image = self.transform(rainy_image)
 
-        ret = {
-            "clean_image" : clean_image,
-            "rainy_image" : rainy_image
+        return {
+            "clean_image": clean_image,
+            "rainy_image": rainy_image
         }
-        return ret
 
     def __add__(self, other):
+        assert isinstance(other, ImagePairDataset), "other must be an instance of ImagePairDataset"
         return ImagePairDataset(
-            clean_img_dir=self.clean_img+other.clean_img,
-            rainy_img_dirs=self.rainy_img+other.rainy_img,
-            transform=self.transform
-            )
+            clean_img_dir=self.clean_img + other.clean_img,
+            rainy_img_dirs=self.dirty_img + other.dirty_img,
+            transform=self.transform or other.transform
+        )
(파일 끝에 줄바꿈 문자 없음)
train.py
--- train.py
+++ train.py
@@ -66,9 +66,9 @@
     pass
 
 # 이건 땜빵이고 차후에 데이터 관리 모듈 만들꺼임
-rainy_data_path = glob.glob("data/source/Peking_raindrop_dataset/dirty/*.png")
+rainy_data_path = glob.glob("data/source/Oxford_raindrop_dataset/dirty/*.png")
 rainy_data_path = sorted(rainy_data_path)
-clean_data_path = glob.glob("data/source/Peking_raindrop_dataset/clean/*.png")
+clean_data_path = glob.glob("data/source/Oxford_raindrop_dataset/clean/*.png")
 clean_data_path = sorted(clean_data_path)
 
 resize = torchvision.transforms.Resize((480, 720), antialias=True)
Add a comment
List