mirror of https://github.com/coqui-ai/TTS.git
Update Logger API, recipes
parent
f63cf46c55
commit
936a47504d
|
@ -116,12 +116,12 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step):
|
||||||
"step_time": step_time,
|
"step_time": step_time,
|
||||||
"avg_loader_time": avg_loader_time,
|
"avg_loader_time": avg_loader_time,
|
||||||
}
|
}
|
||||||
tb_logger.tb_train_epoch_stats(global_step, train_stats)
|
dashboard_logger.train_epoch_stats(global_step, train_stats)
|
||||||
figures = {
|
figures = {
|
||||||
# FIXME: not constant
|
# FIXME: not constant
|
||||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10),
|
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10),
|
||||||
}
|
}
|
||||||
tb_logger.tb_train_figures(global_step, figures)
|
dashboard_logger.train_figures(global_step, figures)
|
||||||
|
|
||||||
if global_step % c.print_step == 0:
|
if global_step % c.print_step == 0:
|
||||||
print(
|
print(
|
||||||
|
|
|
@ -184,7 +184,6 @@ class Trainer:
|
||||||
if not self.config.log_model_step:
|
if not self.config.log_model_step:
|
||||||
self.config.log_model_step = self.config.save_step
|
self.config.log_model_step = self.config.save_step
|
||||||
|
|
||||||
|
|
||||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||||
self._setup_logger_config(log_file)
|
self._setup_logger_config(log_file)
|
||||||
|
|
||||||
|
@ -1147,7 +1146,7 @@ def process_args(args, config=None):
|
||||||
os.chmod(experiment_path, 0o775)
|
os.chmod(experiment_path, 0o775)
|
||||||
|
|
||||||
if config.dashboard_logger == "tensorboard":
|
if config.dashboard_logger == "tensorboard":
|
||||||
dashboard_logger = TensorboardLogger(output_path, model_name=config.model)
|
dashboard_logger = TensorboardLogger(config.output_path, model_name=config.model)
|
||||||
dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
dashboard_logger.add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||||
|
|
||||||
elif config.dashboard_logger == "wandb":
|
elif config.dashboard_logger == "wandb":
|
||||||
|
@ -1162,7 +1161,6 @@ def process_args(args, config=None):
|
||||||
entity=config.wandb_entity,
|
entity=config.wandb_entity,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
c_logger = ConsoleLogger()
|
c_logger = ConsoleLogger()
|
||||||
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,6 @@ class TensorboardLogger(object):
|
||||||
def __init__(self, log_dir, model_name):
|
def __init__(self, log_dir, model_name):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.writer = SummaryWriter(log_dir)
|
self.writer = SummaryWriter(log_dir)
|
||||||
self.train_stats = {}
|
|
||||||
self.eval_stats = {}
|
|
||||||
|
|
||||||
def model_weights(self, model, step):
|
def model_weights(self, model, step):
|
||||||
layer_num = 1
|
layer_num = 1
|
||||||
|
@ -71,11 +69,11 @@ class TensorboardLogger(object):
|
||||||
def add_text(self, title, text, step):
|
def add_text(self, title, text, step):
|
||||||
self.writer.add_text(title, text, step)
|
self.writer.add_text(title, text, step)
|
||||||
|
|
||||||
def log_artifact(self, file_or_dir, name, artifact_type, aliases=None):
|
def log_artifact(self, file_or_dir, name, artifact_type, aliases=None): # pylint: disable=W0613, R0201
|
||||||
return
|
yield
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
return
|
self.writer.flush()
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
return
|
self.writer.close()
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from pathlib import Path
|
# pylint: disable=W0613
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import wandb
|
import wandb
|
||||||
|
@ -29,10 +31,8 @@ class WandbLogger:
|
||||||
self.dict_to_scalar("weights", {"layer{}-{}/min".format(layer_num, name): param.min()})
|
self.dict_to_scalar("weights", {"layer{}-{}/min".format(layer_num, name): param.min()})
|
||||||
self.dict_to_scalar("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()})
|
self.dict_to_scalar("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()})
|
||||||
self.dict_to_scalar("weights", {"layer{}-{}/std".format(layer_num, name): param.std()})
|
self.dict_to_scalar("weights", {"layer{}-{}/std".format(layer_num, name): param.std()})
|
||||||
'''
|
self.log_dict["weights/layer{}-{}/param".format(layer_num, name)] = wandb.Histogram(param)
|
||||||
self.writer.add_histogram("layer{}-{}/param".format(layer_num, name), param, step)
|
self.log_dict["weights/layer{}-{}/grad".format(layer_num, name)] = wandb.Histogram(param.grad)
|
||||||
self.writer.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step)
|
|
||||||
'''
|
|
||||||
layer_num += 1
|
layer_num += 1
|
||||||
|
|
||||||
def dict_to_scalar(self, scope_name, stats):
|
def dict_to_scalar(self, scope_name, stats):
|
||||||
|
@ -52,7 +52,6 @@ class WandbLogger:
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
def log(self, log_dict, prefix="", flush=False):
|
def log(self, log_dict, prefix="", flush=False):
|
||||||
for key, value in log_dict.items():
|
for key, value in log_dict.items():
|
||||||
self.log_dict[prefix + key] = value
|
self.log_dict[prefix + key] = value
|
||||||
|
|
|
@ -25,6 +25,6 @@ config = AlignTTSConfig(
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -25,6 +25,6 @@ config = GlowTTSConfig(
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config],
|
datasets=[dataset_config],
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -24,6 +24,6 @@ config = HifiganConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -24,6 +24,6 @@ config = MultibandMelganConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -24,6 +24,6 @@ config = UnivnetConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -22,6 +22,6 @@ config = WavegradConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
|
@ -24,6 +24,6 @@ config = WavernnConfig(
|
||||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger, cudnn_benchmark=True)
|
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=True)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
Loading…
Reference in New Issue