윤영준 윤영준 2023-07-12
... mask was inverted...
@c28c2ce8a36e074df672a432fe0de31474650dae
model/Generator.py
--- model/Generator.py
+++ model/Generator.py
@@ -43,7 +43,7 @@
         diff = abs(clean - dirty)
         diff = sum(diff, dim=1)
 
-        bin_diff = (diff < thresold).to(clean.dtype)
+        bin_diff = (diff >= thresold).to(clean.dtype)
 
         return bin_diff
 
train.py
--- train.py
+++ train.py
@@ -66,12 +66,12 @@
     pass
 
 # 이건 땜빵이고 차후에 데이터 관리 모듈 만들꺼임
-rainy_data_path = glob.glob("data/source/Oxford_raindrop_dataset/dirty/*.png")
+rainy_data_path = glob.glob("data/source/Peking_raindrop_dataset/dirty/*.png")
 rainy_data_path = sorted(rainy_data_path)
-clean_data_path = glob.glob("data/source/Oxford_raindrop_dataset/clean/*.png")
+clean_data_path = glob.glob("data/source/Peking_raindrop_dataset/clean/*.png")
 clean_data_path = sorted(clean_data_path)
 
-resize = torchvision.transforms.Resize((692, 776), antialias=True)
+resize = torchvision.transforms.Resize((480, 720), antialias=True)
 dataset = ImagePairDataset(clean_data_path, rainy_data_path, transform=resize)
 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
 # declare generator loss
@@ -116,7 +116,7 @@
 
         attentiveRNNresults = generator.attentiveRNN(rainy_img)
         generator_attention_map = attentiveRNNresults['attention_map_list']
-        binary_difference_mask = generator.binary_diff_mask(clean_img, rainy_img, thresold=0.2)
+        binary_difference_mask = generator.binary_diff_mask(clean_img, rainy_img, thresold=0.24)
         generator_loss_ARNN = generator.attentiveRNN.loss(rainy_img, binary_difference_mask)
         generator_loss_ARNN.backward()
         optimizer_G_ARNN.step()
Add a comment
List