diff --git a/train.py b/train.py index 02f28c1d..f6d73a4a 100644 --- a/train.py +++ b/train.py @@ -198,7 +198,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, # detach loss values loss_dict_new = dict() for key, value in loss_dict.items(): - if isinstance(value, int) or isinstance(value, float): + if isinstance(value, (int, float)): loss_dict_new[key] = value else: loss_dict_new[key] = value.item() @@ -336,7 +336,7 @@ def evaluate(model, criterion, ap, global_step, epoch): # detach loss values loss_dict_new = dict() for key, value in loss_dict.items(): - if isinstance(value, int) or isinstance(value, float): + if isinstance(value, (int, float)): loss_dict_new[key] = value else: loss_dict_new[key] = value.item() diff --git a/vocoder/train.py b/vocoder/train.py index 091a6932..ce8f111a 100644 --- a/vocoder/train.py +++ b/vocoder/train.py @@ -17,7 +17,7 @@ from TTS.utils.generic_utils import (KeepAverage, count_parameters, from TTS.utils.io import copy_config_file, load_config from TTS.utils.radam import RAdam from TTS.utils.tensorboard_logger import TensorboardLogger -from TTS.utils.training import setup_torch_training_env, NoamLR +from TTS.utils.training import setup_torch_training_env from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.datasets.preprocess import load_wav_data # from distribute import (DistributedSampler, apply_gradient_allreduce,