use `to_cuda()` for moving data in `format_batch()`

pull/506/head
Eren Gölge 2021-06-03 15:05:39 +02:00
parent 877bf66b61
commit 9042ae9195
3 changed files with 21 additions and 12 deletions

View File

@ -27,7 +27,7 @@ from TTS.tts.utils.text.symbols import make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict
from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict, to_cuda
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
from TTS.utils.training import check_update, setup_torch_training_env
@ -377,18 +377,18 @@ class TrainerTTS:
# dispatch batch to GPU
if self.use_cuda:
text_input = text_input.cuda(non_blocking=True)
text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True)
linear_input = linear_input.cuda(non_blocking=True) if self.config.model.lower() in ["tacotron"] else None
stop_targets = stop_targets.cuda(non_blocking=True)
attn_mask = attn_mask.cuda(non_blocking=True) if attn_mask is not None else None
durations = durations.cuda(non_blocking=True) if attn_mask is not None else None
text_input = to_cuda(text_input)
text_lengths = to_cuda(text_lengths)
mel_input = to_cuda(mel_input)
mel_lengths = to_cuda(mel_lengths)
linear_input = to_cuda(linear_input) if self.config.model.lower() in ["tacotron"] else None
stop_targets = to_cuda(stop_targets)
attn_mask = to_cuda(attn_mask) if attn_mask is not None else None
durations = to_cuda(durations) if attn_mask is not None else None
if speaker_ids is not None:
speaker_ids = speaker_ids.cuda(non_blocking=True)
speaker_ids = to_cuda(speaker_ids)
if speaker_embeddings is not None:
speaker_embeddings = speaker_embeddings.cuda(non_blocking=True)
speaker_embeddings = to_cuda(speaker_embeddings)
return {
"text_input": text_input,

View File

@ -282,7 +282,7 @@ class TTSDataset(Dataset):
"""
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.Mapping):
if isinstance(batch[0], collections.abc.Mapping):
text_lenghts = np.array([len(d["text"]) for d in batch])

View File

@ -13,6 +13,15 @@ from typing import Dict
import torch
def to_cuda(x: torch.Tensor) -> torch.Tensor:
if x is None:
return None
x = x.contiguous()
if torch.cuda.is_available():
x = x.cuda(non_blocking=True)
return x
def get_cuda():
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")