diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index b2e4be9e..cb349ca2 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -171,6 +171,9 @@ class VitsArgs(Coqpit): speaker_encoder_model_path (str): Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". + condition_dp_on_speaker (bool): + Condition the duration predictor on the speaker embedding. Defaults to True. + freeze_encoder (bool): Freeze the encoder weigths during training. Defaults to False. @@ -233,6 +236,7 @@ class VitsArgs(Coqpit): use_speaker_encoder_as_loss: bool = False speaker_encoder_config_path: str = "" speaker_encoder_model_path: str = "" + condition_dp_on_speaker: bool = True freeze_encoder: bool = False freeze_DP: bool = False freeze_PE: bool = False @@ -349,7 +353,7 @@ class Vits(BaseTTS): 3, args.dropout_p_duration_predictor, 4, - cond_channels=self.embedded_speaker_dim, + cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, language_emb_dim=self.embedded_language_dim, ) else: @@ -358,7 +362,7 @@ class Vits(BaseTTS): 256, 3, args.dropout_p_duration_predictor, - cond_channels=self.embedded_speaker_dim, + cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, language_emb_dim=self.embedded_language_dim, ) @@ -595,12 +599,15 @@ class Vits(BaseTTS): # duration predictor attn_durations = attn.sum(3) + g_dp = None + if self.args.condition_dp_on_speaker: + g_dp = g.detach() if self.args.detach_dp_input and g is not None else g if self.args.use_sdp: loss_duration = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, attn_durations, - g=g.detach() if self.args.detach_dp_input and g is not None else g, + g=g_dp, lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) loss_duration = loss_duration / torch.sum(x_mask) @@ -609,7 +616,7 @@ class Vits(BaseTTS): log_durations = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, - g=g.detach() if self.args.detach_dp_input and g is not None else g, + g=g_dp, lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) @@ -685,10 +692,10 @@ class Vits(BaseTTS): if self.args.use_sdp: logw = self.duration_predictor( - x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb + x, x_mask, g=g if self.args.condition_dp_on_speaker else None, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb ) else: - logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb) + logw = self.duration_predictor(x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb) w = torch.exp(logw) * x_mask * self.length_scale w_ceil = torch.ceil(w)