mirror of https://github.com/coqui-ai/TTS.git
Fix multi-speaker init of Tacotron models & tests
parent
01324c8e70
commit
e4648ffef1
|
@ -8,10 +8,10 @@ class GST(nn.Module):
|
||||||
|
|
||||||
See https://arxiv.org/pdf/1803.09017"""
|
See https://arxiv.org/pdf/1803.09017"""
|
||||||
|
|
||||||
def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None):
|
def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim)
|
self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim)
|
||||||
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim)
|
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim)
|
||||||
|
|
||||||
def forward(self, inputs, speaker_embedding=None):
|
def forward(self, inputs, speaker_embedding=None):
|
||||||
enc_out = self.encoder(inputs)
|
enc_out = self.encoder(inputs)
|
||||||
|
@ -83,19 +83,19 @@ class ReferenceEncoder(nn.Module):
|
||||||
class StyleTokenLayer(nn.Module):
|
class StyleTokenLayer(nn.Module):
|
||||||
"""NN Module attending to style tokens based on prosody encodings."""
|
"""NN Module attending to style tokens based on prosody encodings."""
|
||||||
|
|
||||||
def __init__(self, num_heads, num_style_tokens, embedding_dim, d_vector_dim=None):
|
def __init__(self, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.query_dim = embedding_dim // 2
|
self.query_dim = gst_embedding_dim // 2
|
||||||
|
|
||||||
if d_vector_dim:
|
if d_vector_dim:
|
||||||
self.query_dim += d_vector_dim
|
self.query_dim += d_vector_dim
|
||||||
|
|
||||||
self.key_dim = embedding_dim // num_heads
|
self.key_dim = gst_embedding_dim // num_heads
|
||||||
self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim))
|
self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim))
|
||||||
nn.init.normal_(self.style_tokens, mean=0, std=0.5)
|
nn.init.normal_(self.style_tokens, mean=0, std=0.5)
|
||||||
self.attention = MultiHeadAttention(
|
self.attention = MultiHeadAttention(
|
||||||
query_dim=self.query_dim, key_dim=self.key_dim, num_units=embedding_dim, num_heads=num_heads
|
query_dim=self.query_dim, key_dim=self.key_dim, num_units=gst_embedding_dim, num_heads=num_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
|
|
|
@ -76,9 +76,6 @@ class BaseTacotron(BaseTTS):
|
||||||
self.decoder_backward = None
|
self.decoder_backward = None
|
||||||
self.coarse_decoder = None
|
self.coarse_decoder = None
|
||||||
|
|
||||||
# init multi-speaker layers
|
|
||||||
self.init_multispeaker(config)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _format_aux_input(aux_input: Dict) -> Dict:
|
def _format_aux_input(aux_input: Dict) -> Dict:
|
||||||
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
|
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
|
||||||
|
@ -237,6 +234,7 @@ class BaseTacotron(BaseTTS):
|
||||||
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
||||||
"""Compute global style token"""
|
"""Compute global style token"""
|
||||||
if isinstance(style_input, dict):
|
if isinstance(style_input, dict):
|
||||||
|
# multiply each style token with a weight
|
||||||
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
|
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
|
||||||
if speaker_embedding is not None:
|
if speaker_embedding is not None:
|
||||||
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
||||||
|
@ -248,8 +246,10 @@ class BaseTacotron(BaseTTS):
|
||||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||||
elif style_input is None:
|
elif style_input is None:
|
||||||
|
# ignore style token and return zero tensor
|
||||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
||||||
else:
|
else:
|
||||||
|
# compute style tokens
|
||||||
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
||||||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
|
@ -12,15 +12,19 @@ from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.data import sequence_mask
|
from TTS.tts.utils.data import sequence_mask
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.speakers import get_speaker_manager
|
from TTS.tts.utils.speakers import get_speaker_manager
|
||||||
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
class GlowTTS(BaseTTS):
|
class GlowTTS(BaseTTS):
|
||||||
"""Glow TTS models from https://arxiv.org/abs/2005.11129
|
"""GlowTTS model.
|
||||||
|
|
||||||
Paper abstract:
|
Paper::
|
||||||
|
https://arxiv.org/abs/2005.11129
|
||||||
|
|
||||||
|
Paper abstract::
|
||||||
Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate
|
Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate
|
||||||
mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained
|
mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained
|
||||||
without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS,
|
without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS,
|
||||||
|
@ -145,7 +149,6 @@ class GlowTTS(BaseTTS):
|
||||||
g = F.normalize(g).unsqueeze(-1)
|
g = F.normalize(g).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# embedding pass
|
# embedding pass
|
||||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||||
|
@ -362,12 +365,49 @@ class GlowTTS(BaseTTS):
|
||||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||||
return figures, {"audio": train_audio}
|
return figures, {"audio": train_audio}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||||
return self.train_step(batch, criterion)
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||||
return self.train_log(ap, batch, outputs)
|
return self.train_log(ap, batch, outputs)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_run(self, ap):
|
||||||
|
"""Generic test run for `tts` models used by `Trainer`.
|
||||||
|
|
||||||
|
You can override this for a different behaviour.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||||
|
"""
|
||||||
|
print(" | > Synthesizing test sentences.")
|
||||||
|
test_audios = {}
|
||||||
|
test_figures = {}
|
||||||
|
test_sentences = self.config.test_sentences
|
||||||
|
aux_inputs = self.get_aux_input()
|
||||||
|
for idx, sen in enumerate(test_sentences):
|
||||||
|
outputs = synthesis(
|
||||||
|
self,
|
||||||
|
sen,
|
||||||
|
self.config,
|
||||||
|
"cuda" in str(next(self.parameters()).device),
|
||||||
|
ap,
|
||||||
|
speaker_id=aux_inputs["speaker_id"],
|
||||||
|
d_vector=aux_inputs["d_vector"],
|
||||||
|
style_wav=aux_inputs["style_wav"],
|
||||||
|
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||||
|
use_griffin_lim=True,
|
||||||
|
do_trim_silence=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_audios["{}-audio".format(idx)] = outputs["wav"]
|
||||||
|
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||||
|
outputs["outputs"]["model_outputs"], ap, output_fig=False
|
||||||
|
)
|
||||||
|
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
|
||||||
|
return test_figures, test_audios
|
||||||
|
|
||||||
def preprocess(self, y, y_lengths, y_max_length, attn=None):
|
def preprocess(self, y, y_lengths, y_max_length, attn=None):
|
||||||
if y_max_length is not None:
|
if y_max_length is not None:
|
||||||
y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze
|
y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze
|
||||||
|
|
|
@ -106,7 +106,7 @@ class SpeedySpeech(BaseTTS):
|
||||||
if isinstance(config.model_args.length_scale, int)
|
if isinstance(config.model_args.length_scale, int)
|
||||||
else config.model_args.length_scale
|
else config.model_args.length_scale
|
||||||
)
|
)
|
||||||
self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels)
|
self.emb = nn.Embedding(self.num_chars, config.model_args.hidden_channels)
|
||||||
self.encoder = Encoder(
|
self.encoder = Encoder(
|
||||||
config.model_args.hidden_channels,
|
config.model_args.hidden_channels,
|
||||||
config.model_args.hidden_channels,
|
config.model_args.hidden_channels,
|
||||||
|
@ -228,6 +228,7 @@ class SpeedySpeech(BaseTTS):
|
||||||
outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn}
|
outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
|
|
|
@ -30,12 +30,11 @@ class Tacotron(BaseTacotron):
|
||||||
for key in config:
|
for key in config:
|
||||||
setattr(self, key, config[key])
|
setattr(self, key, config[key])
|
||||||
|
|
||||||
# speaker embedding layer
|
# set speaker embedding channel size for determining `in_channels` for the connected layers.
|
||||||
if self.num_speakers > 1:
|
# `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based
|
||||||
|
# on the number of speakers infered from the dataset.
|
||||||
|
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||||
self.init_multispeaker(config)
|
self.init_multispeaker(config)
|
||||||
|
|
||||||
# speaker and gst embeddings is concat in decoder input
|
|
||||||
if self.num_speakers > 1:
|
|
||||||
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
|
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
|
||||||
|
|
||||||
if self.use_gst:
|
if self.use_gst:
|
||||||
|
@ -75,13 +74,11 @@ class Tacotron(BaseTacotron):
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
self.gst_layer = GST(
|
self.gst_layer = GST(
|
||||||
num_mel=self.decoder_output_dim,
|
num_mel=self.decoder_output_dim,
|
||||||
d_vector_dim=self.d_vector_dim
|
|
||||||
if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding
|
|
||||||
else None,
|
|
||||||
num_heads=self.gst.gst_num_heads,
|
num_heads=self.gst.gst_num_heads,
|
||||||
num_style_tokens=self.gst.gst_num_style_tokens,
|
num_style_tokens=self.gst.gst_num_style_tokens,
|
||||||
gst_embedding_dim=self.gst.gst_embedding_dim,
|
gst_embedding_dim=self.gst.gst_embedding_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
# backward pass decoder
|
# backward pass decoder
|
||||||
if self.bidirectional_decoder:
|
if self.bidirectional_decoder:
|
||||||
self._init_backward_decoder()
|
self._init_backward_decoder()
|
||||||
|
@ -106,7 +103,9 @@ class Tacotron(BaseTacotron):
|
||||||
self.max_decoder_steps,
|
self.max_decoder_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None):
|
def forward( # pylint: disable=dangerous-default-value
|
||||||
|
self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None}
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
text: [B, T_in]
|
text: [B, T_in]
|
||||||
|
@ -115,6 +114,7 @@ class Tacotron(BaseTacotron):
|
||||||
mel_lengths: [B]
|
mel_lengths: [B]
|
||||||
aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
|
aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
|
||||||
"""
|
"""
|
||||||
|
aux_input = self._format_aux_input(aux_input)
|
||||||
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
||||||
inputs = self.embedding(text)
|
inputs = self.embedding(text)
|
||||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||||
|
@ -125,12 +125,10 @@ class Tacotron(BaseTacotron):
|
||||||
# global style token
|
# global style token
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(
|
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||||
encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
|
||||||
)
|
|
||||||
# speaker embedding
|
# speaker embedding
|
||||||
if self.num_speakers > 1:
|
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||||
if not self.use_d_vectors:
|
if not self.use_d_vector_file:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
|
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
|
||||||
else:
|
else:
|
||||||
|
@ -182,7 +180,7 @@ class Tacotron(BaseTacotron):
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.use_d_vectors:
|
if not self.use_d_vector_file:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])
|
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])
|
||||||
# reshape embedded_speakers
|
# reshape embedded_speakers
|
||||||
|
|
|
@ -31,12 +31,11 @@ class Tacotron2(BaseTacotron):
|
||||||
for key in config:
|
for key in config:
|
||||||
setattr(self, key, config[key])
|
setattr(self, key, config[key])
|
||||||
|
|
||||||
# speaker embedding layer
|
# set speaker embedding channel size for determining `in_channels` for the connected layers.
|
||||||
if self.num_speakers > 1:
|
# `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based
|
||||||
|
# on the number of speakers infered from the dataset.
|
||||||
|
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||||
self.init_multispeaker(config)
|
self.init_multispeaker(config)
|
||||||
|
|
||||||
# speaker and gst embeddings is concat in decoder input
|
|
||||||
if self.num_speakers > 1:
|
|
||||||
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
|
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
|
||||||
|
|
||||||
if self.use_gst:
|
if self.use_gst:
|
||||||
|
@ -47,6 +46,7 @@ class Tacotron2(BaseTacotron):
|
||||||
|
|
||||||
# base model layers
|
# base model layers
|
||||||
self.encoder = Encoder(self.encoder_in_features)
|
self.encoder = Encoder(self.encoder_in_features)
|
||||||
|
|
||||||
self.decoder = Decoder(
|
self.decoder = Decoder(
|
||||||
self.decoder_in_features,
|
self.decoder_in_features,
|
||||||
self.decoder_output_dim,
|
self.decoder_output_dim,
|
||||||
|
@ -73,9 +73,6 @@ class Tacotron2(BaseTacotron):
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
self.gst_layer = GST(
|
self.gst_layer = GST(
|
||||||
num_mel=self.decoder_output_dim,
|
num_mel=self.decoder_output_dim,
|
||||||
d_vector_dim=self.d_vector_dim
|
|
||||||
if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding
|
|
||||||
else None,
|
|
||||||
num_heads=self.gst.gst_num_heads,
|
num_heads=self.gst.gst_num_heads,
|
||||||
num_style_tokens=self.gst.gst_num_style_tokens,
|
num_style_tokens=self.gst.gst_num_style_tokens,
|
||||||
gst_embedding_dim=self.gst.gst_embedding_dim,
|
gst_embedding_dim=self.gst.gst_embedding_dim,
|
||||||
|
@ -110,7 +107,9 @@ class Tacotron2(BaseTacotron):
|
||||||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments
|
return mel_outputs, mel_outputs_postnet, alignments
|
||||||
|
|
||||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None):
|
def forward( # pylint: disable=dangerous-default-value
|
||||||
|
self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None}
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
text: [B, T_in]
|
text: [B, T_in]
|
||||||
|
@ -130,11 +129,10 @@ class Tacotron2(BaseTacotron):
|
||||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(
|
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||||
encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
|
||||||
)
|
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||||
if self.num_speakers > 1:
|
if not self.use_d_vector_file:
|
||||||
if not self.use_d_vectors:
|
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
|
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
|
||||||
else:
|
else:
|
||||||
|
@ -186,8 +184,9 @@ class Tacotron2(BaseTacotron):
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
||||||
|
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.use_d_vectors:
|
if not self.use_d_vector_file:
|
||||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None]
|
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None]
|
||||||
# reshape embedded_speakers
|
# reshape embedded_speakers
|
||||||
if embedded_speakers.ndim == 1:
|
if embedded_speakers.ndim == 1:
|
||||||
|
|
|
@ -360,10 +360,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
||||||
elif c.use_d_vector_file and c.d_vector_file:
|
elif c.use_d_vector_file and c.d_vector_file:
|
||||||
# new speaker manager with external speaker embeddings.
|
# new speaker manager with external speaker embeddings.
|
||||||
speaker_manager.set_d_vectors_from_file(c.d_vector_file)
|
speaker_manager.set_d_vectors_from_file(c.d_vector_file)
|
||||||
elif c.use_d_vector_file and not c.d_vector_file: # new speaker manager with speaker IDs file.
|
elif c.use_d_vector_file and not c.d_vector_file:
|
||||||
raise "use_d_vector_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
|
raise "use_d_vector_file is True, so you need pass a external speaker embedding file."
|
||||||
|
elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file:
|
||||||
|
# new speaker manager with speaker IDs file.
|
||||||
|
speaker_manager.set_speaker_ids_from_file(c.speakers_file)
|
||||||
print(
|
print(
|
||||||
" > Training with {} speakers: {}".format(
|
" > Speaker manager is loaded with {} speakers: {}".format(
|
||||||
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
|
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -29,6 +29,7 @@ config = Tacotron2Config(
|
||||||
"Be a voice, not an echo.",
|
"Be a voice, not an echo.",
|
||||||
],
|
],
|
||||||
d_vector_file="tests/data/ljspeech/speakers.json",
|
d_vector_file="tests/data/ljspeech/speakers.json",
|
||||||
|
d_vector_dim=256,
|
||||||
max_decoder_steps=50,
|
max_decoder_steps=50,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -25,8 +25,68 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
|
|
||||||
|
|
||||||
class TacotronTrainTest(unittest.TestCase):
|
class TacotronTrainTest(unittest.TestCase):
|
||||||
|
"""Test vanilla Tacotron2 model."""
|
||||||
|
|
||||||
def test_train_step(self): # pylint: disable=no-self-use
|
def test_train_step(self): # pylint: disable=no-self-use
|
||||||
config = config_global.copy()
|
config = config_global.copy()
|
||||||
|
config.use_speaker_embedding = False
|
||||||
|
config.num_speakers = 1
|
||||||
|
|
||||||
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
|
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||||
|
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||||
|
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
||||||
|
mel_postnet_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
||||||
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
|
mel_lengths[0] = 30
|
||||||
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
|
|
||||||
|
for idx in mel_lengths:
|
||||||
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
|
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
|
||||||
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
|
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||||
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
|
model = Tacotron2(config).to(device)
|
||||||
|
model.train()
|
||||||
|
model_ref = copy.deepcopy(model)
|
||||||
|
count = 0
|
||||||
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
|
assert (param - param_ref).sum() == 0, param
|
||||||
|
count += 1
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
||||||
|
for i in range(5):
|
||||||
|
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
|
||||||
|
assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0
|
||||||
|
assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
||||||
|
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
|
||||||
|
loss = loss + criterion(outputs["model_outputs"], mel_postnet_spec, mel_lengths) + stop_loss
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
# check parameter changes
|
||||||
|
count = 0
|
||||||
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
|
# ignore pre-higway layer since it works conditional
|
||||||
|
# if count not in [145, 59]:
|
||||||
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
|
count, param.shape, param, param_ref
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
|
class MultiSpeakerTacotronTrainTest(unittest.TestCase):
|
||||||
|
"""Test multi-speaker Tacotron2 with speaker embedding layer"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_train_step():
|
||||||
|
config = config_global.copy()
|
||||||
|
config.use_speaker_embedding = True
|
||||||
|
config.num_speakers = 5
|
||||||
|
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||||
|
@ -45,6 +105,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
|
|
||||||
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
|
config.d_vector_dim = 55
|
||||||
model = Tacotron2(config).to(device)
|
model = Tacotron2(config).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
|
@ -76,65 +137,18 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
|
||||||
@staticmethod
|
|
||||||
def test_train_step():
|
|
||||||
config = config_global.copy()
|
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
|
||||||
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
|
||||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
|
||||||
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
|
||||||
mel_postnet_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
|
||||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
|
||||||
mel_lengths[0] = 30
|
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
|
||||||
speaker_ids = torch.rand(8, 55).to(device)
|
|
||||||
|
|
||||||
for idx in mel_lengths:
|
|
||||||
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
|
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
|
||||||
|
|
||||||
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
|
||||||
config.d_vector_dim = 55
|
|
||||||
model = Tacotron2(config).to(device)
|
|
||||||
model.train()
|
|
||||||
model_ref = copy.deepcopy(model)
|
|
||||||
count = 0
|
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
|
||||||
assert (param - param_ref).sum() == 0, param
|
|
||||||
count += 1
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
|
||||||
for i in range(5):
|
|
||||||
outputs = model.forward(
|
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_ids}
|
|
||||||
)
|
|
||||||
assert torch.sigmoid(outputs["stop_tokens"]).data.max() <= 1.0
|
|
||||||
assert torch.sigmoid(outputs["stop_tokens"]).data.min() >= 0.0
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
|
||||||
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
|
|
||||||
loss = loss + criterion(outputs["model_outputs"], mel_postnet_spec, mel_lengths) + stop_loss
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
# check parameter changes
|
|
||||||
count = 0
|
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
|
||||||
# ignore pre-higway layer since it works conditional
|
|
||||||
# if count not in [145, 59]:
|
|
||||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
|
||||||
count, param.shape, param, param_ref
|
|
||||||
)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
|
|
||||||
class TacotronGSTTrainTest(unittest.TestCase):
|
class TacotronGSTTrainTest(unittest.TestCase):
|
||||||
|
"""Test multi-speaker Tacotron2 with Global Style Token and Speaker Embedding"""
|
||||||
|
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
def test_train_step(self):
|
def test_train_step(self):
|
||||||
# with random gst mel style
|
# with random gst mel style
|
||||||
config = config_global.copy()
|
config = config_global.copy()
|
||||||
|
config.use_speaker_embedding = True
|
||||||
|
config.num_speakers = 10
|
||||||
|
config.use_gst = True
|
||||||
|
config.gst = GSTConfig()
|
||||||
|
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||||
|
@ -247,9 +261,17 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||||
|
"""Test multi-speaker Tacotron2 with Global Style Tokens and d-vector inputs."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
|
|
||||||
config = config_global.copy()
|
config = config_global.copy()
|
||||||
|
config.use_d_vector_file = True
|
||||||
|
|
||||||
|
config.use_gst = True
|
||||||
|
config.gst = GSTConfig()
|
||||||
|
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
input_lengths = torch.randint(100, 128, (8,)).long().to(device)
|
||||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||||
|
|
|
@ -32,6 +32,61 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
config = config_global.copy()
|
config = config_global.copy()
|
||||||
|
config.use_speaker_embedding = False
|
||||||
|
config.num_speakers = 1
|
||||||
|
|
||||||
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
|
input_lengths[-1] = 128
|
||||||
|
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
||||||
|
linear_spec = torch.rand(8, 30, config.audio["fft_size"] // 2 + 1).to(device)
|
||||||
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
|
mel_lengths[-1] = mel_spec.size(1)
|
||||||
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
|
|
||||||
|
for idx in mel_lengths:
|
||||||
|
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
||||||
|
|
||||||
|
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
|
||||||
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
|
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
||||||
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
|
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
|
||||||
|
model.train()
|
||||||
|
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
|
||||||
|
model_ref = copy.deepcopy(model)
|
||||||
|
count = 0
|
||||||
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
|
assert (param - param_ref).sum() == 0, param
|
||||||
|
count += 1
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
||||||
|
for _ in range(5):
|
||||||
|
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
||||||
|
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
|
||||||
|
loss = loss + criterion(outputs["model_outputs"], linear_spec, mel_lengths) + stop_loss
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
# check parameter changes
|
||||||
|
count = 0
|
||||||
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
|
# ignore pre-higway layer since it works conditional
|
||||||
|
# if count not in [145, 59]:
|
||||||
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
|
count, param.shape, param, param_ref
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
|
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||||
|
@staticmethod
|
||||||
|
def test_train_step():
|
||||||
|
config = config_global.copy()
|
||||||
|
config.use_speaker_embedding = True
|
||||||
|
config.num_speakers = 5
|
||||||
|
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 128
|
input_lengths[-1] = 128
|
||||||
|
@ -50,6 +105,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
|
|
||||||
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
|
config.d_vector_dim = 55
|
||||||
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
|
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
|
||||||
model.train()
|
model.train()
|
||||||
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
|
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
|
||||||
|
@ -80,63 +136,14 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
|
||||||
@staticmethod
|
|
||||||
def test_train_step():
|
|
||||||
config = config_global.copy()
|
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
|
||||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
|
||||||
input_lengths[-1] = 128
|
|
||||||
mel_spec = torch.rand(8, 30, config.audio["num_mels"]).to(device)
|
|
||||||
linear_spec = torch.rand(8, 30, config.audio["fft_size"] // 2 + 1).to(device)
|
|
||||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
|
||||||
mel_lengths[-1] = mel_spec.size(1)
|
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
|
||||||
speaker_embeddings = torch.rand(8, 55).to(device)
|
|
||||||
|
|
||||||
for idx in mel_lengths:
|
|
||||||
stop_targets[:, int(idx.item()) :, 0] = 1.0
|
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // config.r, -1)
|
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
|
||||||
|
|
||||||
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
|
||||||
config.d_vector_dim = 55
|
|
||||||
model = Tacotron(config).to(device) # FIXME: missing num_speakers parameter to Tacotron ctor
|
|
||||||
model.train()
|
|
||||||
print(" > Num parameters for Tacotron model:%s" % (count_parameters(model)))
|
|
||||||
model_ref = copy.deepcopy(model)
|
|
||||||
count = 0
|
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
|
||||||
assert (param - param_ref).sum() == 0, param
|
|
||||||
count += 1
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
|
||||||
for _ in range(5):
|
|
||||||
outputs = model.forward(
|
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, aux_input={"d_vectors": speaker_embeddings}
|
|
||||||
)
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss = criterion(outputs["decoder_outputs"], mel_spec, mel_lengths)
|
|
||||||
stop_loss = criterion_st(outputs["stop_tokens"], stop_targets)
|
|
||||||
loss = loss + criterion(outputs["model_outputs"], linear_spec, mel_lengths) + stop_loss
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
# check parameter changes
|
|
||||||
count = 0
|
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
|
||||||
# ignore pre-higway layer since it works conditional
|
|
||||||
# if count not in [145, 59]:
|
|
||||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
|
||||||
count, param.shape, param, param_ref
|
|
||||||
)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
|
|
||||||
class TacotronGSTTrainTest(unittest.TestCase):
|
class TacotronGSTTrainTest(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
config = config_global.copy()
|
config = config_global.copy()
|
||||||
|
config.use_speaker_embedding = True
|
||||||
|
config.num_speakers = 10
|
||||||
|
config.use_gst = True
|
||||||
|
config.gst = GSTConfig()
|
||||||
# with random gst mel style
|
# with random gst mel style
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
|
@ -244,6 +251,11 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
config = config_global.copy()
|
config = config_global.copy()
|
||||||
|
config.use_d_vector_file = True
|
||||||
|
|
||||||
|
config.use_gst = True
|
||||||
|
config.gst = GSTConfig()
|
||||||
|
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 128
|
input_lengths[-1] = 128
|
||||||
|
|
Loading…
Reference in New Issue