
--- tools/argparser.py
+++ tools/argparser.py
... | ... | @@ -2,21 +2,19 @@ |
2 | 2 |
|
3 | 3 |
|
4 | 4 |
def get_param(): |
5 |
- parser = argparse.ArgumentParser() |
|
6 |
- parser.add_argument("--epoch", type=int, help="how many epoch will you train?") |
|
7 |
- parser.add_argument("--batch_size", type=int, help="size of batch") |
|
8 |
- parser.add_argument("--learning_rate", type=float, default=0.002, help="learning_rate using ADAM optimizer") |
|
9 |
- parser.add_argument("--continue_from", type=str, default=None, help="continue training from: {your trained weight}") |
|
10 |
- parser.add_argument("--mini_epoch_discriminator", type=int, default=4, |
|
11 |
- help="how many epochs does discriminator trains over a epoch?") |
|
12 |
- parser.add_argument("--mini_epoch_generator", type=int, default=4, |
|
13 |
- help="how many epochs does generator trains over a epoch?") |
|
14 |
- parser.add_argument("--AttentiveRNNBLCKs", type=int, default=3, |
|
15 |
- help="how many LSTM blocks in generator?") |
|
16 |
- parser.add_argument("--AttentiveRNNResNetdepth", type=int, default=2, |
|
17 |
- help="how deep is each RNN blocks?") |
|
18 |
- parser.add_argument("--save_interval", type=int, default=10, |
|
19 |
- help="weight save interval") |
|
20 |
- parser.add_argument("--sample_interval", type=int, default=10, |
|
21 |
- help="sample image interval") |
|
22 |
- return parser |
|
5 |
+ parser = argparse.ArgumentParser(description="GAN Training Arguments") |
|
6 |
+ |
|
7 |
+ parser.add_argument("--epochs", "-e", type=int, required=True, help="Total number of epochs") |
|
8 |
+ parser.add_argument("--batch_size", "-b", type=int, required=True, help="Size of single batch") |
|
9 |
+ parser.add_argument("--save_interval", "-s", type=int, required=True, help="Interval for saving weights") |
|
10 |
+ parser.add_argument("--sample_interval", type=int, required=True, help="Interval for saving inference result") |
|
11 |
+ parser.add_argument("--device", "-d", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to use for computation") |
|
12 |
+ parser.add_argument("--load", "-l", type=str, help="Path to previous weights for continuing training") |
|
13 |
+ parser.add_argument("--generator_learning_rate", "-g_lr", type=float, required=True, help="Learning rate of generator") |
|
14 |
+ parser.add_argument("--generator_learning_miniepoch", "-g_epoch", type=int, default=1, help="Number of times generator trains in a single epoch") |
|
15 |
+ parser.add_argument("--generator_attentivernn_blocks", "-g_arnn_b", type=int, default=1, help="Number of blocks of RNN in attention network") |
|
16 |
+ parser.add_argument("--generator_resnet_depth", "-g_depth", type=int, default=1, help="Depth of ResNet in each attention RNN blocks") |
|
17 |
+ parser.add_argument("--discriminator_learning_rate", "-d_lr", type=float, help="Learning rate of discriminator. If not given, it is assumed to be the same as the generator") |
|
18 |
+ |
|
19 |
+ args = parser.parse_args() |
|
20 |
+ return args |
--- train.py
+++ train.py
... | ... | @@ -23,7 +23,19 @@ |
23 | 23 |
torch.nn.init.normal_(m.weight.data, 1.0, 0.02) |
24 | 24 |
torch.nn.init.constant_(m.bias.data, 0.0) |
25 | 25 |
|
26 |
-param = get_param() |
|
26 |
+args = get_param() |
|
27 |
+ |
|
28 |
+epochs = args.epochs |
|
29 |
+batch_size = args.batch_size |
|
30 |
+save_interval = args.save_interval |
|
31 |
+sample_interval = args.sample_interval |
|
32 |
+device = args.device |
|
33 |
+load = args.load |
|
34 |
+generator_learning_rate = args.generator_learning_rate |
|
35 |
+generator_learning_miniepoch = args.generator_learning_miniepoch |
|
36 |
+generator_attentivernn_blocks = args.generator_attentivernn_blocks |
|
37 |
+generator_resnet_depth = args.generator_resnet_depth |
|
38 |
+discriminator_learning_rate = args.discriminator_learning_rate if args.discriminator_learning_rate is not None else args.generator_learning_rate |
|
27 | 39 |
|
28 | 40 |
logger = Logger() |
29 | 41 |
|
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?