mirror of https://github.com/coqui-ai/TTS.git
Make style
parent
f21067a84a
commit
a3279f9294
|
@ -17,7 +17,6 @@ from tqdm import tqdm
|
|||
|
||||
from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
|
||||
|
||||
|
||||
try:
|
||||
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
|
||||
|
||||
|
|
|
@ -441,7 +441,9 @@ class GPT(nn.Module):
|
|||
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
|
||||
|
||||
# Pad mel codes with stop_audio_token
|
||||
audio_codes = self.set_mel_padding(audio_codes, code_lengths - 3) # -3 to get the real code lengths without consider start and stop tokens that was not added yet
|
||||
audio_codes = self.set_mel_padding(
|
||||
audio_codes, code_lengths - 3
|
||||
) # -3 to get the real code lengths without consider start and stop tokens that was not added yet
|
||||
|
||||
# Build input and target tensors
|
||||
# Prepend start token to inputs and append stop token to targets
|
||||
|
|
|
@ -1,23 +1,22 @@
|
|||
import os
|
||||
import re
|
||||
import torch
|
||||
import pypinyin
|
||||
import textwrap
|
||||
|
||||
from functools import cached_property
|
||||
|
||||
import pypinyin
|
||||
import torch
|
||||
from hangul_romanize import Transliter
|
||||
from hangul_romanize.rule import academic
|
||||
from num2words import num2words
|
||||
from spacy.lang.ar import Arabic
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.es import Spanish
|
||||
from spacy.lang.ja import Japanese
|
||||
from spacy.lang.zh import Chinese
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
||||
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.zh import Chinese
|
||||
from spacy.lang.ja import Japanese
|
||||
from spacy.lang.ar import Arabic
|
||||
from spacy.lang.es import Spanish
|
||||
|
||||
|
||||
def get_spacy_lang(lang):
|
||||
if lang == "zh":
|
||||
|
@ -32,6 +31,7 @@ def get_spacy_lang(lang):
|
|||
# For most languages, Enlish does the job
|
||||
return English()
|
||||
|
||||
|
||||
def split_sentence(text, lang, text_split_length=250):
|
||||
"""Preprocess the input text"""
|
||||
text_splits = []
|
||||
|
@ -67,6 +67,7 @@ def split_sentence(text, lang, text_split_length=250):
|
|||
|
||||
return text_splits
|
||||
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
|
@ -619,7 +620,7 @@ class VoiceBpeTokenizer:
|
|||
return cutlet.Cutlet()
|
||||
|
||||
def check_input_length(self, txt, lang):
|
||||
lang = lang.split("-")[0] # remove the region
|
||||
lang = lang.split("-")[0] # remove the region
|
||||
limit = self.char_limits.get(lang, 250)
|
||||
if len(txt) > limit:
|
||||
print(
|
||||
|
@ -640,7 +641,7 @@ class VoiceBpeTokenizer:
|
|||
return txt
|
||||
|
||||
def encode(self, txt, lang):
|
||||
lang = lang.split("-")[0] # remove the region
|
||||
lang = lang.split("-")[0] # remove the region
|
||||
self.check_input_length(txt, lang)
|
||||
txt = self.preprocess_text(txt, lang)
|
||||
lang = "zh-cn" if lang == "zh" else lang
|
||||
|
|
|
@ -513,13 +513,13 @@ class Xtts(BaseTTS):
|
|||
enable_text_splitting=False,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
language = language.split("-")[0] # remove the country code
|
||||
language = language.split("-")[0] # remove the country code
|
||||
length_scale = 1.0 / max(speed, 0.05)
|
||||
if enable_text_splitting:
|
||||
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
||||
else:
|
||||
text = [text]
|
||||
|
||||
|
||||
wavs = []
|
||||
gpt_latents_list = []
|
||||
for sent in text:
|
||||
|
@ -563,9 +563,7 @@ class Xtts(BaseTTS):
|
|||
|
||||
if length_scale != 1.0:
|
||||
gpt_latents = F.interpolate(
|
||||
gpt_latents.transpose(1, 2),
|
||||
scale_factor=length_scale,
|
||||
mode="linear"
|
||||
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
||||
).transpose(1, 2)
|
||||
|
||||
gpt_latents_list.append(gpt_latents.cpu())
|
||||
|
@ -623,7 +621,7 @@ class Xtts(BaseTTS):
|
|||
enable_text_splitting=False,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
language = language.split("-")[0] # remove the country code
|
||||
language = language.split("-")[0] # remove the country code
|
||||
length_scale = 1.0 / max(speed, 0.05)
|
||||
if enable_text_splitting:
|
||||
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
||||
|
@ -675,9 +673,7 @@ class Xtts(BaseTTS):
|
|||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||
if length_scale != 1.0:
|
||||
gpt_latents = F.interpolate(
|
||||
gpt_latents.transpose(1, 2),
|
||||
scale_factor=length_scale,
|
||||
mode="linear"
|
||||
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
||||
).transpose(1, 2)
|
||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||
|
|
|
@ -186,7 +186,7 @@ def test_xtts_v2_streaming():
|
|||
"en",
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
speed=1.5
|
||||
speed=1.5,
|
||||
)
|
||||
wav_chuncks = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
|
@ -198,7 +198,7 @@ def test_xtts_v2_streaming():
|
|||
"en",
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
speed=0.66
|
||||
speed=0.66,
|
||||
)
|
||||
wav_chuncks = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
|
|
Loading…
Reference in New Issue