--- model/Generator.py
+++ model/Generator.py
... | ... | @@ -43,7 +43,7 @@ |
43 | 43 |
diff = abs(clean - dirty) |
44 | 44 |
diff = sum(diff, dim=1) |
45 | 45 |
|
46 |
- bin_diff = (diff < thresold).to(clean.dtype) |
|
46 |
+ bin_diff = (diff >= thresold).to(clean.dtype) |
|
47 | 47 |
|
48 | 48 |
return bin_diff |
49 | 49 |
|
--- train.py
+++ train.py
... | ... | @@ -66,12 +66,12 @@ |
66 | 66 |
pass |
67 | 67 |
|
68 | 68 |
# 이건 땜빵이고 차후에 데이터 관리 모듈 만들꺼임 |
69 |
-rainy_data_path = glob.glob("data/source/Oxford_raindrop_dataset/dirty/*.png") |
|
69 |
+rainy_data_path = glob.glob("data/source/Peking_raindrop_dataset/dirty/*.png") |
|
70 | 70 |
rainy_data_path = sorted(rainy_data_path) |
71 |
-clean_data_path = glob.glob("data/source/Oxford_raindrop_dataset/clean/*.png") |
|
71 |
+clean_data_path = glob.glob("data/source/Peking_raindrop_dataset/clean/*.png") |
|
72 | 72 |
clean_data_path = sorted(clean_data_path) |
73 | 73 |
|
74 |
-resize = torchvision.transforms.Resize((692, 776), antialias=True) |
|
74 |
+resize = torchvision.transforms.Resize((480, 720), antialias=True) |
|
75 | 75 |
dataset = ImagePairDataset(clean_data_path, rainy_data_path, transform=resize) |
76 | 76 |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) |
77 | 77 |
# declare generator loss |
... | ... | @@ -116,7 +116,7 @@ |
116 | 116 |
|
117 | 117 |
attentiveRNNresults = generator.attentiveRNN(rainy_img) |
118 | 118 |
generator_attention_map = attentiveRNNresults['attention_map_list'] |
119 |
- binary_difference_mask = generator.binary_diff_mask(clean_img, rainy_img, thresold=0.2) |
|
119 |
+ binary_difference_mask = generator.binary_diff_mask(clean_img, rainy_img, thresold=0.24) |
|
120 | 120 |
generator_loss_ARNN = generator.attentiveRNN.loss(rainy_img, binary_difference_mask) |
121 | 121 |
generator_loss_ARNN.backward() |
122 | 122 |
optimizer_G_ARNN.step() |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?