diff --git a/precise/model.py b/precise/model.py index 696dc85..6ee90ca 100644 --- a/precise/model.py +++ b/precise/model.py @@ -36,6 +36,7 @@ class ModelParams: extra_metrics = attr.ib(False) # type: bool skip_acc = attr.ib(False) # type: bool loss_bias = attr.ib(0.7) # type: float + freeze_till = attr.ib(0) # type: bool def load_precise_model(model_name: str) -> Any: @@ -76,5 +77,7 @@ def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential' load_keras() metrics = ['accuracy'] + params.extra_metrics * [false_pos, false_neg] set_loss_bias(params.loss_bias) + for i in model.layers[:params.freeze_till]: + i.trainable = False model.compile('rmsprop', weighted_log_loss, metrics=(not params.skip_acc) * metrics) return model diff --git a/precise/scripts/train.py b/precise/scripts/train.py index af964c1..c35b438 100755 --- a/precise/scripts/train.py +++ b/precise/scripts/train.py @@ -60,6 +60,10 @@ class Trainer: :-em --extra-metrics Add extra metrics during training + + :-f --freeze-till int 0 + Freeze all weights up to this index (non-inclusive). + Can be negative to wrap from end ... ''' @@ -80,7 +84,7 @@ class Trainer: inject_params(args.model) save_params(args.model) params = ModelParams(skip_acc=args.no_validation, extra_metrics=args.extra_metrics, - loss_bias=1.0 - args.sensitivity) + loss_bias=1.0 - args.sensitivity, freeze_till=args.freeze_till) self.model = create_model(args.model, params) self.train, self.test = self.load_data(self.args)