mirror of https://github.com/coqui-ai/TTS.git
Run `make style` & re-enable it in CI (#3127)
parent
6fef4f9067
commit
38f6f8f0bb
|
@ -42,6 +42,5 @@ jobs:
|
|||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
# - name: Lint check
|
||||
# run: |
|
||||
# make lint
|
||||
- name: Style check
|
||||
run: make style
|
||||
|
|
|
@ -264,7 +264,7 @@ class TTS(nn.Module):
|
|||
language: str = None,
|
||||
emotion: str = None,
|
||||
speed: float = 1.0,
|
||||
pipe_out = None,
|
||||
pipe_out=None,
|
||||
file_path: str = None,
|
||||
) -> Union[np.ndarray, str]:
|
||||
"""Convert text to speech using Coqui Studio models. Use `CS_API` class if you are only interested in the API.
|
||||
|
@ -359,7 +359,7 @@ class TTS(nn.Module):
|
|||
speaker_wav: str = None,
|
||||
emotion: str = None,
|
||||
speed: float = 1.0,
|
||||
pipe_out = None,
|
||||
pipe_out=None,
|
||||
file_path: str = "output.wav",
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -460,7 +460,7 @@ class TTS(nn.Module):
|
|||
"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
||||
# Lazy code... save it to a temp file to resample it while reading it for VC
|
||||
self.tts_to_file(text=text, speaker=None, language=language, file_path=fp.name,speaker_wav=speaker_wav)
|
||||
self.tts_to_file(text=text, speaker=None, language=language, file_path=fp.name, speaker_wav=speaker_wav)
|
||||
if self.voice_converter is None:
|
||||
self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24")
|
||||
wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav)
|
||||
|
|
|
@ -427,7 +427,9 @@ def main():
|
|||
tts_path = model_path
|
||||
tts_config_path = config_path
|
||||
if "default_vocoder" in model_item:
|
||||
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
||||
args.vocoder_name = (
|
||||
model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
||||
)
|
||||
|
||||
# voice conversion model
|
||||
if model_item["model_type"] == "voice_conversion_models":
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
|
||||
import torch
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
import pypinyin
|
||||
import torch
|
||||
from num2words import num2words
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
@ -87,7 +87,7 @@ _abbreviations = {
|
|||
"it": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
#("sig.ra", "signora"),
|
||||
# ("sig.ra", "signora"),
|
||||
("sig", "signore"),
|
||||
("dr", "dottore"),
|
||||
("st", "santo"),
|
||||
|
@ -121,49 +121,51 @@ _abbreviations = {
|
|||
"cs": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("dr", "doktor"), # doctor
|
||||
("ing", "inženýr"), # engineer
|
||||
("p", "pan"), # Could also map to pani for woman but no easy way to do it
|
||||
("dr", "doktor"), # doctor
|
||||
("ing", "inženýr"), # engineer
|
||||
("p", "pan"), # Could also map to pani for woman but no easy way to do it
|
||||
# Other abbreviations would be specialized and not as common.
|
||||
]
|
||||
],
|
||||
"ru": [
|
||||
(re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("г-жа", "госпожа"), # Mrs.
|
||||
("г-н", "господин"), # Mr.
|
||||
("д-р", "доктор"), # doctor
|
||||
("г-жа", "госпожа"), # Mrs.
|
||||
("г-н", "господин"), # Mr.
|
||||
("д-р", "доктор"), # doctor
|
||||
# Other abbreviations are less common or specialized.
|
||||
]
|
||||
],
|
||||
"nl": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("dhr", "de heer"), # Mr.
|
||||
("dhr", "de heer"), # Mr.
|
||||
("mevr", "mevrouw"), # Mrs.
|
||||
("dr", "dokter"), # doctor
|
||||
("jhr", "jonkheer"), # young lord or nobleman
|
||||
("dr", "dokter"), # doctor
|
||||
("jhr", "jonkheer"), # young lord or nobleman
|
||||
# Dutch uses more abbreviations, but these are the most common ones.
|
||||
]
|
||||
],
|
||||
"tr": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("b", "bay"), # Mr.
|
||||
("b", "bay"), # Mr.
|
||||
("byk", "büyük"), # büyük
|
||||
("dr", "doktor"), # doctor
|
||||
("dr", "doktor"), # doctor
|
||||
# Add other Turkish abbreviations here if needed.
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
def expand_abbreviations_multilingual(text, lang='en'):
|
||||
|
||||
def expand_abbreviations_multilingual(text, lang="en"):
|
||||
for regex, replacement in _abbreviations[lang]:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
_symbols_multilingual = {
|
||||
'en': [
|
||||
"en": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " and "),
|
||||
|
@ -172,10 +174,10 @@ _symbols_multilingual = {
|
|||
("#", " hash "),
|
||||
("$", " dollar "),
|
||||
("£", " pound "),
|
||||
("°", " degree ")
|
||||
("°", " degree "),
|
||||
]
|
||||
],
|
||||
'es': [
|
||||
"es": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " y "),
|
||||
|
@ -184,10 +186,10 @@ _symbols_multilingual = {
|
|||
("#", " numeral "),
|
||||
("$", " dolar "),
|
||||
("£", " libra "),
|
||||
("°", " grados ")
|
||||
("°", " grados "),
|
||||
]
|
||||
],
|
||||
'fr': [
|
||||
"fr": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " et "),
|
||||
|
@ -196,10 +198,10 @@ _symbols_multilingual = {
|
|||
("#", " dièse "),
|
||||
("$", " dollar "),
|
||||
("£", " livre "),
|
||||
("°", " degrés ")
|
||||
("°", " degrés "),
|
||||
]
|
||||
],
|
||||
'de': [
|
||||
"de": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " und "),
|
||||
|
@ -208,10 +210,10 @@ _symbols_multilingual = {
|
|||
("#", " raute "),
|
||||
("$", " dollar "),
|
||||
("£", " pfund "),
|
||||
("°", " grad ")
|
||||
("°", " grad "),
|
||||
]
|
||||
],
|
||||
'pt': [
|
||||
"pt": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " e "),
|
||||
|
@ -220,10 +222,10 @@ _symbols_multilingual = {
|
|||
("#", " cardinal "),
|
||||
("$", " dólar "),
|
||||
("£", " libra "),
|
||||
("°", " graus ")
|
||||
("°", " graus "),
|
||||
]
|
||||
],
|
||||
'it': [
|
||||
"it": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " e "),
|
||||
|
@ -232,10 +234,10 @@ _symbols_multilingual = {
|
|||
("#", " cancelletto "),
|
||||
("$", " dollaro "),
|
||||
("£", " sterlina "),
|
||||
("°", " gradi ")
|
||||
("°", " gradi "),
|
||||
]
|
||||
],
|
||||
'pl': [
|
||||
"pl": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " i "),
|
||||
|
@ -244,7 +246,7 @@ _symbols_multilingual = {
|
|||
("#", " krzyżyk "),
|
||||
("$", " dolar "),
|
||||
("£", " funt "),
|
||||
("°", " stopnie ")
|
||||
("°", " stopnie "),
|
||||
]
|
||||
],
|
||||
"ar": [
|
||||
|
@ -257,7 +259,7 @@ _symbols_multilingual = {
|
|||
("#", " رقم "),
|
||||
("$", " دولار "),
|
||||
("£", " جنيه "),
|
||||
("°", " درجة ")
|
||||
("°", " درجة "),
|
||||
]
|
||||
],
|
||||
"zh-cn": [
|
||||
|
@ -270,7 +272,7 @@ _symbols_multilingual = {
|
|||
("#", " 号 "),
|
||||
("$", " 美元 "),
|
||||
("£", " 英镑 "),
|
||||
("°", " 度 ")
|
||||
("°", " 度 "),
|
||||
]
|
||||
],
|
||||
"cs": [
|
||||
|
@ -283,7 +285,7 @@ _symbols_multilingual = {
|
|||
("#", " křížek "),
|
||||
("$", " dolar "),
|
||||
("£", " libra "),
|
||||
("°", " stupně ")
|
||||
("°", " stupně "),
|
||||
]
|
||||
],
|
||||
"ru": [
|
||||
|
@ -296,7 +298,7 @@ _symbols_multilingual = {
|
|||
("#", " номер "),
|
||||
("$", " доллар "),
|
||||
("£", " фунт "),
|
||||
("°", " градус ")
|
||||
("°", " градус "),
|
||||
]
|
||||
],
|
||||
"nl": [
|
||||
|
@ -309,7 +311,7 @@ _symbols_multilingual = {
|
|||
("#", " hekje "),
|
||||
("$", " dollar "),
|
||||
("£", " pond "),
|
||||
("°", " graden ")
|
||||
("°", " graden "),
|
||||
]
|
||||
],
|
||||
"tr": [
|
||||
|
@ -321,15 +323,16 @@ _symbols_multilingual = {
|
|||
("#", " diyez "),
|
||||
("$", " dolar "),
|
||||
("£", " sterlin "),
|
||||
("°", " derece ")
|
||||
("°", " derece "),
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
def expand_symbols_multilingual(text, lang='en'):
|
||||
|
||||
def expand_symbols_multilingual(text, lang="en"):
|
||||
for regex, replacement in _symbols_multilingual[lang]:
|
||||
text = re.sub(regex, replacement, text)
|
||||
text = text.replace(' ', ' ') # Ensure there are no double spaces
|
||||
text = text.replace(" ", " ") # Ensure there are no double spaces
|
||||
return text.strip()
|
||||
|
||||
|
||||
|
@ -342,41 +345,45 @@ _ordinal_re = {
|
|||
"it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"),
|
||||
"pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"),
|
||||
"ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"),
|
||||
"cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals.
|
||||
"cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals.
|
||||
"ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
|
||||
"nl": re.compile(r"([0-9]+)(de|ste|e)"),
|
||||
"tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
|
||||
}
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
_currency_re = {
|
||||
'USD': re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
|
||||
'GBP': re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
|
||||
'EUR': re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))")
|
||||
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
|
||||
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
|
||||
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
|
||||
}
|
||||
|
||||
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
|
||||
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
|
||||
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
text = m.group(0)
|
||||
if "," in text:
|
||||
text = text.replace(",", "")
|
||||
return text
|
||||
|
||||
|
||||
def _remove_dots(m):
|
||||
text = m.group(0)
|
||||
if "." in text:
|
||||
text = text.replace(".", "")
|
||||
return text
|
||||
|
||||
def _expand_decimal_point(m, lang='en'):
|
||||
|
||||
def _expand_decimal_point(m, lang="en"):
|
||||
amount = m.group(1).replace(",", ".")
|
||||
return num2words(float(amount), lang=lang if lang != "cs" else "cz")
|
||||
|
||||
def _expand_currency(m, lang='en', currency='USD'):
|
||||
amount = float((re.sub(r'[^\d.]', '', m.group(0).replace(",", "."))))
|
||||
full_amount = num2words(amount, to='currency', currency=currency, lang=lang if lang != "cs" else "cz")
|
||||
|
||||
def _expand_currency(m, lang="en", currency="USD"):
|
||||
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
|
||||
full_amount = num2words(amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz")
|
||||
|
||||
and_equivalents = {
|
||||
"en": ", ",
|
||||
|
@ -400,13 +407,16 @@ def _expand_currency(m, lang='en', currency='USD'):
|
|||
|
||||
return full_amount
|
||||
|
||||
def _expand_ordinal(m, lang='en'):
|
||||
|
||||
def _expand_ordinal(m, lang="en"):
|
||||
return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
|
||||
|
||||
def _expand_number(m, lang='en'):
|
||||
|
||||
def _expand_number(m, lang="en"):
|
||||
return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
|
||||
|
||||
def expand_numbers_multilingual(text, lang='en'):
|
||||
|
||||
def expand_numbers_multilingual(text, lang="en"):
|
||||
if lang == "zh-cn":
|
||||
text = zh_num2words()(text)
|
||||
else:
|
||||
|
@ -415,9 +425,9 @@ def expand_numbers_multilingual(text, lang='en'):
|
|||
else:
|
||||
text = re.sub(_dot_number_re, _remove_dots, text)
|
||||
try:
|
||||
text = re.sub(_currency_re['GBP'], lambda m: _expand_currency(m, lang, 'GBP'), text)
|
||||
text = re.sub(_currency_re['USD'], lambda m: _expand_currency(m, lang, 'USD'), text)
|
||||
text = re.sub(_currency_re['EUR'], lambda m: _expand_currency(m, lang, 'EUR'), text)
|
||||
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
|
||||
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
|
||||
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
|
||||
except:
|
||||
pass
|
||||
if lang != "tr":
|
||||
|
@ -426,15 +436,18 @@ def expand_numbers_multilingual(text, lang='en'):
|
|||
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
|
||||
return text
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def multilingual_cleaners(text, lang):
|
||||
text = text.replace('"', '')
|
||||
if lang=="tr":
|
||||
text = text.replace('"', "")
|
||||
if lang == "tr":
|
||||
text = text.replace("İ", "i")
|
||||
text = text.replace("Ö", "ö")
|
||||
text = text.replace("Ü", "ü")
|
||||
|
@ -445,20 +458,26 @@ def multilingual_cleaners(text, lang):
|
|||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def chinese_transliterate(text):
|
||||
return "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)])
|
||||
return "".join(
|
||||
p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)
|
||||
)
|
||||
|
||||
|
||||
def japanese_cleaners(text, katsu):
|
||||
text = katsu.romaji(text)
|
||||
text = lowercase(text)
|
||||
return text
|
||||
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file=None, preprocess=None):
|
||||
self.tokenizer = None
|
||||
|
@ -485,6 +504,7 @@ class VoiceBpeTokenizer:
|
|||
elif lang == "ja":
|
||||
if self.katsu is None:
|
||||
import cutlet
|
||||
|
||||
self.katsu = cutlet.Cutlet()
|
||||
txt = japanese_cleaners(txt, self.katsu)
|
||||
else:
|
||||
|
|
|
@ -2,9 +2,14 @@
|
|||
# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
|
||||
# 2019.9 - 2022 Jiayu DU
|
||||
|
||||
import sys, os, argparse
|
||||
import string, re
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
|
||||
# fmt: off
|
||||
|
||||
# ================================================================================ #
|
||||
# basic constant
|
||||
|
|
|
@ -2,10 +2,10 @@ import os
|
|||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import librosa
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
|
||||
|
@ -386,10 +386,12 @@ class Xtts(BaseTTS):
|
|||
@torch.inference_mode()
|
||||
def get_speaker_embedding(self, audio, sr):
|
||||
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
|
||||
return self.hifigan_decoder.speaker_encoder.forward(
|
||||
audio_16k.to(self.device), l2_norm=True
|
||||
).unsqueeze(-1).to(self.device)
|
||||
|
||||
return (
|
||||
self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True)
|
||||
.unsqueeze(-1)
|
||||
.to(self.device)
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_conditioning_latents(
|
||||
self,
|
||||
|
@ -398,7 +400,7 @@ class Xtts(BaseTTS):
|
|||
max_ref_length=10,
|
||||
librosa_trim_db=None,
|
||||
sound_norm_refs=False,
|
||||
):
|
||||
):
|
||||
speaker_embedding = None
|
||||
diffusion_cond_latents = None
|
||||
|
||||
|
@ -647,13 +649,19 @@ class Xtts(BaseTTS):
|
|||
break
|
||||
|
||||
if decoder == "hifigan":
|
||||
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
||||
assert hasattr(
|
||||
self, "hifigan_decoder"
|
||||
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||
elif decoder == "ne_hifigan":
|
||||
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
||||
assert hasattr(
|
||||
self, "ne_hifigan_decoder"
|
||||
), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
||||
wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||
else:
|
||||
assert hasattr(self, "diffusion_decoder"), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
|
||||
assert hasattr(
|
||||
self, "diffusion_decoder"
|
||||
), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
|
||||
mel = do_spectrogram_diffusion(
|
||||
self.diffusion_decoder,
|
||||
diffuser,
|
||||
|
@ -742,10 +750,14 @@ class Xtts(BaseTTS):
|
|||
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
||||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||
if decoder == "hifigan":
|
||||
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
||||
assert hasattr(
|
||||
self, "hifigan_decoder"
|
||||
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||
elif decoder == "ne_hifigan":
|
||||
assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
||||
assert hasattr(
|
||||
self, "ne_hifigan_decoder"
|
||||
), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
|
||||
wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||
else:
|
||||
raise NotImplementedError("Diffusion for streaming inference not implemented.")
|
||||
|
@ -756,10 +768,14 @@ class Xtts(BaseTTS):
|
|||
yield wav_chunk
|
||||
|
||||
def forward(self):
|
||||
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training")
|
||||
raise NotImplementedError(
|
||||
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
|
||||
)
|
||||
|
||||
def eval_step(self):
|
||||
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training")
|
||||
raise NotImplementedError(
|
||||
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument
|
||||
|
@ -835,12 +851,18 @@ class Xtts(BaseTTS):
|
|||
self.load_state_dict(checkpoint, strict=strict)
|
||||
|
||||
if eval:
|
||||
if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
|
||||
if hasattr(self, "ne_hifigan_decoder"): self.hifigan_decoder.eval()
|
||||
if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
|
||||
if hasattr(self, "vocoder"): self.vocoder.eval()
|
||||
if hasattr(self, "hifigan_decoder"):
|
||||
self.hifigan_decoder.eval()
|
||||
if hasattr(self, "ne_hifigan_decoder"):
|
||||
self.hifigan_decoder.eval()
|
||||
if hasattr(self, "diffusion_decoder"):
|
||||
self.diffusion_decoder.eval()
|
||||
if hasattr(self, "vocoder"):
|
||||
self.vocoder.eval()
|
||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
|
||||
self.gpt.eval()
|
||||
|
||||
def train_step(self):
|
||||
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training")
|
||||
raise NotImplementedError(
|
||||
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
|
||||
)
|
||||
|
|
|
@ -428,7 +428,7 @@ def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False,
|
|||
return x
|
||||
|
||||
|
||||
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, pipe_out = None, **kwargs) -> None:
|
||||
def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, pipe_out=None, **kwargs) -> None:
|
||||
"""Save float waveform to a file using Scipy.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -694,7 +694,7 @@ class AudioProcessor(object):
|
|||
x = self.rms_volume_norm(x, self.db_level)
|
||||
return x
|
||||
|
||||
def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out = None) -> None:
|
||||
def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out=None) -> None:
|
||||
"""Save a waveform to a file using Scipy.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -235,7 +235,7 @@ class Synthesizer(nn.Module):
|
|||
"""
|
||||
return self.seg.segment(text)
|
||||
|
||||
def save_wav(self, wav: List[int], path: str, pipe_out = None) -> None:
|
||||
def save_wav(self, wav: List[int], path: str, pipe_out=None) -> None:
|
||||
"""Save the waveform as a file.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -7,7 +7,6 @@ from TTS.tts.datasets import load_tts_samples
|
|||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
|
||||
# Logging parameters
|
||||
RUN_NAME = "GPT_XTTS_LJSpeech_FT"
|
||||
PROJECT_NAME = "XTTS_trainer"
|
||||
|
@ -60,13 +59,15 @@ TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vo
|
|||
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth"
|
||||
|
||||
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
||||
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file
|
||||
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file
|
||||
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file
|
||||
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file
|
||||
|
||||
# download XTTS v1.1 files if needed
|
||||
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
|
||||
print(" > Downloading XTTS v1.1 files!")
|
||||
ModelManager._download_model_files([TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
||||
ModelManager._download_model_files(
|
||||
[TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
|
||||
)
|
||||
|
||||
|
||||
# Training sentences generations
|
||||
|
@ -93,7 +94,7 @@ def main():
|
|||
gpt_num_audio_tokens=8194,
|
||||
gpt_start_audio_token=8192,
|
||||
gpt_stop_audio_token=8193,
|
||||
use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint
|
||||
use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint
|
||||
)
|
||||
# define audio config
|
||||
audio_config = XttsAudioConfig(
|
||||
|
|
|
@ -22,7 +22,4 @@ def test_synthesize():
|
|||
)
|
||||
|
||||
# test pipe_out command
|
||||
run_cli(
|
||||
'tts --text "test." --pipe_out '
|
||||
f'--out_path "{output_path}" | aplay'
|
||||
)
|
||||
run_cli(f'tts --text "test." --pipe_out --out_path "{output_path}" | aplay')
|
||||
|
|
Loading…
Reference in New Issue