Implement FreeVC (#2451)

* Update .gitignore

* Draft FreeVC implementation

* Tests and relevant updates

* Update API tests

* Add missings

* Update requirements

* :(

* Lazy handle for vc

* Update docs for voice conversion

* Make style
pull/2462/head
Eren Gölge 2023-03-25 18:33:23 +01:00 committed by GitHub
parent 090cadf270
commit d309f50e53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 4416 additions and 38 deletions

3
.gitignore vendored
View File

@ -137,7 +137,7 @@ VCTK-Corpus-removed-silence/*
# ignore training logs
trainer_*_log.txt
# files used internally fro dev, test etc.
# files used internally for dev, test etc.
tests/outputs/*
tests/train_outputs/*
TODO.txt
@ -168,3 +168,4 @@ internal/*
wandb
depot/*
coqui_recipes/*
local_scripts/*

View File

@ -802,5 +802,18 @@
}
}
}
},
"voice_conversion_models":{
"multilingual":{
"vctk":{
"freevc24":{
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.13.0_models/voice_conversion_models--multilingual--vctk--freevc24.zip",
"description": "FreeVC model trained on VCTK dataset from https://github.com/OlaWod/FreeVC",
"author": "Jing-Yi Li @OlaWod",
"license": "MIT",
"commit": null
}
}
}
}
}

View File

@ -1,5 +1,7 @@
import tempfile
from pathlib import Path
from TTS.utils.audio.numpy_transforms import save_wav
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
@ -49,11 +51,14 @@ class TTS:
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)
self.synthesizer = None
self.voice_converter = None
if model_name:
self.load_model_by_name(model_name, gpu)
self.load_tts_model_by_name(model_name, gpu)
if model_path:
self.load_model_by_path(
self.load_tts_model_by_path(
model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu
)
@ -96,12 +101,22 @@ class TTS:
def download_model_by_name(self, model_name: str):
model_path, config_path, model_item = self.manager.download_model(model_name)
if model_item["default_vocoder"] is None:
if model_item.get("default_vocoder") is None:
return model_path, config_path, None, None
vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"])
return model_path, config_path, vocoder_path, vocoder_config_path
def load_model_by_name(self, model_name: str, gpu: bool = False):
def load_vc_model_by_name(self, model_name: str, gpu: bool = False):
"""Load one of the voice conversion models by name.
Args:
model_name (str): Model name to load. You can list models by ```tts.models```.
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
model_path, config_path, _, _ = self.download_model_by_name(model_name)
self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu)
def load_tts_model_by_name(self, model_name: str, gpu: bool = False):
"""Load one of 🐸TTS models by name.
Args:
@ -127,7 +142,7 @@ class TTS:
use_cuda=gpu,
)
def load_model_by_path(
def load_tts_model_by_path(
self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False
):
"""Load a model from a path.
@ -219,3 +234,67 @@ class TTS:
"""
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav)
self.synthesizer.save_wav(wav=wav, path=file_path)
def voice_conversion(
self,
sourve_wav: str,
target_wav: str,
):
"""Voice conversion with FreeVC. Convert source wav to target speaker.
Args:
source_wav (str):
Path to the source wav file.
target_wav (str):
Path to the target wav file.
"""
wav = self.synthesizer.voice_conversion(source_wav=sourve_wav, target_wav=target_wav)
return wav
def tts_with_vc(self, text: str, language: str = None, speaker_wav: str = None):
"""Convert text to speech with voice conversion.
It combines tts with voice conversion to fake voice cloning.
- Convert text to speech with tts.
- Convert the output wav to target speaker with voice conversion.
Args:
text (str):
Input text to synthesize.
language (str, optional):
Language code for multi-lingual models. You can check whether loaded model is multi-lingual
`tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
speaker_wav (str, optional):
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
Defaults to None.
"""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
# Lazy code... save it to a temp file to resample it while reading it for VC
self.tts_to_file(text=text, speaker=None, language=language, file_path=fp.name)
if self.voice_converter is None:
self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24")
wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav)
return wav
def tts_with_vc_to_file(
self, text: str, language: str = None, speaker_wav: str = None, file_path: str = "output.wav"
):
"""Convert text to speech with voice conversion and save to file.
Check `tts_with_vc` for more details.
Args:
text (str):
Input text to synthesize.
language (str, optional):
Language code for multi-lingual models. You can check whether loaded model is multi-lingual
`tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
speaker_wav (str, optional):
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
Defaults to None.
file_path (str, optional):
Output file path. Defaults to "output.wav".
"""
wav = self.tts_with_vc(text=text, language=language, speaker_wav=speaker_wav)
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)

View File

@ -100,6 +100,12 @@ If you don't specify any models, then it uses LJSpeech based English model.
```
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
```
### Voice Conversion Models
```
$ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
```
"""
# We remove Markdown code formatting programmatically here to allow us to copy-and-paste from main README to keep
# documentation in sync more easily.
@ -245,6 +251,20 @@ If you don't specify any models, then it uses LJSpeech based English model.
default=True,
)
# voice conversion args
parser.add_argument(
"--source_wav",
type=str,
default=None,
help="Original audio file to convert in the voice of the target_wav",
)
parser.add_argument(
"--target_wav",
type=str,
default=None,
help="Target audio file to convert in the voice of the source_wav",
)
args = parser.parse_args()
# print the description if either text or list_models is not set
@ -256,6 +276,8 @@ If you don't specify any models, then it uses LJSpeech based English model.
args.reference_wav,
args.model_info_by_idx,
args.model_info_by_name,
args.source_wav,
args.target_wav,
]
if not any(check_args):
parser.parse_args(["-h"])
@ -264,21 +286,23 @@ If you don't specify any models, then it uses LJSpeech based English model.
path = Path(__file__).parent / "../.models.json"
manager = ModelManager(path, progress_bar=args.progress_bar)
model_path = None
config_path = None
tts_path = None
tts_config_path = None
speakers_file_path = None
language_ids_file_path = None
vocoder_path = None
vocoder_config_path = None
encoder_path = None
encoder_config_path = None
vc_path = None
vc_config_path = None
# CASE1 #list : list pre-trained TTS models
if args.list_models:
manager.list_models()
sys.exit()
# CASE2 #info : model info of pre-trained TTS models
# CASE2 #info : model info for pre-trained TTS models
if args.model_info_by_idx:
model_query = args.model_info_by_idx
manager.model_info_by_idx(model_query)
@ -292,15 +316,27 @@ If you don't specify any models, then it uses LJSpeech based English model.
# CASE3: load pre-trained model paths
if args.model_name is not None and not args.model_path:
model_path, config_path, model_item = manager.download_model(args.model_name)
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
# tts model
if model_item["model_type"] == "tts_models":
tts_path = model_path
tts_config_path = config_path
if "default_vocoder" in model_item:
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
# voice conversion model
if model_item["model_type"] == "voice_conversion_models":
vc_path = model_path
vc_config_path = config_path
# load vocoder
if args.vocoder_name is not None and not args.vocoder_path:
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
# CASE4: set custom model paths
if args.model_path is not None:
model_path = args.model_path
config_path = args.config_path
tts_path = args.model_path
tts_config_path = args.config_path
speakers_file_path = args.speakers_file_path
language_ids_file_path = args.language_ids_file_path
@ -314,14 +350,16 @@ If you don't specify any models, then it uses LJSpeech based English model.
# load models
synthesizer = Synthesizer(
model_path,
config_path,
tts_path,
tts_config_path,
speakers_file_path,
language_ids_file_path,
vocoder_path,
vocoder_config_path,
encoder_path,
encoder_config_path,
vc_path,
vc_config_path,
args.use_cuda,
)
@ -354,16 +392,22 @@ If you don't specify any models, then it uses LJSpeech based English model.
print(" > Text: {}".format(args.text))
# kick it
wav = synthesizer.tts(
args.text,
args.speaker_idx,
args.language_idx,
args.speaker_wav,
reference_wav=args.reference_wav,
style_wav=args.capacitron_style_wav,
style_text=args.capacitron_style_text,
reference_speaker_name=args.reference_speaker_idx,
)
if tts_path is not None:
wav = synthesizer.tts(
args.text,
args.speaker_idx,
args.language_idx,
args.speaker_wav,
reference_wav=args.reference_wav,
style_wav=args.capacitron_style_wav,
style_text=args.capacitron_style_text,
reference_speaker_name=args.reference_speaker_idx,
)
elif vc_path is not None:
wav = synthesizer.voice_conversion(
source_wav=args.source_wav,
target_wav=args.target_wav,
)
# save the results
print(" > Saving output to {}".format(args.out_path))

View File

@ -37,7 +37,7 @@ def register_config(model_name: str) -> Coqpit:
"""
config_class = None
config_name = model_name + "_config"
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs"]
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs", "TTS.vc.configs"]
for path in paths:
try:
config_class = find_module(path, config_name)

View File

@ -27,6 +27,8 @@ class BaseTTS(BaseTrainerModel):
It defines common `tts` specific functions on top of `Model` implementation.
"""
MODEL_TYPE = "tts"
def __init__(
self,
config: Coqpit,

View File

@ -85,6 +85,7 @@ def to_camel(text):
text = text.capitalize()
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
text = text.replace("Tts", "TTS")
text = text.replace("vc", "VC")
return text

View File

@ -185,6 +185,13 @@ class ModelManager(object):
"""
return self._list_for_model_type("vocoder_models")
def list_vc_models(self):
"""Print all the voice conversion models and return a list of model names
Format is `language/dataset/model`
"""
return self._list_for_model_type("voice_conversion_models")
def list_langs(self):
"""Print all the available languages"""
print(" Name format: type/language")
@ -234,6 +241,7 @@ class ModelManager(object):
model_type, lang, dataset, model = model_name.split("/")
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type
# set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):

View File

@ -12,6 +12,8 @@ from TTS.tts.models import setup_model as setup_tts_model
# pylint: disable=wildcard-import
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import save_wav
from TTS.vc.models import setup_model as setup_vc_model
from TTS.vocoder.models import setup_model as setup_vocoder_model
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
@ -19,14 +21,16 @@ from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
class Synthesizer(object):
def __init__(
self,
tts_checkpoint: str,
tts_config_path: str,
tts_checkpoint: str = "",
tts_config_path: str = "",
tts_speakers_file: str = "",
tts_languages_file: str = "",
vocoder_checkpoint: str = "",
vocoder_config: str = "",
encoder_checkpoint: str = "",
encoder_config: str = "",
vc_checkpoint: str = "",
vc_config: str = "",
use_cuda: bool = False,
) -> None:
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
@ -41,12 +45,14 @@ class Synthesizer(object):
TODO: set the segmenter based on the source language
Args:
tts_checkpoint (str): path to the tts model file.
tts_config_path (str): path to the tts config file.
tts_checkpoint (str, optional): path to the tts model file.
tts_config_path (str, optional): path to the tts config file.
vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
encoder_checkpoint (str, optional): path to the speaker encoder model file. Defaults to `""`,
encoder_config (str, optional): path to the speaker encoder config file. Defaults to `""`,
vc_checkpoint (str, optional): path to the voice conversion model file. Defaults to `""`,
vc_config (str, optional): path to the voice conversion config file. Defaults to `""`,
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
"""
self.tts_checkpoint = tts_checkpoint
@ -57,10 +63,13 @@ class Synthesizer(object):
self.vocoder_config = vocoder_config
self.encoder_checkpoint = encoder_checkpoint
self.encoder_config = encoder_config
self.vc_checkpoint = vc_checkpoint
self.vc_config = vc_config
self.use_cuda = use_cuda
self.tts_model = None
self.vocoder_model = None
self.vc_model = None
self.speaker_manager = None
self.tts_speakers = {}
self.language_manager = None
@ -72,12 +81,19 @@ class Synthesizer(object):
if self.use_cuda:
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if tts_checkpoint:
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if vocoder_checkpoint:
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
if vc_checkpoint:
self._load_vc(vc_checkpoint, vc_config, use_cuda)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
@staticmethod
def _get_segmenter(lang: str):
"""get the sentence segmenter for the given language.
@ -90,6 +106,26 @@ class Synthesizer(object):
"""
return pysbd.Segmenter(language=lang, clean=True)
def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> None:
"""Load the voice conversion model.
1. Load the model config.
2. Init the model from the config.
3. Load the model weights.
4. Move the model to the GPU if CUDA is enabled.
Args:
vc_checkpoint (str): path to the model checkpoint.
tts_config_path (str): path to the model config file.
use_cuda (bool): enable/disable CUDA use.
"""
# pylint: disable=global-statement
self.vc_config = load_config(vc_config_path)
self.vc_model = setup_vc_model(config=self.vc_config)
self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint)
if use_cuda:
self.vc_model.cuda()
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
"""Load the TTS model.
@ -168,7 +204,11 @@ class Synthesizer(object):
path (str): output path to save the waveform.
"""
wav = np.array(wav)
self.tts_model.ap.save_wav(wav, path, self.output_sample_rate)
save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate)
def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]:
output_wav = self.vc_model.voice_conversion(source_wav, target_wav)
return output_wav
def tts(
self,

View File

View File

@ -0,0 +1,5 @@
from dataclasses import dataclass, field
from typing import List
from TTS.vc.configs.shared_configs import BaseVCConfig
from TTS.vc.models.freevc import FreeVCArgs, FreeVCAudioConfig, FreeVCConfig

View File

@ -0,0 +1,155 @@
from dataclasses import asdict, dataclass, field
from typing import Dict, List
from coqpit import Coqpit, check_argument
from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
@dataclass
class BaseVCConfig(BaseTrainingConfig):
"""Shared parameters among all the tts models.
Args:
audio (BaseAudioConfig):
Audio processor config object instance.
batch_group_size (int):
Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence
length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to
prevent using the same batches for each epoch.
loss_masking (bool):
enable / disable masking loss values against padded segments of samples in a batch.
min_text_len (int):
Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0.
max_text_len (int):
Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf").
min_audio_len (int):
Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0.
max_audio_len (int):
Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the
dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an
OOM error in training. Defaults to float("inf").
compute_f0 (int):
(Not in use yet).
compute_energy (int):
(Not in use yet).
compute_linear_spec (bool):
If True data loader computes and returns linear spectrograms alongside the other data.
precompute_num_workers (int):
Number of workers to precompute features. Defaults to 0.
use_noise_augment (bool):
Augment the input audio with random noise.
start_by_longest (bool):
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
Defaults to False.
shuffle (bool):
If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True.
drop_last (bool):
If True, the data loader will drop the last batch if it is not complete. It helps to prevent
issues that emerge from the partial batch statistics. Defaults to True.
add_blank (bool):
Add blank characters between each other two characters. It improves performance for some models at expense
of slower run-time due to the longer input sequence.
datasets (List[BaseDatasetConfig]):
List of datasets used for training. If multiple datasets are provided, they are merged and used together
for training.
optimizer (str):
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
Defaults to ``.
optimizer_params (dict):
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
lr_scheduler (str):
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
`TTS.utils.training`. Defaults to ``.
lr_scheduler_params (dict):
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
test_sentences (List[str]):
List of sentences to be used at testing. Defaults to '[]'
eval_split_max_size (int):
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
eval_split_size (float):
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
use_speaker_weighted_sampler (bool):
Enable / Disable the batch balancer by speaker. Defaults to ```False```.
speaker_weighted_sampler_alpha (float):
Number that control the influence of the speaker sampler weights. Defaults to ```1.0```.
use_language_weighted_sampler (bool):
Enable / Disable the batch balancer by language. Defaults to ```False```.
language_weighted_sampler_alpha (float):
Number that control the influence of the language sampler weights. Defaults to ```1.0```.
use_length_weighted_sampler (bool):
Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided
into 10 buckets considering the min and max audio of the dataset. The sampler weights will be
computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```.
length_weighted_sampler_alpha (float):
Number that control the influence of the length sampler weights. Defaults to ```1.0```.
"""
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
# training params
batch_group_size: int = 0
loss_masking: bool = None
# dataloading
min_audio_len: int = 1
max_audio_len: int = float("inf")
min_text_len: int = 1
max_text_len: int = float("inf")
compute_f0: bool = False
compute_energy: bool = False
compute_linear_spec: bool = False
precompute_num_workers: int = 0
use_noise_augment: bool = False
start_by_longest: bool = False
shuffle: bool = False
drop_last: bool = False
# dataset
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer
optimizer: str = "radam"
optimizer_params: dict = None
# scheduler
lr_scheduler: str = None
lr_scheduler_params: dict = field(default_factory=lambda: {})
# testing
test_sentences: List[str] = field(default_factory=lambda: [])
# evaluation
eval_split_max_size: int = None
eval_split_size: float = 0.01
# weighted samplers
use_speaker_weighted_sampler: bool = False
speaker_weighted_sampler_alpha: float = 1.0
use_language_weighted_sampler: bool = False
language_weighted_sampler_alpha: float = 1.0
use_length_weighted_sampler: bool = False
length_weighted_sampler_alpha: float = 1.0

17
TTS/vc/models/__init__.py Normal file
View File

@ -0,0 +1,17 @@
import importlib
import re
from typing import Dict, List, Union
def to_camel(text):
text = text.capitalize()
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC":
print(" > Using model: {}".format(config.model))
# fetch the right model implementation.
if "model" in config and config["model"].lower() == "freevc":
MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC
model = MyModel.init_from_config(config, samples)
return model

429
TTS/vc/models/base_vc.py Normal file
View File

@ -0,0 +1,429 @@
import os
import random
from typing import Dict, List, Tuple, Union
import torch
import torch.distributed as dist
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
# pylint: skip-file
class BaseVC(BaseTrainerModel):
"""Base `vc` class. Every new `vc` model must inherit this.
It defines common `vc` specific functions on top of `Model` implementation.
"""
MODEL_TYPE = "vc"
def __init__(
self,
config: Coqpit,
ap: "AudioProcessor",
speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None,
):
super().__init__()
self.config = config
self.ap = ap
self.speaker_manager = speaker_manager
self.language_manager = language_manager
self._set_model_args(config)
def _set_model_args(self, config: Coqpit):
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
`ModelArgs` has all the fields reuqired to initialize the model architecture.
`ModelConfig` has all the fields required for training, inference and containes `ModelArgs`.
If the config is for training with a name like "*Config", then the model args are embeded in the
config.model_args
If the config is for the model with a name like "*Args", then we assign the directly.
"""
# don't use isintance not to import recursively
if "Config" in config.__class__.__name__:
self.config = config
self.args = config.model_args
elif "Args" in config.__class__.__name__:
self.args = config
else:
raise ValueError("config must be either a *Config or *Args")
def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers.
This implementation yields 3 possible outcomes:
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
`config.d_vector_dim` or 512.
You can override this function for new models.
Args:
config (Coqpit): Model configuration.
"""
# set number of speakers
if self.speaker_manager is not None:
self.num_speakers = self.speaker_manager.num_speakers
elif hasattr(config, "num_speakers"):
self.num_speakers = config.num_speakers
# set ultimate speaker embedding size
if config.use_speaker_embedding or config.use_d_vector_file:
self.embedded_speaker_dim = (
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
def get_aux_input(self, **kwargs) -> Dict:
"""Prepare and return `aux_input` used by `forward()`"""
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
def get_aux_input_from_test_sentences(self, sentence_info):
if hasattr(self.config, "model_args"):
config = self.config.model_args
else:
config = self.config
# extract speaker and language info
text, speaker_name, style_wav, language_name = None, None, None, None
if isinstance(sentence_info, list):
if len(sentence_info) == 1:
text = sentence_info[0]
elif len(sentence_info) == 2:
text, speaker_name = sentence_info
elif len(sentence_info) == 3:
text, speaker_name, style_wav = sentence_info
elif len(sentence_info) == 4:
text, speaker_name, style_wav, language_name = sentence_info
else:
text = sentence_info
# get speaker id/d_vector
speaker_id, d_vector, language_id = None, None, None
if self.speaker_manager is not None:
if config.use_d_vector_file:
if speaker_name is None:
d_vector = self.speaker_manager.get_random_embedding()
else:
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
elif config.use_speaker_embedding:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_id()
else:
speaker_id = self.speaker_manager.name_to_id[speaker_name]
# get language id
if self.language_manager is not None and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.name_to_id[language_name]
return {
"text": text,
"speaker_id": speaker_id,
"style_wav": style_wav,
"d_vector": d_vector,
"language_id": language_id,
}
def format_batch(self, batch: Dict) -> Dict:
"""Generic batch formatting for `VCDataset`.
You must override this if you use a custom dataset.
Args:
batch (Dict): [description]
Returns:
Dict: [description]
"""
# setup input batch
text_input = batch["token_id"]
text_lengths = batch["token_id_lengths"]
speaker_names = batch["speaker_names"]
linear_input = batch["linear"]
mel_input = batch["mel"]
mel_lengths = batch["mel_lengths"]
stop_targets = batch["stop_targets"]
item_idx = batch["item_idxs"]
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
attn_mask = batch["attns"]
waveform = batch["waveform"]
pitch = batch["pitch"]
energy = batch["energy"]
language_ids = batch["language_ids"]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
# compute durations from attention masks
durations = None
if attn_mask is not None:
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
for idx, am in enumerate(attn_mask):
# compute raw durations
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
dur[c_idxs] = counts
# smooth the durations and set any 0 duration to 1
# by cutting off from the largest duration indeces.
extra_frames = dur.sum() - mel_lengths[idx]
largest_idxs = torch.argsort(-dur)[:extra_frames]
dur[largest_idxs] -= 1
assert (
dur.sum() == mel_lengths[idx]
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
durations[idx, : text_lengths[idx]] = dur
# set stop targets wrt reduction factor
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_()
return {
"text_input": text_input,
"text_lengths": text_lengths,
"speaker_names": speaker_names,
"mel_input": mel_input,
"mel_lengths": mel_lengths,
"linear_input": linear_input,
"stop_targets": stop_targets,
"stop_target_lengths": stop_target_lengths,
"attn_mask": attn_mask,
"durations": durations,
"speaker_ids": speaker_ids,
"d_vectors": d_vectors,
"max_text_length": float(max_text_length),
"max_spec_length": float(max_spec_length),
"item_idx": item_idx,
"waveform": waveform,
"pitch": pitch,
"energy": energy,
"language_ids": language_ids,
"audio_unique_names": batch["audio_unique_names"],
}
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
weights = None
data_items = dataset.samples
if getattr(config, "use_language_weighted_sampler", False):
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
print(" > Using Language weighted sampler with alpha:", alpha)
weights = get_language_balancer_weights(data_items) * alpha
if getattr(config, "use_speaker_weighted_sampler", False):
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
print(" > Using Speaker weighted sampler with alpha:", alpha)
if weights is not None:
weights += get_speaker_balancer_weights(data_items) * alpha
else:
weights = get_speaker_balancer_weights(data_items) * alpha
if getattr(config, "use_length_weighted_sampler", False):
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
print(" > Using Length weighted sampler with alpha:", alpha)
if weights is not None:
weights += get_length_balancer_weights(data_items) * alpha
else:
weights = get_length_balancer_weights(data_items) * alpha
if weights is not None:
sampler = WeightedRandomSampler(weights, len(weights))
else:
sampler = None
# sampler for DDP
if sampler is None:
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
else: # If a sampler is already defined use this sampler and DDP sampler together
sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler
return sampler
def get_data_loader(
self,
config: Coqpit,
assets: Dict,
is_eval: bool,
samples: Union[List[Dict], List[List]],
verbose: bool,
num_gpus: int,
rank: int = None,
) -> "DataLoader":
if is_eval and not config.run_eval:
loader = None
else:
# setup multi-speaker attributes
if self.speaker_manager is not None:
if hasattr(config, "model_args"):
speaker_id_mapping = (
self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None
)
d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None
config.use_d_vector_file = config.model_args.use_d_vector_file
else:
speaker_id_mapping = self.speaker_manager.name_to_id if config.use_speaker_embedding else None
d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None
else:
speaker_id_mapping = None
d_vector_mapping = None
# setup multi-lingual attributes
if self.language_manager is not None:
language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None
else:
language_id_mapping = None
# init dataloader
dataset = TTSDataset(
outputs_per_step=config.r if "r" in config else 1,
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
compute_f0=config.get("compute_f0", False),
f0_cache_path=config.get("f0_cache_path", None),
compute_energy=config.get("compute_energy", False),
energy_cache_path=config.get("energy_cache_path", None),
samples=samples,
ap=self.ap,
return_wav=config.return_wav if "return_wav" in config else False,
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
min_text_len=config.min_text_len,
max_text_len=config.max_text_len,
min_audio_len=config.min_audio_len,
max_audio_len=config.max_audio_len,
phoneme_cache_path=config.phoneme_cache_path,
precompute_num_workers=config.precompute_num_workers,
use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=None,
start_by_longest=config.start_by_longest,
language_id_mapping=language_id_mapping,
)
# wait all the DDP process to be ready
if num_gpus > 1:
dist.barrier()
# sort input sequences from short to long
dataset.preprocess_samples()
# get samplers
sampler = self.get_sampler(config, dataset, num_gpus)
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=config.shuffle if sampler is None else False, # if there is no other sampler
collate_fn=dataset.collate_fn,
drop_last=config.drop_last, # setting this False might cause issues in AMP training.
sampler=sampler,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
return loader
def _get_test_aux_input(
self,
) -> Dict:
d_vector = None
if self.config.use_d_vector_file:
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
d_vector = (random.sample(sorted(d_vector), 1),)
aux_inputs = {
"speaker_id": None
if not self.config.use_speaker_embedding
else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1),
"d_vector": d_vector,
"style_wav": None, # TODO: handle GST style input
}
return aux_inputs
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
"""Generic test run for `vc` models used by `Trainer`.
You can override this for a different behaviour.
Args:
assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`.
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self._get_test_aux_input()
for idx, sen in enumerate(test_sentences):
if isinstance(sen, list):
aux_inputs = self.get_aux_input_from_test_sentences(sen)
sen = aux_inputs["text"]
outputs_dict = synthesis(
self,
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
use_griffin_lim=True,
do_trim_silence=False,
)
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
)
test_figures["{}-alignment".format(idx)] = plot_alignment(
outputs_dict["outputs"]["alignments"], output_fig=False
)
return test_figures, test_audios
def on_init_start(self, trainer):
"""Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths."""
if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.pth")
self.speaker_manager.save_ids_to_file(output_path)
trainer.config.speakers_file = output_path
# some models don't have `model_args` set
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.speakers_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.pth` is saved to {output_path}.")
print(" > `speakers_file` is updated in the config.json.")
if self.language_manager is not None:
output_path = os.path.join(trainer.output_path, "language_ids.json")
self.language_manager.save_ids_to_file(output_path)
trainer.config.language_ids_file = output_path
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.language_ids_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `language_ids.json` is saved to {output_path}.")
print(" > `language_ids_file` is updated in the config.json.")

833
TTS/vc/models/freevc.py Normal file
View File

@ -0,0 +1,833 @@
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union
import librosa
import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
import TTS.vc.modules.freevc.commons as commons
import TTS.vc.modules.freevc.modules as modules
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.io import load_fsspec, save_checkpoint
from TTS.vc.configs.shared_configs import BaseVCConfig
from TTS.vc.models.base_vc import BaseVC
from TTS.vc.modules.freevc.commons import get_padding, init_weights
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
from TTS.vc.modules.freevc.wavlm import get_wavlm
class ResidualCouplingBlock(nn.Module):
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.n_flows = n_flows
self.gin_channels = gin_channels
self.flows = nn.ModuleList()
for i in range(n_flows):
self.flows.append(
modules.ResidualCouplingLayer(
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
mean_only=True,
)
)
self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class Encoder(nn.Module):
def __init__(
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
class Generator(torch.nn.Module):
def __init__(
self,
initial_channel,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=0,
):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(init_weights)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print("Removing weight norm...")
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(MultiPeriodDiscriminator, self).__init__()
periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class SpeakerEncoder(torch.nn.Module):
def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256):
super(SpeakerEncoder, self).__init__()
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
self.relu = nn.ReLU()
def forward(self, mels):
self.lstm.flatten_parameters()
_, (hidden, _) = self.lstm(mels)
embeds_raw = self.relu(self.linear(hidden[-1]))
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
mel_slices = []
for i in range(0, total_frames - partial_frames, partial_hop):
mel_range = torch.arange(i, i + partial_frames)
mel_slices.append(mel_range)
return mel_slices
def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
mel_len = mel.size(1)
last_mel = mel[:, -partial_frames:]
if mel_len > partial_frames:
mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop)
mels = list(mel[:, s] for s in mel_slices)
mels.append(last_mel)
mels = torch.stack(tuple(mels), 0).squeeze(1)
with torch.no_grad():
partial_embeds = self(mels)
embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
# embed = embed / torch.linalg.norm(embed, 2)
else:
with torch.no_grad():
embed = self(last_mel)
return embed
@dataclass
class FreeVCAudioConfig(Coqpit):
"""Audio configuration
Args:
max_wav_value (float):
The maximum value of the waveform.
input_sample_rate (int):
The sampling rate of the input waveform.
output_sample_rate (int):
The sampling rate of the output waveform.
filter_length (int):
The length of the filter.
hop_length (int):
The hop length.
win_length (int):
The window length.
n_mel_channels (int):
The number of mel channels.
mel_fmin (float):
The minimum frequency of the mel filterbank.
mel_fmax (Optional[float]):
The maximum frequency of the mel filterbank.
"""
max_wav_value: float = field(default=32768.0)
input_sample_rate: int = field(default=16000)
output_sample_rate: int = field(default=24000)
filter_length: int = field(default=1280)
hop_length: int = field(default=320)
win_length: int = field(default=1280)
n_mel_channels: int = field(default=80)
mel_fmin: float = field(default=0.0)
mel_fmax: Optional[float] = field(default=None)
@dataclass
class FreeVCArgs(Coqpit):
"""FreeVC model arguments
Args:
spec_channels (int):
The number of channels in the spectrogram.
inter_channels (int):
The number of channels in the intermediate layers.
hidden_channels (int):
The number of channels in the hidden layers.
filter_channels (int):
The number of channels in the filter layers.
n_heads (int):
The number of attention heads.
n_layers (int):
The number of layers.
kernel_size (int):
The size of the kernel.
p_dropout (float):
The dropout probability.
resblock (str):
The type of residual block.
resblock_kernel_sizes (List[int]):
The kernel sizes for the residual blocks.
resblock_dilation_sizes (List[List[int]]):
The dilation sizes for the residual blocks.
upsample_rates (List[int]):
The upsample rates.
upsample_initial_channel (int):
The number of channels in the initial upsample layer.
upsample_kernel_sizes (List[int]):
The kernel sizes for the upsample layers.
n_layers_q (int):
The number of layers in the quantization network.
use_spectral_norm (bool):
Whether to use spectral normalization.
gin_channels (int):
The number of channels in the global conditioning vector.
ssl_dim (int):
The dimension of the self-supervised learning embedding.
use_spk (bool):
Whether to use external speaker encoder.
"""
spec_channels: int = field(default=641)
inter_channels: int = field(default=192)
hidden_channels: int = field(default=192)
filter_channels: int = field(default=768)
n_heads: int = field(default=2)
n_layers: int = field(default=6)
kernel_size: int = field(default=3)
p_dropout: float = field(default=0.1)
resblock: str = field(default="1")
resblock_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11])
resblock_dilation_sizes: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
upsample_rates: List[int] = field(default_factory=lambda: [10, 8, 2, 2])
upsample_initial_channel: int = field(default=512)
upsample_kernel_sizes: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
n_layers_q: int = field(default=3)
use_spectral_norm: bool = field(default=False)
gin_channels: int = field(default=256)
ssl_dim: int = field(default=1024)
use_spk: bool = field(default=False)
num_spks: int = field(default=0)
segment_size: int = field(default=8960)
class FreeVC(BaseVC):
"""
Papaer::
https://arxiv.org/abs/2210.15418#
Paper Abstract::
Voice conversion (VC) can be achieved by first extracting source content information and target speaker
information, and then reconstructing waveform with these information. However, current approaches normally
either extract dirty content information with speaker information leaked in, or demand a large amount of
annotated data for training. Besides, the quality of reconstructed waveform can be degraded by the
mismatch between conversion model and vocoder. In this paper, we adopt the end-to-end framework of VITS for
high-quality waveform reconstruction, and propose strategies for clean content information extraction without
text annotation. We disentangle content information by imposing an information bottleneck to WavLM features,
and propose the spectrogram-resize based data augmentation to improve the purity of extracted content
information. Experimental results show that the proposed method outperforms the latest VC models trained with
annotated data and has greater robustness.
Original Code::
https://github.com/OlaWod/FreeVC
Examples:
>>> from TTS.vc.configs.freevc_config import FreeVCConfig
>>> from TTS.vc.models.freevc import FreeVC
>>> config = FreeVCConfig()
>>> model = FreeVC(config)
"""
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
super().__init__(config, None, speaker_manager, None)
self.init_multispeaker(config)
self.spec_channels = self.args.spec_channels
self.inter_channels = self.args.inter_channels
self.hidden_channels = self.args.hidden_channels
self.filter_channels = self.args.filter_channels
self.n_heads = self.args.n_heads
self.n_layers = self.args.n_layers
self.kernel_size = self.args.kernel_size
self.p_dropout = self.args.p_dropout
self.resblock = self.args.resblock
self.resblock_kernel_sizes = self.args.resblock_kernel_sizes
self.resblock_dilation_sizes = self.args.resblock_dilation_sizes
self.upsample_rates = self.args.upsample_rates
self.upsample_initial_channel = self.args.upsample_initial_channel
self.upsample_kernel_sizes = self.args.upsample_kernel_sizes
self.segment_size = self.args.segment_size
self.gin_channels = self.args.gin_channels
self.ssl_dim = self.args.ssl_dim
self.use_spk = self.args.use_spk
self.enc_p = Encoder(self.args.ssl_dim, self.inter_channels, self.hidden_channels, 5, 1, 16)
self.dec = Generator(
self.inter_channels,
self.resblock,
self.resblock_kernel_sizes,
self.resblock_dilation_sizes,
self.upsample_rates,
self.upsample_initial_channel,
self.upsample_kernel_sizes,
gin_channels=self.gin_channels,
)
self.enc_q = Encoder(
self.spec_channels, self.inter_channels, self.hidden_channels, 5, 1, 16, gin_channels=self.gin_channels
)
self.flow = ResidualCouplingBlock(
self.inter_channels, self.hidden_channels, 5, 1, 4, gin_channels=self.gin_channels
)
if not self.use_spk:
self.enc_spk = SpeakerEncoder(model_hidden_size=self.gin_channels, model_embedding_size=self.gin_channels)
else:
self.load_pretrained_speaker_encoder()
self.wavlm = get_wavlm()
@property
def device(self):
return next(self.parameters()).device
def load_pretrained_speaker_encoder(self):
"""Load pretrained speaker encoder model as mentioned in the paper."""
print(" > Loading pretrained speaker encoder model ...")
self.enc_spk_ex = SpeakerEncoderEx(
"https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt"
)
def init_multispeaker(self, config: Coqpit):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
Args:
config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
"""
self.num_spks = self.args.num_spks
if self.speaker_manager:
self.num_spks = self.speaker_manager.num_spks
def forward(
self,
c: torch.Tensor,
spec: torch.Tensor,
g: Optional[torch.Tensor] = None,
mel: Optional[torch.Tensor] = None,
c_lengths: Optional[torch.Tensor] = None,
spec_lengths: Optional[torch.Tensor] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]:
"""
Forward pass of the model.
Args:
c: WavLM features. Shape: (batch_size, c_seq_len).
spec: The input spectrogram. Shape: (batch_size, spec_seq_len, spec_dim).
g: The speaker embedding. Shape: (batch_size, spk_emb_dim).
mel: The input mel-spectrogram for the speaker encoder. Shape: (batch_size, mel_seq_len, mel_dim).
c_lengths: The lengths of the WavLM features. Shape: (batch_size,).
spec_lengths: The lengths of the spectrogram. Shape: (batch_size,).
Returns:
o: The output spectrogram. Shape: (batch_size, spec_seq_len, spec_dim).
ids_slice: The slice indices. Shape: (batch_size, num_slices).
spec_mask: The spectrogram mask. Shape: (batch_size, spec_seq_len).
(z, z_p, m_p, logs_p, m_q, logs_q): A tuple of latent variables.
"""
# If c_lengths is None, set it to the length of the last dimension of c
if c_lengths is None:
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
# If spec_lengths is None, set it to the length of the last dimension of spec
if spec_lengths is None:
spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device)
# If use_spk is False, compute g from mel using enc_spk
g = None
if not self.use_spk:
g = self.enc_spk(mel).unsqueeze(-1)
# Compute m_p, logs_p, z, m_q, logs_q, and spec_mask using enc_p and enc_q
_, m_p, logs_p, _ = self.enc_p(c, c_lengths)
z, m_q, logs_q, spec_mask = self.enc_q(spec.transpose(1, 2), spec_lengths, g=g)
# Compute z_p using flow
z_p = self.flow(z, spec_mask, g=g)
# Randomly slice z and compute o using dec
z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size)
o = self.dec(z_slice, g=g)
return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
@torch.no_grad()
def inference(self, c, g=None, mel=None, c_lengths=None):
"""
Inference pass of the model
Args:
c (torch.Tensor): Input tensor. Shape: (batch_size, c_seq_len).
g (torch.Tensor): Speaker embedding tensor. Shape: (batch_size, spk_emb_dim).
mel (torch.Tensor): Mel-spectrogram tensor. Shape: (batch_size, mel_seq_len, mel_dim).
c_lengths (torch.Tensor): Lengths of the input tensor. Shape: (batch_size,).
Returns:
torch.Tensor: Output tensor.
"""
if c_lengths == None:
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
if not self.use_spk:
g = self.enc_spk.embed_utterance(mel)
g = g.unsqueeze(-1)
z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths)
z = self.flow(z_p, c_mask, g=g, reverse=True)
o = self.dec(z * c_mask, g=g)
return o
def extract_wavlm_features(self, y):
"""Extract WavLM features from an audio tensor.
Args:
y (torch.Tensor): Audio tensor. Shape: (batch_size, audio_seq_len).
"""
with torch.no_grad():
c = self.wavlm.extract_features(y)[0]
c = c.transpose(1, 2)
return c
def load_audio(self, wav):
"""Read and format the input audio."""
if isinstance(wav, str):
wav, _ = librosa.load(wav, sr=self.config.audio.input_sample_rate)
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav).to(self.device)
if isinstance(wav, torch.Tensor):
wav = wav.to(self.device)
if isinstance(wav, list):
wav = torch.from_numpy(np.array(wav)).to(self.device)
return wav.float()
@torch.inference_mode()
def voice_conversion(self, src, tgt):
"""
Voice conversion pass of the model.
Args:
src (str or torch.Tensor): Source utterance.
tgt (str or torch.Tensor): Target utterance.
Returns:
torch.Tensor: Output tensor.
"""
wav_tgt = self.load_audio(tgt).cpu().numpy()
wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
if self.config.model_args.use_spk:
g_tgt = self.enc_spk_ex.embed_utterance(wav_tgt)
g_tgt = torch.from_numpy(g_tgt)[None, :, None].to(self.device)
else:
wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(self.device)
mel_tgt = mel_spectrogram_torch(
wav_tgt,
self.config.audio.filter_length,
self.config.audio.n_mel_channels,
self.config.audio.input_sample_rate,
self.config.audio.hop_length,
self.config.audio.win_length,
self.config.audio.mel_fmin,
self.config.audio.mel_fmax,
)
# src
wav_src = self.load_audio(src)
c = self.extract_wavlm_features(wav_src[None, :])
if self.config.model_args.use_spk:
audio = self.inference(c, g=g_tgt)
else:
audio = self.inference(c, mel=mel_tgt.transpose(1, 2))
audio = audio[0][0].data.cpu().float().numpy()
return audio
def eval_step():
...
@staticmethod
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
model = FreeVC(config)
return model
def load_checkpoint(self, config, checkpoint_path, eval=False, strict=True, cache=False):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"], strict=strict)
if eval:
self.eval()
def train_step():
...
@dataclass
class FreeVCConfig(BaseVCConfig):
"""Defines parameters for FreeVC End2End TTS model.
Args:
model (str):
Model name. Do not change unless you know what you are doing.
model_args (FreeVCArgs):
Model architecture arguments. Defaults to `FreeVCArgs()`.
audio (FreeVCAudioConfig):
Audio processing configuration. Defaults to `FreeVCAudioConfig()`.
grad_clip (List):
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
lr_gen (float):
Initial learning rate for the generator. Defaults to 0.0002.
lr_disc (float):
Initial learning rate for the discriminator. Defaults to 0.0002.
lr_scheduler_gen (str):
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
`ExponentialLR`.
lr_scheduler_gen_params (dict):
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
lr_scheduler_disc (str):
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
`ExponentialLR`.
lr_scheduler_disc_params (dict):
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
scheduler_after_epoch (bool):
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
optimizer (str):
Name of the optimizer to use with both the generator and the discriminator networks. One of the
`torch.optim.*`. Defaults to `AdamW`.
kl_loss_alpha (float):
Loss weight for KL loss. Defaults to 1.0.
disc_loss_alpha (float):
Loss weight for the discriminator loss. Defaults to 1.0.
gen_loss_alpha (float):
Loss weight for the generator loss. Defaults to 1.0.
feat_loss_alpha (float):
Loss weight for the feature matching loss. Defaults to 1.0.
mel_loss_alpha (float):
Loss weight for the mel loss. Defaults to 45.0.
return_wav (bool):
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
use_weighted_sampler (bool):
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
weighted_sampler_attrs (dict):
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
by overweighting `root_path` by 2.0. Defaults to `{}`.
weighted_sampler_multipliers (dict):
Weight each unique value of a key returned by the formatter for weighted sampling.
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
add_blank (bool):
If true, a blank token is added in between every character. Defaults to `True`.
test_sentences (List[List]):
List of sentences with speaker and language information to be used for testing.
language_ids_file (str):
Path to the language ids file.
use_language_embedding (bool):
If true, language embedding is used. Defaults to `False`.
Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
Example:
>>> from TTS.tts.configs.freevc_config import FreeVCConfig
>>> config = FreeVCConfig()
"""
model: str = "freevc"
# model specific params
model_args: FreeVCArgs = FreeVCArgs()
audio: FreeVCAudioConfig = FreeVCAudioConfig()
# optimizer
# TODO with training support
# loss params
# TODO with training support
# data loader params
return_wav: bool = True
compute_linear_spec: bool = True
# sampler params
use_weighted_sampler: bool = False # TODO: move it to the base config
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
# overrides
r: int = 1 # DO NOT CHANGE
add_blank: bool = True
# multi-speaker settings
# use speaker embedding layer
num_speakers: int = 0
speakers_file: str = None
speaker_embedding_channels: int = 256
# use d-vectors
use_d_vector_file: bool = False
d_vector_file: List[str] = None
d_vector_dim: int = None
def __post_init__(self):
for key, val in self.model_args.items():
if hasattr(self, key):
self[key] = val

View File

View File

View File

@ -0,0 +1,170 @@
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def rand_spec_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm

View File

@ -0,0 +1,125 @@
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
MAX_WAV_VALUE = 32768.0
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
global mel_basis
dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)
return spec
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)
return spec

View File

@ -0,0 +1,391 @@
import copy
import math
import numpy as np
import scipy
import torch
from torch import nn
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
import TTS.vc.modules.freevc.commons as commons
from TTS.vc.modules.freevc.commons import get_padding, init_weights
LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
assert n_layers > 1, "Number of layers should be larger than 0."
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class WN(torch.nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
super(WN, self).__init__()
assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels
self.kernel_size = (kernel_size,)
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.p_dropout = p_dropout
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.drop = nn.Dropout(p_dropout)
if gin_channels != 0:
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
for i in range(n_layers):
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:, : self.hidden_channels, :]
x = (x + res_acts) * x_mask
output = output + res_skip_acts[:, self.hidden_channels :, :]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
weight_norm(
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
),
]
)
self.convs2.apply(init_weights)
def forward(self, x, x_mask=None):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c2(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights)
def forward(self, x, x_mask=None):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Log(nn.Module):
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class Flip(nn.Module):
def forward(self, x, *args, reverse=False, **kwargs):
x = torch.flip(x, [1])
if not reverse:
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
return x, logdet
else:
return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels, 1))
self.logs = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class ResidualCouplingLayer(nn.Module):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=0,
gin_channels=0,
mean_only=False,
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = WN(
hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels
)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
logs = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x

View File

@ -0,0 +1,65 @@
import struct
from pathlib import Path
from typing import Optional, Union
# import webrtcvad
import librosa
import numpy as np
from scipy.ndimage.morphology import binary_dilation
from TTS.vc.modules.freevc.speaker_encoder.hparams import *
int16_max = (2**15) - 1
def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], source_sr: Optional[int] = None):
"""
Applies the preprocessing operations used in training the Speaker Encoder to a waveform
either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
:param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
just .wav), either the waveform as a numpy array of floats.
:param source_sr: if passing an audio waveform, the sampling rate of the waveform before
preprocessing. After preprocessing, the waveform's sampling rate will match the data
hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
this argument will be ignored.
"""
# Load the wav from disk if needed
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
wav, source_sr = librosa.load(fpath_or_wav, sr=None)
else:
wav = fpath_or_wav
# Resample the wav if needed
if source_sr is not None and source_sr != sampling_rate:
wav = librosa.resample(wav, source_sr, sampling_rate)
# Apply the preprocessing: normalize volume and shorten long silences
wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
wav = trim_long_silences(wav)
return wav
def wav_to_mel_spectrogram(wav):
"""
Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
Note: this not a log-mel spectrogram.
"""
frames = librosa.feature.melspectrogram(
y=wav,
sr=sampling_rate,
n_fft=int(sampling_rate * mel_window_length / 1000),
hop_length=int(sampling_rate * mel_window_step / 1000),
n_mels=mel_n_channels,
)
return frames.astype(np.float32).T
def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
if increase_only and decrease_only:
raise ValueError("Both increase only and decrease only are set")
dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav**2))
if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
return wav
return wav * (10 ** (dBFS_change / 20))

View File

@ -0,0 +1,31 @@
## Mel-filterbank
mel_window_length = 25 # In milliseconds
mel_window_step = 10 # In milliseconds
mel_n_channels = 40
## Audio
sampling_rate = 16000
# Number of spectrogram frames in a partial utterance
partials_n_frames = 160 # 1600 ms
## Voice Activation Detection
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
# This sets the granularity of the VAD. Should not need to be changed.
vad_window_length = 30 # In milliseconds
# Number of frames to average together when performing the moving average smoothing.
# The larger this value, the larger the VAD variations must be to not get smoothed out.
vad_moving_average_width = 8
# Maximum number of consecutive silent frames a segment can have.
vad_max_silence_length = 6
## Audio volume normalization
audio_norm_target_dBFS = -30
## Model parameters
model_hidden_size = 256
model_embedding_size = 256
model_num_layers = 3

View File

@ -0,0 +1,175 @@
from pathlib import Path
from time import perf_counter as timer
from typing import List, Union
import numpy as np
import torch
from torch import nn
from TTS.utils.io import load_fsspec
from TTS.vc.modules.freevc.speaker_encoder import audio
from TTS.vc.modules.freevc.speaker_encoder.hparams import *
class SpeakerEncoder(nn.Module):
def __init__(self, weights_fpath, device: Union[str, torch.device] = None, verbose=True):
"""
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
If None, defaults to cuda if it is available on your machine, otherwise the model will
run on cpu. Outputs are always returned on the cpu, as numpy arrays.
"""
super().__init__()
# Define the network
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
self.relu = nn.ReLU()
# Get the target device
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
elif isinstance(device, str):
device = torch.device(device)
self.device = device
# Load the pretrained model'speaker weights
# weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
# if not weights_fpath.exists():
# raise Exception("Couldn't find the voice encoder pretrained model at %s." %
# weights_fpath)
start = timer()
checkpoint = load_fsspec(weights_fpath, map_location="cpu")
self.load_state_dict(checkpoint["model_state"], strict=False)
self.to(device)
if verbose:
print("Loaded the voice encoder model on %s in %.2f seconds." % (device.type, timer() - start))
def forward(self, mels: torch.FloatTensor):
"""
Computes the embeddings of a batch of utterance spectrograms.
:param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape
(batch_size, n_frames, n_channels)
:return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size).
Embeddings are positive and L2-normed, thus they lay in the range [0, 1].
"""
# Pass the input through the LSTM layers and retrieve the final hidden state of the last
# layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings.
_, (hidden, _) = self.lstm(mels)
embeds_raw = self.relu(self.linear(hidden[-1]))
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
@staticmethod
def compute_partial_slices(n_samples: int, rate, min_coverage):
"""
Computes where to split an utterance waveform and its corresponding mel spectrogram to
obtain partial utterances of <partials_n_frames> each. Both the waveform and the
mel spectrogram slices are returned, so as to make each partial utterance waveform
correspond to its spectrogram.
The returned ranges may be indexing further than the length of the waveform. It is
recommended that you pad the waveform with zeros up to wav_slices[-1].stop.
:param n_samples: the number of samples in the waveform
:param rate: how many partial utterances should occur per second. Partial utterances must
cover the span of the entire utterance, thus the rate should not be lower than the inverse
of the duration of a partial utterance. By default, partial utterances are 1.6s long and
the minimum rate is thus 0.625.
:param min_coverage: when reaching the last partial utterance, it may or may not have
enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
then the last partial utterance will be considered by zero-padding the audio. Otherwise,
it will be discarded. If there aren't enough frames for one partial utterance,
this parameter is ignored so that the function always returns at least one slice.
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
respectively the waveform and the mel spectrogram with these slices to obtain the partial
utterances.
"""
assert 0 < min_coverage <= 1
# Compute how many frames separate two partial utterances
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
assert 0 < frame_step, "The rate is too high"
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % (
sampling_rate / (samples_per_frame * partials_n_frames)
)
# Compute the slices
wav_slices, mel_slices = [], []
steps = max(1, n_frames - partials_n_frames + frame_step + 1)
for i in range(0, steps, frame_step):
mel_range = np.array([i, i + partials_n_frames])
wav_range = mel_range * samples_per_frame
mel_slices.append(slice(*mel_range))
wav_slices.append(slice(*wav_range))
# Evaluate whether extra padding is warranted or not
last_wav_range = wav_slices[-1]
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
if coverage < min_coverage and len(mel_slices) > 1:
mel_slices = mel_slices[:-1]
wav_slices = wav_slices[:-1]
return wav_slices, mel_slices
def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75):
"""
Computes an embedding for a single utterance. The utterance is divided in partial
utterances and an embedding is computed for each. The complete utterance embedding is the
L2-normed average embedding of the partial utterances.
TODO: independent batched version of this function
:param wav: a preprocessed utterance waveform as a numpy array of float32
:param return_partials: if True, the partial embeddings will also be returned along with
the wav slices corresponding to each partial utterance.
:param rate: how many partial utterances should occur per second. Partial utterances must
cover the span of the entire utterance, thus the rate should not be lower than the inverse
of the duration of a partial utterance. By default, partial utterances are 1.6s long and
the minimum rate is thus 0.625.
:param min_coverage: when reaching the last partial utterance, it may or may not have
enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
then the last partial utterance will be considered by zero-padding the audio. Otherwise,
it will be discarded. If there aren't enough frames for one partial utterance,
this parameter is ignored so that the function always returns at least one slice.
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
returned.
"""
# Compute where to split the utterance into partials and pad the waveform with zeros if
# the partial utterances cover a larger range.
wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
max_wave_length = wav_slices[-1].stop
if max_wave_length >= len(wav):
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
# Split the utterance into partials and forward them through the model
mel = audio.wav_to_mel_spectrogram(wav)
mels = np.array([mel[s] for s in mel_slices])
with torch.no_grad():
mels = torch.from_numpy(mels).to(self.device)
partial_embeds = self(mels).cpu().numpy()
# Compute the utterance embedding from the partial embeddings
raw_embed = np.mean(partial_embeds, axis=0)
embed = raw_embed / np.linalg.norm(raw_embed, 2)
if return_partials:
return embed, partial_embeds, wav_slices
return embed
def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
"""
Compute the embedding of a collection of wavs (presumably from the same speaker) by
averaging their embedding and L2-normalizing it.
:param wavs: list of wavs a numpy arrays of float32.
:param kwargs: extra arguments to embed_utterance()
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,).
"""
raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs], axis=0)
return raw_embed / np.linalg.norm(raw_embed, 2)

View File

@ -0,0 +1,35 @@
import os
import urllib.request
import torch
from TTS.utils.generic_utils import get_user_data_dir
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt"
def get_wavlm(device="cpu"):
"""Download the model and return the model object."""
output_path = get_user_data_dir("tts")
output_path = os.path.join(output_path, "wavlm")
if not os.path.exists(output_path):
os.makedirs(output_path)
output_path = os.path.join(output_path, "WavLM-Large.pt")
if not os.path.exists(output_path):
print(f" > Downloading WavLM model to {output_path} ...")
urllib.request.urlretrieve(model_uri, output_path)
checkpoint = torch.load(output_path, map_location=torch.device(device))
cfg = WavLMConfig(checkpoint["cfg"])
wavlm = WavLM(cfg).to(device)
wavlm.load_state_dict(checkpoint["model"])
wavlm.eval()
return wavlm
if __name__ == "__main__":
wavlm = get_wavlm()

View File

@ -0,0 +1,99 @@
{
"_name_or_path": "./wavlm-large/",
"activation_dropout": 0.0,
"adapter_kernel_size": 3,
"adapter_stride": 2,
"add_adapter": false,
"apply_spec_augment": true,
"architectures": [
"WavLMModel"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"classifier_proj_size": 256,
"codevector_dim": 768,
"contrastive_logits_temperature": 0.1,
"conv_bias": false,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"diversity_loss_weight": 0.1,
"do_stable_layer_norm": true,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.1,
"feat_quantizer_dropout": 0.0,
"final_dropout": 0.0,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_feature_length": 10,
"mask_feature_min_masks": 0,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_masks": 2,
"mask_time_min_space": 1,
"mask_time_other": 0.0,
"mask_time_prob": 0.075,
"mask_time_selection": "static",
"max_bucket_distance": 800,
"model_type": "wavlm",
"num_adapter_layers": 3,
"num_attention_heads": 16,
"num_buckets": 320,
"num_codevector_groups": 2,
"num_codevectors_per_group": 320,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_ctc_classes": 80,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"num_negatives": 100,
"output_hidden_size": 1024,
"pad_token_id": 0,
"proj_codevector_dim": 768,
"replace_prob": 0.5,
"tokenizer_class": "Wav2Vec2CTCTokenizer",
"torch_dtype": "float32",
"transformers_version": "4.15.0.dev0",
"use_weighted_layer_sum": false,
"vocab_size": 32
}

View File

@ -0,0 +1,768 @@
# --------------------------------------------------------
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
import warnings
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import Parameter
class TransposeLast(nn.Module):
def __init__(self, deconstruct_idx=None):
super().__init__()
self.deconstruct_idx = deconstruct_idx
def forward(self, x):
if self.deconstruct_idx is not None:
x = x[self.deconstruct_idx]
return x.transpose(-2, -1)
class Fp32LayerNorm(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.layer_norm(
input.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
class Fp32GroupNorm(nn.GroupNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.group_norm(
input.float(),
self.num_groups,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
class SamePad(nn.Module):
def __init__(self, kernel_size, causal=False):
super().__init__()
if causal:
self.remove = kernel_size - 1
else:
self.remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x):
if self.remove > 0:
x = x[:, :, : -self.remove]
return x
class Swish(nn.Module):
"""Swish function"""
def __init__(self):
"""Construct an MultiHeadedAttention object."""
super(Swish, self).__init__()
self.act = torch.nn.Sigmoid()
def forward(self, x):
return x * self.act(x)
class GLU_Linear(nn.Module):
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
super(GLU_Linear, self).__init__()
self.glu_type = glu_type
self.output_dim = output_dim
if glu_type == "sigmoid":
self.glu_act = torch.nn.Sigmoid()
elif glu_type == "swish":
self.glu_act = Swish()
elif glu_type == "relu":
self.glu_act = torch.nn.ReLU()
elif glu_type == "gelu":
self.glu_act = torch.nn.GELU()
if bias_in_glu:
self.linear = nn.Linear(input_dim, output_dim * 2, True)
else:
self.linear = nn.Linear(input_dim, output_dim * 2, False)
def forward(self, x):
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
x = self.linear(x)
if self.glu_type == "bilinear":
x = x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2]
else:
x = x[:, :, 0 : self.output_dim] * self.glu_act(x[:, :, self.output_dim : self.output_dim * 2])
return x
def gelu_accurate(x):
if not hasattr(gelu_accurate, "_a"):
gelu_accurate._a = math.sqrt(2 / math.pi)
return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
def gelu(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(x.float()).type_as(x)
def get_activation_fn(activation: str):
"""Returns the activation function corresponding to `activation`"""
if activation == "relu":
return F.relu
elif activation == "gelu":
return gelu
elif activation == "gelu_fast":
warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate")
return gelu_accurate
elif activation == "gelu_accurate":
return gelu_accurate
elif activation == "tanh":
return torch.tanh
elif activation == "linear":
return lambda x: x
elif activation == "glu":
return lambda x: x
else:
raise RuntimeError("--activation-fn {} not supported".format(activation))
def init_bert_params(module):
"""
Initialize the weights specific to the BERT Model.
This overrides the default initializations depending on the specified arguments.
1. If normal_init_linear_weights is set then weights of linear
layer will be initialized using the normal distribution and
bais will be set to the specified value.
2. If normal_init_embed_weights is set then weights of embedding
layer will be initialized using the normal distribution.
3. If normal_init_proj_weights is set then weights of
in_project_weight for MultiHeadAttention initialized using
the normal distribution (to be validated).
"""
def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)
def quant_noise(module, p, block_size):
"""
Wraps modules and applies quantization noise to the weights for
subsequent quantization with Iterative Product Quantization as
described in "Training with Quantization Noise for Extreme Model Compression"
Args:
- module: nn.Module
- p: amount of Quantization Noise
- block_size: size of the blocks for subsequent quantization with iPQ
Remarks:
- Module weights must have the right sizes wrt the block size
- Only Linear, Embedding and Conv2d modules are supported for the moment
- For more detail on how to quantize by blocks with convolutional weights,
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
- We implement the simplest form of noise here as stated in the paper
which consists in randomly dropping blocks
"""
# if no quantization noise, don't register hook
if p <= 0:
return module
# supported modules
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
# test whether module.weight has the right sizes wrt block_size
is_conv = module.weight.ndim == 4
# 2D matrix
if not is_conv:
assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes"
# 4D matrix
else:
# 1x1 convolutions
if module.kernel_size == (1, 1):
assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes"
# regular convolutions
else:
k = module.kernel_size[0] * module.kernel_size[1]
assert k % block_size == 0, "Kernel size must be a multiple of block size"
def _forward_pre_hook(mod, input):
# no noise for evaluation
if mod.training:
if not is_conv:
# gather weight and sizes
weight = mod.weight
in_features = weight.size(1)
out_features = weight.size(0)
# split weight matrix into blocks and randomly drop selected blocks
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
else:
# gather weight and sizes
weight = mod.weight
in_channels = mod.in_channels
out_channels = mod.out_channels
# split weight matrix into blocks and randomly drop selected blocks
if mod.kernel_size == (1, 1):
mask = torch.zeros(
int(in_channels // block_size * out_channels),
device=weight.device,
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
else:
mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
mask.bernoulli_(p)
mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
# scale weights and apply mask
mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0)
module.register_forward_pre_hook(_forward_pre_hook)
return module
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
has_relative_attention_bias=False,
num_buckets=32,
max_distance=128,
gru_rel_pos=False,
rescale_init=False,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout_module = nn.Dropout(dropout)
self.has_relative_attention_bias = has_relative_attention_bias
self.num_buckets = num_buckets
self.max_distance = max_distance
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
self.head_dim = embed_dim // num_heads
self.q_head_dim = self.head_dim
self.k_head_dim = self.head_dim
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
k_bias = True
if rescale_init:
k_bias = False
k_embed_dim = embed_dim
q_embed_dim = embed_dim
self.k_proj = quant_noise(nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size)
self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
self.q_proj = quant_noise(nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size)
self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.gru_rel_pos = gru_rel_pos
if self.gru_rel_pos:
self.grep_linear = nn.Linear(self.q_head_dim, 8)
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
self.reset_parameters()
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
if self.has_relative_attention_bias:
nn.init.xavier_normal_(self.relative_attention_bias.weight)
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
num_buckets = self.num_buckets
max_distance = self.max_distance
relative_buckets = 0
if bidirectional:
num_buckets = num_buckets // 2
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
relative_positions = torch.abs(relative_positions)
else:
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
max_exact = num_buckets // 2
is_small = relative_positions < max_exact
relative_postion_if_large = max_exact + (
torch.log(relative_positions.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_postion_if_large = torch.min(
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length):
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position
relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(relative_position_bucket)
values = values.permute([2, 0, 1])
return values
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
position_bias: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if key is not None:
src_len, key_bsz, _ = key.size()
if not torch.jit.is_scripting():
assert key_bsz == bsz
assert value is not None
assert src_len, bsz == value.shape[:2]
if self.has_relative_attention_bias and position_bias is None:
position_bias = self.compute_bias(tgt_len, src_len)
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
if (
not is_tpu # don't use PyTorch version on TPUs
and incremental_state is None
and not static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and not torch.jit.is_scripting()
and self.q_head_dim == self.head_dim
):
assert key is not None and value is not None
assert attn_mask is None
attn_mask_rel_pos = None
if position_bias is not None:
attn_mask_rel_pos = position_bias
if self.gru_rel_pos:
query_layer = query.transpose(0, 1)
new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
query_layer = query_layer.view(*new_x_shape)
query_layer = query_layer.permute(0, 2, 1, 3)
_B, _H, _L, __ = query_layer.size()
gate_a, gate_b = torch.sigmoid(
self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
).chunk(2, dim=-1)
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
k_proj_bias = self.k_proj.bias
if k_proj_bias is None:
k_proj_bias = torch.zeros_like(self.q_proj.bias)
x, attn = F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training,
# self.training or self.dropout_module.apply_during_inference,
key_padding_mask,
need_weights,
attn_mask_rel_pos,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
return x, attn, position_bias
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
src_len = k.size(1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
assert k.size(1) == src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
],
dim=1,
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v, position_bias
if position_bias is not None:
if self.gru_rel_pos == 1:
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
_B, _H, _L, __ = query_layer.size()
gate_a, gate_b = torch.sigmoid(
self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
).chunk(2, dim=-1)
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
position_bias = position_bias.view(attn_weights.size())
attn_weights = attn_weights + position_bias
attn_weights_float = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights, position_bias
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
if src_len > prev_key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - prev_key_padding_mask.size(1)),
device=prev_key_padding_mask.device,
)
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
else:
new_key_padding_mask = prev_key_padding_mask.float()
elif key_padding_mask is not None:
if src_len > key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - key_padding_mask.size(1)),
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
else:
new_key_padding_mask = key_padding_mask.float()
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[Tensor]] = {}
return empty_result
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
return attn_weights

View File

@ -0,0 +1,719 @@
# --------------------------------------------------------
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import logging
import math
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm
from TTS.vc.modules.freevc.wavlm.modules import (
Fp32GroupNorm,
Fp32LayerNorm,
GLU_Linear,
GradMultiply,
MultiheadAttention,
SamePad,
TransposeLast,
get_activation_fn,
init_bert_params,
)
logger = logging.getLogger(__name__)
def compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[torch.Tensor],
mask_prob: float,
mask_length: int,
mask_type: str = "static",
mask_other: float = 0.0,
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_type: how to compute mask lengths
static = fixed size
uniform = sample from uniform distribution [mask_other, mask_length*2]
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
poisson = sample from possion distribution with lambda = mask length
min_masks: minimum number of masked spans
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
"""
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
all_num_mask = int(
# add a random number for probabilistic rounding
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(bsz):
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
sz = all_sz
num_mask = all_num_mask
if mask_type == "static":
lengths = np.full(num_mask, mask_length)
elif mask_type == "uniform":
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
elif mask_type == "normal":
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
lengths = [max(1, int(round(x))) for x in lengths]
elif mask_type == "poisson":
lengths = np.random.poisson(mask_length, size=num_mask)
lengths = [int(round(x)) for x in lengths]
else:
raise Exception("unknown mask selection " + mask_type)
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
span_start = np.random.randint(s, e - length)
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - keep_length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
lens = np.fromiter(
(e - s if e - s >= length + min_space else 0 for s, e in parts),
np.int,
)
l_sum = np.sum(lens)
if l_sum == 0:
break
probs = lens / np.sum(lens)
c = np.random.choice(len(parts), p=probs)
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = np.asarray(mask_idc)
else:
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
mask[i, mask_idc] = True
return mask
class WavLMConfig:
def __init__(self, cfg=None):
self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
self.encoder_layers: int = 12 # num encoder layers in the transformer
self.encoder_embed_dim: int = 768 # encoder embedding dimension
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
self.encoder_attention_heads: int = 12 # num encoder attention heads
self.activation_fn: str = "gelu" # activation function to use
self.layer_norm_first: bool = False # apply layernorm first in the transformer
self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
self.conv_bias: bool = False # include bias in conv encoder
self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
# dropouts
self.dropout: float = 0.1 # dropout probability for the transformer
self.attention_dropout: float = 0.1 # dropout probability for attention weights
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
# masking
self.mask_length: int = 10 # mask length
self.mask_prob: float = 0.65 # probability of replacing a token with mask
self.mask_selection: str = "static" # how to choose mask length
self.mask_other: float = (
0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
)
self.no_mask_overlap: bool = False # whether to allow masks to overlap
self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
# channel masking
self.mask_channel_length: int = 10 # length of the mask for features (channels)
self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
self.mask_channel_other: float = (
0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
)
self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
# positional embeddings
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
# relative position embedding
self.relative_position_embedding: bool = False # apply relative position embedding
self.num_buckets: int = 320 # number of buckets for relative position embedding
self.max_distance: int = 1280 # maximum distance for relative position embedding
self.gru_rel_pos: bool = False # apply gated relative position embedding
if cfg is not None:
self.update(cfg)
def update(self, cfg: dict):
self.__dict__.update(cfg)
class WavLM(nn.Module):
def __init__(
self,
cfg: WavLMConfig,
) -> None:
super().__init__()
logger.info(f"WavLM Config: {cfg.__dict__}")
self.cfg = cfg
feature_enc_layers = eval(cfg.conv_feature_layers)
self.embed = feature_enc_layers[-1][0]
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
mode=cfg.extractor_mode,
conv_bias=cfg.conv_bias,
)
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None
)
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.feature_grad_mult = cfg.feature_grad_mult
self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
def apply_mask(self, x, padding_mask):
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1)
x[mask_channel_indices] = 0
return x, mask_indices
def forward_padding_mask(
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
# padding_mask = padding_mask.all(-1)
padding_mask = padding_mask.any(-1)
return padding_mask
def extract_features(
self,
source: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
mask: bool = False,
ret_conv: bool = False,
output_layer: Optional[int] = None,
ret_layer_results: bool = False,
):
if self.feature_grad_mult > 0:
features = self.feature_extractor(source)
if self.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.feature_grad_mult)
else:
with torch.no_grad():
features = self.feature_extractor(source)
features = features.transpose(1, 2)
features = self.layer_norm(features)
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
if mask:
x, mask_indices = self.apply_mask(features, padding_mask)
else:
x = features
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x, layer_results = self.encoder(
x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1
)
res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
feature = res["features"] if ret_conv else res["x"]
if ret_layer_results:
feature = (feature, res["layer_results"])
return feature, res["padding_mask"]
class ConvFeatureExtractionModel(nn.Module):
def __init__(
self,
conv_layers: List[Tuple[int, int, int]],
dropout: float = 0.0,
mode: str = "default",
conv_bias: bool = False,
conv_type: str = "default",
):
super().__init__()
assert mode in {"default", "layer_norm"}
def block(
n_in,
n_out,
k,
stride,
is_layer_norm=False,
is_group_norm=False,
conv_bias=False,
):
def make_conv():
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
nn.init.kaiming_normal_(conv.weight)
return conv
assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive"
if is_layer_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
nn.Sequential(
TransposeLast(),
Fp32LayerNorm(dim, elementwise_affine=True),
TransposeLast(),
),
nn.GELU(),
)
elif is_group_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
Fp32GroupNorm(dim, dim, affine=True),
nn.GELU(),
)
else:
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
self.conv_type = conv_type
if self.conv_type == "default":
in_d = 1
self.conv_layers = nn.ModuleList()
for i, cl in enumerate(conv_layers):
assert len(cl) == 3, "invalid conv definition: " + str(cl)
(dim, k, stride) = cl
self.conv_layers.append(
block(
in_d,
dim,
k,
stride,
is_layer_norm=mode == "layer_norm",
is_group_norm=mode == "default" and i == 0,
conv_bias=conv_bias,
)
)
in_d = dim
elif self.conv_type == "conv2d":
in_d = 1
self.conv_layers = nn.ModuleList()
for i, cl in enumerate(conv_layers):
assert len(cl) == 3
(dim, k, stride) = cl
self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride))
self.conv_layers.append(torch.nn.ReLU())
in_d = dim
elif self.conv_type == "custom":
in_d = 1
idim = 80
self.conv_layers = nn.ModuleList()
for i, cl in enumerate(conv_layers):
assert len(cl) == 3
(dim, k, stride) = cl
self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride, padding=1))
self.conv_layers.append(torch.nn.LayerNorm([dim, idim]))
self.conv_layers.append(torch.nn.ReLU())
in_d = dim
if (i + 1) % 2 == 0:
self.conv_layers.append(torch.nn.MaxPool2d(2, stride=2, ceil_mode=True))
idim = int(math.ceil(idim / 2))
else:
pass
def forward(self, x, mask=None):
# BxT -> BxCxT
x = x.unsqueeze(1)
if self.conv_type == "custom":
for conv in self.conv_layers:
if isinstance(conv, nn.LayerNorm):
x = x.transpose(1, 2)
x = conv(x).transpose(1, 2)
else:
x = conv(x)
x = x.transpose(2, 3).contiguous()
x = x.view(x.size(0), -1, x.size(-1))
else:
for conv in self.conv_layers:
x = conv(x)
if self.conv_type == "conv2d":
b, c, t, f = x.size()
x = x.transpose(2, 3).contiguous().view(b, c * f, t)
return x
class TransformerEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.dropout = args.dropout
self.embedding_dim = args.encoder_embed_dim
self.pos_conv = nn.Conv1d(
self.embedding_dim,
self.embedding_dim,
kernel_size=args.conv_pos,
padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
if hasattr(args, "relative_position_embedding"):
self.relative_position_embedding = args.relative_position_embedding
self.num_buckets = args.num_buckets
self.max_distance = args.max_distance
else:
self.relative_position_embedding = False
self.num_buckets = 0
self.max_distance = 0
self.layers = nn.ModuleList(
[
TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
has_relative_attention_bias=(self.relative_position_embedding and i == 0),
num_buckets=self.num_buckets,
max_distance=self.max_distance,
gru_rel_pos=args.gru_rel_pos,
)
for i in range(args.encoder_layers)
]
)
self.layer_norm_first = args.layer_norm_first
self.layer_norm = LayerNorm(self.embedding_dim)
self.layerdrop = args.encoder_layerdrop
self.apply(init_bert_params)
def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
return x, layer_results
def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
if padding_mask is not None:
x[padding_mask] = 0
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x += x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
layer_results = []
z = None
if tgt_layer is not None:
layer_results.append((x, z))
r = None
pos_bias = None
for i, layer in enumerate(self.layers):
dropout_probability = np.random.random()
if not self.training or (dropout_probability > self.layerdrop):
x, z, pos_bias = layer(
x,
self_attn_padding_mask=padding_mask,
need_weights=False,
self_attn_mask=streaming_mask,
pos_bias=pos_bias,
)
if tgt_layer is not None:
layer_results.append((x, z))
if i == tgt_layer:
r = x
break
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
return x, layer_results
class TransformerSentenceEncoderLayer(nn.Module):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: float = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
has_relative_attention_bias: bool = False,
num_buckets: int = 0,
max_distance: int = 0,
rescale_init: bool = False,
gru_rel_pos: bool = False,
) -> None:
super().__init__()
# Initialize parameters
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
# Initialize blocks
self.activation_name = activation_fn
self.activation_fn = get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
has_relative_attention_bias=has_relative_attention_bias,
num_buckets=num_buckets,
max_distance=max_distance,
rescale_init=rescale_init,
gru_rel_pos=gru_rel_pos,
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(self.activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
# layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
if self.activation_name == "glu":
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
else:
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = LayerNorm(self.embedding_dim)
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
pos_bias=None,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=False,
attn_mask=self_attn_mask,
position_bias=pos_bias,
)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
if self.activation_name == "glu":
x = self.fc1(x)
else:
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
else:
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=need_weights,
attn_mask=self_attn_mask,
position_bias=pos_bias,
)
x = self.dropout1(x)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
if self.activation_name == "glu":
x = self.fc1(x)
else:
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
x = self.final_layer_norm(x)
return x, attn, pos_bias

View File

@ -18,6 +18,8 @@ class BaseVocoder(BaseTrainerModel):
- 1D tensors `batch x 1`
"""
MODEL_TYPE = "vocoder"
def __init__(self, config):
super().__init__()
self._set_model_args(config)

View File

@ -36,8 +36,8 @@ Run a tts and a vocoder model from the released model list. Note that not every
```bash
tts --text "Text for TTS" \
--model_name "<type>/<language>/<dataset>/<model_name>" \
--vocoder_name "<type>/<language>/<dataset>/<model_name>" \
--model_name "tts_models/<language>/<dataset>/<model_name>" \
--vocoder_name "vocoder_models/<language>/<dataset>/<model_name>" \
--out_path folder/to/save/output.wav
```
@ -64,8 +64,17 @@ tts --text "Text for TTS" \
Run a multi-speaker TTS model from the released models list.
```bash
tts --model_name "<type>/<language>/<dataset>/<model_name>" --list_speaker_idxs # list the possible speaker IDs.
tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx "<speaker_id>"
tts --model_name "tts_models/<language>/<dataset>/<model_name>" --list_speaker_idxs # list the possible speaker IDs.
tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "tts_models/<language>/<dataset>/<model_name>" --speaker_idx "<speaker_id>"
```
Run a released voice conversion model
```bash
tts --model_name "voice_conversion/<language>/<dataset>/<model_name>"
--source_wav "my/source/speaker/audio.wav"
--target_wav "my/target/speaker/audio.wav"
--out_path folder/to/save/output.wav
```
**Note:** You can use ```./TTS/bin/synthesize.py``` if you prefer running ```tts``` from the TTS project folder.
@ -135,4 +144,23 @@ tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_
tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="output.wav")
tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav")
```
Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav`
```python
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True)
tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav")
```
Example voice cloning by a single speaker TTS model combining with the voice conversion model. This way, you can
clone voices by using any model in 🐸TTS.
```python
tts = TTS("tts_models/de/thorsten/tacotron2-DDC")
tts.tts_with_vc_to_file(
"Wie sage ich auf Italienisch, dass ich dich liebe?",
speaker_wav="target/speaker.wav",
file_path="ouptut.wav"
)
```

View File

@ -14,6 +14,7 @@ tqdm
anyascii
pyyaml
fsspec>=2021.04.0
aiohttp
packaging
# deps for examples
flask

View File

@ -28,7 +28,7 @@ class TTSTest(unittest.TestCase):
def test_multi_speaker_multi_lingual_model(self):
tts = TTS()
tts.load_model_by_name(tts.models[0]) # YourTTS
tts.load_tts_model_by_name(tts.models[0]) # YourTTS
tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path=OUTPUT_PATH)
self.assertTrue(tts.is_multi_speaker)
@ -38,5 +38,5 @@ class TTSTest(unittest.TestCase):
def test_voice_cloning(self): # pylint: disable=no-self-use
tts = TTS()
tts.load_model_by_name("tts_models/multilingual/multi-dataset/your_tts")
tts.load_tts_model_by_name("tts_models/multilingual/multi-dataset/your_tts")
tts.tts_to_file("Hello world!", speaker_wav=cloning_test_wav_path, language="en", file_path=OUTPUT_PATH)

View File

View File

@ -0,0 +1,135 @@
import os
import unittest
import torch
from tests import get_tests_input_path
from TTS.vc.configs.freevc_config import FreeVCConfig
from TTS.vc.models.freevc import FreeVC
# pylint: disable=unused-variable
# pylint: disable=no-self-use
torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
c = FreeVCConfig()
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
BATCH_SIZE = 3
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class TestFreeVC(unittest.TestCase):
def _create_inputs(self, config, batch_size=2):
input_dummy = torch.rand(batch_size, 30 * config.audio["hop_length"]).to(device)
input_lengths = torch.randint(100, 30 * config.audio["hop_length"], (batch_size,)).long().to(device)
input_lengths[-1] = 30 * config.audio["hop_length"]
spec = torch.rand(batch_size, 30, config.audio["filter_length"] // 2 + 1).to(device)
mel = torch.rand(batch_size, 30, config.audio["n_mel_channels"]).to(device)
spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
spec_lengths[-1] = spec.size(2)
waveform = torch.rand(batch_size, spec.size(2) * config.audio["hop_length"]).to(device)
return input_dummy, input_lengths, mel, spec, spec_lengths, waveform
@staticmethod
def _create_inputs_inference():
source_wav = torch.rand(16000)
target_wav = torch.rand(16000)
return source_wav, target_wav
@staticmethod
def _check_parameter_changes(model, model_ref):
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref
)
count += 1
def test_methods(self):
config = FreeVCConfig()
model = FreeVC(config).to(device)
model.load_pretrained_speaker_encoder()
model.init_multispeaker(config)
wavlm_feats = model.extract_wavlm_features(torch.rand(1, 16000))
assert wavlm_feats.shape == (1, 1024, 49), wavlm_feats.shape
def test_load_audio(self):
config = FreeVCConfig()
model = FreeVC(config).to(device)
wav = model.load_audio(WAV_FILE)
wav2 = model.load_audio(wav)
assert all(torch.isclose(wav, wav2))
def _test_forward(self, batch_size):
# create model
config = FreeVCConfig()
model = FreeVC(config).to(device)
model.train()
print(" > Num parameters for FreeVC model:%s" % (count_parameters(model)))
_, _, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size)
wavlm_vec = model.extract_wavlm_features(waveform)
wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long)
y = model.forward(wavlm_vec, spec, None, mel, spec_lengths, wavlm_vec_lengths)
# TODO: assert with training implementation
def test_forward(self):
self._test_forward(1)
self._test_forward(3)
def _test_inference(self, batch_size):
config = FreeVCConfig()
model = FreeVC(config).to(device)
model.eval()
_, _, mel, _, _, waveform = self._create_inputs(config, batch_size)
wavlm_vec = model.extract_wavlm_features(waveform)
wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long)
output_wav = model.inference(wavlm_vec, None, mel, wavlm_vec_lengths)
assert (
output_wav.shape[-1] // config.audio.hop_length == wavlm_vec.shape[-1]
), f"{output_wav.shape[-1] // config.audio.hop_length} != {wavlm_vec.shape}"
def test_inference(self):
self._test_inference(1)
self._test_inference(3)
def test_voice_conversion(self):
config = FreeVCConfig()
model = FreeVC(config).to(device)
model.eval()
source_wav, target_wav = self._create_inputs_inference()
output_wav = model.voice_conversion(source_wav, target_wav)
assert (
output_wav.shape[0] + config.audio.hop_length == source_wav.shape[0]
), f"{output_wav.shape} != {source_wav.shape}"
def test_train_step(self):
...
def test_train_eval_log(self):
...
def test_test_run(self):
...
def test_load_checkpoint(self):
...
def test_get_criterion(self):
...
def test_init_from_config(self):
...

View File

@ -51,6 +51,13 @@ def run_models(offset=0, step=1):
# remove downloaded models
shutil.rmtree(local_download_dir)
shutil.rmtree(get_user_data_dir("tts"))
elif "voice_conversion_models" in model_name:
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
reference_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0032.wav")
run_cli(
f"tts --model_name {model_name} "
f'--out_path "{output_path}" --source_wav "{speaker_wav}" --target_wav "{reference_wav}" --progress_bar False'
)
else:
# only download the model
manager.download_model(model_name)