diff --git a/TTS/utils/trainer_utils.py b/TTS/utils/trainer_utils.py index 02e68905..29915527 100644 --- a/TTS/utils/trainer_utils.py +++ b/TTS/utils/trainer_utils.py @@ -11,11 +11,15 @@ def is_apex_available(): def setup_torch_training_env(cudnn_enable, cudnn_benchmark): + num_gpus = torch.cuda.device_count() + if num_gpus > 1: + raise RuntimeError( + f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`." + ) torch.backends.cudnn.enabled = cudnn_enable torch.backends.cudnn.benchmark = cudnn_benchmark torch.manual_seed(54321) use_cuda = torch.cuda.is_available() - num_gpus = torch.cuda.device_count() print(" > Using CUDA: ", use_cuda) print(" > Number of GPUs: ", num_gpus) return use_cuda, num_gpus