data:image/s3,"s3://crabby-images/77fc1/77fc1ecd598263bdfa1d6248fbe60b3bfc41f6f8" alt=""
--- tools/dataloader.py
+++ tools/dataloader.py
... | ... | @@ -2,35 +2,41 @@ |
2 | 2 |
from torchvision.io import read_image |
3 | 3 |
import numpy as np |
4 | 4 |
|
5 |
-# the dataset for this model needs to be in at least, a pair of clean and dirty image. |
|
6 |
-# in other words, 'annotations' is not really appropriate word |
|
7 |
-# however, the problem here is that we need should be able to pair more than one dirty image per clean image, |
|
8 |
-# I am lost now |
|
9 | 5 |
|
10 | 6 |
class ImageDataSet(Dataset): |
11 |
- def __init__(self, img_dir, annotations, transform=None): |
|
12 |
- self.annotations = annotations |
|
13 |
- self.img_dir = img_dir |
|
7 |
+ def __init__(self, clean_img_dir, rainy_img_dirs, transform=None): |
|
8 |
+ self.clean_img = clean_img_dir |
|
9 |
+ self.rainy_img = rainy_img_dirs |
|
14 | 10 |
self.transform = transform |
15 | 11 |
# self.target_transform = target_transform |
16 | 12 |
# print(self.transform) |
17 | 13 |
|
18 | 14 |
def __len__(self): |
19 |
- return len(self.annotations) |
|
15 |
+ return len(self.clean_img) |
|
20 | 16 |
|
21 | 17 |
def __getitem__(self, idx): |
22 |
- img_path = self.img_dir[idx] |
|
23 |
- image = read_image(img_path) |
|
24 |
- label = self.annotations[idx] |
|
18 |
+ clean_img_path = self.clean_img[idx] |
|
19 |
+ |
|
20 |
+ i = 0 |
|
21 |
+ if len(self.rainy_img) != 1: |
|
22 |
+ rng = np.random.default_rng() |
|
23 |
+ i = rng.integers(low=0, high=len(self.rainy_img)-1) |
|
24 |
+ rainy_img_path = self.rainy_img[idx][i] |
|
25 |
+ clean_image = read_image(clean_img_path) |
|
26 |
+ rainy_image = read_image(rainy_img_path) |
|
25 | 27 |
if self.transform: |
26 |
- image = self.transform(image) |
|
27 |
- # if self.target_transform(label): |
|
28 |
- # label = self.target_transform(label) |
|
29 |
- return image, label |
|
28 |
+ clean_image = self.transform(clean_image) |
|
29 |
+ rainy_image = self.rainy_img(rainy_image) |
|
30 |
+ |
|
31 |
+ ret = { |
|
32 |
+ "clean_image" : clean_image, |
|
33 |
+ "rainy_image" : rainy_image |
|
34 |
+ } |
|
35 |
+ return ret |
|
30 | 36 |
|
31 | 37 |
def __add__(self, other): |
32 | 38 |
return ImageDataSet( |
33 |
- img_dir=self.img_dir+other.img_dir, |
|
34 |
- annotations=np.array(list(self.annotations)+list(other.annotations)), |
|
39 |
+ clean_img_dir=self.clean_img+other.clean_img, |
|
40 |
+ rainy_img_dirs=self.rainy_img+other.rainy_img, |
|
35 | 41 |
transform=self.transform |
36 | 42 |
) |
--- train.py
+++ train.py
... | ... | @@ -32,9 +32,10 @@ |
32 | 32 |
generator = Generator() # get network values and stuff |
33 | 33 |
discriminator = Discriminator() |
34 | 34 |
|
35 |
-if cuda: |
|
36 |
- generator.cuda() |
|
37 |
- discriminator.cuda() |
|
35 |
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
36 |
+ |
|
37 |
+generator = Generator().to(device) |
|
38 |
+discriminator = Discriminator().to(device) |
|
38 | 39 |
|
39 | 40 |
if load is not False: |
40 | 41 |
generator.load_state_dict(torch.load("example_path")) |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?