윤영준 윤영준 2023-06-27
somewhat finished ImageDataSet
@5cda0f71483a6c0fb82f343bee2489db30c62869
tools/dataloader.py
--- tools/dataloader.py
+++ tools/dataloader.py
@@ -2,35 +2,41 @@
 from torchvision.io import read_image
 import numpy as np
 
-# the dataset for this model needs to be in at least, a pair of clean and dirty image.
-# in other words, 'annotations' is not really appropriate word
-# however, the problem here is that we need should be able to pair more than one dirty image per clean image,
-# I am lost now
 
 class ImageDataSet(Dataset):
-    def __init__(self, img_dir, annotations, transform=None):
-        self.annotations = annotations
-        self.img_dir = img_dir
+    def __init__(self, clean_img_dir, rainy_img_dirs, transform=None):
+        self.clean_img = clean_img_dir
+        self.rainy_img = rainy_img_dirs
         self.transform = transform
         # self.target_transform = target_transform
         # print(self.transform)
 
     def __len__(self):
-        return len(self.annotations)
+        return len(self.clean_img)
 
     def __getitem__(self, idx):
-        img_path = self.img_dir[idx]
-        image = read_image(img_path)
-        label = self.annotations[idx]
+        clean_img_path = self.clean_img[idx]
+
+        i = 0
+        if len(self.rainy_img) != 1:
+            rng = np.random.default_rng()
+            i = rng.integers(low=0, high=len(self.rainy_img)-1)
+        rainy_img_path = self.rainy_img[idx][i]
+        clean_image = read_image(clean_img_path)
+        rainy_image = read_image(rainy_img_path)
         if self.transform:
-            image = self.transform(image)
-        # if self.target_transform(label):
-        #     label = self.target_transform(label)
-        return image, label
+            clean_image = self.transform(clean_image)
+            rainy_image = self.rainy_img(rainy_image)
+
+        ret = {
+            "clean_image" : clean_image,
+            "rainy_image" : rainy_image
+        }
+        return ret
 
     def __add__(self, other):
         return ImageDataSet(
-            img_dir=self.img_dir+other.img_dir,
-            annotations=np.array(list(self.annotations)+list(other.annotations)),
+            clean_img_dir=self.clean_img+other.clean_img,
+            rainy_img_dirs=self.rainy_img+other.rainy_img,
             transform=self.transform
             )
train.py
--- train.py
+++ train.py
@@ -32,9 +32,10 @@
 generator = Generator() # get network values and stuff
 discriminator = Discriminator()
 
-if cuda:
-    generator.cuda()
-    discriminator.cuda()
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+generator = Generator().to(device)
+discriminator = Discriminator().to(device)
 
 if load is not False:
     generator.load_state_dict(torch.load("example_path"))
Add a comment
List