From 99d7f2a666f61dba206036dc457b6946765f834c Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Sat, 28 Sep 2019 15:31:18 +0200 Subject: [PATCH] update set_weight_decay --- utils/generic_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 983797ba..3188067f 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -182,9 +182,9 @@ def weight_decay(optimizer): return optimizer, current_lr -def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v"}): +def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}): """ - Skip biases, BatchNorm parameters for weight decay + Skip biases, BatchNorm parameters, rnns. and attention projection layer v """ decay = [] @@ -192,7 +192,8 @@ def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v"}): for name, param in model.named_parameters(): if not param.requires_grad: continue - if len(param.shape) == 1 or name in skip_list: + + if len(param.shape) == 1 or any([skip_name in name for skip_name in skip_list]): no_decay.append(param) else: decay.append(param)