mirror of https://github.com/coqui-ai/TTS.git
use `to_cuda()` for moving data in `format_batch()`
parent
877bf66b61
commit
9042ae9195
|
@ -27,7 +27,7 @@ from TTS.tts.utils.text.symbols import make_symbols
|
|||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict, to_cuda
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||
from TTS.utils.training import check_update, setup_torch_training_env
|
||||
|
||||
|
@ -377,18 +377,18 @@ class TrainerTTS:
|
|||
|
||||
# dispatch batch to GPU
|
||||
if self.use_cuda:
|
||||
text_input = text_input.cuda(non_blocking=True)
|
||||
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||
mel_input = mel_input.cuda(non_blocking=True)
|
||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||
linear_input = linear_input.cuda(non_blocking=True) if self.config.model.lower() in ["tacotron"] else None
|
||||
stop_targets = stop_targets.cuda(non_blocking=True)
|
||||
attn_mask = attn_mask.cuda(non_blocking=True) if attn_mask is not None else None
|
||||
durations = durations.cuda(non_blocking=True) if attn_mask is not None else None
|
||||
text_input = to_cuda(text_input)
|
||||
text_lengths = to_cuda(text_lengths)
|
||||
mel_input = to_cuda(mel_input)
|
||||
mel_lengths = to_cuda(mel_lengths)
|
||||
linear_input = to_cuda(linear_input) if self.config.model.lower() in ["tacotron"] else None
|
||||
stop_targets = to_cuda(stop_targets)
|
||||
attn_mask = to_cuda(attn_mask) if attn_mask is not None else None
|
||||
durations = to_cuda(durations) if attn_mask is not None else None
|
||||
if speaker_ids is not None:
|
||||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||
speaker_ids = to_cuda(speaker_ids)
|
||||
if speaker_embeddings is not None:
|
||||
speaker_embeddings = speaker_embeddings.cuda(non_blocking=True)
|
||||
speaker_embeddings = to_cuda(speaker_embeddings)
|
||||
|
||||
return {
|
||||
"text_input": text_input,
|
||||
|
|
|
@ -282,7 +282,7 @@ class TTSDataset(Dataset):
|
|||
"""
|
||||
|
||||
# Puts each data field into a tensor with outer dimension batch size
|
||||
if isinstance(batch[0], collections.Mapping):
|
||||
if isinstance(batch[0], collections.abc.Mapping):
|
||||
|
||||
text_lenghts = np.array([len(d["text"]) for d in batch])
|
||||
|
||||
|
|
|
@ -13,6 +13,15 @@ from typing import Dict
|
|||
import torch
|
||||
|
||||
|
||||
def to_cuda(x: torch.Tensor) -> torch.Tensor:
|
||||
if x is None:
|
||||
return None
|
||||
x = x.contiguous()
|
||||
if torch.cuda.is_available():
|
||||
x = x.cuda(non_blocking=True)
|
||||
return x
|
||||
|
||||
|
||||
def get_cuda():
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
|
Loading…
Reference in New Issue