mirror of https://github.com/coqui-ai/TTS.git
Reformat multi-speaker handling in GlowTTS
parent
8d41060d36
commit
d847a68e42
|
@ -109,6 +109,10 @@ class GlowTTS(BaseTTS):
|
|||
# init speaker manager
|
||||
self.speaker_manager = get_speaker_manager(config, data=data)
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
if config.use_d_vector_file:
|
||||
self.external_d_vector_dim = config.d_vector_dim
|
||||
else:
|
||||
self.external_d_vector_dim = 0
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = self.c_in_channels
|
||||
|
@ -129,7 +133,7 @@ class GlowTTS(BaseTTS):
|
|||
return y_mean, y_log_scale, o_attn_dur
|
||||
|
||||
def forward(
|
||||
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None}
|
||||
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -143,8 +147,8 @@ class GlowTTS(BaseTTS):
|
|||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
if g is not None:
|
||||
if self.d_vector_dim:
|
||||
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||
if not self.use_d_vector_file:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
|
@ -181,7 +185,7 @@ class GlowTTS(BaseTTS):
|
|||
|
||||
@torch.no_grad()
|
||||
def inference_with_MAS(
|
||||
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None}
|
||||
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
It's similar to the teacher forcing in Tacotron.
|
||||
|
@ -198,12 +202,11 @@ class GlowTTS(BaseTTS):
|
|||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
if g is not None:
|
||||
if self.external_d_vector_dim:
|
||||
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||
if not self.use_d_vector_file:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||
|
@ -243,7 +246,7 @@ class GlowTTS(BaseTTS):
|
|||
|
||||
@torch.no_grad()
|
||||
def decoder_inference(
|
||||
self, y, y_lengths=None, aux_input={"d_vectors": None}
|
||||
self, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -275,7 +278,7 @@ class GlowTTS(BaseTTS):
|
|||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None}): # pylint: disable=dangerous-default-value
|
||||
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids":None}): # pylint: disable=dangerous-default-value
|
||||
x_lengths = aux_input["x_lengths"]
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
|
||||
|
@ -326,8 +329,9 @@ class GlowTTS(BaseTTS):
|
|||
mel_input = batch["mel_input"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors})
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors, "speaker_ids":speaker_ids})
|
||||
|
||||
loss_dict = criterion(
|
||||
outputs["model_outputs"],
|
||||
|
|
Loading…
Reference in New Issue