update set_weight_decay

pull/10/head
Eren Golge 2019-09-28 15:31:18 +02:00
parent 8565c508e4
commit 99d7f2a666
1 changed files with 4 additions and 3 deletions

View File

@ -182,9 +182,9 @@ def weight_decay(optimizer):
return optimizer, current_lr
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v"}):
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
"""
Skip biases, BatchNorm parameters for weight decay
Skip biases, BatchNorm parameters, rnns.
and attention projection layer v
"""
decay = []
@ -192,7 +192,8 @@ def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v"}):
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if len(param.shape) == 1 or name in skip_list:
if len(param.shape) == 1 or any([skip_name in name for skip_name in skip_list]):
no_decay.append(param)
else:
decay.append(param)