diff --git a/precise/scripts/train.py b/precise/scripts/train.py index 2e76401..6c48eb5 100755 --- a/precise/scripts/train.py +++ b/precise/scripts/train.py @@ -164,8 +164,8 @@ class Trainer: train_inputs, train_outputs = self.sampled_data self.model.fit( train_inputs, train_outputs, self.args.batch_size, - self.epoch + self.args.epochs, validation_data=self.test, initial_epoch=self.epoch, - callbacks=self.callbacks + self.epoch + self.args.epochs, validation_data=self.test, + initial_epoch=self.epoch, callbacks=self.callbacks ) except KeyboardInterrupt: print() diff --git a/precise/scripts/train_incremental.py b/precise/scripts/train_incremental.py index 5396187..6804dd8 100755 --- a/precise/scripts/train_incremental.py +++ b/precise/scripts/train_incremental.py @@ -98,10 +98,11 @@ class IncrementalTrainer(Trainer): train_data = TrainData.merge(train_data, self.sampled_data) test_data = TrainData.merge(test_data, self.test) + train_inputs, train_outputs = train_data print() try: self.listener.runner.model.fit( - train_data[0], train_data[1], self.args.batch_size, self.epoch + self.args.epochs, + train_inputs, train_outputs, self.args.batch_size, self.epoch + self.args.epochs, validation_data=test_data, callbacks=self.callbacks, initial_epoch=self.epoch ) finally: