mirror of https://github.com/coqui-ai/TTS.git
Fix XTTS GPT padding and inference issues (#3216)
* Fix end artifact for fine tuning models * Bug fix on zh-cn inference * Remove ununsed codepull/3226/head
parent
15f0ac57d6
commit
73a5bd08c0
|
@ -426,15 +426,6 @@ class GPT(nn.Module):
|
|||
if max_mel_len > audio_codes.shape[-1]:
|
||||
audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
|
||||
|
||||
silence = True
|
||||
for idx, l in enumerate(code_lengths):
|
||||
length = l.item()
|
||||
while silence:
|
||||
if audio_codes[idx, length - 1] != 83:
|
||||
break
|
||||
length -= 1
|
||||
code_lengths[idx] = length
|
||||
|
||||
# 💖 Lovely assertions
|
||||
assert (
|
||||
max_mel_len <= audio_codes.shape[-1]
|
||||
|
@ -450,7 +441,7 @@ class GPT(nn.Module):
|
|||
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
|
||||
|
||||
# Pad mel codes with stop_audio_token
|
||||
audio_codes = self.set_mel_padding(audio_codes, code_lengths)
|
||||
audio_codes = self.set_mel_padding(audio_codes, code_lengths - 3) # -3 to get the real code lengths without consider start and stop tokens that was not added yet
|
||||
|
||||
# Build input and target tensors
|
||||
# Prepend start token to inputs and append stop token to targets
|
||||
|
|
|
@ -115,7 +115,7 @@ _abbreviations = {
|
|||
# There are not many common abbreviations in Arabic as in English.
|
||||
]
|
||||
],
|
||||
"zh": [
|
||||
"zh-cn": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
|
||||
|
@ -280,7 +280,7 @@ _symbols_multilingual = {
|
|||
("°", " درجة "),
|
||||
]
|
||||
],
|
||||
"zh": [
|
||||
"zh-cn": [
|
||||
# Chinese
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
|
@ -571,7 +571,7 @@ class VoiceBpeTokenizer:
|
|||
)
|
||||
|
||||
def preprocess_text(self, txt, lang):
|
||||
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "zh-cn"}:
|
||||
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}:
|
||||
txt = multilingual_cleaners(txt, lang)
|
||||
if lang in {"zh", "zh-cn"}:
|
||||
txt = chinese_transliterate(txt)
|
||||
|
@ -682,8 +682,8 @@ def test_expand_numbers_multilingual():
|
|||
("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
|
||||
("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
|
||||
# Chinese (Simplified)
|
||||
("在12.5秒内", "在十二点五秒内", "zh"),
|
||||
("有50名士兵", "有五十名士兵", "zh"),
|
||||
("在12.5秒内", "在十二点五秒内", "zh-cn"),
|
||||
("有50名士兵", "有五十名士兵", "zh-cn"),
|
||||
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
|
||||
# ("那将是20€先生", '那将是二十欧元先生', 'zh'),
|
||||
# Turkish
|
||||
|
@ -764,7 +764,7 @@ def test_symbols_multilingual():
|
|||
("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
|
||||
("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
|
||||
("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
|
||||
("我的电量为 14%", "我的电量为 14 百分之", "zh"),
|
||||
("我的电量为 14%", "我的电量为 14 百分之", "zh-cn"),
|
||||
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
|
||||
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
|
||||
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
|
||||
|
|
|
@ -7,7 +7,6 @@ import torch.nn.functional as F
|
|||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.tts.layers.tortoise.audio_utils import wav_to_univnet_mel
|
||||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||
|
@ -308,26 +307,6 @@ class Xtts(BaseTTS):
|
|||
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
||||
return cond_latent.transpose(1, 2)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_diffusion_cond_latents(self, audio, sr):
|
||||
from math import ceil
|
||||
|
||||
diffusion_conds = []
|
||||
CHUNK_SIZE = 102400
|
||||
audio_24k = torchaudio.functional.resample(audio, sr, 24000)
|
||||
for chunk in range(ceil(audio_24k.shape[1] / CHUNK_SIZE)):
|
||||
current_sample = audio_24k[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
|
||||
current_sample = pad_or_truncate(current_sample, CHUNK_SIZE)
|
||||
cond_mel = wav_to_univnet_mel(
|
||||
current_sample.to(self.device),
|
||||
do_normalization=False,
|
||||
device=self.device,
|
||||
)
|
||||
diffusion_conds.append(cond_mel)
|
||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||
diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
|
||||
return diffusion_latent
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_speaker_embedding(self, audio, sr):
|
||||
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
|
||||
|
@ -575,16 +554,6 @@ class Xtts(BaseTTS):
|
|||
return_attentions=False,
|
||||
return_latent=True,
|
||||
)
|
||||
silence_token = 83
|
||||
ctokens = 0
|
||||
for k in range(gpt_codes.shape[-1]):
|
||||
if gpt_codes[0, k] == silence_token:
|
||||
ctokens += 1
|
||||
else:
|
||||
ctokens = 0
|
||||
if ctokens > 8:
|
||||
gpt_latents = gpt_latents[:, :k]
|
||||
break
|
||||
|
||||
if length_scale != 1.0:
|
||||
gpt_latents = F.interpolate(
|
||||
|
|
Loading…
Reference in New Issue