mirror of https://github.com/coqui-ai/TTS.git
update set_weight_decay
parent
8565c508e4
commit
99d7f2a666
|
@ -182,9 +182,9 @@ def weight_decay(optimizer):
|
|||
return optimizer, current_lr
|
||||
|
||||
|
||||
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v"}):
|
||||
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
|
||||
"""
|
||||
Skip biases, BatchNorm parameters for weight decay
|
||||
Skip biases, BatchNorm parameters, rnns.
|
||||
and attention projection layer v
|
||||
"""
|
||||
decay = []
|
||||
|
@ -192,7 +192,8 @@ def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v"}):
|
|||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if len(param.shape) == 1 or name in skip_list:
|
||||
|
||||
if len(param.shape) == 1 or any([skip_name in name for skip_name in skip_list]):
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
|
|
Loading…
Reference in New Issue