
--- train.py
+++ train.py
... | ... | @@ -71,7 +71,7 @@ |
71 | 71 |
|
72 | 72 |
resize = torchvision.transforms.Resize((692, 776), antialias=True) |
73 | 73 |
dataset = ImagePairDataset(clean_data_path, rainy_data_path, transform=resize) |
74 |
-dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
74 |
+dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) |
|
75 | 75 |
# declare generator loss |
76 | 76 |
|
77 | 77 |
|
... | ... | @@ -148,7 +148,7 @@ |
148 | 148 |
|
149 | 149 |
day = strftime("%Y-%m-%d %H:%M:%S", gmtime()) |
150 | 150 |
if epoch % save_interval == 0 and epoch != 0: |
151 |
- torch.save(generator.attentionRNN.state_dict(), f"weight/Attention_RNN_{epoch}_{day}.pt") |
|
151 |
+ torch.save(generator.attentiveRNN.state_dict(), f"weight/Attention_RNN_{epoch}_{day}.pt") |
|
152 | 152 |
torch.save(generator.state_dict(), f"weight/Generator_{epoch}_{day}.pt") |
153 | 153 |
torch.save(discriminator.state_dict(), f"weight/Discriminator_{epoch}_{day}.pt") |
154 | 154 |
|
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?