윤영준 윤영준 2023-07-05
Merge remote-tracking branch 'origin/master'
@3a11ad1f257f24ab547ba9ec14ed8f0f2d0d0cad
train.py
--- train.py
+++ train.py
@@ -71,7 +71,7 @@
 
 resize = torchvision.transforms.Resize((692, 776), antialias=True)
 dataset = ImagePairDataset(clean_data_path, rainy_data_path, transform=resize)
-dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
+dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
 # declare generator loss
 
 optimizer_G = torch.optim.Adam(generator.parameters(), lr=generator_learning_rate)
@@ -158,7 +158,7 @@
 
     day = strftime("%Y-%m-%d %H:%M:%S", gmtime())
     if epoch % save_interval == 0 and epoch != 0:
-        torch.save(generator.attentionRNN.state_dict(), f"weight/Attention_RNN_{epoch}_{day}.pt")
+        torch.save(generator.attentiveRNN.state_dict(), f"weight/Attention_RNN_{epoch}_{day}.pt")
         torch.save(generator.state_dict(), f"weight/Generator_{epoch}_{day}.pt")
         torch.save(discriminator.state_dict(), f"weight/Discriminator_{epoch}_{day}.pt")
 
Add a comment
List