Fix GPU init in tests

pull/800/head
Eren Gölge 2021-09-10 08:28:10 +00:00
parent 0541a25e90
commit 3abc3a1d32
1 changed files with 2 additions and 2 deletions

View File

@ -7,8 +7,8 @@ from TTS.utils.generic_utils import get_cuda
def get_device_id():
use_cuda, _ = get_cuda()
if use_cuda:
if 'CUDA_VISIBLE_DEVICES' in os.environ and os.environ['CUDA_VISIBLE_DEVICES'] != "":
GPU_ID = os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0]
if "CUDA_VISIBLE_DEVICES" in os.environ and os.environ["CUDA_VISIBLE_DEVICES"] != "":
GPU_ID = os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]
else:
GPU_ID = "0"
else: