Add freeze vocoder generator and flow-based decoder option

pull/1032/head
Edresson 2021-09-19 21:06:58 -03:00 committed by Eren Gölge
parent de41165af4
commit 39aff6685e
1 changed files with 15 additions and 3 deletions

View File

@ -225,6 +225,8 @@ class VitsArgs(Coqpit):
freeze_encoder: bool = False
freeze_DP: bool = False
freeze_PE: bool = False
freeze_flow_decoder: bool = False
freeze_waveform_decoder: bool = False
@ -787,6 +789,8 @@ class Vits(BaseTTS):
if self.args.freeze_encoder:
for param in self.text_encoder.parameters():
param.requires_grad = False
if hasattr(self, 'emb_l'):
for param in self.emb_l.parameters():
param.requires_grad = False
@ -798,6 +802,14 @@ class Vits(BaseTTS):
for param in self.duration_predictor.parameters():
param.requires_grad = False
if self.args.freeze_flow_decoder:
for param in self.flow.parameters():
param.requires_grad = False
if self.args.freeze_waveform_decoder:
for param in self.waveform_decoder.parameters():
param.requires_grad = False
if optimizer_idx == 0:
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]