mirror of https://github.com/coqui-ai/TTS.git
linter fix
parent
99420a2d9b
commit
6a661f98e2
4
train.py
4
train.py
|
@ -198,7 +198,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, int) or isinstance(value, float):
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
|
@ -336,7 +336,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, int) or isinstance(value, float):
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
|
|
|
@ -17,7 +17,7 @@ from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
|||
from TTS.utils.io import copy_config_file, load_config
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.training import setup_torch_training_env, NoamLR
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||
# from distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||
|
|
Loading…
Reference in New Issue