diff --git a/precise/scripts/train.py b/precise/scripts/train.py index 1794244..2e76401 100755 --- a/precise/scripts/train.py +++ b/precise/scripts/train.py @@ -19,7 +19,6 @@ from os.path import splitext, isfile from prettyparse import add_to_parser from typing import Any, Tuple -from precise.functions import set_loss_bias from precise.model import create_model, ModelParams from precise.params import inject_params, save_params from precise.train_data import TrainData @@ -82,8 +81,8 @@ class Trainer: save_params(args.model) self.train, self.test = self.load_data(self.args) - set_loss_bias(1.0 - args.sensitivity) - params = ModelParams(skip_acc=args.no_validation, extra_metrics=args.extra_metrics) + params = ModelParams(skip_acc=args.no_validation, extra_metrics=args.extra_metrics, + loss_bias=1.0 - args.sensitivity) self.model = create_model(args.model, params) self.model.summary()