feat: add adjust_speech_rate function to modify speech speed with more durable latents. also missed tts speed implementations added.

pull/4115/head
isikhi 2024-12-28 23:08:08 +03:00
parent dbf1a08a0d
commit 26128be422
3 changed files with 36 additions and 8 deletions

View File

@ -283,6 +283,7 @@ class TTS(nn.Module):
style_text=None,
reference_speaker_name=None,
split_sentences=split_sentences,
speed=speed,
**kwargs,
)
return wav
@ -330,13 +331,13 @@ class TTS(nn.Module):
Additional arguments for the model.
"""
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
wav = self.tts(
text=text,
speaker=speaker,
language=language,
speaker_wav=speaker_wav,
split_sentences=split_sentences,
speed=speed,
**kwargs,
)
self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)

View File

@ -4,6 +4,7 @@ from typing import Dict, List, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
@ -76,6 +77,33 @@ class BaseTTS(BaseTrainerModel):
else:
raise ValueError("config must be either a *Config or *Args")
def adjust_speech_rate(self, gpt_latents, length_scale):
if abs(length_scale - 1.0) < 1e-6:
return gpt_latents
B, L, D = gpt_latents.shape
target_length = int(L * length_scale)
assert target_length > 0, f"Invalid target length: {target_length}"
try:
resized = F.interpolate(
gpt_latents.transpose(1, 2),
size=target_length,
mode="linear",
align_corners=True
).transpose(1, 2)
if torch.isnan(resized).any():
print("Warning: NaN values detected on adjust speech rate")
return gpt_latents
return resized
except RuntimeError as e:
print(f"Interpolation failed: {e}")
return gpt_latents
def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers.

View File

@ -379,7 +379,7 @@ class Xtts(BaseTTS):
return gpt_cond_latents, speaker_embedding
def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwargs):
def synthesize(self, text, config, speaker_wav, language, speaker_id=None, speed: float = 1.0, **kwargs):
"""Synthesize speech with the given input text.
Args:
@ -409,14 +409,14 @@ class Xtts(BaseTTS):
settings.update(kwargs) # allow overriding of preset settings with kwargs
if speaker_id is not None:
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
return self.inference(text, language, gpt_cond_latent, speaker_embedding, speed=speed, **settings)
settings.update({
"gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs,
})
return self.full_inference(text, speaker_wav, language, **settings)
return self.full_inference(text, speaker_wav, language, speed=speed, **settings)
@torch.inference_mode()
def full_inference(
@ -436,6 +436,7 @@ class Xtts(BaseTTS):
gpt_cond_chunk_len=6,
max_ref_len=10,
sound_norm_refs=False,
speed: float = 1.0,
**hf_generate_kwargs,
):
"""
@ -496,6 +497,7 @@ class Xtts(BaseTTS):
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
speed=speed,
**hf_generate_kwargs,
)
@ -569,10 +571,7 @@ class Xtts(BaseTTS):
)
if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
).transpose(1, 2)
gpt_latents = self.adjust_speech_rate(gpt_latents, length_scale)
gpt_latents_list.append(gpt_latents.cpu())
wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())