mirror of https://github.com/coqui-ai/TTS.git
feat: add adjust_speech_rate function to modify speech speed with more durable latents. also missed tts speed implementations added.
parent
dbf1a08a0d
commit
26128be422
|
@ -283,6 +283,7 @@ class TTS(nn.Module):
|
|||
style_text=None,
|
||||
reference_speaker_name=None,
|
||||
split_sentences=split_sentences,
|
||||
speed=speed,
|
||||
**kwargs,
|
||||
)
|
||||
return wav
|
||||
|
@ -330,13 +331,13 @@ class TTS(nn.Module):
|
|||
Additional arguments for the model.
|
||||
"""
|
||||
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
|
||||
|
||||
wav = self.tts(
|
||||
text=text,
|
||||
speaker=speaker,
|
||||
language=language,
|
||||
speaker_wav=speaker_wav,
|
||||
split_sentences=split_sentences,
|
||||
speed=speed,
|
||||
**kwargs,
|
||||
)
|
||||
self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Dict, List, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -76,6 +77,33 @@ class BaseTTS(BaseTrainerModel):
|
|||
else:
|
||||
raise ValueError("config must be either a *Config or *Args")
|
||||
|
||||
def adjust_speech_rate(self, gpt_latents, length_scale):
|
||||
if abs(length_scale - 1.0) < 1e-6:
|
||||
return gpt_latents
|
||||
|
||||
B, L, D = gpt_latents.shape
|
||||
target_length = int(L * length_scale)
|
||||
|
||||
assert target_length > 0, f"Invalid target length: {target_length}"
|
||||
|
||||
try:
|
||||
resized = F.interpolate(
|
||||
gpt_latents.transpose(1, 2),
|
||||
size=target_length,
|
||||
mode="linear",
|
||||
align_corners=True
|
||||
).transpose(1, 2)
|
||||
|
||||
if torch.isnan(resized).any():
|
||||
print("Warning: NaN values detected on adjust speech rate")
|
||||
return gpt_latents
|
||||
|
||||
return resized
|
||||
|
||||
except RuntimeError as e:
|
||||
print(f"Interpolation failed: {e}")
|
||||
return gpt_latents
|
||||
|
||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||
`in_channels` size of the connected layers.
|
||||
|
|
|
@ -379,7 +379,7 @@ class Xtts(BaseTTS):
|
|||
|
||||
return gpt_cond_latents, speaker_embedding
|
||||
|
||||
def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwargs):
|
||||
def synthesize(self, text, config, speaker_wav, language, speaker_id=None, speed: float = 1.0, **kwargs):
|
||||
"""Synthesize speech with the given input text.
|
||||
|
||||
Args:
|
||||
|
@ -409,14 +409,14 @@ class Xtts(BaseTTS):
|
|||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||
if speaker_id is not None:
|
||||
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
|
||||
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
|
||||
return self.inference(text, language, gpt_cond_latent, speaker_embedding, speed=speed, **settings)
|
||||
settings.update({
|
||||
"gpt_cond_len": config.gpt_cond_len,
|
||||
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
||||
"max_ref_len": config.max_ref_len,
|
||||
"sound_norm_refs": config.sound_norm_refs,
|
||||
})
|
||||
return self.full_inference(text, speaker_wav, language, **settings)
|
||||
return self.full_inference(text, speaker_wav, language, speed=speed, **settings)
|
||||
|
||||
@torch.inference_mode()
|
||||
def full_inference(
|
||||
|
@ -436,6 +436,7 @@ class Xtts(BaseTTS):
|
|||
gpt_cond_chunk_len=6,
|
||||
max_ref_len=10,
|
||||
sound_norm_refs=False,
|
||||
speed: float = 1.0,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -496,6 +497,7 @@ class Xtts(BaseTTS):
|
|||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
do_sample=do_sample,
|
||||
speed=speed,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
|
@ -569,10 +571,7 @@ class Xtts(BaseTTS):
|
|||
)
|
||||
|
||||
if length_scale != 1.0:
|
||||
gpt_latents = F.interpolate(
|
||||
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
||||
).transpose(1, 2)
|
||||
|
||||
gpt_latents = self.adjust_speech_rate(gpt_latents, length_scale)
|
||||
gpt_latents_list.append(gpt_latents.cpu())
|
||||
wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())
|
||||
|
||||
|
|
Loading…
Reference in New Issue