fix device mismatch wavegrad training

this should fixe the device mismatch as seen here https://github.com/mozilla/TTS/issues/622#issue-789802916
pull/367/head
Alexander Korolev 2021-01-29 15:18:59 +01:00 committed by Eren Gölge
parent bfb12732f1
commit ace430d5e6
1 changed files with 4 additions and 4 deletions

View File

@ -345,6 +345,10 @@ def main(args): # pylint: disable=redefined-outer-name
# setup criterion # setup criterion
criterion = torch.nn.L1Loss().cuda() criterion = torch.nn.L1Loss().cuda()
if use_cuda:
model.cuda()
criterion.cuda()
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu') checkpoint = torch.load(args.restore_path, map_location='cpu')
try: try:
@ -378,10 +382,6 @@ def main(args): # pylint: disable=redefined-outer-name
else: else:
args.restore_step = 0 args.restore_step = 0
if use_cuda:
model.cuda()
criterion.cuda()
# DISTRUBUTED # DISTRUBUTED
if num_gpus > 1: if num_gpus > 1:
model = DDP_th(model, device_ids=[args.rank]) model = DDP_th(model, device_ids=[args.rank])