윤영준 윤영준 2023-06-30
working on parser
@ea5d2cb61ec40026cd9d394b40ae9c5d3207d4ab
tools/argparser.py
--- tools/argparser.py
+++ tools/argparser.py
@@ -2,21 +2,19 @@
 
 
 def get_param():
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--epoch", type=int, help="how many epoch will you train?")
-    parser.add_argument("--batch_size", type=int, help="size of batch")
-    parser.add_argument("--learning_rate", type=float, default=0.002, help="learning_rate using ADAM optimizer")
-    parser.add_argument("--continue_from", type=str, default=None, help="continue training from: {your trained weight}")
-    parser.add_argument("--mini_epoch_discriminator", type=int, default=4,
-                        help="how many epochs does discriminator trains over a epoch?")
-    parser.add_argument("--mini_epoch_generator", type=int, default=4,
-                        help="how many epochs does generator trains over a epoch?")
-    parser.add_argument("--AttentiveRNNBLCKs", type=int, default=3,
-                        help="how many LSTM blocks in generator?")
-    parser.add_argument("--AttentiveRNNResNetdepth", type=int, default=2,
-                        help="how deep is each RNN blocks?")
-    parser.add_argument("--save_interval", type=int, default=10,
-                        help="weight save interval")
-    parser.add_argument("--sample_interval", type=int, default=10,
-                        help="sample image interval")
-    return parser
+    parser = argparse.ArgumentParser(description="GAN Training Arguments")
+
+    parser.add_argument("--epochs", "-e", type=int, required=True, help="Total number of epochs")
+    parser.add_argument("--batch_size", "-b", type=int, required=True, help="Size of single batch")
+    parser.add_argument("--save_interval", "-s", type=int, required=True, help="Interval for saving weights")
+    parser.add_argument("--sample_interval", type=int, required=True, help="Interval for saving inference result")
+    parser.add_argument("--device", "-d", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to use for computation")
+    parser.add_argument("--load", "-l", type=str, help="Path to previous weights for continuing training")
+    parser.add_argument("--generator_learning_rate", "-g_lr", type=float, required=True, help="Learning rate of generator")
+    parser.add_argument("--generator_learning_miniepoch", "-g_epoch", type=int, default=1, help="Number of times generator trains in a single epoch")
+    parser.add_argument("--generator_attentivernn_blocks", "-g_arnn_b", type=int, default=1, help="Number of blocks of RNN in attention network")
+    parser.add_argument("--generator_resnet_depth", "-g_depth", type=int, default=1, help="Depth of ResNet in each attention RNN blocks")
+    parser.add_argument("--discriminator_learning_rate", "-d_lr", type=float, help="Learning rate of discriminator. If not given, it is assumed to be the same as the generator")
+
+    args = parser.parse_args()
+    return args
train.py
--- train.py
+++ train.py
@@ -23,7 +23,19 @@
         torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
         torch.nn.init.constant_(m.bias.data, 0.0)
 
-param = get_param()
+args = get_param()
+
+epochs = args.epochs
+batch_size = args.batch_size
+save_interval = args.save_interval
+sample_interval = args.sample_interval
+device = args.device
+load = args.load
+generator_learning_rate = args.generator_learning_rate
+generator_learning_miniepoch = args.generator_learning_miniepoch
+generator_attentivernn_blocks = args.generator_attentivernn_blocks
+generator_resnet_depth = args.generator_resnet_depth
+discriminator_learning_rate = args.discriminator_learning_rate if args.discriminator_learning_rate is not None else args.generator_learning_rate
 
 logger = Logger()
 
Add a comment
List