diff --git a/TTS/_version.py b/TTS/_version.py index f4956698..311f216e 100644 --- a/TTS/_version.py +++ b/TTS/_version.py @@ -1 +1 @@ -__version__ = "0.0.14.1" +__version__ = "0.0.14" diff --git a/TTS/tts/datasets/preprocess.py b/TTS/tts/datasets/preprocess.py index 72ab160e..271b1734 100644 --- a/TTS/tts/datasets/preprocess.py +++ b/TTS/tts/datasets/preprocess.py @@ -424,3 +424,17 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]: wav_path = os.path.join(root_path, "clips_22", wav_name) items.append([text, wav_path, speaker_name]) return items + + +def kokoro(root_path, meta_file): + """Japanese single-speaker dataset from https://github.com/kaiidams/Kokoro-Speech-Dataset""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "kokoro" + with open(txt_file, "r") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + '.wav') + text = cols[2].replace(" ", "") + items.append([text, wav_file, speaker_name]) + return items diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py index 9367e6e2..f9f44167 100644 --- a/TTS/tts/utils/text/__init__.py +++ b/TTS/tts/utils/text/__init__.py @@ -6,6 +6,7 @@ from packaging import version from TTS.tts.utils.text import cleaners from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes +from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes from TTS.tts.utils.text.symbols import _bos, _eos, _punctuations, make_symbols, phonemes, symbols # pylint: disable=unnecessary-comprehension @@ -39,6 +40,11 @@ def text2phone(text, language): if language == "zh-CN": ph = chinese_text_to_phonemes(text) return ph + + if language == "ja-jp": + ph = japanese_text_to_phonemes(text) + return ph + raise ValueError(f" [!] Language {language} is not supported for phonemization.") diff --git a/TTS/tts/utils/text/japanese/__init__.py b/TTS/tts/utils/text/japanese/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/tts/utils/text/japanese/phonemizer.py b/TTS/tts/utils/text/japanese/phonemizer.py new file mode 100644 index 00000000..f09d5b05 --- /dev/null +++ b/TTS/tts/utils/text/japanese/phonemizer.py @@ -0,0 +1,379 @@ +# Convert Japanese text to phonemes which is +# compatible with Julius https://github.com/julius-speech/segmentation-kit + +import re +import MeCab + +_CONVRULES = [ + # Conversion of 2 letters + 'アァ/ a a', + 'イィ/ i i', + 'イェ/ i e', + 'イャ/ y a', + 'ウゥ/ u:', + 'エェ/ e e', + 'オォ/ o:', + 'カァ/ k a:', + 'キィ/ k i:', + 'クゥ/ k u:', + 'クャ/ ky a', + 'クュ/ ky u', + 'クョ/ ky o', + 'ケェ/ k e:', + 'コォ/ k o:', + 'ガァ/ g a:', + 'ギィ/ g i:', + 'グゥ/ g u:', + 'グャ/ gy a', + 'グュ/ gy u', + 'グョ/ gy o', + 'ゲェ/ g e:', + 'ゴォ/ g o:', + 'サァ/ s a:', + 'シィ/ sh i:', + 'スゥ/ s u:', + 'スャ/ sh a', + 'スュ/ sh u', + 'スョ/ sh o', + 'セェ/ s e:', + 'ソォ/ s o:', + 'ザァ/ z a:', + 'ジィ/ j i:', + 'ズゥ/ z u:', + 'ズャ/ zy a', + 'ズュ/ zy u', + 'ズョ/ zy o', + 'ゼェ/ z e:', + 'ゾォ/ z o:', + 'タァ/ t a:', + 'チィ/ ch i:', + 'ツァ/ ts a', + 'ツィ/ ts i', + 'ツゥ/ ts u:', + 'ツャ/ ch a', + 'ツュ/ ch u', + 'ツョ/ ch o', + 'ツェ/ ts e', + 'ツォ/ ts o', + 'テェ/ t e:', + 'トォ/ t o:', + 'ダァ/ d a:', + 'ヂィ/ j i:', + 'ヅゥ/ d u:', + 'ヅャ/ zy a', + 'ヅュ/ zy u', + 'ヅョ/ zy o', + 'デェ/ d e:', + 'ドォ/ d o:', + 'ナァ/ n a:', + 'ニィ/ n i:', + 'ヌゥ/ n u:', + 'ヌャ/ ny a', + 'ヌュ/ ny u', + 'ヌョ/ ny o', + 'ネェ/ n e:', + 'ノォ/ n o:', + 'ハァ/ h a:', + 'ヒィ/ h i:', + 'フゥ/ f u:', + 'フャ/ hy a', + 'フュ/ hy u', + 'フョ/ hy o', + 'ヘェ/ h e:', + 'ホォ/ h o:', + 'バァ/ b a:', + 'ビィ/ b i:', + 'ブゥ/ b u:', + 'フャ/ hy a', + 'ブュ/ by u', + 'フョ/ hy o', + 'ベェ/ b e:', + 'ボォ/ b o:', + 'パァ/ p a:', + 'ピィ/ p i:', + 'プゥ/ p u:', + 'プャ/ py a', + 'プュ/ py u', + 'プョ/ py o', + 'ペェ/ p e:', + 'ポォ/ p o:', + 'マァ/ m a:', + 'ミィ/ m i:', + 'ムゥ/ m u:', + 'ムャ/ my a', + 'ムュ/ my u', + 'ムョ/ my o', + 'メェ/ m e:', + 'モォ/ m o:', + 'ヤァ/ y a:', + 'ユゥ/ y u:', + 'ユャ/ y a:', + 'ユュ/ y u:', + 'ユョ/ y o:', + 'ヨォ/ y o:', + 'ラァ/ r a:', + 'リィ/ r i:', + 'ルゥ/ r u:', + 'ルャ/ ry a', + 'ルュ/ ry u', + 'ルョ/ ry o', + 'レェ/ r e:', + 'ロォ/ r o:', + 'ワァ/ w a:', + 'ヲォ/ o:', + 'ディ/ d i', + 'デェ/ d e:', + 'デャ/ dy a', + 'デュ/ dy u', + 'デョ/ dy o', + 'ティ/ t i', + 'テェ/ t e:', + 'テャ/ ty a', + 'テュ/ ty u', + 'テョ/ ty o', + 'スィ/ s i', + 'ズァ/ z u a', + 'ズィ/ z i', + 'ズゥ/ z u', + 'ズャ/ zy a', + 'ズュ/ zy u', + 'ズョ/ zy o', + 'ズェ/ z e', + 'ズォ/ z o', + 'キャ/ ky a', + 'キュ/ ky u', + 'キョ/ ky o', + 'シャ/ sh a', + 'シュ/ sh u', + 'シェ/ sh e', + 'ショ/ sh o', + 'チャ/ ch a', + 'チュ/ ch u', + 'チェ/ ch e', + 'チョ/ ch o', + 'トゥ/ t u', + 'トャ/ ty a', + 'トュ/ ty u', + 'トョ/ ty o', + 'ドァ/ d o a', + 'ドゥ/ d u', + 'ドャ/ dy a', + 'ドュ/ dy u', + 'ドョ/ dy o', + 'ドォ/ d o:', + 'ニャ/ ny a', + 'ニュ/ ny u', + 'ニョ/ ny o', + 'ヒャ/ hy a', + 'ヒュ/ hy u', + 'ヒョ/ hy o', + 'ミャ/ my a', + 'ミュ/ my u', + 'ミョ/ my o', + 'リャ/ ry a', + 'リュ/ ry u', + 'リョ/ ry o', + 'ギャ/ gy a', + 'ギュ/ gy u', + 'ギョ/ gy o', + 'ヂェ/ j e', + 'ヂャ/ j a', + 'ヂュ/ j u', + 'ヂョ/ j o', + 'ジェ/ j e', + 'ジャ/ j a', + 'ジュ/ j u', + 'ジョ/ j o', + 'ビャ/ by a', + 'ビュ/ by u', + 'ビョ/ by o', + 'ピャ/ py a', + 'ピュ/ py u', + 'ピョ/ py o', + 'ウァ/ u a', + 'ウィ/ w i', + 'ウェ/ w e', + 'ウォ/ w o', + 'ファ/ f a', + 'フィ/ f i', + 'フゥ/ f u', + 'フャ/ hy a', + 'フュ/ hy u', + 'フョ/ hy o', + 'フェ/ f e', + 'フォ/ f o', + 'ヴァ/ b a', + 'ヴィ/ b i', + 'ヴェ/ b e', + 'ヴォ/ b o', + 'ヴュ/ by u', + + # Conversion of 1 letter + 'ア/ a', + 'イ/ i', + 'ウ/ u', + 'エ/ e', + 'オ/ o', + 'カ/ k a', + 'キ/ k i', + 'ク/ k u', + 'ケ/ k e', + 'コ/ k o', + 'サ/ s a', + 'シ/ sh i', + 'ス/ s u', + 'セ/ s e', + 'ソ/ s o', + 'タ/ t a', + 'チ/ ch i', + 'ツ/ ts u', + 'テ/ t e', + 'ト/ t o', + 'ナ/ n a', + 'ニ/ n i', + 'ヌ/ n u', + 'ネ/ n e', + 'ノ/ n o', + 'ハ/ h a', + 'ヒ/ h i', + 'フ/ f u', + 'ヘ/ h e', + 'ホ/ h o', + 'マ/ m a', + 'ミ/ m i', + 'ム/ m u', + 'メ/ m e', + 'モ/ m o', + 'ラ/ r a', + 'リ/ r i', + 'ル/ r u', + 'レ/ r e', + 'ロ/ r o', + 'ガ/ g a', + 'ギ/ g i', + 'グ/ g u', + 'ゲ/ g e', + 'ゴ/ g o', + 'ザ/ z a', + 'ジ/ j i', + 'ズ/ z u', + 'ゼ/ z e', + 'ゾ/ z o', + 'ダ/ d a', + 'ヂ/ j i', + 'ヅ/ z u', + 'デ/ d e', + 'ド/ d o', + 'バ/ b a', + 'ビ/ b i', + 'ブ/ b u', + 'ベ/ b e', + 'ボ/ b o', + 'パ/ p a', + 'ピ/ p i', + 'プ/ p u', + 'ペ/ p e', + 'ポ/ p o', + 'ヤ/ y a', + 'ユ/ y u', + 'ヨ/ y o', + 'ワ/ w a', + 'ヰ/ i', + 'ヱ/ e', + 'ヲ/ o', + 'ン/ N', + 'ッ/ q', + 'ヴ/ b u', + 'ー/:', + + # Try converting broken text + 'ァ/ a', + 'ィ/ i', + 'ゥ/ u', + 'ェ/ e', + 'ォ/ o', + 'ヮ/ w a', + 'ォ/ o', + + # Symbols + '、/ ,', + '。/ .', + '!/ !', + '?/ ?', + '・/ ,' +] + +_COLON_RX = re.compile(':+') +_REJECT_RX = re.compile('[^ a-zA-Z:,.?]') + +def _makerulemap(): + l = [tuple(x.split('/')) for x in _CONVRULES] + return tuple( + {k: v for k, v in l if len(k) == i} + for i in (1, 2) + ) + +_RULEMAP1, _RULEMAP2 = _makerulemap() + +def kata2phoneme(text: str) -> str: + """Convert katakana text to phonemes. + """ + text = text.strip() + res = '' + while text: + if len(text) >= 2: + x = _RULEMAP2.get(text[:2]) + if x is not None: + text = text[2:] + res += x + continue + x = _RULEMAP1.get(text[0]) + if x is not None: + text = text[1:] + res += x + continue + res += ' ' + text[0] + text = text[1:] + res = _COLON_RX.sub(':', res) + return res[1:] + +_KATAKANA = ''.join(chr(ch) for ch in range(ord('ァ'), ord('ン') + 1)) +_HIRAGANA = ''.join(chr(ch) for ch in range(ord('ぁ'), ord('ん') + 1)) +_HIRA2KATATRANS = str.maketrans(_HIRAGANA, _KATAKANA) + +def hira2kata(text: str) -> str: + text = text.translate(_HIRA2KATATRANS) + return text.replace('う゛', 'ヴ') + +_SYMBOL_TOKENS = set(list('・、。?!')) +_NO_YOMI_TOKENS = set(list('「」『』―()[][] …')) +_TAGGER = MeCab.Tagger() + +def text2kata(text: str) -> str: + parsed = _TAGGER.parse(text) + res = [] + for line in parsed.split('\n'): + if line == 'EOS': + break + parts = line.split('\t') + + word, yomi = parts[0], parts[1] + if yomi: + res.append(yomi) + else: + if word in _SYMBOL_TOKENS: + res.append(word) + elif word in ('っ', 'ッ'): + res.append('ッ') + elif word in _NO_YOMI_TOKENS: + pass + else: + res.append(word) + return hira2kata(''.join(res)) + +def japanese_text_to_phonemes(text: str) -> str: + """Convert Japanese text to phonemes. + """ + res = text2kata(text) + res = kata2phoneme(res) + return res.replace(' ', '') diff --git a/recipes/kokoro/tacotron2-DDC/run.sh b/recipes/kokoro/tacotron2-DDC/run.sh new file mode 100644 index 00000000..86fda642 --- /dev/null +++ b/recipes/kokoro/tacotron2-DDC/run.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# take the scripts's parent's directory to prefix all the output paths. +RUN_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +CORPUS=kokoro-speech-v1_1-small +echo $RUN_DIR +if [ \! -d $RUN_DIR/$CORPUS ] ; then + echo "$RUN_DIR/$CORPUS doesn't exist." + echo "Follow the instruction of https://github.com/kaiidams/Kokoro-Speech-Dataset to make the corpus." + exit 1 +fi +# create train-val splits +shuf $RUN_DIR/$CORPUS/metadata.csv > $RUN_DIR/$CORPUS/metadata_shuf.csv +head -n 8000 $RUN_DIR/$CORPUS/metadata_shuf.csv > $RUN_DIR/$CORPUS/metadata_train.csv +tail -n 812 $RUN_DIR/$CORPUS/metadata_shuf.csv > $RUN_DIR/$CORPUS/metadata_val.csv +# compute dataset mean and variance for normalization +python TTS/bin/compute_statistics.py $RUN_DIR/tacotron2-DDC.json $RUN_DIR/scale_stats.npy --data_path $RUN_DIR/$CORPUS/wavs/ +# training .... +# change the GPU id if needed +CUDA_VISIBLE_DEVICES="0" python TTS/bin/train_tacotron.py --config_path $RUN_DIR/tacotron2-DDC.json \ + --coqpit.output_path $RUN_DIR \ + --coqpit.datasets.0.path $RUN_DIR/$CORPUS \ + --coqpit.audio.stats_path $RUN_DIR/scale_stats.npy \ + --coqpit.phoneme_cache_path $RUN_DIR/phoneme_cache \ \ No newline at end of file diff --git a/recipes/kokoro/tacotron2-DDC/tacotron2-DDC.json b/recipes/kokoro/tacotron2-DDC/tacotron2-DDC.json new file mode 100644 index 00000000..b3630055 --- /dev/null +++ b/recipes/kokoro/tacotron2-DDC/tacotron2-DDC.json @@ -0,0 +1,125 @@ +{ + "datasets": [ + { + "name": "kokoro", + "path": "DEFINE THIS", + "meta_file_train": "metadata.csv", + "meta_file_val": null + } + ], + "audio": { + "fft_size": 1024, + "win_length": 1024, + "hop_length": 256, + "frame_length_ms": null, + "frame_shift_ms": null, + "sample_rate": 22050, + "preemphasis": 0.0, + "ref_level_db": 20, + "do_trim_silence": true, + "trim_db": 60, + "power": 1.5, + "griffin_lim_iters": 60, + "num_mels": 80, + "mel_fmin": 50.0, + "mel_fmax": 7600.0, + "spec_gain": 1, + "signal_norm": true, + "min_level_db": -100, + "symmetric_norm": true, + "max_norm": 4.0, + "clip_norm": true, + "stats_path": "scale_stats.npy" + }, + "gst":{ + "gst_style_input": null, + + + + "gst_embedding_dim": 512, + "gst_num_heads": 4, + "gst_style_tokens": 10, + "gst_use_speaker_embedding": false + }, + "model": "Tacotron2", + "run_name": "kokoro-ddc", + "run_description": "tacotron2 with DDC and differential spectral loss.", + "batch_size": 32, + "eval_batch_size": 16, + "mixed_precision": true, + "distributed": { + "backend": "nccl", + "url": "tcp:\/\/localhost:54321" + }, + "reinit_layers": [], + "loss_masking": true, + "decoder_loss_alpha": 0.5, + "postnet_loss_alpha": 0.25, + "postnet_diff_spec_alpha": 0.25, + "decoder_diff_spec_alpha": 0.25, + "decoder_ssim_alpha": 0.5, + "postnet_ssim_alpha": 0.25, + "ga_alpha": 5.0, + "stopnet_pos_weight": 15.0, + "run_eval": true, + "test_delay_epochs": 10, + "test_sentences_file": null, + "noam_schedule": false, + "grad_clip": 1.0, + "epochs": 1000, + "lr": 0.0001, + "wd": 0.000001, + "warmup_steps": 4000, + "seq_len_norm": false, + "memory_size": -1, + "prenet_type": "original", + "prenet_dropout": true, + "attention_type": "original", + "windowing": false, + "use_forward_attn": false, + "forward_attn_mask": false, + "transition_agent": false, + "location_attn": true, + "bidirectional_decoder": false, + "double_decoder_consistency": true, + "ddc_r": 7, + "attention_heads": 4, + "attention_norm": "sigmoid", + "r": 7, + "gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], + "stopnet": true, + "separate_stopnet": true, + "print_step": 25, + "tb_plot_step": 100, + "print_eval": false, + "save_step": 10000, + "checkpoint": true, + "keep_all_best": false, + "keep_after": 10000, + "tb_model_param_stats": false, + "text_cleaner": "basic_cleaners", + "enable_eos_bos_chars": false, + "num_loader_workers": 4, + "num_val_loader_workers": 4, + "batch_group_size": 4, + "min_seq_len": 6, + "max_seq_len": 153, + "compute_input_seq_cache": false, + "use_noise_augment": true, + "output_path": "DEFINE THIS", + "phoneme_cache_path": "DEFINE THIS", + "use_phonemes": true, + "phoneme_language": "ja-jp", + "characters": { + "pad": "_", + "eos": "~", + "bos": "^", + "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", + "punctuations": "!'(),-.:;? ", + "phonemes": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + }, + "use_speaker_embedding": false, + "use_gst": false, + "use_external_speaker_embedding_file": false, + "external_speaker_embedding_file": "../../speakers-vctk-en.json" +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b376eb1b..ab828503 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,5 @@ numba==0.52 umap-learn==0.4.6 anyascii coqpit +mecab-python3 +unidic-lite diff --git a/tests/tts_tests/test_japanese_phonemizer.py b/tests/tts_tests/test_japanese_phonemizer.py new file mode 100644 index 00000000..437042f0 --- /dev/null +++ b/tests/tts_tests/test_japanese_phonemizer.py @@ -0,0 +1,22 @@ +import unittest +from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes + +_TEST_CASES = ''' +どちらに行きますか?/dochiraniikimasuka? +今日は温泉に、行きます。/kyo:waoNseNni,ikimasu. +「A」から「Z」までです。/AkaraZmadedesu. +そうですね!/so:desune! +クジラは哺乳類です。/kujirawahonyu:ruidesu. +ヴィディオを見ます。/bidioomimasu. +ky o: w a o N s e N n i , i k i m a s u ./kyo:waoNseNni,ikimasu. +''' + +class TestText(unittest.TestCase): + + def test_japanese_text_to_phonemes(self): + for line in _TEST_CASES.strip().split('\n'): + text, phone = line.split('/') + self.assertEqual(japanese_text_to_phonemes(text), phone) + +if __name__ == '__main__': + unittest.main()