diff --git a/TTS/trainer.py b/TTS/trainer.py index d81132cf..8ec59f55 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -27,7 +27,7 @@ from TTS.tts.utils.text.symbols import make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.distribute import init_distributed -from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict +from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict, to_cuda from TTS.utils.logging import ConsoleLogger, TensorboardLogger from TTS.utils.training import check_update, setup_torch_training_env @@ -377,18 +377,18 @@ class TrainerTTS: # dispatch batch to GPU if self.use_cuda: - text_input = text_input.cuda(non_blocking=True) - text_lengths = text_lengths.cuda(non_blocking=True) - mel_input = mel_input.cuda(non_blocking=True) - mel_lengths = mel_lengths.cuda(non_blocking=True) - linear_input = linear_input.cuda(non_blocking=True) if self.config.model.lower() in ["tacotron"] else None - stop_targets = stop_targets.cuda(non_blocking=True) - attn_mask = attn_mask.cuda(non_blocking=True) if attn_mask is not None else None - durations = durations.cuda(non_blocking=True) if attn_mask is not None else None + text_input = to_cuda(text_input) + text_lengths = to_cuda(text_lengths) + mel_input = to_cuda(mel_input) + mel_lengths = to_cuda(mel_lengths) + linear_input = to_cuda(linear_input) if self.config.model.lower() in ["tacotron"] else None + stop_targets = to_cuda(stop_targets) + attn_mask = to_cuda(attn_mask) if attn_mask is not None else None + durations = to_cuda(durations) if attn_mask is not None else None if speaker_ids is not None: - speaker_ids = speaker_ids.cuda(non_blocking=True) + speaker_ids = to_cuda(speaker_ids) if speaker_embeddings is not None: - speaker_embeddings = speaker_embeddings.cuda(non_blocking=True) + speaker_embeddings = to_cuda(speaker_embeddings) return { "text_input": text_input, diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index cbb0a593..76f82c97 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -282,7 +282,7 @@ class TTSDataset(Dataset): """ # Puts each data field into a tensor with outer dimension batch size - if isinstance(batch[0], collections.Mapping): + if isinstance(batch[0], collections.abc.Mapping): text_lenghts = np.array([len(d["text"]) for d in batch]) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 0c28116d..a1abf5fe 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -13,6 +13,15 @@ from typing import Dict import torch +def to_cuda(x: torch.Tensor) -> torch.Tensor: + if x is None: + return None + x = x.contiguous() + if torch.cuda.is_available(): + x = x.cuda(non_blocking=True) + return x + + def get_cuda(): use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")