mirror of https://github.com/coqui-ai/TTS.git
fix device mismatch wavegrad training
this should fixe the device mismatch as seen here https://github.com/mozilla/TTS/issues/622#issue-789802916pull/367/head
parent
bfb12732f1
commit
ace430d5e6
|
@ -345,6 +345,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
# setup criterion
|
# setup criterion
|
||||||
criterion = torch.nn.L1Loss().cuda()
|
criterion = torch.nn.L1Loss().cuda()
|
||||||
|
|
||||||
|
if use_cuda:
|
||||||
|
model.cuda()
|
||||||
|
criterion.cuda()
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||||
try:
|
try:
|
||||||
|
@ -378,10 +382,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
|
|
||||||
if use_cuda:
|
|
||||||
model.cuda()
|
|
||||||
criterion.cuda()
|
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
model = DDP_th(model, device_ids=[args.rank])
|
model = DDP_th(model, device_ids=[args.rank])
|
||||||
|
|
Loading…
Reference in New Issue