Add --freeze-till training parameter for transfer learning
For example, --freeze-till -2 will freeze layers[0:-2]pull/81/head
parent
f870827116
commit
06430cd24d
|
@ -36,6 +36,7 @@ class ModelParams:
|
|||
extra_metrics = attr.ib(False) # type: bool
|
||||
skip_acc = attr.ib(False) # type: bool
|
||||
loss_bias = attr.ib(0.7) # type: float
|
||||
freeze_till = attr.ib(0) # type: bool
|
||||
|
||||
|
||||
def load_precise_model(model_name: str) -> Any:
|
||||
|
@ -76,5 +77,7 @@ def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential'
|
|||
load_keras()
|
||||
metrics = ['accuracy'] + params.extra_metrics * [false_pos, false_neg]
|
||||
set_loss_bias(params.loss_bias)
|
||||
for i in model.layers[:params.freeze_till]:
|
||||
i.trainable = False
|
||||
model.compile('rmsprop', weighted_log_loss, metrics=(not params.skip_acc) * metrics)
|
||||
return model
|
||||
|
|
|
@ -60,6 +60,10 @@ class Trainer:
|
|||
|
||||
:-em --extra-metrics
|
||||
Add extra metrics during training
|
||||
|
||||
:-f --freeze-till int 0
|
||||
Freeze all weights up to this index (non-inclusive).
|
||||
Can be negative to wrap from end
|
||||
|
||||
...
|
||||
'''
|
||||
|
@ -80,7 +84,7 @@ class Trainer:
|
|||
inject_params(args.model)
|
||||
save_params(args.model)
|
||||
params = ModelParams(skip_acc=args.no_validation, extra_metrics=args.extra_metrics,
|
||||
loss_bias=1.0 - args.sensitivity)
|
||||
loss_bias=1.0 - args.sensitivity, freeze_till=args.freeze_till)
|
||||
self.model = create_model(args.model, params)
|
||||
self.train, self.test = self.load_data(self.args)
|
||||
|
||||
|
|
Loading…
Reference in New Issue