mirror of https://github.com/coqui-ai/TTS.git
Add freeze vocoder generator and flow-based decoder option
parent
de41165af4
commit
39aff6685e
|
@ -225,6 +225,8 @@ class VitsArgs(Coqpit):
|
|||
freeze_encoder: bool = False
|
||||
freeze_DP: bool = False
|
||||
freeze_PE: bool = False
|
||||
freeze_flow_decoder: bool = False
|
||||
freeze_waveform_decoder: bool = False
|
||||
|
||||
|
||||
|
||||
|
@ -787,9 +789,11 @@ class Vits(BaseTTS):
|
|||
if self.args.freeze_encoder:
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
for param in self.emb_l.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
if hasattr(self, 'emb_l'):
|
||||
for param in self.emb_l.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.args.freeze_PE:
|
||||
for param in self.posterior_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
@ -798,6 +802,14 @@ class Vits(BaseTTS):
|
|||
for param in self.duration_predictor.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.args.freeze_flow_decoder:
|
||||
for param in self.flow.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.args.freeze_waveform_decoder:
|
||||
for param in self.waveform_decoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if optimizer_idx == 0:
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
|
|
Loading…
Reference in New Issue