--- tools/dataloader.py
+++ tools/dataloader.py
... | ... | @@ -1,42 +1,46 @@ |
1 | 1 |
from torch.utils.data import Dataset |
2 | 2 |
from torchvision.io import read_image |
3 |
+ |
|
4 |
+from typing import Union, List |
|
3 | 5 |
import numpy as np |
4 | 6 |
|
5 | 7 |
|
6 | 8 |
class ImagePairDataset(Dataset): |
7 |
- def __init__(self, clean_img_dir, rainy_img_dirs, transform=None): |
|
9 |
+ def __init__(self, clean_img_dir: List[str], rainy_img_dirs: Union[List[str], List[List[str]]], transform=None): |
|
8 | 10 |
self.clean_img = clean_img_dir |
9 |
- self.rainy_img = rainy_img_dirs |
|
11 |
+ self.dirty_img = rainy_img_dirs |
|
10 | 12 |
self.transform = transform |
13 |
+ self.rng = np.random.default_rng() |
|
11 | 14 |
|
12 | 15 |
def __len__(self): |
13 | 16 |
return len(self.clean_img) |
14 | 17 |
|
15 | 18 |
def __getitem__(self, idx): |
16 | 19 |
clean_img_path = self.clean_img[idx] |
20 |
+ rainy_img_paths = self.dirty_img[idx] |
|
17 | 21 |
|
18 |
- i = 0 |
|
19 |
- if len(self.rainy_img[idx]) is list: |
|
20 |
- rng = np.random.default_rng() |
|
21 |
- i = rng.integers(low=0, high=len(self.rainy_img)-1) |
|
22 |
- rainy_img_path = self.rainy_img[idx][i] |
|
22 |
+ if isinstance(rainy_img_paths, list): |
|
23 |
+ i = self.rng.integers(low=0, high=len(rainy_img_paths) - 1) |
|
24 |
+ rainy_img_path = rainy_img_paths[i] |
|
23 | 25 |
else: |
24 |
- rainy_img_path = self.rainy_img[idx] |
|
26 |
+ rainy_img_path = rainy_img_paths |
|
27 |
+ |
|
25 | 28 |
clean_image = read_image(clean_img_path) |
26 | 29 |
rainy_image = read_image(rainy_img_path) |
30 |
+ |
|
27 | 31 |
if self.transform: |
28 | 32 |
clean_image = self.transform(clean_image) |
29 | 33 |
rainy_image = self.transform(rainy_image) |
30 | 34 |
|
31 |
- ret = { |
|
32 |
- "clean_image" : clean_image, |
|
33 |
- "rainy_image" : rainy_image |
|
35 |
+ return { |
|
36 |
+ "clean_image": clean_image, |
|
37 |
+ "rainy_image": rainy_image |
|
34 | 38 |
} |
35 |
- return ret |
|
36 | 39 |
|
37 | 40 |
def __add__(self, other): |
41 |
+ assert isinstance(other, ImagePairDataset), "other must be an instance of ImagePairDataset" |
|
38 | 42 |
return ImagePairDataset( |
39 |
- clean_img_dir=self.clean_img+other.clean_img, |
|
40 |
- rainy_img_dirs=self.rainy_img+other.rainy_img, |
|
41 |
- transform=self.transform |
|
42 |
- ) |
|
43 |
+ clean_img_dir=self.clean_img + other.clean_img, |
|
44 |
+ rainy_img_dirs=self.dirty_img + other.dirty_img, |
|
45 |
+ transform=self.transform or other.transform |
|
46 |
+ )(파일 끝에 줄바꿈 문자 없음) |
--- train.py
+++ train.py
... | ... | @@ -66,9 +66,9 @@ |
66 | 66 |
pass |
67 | 67 |
|
68 | 68 |
# 이건 땜빵이고 차후에 데이터 관리 모듈 만들꺼임 |
69 |
-rainy_data_path = glob.glob("data/source/Peking_raindrop_dataset/dirty/*.png") |
|
69 |
+rainy_data_path = glob.glob("data/source/Oxford_raindrop_dataset/dirty/*.png") |
|
70 | 70 |
rainy_data_path = sorted(rainy_data_path) |
71 |
-clean_data_path = glob.glob("data/source/Peking_raindrop_dataset/clean/*.png") |
|
71 |
+clean_data_path = glob.glob("data/source/Oxford_raindrop_dataset/clean/*.png") |
|
72 | 72 |
clean_data_path = sorted(clean_data_path) |
73 | 73 |
|
74 | 74 |
resize = torchvision.transforms.Resize((480, 720), antialias=True) |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?