import argparse def get_param(): parser = argparse.ArgumentParser(description="GAN Training Arguments") parser.add_argument("--epochs", "-e", type=int, required=True, help="Total number of epochs") parser.add_argument("--batch_size", "-b", type=int, required=True, help="Size of single batch") parser.add_argument("--save_interval", "-s", type=int, required=True, help="Interval for saving weights") parser.add_argument("--sample_interval", type=int, required=True, help="Interval for saving inference result") parser.add_argument("--num_worker", "-j", type=int, default=4, help="Dataloader's number of threads") # its j because make uses -j for number of threads parser.add_argument("--device", "-d", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to use for " "computation") parser.add_argument("--load", "-l", type=str, default=None, help="Path to previous weights for continuing training") parser.add_argument("--generator_learning_rate", "-g_lr", type=float, required=True, help="Learning rate of " "generator") parser.add_argument("--generator_ARNN_learning_rate", "-g_arnn_lr", type=float, help="learning rate of Attention " "RNN network, default is " "same as the whole generator " "(autoencoder)") parser.add_argument("--generator_learning_miniepoch", "-g_epoch", type=int, default=1, help="Number of times " "generator trains in " "a single epoch") parser.add_argument("--generator_attentivernn_blocks", "-g_arnn_b", type=int, default=1, help="Number of blocks " "of RNN in " "attention network") parser.add_argument("--generator_resnet_depth", "-g_depth", type=int, default=1, help="Depth of ResNet in each " "attention RNN blocks") parser.add_argument("--discriminator_learning_rate", "-d_lr", default=None, type=float, help="Learning rate of " "discriminator." "If not given, it is " "assumed to be" "the same as the " "generator") args = parser.parse_args() return args