Fix style tests

pull/1374/head
Edresson Casanova 2022-03-23 15:31:33 -03:00
parent ab20a34170
commit 397b3e9baf
6 changed files with 18 additions and 22 deletions

View File

@ -12,7 +12,7 @@ from TTS.tts.utils.speakers import SpeakerManager
def compute_encoder_accuracy(dataset_items, encoder_manager):
class_name_key = encoder_manager.encoder_config.class_name_key
map_classid_to_classname = getattr(encoder_manager.encoder_config, 'map_classid_to_classname', None)
map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None)
class_acc_dict = {}

View File

@ -279,9 +279,7 @@ class BaseTTS(BaseTrainerModel):
# setup multi-speaker attributes
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
if hasattr(config, "model_args"):
speaker_id_mapping = (
self.speaker_manager.ids if config.model_args.use_speaker_embedding else None
)
speaker_id_mapping = self.speaker_manager.ids if config.model_args.use_speaker_embedding else None
d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None
config.use_d_vector_file = config.model_args.use_d_vector_file
else:
@ -293,9 +291,7 @@ class BaseTTS(BaseTrainerModel):
# setup multi-lingual attributes
if hasattr(self, "language_manager") and self.language_manager is not None:
language_id_mapping = (
self.language_manager.ids if self.args.use_language_embedding else None
)
language_id_mapping = self.language_manager.ids if self.args.use_language_embedding else None
else:
language_id_mapping = None

View File

@ -1304,11 +1304,7 @@ class Vits(BaseTTS):
d_vectors = torch.FloatTensor(d_vectors)
# get language ids from language names
if (
self.language_manager is not None
and self.language_manager.ids
and self.args.use_language_embedding
):
if self.language_manager is not None and self.language_manager.ids and self.args.use_language_embedding:
language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]]
if language_ids is not None:

View File

@ -1,5 +1,5 @@
import os
from typing import Dict, List, Any
from typing import Any, Dict, List
import fsspec
import numpy as np
@ -9,6 +9,7 @@ from coqpit import Coqpit
from TTS.config import check_config_and_model_args
from TTS.tts.utils.managers import BaseIDManager
class LanguageManager(BaseIDManager):
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
in a way that can be queried by language.

View File

@ -12,13 +12,11 @@ from TTS.utils.audio import AudioProcessor
class BaseIDManager:
""" Base `ID` Manager class. Every new `ID` manager must inherit this.
"""Base `ID` Manager class. Every new `ID` manager must inherit this.
It defines common `ID` manager specific functions.
"""
def __init__(
self,
id_file_path: str = ""
):
def __init__(self, id_file_path: str = ""):
self.ids = {}
if id_file_path:
@ -85,10 +83,12 @@ class BaseIDManager:
ids = {name: i for i, name in enumerate(classes)}
return ids
class EmbeddingManager(BaseIDManager):
""" Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
"""Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
It defines common `Embedding` manager specific functions.
"""
def __init__(
self,
embedding_file_path: str = "",
@ -225,7 +225,9 @@ class EmbeddingManager(BaseIDManager):
"""
self.encoder_config = load_config(config_path)
self.encoder = setup_encoder_model(self.encoder_config)
self.encoder_criterion = self.encoder.load_checkpoint(self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda)
self.encoder_criterion = self.encoder.load_checkpoint(
self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda
)
self.encoder_ap = AudioProcessor(**self.encoder_config.audio)
def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list:

View File

@ -10,6 +10,7 @@ from coqpit import Coqpit
from TTS.config import get_from_config_or_model_args_with_default
from TTS.tts.utils.managers import EmbeddingManager
class SpeakerManager(EmbeddingManager):
"""Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information
in a way that can be queried by speaker or clip.
@ -64,8 +65,8 @@ class SpeakerManager(EmbeddingManager):
id_file_path=speaker_id_file_path,
encoder_model_path=encoder_model_path,
encoder_config_path=encoder_config_path,
use_cuda=use_cuda
)
use_cuda=use_cuda,
)
if data_items:
self.set_ids_from_data(data_items, parse_key="speaker_name")