
Merge remote-tracking branch 'origin/master'
@3a11ad1f257f24ab547ba9ec14ed8f0f2d0d0cad
--- train.py
+++ train.py
... | ... | @@ -71,7 +71,7 @@ |
71 | 71 |
|
72 | 72 |
resize = torchvision.transforms.Resize((692, 776), antialias=True) |
73 | 73 |
dataset = ImagePairDataset(clean_data_path, rainy_data_path, transform=resize) |
74 |
-dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
74 |
+dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) |
|
75 | 75 |
# declare generator loss |
76 | 76 |
|
77 | 77 |
optimizer_G = torch.optim.Adam(generator.parameters(), lr=generator_learning_rate) |
... | ... | @@ -158,7 +158,7 @@ |
158 | 158 |
|
159 | 159 |
day = strftime("%Y-%m-%d %H:%M:%S", gmtime()) |
160 | 160 |
if epoch % save_interval == 0 and epoch != 0: |
161 |
- torch.save(generator.attentionRNN.state_dict(), f"weight/Attention_RNN_{epoch}_{day}.pt") |
|
161 |
+ torch.save(generator.attentiveRNN.state_dict(), f"weight/Attention_RNN_{epoch}_{day}.pt") |
|
162 | 162 |
torch.save(generator.state_dict(), f"weight/Generator_{epoch}_{day}.pt") |
163 | 163 |
torch.save(discriminator.state_dict(), f"weight/Discriminator_{epoch}_{day}.pt") |
164 | 164 |
|
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?