diff --git a/train.py b/train.py index 7b5a02fe..6d4af558 100644 --- a/train.py +++ b/train.py @@ -333,7 +333,7 @@ def main(args): train_loader = DataLoader(train_dataset, batch_size=c.batch_size, shuffle=False, collate_fn=train_dataset.collate_fn, - drop_last=False, num_workers=c.num_loader_workers, + drop_last=True, num_workers=c.num_loader_workers, pin_memory=True) val_dataset = Dataset(os.path.join(c.data_path, c.meta_file_val),