From c614f219828139f972a0e893178d8448dca8aa41 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Tue, 12 Jul 2022 14:25:21 +0200 Subject: [PATCH] Add durations as aux input for VITS (#1694) * Add durations as aux input for VITS * Make style * Fix tts_tests * Fix test_get_aux_input --- TTS/tts/models/vits.py | 47 +++++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index a6b1c743..9263c0b1 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -786,7 +786,7 @@ class Vits(BaseTTS): print(" > Text Encoder was reinit.") def get_aux_input(self, aux_input: Dict): - sid, g, lid = self._set_cond_input(aux_input) + sid, g, lid, _ = self._set_cond_input(aux_input) return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} def _freeze_layers(self): @@ -817,7 +817,7 @@ class Vits(BaseTTS): @staticmethod def _set_cond_input(aux_input: Dict): """Set the speaker conditioning input based on the multi-speaker mode.""" - sid, g, lid = None, None, None + sid, g, lid, durations = None, None, None, None if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: sid = aux_input["speaker_ids"] if sid.ndim == 0: @@ -832,7 +832,10 @@ class Vits(BaseTTS): if lid.ndim == 0: lid = lid.unsqueeze_(0) - return sid, g, lid + if "durations" in aux_input and aux_input["durations"] is not None: + durations = aux_input["durations"] + + return sid, g, lid, durations def _set_speaker_input(self, aux_input: Dict): d_vectors = aux_input.get("d_vectors", None) @@ -946,7 +949,7 @@ class Vits(BaseTTS): - syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` """ outputs = {} - sid, g, lid = self._set_cond_input(aux_input) + sid, g, lid, _ = self._set_cond_input(aux_input) # speaker embedding if self.args.use_speaker_embedding and sid is not None: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] @@ -1028,7 +1031,9 @@ class Vits(BaseTTS): @torch.no_grad() def inference( - self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None} + self, + x, + aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None}, ): # pylint: disable=dangerous-default-value """ Note: @@ -1048,7 +1053,7 @@ class Vits(BaseTTS): - m_p: :math:`[B, C, T_dec]` - logs_p: :math:`[B, C, T_dec]` """ - sid, g, lid = self._set_cond_input(aux_input) + sid, g, lid, durations = self._set_cond_input(aux_input) x_lengths = self._set_x_lengths(x, aux_input) # speaker embedding @@ -1062,21 +1067,25 @@ class Vits(BaseTTS): x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) - if self.args.use_sdp: - logw = self.duration_predictor( - x, - x_mask, - g=g if self.args.condition_dp_on_speaker else None, - reverse=True, - noise_scale=self.inference_noise_scale_dp, - lang_emb=lang_emb, - ) + if durations is None: + if self.args.use_sdp: + logw = self.duration_predictor( + x, + x_mask, + g=g if self.args.condition_dp_on_speaker else None, + reverse=True, + noise_scale=self.inference_noise_scale_dp, + lang_emb=lang_emb, + ) + else: + logw = self.duration_predictor( + x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb + ) + w = torch.exp(logw) * x_mask * self.length_scale else: - logw = self.duration_predictor( - x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb - ) + assert durations.shape[-1] == x.shape[-1] + w = durations.unsqueeze(0) - w = torch.exp(logw) * x_mask * self.length_scale w_ceil = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec]