mirror of https://github.com/coqui-ai/TTS.git
Add alphas to control language and speaker balancer (#1216)
* Add alphas to control language and speaker balancer * Add docs for speaker and language samplers * Change the Samplers weights to float for save memory * Change the test_samplers to unittest format * Add get_sampler method in BaseTTS * Fix rebase issues * Add language and speaker samplers support for DDP training * Rename distributed sampler wrapper * Remove the DistributedSamplerWrapper and use the one from Trainer * Bugfix after rebase * Move the samplers config to tts configpull/1349/head
parent
f381e29b91
commit
917f417ac4
|
@ -258,4 +258,3 @@ class BaseTrainingConfig(TrainerConfig):
|
|||
num_loader_workers: int = 0
|
||||
num_eval_loader_workers: int = 0
|
||||
use_noise_augment: bool = False
|
||||
use_language_weighted_sampler: bool = False
|
||||
|
|
|
@ -220,6 +220,18 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
eval_split_size (float):
|
||||
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
|
||||
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
|
||||
|
||||
use_speaker_weighted_sampler (bool):
|
||||
Enable / Disable the batch balancer by speaker. Defaults to ```False```.
|
||||
|
||||
speaker_weighted_sampler_alpha (float):
|
||||
Number that control the influence of the speaker sampler weights. Defaults to ```1.0```.
|
||||
|
||||
use_language_weighted_sampler (bool):
|
||||
Enable / Disable the batch balancer by language. Defaults to ```False```.
|
||||
|
||||
language_weighted_sampler_alpha (float):
|
||||
Number that control the influence of the language sampler weights. Defaults to ```1.0```.
|
||||
"""
|
||||
|
||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
|
@ -262,3 +274,8 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
# evaluation
|
||||
eval_split_max_size: int = None
|
||||
eval_split_size: float = 0.01
|
||||
# weighted samplers
|
||||
use_speaker_weighted_sampler: bool = False
|
||||
speaker_weighted_sampler_alpha: float = 1.0
|
||||
use_language_weighted_sampler: bool = False
|
||||
language_weighted_sampler_alpha: float = 1.0
|
||||
|
|
|
@ -7,14 +7,15 @@ import torch.distributed as dist
|
|||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||
|
||||
from TTS.model import BaseTrainerModel
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_balancer_weights
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
@ -232,6 +233,36 @@ class BaseTTS(BaseTrainerModel):
|
|||
"language_ids": language_ids,
|
||||
}
|
||||
|
||||
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
|
||||
weights = None
|
||||
data_items = dataset.samples
|
||||
|
||||
if getattr(config, "use_language_weighted_sampler", False):
|
||||
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Language weighted sampler with alpha:", alpha)
|
||||
weights = get_language_balancer_weights(data_items) * alpha
|
||||
|
||||
if getattr(config, "use_speaker_weighted_sampler", False):
|
||||
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Speaker weighted sampler with alpha:", alpha)
|
||||
if weights is not None:
|
||||
weights += get_speaker_balancer_weights(data_items) * alpha
|
||||
else:
|
||||
weights = get_speaker_balancer_weights(data_items) * alpha
|
||||
|
||||
if weights is not None:
|
||||
sampler = WeightedRandomSampler(weights, len(weights))
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
# sampler for DDP
|
||||
if sampler is None:
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
else: # If a sampler is already defined use this sampler and DDP sampler together
|
||||
sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler
|
||||
|
||||
return sampler
|
||||
|
||||
def get_data_loader(
|
||||
self,
|
||||
config: Coqpit,
|
||||
|
@ -300,25 +331,8 @@ class BaseTTS(BaseTrainerModel):
|
|||
# sort input sequences from short to long
|
||||
dataset.preprocess_samples()
|
||||
|
||||
# sampler for DDP
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
|
||||
# Weighted samplers
|
||||
# TODO: make this DDP amenable
|
||||
assert not (
|
||||
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
|
||||
), "language_weighted_sampler is not supported with DistributedSampler"
|
||||
assert not (
|
||||
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
|
||||
), "speaker_weighted_sampler is not supported with DistributedSampler"
|
||||
|
||||
if sampler is None:
|
||||
if getattr(config, "use_language_weighted_sampler", False):
|
||||
print(" > Using Language weighted sampler")
|
||||
sampler = get_language_weighted_sampler(dataset.samples)
|
||||
elif getattr(config, "use_speaker_weighted_sampler", False):
|
||||
print(" > Using Language weighted sampler")
|
||||
sampler = get_speaker_weighted_sampler(dataset.samples)
|
||||
# get samplers
|
||||
sampler = self.get_sampler(config, dataset, num_gpus)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
|
|
|
@ -13,7 +13,6 @@ from torch import nn
|
|||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
|
@ -24,8 +23,8 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
|
|||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
|
@ -1354,31 +1353,15 @@ class Vits(BaseTTS):
|
|||
# sort input sequences from short to long
|
||||
dataset.preprocess_samples()
|
||||
|
||||
# sampler for DDP
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
|
||||
# Weighted samplers
|
||||
# TODO: make this DDP amenable
|
||||
assert not (
|
||||
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
|
||||
), "language_weighted_sampler is not supported with DistributedSampler"
|
||||
assert not (
|
||||
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
|
||||
), "speaker_weighted_sampler is not supported with DistributedSampler"
|
||||
|
||||
if sampler is None:
|
||||
if getattr(config, "use_language_weighted_sampler", False):
|
||||
print(" > Using Language weighted sampler")
|
||||
sampler = get_language_weighted_sampler(dataset.samples)
|
||||
elif getattr(config, "use_speaker_weighted_sampler", False):
|
||||
print(" > Using Language weighted sampler")
|
||||
sampler = get_speaker_weighted_sampler(dataset.samples)
|
||||
# get samplers
|
||||
sampler = self.get_sampler(config, dataset, num_gpus)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
shuffle=False, # shuffle is done in the dataset.
|
||||
drop_last=False, # setting this False might cause issues in AMP training.
|
||||
sampler=sampler,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
|
|
|
@ -6,7 +6,6 @@ import fsspec
|
|||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
|
||||
from TTS.config import check_config_and_model_args
|
||||
|
||||
|
@ -128,11 +127,14 @@ def _set_file_path(path):
|
|||
return None
|
||||
|
||||
|
||||
def get_language_weighted_sampler(items: list):
|
||||
def get_language_balancer_weights(items: list):
|
||||
language_names = np.array([item["language"] for item in items])
|
||||
unique_language_names = np.unique(language_names).tolist()
|
||||
language_ids = [unique_language_names.index(l) for l in language_names]
|
||||
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
|
||||
weight_language = 1.0 / language_count
|
||||
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
|
||||
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
||||
# get weight for each sample
|
||||
dataset_samples_weight = np.array([weight_language[l] for l in language_ids])
|
||||
# normalize
|
||||
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||
return torch.from_numpy(dataset_samples_weight).float()
|
||||
|
|
|
@ -7,7 +7,6 @@ import fsspec
|
|||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
|
||||
from TTS.config import get_from_config_or_model_args_with_default, load_config
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
|
@ -449,11 +448,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
|||
return speaker_manager
|
||||
|
||||
|
||||
def get_speaker_weighted_sampler(items: list):
|
||||
def get_speaker_balancer_weights(items: list):
|
||||
speaker_names = np.array([item["speaker_name"] for item in items])
|
||||
unique_speaker_names = np.unique(speaker_names).tolist()
|
||||
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
|
||||
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
|
||||
weight_speaker = 1.0 / speaker_count
|
||||
dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
|
||||
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
||||
dataset_samples_weight = np.array([weight_speaker[l] for l in speaker_ids])
|
||||
# normalize
|
||||
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||
return torch.from_numpy(dataset_samples_weight).float()
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import functools
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.languages import get_language_weighted_sampler
|
||||
from TTS.tts.utils.languages import get_language_balancer_weights
|
||||
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
||||
|
||||
# Fixing random state to avoid random fails
|
||||
torch.manual_seed(0)
|
||||
|
@ -25,34 +28,57 @@ dataset_config_pt = BaseDatasetConfig(
|
|||
language="pt-br",
|
||||
)
|
||||
|
||||
# Adding the EN samples twice to create an unbalanced dataset
|
||||
# Adding the EN samples twice to create a language unbalanced dataset
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
[dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
|
||||
)
|
||||
|
||||
# gerenate a speaker unbalanced dataset
|
||||
for i, sample in enumerate(train_samples):
|
||||
if i < 5:
|
||||
sample["speaker_name"] = "ljspeech-0"
|
||||
else:
|
||||
sample["speaker_name"] = "ljspeech-1"
|
||||
|
||||
|
||||
def is_balanced(lang_1, lang_2):
|
||||
return 0.85 < lang_1 / lang_2 < 1.2
|
||||
|
||||
|
||||
random_sampler = torch.utils.data.RandomSampler(train_samples)
|
||||
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
|
||||
en, pt = 0, 0
|
||||
for index in ids:
|
||||
if train_samples[index]["language"] == "en":
|
||||
en += 1
|
||||
else:
|
||||
pt += 1
|
||||
class TestSamplers(unittest.TestCase):
|
||||
def test_language_random_sampler(self): # pylint: disable=no-self-use
|
||||
random_sampler = torch.utils.data.RandomSampler(train_samples)
|
||||
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
|
||||
en, pt = 0, 0
|
||||
for index in ids:
|
||||
if train_samples[index]["language"] == "en":
|
||||
en += 1
|
||||
else:
|
||||
pt += 1
|
||||
|
||||
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
|
||||
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
|
||||
|
||||
weighted_sampler = get_language_weighted_sampler(train_samples)
|
||||
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
||||
en, pt = 0, 0
|
||||
for index in ids:
|
||||
if train_samples[index]["language"] == "en":
|
||||
en += 1
|
||||
else:
|
||||
pt += 1
|
||||
def test_language_weighted_random_sampler(self): # pylint: disable=no-self-use
|
||||
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_language_balancer_weights(train_samples), len(train_samples))
|
||||
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
||||
en, pt = 0, 0
|
||||
for index in ids:
|
||||
if train_samples[index]["language"] == "en":
|
||||
en += 1
|
||||
else:
|
||||
pt += 1
|
||||
|
||||
assert is_balanced(en, pt), "Weighted sampler is supposed to be balanced"
|
||||
assert is_balanced(en, pt), "Language Weighted sampler is supposed to be balanced"
|
||||
|
||||
def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use
|
||||
|
||||
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_speaker_balancer_weights(train_samples), len(train_samples))
|
||||
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
||||
spk1, spk2 = 0, 0
|
||||
for index in ids:
|
||||
if train_samples[index]["speaker_name"] == "ljspeech-0":
|
||||
spk1 += 1
|
||||
else:
|
||||
spk2 += 1
|
||||
|
||||
assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"
|
||||
|
|
|
@ -45,7 +45,7 @@ config = VitsConfig(
|
|||
["Be a voice, not an echo.", "ljspeech-0", None, "en"],
|
||||
["Be a voice, not an echo.", "ljspeech-1", None, "pt-br"],
|
||||
],
|
||||
datasets=[dataset_config_en, dataset_config_pt],
|
||||
datasets=[dataset_config_en, dataset_config_en, dataset_config_en, dataset_config_pt],
|
||||
)
|
||||
# set audio config
|
||||
config.audio.do_trim_silence = True
|
||||
|
@ -71,8 +71,11 @@ config.d_vector_dim = 256
|
|||
config.model_args.use_sdp = True
|
||||
config.use_sdp = True
|
||||
|
||||
# deactivate language sampler
|
||||
config.use_language_weighted_sampler = False
|
||||
# activate language and speaker samplers
|
||||
config.use_language_weighted_sampler = True
|
||||
config.language_weighted_sampler_alpha = 10
|
||||
config.use_speaker_weighted_sampler = True
|
||||
config.speaker_weighted_sampler_alpha = 5
|
||||
|
||||
config.save_json(config_path)
|
||||
|
||||
|
|
Loading…
Reference in New Issue