
File name
Commit message
Commit date
File name
Commit message
Commit date
import argparse
def get_param():
parser = argparse.ArgumentParser(description="GAN Training Arguments")
parser.add_argument("--epochs", "-e", type=int, required=True, help="Total number of epochs")
parser.add_argument("--batch_size", "-b", type=int, required=True, help="Size of single batch")
parser.add_argument("--save_interval", "-s", type=int, required=True, help="Interval for saving weights")
parser.add_argument("--sample_interval", type=int, required=True, help="Interval for saving inference result")
parser.add_argument("--num_worker", "-j", type=int, default=4, help="Dataloader's number of threads")
# its j because make uses -j for number of threads
parser.add_argument("--device", "-d", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to use for "
"computation")
parser.add_argument("--load", "-l", type=str, default=None, help="Path to previous weights for continuing training")
parser.add_argument("--generator_learning_rate", "-g_lr", type=float, required=True, help="Learning rate of "
"generator")
parser.add_argument("--generator_ARNN_learning_rate", "-g_arnn_lr", type=float, help="learning rate of Attention "
"RNN network, default is "
"same as the whole generator "
"(autoencoder)")
parser.add_argument("--generator_learning_miniepoch", "-g_epoch", type=int, default=1, help="Number of times "
"generator trains in "
"a single epoch")
parser.add_argument("--generator_attentivernn_blocks", "-g_arnn_b", type=int, default=1, help="Number of blocks "
"of RNN in "
"attention network")
parser.add_argument("--generator_resnet_depth", "-g_depth", type=int, default=1, help="Depth of ResNet in each "
"attention RNN blocks")
parser.add_argument("--discriminator_learning_rate", "-d_lr", default=None, type=float, help="Learning rate of "
"discriminator."
"If not given, it is "
"assumed to be"
"the same as the "
"generator")
args = parser.parse_args()
return args