mirror of https://github.com/coqui-ai/TTS.git
Implement get_state_dict
parent
dfc19cd3ae
commit
4d1718a19a
|
@ -95,6 +95,9 @@ class ForwardTTSArgs(Coqpit):
|
|||
num_speakers (int):
|
||||
Number of speakers for the speaker embedding layer. Defaults to 0.
|
||||
|
||||
use_speaker_embedding (bool):
|
||||
Whether to use a speaker embedding layer. Defaults to False.
|
||||
|
||||
speakers_file (str):
|
||||
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
|
||||
|
||||
|
@ -107,8 +110,10 @@ class ForwardTTSArgs(Coqpit):
|
|||
d_vector_dim (int):
|
||||
Number of d-vector channels. Defaults to 0.
|
||||
|
||||
"""
|
||||
d_vector_file (str):
|
||||
Path to the d-vector file. Defaults to None.
|
||||
|
||||
"""
|
||||
num_chars: int = None
|
||||
out_channels: int = 80
|
||||
hidden_channels: int = 384
|
||||
|
@ -148,6 +153,7 @@ class ForwardTTSArgs(Coqpit):
|
|||
max_duration: int = 75
|
||||
num_speakers: int = 1
|
||||
use_speaker_embedding: bool = False
|
||||
speaker_embedding_channels: int = 256
|
||||
speakers_file: str = None
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_dim: int = None
|
||||
|
@ -177,9 +183,18 @@ class ForwardTTS(BaseTTS):
|
|||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs
|
||||
>>> config = ForwardTTSArgs()
|
||||
>>> model = ForwardTTS(config)
|
||||
Instantiate the model directly.
|
||||
|
||||
>>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eArgs
|
||||
>>> args = ForwardTTSE2eArgs()
|
||||
>>> model = ForwardTTSE2e(args)
|
||||
|
||||
Instantiate the model from config.
|
||||
|
||||
>>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e
|
||||
>>> from TTS.tts.configs.fast_pitch_e2e_config import FastPitchE2eConfig
|
||||
>>> config = FastPitchE2eConfig(num_chars=10)
|
||||
>>> model = ForwardTTSE2e.init_from_config(config)
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
@ -272,6 +287,7 @@ class ForwardTTS(BaseTTS):
|
|||
self._init_d_vector()
|
||||
|
||||
def _init_speaker_embedding(self):
|
||||
"""Init class arguments for training with a speaker embedding layer."""
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if self.num_speakers > 0:
|
||||
print(" > initialization of speaker-embedding layers.")
|
||||
|
@ -279,6 +295,7 @@ class ForwardTTS(BaseTTS):
|
|||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
|
||||
def _init_d_vector(self):
|
||||
"""Init class arguments for training with external speaker embeddings."""
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if hasattr(self, "emb_g"):
|
||||
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
||||
|
@ -286,7 +303,7 @@ class ForwardTTS(BaseTTS):
|
|||
|
||||
@staticmethod
|
||||
def _set_cond_input(aux_input: Dict):
|
||||
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
||||
"""Set auxilliary model inputs based on the model configuration."""
|
||||
sid, g, lid = None, None, None
|
||||
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||
sid = aux_input["speaker_ids"]
|
||||
|
@ -305,12 +322,19 @@ class ForwardTTS(BaseTTS):
|
|||
return sid, g, lid
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
"""Get auxilliary model inputs based on the model configuration."""
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||
|
||||
@staticmethod
|
||||
def generate_attn(dr, x_mask, y_mask=None):
|
||||
"""Generate an attention mask from the durations.
|
||||
"""Generate an attention mask from the linear scale durations.
|
||||
|
||||
Args:
|
||||
dr (Tensor): Linear scale durations.
|
||||
x_mask (Tensor): Mask for the input (character) sequence.
|
||||
y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations
|
||||
if None. Defaults to None.
|
||||
|
||||
Shapes
|
||||
- dr: :math:`(B, T_{en})`
|
||||
|
@ -327,8 +351,14 @@ class ForwardTTS(BaseTTS):
|
|||
return attn
|
||||
|
||||
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
|
||||
"""Generate attention alignment map from durations and
|
||||
expand encoder outputs
|
||||
"""Generate attention alignment map from linear scale durations and
|
||||
expand encoder outputs.
|
||||
|
||||
Args:
|
||||
en (Tensor): Encoder outputs.
|
||||
dr (Tensor): Linear scale durations.
|
||||
x_mask (Tensor): Mask for the input (character) sequence.
|
||||
y_mask (Tensor): Mask for the output (spectrogram) sequence.
|
||||
|
||||
Shapes:
|
||||
- en: :math:`(B, D_{en}, T_{en})`
|
||||
|
@ -360,8 +390,8 @@ class ForwardTTS(BaseTTS):
|
|||
5. Round the duration values.
|
||||
|
||||
Args:
|
||||
o_dr_log: Log scale durations.
|
||||
x_mask: Input text mask.
|
||||
o_dr_log (Tensor): Log scale durations.
|
||||
x_mask (Tensor): Input text mask.
|
||||
|
||||
Shapes:
|
||||
- o_dr_log: :math:`(B, T_{de})`
|
||||
|
|
|
@ -474,31 +474,6 @@ class ForwardTTSE2e(BaseTTSE2E):
|
|||
model_outputs = {**encoder_outputs}
|
||||
return model_outputs
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (ForwardTTSE2eConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
from TTS.utils.audio.processor import AudioProcessor
|
||||
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
# language_manager = LanguageManager.init_from_config(config)
|
||||
return ForwardTTSE2e(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
||||
if optimizer_idx == 0:
|
||||
tokens = batch["text_input"]
|
||||
|
@ -1000,3 +975,51 @@ class ForwardTTSE2e(BaseTTSE2E):
|
|||
mel_fmax=self.config.audio.mel_fmax,
|
||||
mel_fmin=self.config.audio.mel_fmin,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (ForwardTTSE2eConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
from TTS.utils.audio.processor import AudioProcessor
|
||||
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
# language_manager = LanguageManager.init_from_config(config)
|
||||
return ForwardTTSE2e(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
):
|
||||
"""Load model from a checkpoint created by the 👟"""
|
||||
# pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
def get_state_dict(self):
|
||||
"""Custom state dict of the model with all the necessary components for inference."""
|
||||
save_state = {
|
||||
"config": self.config.to_dict(),
|
||||
"args": self.args.to_dict(),
|
||||
"model": self.state_dict
|
||||
}
|
||||
|
||||
if hasattr(self, "emb_g"):
|
||||
save_state["speaker_ids"] = self.speaker_manager.speaker_ids
|
||||
|
||||
if self.args.use_d_vector_file:
|
||||
# TODO: implement saving of d_vectors
|
||||
...
|
||||
return save_state
|
||||
|
||||
def save(self, config, checkpoint_path):
|
||||
"""Save model to a file."""
|
||||
save_state = self.get_state_dict(config, checkpoint_path)
|
||||
torch.save(save_state, checkpoint_path)
|
||||
|
|
Loading…
Reference in New Issue