윤영준 윤영준 2023-06-26
writing get_param() for get param for training in training code, Logger for logging the progress of training.
Logger for logging the progress of training.
@cabb483963c87e11f7952e7ea39b0d0053559eb1
 
tools/argparser.py (added)
+++ tools/argparser.py
@@ -0,0 +1,22 @@
+import argparse
+
+
+def get_param():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--epoch", type=int, help="how many epoch will you train?")
+    parser.add_argument("--batch_size", type=int, help="size of batch")
+    parser.add_argument("--learning_rate", type=float, default=0.002, help="learning_rate using ADAM optimizer")
+    parser.add_argument("--continue_from", type=str, default=None, help="continue training from: {your trained weight}")
+    parser.add_argument("--mini_epoch_discriminator", type=int, default=4,
+                        help="how many epochs does discriminator trains over a epoch?")
+    parser.add_argument("--mini_epoch_generator", type=int, default=4,
+                        help="how many epochs does generator trains over a epoch?")
+    parser.add_argument("--AttentiveRNNBLCKs", type=int, default=3,
+                        help="how many LSTM blocks in generator?")
+    parser.add_argument("--AttentiveRNNResNetdepth", type=int, default=2,
+                        help="how deep is each RNN blocks?")
+    parser.add_argument("--save_interval", type=int, default=10,
+                        help="weight save interval")
+    parser.add_argument("--sample_interval", type=int, default=10,
+                        help="sample image interval")
+    return parser
 
tools/dataloader.py (added)
+++ tools/dataloader.py
@@ -0,0 +1,36 @@
+from torch.utils.data import Dataset
+from torchvision.io import read_image
+import numpy as np
+
+# the dataset for this model needs to be in at least, a pair of clean and dirty image.
+# in other words, 'annotations' is not really appropriate word
+# however, the problem here is that we need should be able to pair more than one dirty image per clean image,
+# I am lost now
+
+class ImageDataSet(Dataset):
+    def __init__(self, img_dir, annotations, transform=None):
+        self.annotations = annotations
+        self.img_dir = img_dir
+        self.transform = transform
+        # self.target_transform = target_transform
+        # print(self.transform)
+
+    def __len__(self):
+        return len(self.annotations)
+
+    def __getitem__(self, idx):
+        img_path = self.img_dir[idx]
+        image = read_image(img_path)
+        label = self.annotations[idx]
+        if self.transform:
+            image = self.transform(image)
+        # if self.target_transform(label):
+        #     label = self.target_transform(label)
+        return image, label
+
+    def __add__(self, other):
+        return ImageDataSet(
+            img_dir=self.img_dir+other.img_dir,
+            annotations=np.array(list(self.annotations)+list(other.annotations)),
+            transform=self.transform
+            )
 
tools/logger.py (added)
+++ tools/logger.py
@@ -0,0 +1,19 @@
+import sys
+import time
+class Logger():
+    def __init__(self):
+        self.start_time = time.time()
+        self.epoch_start_time = self.start_time
+    def print_training_log(self, current_epoch, total_epoch, **kargs):
+        current_time = time.time()
+        epoch_time = current_time - self.epoch_start_time
+        total_time = current_time - self.start_time
+
+        estimated_total_time = total_time * total_epoch / (current_epoch + 1)
+        remaining_time = estimated_total_time - total_time
+
+        self.epoch_start_time = current_time
+
+        sys.stdout.write(
+            f"epoch : {current_epoch}/{total_epoch}\n"
+        )(No newline at end of file)
train.py
--- train.py
+++ train.py
@@ -1,11 +1,25 @@
+import sys
+import os
 import torch
 import numpy as np
 import pandas as pd
 import plotly.express as px
+from torchvision.utils import save_image
+
 from model import Autoencoder
 from model import Generator
 from model import Discriminator
 from model import AttentiveRNN
+from tools.argparser import get_param
+from tools.logger import Logger
+
+param = get_param()
+
+logger = Logger()
+
+## RNN 따로 돌리고 CPU로 메모리 옳기고
+## Autoencoder 따로 돌리고 메모리 옳기고
+## 안되는가
 
 ## 대충 열심히 GAN 구성하는 코드
 ## 대충 그래서 weight export해서 inference용과 training용으로 나누는 코드
Add a comment
List