data:image/s3,"s3://crabby-images/77fc1/77fc1ecd598263bdfa1d6248fbe60b3bfc41f6f8" alt=""
writing get_param() for get param for training in training code, Logger for logging the progress of training.
Logger for logging the progress of training.
@cabb483963c87e11f7952e7ea39b0d0053559eb1
+++ tools/argparser.py
... | ... | @@ -0,0 +1,22 @@ |
1 | +import argparse | |
2 | + | |
3 | + | |
4 | +def get_param(): | |
5 | + parser = argparse.ArgumentParser() | |
6 | + parser.add_argument("--epoch", type=int, help="how many epoch will you train?") | |
7 | + parser.add_argument("--batch_size", type=int, help="size of batch") | |
8 | + parser.add_argument("--learning_rate", type=float, default=0.002, help="learning_rate using ADAM optimizer") | |
9 | + parser.add_argument("--continue_from", type=str, default=None, help="continue training from: {your trained weight}") | |
10 | + parser.add_argument("--mini_epoch_discriminator", type=int, default=4, | |
11 | + help="how many epochs does discriminator trains over a epoch?") | |
12 | + parser.add_argument("--mini_epoch_generator", type=int, default=4, | |
13 | + help="how many epochs does generator trains over a epoch?") | |
14 | + parser.add_argument("--AttentiveRNNBLCKs", type=int, default=3, | |
15 | + help="how many LSTM blocks in generator?") | |
16 | + parser.add_argument("--AttentiveRNNResNetdepth", type=int, default=2, | |
17 | + help="how deep is each RNN blocks?") | |
18 | + parser.add_argument("--save_interval", type=int, default=10, | |
19 | + help="weight save interval") | |
20 | + parser.add_argument("--sample_interval", type=int, default=10, | |
21 | + help="sample image interval") | |
22 | + return parser |
+++ tools/dataloader.py
... | ... | @@ -0,0 +1,36 @@ |
1 | +from torch.utils.data import Dataset | |
2 | +from torchvision.io import read_image | |
3 | +import numpy as np | |
4 | + | |
5 | +# the dataset for this model needs to be in at least, a pair of clean and dirty image. | |
6 | +# in other words, 'annotations' is not really appropriate word | |
7 | +# however, the problem here is that we need should be able to pair more than one dirty image per clean image, | |
8 | +# I am lost now | |
9 | + | |
10 | +class ImageDataSet(Dataset): | |
11 | + def __init__(self, img_dir, annotations, transform=None): | |
12 | + self.annotations = annotations | |
13 | + self.img_dir = img_dir | |
14 | + self.transform = transform | |
15 | + # self.target_transform = target_transform | |
16 | + # print(self.transform) | |
17 | + | |
18 | + def __len__(self): | |
19 | + return len(self.annotations) | |
20 | + | |
21 | + def __getitem__(self, idx): | |
22 | + img_path = self.img_dir[idx] | |
23 | + image = read_image(img_path) | |
24 | + label = self.annotations[idx] | |
25 | + if self.transform: | |
26 | + image = self.transform(image) | |
27 | + # if self.target_transform(label): | |
28 | + # label = self.target_transform(label) | |
29 | + return image, label | |
30 | + | |
31 | + def __add__(self, other): | |
32 | + return ImageDataSet( | |
33 | + img_dir=self.img_dir+other.img_dir, | |
34 | + annotations=np.array(list(self.annotations)+list(other.annotations)), | |
35 | + transform=self.transform | |
36 | + ) |
+++ tools/logger.py
... | ... | @@ -0,0 +1,19 @@ |
1 | +import sys | |
2 | +import time | |
3 | +class Logger(): | |
4 | + def __init__(self): | |
5 | + self.start_time = time.time() | |
6 | + self.epoch_start_time = self.start_time | |
7 | + def print_training_log(self, current_epoch, total_epoch, **kargs): | |
8 | + current_time = time.time() | |
9 | + epoch_time = current_time - self.epoch_start_time | |
10 | + total_time = current_time - self.start_time | |
11 | + | |
12 | + estimated_total_time = total_time * total_epoch / (current_epoch + 1) | |
13 | + remaining_time = estimated_total_time - total_time | |
14 | + | |
15 | + self.epoch_start_time = current_time | |
16 | + | |
17 | + sys.stdout.write( | |
18 | + f"epoch : {current_epoch}/{total_epoch}\n" | |
19 | + )(No newline at end of file) |
--- train.py
+++ train.py
... | ... | @@ -1,11 +1,25 @@ |
1 |
+import sys |
|
2 |
+import os |
|
1 | 3 |
import torch |
2 | 4 |
import numpy as np |
3 | 5 |
import pandas as pd |
4 | 6 |
import plotly.express as px |
7 |
+from torchvision.utils import save_image |
|
8 |
+ |
|
5 | 9 |
from model import Autoencoder |
6 | 10 |
from model import Generator |
7 | 11 |
from model import Discriminator |
8 | 12 |
from model import AttentiveRNN |
13 |
+from tools.argparser import get_param |
|
14 |
+from tools.logger import Logger |
|
15 |
+ |
|
16 |
+param = get_param() |
|
17 |
+ |
|
18 |
+logger = Logger() |
|
19 |
+ |
|
20 |
+## RNN 따로 돌리고 CPU로 메모리 옳기고 |
|
21 |
+## Autoencoder 따로 돌리고 메모리 옳기고 |
|
22 |
+## 안되는가 |
|
9 | 23 |
|
10 | 24 |
## 대충 열심히 GAN 구성하는 코드 |
11 | 25 |
## 대충 그래서 weight export해서 inference용과 training용으로 나누는 코드 |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?