Add alphas to control language and speaker balancer (#1216)

* Add alphas to control language and speaker balancer

* Add docs for speaker and language samplers

* Change the Samplers weights to float for save memory

* Change the test_samplers to unittest format

* Add get_sampler method in BaseTTS

* Fix rebase issues

* Add language and speaker samplers support for DDP training

* Rename distributed sampler wrapper

* Remove the DistributedSamplerWrapper and use the one from Trainer

* Bugfix after rebase

* Move the samplers config to tts config
pull/1349/head
Edresson Casanova 2022-03-10 10:56:09 -03:00 committed by GitHub
parent f381e29b91
commit 917f417ac4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 121 additions and 76 deletions

View File

@ -258,4 +258,3 @@ class BaseTrainingConfig(TrainerConfig):
num_loader_workers: int = 0
num_eval_loader_workers: int = 0
use_noise_augment: bool = False
use_language_weighted_sampler: bool = False

View File

@ -220,6 +220,18 @@ class BaseTTSConfig(BaseTrainingConfig):
eval_split_size (float):
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
use_speaker_weighted_sampler (bool):
Enable / Disable the batch balancer by speaker. Defaults to ```False```.
speaker_weighted_sampler_alpha (float):
Number that control the influence of the speaker sampler weights. Defaults to ```1.0```.
use_language_weighted_sampler (bool):
Enable / Disable the batch balancer by language. Defaults to ```False```.
language_weighted_sampler_alpha (float):
Number that control the influence of the language sampler weights. Defaults to ```1.0```.
"""
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
@ -262,3 +274,8 @@ class BaseTTSConfig(BaseTrainingConfig):
# evaluation
eval_split_max_size: int = None
eval_split_size: float = 0.01
# weighted samplers
use_speaker_weighted_sampler: bool = False
speaker_weighted_sampler_alpha: float = 1.0
use_language_weighted_sampler: bool = False
language_weighted_sampler_alpha: float = 1.0

View File

@ -7,14 +7,15 @@ import torch.distributed as dist
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_balancer_weights
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from torch.utils.data.sampler import WeightedRandomSampler
# pylint: skip-file
@ -232,6 +233,36 @@ class BaseTTS(BaseTrainerModel):
"language_ids": language_ids,
}
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
weights = None
data_items = dataset.samples
if getattr(config, "use_language_weighted_sampler", False):
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
print(" > Using Language weighted sampler with alpha:", alpha)
weights = get_language_balancer_weights(data_items) * alpha
if getattr(config, "use_speaker_weighted_sampler", False):
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
print(" > Using Speaker weighted sampler with alpha:", alpha)
if weights is not None:
weights += get_speaker_balancer_weights(data_items) * alpha
else:
weights = get_speaker_balancer_weights(data_items) * alpha
if weights is not None:
sampler = WeightedRandomSampler(weights, len(weights))
else:
sampler = None
# sampler for DDP
if sampler is None:
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
else: # If a sampler is already defined use this sampler and DDP sampler together
sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler
return sampler
def get_data_loader(
self,
config: Coqpit,
@ -300,25 +331,8 @@ class BaseTTS(BaseTrainerModel):
# sort input sequences from short to long
dataset.preprocess_samples()
# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
# Weighted samplers
# TODO: make this DDP amenable
assert not (
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
), "language_weighted_sampler is not supported with DistributedSampler"
assert not (
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
), "speaker_weighted_sampler is not supported with DistributedSampler"
if sampler is None:
if getattr(config, "use_language_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_language_weighted_sampler(dataset.samples)
elif getattr(config, "use_speaker_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_speaker_weighted_sampler(dataset.samples)
# get samplers
sampler = self.get_sampler(config, dataset, num_gpus)
loader = DataLoader(
dataset,

View File

@ -13,7 +13,6 @@ from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.configs.shared_configs import CharactersConfig
@ -24,8 +23,8 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlock
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
@ -1354,31 +1353,15 @@ class Vits(BaseTTS):
# sort input sequences from short to long
dataset.preprocess_samples()
# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
# Weighted samplers
# TODO: make this DDP amenable
assert not (
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
), "language_weighted_sampler is not supported with DistributedSampler"
assert not (
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
), "speaker_weighted_sampler is not supported with DistributedSampler"
if sampler is None:
if getattr(config, "use_language_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_language_weighted_sampler(dataset.samples)
elif getattr(config, "use_speaker_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_speaker_weighted_sampler(dataset.samples)
# get samplers
sampler = self.get_sampler(config, dataset, num_gpus)
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=False, # shuffle is done in the dataset.
drop_last=False, # setting this False might cause issues in AMP training.
sampler=sampler,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,

View File

@ -6,7 +6,6 @@ import fsspec
import numpy as np
import torch
from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler
from TTS.config import check_config_and_model_args
@ -128,11 +127,14 @@ def _set_file_path(path):
return None
def get_language_weighted_sampler(items: list):
def get_language_balancer_weights(items: list):
language_names = np.array([item["language"] for item in items])
unique_language_names = np.unique(language_names).tolist()
language_ids = [unique_language_names.index(l) for l in language_names]
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
weight_language = 1.0 / language_count
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
# get weight for each sample
dataset_samples_weight = np.array([weight_language[l] for l in language_ids])
# normalize
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
return torch.from_numpy(dataset_samples_weight).float()

View File

@ -7,7 +7,6 @@ import fsspec
import numpy as np
import torch
from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler
from TTS.config import get_from_config_or_model_args_with_default, load_config
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
@ -449,11 +448,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
return speaker_manager
def get_speaker_weighted_sampler(items: list):
def get_speaker_balancer_weights(items: list):
speaker_names = np.array([item["speaker_name"] for item in items])
unique_speaker_names = np.unique(speaker_names).tolist()
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
weight_speaker = 1.0 / speaker_count
dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
dataset_samples_weight = np.array([weight_speaker[l] for l in speaker_ids])
# normalize
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
return torch.from_numpy(dataset_samples_weight).float()

View File

@ -1,10 +1,13 @@
import functools
import unittest
import torch
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.languages import get_language_weighted_sampler
from TTS.tts.utils.languages import get_language_balancer_weights
from TTS.tts.utils.speakers import get_speaker_balancer_weights
# Fixing random state to avoid random fails
torch.manual_seed(0)
@ -25,34 +28,57 @@ dataset_config_pt = BaseDatasetConfig(
language="pt-br",
)
# Adding the EN samples twice to create an unbalanced dataset
# Adding the EN samples twice to create a language unbalanced dataset
train_samples, eval_samples = load_tts_samples(
[dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
)
# gerenate a speaker unbalanced dataset
for i, sample in enumerate(train_samples):
if i < 5:
sample["speaker_name"] = "ljspeech-0"
else:
sample["speaker_name"] = "ljspeech-1"
def is_balanced(lang_1, lang_2):
return 0.85 < lang_1 / lang_2 < 1.2
random_sampler = torch.utils.data.RandomSampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index]["language"] == "en":
en += 1
else:
pt += 1
class TestSamplers(unittest.TestCase):
def test_language_random_sampler(self): # pylint: disable=no-self-use
random_sampler = torch.utils.data.RandomSampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index]["language"] == "en":
en += 1
else:
pt += 1
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
weighted_sampler = get_language_weighted_sampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index]["language"] == "en":
en += 1
else:
pt += 1
def test_language_weighted_random_sampler(self): # pylint: disable=no-self-use
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_language_balancer_weights(train_samples), len(train_samples))
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index]["language"] == "en":
en += 1
else:
pt += 1
assert is_balanced(en, pt), "Weighted sampler is supposed to be balanced"
assert is_balanced(en, pt), "Language Weighted sampler is supposed to be balanced"
def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_speaker_balancer_weights(train_samples), len(train_samples))
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
spk1, spk2 = 0, 0
for index in ids:
if train_samples[index]["speaker_name"] == "ljspeech-0":
spk1 += 1
else:
spk2 += 1
assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"

View File

@ -45,7 +45,7 @@ config = VitsConfig(
["Be a voice, not an echo.", "ljspeech-0", None, "en"],
["Be a voice, not an echo.", "ljspeech-1", None, "pt-br"],
],
datasets=[dataset_config_en, dataset_config_pt],
datasets=[dataset_config_en, dataset_config_en, dataset_config_en, dataset_config_pt],
)
# set audio config
config.audio.do_trim_silence = True
@ -71,8 +71,11 @@ config.d_vector_dim = 256
config.model_args.use_sdp = True
config.use_sdp = True
# deactivate language sampler
config.use_language_weighted_sampler = False
# activate language and speaker samplers
config.use_language_weighted_sampler = True
config.language_weighted_sampler_alpha = 10
config.use_speaker_weighted_sampler = True
config.speaker_weighted_sampler_alpha = 5
config.save_json(config_path)