mirror of https://github.com/coqui-ai/TTS.git
add speaker encoder coqpit
parent
9f2d2d2081
commit
f8e52965dd
|
@ -0,0 +1,54 @@
|
|||
from coqpit import MISSING
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List
|
||||
from TTS.config.shared_configs import BaseTrainingConfig, BaseAudioConfig, BaseDatasetConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeakerEncoderConfig(BaseTrainingConfig):
|
||||
"""Defines parameters for Speaker Encoder model."""
|
||||
|
||||
model: str = "speaker_encoder"
|
||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||
|
||||
# model params
|
||||
model_params: dict = field(default_factory=lambda: {
|
||||
"input_dim": 40,
|
||||
"proj_dim": 256,
|
||||
"lstm_dim": 768,
|
||||
"num_lstm_layers": 3,
|
||||
"use_lstm_with_projection": True
|
||||
})
|
||||
|
||||
storage: dict = field(default_factory=lambda:{
|
||||
"sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage
|
||||
"storage_size": 15, # the size of the in-memory storage with respect to a single batch
|
||||
"additive_noise": 1e-5 # add very small gaussian noise to the data in order to increase robustness
|
||||
})
|
||||
|
||||
# training params
|
||||
max_train_step: int = 1000 # end training when number of training steps reaches this value.
|
||||
loss: str = 'angleproto'
|
||||
grad_clip: float = 3.0
|
||||
lr: float = 0.0001
|
||||
lr_decay: bool = False
|
||||
warmup_steps: int = 4000
|
||||
wd: float = 1e-6
|
||||
|
||||
# logging params
|
||||
tb_model_param_stats: bool = False
|
||||
steps_plot_stats: int = 10
|
||||
checkpoint: bool = True
|
||||
save_step: int = 1000
|
||||
print_step: int = 20
|
||||
|
||||
# data loader
|
||||
num_speakers_in_batch: int = MISSING
|
||||
num_utters_per_speaker: int = MISSING
|
||||
num_loader_workers: int = MISSING
|
||||
|
||||
def check_values(self):
|
||||
super().check_values()
|
||||
c = asdict(self)
|
||||
assert c['model_params']['input_dim'] == self.audio.num_mels, " [!] model input dimendion must be equal to melspectrogram dimension."
|
Loading…
Reference in New Issue