mirror of https://github.com/coqui-ai/TTS.git
Raise an error when multiple GPUs are in use
User must define the target GPU by `CUDA_VISIBLE_DEVICES` and use `distribute.py` for multi-gpu training.pull/611/head
parent
270c3823eb
commit
a05b234080
|
@ -11,11 +11,15 @@ def is_apex_available():
|
|||
|
||||
|
||||
def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus > 1:
|
||||
raise RuntimeError(
|
||||
f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`."
|
||||
)
|
||||
torch.backends.cudnn.enabled = cudnn_enable
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
torch.manual_seed(54321)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
num_gpus = torch.cuda.device_count()
|
||||
print(" > Using CUDA: ", use_cuda)
|
||||
print(" > Number of GPUs: ", num_gpus)
|
||||
return use_cuda, num_gpus
|
||||
|
|
Loading…
Reference in New Issue