mirror of https://github.com/coqui-ai/TTS.git
Fix style tests
parent
ab20a34170
commit
397b3e9baf
|
@ -12,7 +12,7 @@ from TTS.tts.utils.speakers import SpeakerManager
|
|||
def compute_encoder_accuracy(dataset_items, encoder_manager):
|
||||
|
||||
class_name_key = encoder_manager.encoder_config.class_name_key
|
||||
map_classid_to_classname = getattr(encoder_manager.encoder_config, 'map_classid_to_classname', None)
|
||||
map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None)
|
||||
|
||||
class_acc_dict = {}
|
||||
|
||||
|
|
|
@ -279,9 +279,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
# setup multi-speaker attributes
|
||||
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
||||
if hasattr(config, "model_args"):
|
||||
speaker_id_mapping = (
|
||||
self.speaker_manager.ids if config.model_args.use_speaker_embedding else None
|
||||
)
|
||||
speaker_id_mapping = self.speaker_manager.ids if config.model_args.use_speaker_embedding else None
|
||||
d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None
|
||||
config.use_d_vector_file = config.model_args.use_d_vector_file
|
||||
else:
|
||||
|
@ -293,9 +291,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
|
||||
# setup multi-lingual attributes
|
||||
if hasattr(self, "language_manager") and self.language_manager is not None:
|
||||
language_id_mapping = (
|
||||
self.language_manager.ids if self.args.use_language_embedding else None
|
||||
)
|
||||
language_id_mapping = self.language_manager.ids if self.args.use_language_embedding else None
|
||||
else:
|
||||
language_id_mapping = None
|
||||
|
||||
|
|
|
@ -1304,11 +1304,7 @@ class Vits(BaseTTS):
|
|||
d_vectors = torch.FloatTensor(d_vectors)
|
||||
|
||||
# get language ids from language names
|
||||
if (
|
||||
self.language_manager is not None
|
||||
and self.language_manager.ids
|
||||
and self.args.use_language_embedding
|
||||
):
|
||||
if self.language_manager is not None and self.language_manager.ids and self.args.use_language_embedding:
|
||||
language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]]
|
||||
|
||||
if language_ids is not None:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
|
@ -9,6 +9,7 @@ from coqpit import Coqpit
|
|||
from TTS.config import check_config_and_model_args
|
||||
from TTS.tts.utils.managers import BaseIDManager
|
||||
|
||||
|
||||
class LanguageManager(BaseIDManager):
|
||||
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
|
||||
in a way that can be queried by language.
|
||||
|
|
|
@ -12,13 +12,11 @@ from TTS.utils.audio import AudioProcessor
|
|||
|
||||
|
||||
class BaseIDManager:
|
||||
""" Base `ID` Manager class. Every new `ID` manager must inherit this.
|
||||
"""Base `ID` Manager class. Every new `ID` manager must inherit this.
|
||||
It defines common `ID` manager specific functions.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
id_file_path: str = ""
|
||||
):
|
||||
|
||||
def __init__(self, id_file_path: str = ""):
|
||||
self.ids = {}
|
||||
|
||||
if id_file_path:
|
||||
|
@ -85,10 +83,12 @@ class BaseIDManager:
|
|||
ids = {name: i for i, name in enumerate(classes)}
|
||||
return ids
|
||||
|
||||
|
||||
class EmbeddingManager(BaseIDManager):
|
||||
""" Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
|
||||
"""Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
|
||||
It defines common `Embedding` manager specific functions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_file_path: str = "",
|
||||
|
@ -225,7 +225,9 @@ class EmbeddingManager(BaseIDManager):
|
|||
"""
|
||||
self.encoder_config = load_config(config_path)
|
||||
self.encoder = setup_encoder_model(self.encoder_config)
|
||||
self.encoder_criterion = self.encoder.load_checkpoint(self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda)
|
||||
self.encoder_criterion = self.encoder.load_checkpoint(
|
||||
self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda
|
||||
)
|
||||
self.encoder_ap = AudioProcessor(**self.encoder_config.audio)
|
||||
|
||||
def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list:
|
||||
|
|
|
@ -10,6 +10,7 @@ from coqpit import Coqpit
|
|||
from TTS.config import get_from_config_or_model_args_with_default
|
||||
from TTS.tts.utils.managers import EmbeddingManager
|
||||
|
||||
|
||||
class SpeakerManager(EmbeddingManager):
|
||||
"""Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information
|
||||
in a way that can be queried by speaker or clip.
|
||||
|
@ -64,8 +65,8 @@ class SpeakerManager(EmbeddingManager):
|
|||
id_file_path=speaker_id_file_path,
|
||||
encoder_model_path=encoder_model_path,
|
||||
encoder_config_path=encoder_config_path,
|
||||
use_cuda=use_cuda
|
||||
)
|
||||
use_cuda=use_cuda,
|
||||
)
|
||||
|
||||
if data_items:
|
||||
self.set_ids_from_data(data_items, parse_key="speaker_name")
|
||||
|
|
Loading…
Reference in New Issue