Add device support in TTS and Synthesizer (#2855)

* fix: resolve merge conflicts

* fix: retain backwards compatability in functions

* feature: utilize device for voice transfer

* feature: use device for vocoder

* chore: cleanup vocoder cpu logic

* fix: add necessary vocoder output device check

* fix: add necessary vocoder output device check

* fix: indentation

* fix: check if waveform is pt tensor before cpu conversion

---------

Co-authored-by: Jake Tae <jaketae@Jakes-MacBook-Pro-2.local>
pull/2875/head
Jake Tae 2023-08-14 15:04:44 -04:00 committed by GitHub
parent febcaf710a
commit 409db505d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 35 deletions

View File

@ -1,8 +1,10 @@
import tempfile
import warnings
from pathlib import Path
from typing import Union
import numpy as np
from torch import nn
from TTS.cs_api import CS_API
from TTS.utils.audio.numpy_transforms import save_wav
@ -10,7 +12,7 @@ from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
class TTS:
class TTS(nn.Module):
"""TODO: Add voice conversion and Capacitron support."""
def __init__(
@ -62,6 +64,7 @@ class TTS:
Defaults to "XTTS".
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
super().__init__()
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)
self.synthesizer = None
@ -70,6 +73,9 @@ class TTS:
self.cs_api_model = cs_api_model
self.model_name = None
if gpu:
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
if model_name is not None:
if "tts_models" in model_name or "coqui_studio" in model_name:
self.load_tts_model_by_name(model_name, gpu)

View File

@ -5,19 +5,21 @@ import torch
from torch import nn
def numpy_to_torch(np_array, dtype, cuda=False):
def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"):
if cuda:
device = "cuda"
if np_array is None:
return None
tensor = torch.as_tensor(np_array, dtype=dtype)
if cuda:
return tensor.cuda()
tensor = torch.as_tensor(np_array, dtype=dtype, device=device)
return tensor
def compute_style_mel(style_wav, ap, cuda=False):
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
def compute_style_mel(style_wav, ap, cuda=False, device="cpu"):
if cuda:
return style_mel.cuda()
device = "cuda"
style_mel = torch.FloatTensor(
ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)), device=device,
).unsqueeze(0)
return style_mel
@ -73,22 +75,22 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
return wav
def id_to_torch(aux_id, cuda=False):
def id_to_torch(aux_id, cuda=False, device="cpu"):
if cuda:
device = "cuda"
if aux_id is not None:
aux_id = np.asarray(aux_id)
aux_id = torch.from_numpy(aux_id)
if cuda:
return aux_id.cuda()
aux_id = torch.from_numpy(aux_id).to(device)
return aux_id
def embedding_to_torch(d_vector, cuda=False):
def embedding_to_torch(d_vector, cuda=False, device="cpu"):
if cuda:
device = "cuda"
if d_vector is not None:
d_vector = np.asarray(d_vector)
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
d_vector = d_vector.squeeze().unsqueeze(0)
if cuda:
return d_vector.cuda()
d_vector = d_vector.squeeze().unsqueeze(0).to(device)
return d_vector
@ -162,6 +164,11 @@ def synthesis(
language_id (int):
Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
"""
# device
device = next(model.parameters()).device
if use_cuda:
device = "cuda"
# GST or Capacitron processing
# TODO: need to handle the case of setting both gst and capacitron to true somewhere
style_mel = None
@ -169,10 +176,10 @@ def synthesis(
if isinstance(style_wav, dict):
style_mel = style_wav
else:
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
style_mel = compute_style_mel(style_wav, model.ap, device=device)
if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None:
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
style_mel = compute_style_mel(style_wav, model.ap, device=device)
style_mel = style_mel.transpose(1, 2) # [1, time, depth]
language_name = None
@ -188,26 +195,26 @@ def synthesis(
)
# pass tensors to backend
if speaker_id is not None:
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
speaker_id = id_to_torch(speaker_id, device=device)
if d_vector is not None:
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
d_vector = embedding_to_torch(d_vector, device=device)
if language_id is not None:
language_id = id_to_torch(language_id, cuda=use_cuda)
language_id = id_to_torch(language_id, device=device)
if not isinstance(style_mel, dict):
# GST or Capacitron style mel
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
style_mel = numpy_to_torch(style_mel, torch.float, device=device)
if style_text is not None:
style_text = np.asarray(
model.tokenizer.text_to_ids(style_text, language=language_id),
dtype=np.int32,
)
style_text = numpy_to_torch(style_text, torch.long, cuda=use_cuda)
style_text = numpy_to_torch(style_text, torch.long, device=device)
style_text = style_text.unsqueeze(0)
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
text_inputs = text_inputs.unsqueeze(0)
# synthesize voice
outputs = run_model_torch(
@ -290,22 +297,27 @@ def transfer_voice(
do_trim_silence (bool):
trim silence after synthesis. Defaults to False.
"""
# device
device = next(model.parameters()).device
if use_cuda:
device = "cuda"
# pass tensors to backend
if speaker_id is not None:
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
speaker_id = id_to_torch(speaker_id, device=device)
if d_vector is not None:
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
d_vector = embedding_to_torch(d_vector, device=device)
if reference_d_vector is not None:
reference_d_vector = embedding_to_torch(reference_d_vector, cuda=use_cuda)
reference_d_vector = embedding_to_torch(reference_d_vector, device=device)
# load reference_wav audio
reference_wav = embedding_to_torch(
model.ap.load_wav(
reference_wav, sr=model.args.encoder_sample_rate if model.args.encoder_sample_rate else model.ap.sample_rate
),
cuda=use_cuda,
device=device,
)
if hasattr(model, "module"):

View File

@ -5,6 +5,7 @@ from typing import List
import numpy as np
import pysbd
import torch
from torch import nn
from TTS.config import load_config
from TTS.tts.configs.vits_config import VitsConfig
@ -21,7 +22,7 @@ from TTS.vocoder.models import setup_model as setup_vocoder_model
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
class Synthesizer(object):
class Synthesizer(nn.Module):
def __init__(
self,
tts_checkpoint: str = "",
@ -60,6 +61,7 @@ class Synthesizer(object):
vc_config (str, optional): path to the voice conversion config file. Defaults to `""`,
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
"""
super().__init__()
self.tts_checkpoint = tts_checkpoint
self.tts_config_path = tts_config_path
self.tts_speakers_file = tts_speakers_file
@ -356,7 +358,12 @@ class Synthesizer(object):
if speaker_wav is not None and self.tts_model.speaker_manager is not None:
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
vocoder_device = "cpu"
use_gl = self.vocoder_model is None
if not use_gl:
vocoder_device = next(self.vocoder_model.parameters()).device
if self.use_cuda:
vocoder_device = "cuda"
if not reference_wav: # not voice conversion
for sen in sens:
@ -388,7 +395,6 @@ class Synthesizer(object):
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
# denormalize tts output based on tts audio config
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
device_type = "cuda" if self.use_cuda else "cpu"
# renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch
@ -403,8 +409,8 @@ class Synthesizer(object):
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
# run vocoder model
# [1, T, C]
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
if self.use_cuda and not use_gl:
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu") and not use_gl:
waveform = waveform.cpu()
if not use_gl:
waveform = waveform.numpy()
@ -453,7 +459,6 @@ class Synthesizer(object):
mel_postnet_spec = outputs[0].detach().cpu().numpy()
# denormalize tts output based on tts audio config
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
device_type = "cuda" if self.use_cuda else "cpu"
# renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch
@ -468,8 +473,8 @@ class Synthesizer(object):
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
# run vocoder model
# [1, T, C]
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
if self.use_cuda:
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu"):
waveform = waveform.cpu()
if not use_gl:
waveform = waveform.numpy()