mirror of https://github.com/coqui-ai/TTS.git
Add device support in TTS and Synthesizer (#2855)
* fix: resolve merge conflicts * fix: retain backwards compatability in functions * feature: utilize device for voice transfer * feature: use device for vocoder * chore: cleanup vocoder cpu logic * fix: add necessary vocoder output device check * fix: add necessary vocoder output device check * fix: indentation * fix: check if waveform is pt tensor before cpu conversion --------- Co-authored-by: Jake Tae <jaketae@Jakes-MacBook-Pro-2.local>pull/2875/head
parent
febcaf710a
commit
409db505d2
|
@ -1,8 +1,10 @@
|
|||
import tempfile
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
from TTS.cs_api import CS_API
|
||||
from TTS.utils.audio.numpy_transforms import save_wav
|
||||
|
@ -10,7 +12,7 @@ from TTS.utils.manage import ModelManager
|
|||
from TTS.utils.synthesizer import Synthesizer
|
||||
|
||||
|
||||
class TTS:
|
||||
class TTS(nn.Module):
|
||||
"""TODO: Add voice conversion and Capacitron support."""
|
||||
|
||||
def __init__(
|
||||
|
@ -62,6 +64,7 @@ class TTS:
|
|||
Defaults to "XTTS".
|
||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||
"""
|
||||
super().__init__()
|
||||
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)
|
||||
|
||||
self.synthesizer = None
|
||||
|
@ -70,6 +73,9 @@ class TTS:
|
|||
self.cs_api_model = cs_api_model
|
||||
self.model_name = None
|
||||
|
||||
if gpu:
|
||||
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
|
||||
|
||||
if model_name is not None:
|
||||
if "tts_models" in model_name or "coqui_studio" in model_name:
|
||||
self.load_tts_model_by_name(model_name, gpu)
|
||||
|
|
|
@ -5,19 +5,21 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
|
||||
def numpy_to_torch(np_array, dtype, cuda=False):
|
||||
def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"):
|
||||
if cuda:
|
||||
device = "cuda"
|
||||
if np_array is None:
|
||||
return None
|
||||
tensor = torch.as_tensor(np_array, dtype=dtype)
|
||||
if cuda:
|
||||
return tensor.cuda()
|
||||
tensor = torch.as_tensor(np_array, dtype=dtype, device=device)
|
||||
return tensor
|
||||
|
||||
|
||||
def compute_style_mel(style_wav, ap, cuda=False):
|
||||
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
|
||||
def compute_style_mel(style_wav, ap, cuda=False, device="cpu"):
|
||||
if cuda:
|
||||
return style_mel.cuda()
|
||||
device = "cuda"
|
||||
style_mel = torch.FloatTensor(
|
||||
ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)), device=device,
|
||||
).unsqueeze(0)
|
||||
return style_mel
|
||||
|
||||
|
||||
|
@ -73,22 +75,22 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
|
|||
return wav
|
||||
|
||||
|
||||
def id_to_torch(aux_id, cuda=False):
|
||||
def id_to_torch(aux_id, cuda=False, device="cpu"):
|
||||
if cuda:
|
||||
device = "cuda"
|
||||
if aux_id is not None:
|
||||
aux_id = np.asarray(aux_id)
|
||||
aux_id = torch.from_numpy(aux_id)
|
||||
if cuda:
|
||||
return aux_id.cuda()
|
||||
aux_id = torch.from_numpy(aux_id).to(device)
|
||||
return aux_id
|
||||
|
||||
|
||||
def embedding_to_torch(d_vector, cuda=False):
|
||||
def embedding_to_torch(d_vector, cuda=False, device="cpu"):
|
||||
if cuda:
|
||||
device = "cuda"
|
||||
if d_vector is not None:
|
||||
d_vector = np.asarray(d_vector)
|
||||
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
|
||||
d_vector = d_vector.squeeze().unsqueeze(0)
|
||||
if cuda:
|
||||
return d_vector.cuda()
|
||||
d_vector = d_vector.squeeze().unsqueeze(0).to(device)
|
||||
return d_vector
|
||||
|
||||
|
||||
|
@ -162,6 +164,11 @@ def synthesis(
|
|||
language_id (int):
|
||||
Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
|
||||
"""
|
||||
# device
|
||||
device = next(model.parameters()).device
|
||||
if use_cuda:
|
||||
device = "cuda"
|
||||
|
||||
# GST or Capacitron processing
|
||||
# TODO: need to handle the case of setting both gst and capacitron to true somewhere
|
||||
style_mel = None
|
||||
|
@ -169,10 +176,10 @@ def synthesis(
|
|||
if isinstance(style_wav, dict):
|
||||
style_mel = style_wav
|
||||
else:
|
||||
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
|
||||
style_mel = compute_style_mel(style_wav, model.ap, device=device)
|
||||
|
||||
if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None:
|
||||
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
|
||||
style_mel = compute_style_mel(style_wav, model.ap, device=device)
|
||||
style_mel = style_mel.transpose(1, 2) # [1, time, depth]
|
||||
|
||||
language_name = None
|
||||
|
@ -188,26 +195,26 @@ def synthesis(
|
|||
)
|
||||
# pass tensors to backend
|
||||
if speaker_id is not None:
|
||||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
||||
speaker_id = id_to_torch(speaker_id, device=device)
|
||||
|
||||
if d_vector is not None:
|
||||
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
|
||||
d_vector = embedding_to_torch(d_vector, device=device)
|
||||
|
||||
if language_id is not None:
|
||||
language_id = id_to_torch(language_id, cuda=use_cuda)
|
||||
language_id = id_to_torch(language_id, device=device)
|
||||
|
||||
if not isinstance(style_mel, dict):
|
||||
# GST or Capacitron style mel
|
||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||
style_mel = numpy_to_torch(style_mel, torch.float, device=device)
|
||||
if style_text is not None:
|
||||
style_text = np.asarray(
|
||||
model.tokenizer.text_to_ids(style_text, language=language_id),
|
||||
dtype=np.int32,
|
||||
)
|
||||
style_text = numpy_to_torch(style_text, torch.long, cuda=use_cuda)
|
||||
style_text = numpy_to_torch(style_text, torch.long, device=device)
|
||||
style_text = style_text.unsqueeze(0)
|
||||
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
|
||||
text_inputs = text_inputs.unsqueeze(0)
|
||||
# synthesize voice
|
||||
outputs = run_model_torch(
|
||||
|
@ -290,22 +297,27 @@ def transfer_voice(
|
|||
do_trim_silence (bool):
|
||||
trim silence after synthesis. Defaults to False.
|
||||
"""
|
||||
# device
|
||||
device = next(model.parameters()).device
|
||||
if use_cuda:
|
||||
device = "cuda"
|
||||
|
||||
# pass tensors to backend
|
||||
if speaker_id is not None:
|
||||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
||||
speaker_id = id_to_torch(speaker_id, device=device)
|
||||
|
||||
if d_vector is not None:
|
||||
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
|
||||
d_vector = embedding_to_torch(d_vector, device=device)
|
||||
|
||||
if reference_d_vector is not None:
|
||||
reference_d_vector = embedding_to_torch(reference_d_vector, cuda=use_cuda)
|
||||
reference_d_vector = embedding_to_torch(reference_d_vector, device=device)
|
||||
|
||||
# load reference_wav audio
|
||||
reference_wav = embedding_to_torch(
|
||||
model.ap.load_wav(
|
||||
reference_wav, sr=model.args.encoder_sample_rate if model.args.encoder_sample_rate else model.ap.sample_rate
|
||||
),
|
||||
cuda=use_cuda,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if hasattr(model, "module"):
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import List
|
|||
import numpy as np
|
||||
import pysbd
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
@ -21,7 +22,7 @@ from TTS.vocoder.models import setup_model as setup_vocoder_model
|
|||
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
||||
|
||||
|
||||
class Synthesizer(object):
|
||||
class Synthesizer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
tts_checkpoint: str = "",
|
||||
|
@ -60,6 +61,7 @@ class Synthesizer(object):
|
|||
vc_config (str, optional): path to the voice conversion config file. Defaults to `""`,
|
||||
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
|
||||
"""
|
||||
super().__init__()
|
||||
self.tts_checkpoint = tts_checkpoint
|
||||
self.tts_config_path = tts_config_path
|
||||
self.tts_speakers_file = tts_speakers_file
|
||||
|
@ -356,7 +358,12 @@ class Synthesizer(object):
|
|||
if speaker_wav is not None and self.tts_model.speaker_manager is not None:
|
||||
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
|
||||
|
||||
vocoder_device = "cpu"
|
||||
use_gl = self.vocoder_model is None
|
||||
if not use_gl:
|
||||
vocoder_device = next(self.vocoder_model.parameters()).device
|
||||
if self.use_cuda:
|
||||
vocoder_device = "cuda"
|
||||
|
||||
if not reference_wav: # not voice conversion
|
||||
for sen in sens:
|
||||
|
@ -388,7 +395,6 @@ class Synthesizer(object):
|
|||
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
|
||||
# denormalize tts output based on tts audio config
|
||||
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
|
||||
device_type = "cuda" if self.use_cuda else "cpu"
|
||||
# renormalize spectrogram based on vocoder config
|
||||
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
||||
# compute scale factor for possible sample rate mismatch
|
||||
|
@ -403,8 +409,8 @@ class Synthesizer(object):
|
|||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
||||
# run vocoder model
|
||||
# [1, T, C]
|
||||
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
|
||||
if self.use_cuda and not use_gl:
|
||||
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
|
||||
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu") and not use_gl:
|
||||
waveform = waveform.cpu()
|
||||
if not use_gl:
|
||||
waveform = waveform.numpy()
|
||||
|
@ -453,7 +459,6 @@ class Synthesizer(object):
|
|||
mel_postnet_spec = outputs[0].detach().cpu().numpy()
|
||||
# denormalize tts output based on tts audio config
|
||||
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
|
||||
device_type = "cuda" if self.use_cuda else "cpu"
|
||||
# renormalize spectrogram based on vocoder config
|
||||
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
||||
# compute scale factor for possible sample rate mismatch
|
||||
|
@ -468,8 +473,8 @@ class Synthesizer(object):
|
|||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
||||
# run vocoder model
|
||||
# [1, T, C]
|
||||
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
|
||||
if self.use_cuda:
|
||||
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
|
||||
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu"):
|
||||
waveform = waveform.cpu()
|
||||
if not use_gl:
|
||||
waveform = waveform.numpy()
|
||||
|
|
Loading…
Reference in New Issue