diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 43867239..a41e29a8 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -208,7 +208,7 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == "__main__": - args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv) + args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger, wandb_logger = init_training(sys.argv) try: main(args) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index c491700d..be8d6200 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -5,8 +5,8 @@ from TTS.trainer import Trainer, init_training def main(): """Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```""" - args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv) - trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=False) + args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(sys.argv) + trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger, cudnn_benchmark=False) trainer.fit() diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index 868aae2e..1eac603e 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -8,8 +8,8 @@ from TTS.utils.generic_utils import remove_experiment_folder def main(): try: - args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv) - trainer = Trainer(args, config, output_path, c_logger, tb_logger) + args, config, output_path, _, c_logger, tb_logger, wandb_logger = init_training(sys.argv) + trainer = Trainer(args, config, output_path, c_logger, tb_logger, wandb_logger) trainer.fit() except KeyboardInterrupt: remove_experiment_folder(output_path) diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index dd92da65..8d3b108e 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -231,6 +231,16 @@ class BaseTrainingConfig(Coqpit): tb_model_param_stats (bool): Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging. Defaults to ```False```. + wandb_disabled: bool = False + + wandb_project_name (str): + Name of the W&B project. Defaults to config.model + + wandb_entity (str): + Name of W&B entity/team. Enables collaboration across a team or org. + + wandb_log_model_step (int): + Number of steps required to log a checkpoint as W&B artifact save_step (int):ipt Number of steps required to save the next checkpoint. @@ -276,6 +286,10 @@ class BaseTrainingConfig(Coqpit): print_step: int = 25 tb_plot_step: int = 100 tb_model_param_stats: bool = False + wandb_disabled: bool = False + wandb_project_name: str = None + wandb_entity: str = None + wandb_log_model_step: int = None # checkpointing save_step: int = 10000 checkpoint: bool = True diff --git a/TTS/trainer.py b/TTS/trainer.py index d39fd747..48ea92b5 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -159,7 +159,7 @@ class Trainer: self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark) if config is None: # parse config from console arguments - config, output_path, _, c_logger, tb_logger = process_args(args) + config, output_path, _, c_logger, tb_logger, wandb_logger = process_args(args) self.output_path = output_path self.args = args @@ -657,6 +657,15 @@ class Trainer: self.output_path, model_loss=target_avg_loss, ) + + if ( + self.config.wandb_log_model_step + and self.total_steps_done % self.config.wandb_log_model_step == 0 + ): + # log checkpoint as W&B artifact + aliases = [f"epoch-{self.epochs_done}", f"step-{self.total_steps_done}"] + self.wandb_logger.log_artifact(self.output_path, "checkpoint", "model", aliases) + # training visualizations figures, audios = None, None if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"): @@ -860,10 +869,13 @@ class Trainer: """Where the ✨️magic✨️ happens...""" try: self._fit() + self.wandb_logger.finish() except KeyboardInterrupt: self.callbacks.on_keyboard_interrupt() # if the output folder is empty remove the run. remove_experiment_folder(self.output_path) + # finish the wandb run and sync data + self.wandb_logger.finish() # stop without error signal try: sys.exit(0) @@ -1092,7 +1104,7 @@ def process_args(args, config=None): logging to the console. tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does the TensorBoard logging. - wandb_logger (TTS.utils.tensorboard.WandbLogger): Class that does the W&B Loggin + wandb_logger (TTS.utils.tensorboard.WandbLogger): Class that does the W&B Logging TODO: - Interactive config definition. @@ -1106,17 +1118,15 @@ def process_args(args, config=None): args.restore_path, best_model = get_last_checkpoint(args.continue_path) if not args.best_path: args.best_path = best_model - # setup output paths and read configs - if config is None: - config = load_config(args.config_path) - # init config + + # init config if not already defined if config is None: if args.config_path: # init from a file config = load_config(args.config_path) else: # init from console args - from TTS.config.shared_configs import BaseTrainingConfig + from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel config_base = BaseTrainingConfig() config_base.parse_known_args(coqpit_overrides) @@ -1148,11 +1158,14 @@ def process_args(args, config=None): # write model desc to tensorboard tb_logger.tb_add_text("model-config", f"
{config.to_json()}
", 0) - wandb_logger = WandbLogger( - project=config.model, - name=os.path.basename(experiment_path), - config=config, - ) + if not config.wandb_disabled: + wandb_project_name = config.model + if config.wandb_project_name: + wandb_project_name = config.wandb_project_name + + wandb_logger = WandbLogger( + project=wandb_project_name, name=config.run_name, config=config, entity=config.wandb_entity + ) c_logger = ConsoleLogger() return config, experiment_path, audio_path, c_logger, tb_logger, wandb_logger diff --git a/TTS/utils/logging/wandb_logger.py b/TTS/utils/logging/wandb_logger.py index 4d6f0c30..e8f6765b 100644 --- a/TTS/utils/logging/wandb_logger.py +++ b/TTS/utils/logging/wandb_logger.py @@ -2,7 +2,7 @@ from pathlib import Path try: import wandb - from wandb import init, finish + from wandb import finish, init # pylint: disable=W0611 except ImportError: wandb = None @@ -15,10 +15,6 @@ class WandbLogger: self.log_dict = {} def log(self, log_dict, prefix="", flush=False): - """ - This function accumulates data in self.log_dict. If flush is set. - the accumulated metrics will be logged directly to wandb dashboard. - """ for key, value in log_dict.items(): self.log_dict[prefix + key] = value if flush: # for cases where you don't want to accumulate data @@ -53,13 +49,14 @@ class WandbLogger: self.log_dict = {} def finish(self): - """ - Finish this W&B run - """ - self.run.finish() + if self.run: + self.run.finish() - def log_artifact(self, file_or_dir, name, type, aliases=[]): - artifact = wandb.Artifact(name, type=type) + def log_artifact(self, file_or_dir, name, artifact_type, aliases=None): + if not self.run: + return + name = "_".join([self.run.id, name]) + artifact = wandb.Artifact(name, type=artifact_type) data_path = Path(file_or_dir) if data_path.is_dir(): artifact.add_dir(str(data_path))