diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 93a5bad2..acf750a0 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -599,7 +599,6 @@ class VitsGeneratorLoss(nn.Module): feats_disc_fake, feats_disc_real, loss_duration, - fine_tuning_mode=0, use_speaker_encoder_as_loss=False, gt_spk_emb=None, syn_spk_emb=None, @@ -623,14 +622,9 @@ class VitsGeneratorLoss(nn.Module): # compute mel spectrograms from the waveforms mel = self.stft(waveform) mel_hat = self.stft(waveform_hat) + # compute losses - - # ignore tts model loss if fine tunning mode is on - if fine_tuning_mode: - loss_kl = 0.0 - else: - loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha - + loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 9f895fc1..0abf0ca3 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -167,11 +167,6 @@ class VitsArgs(Coqpit): speaker_encoder_model_path (str): Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". - fine_tuning_mode (int): - Fine tuning only the vocoder part of the model, while the rest will be frozen. Defaults to 0. - Mode 0: Disabled; - Mode 1: uses the distribution predicted by the encoder and It's recommended for TTS; - Mode 2: uses the distribution predicted by the encoder and It's recommended for voice conversion. """ num_chars: int = 100 @@ -219,7 +214,6 @@ class VitsArgs(Coqpit): use_speaker_encoder_as_loss: bool = False speaker_encoder_config_path: str = "" speaker_encoder_model_path: str = "" - fine_tuning_mode: int = 0 freeze_encoder: bool = False freeze_DP: bool = False freeze_PE: bool = False @@ -672,122 +666,6 @@ class Vits(BaseTTS): ) return outputs - def forward_fine_tuning( - self, - x: torch.tensor, - x_lengths: torch.tensor, - y: torch.tensor, - y_lengths: torch.tensor, - aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, - waveform=None, - ) -> Dict: - """Forward pass of the model. - - Args: - x (torch.tensor): Batch of input character sequence IDs. - x_lengths (torch.tensor): Batch of input character sequence lengths. - y (torch.tensor): Batch of input spectrograms. - y_lengths (torch.tensor): Batch of input spectrogram lengths. - aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}. - - Returns: - Dict: model outputs keyed by the output name. - - Shapes: - - x: :math:`[B, T_seq]` - - x_lengths: :math:`[B]` - - y: :math:`[B, C, T_spec]` - - y_lengths: :math:`[B]` - - d_vectors: :math:`[B, C, 1]` - - speaker_ids: :math:`[B]` - """ - with torch.no_grad(): - outputs = {} - sid, g, lid = self._set_cond_input(aux_input) - # speaker embedding - if self.args.use_speaker_embedding and sid is not None and not self.use_d_vector: - g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] - - # language embedding - lang_emb = None - if self.args.use_language_embedding and lid is not None: - lang_emb = self.emb_l(lid).unsqueeze(-1) - - x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) - - # posterior encoder - z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) - - # flow layers - z_p = self.flow(z, y_mask, g=g) - - # find the alignment path - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - with torch.no_grad(): - o_scale = torch.exp(-2 * logs_p) - logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)]) - logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) - logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] - logp = logp2 + logp3 + logp1 + logp4 - attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() - - # expand prior - m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) - logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) - - # mode 1: like SC-GlowTTS paper; mode 2: recommended for voice conversion - if self.args.fine_tuning_mode == 1: - z_ft = m_p - elif self.args.fine_tuning_mode == 2: - z_ft = z_p - else: - raise RuntimeError(" [!] Invalid Fine Tunning Mode !") - - # inverse decoder and get the output - z_f_pred = self.flow(z_ft, y_mask, g=g, reverse=True) - z_slice, slice_ids = rand_segments(z_f_pred, y_lengths, self.spec_segment_size) - - o = self.waveform_decoder(z_slice, g=g) - - wav_seg = segment( - waveform.transpose(1, 2), - slice_ids * self.config.audio.hop_length, - self.args.spec_segment_size * self.config.audio.hop_length, - ) - - if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None: - # concate generated and GT waveforms - wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1) - - # resample audio to speaker encoder sample_rate - if self.audio_transform is not None: - wavs_batch = self.audio_transform(wavs_batch) - - pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True) - - # split generated and GT speaker embeddings - gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) - else: - gt_spk_emb, syn_spk_emb = None, None - - outputs.update( - { - "model_outputs": o, - "alignments": attn.squeeze(1), - "loss_duration": 0.0, - "z": z, - "z_p": z_p, - "m_p": m_p, - "logs_p": logs_p, - "m_q": m_q, - "logs_q": logs_q, - "waveform_seg": wav_seg, - "gt_spk_emb": gt_spk_emb, - "syn_spk_emb": syn_spk_emb, - } - ) - return outputs def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}): """ @@ -869,15 +747,6 @@ class Vits(BaseTTS): if optimizer_idx not in [0, 1]: raise ValueError(" [!] Unexpected `optimizer_idx`.") - # generator pass - if self.args.fine_tuning_mode: - # ToDo: find better place fot it - # force eval mode - self.eval() - # restore train mode for the vocoder part - self.waveform_decoder.train() - self.disc.train() - if self.args.freeze_encoder: for param in self.text_encoder.parameters(): param.requires_grad = False @@ -913,25 +782,14 @@ class Vits(BaseTTS): waveform = batch["waveform"] # generator pass - if self.args.fine_tuning_mode: - # model forward - outputs = self.forward_fine_tuning( - text_input, - text_lengths, - linear_input.transpose(1, 2), - mel_lengths, - aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, - waveform=waveform, - ) - else: - outputs = self.forward( - text_input, - text_lengths, - linear_input.transpose(1, 2), - mel_lengths, - aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, - waveform=waveform, - ) + outputs = self.forward( + text_input, + text_lengths, + linear_input.transpose(1, 2), + mel_lengths, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, + waveform=waveform, + ) # cache tensors for the discriminator self.y_disc_cache = None @@ -958,7 +816,6 @@ class Vits(BaseTTS): feats_disc_fake=outputs["feats_disc_fake"], feats_disc_real=outputs["feats_disc_real"], loss_duration=outputs["loss_duration"], - fine_tuning_mode=self.args.fine_tuning_mode, use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, gt_spk_emb=outputs["gt_spk_emb"], syn_spk_emb=outputs["syn_spk_emb"],