import sys import time import plotly.express as px import dash from dash import dcc from dash import html from dash.dependencies import Input, Output class Logger(): def __init__(self): self.start_time = time.time() self.epoch_start_time = self.start_time self.losses = [] def print_training_log(self, current_epoch, total_epoch, current_batch, total_batch, losses): assert type(losses) == dict current_time = time.time() epoch_time = current_time - self.epoch_start_time total_time = current_time - self.start_time estimated_total_time = total_time * total_epoch / (current_epoch + 1) remaining_time = estimated_total_time - total_time self.epoch_start_time = current_time terminal_logging_string = "" for loss_name, loss_value in losses.items(): terminal_logging_string += f"{loss_name} : {loss_value}\n" if loss_name == 'loss': self.losses.append(loss_value) sys.stdout.write( f"epoch : {current_epoch}/{total_epoch}\n" f"batch : {current_batch}/{total_batch}\n" f"estimated time remaining : {remaining_time}\n" f"{terminal_logging_string}\n" )