diff --git a/TTS/bin/train_vocoder_wavegrad.py b/TTS/bin/train_vocoder_wavegrad.py index 73802c63..b104652d 100644 --- a/TTS/bin/train_vocoder_wavegrad.py +++ b/TTS/bin/train_vocoder_wavegrad.py @@ -344,6 +344,10 @@ def main(args): # pylint: disable=redefined-outer-name # setup criterion criterion = torch.nn.L1Loss().cuda() + + if use_cuda: + model.cuda() + criterion.cuda() if args.restore_path: checkpoint = torch.load(args.restore_path, map_location='cpu') @@ -378,10 +382,6 @@ def main(args): # pylint: disable=redefined-outer-name else: args.restore_step = 0 - if use_cuda: - model.cuda() - criterion.cuda() - # DISTRUBUTED if num_gpus > 1: model = DDP_th(model, device_ids=[args.rank])