diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index 7ed79b36..bbc88fb6 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -1,14 +1,15 @@ """Find all the unique characters in a dataset""" import argparse +import multiprocessing from argparse import RawTextHelpFormatter +import numpy +from tqdm.contrib.concurrent import process_map + from TTS.config import load_config from TTS.tts.datasets import load_meta_data - -import numpy -import multiprocessing from TTS.tts.utils.text import text2phone -from tqdm.contrib.concurrent import process_map + def compute_phonemes(item): try: @@ -18,7 +19,8 @@ def compute_phonemes(item): except: return [] return list(set(ph)) - + + def main(): global c # pylint: disable=bad-option-value @@ -51,8 +53,6 @@ def main(): phones_force_lower = [c.lower() for c in phones] phones_force_lower = set(phones_force_lower) - - print(f" > Number of unique phonemes: {len(phones)}") print(f" > Unique phonemes: {''.join(sorted(phones))}") print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}") diff --git a/TTS/bin/remove_silence_using_vad.py b/TTS/bin/remove_silence_using_vad.py index c7541cc8..25ae26ef 100755 --- a/TTS/bin/remove_silence_using_vad.py +++ b/TTS/bin/remove_silence_using_vad.py @@ -1,26 +1,27 @@ # This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py -import os -import tqdm -import glob import argparse -import pathlib - import collections import contextlib +import glob +import multiprocessing +import os +import pathlib import sys import wave +from itertools import chain + import numpy as np +import tqdm import webrtcvad from tqdm.contrib.concurrent import process_map -import multiprocessing -from itertools import chain + def read_wave(path): """Reads a .wav file. Takes the path, and returns (PCM audio data, sample rate). """ - with contextlib.closing(wave.open(path, 'rb')) as wf: + with contextlib.closing(wave.open(path, "rb")) as wf: num_channels = wf.getnchannels() assert num_channels == 1 sample_width = wf.getsampwidth() @@ -36,7 +37,7 @@ def write_wave(path, audio, sample_rate): Takes path, PCM audio data, and sample rate. """ - with contextlib.closing(wave.open(path, 'wb')) as wf: + with contextlib.closing(wave.open(path, "wb")) as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(sample_rate) @@ -45,6 +46,7 @@ def write_wave(path, audio, sample_rate): class Frame(object): """Represents a "frame" of audio data.""" + def __init__(self, bytes, timestamp, duration): self.bytes = bytes self.timestamp = timestamp @@ -64,13 +66,12 @@ def frame_generator(frame_duration_ms, audio, sample_rate): timestamp = 0.0 duration = (float(n) / sample_rate) / 2.0 while offset + n < len(audio): - yield Frame(audio[offset:offset + n], timestamp, duration) + yield Frame(audio[offset : offset + n], timestamp, duration) timestamp += duration offset += n -def vad_collector(sample_rate, frame_duration_ms, - padding_duration_ms, vad, frames): +def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames): """Filters out non-voiced audio frames. Given a webrtcvad.Vad and a source of audio frames, yields only @@ -133,25 +134,26 @@ def vad_collector(sample_rate, frame_duration_ms, # unvoiced, then enter NOTTRIGGERED and yield whatever # audio we've collected. if num_unvoiced > 0.9 * ring_buffer.maxlen: - #sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration)) + # sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration)) triggered = False - yield b''.join([f.bytes for f in voiced_frames]) + yield b"".join([f.bytes for f in voiced_frames]) ring_buffer.clear() voiced_frames = [] # If we have any leftover voiced audio when we run out of input, # yield it. if voiced_frames: - yield b''.join([f.bytes for f in voiced_frames]) + yield b"".join([f.bytes for f in voiced_frames]) + def remove_silence(filepath): filename = os.path.basename(filepath) - output_path = filepath.replace(os.path.join(args.input_dir, ''),os.path.join(args.output_dir, '')) - # ignore if the file exists + output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, "")) + # ignore if the file exists if os.path.exists(output_path) and not args.force: return False # create all directory structure pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True) - padding_duration_ms = 300 # default 300 + padding_duration_ms = 300 # default 300 audio, sample_rate = read_wave(filepath) vad = webrtcvad.Vad(int(args.aggressiveness)) frames = frame_generator(30, audio, sample_rate) @@ -180,6 +182,7 @@ def remove_silence(filepath): # if fail to remove silence just write the file write_wave(output_path, audio, sample_rate) + def preprocess_audios(): files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True)) print("> Number of files: ", len(files)) @@ -193,21 +196,31 @@ def preprocess_audios(): else: print("> No files Found !") + if __name__ == "__main__": """ usage - python remove_silence.py -i=VCTK-Corpus-bk/ -o=../VCTK-Corpus-removed-silence -g=wav48/*/*.wav -a=2 + python remove_silence.py -i=VCTK-Corpus-bk/ -o=../VCTK-Corpus-removed-silence -g=wav48/*/*.wav -a=2 """ parser = argparse.ArgumentParser() - parser.add_argument('-i', '--input_dir', type=str, default='../VCTK-Corpus', - help='Dataset root dir') - parser.add_argument('-o', '--output_dir', type=str, default='../VCTK-Corpus-removed-silence', - help='Output Dataset dir') - parser.add_argument('-f', '--force', type=bool, default=True, - help='Force the replace of exists files') - parser.add_argument('-g', '--glob', type=str, default='**/*.wav', - help='path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav') - parser.add_argument('-a', '--aggressiveness', type=int, default=2, - help='set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.') + parser.add_argument("-i", "--input_dir", type=str, default="../VCTK-Corpus", help="Dataset root dir") + parser.add_argument( + "-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir" + ) + parser.add_argument("-f", "--force", type=bool, default=True, help="Force the replace of exists files") + parser.add_argument( + "-g", + "--glob", + type=str, + default="**/*.wav", + help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav", + ) + parser.add_argument( + "-a", + "--aggressiveness", + type=int, + default=2, + help="set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.", + ) args = parser.parse_args() preprocess_audios() diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index beeb5ae1..42f041b4 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -5,20 +5,20 @@ import torch.nn as nn from TTS.utils.io import load_fsspec + class PreEmphasis(torch.nn.Module): def __init__(self, coefficient=0.97): super().__init__() self.coefficient = coefficient - self.register_buffer( - 'filter', torch.FloatTensor([-self.coefficient, 1.]).unsqueeze(0).unsqueeze(0) - ) + self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) def forward(self, x): assert len(x.size()) == 2 - x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), 'reflect') + x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") return torch.nn.functional.conv1d(x, self.filter).squeeze(1) + class SELayer(nn.Module): def __init__(self, channel, reduction=8): super(SELayer, self).__init__() @@ -110,8 +110,15 @@ class ResNetSpeakerEncoder(nn.Module): if self.use_torch_spec: self.torch_spec = torch.nn.Sequential( PreEmphasis(audio_config["preemphasis"]), - torchaudio.transforms.MelSpectrogram(sample_rate=audio_config["sample_rate"], n_fft=audio_config["fft_size"], win_length=audio_config["win_length"], hop_length=audio_config["hop_length"], window_fn=torch.hamming_window, n_mels=audio_config["num_mels"]) - ) + torchaudio.transforms.MelSpectrogram( + sample_rate=audio_config["sample_rate"], + n_fft=audio_config["fft_size"], + win_length=audio_config["win_length"], + hop_length=audio_config["hop_length"], + window_fn=torch.hamming_window, + n_mels=audio_config["num_mels"], + ), + ) else: self.torch_spec = None @@ -213,7 +220,7 @@ class ResNetSpeakerEncoder(nn.Module): """ # map to the waveform size if self.use_torch_spec: - num_frames = num_frames * self.audio_config['hop_length'] + num_frames = num_frames * self.audio_config["hop_length"] max_len = x.shape[1] diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/speaker_encoder/utils/generic_utils.py index 3714e3c4..c926e215 100644 --- a/TTS/speaker_encoder/utils/generic_utils.py +++ b/TTS/speaker_encoder/utils/generic_utils.py @@ -179,10 +179,12 @@ def setup_model(c): c.model_params["num_lstm_layers"], ) elif c.model_params["model_name"].lower() == "resnet": - model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"], + model = ResNetSpeakerEncoder( + input_dim=c.model_params["input_dim"], + proj_dim=c.model_params["proj_dim"], log_input=c.model_params.get("log_input", False), use_torch_spec=c.model_params.get("use_torch_spec", False), - audio_config=c.audio + audio_config=c.audio, ) return model diff --git a/TTS/trainer.py b/TTS/trainer.py index 665f2589..c151e716 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -265,7 +265,9 @@ class Trainer: config = self.config.model_args if hasattr(self.config, "model_args") else self.config # save speakers json if config.use_language_embedding and self.model.language_manager.num_languages > 1: - self.model.language_manager.save_language_ids_to_file(os.path.join(self.output_path, "language_ids.json")) + self.model.language_manager.save_language_ids_to_file( + os.path.join(self.output_path, "language_ids.json") + ) if hasattr(self.config, "model_args"): self.config.model_args["num_languages"] = self.model.language_manager.num_languages else: diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index fc51c766..6d177743 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -542,6 +542,7 @@ class TTSDataset(Dataset): ) ) + class PitchExtractor: """Pitch Extractor for computing F0 from wav files. Args: @@ -645,4 +646,4 @@ class PitchExtractor: stats_path = os.path.join(cache_path, "pitch_stats.npy") stats = np.load(stats_path, allow_pickle=True).item() self.mean = stats["mean"].astype(np.float32) - self.std = stats["std"].astype(np.float32) \ No newline at end of file + self.std = stats["std"].astype(np.float32) diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 651b3197..7e65f21a 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -304,7 +304,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48", ununsed_speakers=None): return items -def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ununsed_speakers=None): # pylint: disable=unused-argument +def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ununsed_speakers=None): # pylint: disable=unused-argument """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" items = [] txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index cd2903b0..93a5bad2 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -602,7 +602,7 @@ class VitsGeneratorLoss(nn.Module): fine_tuning_mode=0, use_speaker_encoder_as_loss=False, gt_spk_emb=None, - syn_spk_emb=None + syn_spk_emb=None, ): """ Shapes: @@ -638,7 +638,9 @@ class VitsGeneratorLoss(nn.Module): loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration if use_speaker_encoder_as_loss: - loss_se = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha + loss_se = ( + -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha + ) loss += loss_se return_dict["loss_spk_encoder"] = loss_se diff --git a/TTS/tts/layers/vits/stochastic_duration_predictor.py b/TTS/tts/layers/vits/stochastic_duration_predictor.py index 8ec7c866..7c25156a 100644 --- a/TTS/tts/layers/vits/stochastic_duration_predictor.py +++ b/TTS/tts/layers/vits/stochastic_duration_predictor.py @@ -178,7 +178,14 @@ class StochasticDurationPredictor(nn.Module): """ def __init__( - self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0, language_emb_dim=None + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + dropout_p: float, + num_flows=4, + cond_channels=0, + language_emb_dim=None, ): super().__init__() diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index df6c52f3..de00f6c7 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -246,7 +246,9 @@ class BaseTTS(BaseModel): # setup multi-speaker attributes if hasattr(self, "speaker_manager") and self.speaker_manager is not None: if hasattr(config, "model_args"): - speaker_id_mapping = self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None + speaker_id_mapping = ( + self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None + ) d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None config.use_d_vector_file = config.model_args.use_d_vector_file else: @@ -262,7 +264,9 @@ class BaseTTS(BaseModel): custom_symbols = self.make_symbols(self.config) if hasattr(self, "language_manager"): - language_id_mapping = self.language_manager.language_id_mapping if self.args.use_language_embedding else None + language_id_mapping = ( + self.language_manager.language_id_mapping if self.args.use_language_embedding else None + ) else: language_id_mapping = None diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index bc503cb5..c185150b 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -229,7 +229,6 @@ class VitsArgs(Coqpit): freeze_waveform_decoder: bool = False - class Vits(BaseTTS): """VITS TTS model @@ -306,7 +305,7 @@ class Vits(BaseTTS): args.num_layers_text_encoder, args.kernel_size_text_encoder, args.dropout_p_text_encoder, - language_emb_dim=self.embedded_language_dim + language_emb_dim=self.embedded_language_dim, ) self.posterior_encoder = PosteriorEncoder( @@ -389,16 +388,26 @@ class Vits(BaseTTS): # TODO: make this a function if config.use_speaker_encoder_as_loss: if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path: - raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!") - self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path) + raise RuntimeError( + " [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" + ) + self.speaker_manager.init_speaker_encoder( + config.speaker_encoder_model_path, config.speaker_encoder_config_path + ) self.speaker_encoder = self.speaker_manager.speaker_encoder.train() for param in self.speaker_encoder.parameters(): param.requires_grad = False print(" > External Speaker Encoder Loaded !!") - if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]: - self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"]) + if ( + hasattr(self.speaker_encoder, "audio_config") + and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"] + ): + self.audio_transform = torchaudio.transforms.Resample( + orig_freq=self.audio_config["sample_rate"], + new_freq=self.speaker_encoder.audio_config["sample_rate"], + ) else: self.audio_transform = None else: @@ -529,7 +538,13 @@ class Vits(BaseTTS): if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: language_id = self.language_manager.language_id_mapping[language_name] - return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id} + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + } def forward( self, @@ -567,7 +582,7 @@ class Vits(BaseTTS): g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] # language embedding - lang_emb=None + lang_emb = None if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) @@ -621,9 +636,9 @@ class Vits(BaseTTS): o = self.waveform_decoder(z_slice, g=g) wav_seg = segment( - waveform.transpose(1, 2), - slice_ids * self.config.audio.hop_length, - self.args.spec_segment_size * self.config.audio.hop_length, + waveform.transpose(1, 2), + slice_ids * self.config.audio.hop_length, + self.args.spec_segment_size * self.config.audio.hop_length, ) if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None: @@ -653,7 +668,7 @@ class Vits(BaseTTS): "logs_q": logs_q, "waveform_seg": wav_seg, "gt_spk_emb": gt_spk_emb, - "syn_spk_emb": syn_spk_emb + "syn_spk_emb": syn_spk_emb, } ) return outputs @@ -695,7 +710,7 @@ class Vits(BaseTTS): g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] # language embedding - lang_emb=None + lang_emb = None if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) @@ -737,9 +752,9 @@ class Vits(BaseTTS): o = self.waveform_decoder(z_slice, g=g) wav_seg = segment( - waveform.transpose(1, 2), - slice_ids * self.config.audio.hop_length, - self.args.spec_segment_size * self.config.audio.hop_length, + waveform.transpose(1, 2), + slice_ids * self.config.audio.hop_length, + self.args.spec_segment_size * self.config.audio.hop_length, ) if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None: @@ -770,7 +785,7 @@ class Vits(BaseTTS): "logs_q": logs_q, "waveform_seg": wav_seg, "gt_spk_emb": gt_spk_emb, - "syn_spk_emb": syn_spk_emb + "syn_spk_emb": syn_spk_emb, } ) return outputs @@ -790,14 +805,16 @@ class Vits(BaseTTS): g = self.emb_g(sid).unsqueeze(-1) # language embedding - lang_emb=None + lang_emb = None if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) if self.args.use_sdp: - logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb) + logw = self.duration_predictor( + x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb + ) else: logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb) @@ -866,7 +883,7 @@ class Vits(BaseTTS): for param in self.text_encoder.parameters(): param.requires_grad = False - if hasattr(self, 'emb_l'): + if hasattr(self, "emb_l"): for param in self.emb_l.parameters(): param.requires_grad = False @@ -932,7 +949,7 @@ class Vits(BaseTTS): with autocast(enabled=False): # use float32 for the criterion loss_dict = criterion[optimizer_idx]( waveform_hat=outputs["model_outputs"].float(), - waveform= outputs["waveform_seg"].float(), + waveform=outputs["waveform_seg"].float(), z_p=outputs["z_p"].float(), logs_q=outputs["logs_q"].float(), m_p=outputs["m_p"].float(), @@ -945,7 +962,7 @@ class Vits(BaseTTS): fine_tuning_mode=self.args.fine_tuning_mode, use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, gt_spk_emb=outputs["gt_spk_emb"], - syn_spk_emb=outputs["syn_spk_emb"] + syn_spk_emb=outputs["syn_spk_emb"], ) # ignore duration loss if fine tuning mode is on if not self.args.fine_tuning_mode: diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index 94be914c..5bacc259 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -1,13 +1,14 @@ -import os import json -import torch +import os +from typing import Dict, List, Tuple + import fsspec import numpy as np -from typing import Dict, Tuple, List +import torch from coqpit import Coqpit - from torch.utils.data.sampler import WeightedRandomSampler + class LanguageManager: """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information in a way that can be queried by language. @@ -20,7 +21,9 @@ class LanguageManager: >>> manager = LanguageManager(language_id_file_path=language_id_file_path) >>> language_id_mapper = manager.language_ids """ + language_id_mapping: Dict = {} + def __init__( self, language_id_file_path: str = "", @@ -85,6 +88,7 @@ class LanguageManager: """ self._save_json(file_path, self.language_id_mapping) + def _set_file_path(path): """Find the language_ids.json under the given path or the above it. Intended to band aid the different paths returned in restored and continued training.""" @@ -97,6 +101,7 @@ def _set_file_path(path): return path_continue return None + def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) -> LanguageManager: """Initiate a `LanguageManager` instance by the provided config. @@ -118,7 +123,7 @@ def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) # restoring language manager from a previous run. if language_file: language_manager.set_language_ids_from_file(language_file) - if language_manager.num_languages > 0: + if language_manager.num_languages > 0: print( " > Language manager is loaded with {} languages: {}".format( language_manager.num_languages, ", ".join(language_manager.language_names) @@ -126,11 +131,12 @@ def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) ) return language_manager + def get_language_weighted_sampler(items: list): language_names = np.array([item[3] for item in items]) unique_language_names = np.unique(language_names).tolist() language_ids = [unique_language_names.index(l) for l in language_names] language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names]) - weight_language = 1. / language_count + weight_language = 1.0 / language_count dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double() return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight)) diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 8ccbdafc..d6381a70 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -432,11 +432,12 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, speaker_manager.save_speaker_ids_to_file(out_file_path) return speaker_manager + def get_speaker_weighted_sampler(items: list): speaker_names = np.array([item[2] for item in items]) unique_speaker_names = np.unique(speaker_names).tolist() speaker_ids = [unique_speaker_names.index(l) for l in speaker_names] speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names]) - weight_speaker = 1. / speaker_count + weight_speaker = 1.0 / speaker_count dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double() - return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight)) \ No newline at end of file + return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight)) diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py index 826919c2..f3ffa478 100644 --- a/TTS/tts/utils/text/cleaners.py +++ b/TTS/tts/utils/text/cleaners.py @@ -136,8 +136,9 @@ def phoneme_cleaners(text): text = collapse_whitespace(text) return text + def multilingual_cleaners(text): - '''Pipeline for multilingual text''' + """Pipeline for multilingual text""" text = lowercase(text) text = replace_symbols(text, lang=None) text = remove_aux_symbols(text) diff --git a/tests/tts_tests/test_vits_multilingual_train.py b/tests/tts_tests/test_vits_multilingual_train.py index 664de57e..04b42e61 100644 --- a/tests/tts_tests/test_vits_multilingual_train.py +++ b/tests/tts_tests/test_vits_multilingual_train.py @@ -3,19 +3,27 @@ import os import shutil from tests import get_device_id, get_tests_output_path, run_cli -from TTS.tts.configs.vits_config import VitsConfig from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.configs.vits_config import VitsConfig config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") dataset_config1 = BaseDatasetConfig( - name="ljspeech", meta_file_train="metadata.csv", meta_file_val="metadata.csv", path="tests/data/ljspeech", language="en" + name="ljspeech", + meta_file_train="metadata.csv", + meta_file_val="metadata.csv", + path="tests/data/ljspeech", + language="en", ) dataset_config2 = BaseDatasetConfig( - name="ljspeech", meta_file_train="metadata.csv", meta_file_val="metadata.csv", path="tests/data/ljspeech", language="en2" + name="ljspeech", + meta_file_train="metadata.csv", + meta_file_val="metadata.csv", + path="tests/data/ljspeech", + language="en2", ) config = VitsConfig(