Add char limit warn (#3130)

* Add char limit warning

* Adding v2 langs

* cached_property for cutlet

* Fix import
pull/3168/head
Julian Weber 2023-11-08 10:24:23 +01:00 committed by GitHub
parent f846a9f300
commit ce1a39a9a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 41 additions and 1 deletions

View File

@ -8,6 +8,7 @@ from hangul_romanize import Transliter
from hangul_romanize.rule import academic from hangul_romanize.rule import academic
from num2words import num2words from num2words import num2words
from tokenizers import Tokenizer from tokenizers import Tokenizer
from functools import cached_property
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
@ -535,11 +536,50 @@ DEFAULT_VOCAB_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "
class VoiceBpeTokenizer: class VoiceBpeTokenizer:
def __init__(self, vocab_file=None): def __init__(self, vocab_file=None):
self.tokenizer = None self.tokenizer = None
self.katsu = None
if vocab_file is not None: if vocab_file is not None:
self.tokenizer = Tokenizer.from_file(vocab_file) self.tokenizer = Tokenizer.from_file(vocab_file)
self.char_limits = {
"en": 250,
"de": 253,
"fr": 273,
"es": 239,
"it": 213,
"pt": 203,
"pl": 224,
"zh-cn": 82,
"ar": 166,
"cs": 186,
"ru": 182,
"nl": 251,
"tr": 226,
"ja": 71,
"hu": 224,
"ko": 95,
}
@cached_property
def katsu(self):
import cutlet
return cutlet.Cutlet()
def check_input_length(self, txt, lang):
limit = self.char_limits.get(lang, 250)
if len(txt) > limit:
print(f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio.")
def preprocess_text(self, txt, lang):
if lang in ["en", "es", "fr", "de", "pt", "it", "pl", "ar", "cs", "ru", "nl", "tr", "zh-cn"]:
txt = multilingual_cleaners(txt, lang)
if lang == "zh-cn":
txt = chinese_transliterate(txt)
elif lang == "ja":
txt = japanese_cleaners(txt, self.katsu)
else:
raise NotImplementedError()
return txt
def encode(self, txt, lang): def encode(self, txt, lang):
self.check_input_length(txt, lang)
txt = self.preprocess_text(txt, lang) txt = self.preprocess_text(txt, lang)
txt = f"[{lang}]{txt}" txt = f"[{lang}]{txt}"
txt = txt.replace(" ", "[SPACE]") txt = txt.replace(" ", "[SPACE]")