Raise an error when multiple GPUs are in use

User must define the target GPU by `CUDA_VISIBLE_DEVICES` and
use `distribute.py` for multi-gpu training.
pull/611/head
Eren Gölge 2021-07-04 11:25:49 +02:00
parent 270c3823eb
commit a05b234080
1 changed files with 5 additions and 1 deletions

View File

@ -11,11 +11,15 @@ def is_apex_available():
def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
raise RuntimeError(
f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`."
)
torch.backends.cudnn.enabled = cudnn_enable
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.manual_seed(54321)
use_cuda = torch.cuda.is_available()
num_gpus = torch.cuda.device_count()
print(" > Using CUDA: ", use_cuda)
print(" > Number of GPUs: ", num_gpus)
return use_cuda, num_gpus