mirror of https://github.com/coqui-ai/TTS.git
Implement FreeVC (#2451)
* Update .gitignore * Draft FreeVC implementation * Tests and relevant updates * Update API tests * Add missings * Update requirements * :( * Lazy handle for vc * Update docs for voice conversion * Make stylepull/2462/head
parent
090cadf270
commit
d309f50e53
|
@ -137,7 +137,7 @@ VCTK-Corpus-removed-silence/*
|
|||
# ignore training logs
|
||||
trainer_*_log.txt
|
||||
|
||||
# files used internally fro dev, test etc.
|
||||
# files used internally for dev, test etc.
|
||||
tests/outputs/*
|
||||
tests/train_outputs/*
|
||||
TODO.txt
|
||||
|
@ -168,3 +168,4 @@ internal/*
|
|||
wandb
|
||||
depot/*
|
||||
coqui_recipes/*
|
||||
local_scripts/*
|
||||
|
|
|
@ -802,5 +802,18 @@
|
|||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"voice_conversion_models":{
|
||||
"multilingual":{
|
||||
"vctk":{
|
||||
"freevc24":{
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.13.0_models/voice_conversion_models--multilingual--vctk--freevc24.zip",
|
||||
"description": "FreeVC model trained on VCTK dataset from https://github.com/OlaWod/FreeVC",
|
||||
"author": "Jing-Yi Li @OlaWod",
|
||||
"license": "MIT",
|
||||
"commit": null
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
89
TTS/api.py
89
TTS/api.py
|
@ -1,5 +1,7 @@
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from TTS.utils.audio.numpy_transforms import save_wav
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
|
||||
|
@ -49,11 +51,14 @@ class TTS:
|
|||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||
"""
|
||||
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)
|
||||
|
||||
self.synthesizer = None
|
||||
self.voice_converter = None
|
||||
|
||||
if model_name:
|
||||
self.load_model_by_name(model_name, gpu)
|
||||
self.load_tts_model_by_name(model_name, gpu)
|
||||
if model_path:
|
||||
self.load_model_by_path(
|
||||
self.load_tts_model_by_path(
|
||||
model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu
|
||||
)
|
||||
|
||||
|
@ -96,12 +101,22 @@ class TTS:
|
|||
|
||||
def download_model_by_name(self, model_name: str):
|
||||
model_path, config_path, model_item = self.manager.download_model(model_name)
|
||||
if model_item["default_vocoder"] is None:
|
||||
if model_item.get("default_vocoder") is None:
|
||||
return model_path, config_path, None, None
|
||||
vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"])
|
||||
return model_path, config_path, vocoder_path, vocoder_config_path
|
||||
|
||||
def load_model_by_name(self, model_name: str, gpu: bool = False):
|
||||
def load_vc_model_by_name(self, model_name: str, gpu: bool = False):
|
||||
"""Load one of the voice conversion models by name.
|
||||
|
||||
Args:
|
||||
model_name (str): Model name to load. You can list models by ```tts.models```.
|
||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||
"""
|
||||
model_path, config_path, _, _ = self.download_model_by_name(model_name)
|
||||
self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu)
|
||||
|
||||
def load_tts_model_by_name(self, model_name: str, gpu: bool = False):
|
||||
"""Load one of 🐸TTS models by name.
|
||||
|
||||
Args:
|
||||
|
@ -127,7 +142,7 @@ class TTS:
|
|||
use_cuda=gpu,
|
||||
)
|
||||
|
||||
def load_model_by_path(
|
||||
def load_tts_model_by_path(
|
||||
self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False
|
||||
):
|
||||
"""Load a model from a path.
|
||||
|
@ -219,3 +234,67 @@ class TTS:
|
|||
"""
|
||||
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav)
|
||||
self.synthesizer.save_wav(wav=wav, path=file_path)
|
||||
|
||||
def voice_conversion(
|
||||
self,
|
||||
sourve_wav: str,
|
||||
target_wav: str,
|
||||
):
|
||||
"""Voice conversion with FreeVC. Convert source wav to target speaker.
|
||||
|
||||
Args:
|
||||
source_wav (str):
|
||||
Path to the source wav file.
|
||||
target_wav (str):
|
||||
Path to the target wav file.
|
||||
"""
|
||||
wav = self.synthesizer.voice_conversion(source_wav=sourve_wav, target_wav=target_wav)
|
||||
return wav
|
||||
|
||||
def tts_with_vc(self, text: str, language: str = None, speaker_wav: str = None):
|
||||
"""Convert text to speech with voice conversion.
|
||||
|
||||
It combines tts with voice conversion to fake voice cloning.
|
||||
|
||||
- Convert text to speech with tts.
|
||||
- Convert the output wav to target speaker with voice conversion.
|
||||
|
||||
Args:
|
||||
text (str):
|
||||
Input text to synthesize.
|
||||
language (str, optional):
|
||||
Language code for multi-lingual models. You can check whether loaded model is multi-lingual
|
||||
`tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
|
||||
speaker_wav (str, optional):
|
||||
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
|
||||
Defaults to None.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
||||
# Lazy code... save it to a temp file to resample it while reading it for VC
|
||||
self.tts_to_file(text=text, speaker=None, language=language, file_path=fp.name)
|
||||
if self.voice_converter is None:
|
||||
self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24")
|
||||
wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav)
|
||||
return wav
|
||||
|
||||
def tts_with_vc_to_file(
|
||||
self, text: str, language: str = None, speaker_wav: str = None, file_path: str = "output.wav"
|
||||
):
|
||||
"""Convert text to speech with voice conversion and save to file.
|
||||
|
||||
Check `tts_with_vc` for more details.
|
||||
|
||||
Args:
|
||||
text (str):
|
||||
Input text to synthesize.
|
||||
language (str, optional):
|
||||
Language code for multi-lingual models. You can check whether loaded model is multi-lingual
|
||||
`tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
|
||||
speaker_wav (str, optional):
|
||||
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
|
||||
Defaults to None.
|
||||
file_path (str, optional):
|
||||
Output file path. Defaults to "output.wav".
|
||||
"""
|
||||
wav = self.tts_with_vc(text=text, language=language, speaker_wav=speaker_wav)
|
||||
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
|
||||
|
|
|
@ -100,6 +100,12 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
```
|
||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
|
||||
```
|
||||
|
||||
### Voice Conversion Models
|
||||
|
||||
```
|
||||
$ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
|
||||
```
|
||||
"""
|
||||
# We remove Markdown code formatting programmatically here to allow us to copy-and-paste from main README to keep
|
||||
# documentation in sync more easily.
|
||||
|
@ -245,6 +251,20 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
default=True,
|
||||
)
|
||||
|
||||
# voice conversion args
|
||||
parser.add_argument(
|
||||
"--source_wav",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Original audio file to convert in the voice of the target_wav",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_wav",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Target audio file to convert in the voice of the source_wav",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# print the description if either text or list_models is not set
|
||||
|
@ -256,6 +276,8 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
args.reference_wav,
|
||||
args.model_info_by_idx,
|
||||
args.model_info_by_name,
|
||||
args.source_wav,
|
||||
args.target_wav,
|
||||
]
|
||||
if not any(check_args):
|
||||
parser.parse_args(["-h"])
|
||||
|
@ -264,21 +286,23 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
path = Path(__file__).parent / "../.models.json"
|
||||
manager = ModelManager(path, progress_bar=args.progress_bar)
|
||||
|
||||
model_path = None
|
||||
config_path = None
|
||||
tts_path = None
|
||||
tts_config_path = None
|
||||
speakers_file_path = None
|
||||
language_ids_file_path = None
|
||||
vocoder_path = None
|
||||
vocoder_config_path = None
|
||||
encoder_path = None
|
||||
encoder_config_path = None
|
||||
vc_path = None
|
||||
vc_config_path = None
|
||||
|
||||
# CASE1 #list : list pre-trained TTS models
|
||||
if args.list_models:
|
||||
manager.list_models()
|
||||
sys.exit()
|
||||
|
||||
# CASE2 #info : model info of pre-trained TTS models
|
||||
# CASE2 #info : model info for pre-trained TTS models
|
||||
if args.model_info_by_idx:
|
||||
model_query = args.model_info_by_idx
|
||||
manager.model_info_by_idx(model_query)
|
||||
|
@ -292,15 +316,27 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
# CASE3: load pre-trained model paths
|
||||
if args.model_name is not None and not args.model_path:
|
||||
model_path, config_path, model_item = manager.download_model(args.model_name)
|
||||
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
||||
|
||||
# tts model
|
||||
if model_item["model_type"] == "tts_models":
|
||||
tts_path = model_path
|
||||
tts_config_path = config_path
|
||||
if "default_vocoder" in model_item:
|
||||
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
||||
|
||||
# voice conversion model
|
||||
if model_item["model_type"] == "voice_conversion_models":
|
||||
vc_path = model_path
|
||||
vc_config_path = config_path
|
||||
|
||||
# load vocoder
|
||||
if args.vocoder_name is not None and not args.vocoder_path:
|
||||
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
|
||||
|
||||
# CASE4: set custom model paths
|
||||
if args.model_path is not None:
|
||||
model_path = args.model_path
|
||||
config_path = args.config_path
|
||||
tts_path = args.model_path
|
||||
tts_config_path = args.config_path
|
||||
speakers_file_path = args.speakers_file_path
|
||||
language_ids_file_path = args.language_ids_file_path
|
||||
|
||||
|
@ -314,14 +350,16 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
|
||||
# load models
|
||||
synthesizer = Synthesizer(
|
||||
model_path,
|
||||
config_path,
|
||||
tts_path,
|
||||
tts_config_path,
|
||||
speakers_file_path,
|
||||
language_ids_file_path,
|
||||
vocoder_path,
|
||||
vocoder_config_path,
|
||||
encoder_path,
|
||||
encoder_config_path,
|
||||
vc_path,
|
||||
vc_config_path,
|
||||
args.use_cuda,
|
||||
)
|
||||
|
||||
|
@ -354,16 +392,22 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
print(" > Text: {}".format(args.text))
|
||||
|
||||
# kick it
|
||||
wav = synthesizer.tts(
|
||||
args.text,
|
||||
args.speaker_idx,
|
||||
args.language_idx,
|
||||
args.speaker_wav,
|
||||
reference_wav=args.reference_wav,
|
||||
style_wav=args.capacitron_style_wav,
|
||||
style_text=args.capacitron_style_text,
|
||||
reference_speaker_name=args.reference_speaker_idx,
|
||||
)
|
||||
if tts_path is not None:
|
||||
wav = synthesizer.tts(
|
||||
args.text,
|
||||
args.speaker_idx,
|
||||
args.language_idx,
|
||||
args.speaker_wav,
|
||||
reference_wav=args.reference_wav,
|
||||
style_wav=args.capacitron_style_wav,
|
||||
style_text=args.capacitron_style_text,
|
||||
reference_speaker_name=args.reference_speaker_idx,
|
||||
)
|
||||
elif vc_path is not None:
|
||||
wav = synthesizer.voice_conversion(
|
||||
source_wav=args.source_wav,
|
||||
target_wav=args.target_wav,
|
||||
)
|
||||
|
||||
# save the results
|
||||
print(" > Saving output to {}".format(args.out_path))
|
||||
|
|
|
@ -37,7 +37,7 @@ def register_config(model_name: str) -> Coqpit:
|
|||
"""
|
||||
config_class = None
|
||||
config_name = model_name + "_config"
|
||||
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs"]
|
||||
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs", "TTS.vc.configs"]
|
||||
for path in paths:
|
||||
try:
|
||||
config_class = find_module(path, config_name)
|
||||
|
|
|
@ -27,6 +27,8 @@ class BaseTTS(BaseTrainerModel):
|
|||
It defines common `tts` specific functions on top of `Model` implementation.
|
||||
"""
|
||||
|
||||
MODEL_TYPE = "tts"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Coqpit,
|
||||
|
|
|
@ -85,6 +85,7 @@ def to_camel(text):
|
|||
text = text.capitalize()
|
||||
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
text = text.replace("Tts", "TTS")
|
||||
text = text.replace("vc", "VC")
|
||||
return text
|
||||
|
||||
|
||||
|
|
|
@ -185,6 +185,13 @@ class ModelManager(object):
|
|||
"""
|
||||
return self._list_for_model_type("vocoder_models")
|
||||
|
||||
def list_vc_models(self):
|
||||
"""Print all the voice conversion models and return a list of model names
|
||||
|
||||
Format is `language/dataset/model`
|
||||
"""
|
||||
return self._list_for_model_type("voice_conversion_models")
|
||||
|
||||
def list_langs(self):
|
||||
"""Print all the available languages"""
|
||||
print(" Name format: type/language")
|
||||
|
@ -234,6 +241,7 @@ class ModelManager(object):
|
|||
model_type, lang, dataset, model = model_name.split("/")
|
||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||
model_item = self.models_dict[model_type][lang][dataset][model]
|
||||
model_item["model_type"] = model_type
|
||||
# set the model specific output path
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if os.path.exists(output_path):
|
||||
|
|
|
@ -12,6 +12,8 @@ from TTS.tts.models import setup_model as setup_tts_model
|
|||
# pylint: disable=wildcard-import
|
||||
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import save_wav
|
||||
from TTS.vc.models import setup_model as setup_vc_model
|
||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
||||
|
||||
|
@ -19,14 +21,16 @@ from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
|||
class Synthesizer(object):
|
||||
def __init__(
|
||||
self,
|
||||
tts_checkpoint: str,
|
||||
tts_config_path: str,
|
||||
tts_checkpoint: str = "",
|
||||
tts_config_path: str = "",
|
||||
tts_speakers_file: str = "",
|
||||
tts_languages_file: str = "",
|
||||
vocoder_checkpoint: str = "",
|
||||
vocoder_config: str = "",
|
||||
encoder_checkpoint: str = "",
|
||||
encoder_config: str = "",
|
||||
vc_checkpoint: str = "",
|
||||
vc_config: str = "",
|
||||
use_cuda: bool = False,
|
||||
) -> None:
|
||||
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
|
||||
|
@ -41,12 +45,14 @@ class Synthesizer(object):
|
|||
TODO: set the segmenter based on the source language
|
||||
|
||||
Args:
|
||||
tts_checkpoint (str): path to the tts model file.
|
||||
tts_config_path (str): path to the tts config file.
|
||||
tts_checkpoint (str, optional): path to the tts model file.
|
||||
tts_config_path (str, optional): path to the tts config file.
|
||||
vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
|
||||
vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
|
||||
encoder_checkpoint (str, optional): path to the speaker encoder model file. Defaults to `""`,
|
||||
encoder_config (str, optional): path to the speaker encoder config file. Defaults to `""`,
|
||||
vc_checkpoint (str, optional): path to the voice conversion model file. Defaults to `""`,
|
||||
vc_config (str, optional): path to the voice conversion config file. Defaults to `""`,
|
||||
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
|
||||
"""
|
||||
self.tts_checkpoint = tts_checkpoint
|
||||
|
@ -57,10 +63,13 @@ class Synthesizer(object):
|
|||
self.vocoder_config = vocoder_config
|
||||
self.encoder_checkpoint = encoder_checkpoint
|
||||
self.encoder_config = encoder_config
|
||||
self.vc_checkpoint = vc_checkpoint
|
||||
self.vc_config = vc_config
|
||||
self.use_cuda = use_cuda
|
||||
|
||||
self.tts_model = None
|
||||
self.vocoder_model = None
|
||||
self.vc_model = None
|
||||
self.speaker_manager = None
|
||||
self.tts_speakers = {}
|
||||
self.language_manager = None
|
||||
|
@ -72,12 +81,19 @@ class Synthesizer(object):
|
|||
|
||||
if self.use_cuda:
|
||||
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
||||
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
|
||||
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
||||
|
||||
if tts_checkpoint:
|
||||
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
|
||||
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
||||
|
||||
if vocoder_checkpoint:
|
||||
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
|
||||
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
|
||||
|
||||
if vc_checkpoint:
|
||||
self._load_vc(vc_checkpoint, vc_config, use_cuda)
|
||||
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
||||
|
||||
@staticmethod
|
||||
def _get_segmenter(lang: str):
|
||||
"""get the sentence segmenter for the given language.
|
||||
|
@ -90,6 +106,26 @@ class Synthesizer(object):
|
|||
"""
|
||||
return pysbd.Segmenter(language=lang, clean=True)
|
||||
|
||||
def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> None:
|
||||
"""Load the voice conversion model.
|
||||
|
||||
1. Load the model config.
|
||||
2. Init the model from the config.
|
||||
3. Load the model weights.
|
||||
4. Move the model to the GPU if CUDA is enabled.
|
||||
|
||||
Args:
|
||||
vc_checkpoint (str): path to the model checkpoint.
|
||||
tts_config_path (str): path to the model config file.
|
||||
use_cuda (bool): enable/disable CUDA use.
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
self.vc_config = load_config(vc_config_path)
|
||||
self.vc_model = setup_vc_model(config=self.vc_config)
|
||||
self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint)
|
||||
if use_cuda:
|
||||
self.vc_model.cuda()
|
||||
|
||||
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
||||
"""Load the TTS model.
|
||||
|
||||
|
@ -168,7 +204,11 @@ class Synthesizer(object):
|
|||
path (str): output path to save the waveform.
|
||||
"""
|
||||
wav = np.array(wav)
|
||||
self.tts_model.ap.save_wav(wav, path, self.output_sample_rate)
|
||||
save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate)
|
||||
|
||||
def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]:
|
||||
output_wav = self.vc_model.voice_conversion(source_wav, target_wav)
|
||||
return output_wav
|
||||
|
||||
def tts(
|
||||
self,
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from TTS.vc.configs.shared_configs import BaseVCConfig
|
||||
from TTS.vc.models.freevc import FreeVCArgs, FreeVCAudioConfig, FreeVCConfig
|
|
@ -0,0 +1,155 @@
|
|||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Dict, List
|
||||
|
||||
from coqpit import Coqpit, check_argument
|
||||
|
||||
from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseVCConfig(BaseTrainingConfig):
|
||||
"""Shared parameters among all the tts models.
|
||||
|
||||
Args:
|
||||
|
||||
audio (BaseAudioConfig):
|
||||
Audio processor config object instance.
|
||||
|
||||
batch_group_size (int):
|
||||
Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence
|
||||
length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to
|
||||
prevent using the same batches for each epoch.
|
||||
|
||||
loss_masking (bool):
|
||||
enable / disable masking loss values against padded segments of samples in a batch.
|
||||
|
||||
min_text_len (int):
|
||||
Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0.
|
||||
|
||||
max_text_len (int):
|
||||
Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf").
|
||||
|
||||
min_audio_len (int):
|
||||
Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0.
|
||||
|
||||
max_audio_len (int):
|
||||
Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the
|
||||
dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an
|
||||
OOM error in training. Defaults to float("inf").
|
||||
|
||||
compute_f0 (int):
|
||||
(Not in use yet).
|
||||
|
||||
compute_energy (int):
|
||||
(Not in use yet).
|
||||
|
||||
compute_linear_spec (bool):
|
||||
If True data loader computes and returns linear spectrograms alongside the other data.
|
||||
|
||||
precompute_num_workers (int):
|
||||
Number of workers to precompute features. Defaults to 0.
|
||||
|
||||
use_noise_augment (bool):
|
||||
Augment the input audio with random noise.
|
||||
|
||||
start_by_longest (bool):
|
||||
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
|
||||
Defaults to False.
|
||||
|
||||
shuffle (bool):
|
||||
If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True.
|
||||
|
||||
drop_last (bool):
|
||||
If True, the data loader will drop the last batch if it is not complete. It helps to prevent
|
||||
issues that emerge from the partial batch statistics. Defaults to True.
|
||||
|
||||
add_blank (bool):
|
||||
Add blank characters between each other two characters. It improves performance for some models at expense
|
||||
of slower run-time due to the longer input sequence.
|
||||
|
||||
datasets (List[BaseDatasetConfig]):
|
||||
List of datasets used for training. If multiple datasets are provided, they are merged and used together
|
||||
for training.
|
||||
|
||||
optimizer (str):
|
||||
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
|
||||
Defaults to ``.
|
||||
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
|
||||
lr_scheduler (str):
|
||||
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
|
||||
`TTS.utils.training`. Defaults to ``.
|
||||
|
||||
lr_scheduler_params (dict):
|
||||
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
|
||||
|
||||
test_sentences (List[str]):
|
||||
List of sentences to be used at testing. Defaults to '[]'
|
||||
|
||||
eval_split_max_size (int):
|
||||
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
|
||||
|
||||
eval_split_size (float):
|
||||
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
|
||||
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
|
||||
|
||||
use_speaker_weighted_sampler (bool):
|
||||
Enable / Disable the batch balancer by speaker. Defaults to ```False```.
|
||||
|
||||
speaker_weighted_sampler_alpha (float):
|
||||
Number that control the influence of the speaker sampler weights. Defaults to ```1.0```.
|
||||
|
||||
use_language_weighted_sampler (bool):
|
||||
Enable / Disable the batch balancer by language. Defaults to ```False```.
|
||||
|
||||
language_weighted_sampler_alpha (float):
|
||||
Number that control the influence of the language sampler weights. Defaults to ```1.0```.
|
||||
|
||||
use_length_weighted_sampler (bool):
|
||||
Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided
|
||||
into 10 buckets considering the min and max audio of the dataset. The sampler weights will be
|
||||
computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```.
|
||||
|
||||
length_weighted_sampler_alpha (float):
|
||||
Number that control the influence of the length sampler weights. Defaults to ```1.0```.
|
||||
"""
|
||||
|
||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
# training params
|
||||
batch_group_size: int = 0
|
||||
loss_masking: bool = None
|
||||
# dataloading
|
||||
min_audio_len: int = 1
|
||||
max_audio_len: int = float("inf")
|
||||
min_text_len: int = 1
|
||||
max_text_len: int = float("inf")
|
||||
compute_f0: bool = False
|
||||
compute_energy: bool = False
|
||||
compute_linear_spec: bool = False
|
||||
precompute_num_workers: int = 0
|
||||
use_noise_augment: bool = False
|
||||
start_by_longest: bool = False
|
||||
shuffle: bool = False
|
||||
drop_last: bool = False
|
||||
# dataset
|
||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||
# optimizer
|
||||
optimizer: str = "radam"
|
||||
optimizer_params: dict = None
|
||||
# scheduler
|
||||
lr_scheduler: str = None
|
||||
lr_scheduler_params: dict = field(default_factory=lambda: {})
|
||||
# testing
|
||||
test_sentences: List[str] = field(default_factory=lambda: [])
|
||||
# evaluation
|
||||
eval_split_max_size: int = None
|
||||
eval_split_size: float = 0.01
|
||||
# weighted samplers
|
||||
use_speaker_weighted_sampler: bool = False
|
||||
speaker_weighted_sampler_alpha: float = 1.0
|
||||
use_language_weighted_sampler: bool = False
|
||||
language_weighted_sampler_alpha: float = 1.0
|
||||
use_length_weighted_sampler: bool = False
|
||||
length_weighted_sampler_alpha: float = 1.0
|
|
@ -0,0 +1,17 @@
|
|||
import importlib
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC":
|
||||
print(" > Using model: {}".format(config.model))
|
||||
# fetch the right model implementation.
|
||||
if "model" in config and config["model"].lower() == "freevc":
|
||||
MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC
|
||||
model = MyModel.init_from_config(config, samples)
|
||||
return model
|
|
@ -0,0 +1,429 @@
|
|||
import os
|
||||
import random
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||
|
||||
from TTS.model import BaseTrainerModel
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.data import get_length_balancer_weights
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseVC(BaseTrainerModel):
|
||||
"""Base `vc` class. Every new `vc` model must inherit this.
|
||||
|
||||
It defines common `vc` specific functions on top of `Model` implementation.
|
||||
"""
|
||||
|
||||
MODEL_TYPE = "vc"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Coqpit,
|
||||
ap: "AudioProcessor",
|
||||
speaker_manager: SpeakerManager = None,
|
||||
language_manager: LanguageManager = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.speaker_manager = speaker_manager
|
||||
self.language_manager = language_manager
|
||||
self._set_model_args(config)
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
|
||||
|
||||
`ModelArgs` has all the fields reuqired to initialize the model architecture.
|
||||
|
||||
`ModelConfig` has all the fields required for training, inference and containes `ModelArgs`.
|
||||
|
||||
If the config is for training with a name like "*Config", then the model args are embeded in the
|
||||
config.model_args
|
||||
|
||||
If the config is for the model with a name like "*Args", then we assign the directly.
|
||||
"""
|
||||
# don't use isintance not to import recursively
|
||||
if "Config" in config.__class__.__name__:
|
||||
self.config = config
|
||||
self.args = config.model_args
|
||||
elif "Args" in config.__class__.__name__:
|
||||
self.args = config
|
||||
else:
|
||||
raise ValueError("config must be either a *Config or *Args")
|
||||
|
||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||
`in_channels` size of the connected layers.
|
||||
|
||||
This implementation yields 3 possible outcomes:
|
||||
|
||||
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
|
||||
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
|
||||
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
|
||||
`config.d_vector_dim` or 512.
|
||||
|
||||
You can override this function for new models.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
"""
|
||||
# set number of speakers
|
||||
if self.speaker_manager is not None:
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
elif hasattr(config, "num_speakers"):
|
||||
self.num_speakers = config.num_speakers
|
||||
|
||||
# set ultimate speaker embedding size
|
||||
if config.use_speaker_embedding or config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = (
|
||||
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
||||
)
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
print(" > Init speaker_embedding layer.")
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
"""Prepare and return `aux_input` used by `forward()`"""
|
||||
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
|
||||
|
||||
def get_aux_input_from_test_sentences(self, sentence_info):
|
||||
if hasattr(self.config, "model_args"):
|
||||
config = self.config.model_args
|
||||
else:
|
||||
config = self.config
|
||||
|
||||
# extract speaker and language info
|
||||
text, speaker_name, style_wav, language_name = None, None, None, None
|
||||
|
||||
if isinstance(sentence_info, list):
|
||||
if len(sentence_info) == 1:
|
||||
text = sentence_info[0]
|
||||
elif len(sentence_info) == 2:
|
||||
text, speaker_name = sentence_info
|
||||
elif len(sentence_info) == 3:
|
||||
text, speaker_name, style_wav = sentence_info
|
||||
elif len(sentence_info) == 4:
|
||||
text, speaker_name, style_wav, language_name = sentence_info
|
||||
else:
|
||||
text = sentence_info
|
||||
|
||||
# get speaker id/d_vector
|
||||
speaker_id, d_vector, language_id = None, None, None
|
||||
if self.speaker_manager is not None:
|
||||
if config.use_d_vector_file:
|
||||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_embedding()
|
||||
else:
|
||||
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
speaker_id = self.speaker_manager.get_random_id()
|
||||
else:
|
||||
speaker_id = self.speaker_manager.name_to_id[speaker_name]
|
||||
|
||||
# get language id
|
||||
if self.language_manager is not None and config.use_language_embedding and language_name is not None:
|
||||
language_id = self.language_manager.name_to_id[language_name]
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"speaker_id": speaker_id,
|
||||
"style_wav": style_wav,
|
||||
"d_vector": d_vector,
|
||||
"language_id": language_id,
|
||||
}
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
"""Generic batch formatting for `VCDataset`.
|
||||
|
||||
You must override this if you use a custom dataset.
|
||||
|
||||
Args:
|
||||
batch (Dict): [description]
|
||||
|
||||
Returns:
|
||||
Dict: [description]
|
||||
"""
|
||||
# setup input batch
|
||||
text_input = batch["token_id"]
|
||||
text_lengths = batch["token_id_lengths"]
|
||||
speaker_names = batch["speaker_names"]
|
||||
linear_input = batch["linear"]
|
||||
mel_input = batch["mel"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
stop_targets = batch["stop_targets"]
|
||||
item_idx = batch["item_idxs"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
attn_mask = batch["attns"]
|
||||
waveform = batch["waveform"]
|
||||
pitch = batch["pitch"]
|
||||
energy = batch["energy"]
|
||||
language_ids = batch["language_ids"]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
||||
# compute durations from attention masks
|
||||
durations = None
|
||||
if attn_mask is not None:
|
||||
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
||||
for idx, am in enumerate(attn_mask):
|
||||
# compute raw durations
|
||||
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
|
||||
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
||||
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
||||
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
||||
dur[c_idxs] = counts
|
||||
# smooth the durations and set any 0 duration to 1
|
||||
# by cutting off from the largest duration indeces.
|
||||
extra_frames = dur.sum() - mel_lengths[idx]
|
||||
largest_idxs = torch.argsort(-dur)[:extra_frames]
|
||||
dur[largest_idxs] -= 1
|
||||
assert (
|
||||
dur.sum() == mel_lengths[idx]
|
||||
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||
durations[idx, : text_lengths[idx]] = dur
|
||||
|
||||
# set stop targets wrt reduction factor
|
||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||
stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_()
|
||||
|
||||
return {
|
||||
"text_input": text_input,
|
||||
"text_lengths": text_lengths,
|
||||
"speaker_names": speaker_names,
|
||||
"mel_input": mel_input,
|
||||
"mel_lengths": mel_lengths,
|
||||
"linear_input": linear_input,
|
||||
"stop_targets": stop_targets,
|
||||
"stop_target_lengths": stop_target_lengths,
|
||||
"attn_mask": attn_mask,
|
||||
"durations": durations,
|
||||
"speaker_ids": speaker_ids,
|
||||
"d_vectors": d_vectors,
|
||||
"max_text_length": float(max_text_length),
|
||||
"max_spec_length": float(max_spec_length),
|
||||
"item_idx": item_idx,
|
||||
"waveform": waveform,
|
||||
"pitch": pitch,
|
||||
"energy": energy,
|
||||
"language_ids": language_ids,
|
||||
"audio_unique_names": batch["audio_unique_names"],
|
||||
}
|
||||
|
||||
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
|
||||
weights = None
|
||||
data_items = dataset.samples
|
||||
|
||||
if getattr(config, "use_language_weighted_sampler", False):
|
||||
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Language weighted sampler with alpha:", alpha)
|
||||
weights = get_language_balancer_weights(data_items) * alpha
|
||||
|
||||
if getattr(config, "use_speaker_weighted_sampler", False):
|
||||
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Speaker weighted sampler with alpha:", alpha)
|
||||
if weights is not None:
|
||||
weights += get_speaker_balancer_weights(data_items) * alpha
|
||||
else:
|
||||
weights = get_speaker_balancer_weights(data_items) * alpha
|
||||
|
||||
if getattr(config, "use_length_weighted_sampler", False):
|
||||
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Length weighted sampler with alpha:", alpha)
|
||||
if weights is not None:
|
||||
weights += get_length_balancer_weights(data_items) * alpha
|
||||
else:
|
||||
weights = get_length_balancer_weights(data_items) * alpha
|
||||
|
||||
if weights is not None:
|
||||
sampler = WeightedRandomSampler(weights, len(weights))
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
# sampler for DDP
|
||||
if sampler is None:
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
else: # If a sampler is already defined use this sampler and DDP sampler together
|
||||
sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler
|
||||
|
||||
return sampler
|
||||
|
||||
def get_data_loader(
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: bool,
|
||||
samples: Union[List[Dict], List[List]],
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
rank: int = None,
|
||||
) -> "DataLoader":
|
||||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
# setup multi-speaker attributes
|
||||
if self.speaker_manager is not None:
|
||||
if hasattr(config, "model_args"):
|
||||
speaker_id_mapping = (
|
||||
self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None
|
||||
)
|
||||
d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None
|
||||
config.use_d_vector_file = config.model_args.use_d_vector_file
|
||||
else:
|
||||
speaker_id_mapping = self.speaker_manager.name_to_id if config.use_speaker_embedding else None
|
||||
d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None
|
||||
else:
|
||||
speaker_id_mapping = None
|
||||
d_vector_mapping = None
|
||||
|
||||
# setup multi-lingual attributes
|
||||
if self.language_manager is not None:
|
||||
language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None
|
||||
else:
|
||||
language_id_mapping = None
|
||||
|
||||
# init dataloader
|
||||
dataset = TTSDataset(
|
||||
outputs_per_step=config.r if "r" in config else 1,
|
||||
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||
compute_f0=config.get("compute_f0", False),
|
||||
f0_cache_path=config.get("f0_cache_path", None),
|
||||
compute_energy=config.get("compute_energy", False),
|
||||
energy_cache_path=config.get("energy_cache_path", None),
|
||||
samples=samples,
|
||||
ap=self.ap,
|
||||
return_wav=config.return_wav if "return_wav" in config else False,
|
||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||
min_text_len=config.min_text_len,
|
||||
max_text_len=config.max_text_len,
|
||||
min_audio_len=config.min_audio_len,
|
||||
max_audio_len=config.max_audio_len,
|
||||
phoneme_cache_path=config.phoneme_cache_path,
|
||||
precompute_num_workers=config.precompute_num_workers,
|
||||
use_noise_augment=False if is_eval else config.use_noise_augment,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_id_mapping,
|
||||
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
||||
tokenizer=None,
|
||||
start_by_longest=config.start_by_longest,
|
||||
language_id_mapping=language_id_mapping,
|
||||
)
|
||||
|
||||
# wait all the DDP process to be ready
|
||||
if num_gpus > 1:
|
||||
dist.barrier()
|
||||
|
||||
# sort input sequences from short to long
|
||||
dataset.preprocess_samples()
|
||||
|
||||
# get samplers
|
||||
sampler = self.get_sampler(config, dataset, num_gpus)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
shuffle=config.shuffle if sampler is None else False, # if there is no other sampler
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=config.drop_last, # setting this False might cause issues in AMP training.
|
||||
sampler=sampler,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def _get_test_aux_input(
|
||||
self,
|
||||
) -> Dict:
|
||||
d_vector = None
|
||||
if self.config.use_d_vector_file:
|
||||
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
|
||||
d_vector = (random.sample(sorted(d_vector), 1),)
|
||||
|
||||
aux_inputs = {
|
||||
"speaker_id": None
|
||||
if not self.config.use_speaker_embedding
|
||||
else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1),
|
||||
"d_vector": d_vector,
|
||||
"style_wav": None, # TODO: handle GST style input
|
||||
}
|
||||
return aux_inputs
|
||||
|
||||
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
||||
"""Generic test run for `vc` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
||||
Args:
|
||||
assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self._get_test_aux_input()
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
if isinstance(sen, list):
|
||||
aux_inputs = self.get_aux_input_from_test_sentences(sen)
|
||||
sen = aux_inputs["text"]
|
||||
outputs_dict = synthesis(
|
||||
self,
|
||||
sen,
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
)
|
||||
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
|
||||
)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(
|
||||
outputs_dict["outputs"]["alignments"], output_fig=False
|
||||
)
|
||||
return test_figures, test_audios
|
||||
|
||||
def on_init_start(self, trainer):
|
||||
"""Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths."""
|
||||
if self.speaker_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "speakers.pth")
|
||||
self.speaker_manager.save_ids_to_file(output_path)
|
||||
trainer.config.speakers_file = output_path
|
||||
# some models don't have `model_args` set
|
||||
if hasattr(trainer.config, "model_args"):
|
||||
trainer.config.model_args.speakers_file = output_path
|
||||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||
print(f" > `speakers.pth` is saved to {output_path}.")
|
||||
print(" > `speakers_file` is updated in the config.json.")
|
||||
|
||||
if self.language_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "language_ids.json")
|
||||
self.language_manager.save_ids_to_file(output_path)
|
||||
trainer.config.language_ids_file = output_path
|
||||
if hasattr(trainer.config, "model_args"):
|
||||
trainer.config.model_args.language_ids_file = output_path
|
||||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||
print(f" > `language_ids.json` is saved to {output_path}.")
|
||||
print(" > `language_ids_file` is updated in the config.json.")
|
|
@ -0,0 +1,833 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||
|
||||
import TTS.vc.modules.freevc.commons as commons
|
||||
import TTS.vc.modules.freevc.modules as modules
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.io import load_fsspec, save_checkpoint
|
||||
from TTS.vc.configs.shared_configs import BaseVCConfig
|
||||
from TTS.vc.models.base_vc import BaseVC
|
||||
from TTS.vc.modules.freevc.commons import get_padding, init_weights
|
||||
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
|
||||
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
|
||||
from TTS.vc.modules.freevc.wavlm import get_wavlm
|
||||
|
||||
|
||||
class ResidualCouplingBlock(nn.Module):
|
||||
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.n_flows = n_flows
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.flows = nn.ModuleList()
|
||||
for i in range(n_flows):
|
||||
self.flows.append(
|
||||
modules.ResidualCouplingLayer(
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=gin_channels,
|
||||
mean_only=True,
|
||||
)
|
||||
)
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
if not reverse:
|
||||
for flow in self.flows:
|
||||
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||
else:
|
||||
for flow in reversed(self.flows):
|
||||
x = flow(x, x_mask, g=g, reverse=reverse)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
||||
return z, m, logs, x_mask
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
initial_channel,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=0,
|
||||
):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
self.ups.apply(init_weights)
|
||||
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
def forward(self, x, g=None):
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
self.use_spectral_norm = use_spectral_norm
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
periods = [2, 3, 5, 7, 11]
|
||||
|
||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||
self.discriminators = nn.ModuleList(discs)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_rs.append(fmap_r)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class SpeakerEncoder(torch.nn.Module):
|
||||
def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256):
|
||||
super(SpeakerEncoder, self).__init__()
|
||||
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
||||
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, mels):
|
||||
self.lstm.flatten_parameters()
|
||||
_, (hidden, _) = self.lstm(mels)
|
||||
embeds_raw = self.relu(self.linear(hidden[-1]))
|
||||
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
|
||||
|
||||
def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
|
||||
mel_slices = []
|
||||
for i in range(0, total_frames - partial_frames, partial_hop):
|
||||
mel_range = torch.arange(i, i + partial_frames)
|
||||
mel_slices.append(mel_range)
|
||||
|
||||
return mel_slices
|
||||
|
||||
def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
|
||||
mel_len = mel.size(1)
|
||||
last_mel = mel[:, -partial_frames:]
|
||||
|
||||
if mel_len > partial_frames:
|
||||
mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop)
|
||||
mels = list(mel[:, s] for s in mel_slices)
|
||||
mels.append(last_mel)
|
||||
mels = torch.stack(tuple(mels), 0).squeeze(1)
|
||||
|
||||
with torch.no_grad():
|
||||
partial_embeds = self(mels)
|
||||
embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
|
||||
# embed = embed / torch.linalg.norm(embed, 2)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
embed = self(last_mel)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
@dataclass
|
||||
class FreeVCAudioConfig(Coqpit):
|
||||
"""Audio configuration
|
||||
|
||||
Args:
|
||||
max_wav_value (float):
|
||||
The maximum value of the waveform.
|
||||
|
||||
input_sample_rate (int):
|
||||
The sampling rate of the input waveform.
|
||||
|
||||
output_sample_rate (int):
|
||||
The sampling rate of the output waveform.
|
||||
|
||||
filter_length (int):
|
||||
The length of the filter.
|
||||
|
||||
hop_length (int):
|
||||
The hop length.
|
||||
|
||||
win_length (int):
|
||||
The window length.
|
||||
|
||||
n_mel_channels (int):
|
||||
The number of mel channels.
|
||||
|
||||
mel_fmin (float):
|
||||
The minimum frequency of the mel filterbank.
|
||||
|
||||
mel_fmax (Optional[float]):
|
||||
The maximum frequency of the mel filterbank.
|
||||
"""
|
||||
|
||||
max_wav_value: float = field(default=32768.0)
|
||||
input_sample_rate: int = field(default=16000)
|
||||
output_sample_rate: int = field(default=24000)
|
||||
filter_length: int = field(default=1280)
|
||||
hop_length: int = field(default=320)
|
||||
win_length: int = field(default=1280)
|
||||
n_mel_channels: int = field(default=80)
|
||||
mel_fmin: float = field(default=0.0)
|
||||
mel_fmax: Optional[float] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FreeVCArgs(Coqpit):
|
||||
"""FreeVC model arguments
|
||||
|
||||
Args:
|
||||
spec_channels (int):
|
||||
The number of channels in the spectrogram.
|
||||
|
||||
inter_channels (int):
|
||||
The number of channels in the intermediate layers.
|
||||
|
||||
hidden_channels (int):
|
||||
The number of channels in the hidden layers.
|
||||
|
||||
filter_channels (int):
|
||||
The number of channels in the filter layers.
|
||||
|
||||
n_heads (int):
|
||||
The number of attention heads.
|
||||
|
||||
n_layers (int):
|
||||
The number of layers.
|
||||
|
||||
kernel_size (int):
|
||||
The size of the kernel.
|
||||
|
||||
p_dropout (float):
|
||||
The dropout probability.
|
||||
|
||||
resblock (str):
|
||||
The type of residual block.
|
||||
|
||||
resblock_kernel_sizes (List[int]):
|
||||
The kernel sizes for the residual blocks.
|
||||
|
||||
resblock_dilation_sizes (List[List[int]]):
|
||||
The dilation sizes for the residual blocks.
|
||||
|
||||
upsample_rates (List[int]):
|
||||
The upsample rates.
|
||||
|
||||
upsample_initial_channel (int):
|
||||
The number of channels in the initial upsample layer.
|
||||
|
||||
upsample_kernel_sizes (List[int]):
|
||||
The kernel sizes for the upsample layers.
|
||||
|
||||
n_layers_q (int):
|
||||
The number of layers in the quantization network.
|
||||
|
||||
use_spectral_norm (bool):
|
||||
Whether to use spectral normalization.
|
||||
|
||||
gin_channels (int):
|
||||
The number of channels in the global conditioning vector.
|
||||
|
||||
ssl_dim (int):
|
||||
The dimension of the self-supervised learning embedding.
|
||||
|
||||
use_spk (bool):
|
||||
Whether to use external speaker encoder.
|
||||
"""
|
||||
|
||||
spec_channels: int = field(default=641)
|
||||
inter_channels: int = field(default=192)
|
||||
hidden_channels: int = field(default=192)
|
||||
filter_channels: int = field(default=768)
|
||||
n_heads: int = field(default=2)
|
||||
n_layers: int = field(default=6)
|
||||
kernel_size: int = field(default=3)
|
||||
p_dropout: float = field(default=0.1)
|
||||
resblock: str = field(default="1")
|
||||
resblock_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11])
|
||||
resblock_dilation_sizes: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||
upsample_rates: List[int] = field(default_factory=lambda: [10, 8, 2, 2])
|
||||
upsample_initial_channel: int = field(default=512)
|
||||
upsample_kernel_sizes: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
|
||||
n_layers_q: int = field(default=3)
|
||||
use_spectral_norm: bool = field(default=False)
|
||||
gin_channels: int = field(default=256)
|
||||
ssl_dim: int = field(default=1024)
|
||||
use_spk: bool = field(default=False)
|
||||
num_spks: int = field(default=0)
|
||||
segment_size: int = field(default=8960)
|
||||
|
||||
|
||||
class FreeVC(BaseVC):
|
||||
"""
|
||||
|
||||
Papaer::
|
||||
https://arxiv.org/abs/2210.15418#
|
||||
|
||||
Paper Abstract::
|
||||
Voice conversion (VC) can be achieved by first extracting source content information and target speaker
|
||||
information, and then reconstructing waveform with these information. However, current approaches normally
|
||||
either extract dirty content information with speaker information leaked in, or demand a large amount of
|
||||
annotated data for training. Besides, the quality of reconstructed waveform can be degraded by the
|
||||
mismatch between conversion model and vocoder. In this paper, we adopt the end-to-end framework of VITS for
|
||||
high-quality waveform reconstruction, and propose strategies for clean content information extraction without
|
||||
text annotation. We disentangle content information by imposing an information bottleneck to WavLM features,
|
||||
and propose the spectrogram-resize based data augmentation to improve the purity of extracted content
|
||||
information. Experimental results show that the proposed method outperforms the latest VC models trained with
|
||||
annotated data and has greater robustness.
|
||||
|
||||
Original Code::
|
||||
https://github.com/OlaWod/FreeVC
|
||||
|
||||
Examples:
|
||||
>>> from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||
>>> from TTS.vc.models.freevc import FreeVC
|
||||
>>> config = FreeVCConfig()
|
||||
>>> model = FreeVC(config)
|
||||
"""
|
||||
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
super().__init__(config, None, speaker_manager, None)
|
||||
|
||||
self.init_multispeaker(config)
|
||||
|
||||
self.spec_channels = self.args.spec_channels
|
||||
self.inter_channels = self.args.inter_channels
|
||||
self.hidden_channels = self.args.hidden_channels
|
||||
self.filter_channels = self.args.filter_channels
|
||||
self.n_heads = self.args.n_heads
|
||||
self.n_layers = self.args.n_layers
|
||||
self.kernel_size = self.args.kernel_size
|
||||
self.p_dropout = self.args.p_dropout
|
||||
self.resblock = self.args.resblock
|
||||
self.resblock_kernel_sizes = self.args.resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = self.args.resblock_dilation_sizes
|
||||
self.upsample_rates = self.args.upsample_rates
|
||||
self.upsample_initial_channel = self.args.upsample_initial_channel
|
||||
self.upsample_kernel_sizes = self.args.upsample_kernel_sizes
|
||||
self.segment_size = self.args.segment_size
|
||||
self.gin_channels = self.args.gin_channels
|
||||
self.ssl_dim = self.args.ssl_dim
|
||||
self.use_spk = self.args.use_spk
|
||||
|
||||
self.enc_p = Encoder(self.args.ssl_dim, self.inter_channels, self.hidden_channels, 5, 1, 16)
|
||||
self.dec = Generator(
|
||||
self.inter_channels,
|
||||
self.resblock,
|
||||
self.resblock_kernel_sizes,
|
||||
self.resblock_dilation_sizes,
|
||||
self.upsample_rates,
|
||||
self.upsample_initial_channel,
|
||||
self.upsample_kernel_sizes,
|
||||
gin_channels=self.gin_channels,
|
||||
)
|
||||
self.enc_q = Encoder(
|
||||
self.spec_channels, self.inter_channels, self.hidden_channels, 5, 1, 16, gin_channels=self.gin_channels
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(
|
||||
self.inter_channels, self.hidden_channels, 5, 1, 4, gin_channels=self.gin_channels
|
||||
)
|
||||
if not self.use_spk:
|
||||
self.enc_spk = SpeakerEncoder(model_hidden_size=self.gin_channels, model_embedding_size=self.gin_channels)
|
||||
else:
|
||||
self.load_pretrained_speaker_encoder()
|
||||
|
||||
self.wavlm = get_wavlm()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def load_pretrained_speaker_encoder(self):
|
||||
"""Load pretrained speaker encoder model as mentioned in the paper."""
|
||||
print(" > Loading pretrained speaker encoder model ...")
|
||||
self.enc_spk_ex = SpeakerEncoderEx(
|
||||
"https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt"
|
||||
)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
|
||||
You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||
"""
|
||||
self.num_spks = self.args.num_spks
|
||||
if self.speaker_manager:
|
||||
self.num_spks = self.speaker_manager.num_spks
|
||||
|
||||
def forward(
|
||||
self,
|
||||
c: torch.Tensor,
|
||||
spec: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
mel: Optional[torch.Tensor] = None,
|
||||
c_lengths: Optional[torch.Tensor] = None,
|
||||
spec_lengths: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
]:
|
||||
"""
|
||||
Forward pass of the model.
|
||||
|
||||
Args:
|
||||
c: WavLM features. Shape: (batch_size, c_seq_len).
|
||||
spec: The input spectrogram. Shape: (batch_size, spec_seq_len, spec_dim).
|
||||
g: The speaker embedding. Shape: (batch_size, spk_emb_dim).
|
||||
mel: The input mel-spectrogram for the speaker encoder. Shape: (batch_size, mel_seq_len, mel_dim).
|
||||
c_lengths: The lengths of the WavLM features. Shape: (batch_size,).
|
||||
spec_lengths: The lengths of the spectrogram. Shape: (batch_size,).
|
||||
|
||||
Returns:
|
||||
o: The output spectrogram. Shape: (batch_size, spec_seq_len, spec_dim).
|
||||
ids_slice: The slice indices. Shape: (batch_size, num_slices).
|
||||
spec_mask: The spectrogram mask. Shape: (batch_size, spec_seq_len).
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q): A tuple of latent variables.
|
||||
"""
|
||||
|
||||
# If c_lengths is None, set it to the length of the last dimension of c
|
||||
if c_lengths is None:
|
||||
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
|
||||
|
||||
# If spec_lengths is None, set it to the length of the last dimension of spec
|
||||
if spec_lengths is None:
|
||||
spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device)
|
||||
|
||||
# If use_spk is False, compute g from mel using enc_spk
|
||||
g = None
|
||||
if not self.use_spk:
|
||||
g = self.enc_spk(mel).unsqueeze(-1)
|
||||
|
||||
# Compute m_p, logs_p, z, m_q, logs_q, and spec_mask using enc_p and enc_q
|
||||
_, m_p, logs_p, _ = self.enc_p(c, c_lengths)
|
||||
z, m_q, logs_q, spec_mask = self.enc_q(spec.transpose(1, 2), spec_lengths, g=g)
|
||||
|
||||
# Compute z_p using flow
|
||||
z_p = self.flow(z, spec_mask, g=g)
|
||||
|
||||
# Randomly slice z and compute o using dec
|
||||
z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size)
|
||||
o = self.dec(z_slice, g=g)
|
||||
|
||||
return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c, g=None, mel=None, c_lengths=None):
|
||||
"""
|
||||
Inference pass of the model
|
||||
|
||||
Args:
|
||||
c (torch.Tensor): Input tensor. Shape: (batch_size, c_seq_len).
|
||||
g (torch.Tensor): Speaker embedding tensor. Shape: (batch_size, spk_emb_dim).
|
||||
mel (torch.Tensor): Mel-spectrogram tensor. Shape: (batch_size, mel_seq_len, mel_dim).
|
||||
c_lengths (torch.Tensor): Lengths of the input tensor. Shape: (batch_size,).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor.
|
||||
"""
|
||||
if c_lengths == None:
|
||||
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
|
||||
if not self.use_spk:
|
||||
g = self.enc_spk.embed_utterance(mel)
|
||||
g = g.unsqueeze(-1)
|
||||
z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths)
|
||||
z = self.flow(z_p, c_mask, g=g, reverse=True)
|
||||
o = self.dec(z * c_mask, g=g)
|
||||
return o
|
||||
|
||||
def extract_wavlm_features(self, y):
|
||||
"""Extract WavLM features from an audio tensor.
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Audio tensor. Shape: (batch_size, audio_seq_len).
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
c = self.wavlm.extract_features(y)[0]
|
||||
c = c.transpose(1, 2)
|
||||
return c
|
||||
|
||||
def load_audio(self, wav):
|
||||
"""Read and format the input audio."""
|
||||
if isinstance(wav, str):
|
||||
wav, _ = librosa.load(wav, sr=self.config.audio.input_sample_rate)
|
||||
if isinstance(wav, np.ndarray):
|
||||
wav = torch.from_numpy(wav).to(self.device)
|
||||
if isinstance(wav, torch.Tensor):
|
||||
wav = wav.to(self.device)
|
||||
if isinstance(wav, list):
|
||||
wav = torch.from_numpy(np.array(wav)).to(self.device)
|
||||
return wav.float()
|
||||
|
||||
@torch.inference_mode()
|
||||
def voice_conversion(self, src, tgt):
|
||||
"""
|
||||
Voice conversion pass of the model.
|
||||
|
||||
Args:
|
||||
src (str or torch.Tensor): Source utterance.
|
||||
tgt (str or torch.Tensor): Target utterance.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor.
|
||||
"""
|
||||
|
||||
wav_tgt = self.load_audio(tgt).cpu().numpy()
|
||||
wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
|
||||
|
||||
if self.config.model_args.use_spk:
|
||||
g_tgt = self.enc_spk_ex.embed_utterance(wav_tgt)
|
||||
g_tgt = torch.from_numpy(g_tgt)[None, :, None].to(self.device)
|
||||
else:
|
||||
wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(self.device)
|
||||
mel_tgt = mel_spectrogram_torch(
|
||||
wav_tgt,
|
||||
self.config.audio.filter_length,
|
||||
self.config.audio.n_mel_channels,
|
||||
self.config.audio.input_sample_rate,
|
||||
self.config.audio.hop_length,
|
||||
self.config.audio.win_length,
|
||||
self.config.audio.mel_fmin,
|
||||
self.config.audio.mel_fmax,
|
||||
)
|
||||
# src
|
||||
wav_src = self.load_audio(src)
|
||||
c = self.extract_wavlm_features(wav_src[None, :])
|
||||
|
||||
if self.config.model_args.use_spk:
|
||||
audio = self.inference(c, g=g_tgt)
|
||||
else:
|
||||
audio = self.inference(c, mel=mel_tgt.transpose(1, 2))
|
||||
audio = audio[0][0].data.cpu().float().numpy()
|
||||
return audio
|
||||
|
||||
def eval_step():
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
|
||||
model = FreeVC(config)
|
||||
return model
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False, strict=True, cache=False):
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"], strict=strict)
|
||||
if eval:
|
||||
self.eval()
|
||||
|
||||
def train_step():
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class FreeVCConfig(BaseVCConfig):
|
||||
"""Defines parameters for FreeVC End2End TTS model.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name. Do not change unless you know what you are doing.
|
||||
|
||||
model_args (FreeVCArgs):
|
||||
Model architecture arguments. Defaults to `FreeVCArgs()`.
|
||||
|
||||
audio (FreeVCAudioConfig):
|
||||
Audio processing configuration. Defaults to `FreeVCAudioConfig()`.
|
||||
|
||||
grad_clip (List):
|
||||
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
|
||||
|
||||
lr_gen (float):
|
||||
Initial learning rate for the generator. Defaults to 0.0002.
|
||||
|
||||
lr_disc (float):
|
||||
Initial learning rate for the discriminator. Defaults to 0.0002.
|
||||
|
||||
lr_scheduler_gen (str):
|
||||
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
||||
`ExponentialLR`.
|
||||
|
||||
lr_scheduler_gen_params (dict):
|
||||
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
||||
|
||||
lr_scheduler_disc (str):
|
||||
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
||||
`ExponentialLR`.
|
||||
|
||||
lr_scheduler_disc_params (dict):
|
||||
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
||||
|
||||
scheduler_after_epoch (bool):
|
||||
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
|
||||
|
||||
optimizer (str):
|
||||
Name of the optimizer to use with both the generator and the discriminator networks. One of the
|
||||
`torch.optim.*`. Defaults to `AdamW`.
|
||||
|
||||
kl_loss_alpha (float):
|
||||
Loss weight for KL loss. Defaults to 1.0.
|
||||
|
||||
disc_loss_alpha (float):
|
||||
Loss weight for the discriminator loss. Defaults to 1.0.
|
||||
|
||||
gen_loss_alpha (float):
|
||||
Loss weight for the generator loss. Defaults to 1.0.
|
||||
|
||||
feat_loss_alpha (float):
|
||||
Loss weight for the feature matching loss. Defaults to 1.0.
|
||||
|
||||
mel_loss_alpha (float):
|
||||
Loss weight for the mel loss. Defaults to 45.0.
|
||||
|
||||
return_wav (bool):
|
||||
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
|
||||
|
||||
compute_linear_spec (bool):
|
||||
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
||||
|
||||
use_weighted_sampler (bool):
|
||||
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
|
||||
|
||||
weighted_sampler_attrs (dict):
|
||||
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
|
||||
by overweighting `root_path` by 2.0. Defaults to `{}`.
|
||||
|
||||
weighted_sampler_multipliers (dict):
|
||||
Weight each unique value of a key returned by the formatter for weighted sampling.
|
||||
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
|
||||
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
|
||||
|
||||
r (int):
|
||||
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
||||
|
||||
add_blank (bool):
|
||||
If true, a blank token is added in between every character. Defaults to `True`.
|
||||
|
||||
test_sentences (List[List]):
|
||||
List of sentences with speaker and language information to be used for testing.
|
||||
|
||||
language_ids_file (str):
|
||||
Path to the language ids file.
|
||||
|
||||
use_language_embedding (bool):
|
||||
If true, language embedding is used. Defaults to `False`.
|
||||
|
||||
Note:
|
||||
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.tts.configs.freevc_config import FreeVCConfig
|
||||
>>> config = FreeVCConfig()
|
||||
"""
|
||||
|
||||
model: str = "freevc"
|
||||
# model specific params
|
||||
model_args: FreeVCArgs = FreeVCArgs()
|
||||
audio: FreeVCAudioConfig = FreeVCAudioConfig()
|
||||
|
||||
# optimizer
|
||||
# TODO with training support
|
||||
|
||||
# loss params
|
||||
# TODO with training support
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
compute_linear_spec: bool = True
|
||||
|
||||
# sampler params
|
||||
use_weighted_sampler: bool = False # TODO: move it to the base config
|
||||
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
|
||||
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
|
||||
|
||||
# overrides
|
||||
r: int = 1 # DO NOT CHANGE
|
||||
add_blank: bool = True
|
||||
|
||||
# multi-speaker settings
|
||||
# use speaker embedding layer
|
||||
num_speakers: int = 0
|
||||
speakers_file: str = None
|
||||
speaker_embedding_channels: int = 256
|
||||
|
||||
# use d-vectors
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_file: List[str] = None
|
||||
d_vector_dim: int = None
|
||||
|
||||
def __post_init__(self):
|
||||
for key, val in self.model_args.items():
|
||||
if hasattr(self, key):
|
||||
self[key] = val
|
|
@ -0,0 +1,170 @@
|
|||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
||||
|
||||
|
||||
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
||||
"""KL(P||Q)"""
|
||||
kl = (logs_q - logs_p) - 0.5
|
||||
kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
||||
return kl
|
||||
|
||||
|
||||
def rand_gumbel(shape):
|
||||
"""Sample from the Gumbel distribution, protect from overflows."""
|
||||
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
||||
return -torch.log(-torch.log(uniform_samples))
|
||||
|
||||
|
||||
def rand_gumbel_like(x):
|
||||
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
||||
return g
|
||||
|
||||
|
||||
def slice_segments(x, ids_str, segment_size=4):
|
||||
ret = torch.zeros_like(x[:, :, :segment_size])
|
||||
for i in range(x.size(0)):
|
||||
idx_str = ids_str[i]
|
||||
idx_end = idx_str + segment_size
|
||||
ret[i] = x[i, :, idx_str:idx_end]
|
||||
return ret
|
||||
|
||||
|
||||
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
||||
b, d, t = x.size()
|
||||
if x_lengths is None:
|
||||
x_lengths = t
|
||||
ids_str_max = x_lengths - segment_size + 1
|
||||
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||
ret = slice_segments(x, ids_str, segment_size)
|
||||
return ret, ids_str
|
||||
|
||||
|
||||
def rand_spec_segments(x, x_lengths=None, segment_size=4):
|
||||
b, d, t = x.size()
|
||||
if x_lengths is None:
|
||||
x_lengths = t
|
||||
ids_str_max = x_lengths - segment_size
|
||||
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||
ret = slice_segments(x, ids_str, segment_size)
|
||||
return ret, ids_str
|
||||
|
||||
|
||||
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
||||
position = torch.arange(length, dtype=torch.float)
|
||||
num_timescales = channels // 2
|
||||
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
|
||||
inv_timescales = min_timescale * torch.exp(
|
||||
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
||||
)
|
||||
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
||||
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
||||
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
||||
signal = signal.view(1, channels, length)
|
||||
return signal
|
||||
|
||||
|
||||
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return x + signal.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
||||
|
||||
|
||||
def subsequent_mask(length):
|
||||
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
||||
return mask
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def shift_1d(x):
|
||||
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
||||
return x
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
duration: [b, 1, t_x]
|
||||
mask: [b, 1, t_y, t_x]
|
||||
"""
|
||||
device = duration.device
|
||||
|
||||
b, _, t_y, t_x = mask.shape
|
||||
cum_duration = torch.cumsum(duration, -1)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path.unsqueeze(1).transpose(2, 3) * mask
|
||||
return path
|
||||
|
||||
|
||||
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||
norm_type = float(norm_type)
|
||||
if clip_value is not None:
|
||||
clip_value = float(clip_value)
|
||||
|
||||
total_norm = 0
|
||||
for p in parameters:
|
||||
param_norm = p.grad.data.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
if clip_value is not None:
|
||||
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
||||
total_norm = total_norm ** (1.0 / norm_type)
|
||||
return total_norm
|
|
@ -0,0 +1,125 @@
|
|||
import torch
|
||||
import torch.utils.data
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor
|
||||
"""
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor used to compress
|
||||
"""
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
output = dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
output = dynamic_range_decompression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
return spec
|
||||
|
||||
|
||||
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
||||
global mel_basis
|
||||
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
|
@ -0,0 +1,391 @@
|
|||
import copy
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
|
||||
import TTS.vc.modules.freevc.commons as commons
|
||||
from TTS.vc.modules.freevc.commons import get_padding, init_weights
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, -1)
|
||||
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||
return x.transpose(1, -1)
|
||||
|
||||
|
||||
class ConvReluNorm(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
assert n_layers > 1, "Number of layers should be larger than 0."
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
||||
for _ in range(n_layers - 1):
|
||||
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x_org = x
|
||||
for i in range(self.n_layers):
|
||||
x = self.conv_layers[i](x * x_mask)
|
||||
x = self.norm_layers[i](x)
|
||||
x = self.relu_drop(x)
|
||||
x = x_org + self.proj(x)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class DDSConv(nn.Module):
|
||||
"""
|
||||
Dialted and Depth-Separable Convolution
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.convs_sep = nn.ModuleList()
|
||||
self.convs_1x1 = nn.ModuleList()
|
||||
self.norms_1 = nn.ModuleList()
|
||||
self.norms_2 = nn.ModuleList()
|
||||
for i in range(n_layers):
|
||||
dilation = kernel_size**i
|
||||
padding = (kernel_size * dilation - dilation) // 2
|
||||
self.convs_sep.append(
|
||||
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
|
||||
)
|
||||
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
||||
self.norms_1.append(LayerNorm(channels))
|
||||
self.norms_2.append(LayerNorm(channels))
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
if g is not None:
|
||||
x = x + g
|
||||
for i in range(self.n_layers):
|
||||
y = self.convs_sep[i](x * x_mask)
|
||||
y = self.norms_1[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.convs_1x1[i](y)
|
||||
y = self.norms_2[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.drop(y)
|
||||
x = x + y
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class WN(torch.nn.Module):
|
||||
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
|
||||
super(WN, self).__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = (kernel_size,)
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.gin_channels = gin_channels
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.in_layers = torch.nn.ModuleList()
|
||||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
if gin_channels != 0:
|
||||
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
||||
|
||||
for i in range(n_layers):
|
||||
dilation = dilation_rate**i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
in_layer = torch.nn.Conv1d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
||||
)
|
||||
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
# last one is not necessary
|
||||
if i < n_layers - 1:
|
||||
res_skip_channels = 2 * hidden_channels
|
||||
else:
|
||||
res_skip_channels = hidden_channels
|
||||
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
def forward(self, x, x_mask, g=None, **kwargs):
|
||||
output = torch.zeros_like(x)
|
||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
|
||||
for i in range(self.n_layers):
|
||||
x_in = self.in_layers[i](x)
|
||||
if g is not None:
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
||||
else:
|
||||
g_l = torch.zeros_like(x_in)
|
||||
|
||||
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
||||
acts = self.drop(acts)
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.n_layers - 1:
|
||||
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
||||
x = (x + res_acts) * x_mask
|
||||
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
return output * x_mask
|
||||
|
||||
def remove_weight_norm(self):
|
||||
if self.gin_channels != 0:
|
||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
||||
for l in self.in_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
for l in self.res_skip_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock1, self).__init__()
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
def forward(self, x, x_mask=None):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
if x_mask is not None:
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super(ResBlock2, self).__init__()
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
def forward(self, x, x_mask=None):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
if x_mask is not None:
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class Log(nn.Module):
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||
if not reverse:
|
||||
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
||||
logdet = torch.sum(-y, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = torch.exp(x) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class Flip(nn.Module):
|
||||
def forward(self, x, *args, reverse=False, **kwargs):
|
||||
x = torch.flip(x, [1])
|
||||
if not reverse:
|
||||
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class ElementwiseAffine(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.m = nn.Parameter(torch.zeros(channels, 1))
|
||||
self.logs = nn.Parameter(torch.zeros(channels, 1))
|
||||
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||
if not reverse:
|
||||
y = self.m + torch.exp(self.logs) * x
|
||||
y = y * x_mask
|
||||
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCouplingLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
p_dropout=0,
|
||||
gin_channels=0,
|
||||
mean_only=False,
|
||||
):
|
||||
assert channels % 2 == 0, "channels should be divisible by 2"
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.half_channels = channels // 2
|
||||
self.mean_only = mean_only
|
||||
|
||||
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||
self.enc = WN(
|
||||
hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels
|
||||
)
|
||||
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
||||
self.post.weight.data.zero_()
|
||||
self.post.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
||||
h = self.pre(x0) * x_mask
|
||||
h = self.enc(h, x_mask, g=g)
|
||||
stats = self.post(h) * x_mask
|
||||
if not self.mean_only:
|
||||
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
||||
else:
|
||||
m = stats
|
||||
logs = torch.zeros_like(m)
|
||||
|
||||
if not reverse:
|
||||
x1 = m + x1 * torch.exp(logs) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
logdet = torch.sum(logs, [1, 2])
|
||||
return x, logdet
|
||||
else:
|
||||
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
return x
|
|
@ -0,0 +1,65 @@
|
|||
import struct
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
# import webrtcvad
|
||||
import librosa
|
||||
import numpy as np
|
||||
from scipy.ndimage.morphology import binary_dilation
|
||||
|
||||
from TTS.vc.modules.freevc.speaker_encoder.hparams import *
|
||||
|
||||
int16_max = (2**15) - 1
|
||||
|
||||
|
||||
def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], source_sr: Optional[int] = None):
|
||||
"""
|
||||
Applies the preprocessing operations used in training the Speaker Encoder to a waveform
|
||||
either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
|
||||
|
||||
:param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
|
||||
just .wav), either the waveform as a numpy array of floats.
|
||||
:param source_sr: if passing an audio waveform, the sampling rate of the waveform before
|
||||
preprocessing. After preprocessing, the waveform's sampling rate will match the data
|
||||
hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
|
||||
this argument will be ignored.
|
||||
"""
|
||||
# Load the wav from disk if needed
|
||||
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
|
||||
wav, source_sr = librosa.load(fpath_or_wav, sr=None)
|
||||
else:
|
||||
wav = fpath_or_wav
|
||||
|
||||
# Resample the wav if needed
|
||||
if source_sr is not None and source_sr != sampling_rate:
|
||||
wav = librosa.resample(wav, source_sr, sampling_rate)
|
||||
|
||||
# Apply the preprocessing: normalize volume and shorten long silences
|
||||
wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
|
||||
wav = trim_long_silences(wav)
|
||||
|
||||
return wav
|
||||
|
||||
|
||||
def wav_to_mel_spectrogram(wav):
|
||||
"""
|
||||
Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
|
||||
Note: this not a log-mel spectrogram.
|
||||
"""
|
||||
frames = librosa.feature.melspectrogram(
|
||||
y=wav,
|
||||
sr=sampling_rate,
|
||||
n_fft=int(sampling_rate * mel_window_length / 1000),
|
||||
hop_length=int(sampling_rate * mel_window_step / 1000),
|
||||
n_mels=mel_n_channels,
|
||||
)
|
||||
return frames.astype(np.float32).T
|
||||
|
||||
|
||||
def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
|
||||
if increase_only and decrease_only:
|
||||
raise ValueError("Both increase only and decrease only are set")
|
||||
dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav**2))
|
||||
if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
|
||||
return wav
|
||||
return wav * (10 ** (dBFS_change / 20))
|
|
@ -0,0 +1,31 @@
|
|||
## Mel-filterbank
|
||||
mel_window_length = 25 # In milliseconds
|
||||
mel_window_step = 10 # In milliseconds
|
||||
mel_n_channels = 40
|
||||
|
||||
|
||||
## Audio
|
||||
sampling_rate = 16000
|
||||
# Number of spectrogram frames in a partial utterance
|
||||
partials_n_frames = 160 # 1600 ms
|
||||
|
||||
|
||||
## Voice Activation Detection
|
||||
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
||||
# This sets the granularity of the VAD. Should not need to be changed.
|
||||
vad_window_length = 30 # In milliseconds
|
||||
# Number of frames to average together when performing the moving average smoothing.
|
||||
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
||||
vad_moving_average_width = 8
|
||||
# Maximum number of consecutive silent frames a segment can have.
|
||||
vad_max_silence_length = 6
|
||||
|
||||
|
||||
## Audio volume normalization
|
||||
audio_norm_target_dBFS = -30
|
||||
|
||||
|
||||
## Model parameters
|
||||
model_hidden_size = 256
|
||||
model_embedding_size = 256
|
||||
model_num_layers = 3
|
|
@ -0,0 +1,175 @@
|
|||
from pathlib import Path
|
||||
from time import perf_counter as timer
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vc.modules.freevc.speaker_encoder import audio
|
||||
from TTS.vc.modules.freevc.speaker_encoder.hparams import *
|
||||
|
||||
|
||||
class SpeakerEncoder(nn.Module):
|
||||
def __init__(self, weights_fpath, device: Union[str, torch.device] = None, verbose=True):
|
||||
"""
|
||||
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
|
||||
If None, defaults to cuda if it is available on your machine, otherwise the model will
|
||||
run on cpu. Outputs are always returned on the cpu, as numpy arrays.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Define the network
|
||||
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
||||
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
# Get the target device
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
self.device = device
|
||||
|
||||
# Load the pretrained model'speaker weights
|
||||
# weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
|
||||
# if not weights_fpath.exists():
|
||||
# raise Exception("Couldn't find the voice encoder pretrained model at %s." %
|
||||
# weights_fpath)
|
||||
|
||||
start = timer()
|
||||
checkpoint = load_fsspec(weights_fpath, map_location="cpu")
|
||||
|
||||
self.load_state_dict(checkpoint["model_state"], strict=False)
|
||||
self.to(device)
|
||||
|
||||
if verbose:
|
||||
print("Loaded the voice encoder model on %s in %.2f seconds." % (device.type, timer() - start))
|
||||
|
||||
def forward(self, mels: torch.FloatTensor):
|
||||
"""
|
||||
Computes the embeddings of a batch of utterance spectrograms.
|
||||
:param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape
|
||||
(batch_size, n_frames, n_channels)
|
||||
:return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size).
|
||||
Embeddings are positive and L2-normed, thus they lay in the range [0, 1].
|
||||
"""
|
||||
# Pass the input through the LSTM layers and retrieve the final hidden state of the last
|
||||
# layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings.
|
||||
_, (hidden, _) = self.lstm(mels)
|
||||
embeds_raw = self.relu(self.linear(hidden[-1]))
|
||||
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
|
||||
|
||||
@staticmethod
|
||||
def compute_partial_slices(n_samples: int, rate, min_coverage):
|
||||
"""
|
||||
Computes where to split an utterance waveform and its corresponding mel spectrogram to
|
||||
obtain partial utterances of <partials_n_frames> each. Both the waveform and the
|
||||
mel spectrogram slices are returned, so as to make each partial utterance waveform
|
||||
correspond to its spectrogram.
|
||||
|
||||
The returned ranges may be indexing further than the length of the waveform. It is
|
||||
recommended that you pad the waveform with zeros up to wav_slices[-1].stop.
|
||||
|
||||
:param n_samples: the number of samples in the waveform
|
||||
:param rate: how many partial utterances should occur per second. Partial utterances must
|
||||
cover the span of the entire utterance, thus the rate should not be lower than the inverse
|
||||
of the duration of a partial utterance. By default, partial utterances are 1.6s long and
|
||||
the minimum rate is thus 0.625.
|
||||
:param min_coverage: when reaching the last partial utterance, it may or may not have
|
||||
enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
|
||||
then the last partial utterance will be considered by zero-padding the audio. Otherwise,
|
||||
it will be discarded. If there aren't enough frames for one partial utterance,
|
||||
this parameter is ignored so that the function always returns at least one slice.
|
||||
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
|
||||
respectively the waveform and the mel spectrogram with these slices to obtain the partial
|
||||
utterances.
|
||||
"""
|
||||
assert 0 < min_coverage <= 1
|
||||
|
||||
# Compute how many frames separate two partial utterances
|
||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
|
||||
assert 0 < frame_step, "The rate is too high"
|
||||
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % (
|
||||
sampling_rate / (samples_per_frame * partials_n_frames)
|
||||
)
|
||||
|
||||
# Compute the slices
|
||||
wav_slices, mel_slices = [], []
|
||||
steps = max(1, n_frames - partials_n_frames + frame_step + 1)
|
||||
for i in range(0, steps, frame_step):
|
||||
mel_range = np.array([i, i + partials_n_frames])
|
||||
wav_range = mel_range * samples_per_frame
|
||||
mel_slices.append(slice(*mel_range))
|
||||
wav_slices.append(slice(*wav_range))
|
||||
|
||||
# Evaluate whether extra padding is warranted or not
|
||||
last_wav_range = wav_slices[-1]
|
||||
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
|
||||
if coverage < min_coverage and len(mel_slices) > 1:
|
||||
mel_slices = mel_slices[:-1]
|
||||
wav_slices = wav_slices[:-1]
|
||||
|
||||
return wav_slices, mel_slices
|
||||
|
||||
def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75):
|
||||
"""
|
||||
Computes an embedding for a single utterance. The utterance is divided in partial
|
||||
utterances and an embedding is computed for each. The complete utterance embedding is the
|
||||
L2-normed average embedding of the partial utterances.
|
||||
|
||||
TODO: independent batched version of this function
|
||||
|
||||
:param wav: a preprocessed utterance waveform as a numpy array of float32
|
||||
:param return_partials: if True, the partial embeddings will also be returned along with
|
||||
the wav slices corresponding to each partial utterance.
|
||||
:param rate: how many partial utterances should occur per second. Partial utterances must
|
||||
cover the span of the entire utterance, thus the rate should not be lower than the inverse
|
||||
of the duration of a partial utterance. By default, partial utterances are 1.6s long and
|
||||
the minimum rate is thus 0.625.
|
||||
:param min_coverage: when reaching the last partial utterance, it may or may not have
|
||||
enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
|
||||
then the last partial utterance will be considered by zero-padding the audio. Otherwise,
|
||||
it will be discarded. If there aren't enough frames for one partial utterance,
|
||||
this parameter is ignored so that the function always returns at least one slice.
|
||||
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
|
||||
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
|
||||
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
|
||||
returned.
|
||||
"""
|
||||
# Compute where to split the utterance into partials and pad the waveform with zeros if
|
||||
# the partial utterances cover a larger range.
|
||||
wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
|
||||
max_wave_length = wav_slices[-1].stop
|
||||
if max_wave_length >= len(wav):
|
||||
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
||||
|
||||
# Split the utterance into partials and forward them through the model
|
||||
mel = audio.wav_to_mel_spectrogram(wav)
|
||||
mels = np.array([mel[s] for s in mel_slices])
|
||||
with torch.no_grad():
|
||||
mels = torch.from_numpy(mels).to(self.device)
|
||||
partial_embeds = self(mels).cpu().numpy()
|
||||
|
||||
# Compute the utterance embedding from the partial embeddings
|
||||
raw_embed = np.mean(partial_embeds, axis=0)
|
||||
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
||||
|
||||
if return_partials:
|
||||
return embed, partial_embeds, wav_slices
|
||||
return embed
|
||||
|
||||
def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
|
||||
"""
|
||||
Compute the embedding of a collection of wavs (presumably from the same speaker) by
|
||||
averaging their embedding and L2-normalizing it.
|
||||
|
||||
:param wavs: list of wavs a numpy arrays of float32.
|
||||
:param kwargs: extra arguments to embed_utterance()
|
||||
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,).
|
||||
"""
|
||||
raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs], axis=0)
|
||||
return raw_embed / np.linalg.norm(raw_embed, 2)
|
|
@ -0,0 +1,35 @@
|
|||
import os
|
||||
import urllib.request
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
|
||||
|
||||
model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt"
|
||||
|
||||
|
||||
def get_wavlm(device="cpu"):
|
||||
"""Download the model and return the model object."""
|
||||
|
||||
output_path = get_user_data_dir("tts")
|
||||
|
||||
output_path = os.path.join(output_path, "wavlm")
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
output_path = os.path.join(output_path, "WavLM-Large.pt")
|
||||
if not os.path.exists(output_path):
|
||||
print(f" > Downloading WavLM model to {output_path} ...")
|
||||
urllib.request.urlretrieve(model_uri, output_path)
|
||||
|
||||
checkpoint = torch.load(output_path, map_location=torch.device(device))
|
||||
cfg = WavLMConfig(checkpoint["cfg"])
|
||||
wavlm = WavLM(cfg).to(device)
|
||||
wavlm.load_state_dict(checkpoint["model"])
|
||||
wavlm.eval()
|
||||
return wavlm
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
wavlm = get_wavlm()
|
|
@ -0,0 +1,99 @@
|
|||
{
|
||||
"_name_or_path": "./wavlm-large/",
|
||||
"activation_dropout": 0.0,
|
||||
"adapter_kernel_size": 3,
|
||||
"adapter_stride": 2,
|
||||
"add_adapter": false,
|
||||
"apply_spec_augment": true,
|
||||
"architectures": [
|
||||
"WavLMModel"
|
||||
],
|
||||
"attention_dropout": 0.1,
|
||||
"bos_token_id": 1,
|
||||
"classifier_proj_size": 256,
|
||||
"codevector_dim": 768,
|
||||
"contrastive_logits_temperature": 0.1,
|
||||
"conv_bias": false,
|
||||
"conv_dim": [
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512
|
||||
],
|
||||
"conv_kernel": [
|
||||
10,
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"conv_stride": [
|
||||
5,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"ctc_loss_reduction": "sum",
|
||||
"ctc_zero_infinity": false,
|
||||
"diversity_loss_weight": 0.1,
|
||||
"do_stable_layer_norm": true,
|
||||
"eos_token_id": 2,
|
||||
"feat_extract_activation": "gelu",
|
||||
"feat_extract_dropout": 0.0,
|
||||
"feat_extract_norm": "layer",
|
||||
"feat_proj_dropout": 0.1,
|
||||
"feat_quantizer_dropout": 0.0,
|
||||
"final_dropout": 0.0,
|
||||
"gradient_checkpointing": false,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout": 0.1,
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"layerdrop": 0.1,
|
||||
"mask_channel_length": 10,
|
||||
"mask_channel_min_space": 1,
|
||||
"mask_channel_other": 0.0,
|
||||
"mask_channel_prob": 0.0,
|
||||
"mask_channel_selection": "static",
|
||||
"mask_feature_length": 10,
|
||||
"mask_feature_min_masks": 0,
|
||||
"mask_feature_prob": 0.0,
|
||||
"mask_time_length": 10,
|
||||
"mask_time_min_masks": 2,
|
||||
"mask_time_min_space": 1,
|
||||
"mask_time_other": 0.0,
|
||||
"mask_time_prob": 0.075,
|
||||
"mask_time_selection": "static",
|
||||
"max_bucket_distance": 800,
|
||||
"model_type": "wavlm",
|
||||
"num_adapter_layers": 3,
|
||||
"num_attention_heads": 16,
|
||||
"num_buckets": 320,
|
||||
"num_codevector_groups": 2,
|
||||
"num_codevectors_per_group": 320,
|
||||
"num_conv_pos_embedding_groups": 16,
|
||||
"num_conv_pos_embeddings": 128,
|
||||
"num_ctc_classes": 80,
|
||||
"num_feat_extract_layers": 7,
|
||||
"num_hidden_layers": 24,
|
||||
"num_negatives": 100,
|
||||
"output_hidden_size": 1024,
|
||||
"pad_token_id": 0,
|
||||
"proj_codevector_dim": 768,
|
||||
"replace_prob": 0.5,
|
||||
"tokenizer_class": "Wav2Vec2CTCTokenizer",
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.15.0.dev0",
|
||||
"use_weighted_layer_sum": false,
|
||||
"vocab_size": 32
|
||||
}
|
|
@ -0,0 +1,768 @@
|
|||
# --------------------------------------------------------
|
||||
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
||||
# Copyright (c) 2021 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Based on fairseq code bases
|
||||
# https://github.com/pytorch/fairseq
|
||||
# --------------------------------------------------------
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class TransposeLast(nn.Module):
|
||||
def __init__(self, deconstruct_idx=None):
|
||||
super().__init__()
|
||||
self.deconstruct_idx = deconstruct_idx
|
||||
|
||||
def forward(self, x):
|
||||
if self.deconstruct_idx is not None:
|
||||
x = x[self.deconstruct_idx]
|
||||
return x.transpose(-2, -1)
|
||||
|
||||
|
||||
class Fp32LayerNorm(nn.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, input):
|
||||
output = F.layer_norm(
|
||||
input.float(),
|
||||
self.normalized_shape,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
)
|
||||
return output.type_as(input)
|
||||
|
||||
|
||||
class Fp32GroupNorm(nn.GroupNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, input):
|
||||
output = F.group_norm(
|
||||
input.float(),
|
||||
self.num_groups,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
)
|
||||
return output.type_as(input)
|
||||
|
||||
|
||||
class GradMultiply(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, scale):
|
||||
ctx.scale = scale
|
||||
res = x.new(x)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
return grad * ctx.scale, None
|
||||
|
||||
|
||||
class SamePad(nn.Module):
|
||||
def __init__(self, kernel_size, causal=False):
|
||||
super().__init__()
|
||||
if causal:
|
||||
self.remove = kernel_size - 1
|
||||
else:
|
||||
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||||
|
||||
def forward(self, x):
|
||||
if self.remove > 0:
|
||||
x = x[:, :, : -self.remove]
|
||||
return x
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
"""Swish function"""
|
||||
|
||||
def __init__(self):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(Swish, self).__init__()
|
||||
self.act = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.act(x)
|
||||
|
||||
|
||||
class GLU_Linear(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
||||
super(GLU_Linear, self).__init__()
|
||||
|
||||
self.glu_type = glu_type
|
||||
self.output_dim = output_dim
|
||||
|
||||
if glu_type == "sigmoid":
|
||||
self.glu_act = torch.nn.Sigmoid()
|
||||
elif glu_type == "swish":
|
||||
self.glu_act = Swish()
|
||||
elif glu_type == "relu":
|
||||
self.glu_act = torch.nn.ReLU()
|
||||
elif glu_type == "gelu":
|
||||
self.glu_act = torch.nn.GELU()
|
||||
|
||||
if bias_in_glu:
|
||||
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
||||
else:
|
||||
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
||||
|
||||
def forward(self, x):
|
||||
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
||||
x = self.linear(x)
|
||||
|
||||
if self.glu_type == "bilinear":
|
||||
x = x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2]
|
||||
else:
|
||||
x = x[:, :, 0 : self.output_dim] * self.glu_act(x[:, :, self.output_dim : self.output_dim * 2])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gelu_accurate(x):
|
||||
if not hasattr(gelu_accurate, "_a"):
|
||||
gelu_accurate._a = math.sqrt(2 / math.pi)
|
||||
return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
def gelu(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(x.float()).type_as(x)
|
||||
|
||||
|
||||
def get_activation_fn(activation: str):
|
||||
"""Returns the activation function corresponding to `activation`"""
|
||||
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "gelu":
|
||||
return gelu
|
||||
elif activation == "gelu_fast":
|
||||
warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate")
|
||||
return gelu_accurate
|
||||
elif activation == "gelu_accurate":
|
||||
return gelu_accurate
|
||||
elif activation == "tanh":
|
||||
return torch.tanh
|
||||
elif activation == "linear":
|
||||
return lambda x: x
|
||||
elif activation == "glu":
|
||||
return lambda x: x
|
||||
else:
|
||||
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
||||
|
||||
|
||||
def init_bert_params(module):
|
||||
"""
|
||||
Initialize the weights specific to the BERT Model.
|
||||
This overrides the default initializations depending on the specified arguments.
|
||||
1. If normal_init_linear_weights is set then weights of linear
|
||||
layer will be initialized using the normal distribution and
|
||||
bais will be set to the specified value.
|
||||
2. If normal_init_embed_weights is set then weights of embedding
|
||||
layer will be initialized using the normal distribution.
|
||||
3. If normal_init_proj_weights is set then weights of
|
||||
in_project_weight for MultiHeadAttention initialized using
|
||||
the normal distribution (to be validated).
|
||||
"""
|
||||
|
||||
def normal_(data):
|
||||
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
||||
# so that the RNG is consistent with and without FSDP
|
||||
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
normal_(module.weight.data)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
if isinstance(module, nn.Embedding):
|
||||
normal_(module.weight.data)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
if isinstance(module, MultiheadAttention):
|
||||
normal_(module.q_proj.weight.data)
|
||||
normal_(module.k_proj.weight.data)
|
||||
normal_(module.v_proj.weight.data)
|
||||
|
||||
|
||||
def quant_noise(module, p, block_size):
|
||||
"""
|
||||
Wraps modules and applies quantization noise to the weights for
|
||||
subsequent quantization with Iterative Product Quantization as
|
||||
described in "Training with Quantization Noise for Extreme Model Compression"
|
||||
|
||||
Args:
|
||||
- module: nn.Module
|
||||
- p: amount of Quantization Noise
|
||||
- block_size: size of the blocks for subsequent quantization with iPQ
|
||||
|
||||
Remarks:
|
||||
- Module weights must have the right sizes wrt the block size
|
||||
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
||||
- For more detail on how to quantize by blocks with convolutional weights,
|
||||
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
||||
- We implement the simplest form of noise here as stated in the paper
|
||||
which consists in randomly dropping blocks
|
||||
"""
|
||||
|
||||
# if no quantization noise, don't register hook
|
||||
if p <= 0:
|
||||
return module
|
||||
|
||||
# supported modules
|
||||
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
||||
|
||||
# test whether module.weight has the right sizes wrt block_size
|
||||
is_conv = module.weight.ndim == 4
|
||||
|
||||
# 2D matrix
|
||||
if not is_conv:
|
||||
assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes"
|
||||
|
||||
# 4D matrix
|
||||
else:
|
||||
# 1x1 convolutions
|
||||
if module.kernel_size == (1, 1):
|
||||
assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes"
|
||||
# regular convolutions
|
||||
else:
|
||||
k = module.kernel_size[0] * module.kernel_size[1]
|
||||
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
||||
|
||||
def _forward_pre_hook(mod, input):
|
||||
# no noise for evaluation
|
||||
if mod.training:
|
||||
if not is_conv:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_features = weight.size(1)
|
||||
out_features = weight.size(0)
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
||||
|
||||
else:
|
||||
# gather weight and sizes
|
||||
weight = mod.weight
|
||||
in_channels = mod.in_channels
|
||||
out_channels = mod.out_channels
|
||||
|
||||
# split weight matrix into blocks and randomly drop selected blocks
|
||||
if mod.kernel_size == (1, 1):
|
||||
mask = torch.zeros(
|
||||
int(in_channels // block_size * out_channels),
|
||||
device=weight.device,
|
||||
)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
||||
else:
|
||||
mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
|
||||
mask.bernoulli_(p)
|
||||
mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
||||
|
||||
# scale weights and apply mask
|
||||
mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
|
||||
s = 1 / (1 - p)
|
||||
mod.weight.data = s * weight.masked_fill(mask, 0)
|
||||
|
||||
module.register_forward_pre_hook(_forward_pre_hook)
|
||||
return module
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""Multi-headed attention.
|
||||
|
||||
See "Attention Is All You Need" for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
self_attention=False,
|
||||
encoder_decoder_attention=False,
|
||||
q_noise=0.0,
|
||||
qn_block_size=8,
|
||||
has_relative_attention_bias=False,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
gru_rel_pos=False,
|
||||
rescale_init=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout_module = nn.Dropout(dropout)
|
||||
|
||||
self.has_relative_attention_bias = has_relative_attention_bias
|
||||
self.num_buckets = num_buckets
|
||||
self.max_distance = max_distance
|
||||
if self.has_relative_attention_bias:
|
||||
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.q_head_dim = self.head_dim
|
||||
self.k_head_dim = self.head_dim
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.self_attention = self_attention
|
||||
self.encoder_decoder_attention = encoder_decoder_attention
|
||||
|
||||
assert not self.self_attention or self.qkv_same_dim, (
|
||||
"Self-attention requires query, key and " "value to be of the same size"
|
||||
)
|
||||
|
||||
k_bias = True
|
||||
if rescale_init:
|
||||
k_bias = False
|
||||
|
||||
k_embed_dim = embed_dim
|
||||
q_embed_dim = embed_dim
|
||||
|
||||
self.k_proj = quant_noise(nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size)
|
||||
self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
||||
self.q_proj = quant_noise(nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size)
|
||||
|
||||
self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
self.add_zero_attn = add_zero_attn
|
||||
|
||||
self.gru_rel_pos = gru_rel_pos
|
||||
if self.gru_rel_pos:
|
||||
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
||||
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.qkv_same_dim:
|
||||
# Empirically observed the convergence to be much better with
|
||||
# the scaled initialization
|
||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||
else:
|
||||
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||
|
||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||
if self.out_proj.bias is not None:
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
if self.bias_k is not None:
|
||||
nn.init.xavier_normal_(self.bias_k)
|
||||
if self.bias_v is not None:
|
||||
nn.init.xavier_normal_(self.bias_v)
|
||||
if self.has_relative_attention_bias:
|
||||
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
||||
|
||||
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
||||
num_buckets = self.num_buckets
|
||||
max_distance = self.max_distance
|
||||
relative_buckets = 0
|
||||
|
||||
if bidirectional:
|
||||
num_buckets = num_buckets // 2
|
||||
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
||||
relative_positions = torch.abs(relative_positions)
|
||||
else:
|
||||
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
||||
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_positions < max_exact
|
||||
|
||||
relative_postion_if_large = max_exact + (
|
||||
torch.log(relative_positions.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.long)
|
||||
relative_postion_if_large = torch.min(
|
||||
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
||||
)
|
||||
|
||||
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, query_length, key_length):
|
||||
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
||||
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
||||
relative_position = memory_position - context_position
|
||||
relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
|
||||
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
||||
values = self.relative_attention_bias(relative_position_bucket)
|
||||
values = values.permute([2, 0, 1])
|
||||
return values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
key: Optional[Tensor],
|
||||
value: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
need_weights: bool = True,
|
||||
static_kv: bool = False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
before_softmax: bool = False,
|
||||
need_head_weights: bool = False,
|
||||
position_bias: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
"""Input shape: Time x Batch x Channel
|
||||
|
||||
Args:
|
||||
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||
keys that are pads, of shape `(batch, src_len)`, where
|
||||
padding elements are indicated by 1s.
|
||||
need_weights (bool, optional): return the attention weights,
|
||||
averaged over heads (default: False).
|
||||
attn_mask (ByteTensor, optional): typically used to
|
||||
implement causal attention, where the mask prevents the
|
||||
attention from looking forward in time (default: None).
|
||||
before_softmax (bool, optional): return the raw attention
|
||||
weights and values before the attention softmax.
|
||||
need_head_weights (bool, optional): return the attention
|
||||
weights for each head. Implies *need_weights*. Default:
|
||||
return the average attention weights over all heads.
|
||||
"""
|
||||
if need_head_weights:
|
||||
need_weights = True
|
||||
|
||||
is_tpu = query.device.type == "xla"
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
src_len = tgt_len
|
||||
assert embed_dim == self.embed_dim
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
if key is not None:
|
||||
src_len, key_bsz, _ = key.size()
|
||||
if not torch.jit.is_scripting():
|
||||
assert key_bsz == bsz
|
||||
assert value is not None
|
||||
assert src_len, bsz == value.shape[:2]
|
||||
|
||||
if self.has_relative_attention_bias and position_bias is None:
|
||||
position_bias = self.compute_bias(tgt_len, src_len)
|
||||
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if (
|
||||
not is_tpu # don't use PyTorch version on TPUs
|
||||
and incremental_state is None
|
||||
and not static_kv
|
||||
# A workaround for quantization to work. Otherwise JIT compilation
|
||||
# treats bias in linear module as method.
|
||||
and not torch.jit.is_scripting()
|
||||
and self.q_head_dim == self.head_dim
|
||||
):
|
||||
assert key is not None and value is not None
|
||||
assert attn_mask is None
|
||||
|
||||
attn_mask_rel_pos = None
|
||||
if position_bias is not None:
|
||||
attn_mask_rel_pos = position_bias
|
||||
if self.gru_rel_pos:
|
||||
query_layer = query.transpose(0, 1)
|
||||
new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
|
||||
query_layer = query_layer.view(*new_x_shape)
|
||||
query_layer = query_layer.permute(0, 2, 1, 3)
|
||||
_B, _H, _L, __ = query_layer.size()
|
||||
|
||||
gate_a, gate_b = torch.sigmoid(
|
||||
self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
|
||||
).chunk(2, dim=-1)
|
||||
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
||||
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
||||
|
||||
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
|
||||
k_proj_bias = self.k_proj.bias
|
||||
if k_proj_bias is None:
|
||||
k_proj_bias = torch.zeros_like(self.q_proj.bias)
|
||||
|
||||
x, attn = F.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
torch.empty([0]),
|
||||
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
||||
self.bias_k,
|
||||
self.bias_v,
|
||||
self.add_zero_attn,
|
||||
self.dropout_module.p,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
self.training,
|
||||
# self.training or self.dropout_module.apply_during_inference,
|
||||
key_padding_mask,
|
||||
need_weights,
|
||||
attn_mask_rel_pos,
|
||||
use_separate_proj_weight=True,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
)
|
||||
return x, attn, position_bias
|
||||
|
||||
if incremental_state is not None:
|
||||
saved_state = self._get_input_buffer(incremental_state)
|
||||
if saved_state is not None and "prev_key" in saved_state:
|
||||
# previous time steps are cached - no need to recompute
|
||||
# key and value if they are static
|
||||
if static_kv:
|
||||
assert self.encoder_decoder_attention and not self.self_attention
|
||||
key = value = None
|
||||
else:
|
||||
saved_state = None
|
||||
|
||||
if self.self_attention:
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(query)
|
||||
v = self.v_proj(query)
|
||||
elif self.encoder_decoder_attention:
|
||||
# encoder-decoder attention
|
||||
q = self.q_proj(query)
|
||||
if key is None:
|
||||
assert value is None
|
||||
k = v = None
|
||||
else:
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(key)
|
||||
|
||||
else:
|
||||
assert key is not None and value is not None
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
q *= self.scaling
|
||||
|
||||
if self.bias_k is not None:
|
||||
assert self.bias_v is not None
|
||||
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[
|
||||
key_padding_mask,
|
||||
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim).transpose(0, 1)
|
||||
if k is not None:
|
||||
k = k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim).transpose(0, 1)
|
||||
if v is not None:
|
||||
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||
|
||||
if saved_state is not None:
|
||||
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||
if "prev_key" in saved_state:
|
||||
_prev_key = saved_state["prev_key"]
|
||||
assert _prev_key is not None
|
||||
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
k = prev_key
|
||||
else:
|
||||
assert k is not None
|
||||
k = torch.cat([prev_key, k], dim=1)
|
||||
src_len = k.size(1)
|
||||
if "prev_value" in saved_state:
|
||||
_prev_value = saved_state["prev_value"]
|
||||
assert _prev_value is not None
|
||||
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
v = prev_value
|
||||
else:
|
||||
assert v is not None
|
||||
v = torch.cat([prev_value, v], dim=1)
|
||||
prev_key_padding_mask: Optional[Tensor] = None
|
||||
if "prev_key_padding_mask" in saved_state:
|
||||
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
||||
assert k is not None and v is not None
|
||||
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
||||
key_padding_mask=key_padding_mask,
|
||||
prev_key_padding_mask=prev_key_padding_mask,
|
||||
batch_size=bsz,
|
||||
src_len=k.size(1),
|
||||
static_kv=static_kv,
|
||||
)
|
||||
|
||||
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
saved_state["prev_key_padding_mask"] = key_padding_mask
|
||||
# In this branch incremental_state is never None
|
||||
assert incremental_state is not None
|
||||
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
||||
assert k is not None
|
||||
assert k.size(1) == src_len
|
||||
|
||||
# This is part of a workaround to get around fork/join parallelism
|
||||
# not supporting Optional types.
|
||||
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||
key_padding_mask = None
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if self.add_zero_attn:
|
||||
assert v is not None
|
||||
src_len += 1
|
||||
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
||||
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[
|
||||
key_padding_mask,
|
||||
torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
||||
|
||||
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
# don't attend to padding symbols
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
if not is_tpu:
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||
float("-inf"),
|
||||
)
|
||||
else:
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if before_softmax:
|
||||
return attn_weights, v, position_bias
|
||||
|
||||
if position_bias is not None:
|
||||
if self.gru_rel_pos == 1:
|
||||
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
||||
_B, _H, _L, __ = query_layer.size()
|
||||
gate_a, gate_b = torch.sigmoid(
|
||||
self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
|
||||
).chunk(2, dim=-1)
|
||||
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
||||
position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
||||
|
||||
position_bias = position_bias.view(attn_weights.size())
|
||||
|
||||
attn_weights = attn_weights + position_bias
|
||||
|
||||
attn_weights_float = F.softmax(attn_weights, dim=-1)
|
||||
attn_weights = attn_weights_float.type_as(attn_weights)
|
||||
attn_probs = self.dropout_module(attn_weights)
|
||||
|
||||
assert v is not None
|
||||
attn = torch.bmm(attn_probs, v)
|
||||
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
attn = self.out_proj(attn)
|
||||
attn_weights: Optional[Tensor] = None
|
||||
if need_weights:
|
||||
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
||||
if not need_head_weights:
|
||||
# average attention weights over heads
|
||||
attn_weights = attn_weights.mean(dim=0)
|
||||
|
||||
return attn, attn_weights, position_bias
|
||||
|
||||
@staticmethod
|
||||
def _append_prev_key_padding_mask(
|
||||
key_padding_mask: Optional[Tensor],
|
||||
prev_key_padding_mask: Optional[Tensor],
|
||||
batch_size: int,
|
||||
src_len: int,
|
||||
static_kv: bool,
|
||||
) -> Optional[Tensor]:
|
||||
# saved key padding masks have shape (bsz, seq_len)
|
||||
if prev_key_padding_mask is not None and static_kv:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||||
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
|
||||
# During incremental decoding, as the padding token enters and
|
||||
# leaves the frame, there will be a time when prev or current
|
||||
# is None
|
||||
elif prev_key_padding_mask is not None:
|
||||
if src_len > prev_key_padding_mask.size(1):
|
||||
filler = torch.zeros(
|
||||
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
||||
device=prev_key_padding_mask.device,
|
||||
)
|
||||
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
|
||||
else:
|
||||
new_key_padding_mask = prev_key_padding_mask.float()
|
||||
elif key_padding_mask is not None:
|
||||
if src_len > key_padding_mask.size(1):
|
||||
filler = torch.zeros(
|
||||
(batch_size, src_len - key_padding_mask.size(1)),
|
||||
device=key_padding_mask.device,
|
||||
)
|
||||
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
|
||||
else:
|
||||
new_key_padding_mask = key_padding_mask.float()
|
||||
else:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
return new_key_padding_mask
|
||||
|
||||
def _get_input_buffer(
|
||||
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
||||
) -> Dict[str, Optional[Tensor]]:
|
||||
result = self.get_incremental_state(incremental_state, "attn_state")
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
empty_result: Dict[str, Optional[Tensor]] = {}
|
||||
return empty_result
|
||||
|
||||
def _set_input_buffer(
|
||||
self,
|
||||
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||
buffer: Dict[str, Optional[Tensor]],
|
||||
):
|
||||
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
||||
|
||||
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
||||
return attn_weights
|
|
@ -0,0 +1,719 @@
|
|||
# --------------------------------------------------------
|
||||
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
||||
# Copyright (c) 2021 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Based on fairseq code bases
|
||||
# https://github.com/pytorch/fairseq
|
||||
# --------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from TTS.vc.modules.freevc.wavlm.modules import (
|
||||
Fp32GroupNorm,
|
||||
Fp32LayerNorm,
|
||||
GLU_Linear,
|
||||
GradMultiply,
|
||||
MultiheadAttention,
|
||||
SamePad,
|
||||
TransposeLast,
|
||||
get_activation_fn,
|
||||
init_bert_params,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
padding_mask: Optional[torch.Tensor],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
mask_type: str = "static",
|
||||
mask_other: float = 0.0,
|
||||
min_masks: int = 0,
|
||||
no_overlap: bool = False,
|
||||
min_space: int = 0,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes random mask spans for a given shape
|
||||
|
||||
Args:
|
||||
shape: the the shape for which to compute masks.
|
||||
should be of size 2 where first element is batch size and 2nd is timesteps
|
||||
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
||||
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
||||
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||
mask_type: how to compute mask lengths
|
||||
static = fixed size
|
||||
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
||||
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
||||
poisson = sample from possion distribution with lambda = mask length
|
||||
min_masks: minimum number of masked spans
|
||||
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
||||
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
||||
"""
|
||||
|
||||
bsz, all_sz = shape
|
||||
mask = np.full((bsz, all_sz), False)
|
||||
|
||||
all_num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * all_sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
|
||||
all_num_mask = max(min_masks, all_num_mask)
|
||||
|
||||
mask_idcs = []
|
||||
for i in range(bsz):
|
||||
if padding_mask is not None:
|
||||
sz = all_sz - padding_mask[i].long().sum().item()
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
sz = all_sz
|
||||
num_mask = all_num_mask
|
||||
|
||||
if mask_type == "static":
|
||||
lengths = np.full(num_mask, mask_length)
|
||||
elif mask_type == "uniform":
|
||||
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
||||
elif mask_type == "normal":
|
||||
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
||||
lengths = [max(1, int(round(x))) for x in lengths]
|
||||
elif mask_type == "poisson":
|
||||
lengths = np.random.poisson(mask_length, size=num_mask)
|
||||
lengths = [int(round(x)) for x in lengths]
|
||||
else:
|
||||
raise Exception("unknown mask selection " + mask_type)
|
||||
|
||||
if sum(lengths) == 0:
|
||||
lengths[0] = min(mask_length, sz - 1)
|
||||
|
||||
if no_overlap:
|
||||
mask_idc = []
|
||||
|
||||
def arrange(s, e, length, keep_length):
|
||||
span_start = np.random.randint(s, e - length)
|
||||
mask_idc.extend(span_start + i for i in range(length))
|
||||
|
||||
new_parts = []
|
||||
if span_start - s - min_space >= keep_length:
|
||||
new_parts.append((s, span_start - min_space + 1))
|
||||
if e - span_start - keep_length - min_space > keep_length:
|
||||
new_parts.append((span_start + length + min_space, e))
|
||||
return new_parts
|
||||
|
||||
parts = [(0, sz)]
|
||||
min_length = min(lengths)
|
||||
for length in sorted(lengths, reverse=True):
|
||||
lens = np.fromiter(
|
||||
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
||||
np.int,
|
||||
)
|
||||
l_sum = np.sum(lens)
|
||||
if l_sum == 0:
|
||||
break
|
||||
probs = lens / np.sum(lens)
|
||||
c = np.random.choice(len(parts), p=probs)
|
||||
s, e = parts.pop(c)
|
||||
parts.extend(arrange(s, e, length, min_length))
|
||||
mask_idc = np.asarray(mask_idc)
|
||||
else:
|
||||
min_len = min(lengths)
|
||||
if sz - min_len <= num_mask:
|
||||
min_len = sz - num_mask - 1
|
||||
|
||||
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
||||
|
||||
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
||||
|
||||
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
||||
|
||||
min_len = min([len(m) for m in mask_idcs])
|
||||
for i, mask_idc in enumerate(mask_idcs):
|
||||
if len(mask_idc) > min_len:
|
||||
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
||||
mask[i, mask_idc] = True
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
class WavLMConfig:
|
||||
def __init__(self, cfg=None):
|
||||
self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
|
||||
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
||||
|
||||
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
||||
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
||||
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
||||
self.activation_fn: str = "gelu" # activation function to use
|
||||
|
||||
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
||||
self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
|
||||
self.conv_bias: bool = False # include bias in conv encoder
|
||||
self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
|
||||
|
||||
self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
|
||||
|
||||
# dropouts
|
||||
self.dropout: float = 0.1 # dropout probability for the transformer
|
||||
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
||||
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
||||
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
||||
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
||||
self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
|
||||
|
||||
# masking
|
||||
self.mask_length: int = 10 # mask length
|
||||
self.mask_prob: float = 0.65 # probability of replacing a token with mask
|
||||
self.mask_selection: str = "static" # how to choose mask length
|
||||
self.mask_other: float = (
|
||||
0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
|
||||
)
|
||||
self.no_mask_overlap: bool = False # whether to allow masks to overlap
|
||||
self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
||||
|
||||
# channel masking
|
||||
self.mask_channel_length: int = 10 # length of the mask for features (channels)
|
||||
self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
|
||||
self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
|
||||
self.mask_channel_other: float = (
|
||||
0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
|
||||
)
|
||||
self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
|
||||
self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
||||
|
||||
# positional embeddings
|
||||
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
||||
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
||||
|
||||
# relative position embedding
|
||||
self.relative_position_embedding: bool = False # apply relative position embedding
|
||||
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
||||
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
||||
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
||||
|
||||
if cfg is not None:
|
||||
self.update(cfg)
|
||||
|
||||
def update(self, cfg: dict):
|
||||
self.__dict__.update(cfg)
|
||||
|
||||
|
||||
class WavLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg: WavLMConfig,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
logger.info(f"WavLM Config: {cfg.__dict__}")
|
||||
|
||||
self.cfg = cfg
|
||||
feature_enc_layers = eval(cfg.conv_feature_layers)
|
||||
self.embed = feature_enc_layers[-1][0]
|
||||
|
||||
self.feature_extractor = ConvFeatureExtractionModel(
|
||||
conv_layers=feature_enc_layers,
|
||||
dropout=0.0,
|
||||
mode=cfg.extractor_mode,
|
||||
conv_bias=cfg.conv_bias,
|
||||
)
|
||||
|
||||
self.post_extract_proj = (
|
||||
nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None
|
||||
)
|
||||
|
||||
self.mask_prob = cfg.mask_prob
|
||||
self.mask_selection = cfg.mask_selection
|
||||
self.mask_other = cfg.mask_other
|
||||
self.mask_length = cfg.mask_length
|
||||
self.no_mask_overlap = cfg.no_mask_overlap
|
||||
self.mask_min_space = cfg.mask_min_space
|
||||
|
||||
self.mask_channel_prob = cfg.mask_channel_prob
|
||||
self.mask_channel_selection = cfg.mask_channel_selection
|
||||
self.mask_channel_other = cfg.mask_channel_other
|
||||
self.mask_channel_length = cfg.mask_channel_length
|
||||
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
||||
self.mask_channel_min_space = cfg.mask_channel_min_space
|
||||
|
||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
||||
|
||||
self.feature_grad_mult = cfg.feature_grad_mult
|
||||
|
||||
self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())
|
||||
|
||||
self.encoder = TransformerEncoder(cfg)
|
||||
self.layer_norm = LayerNorm(self.embed)
|
||||
|
||||
def apply_mask(self, x, padding_mask):
|
||||
B, T, C = x.shape
|
||||
if self.mask_prob > 0:
|
||||
mask_indices = compute_mask_indices(
|
||||
(B, T),
|
||||
padding_mask,
|
||||
self.mask_prob,
|
||||
self.mask_length,
|
||||
self.mask_selection,
|
||||
self.mask_other,
|
||||
min_masks=2,
|
||||
no_overlap=self.no_mask_overlap,
|
||||
min_space=self.mask_min_space,
|
||||
)
|
||||
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
||||
x[mask_indices] = self.mask_emb
|
||||
else:
|
||||
mask_indices = None
|
||||
|
||||
if self.mask_channel_prob > 0:
|
||||
mask_channel_indices = compute_mask_indices(
|
||||
(B, C),
|
||||
None,
|
||||
self.mask_channel_prob,
|
||||
self.mask_channel_length,
|
||||
self.mask_channel_selection,
|
||||
self.mask_channel_other,
|
||||
no_overlap=self.no_mask_channel_overlap,
|
||||
min_space=self.mask_channel_min_space,
|
||||
)
|
||||
mask_channel_indices = torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1)
|
||||
x[mask_channel_indices] = 0
|
||||
|
||||
return x, mask_indices
|
||||
|
||||
def forward_padding_mask(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
extra = padding_mask.size(1) % features.size(1)
|
||||
if extra > 0:
|
||||
padding_mask = padding_mask[:, :-extra]
|
||||
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
||||
# padding_mask = padding_mask.all(-1)
|
||||
padding_mask = padding_mask.any(-1)
|
||||
return padding_mask
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
mask: bool = False,
|
||||
ret_conv: bool = False,
|
||||
output_layer: Optional[int] = None,
|
||||
ret_layer_results: bool = False,
|
||||
):
|
||||
if self.feature_grad_mult > 0:
|
||||
features = self.feature_extractor(source)
|
||||
if self.feature_grad_mult != 1.0:
|
||||
features = GradMultiply.apply(features, self.feature_grad_mult)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
features = self.feature_extractor(source)
|
||||
|
||||
features = features.transpose(1, 2)
|
||||
features = self.layer_norm(features)
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||
|
||||
if self.post_extract_proj is not None:
|
||||
features = self.post_extract_proj(features)
|
||||
|
||||
features = self.dropout_input(features)
|
||||
|
||||
if mask:
|
||||
x, mask_indices = self.apply_mask(features, padding_mask)
|
||||
else:
|
||||
x = features
|
||||
|
||||
# feature: (B, T, D), float
|
||||
# target: (B, T), long
|
||||
# x: (B, T, D), float
|
||||
# padding_mask: (B, T), bool
|
||||
# mask_indices: (B, T), bool
|
||||
x, layer_results = self.encoder(
|
||||
x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1
|
||||
)
|
||||
|
||||
res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
|
||||
|
||||
feature = res["features"] if ret_conv else res["x"]
|
||||
if ret_layer_results:
|
||||
feature = (feature, res["layer_results"])
|
||||
return feature, res["padding_mask"]
|
||||
|
||||
|
||||
class ConvFeatureExtractionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conv_layers: List[Tuple[int, int, int]],
|
||||
dropout: float = 0.0,
|
||||
mode: str = "default",
|
||||
conv_bias: bool = False,
|
||||
conv_type: str = "default",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert mode in {"default", "layer_norm"}
|
||||
|
||||
def block(
|
||||
n_in,
|
||||
n_out,
|
||||
k,
|
||||
stride,
|
||||
is_layer_norm=False,
|
||||
is_group_norm=False,
|
||||
conv_bias=False,
|
||||
):
|
||||
def make_conv():
|
||||
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
||||
nn.init.kaiming_normal_(conv.weight)
|
||||
return conv
|
||||
|
||||
assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive"
|
||||
|
||||
if is_layer_norm:
|
||||
return nn.Sequential(
|
||||
make_conv(),
|
||||
nn.Dropout(p=dropout),
|
||||
nn.Sequential(
|
||||
TransposeLast(),
|
||||
Fp32LayerNorm(dim, elementwise_affine=True),
|
||||
TransposeLast(),
|
||||
),
|
||||
nn.GELU(),
|
||||
)
|
||||
elif is_group_norm:
|
||||
return nn.Sequential(
|
||||
make_conv(),
|
||||
nn.Dropout(p=dropout),
|
||||
Fp32GroupNorm(dim, dim, affine=True),
|
||||
nn.GELU(),
|
||||
)
|
||||
else:
|
||||
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
||||
|
||||
self.conv_type = conv_type
|
||||
if self.conv_type == "default":
|
||||
in_d = 1
|
||||
self.conv_layers = nn.ModuleList()
|
||||
for i, cl in enumerate(conv_layers):
|
||||
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
||||
(dim, k, stride) = cl
|
||||
|
||||
self.conv_layers.append(
|
||||
block(
|
||||
in_d,
|
||||
dim,
|
||||
k,
|
||||
stride,
|
||||
is_layer_norm=mode == "layer_norm",
|
||||
is_group_norm=mode == "default" and i == 0,
|
||||
conv_bias=conv_bias,
|
||||
)
|
||||
)
|
||||
in_d = dim
|
||||
elif self.conv_type == "conv2d":
|
||||
in_d = 1
|
||||
self.conv_layers = nn.ModuleList()
|
||||
for i, cl in enumerate(conv_layers):
|
||||
assert len(cl) == 3
|
||||
(dim, k, stride) = cl
|
||||
|
||||
self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride))
|
||||
self.conv_layers.append(torch.nn.ReLU())
|
||||
in_d = dim
|
||||
elif self.conv_type == "custom":
|
||||
in_d = 1
|
||||
idim = 80
|
||||
self.conv_layers = nn.ModuleList()
|
||||
for i, cl in enumerate(conv_layers):
|
||||
assert len(cl) == 3
|
||||
(dim, k, stride) = cl
|
||||
self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride, padding=1))
|
||||
self.conv_layers.append(torch.nn.LayerNorm([dim, idim]))
|
||||
self.conv_layers.append(torch.nn.ReLU())
|
||||
in_d = dim
|
||||
if (i + 1) % 2 == 0:
|
||||
self.conv_layers.append(torch.nn.MaxPool2d(2, stride=2, ceil_mode=True))
|
||||
idim = int(math.ceil(idim / 2))
|
||||
else:
|
||||
pass
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
# BxT -> BxCxT
|
||||
x = x.unsqueeze(1)
|
||||
if self.conv_type == "custom":
|
||||
for conv in self.conv_layers:
|
||||
if isinstance(conv, nn.LayerNorm):
|
||||
x = x.transpose(1, 2)
|
||||
x = conv(x).transpose(1, 2)
|
||||
else:
|
||||
x = conv(x)
|
||||
x = x.transpose(2, 3).contiguous()
|
||||
x = x.view(x.size(0), -1, x.size(-1))
|
||||
else:
|
||||
for conv in self.conv_layers:
|
||||
x = conv(x)
|
||||
if self.conv_type == "conv2d":
|
||||
b, c, t, f = x.size()
|
||||
x = x.transpose(2, 3).contiguous().view(b, c * f, t)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
|
||||
self.dropout = args.dropout
|
||||
self.embedding_dim = args.encoder_embed_dim
|
||||
|
||||
self.pos_conv = nn.Conv1d(
|
||||
self.embedding_dim,
|
||||
self.embedding_dim,
|
||||
kernel_size=args.conv_pos,
|
||||
padding=args.conv_pos // 2,
|
||||
groups=args.conv_pos_groups,
|
||||
)
|
||||
dropout = 0
|
||||
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
||||
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
||||
nn.init.constant_(self.pos_conv.bias, 0)
|
||||
|
||||
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
||||
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
||||
|
||||
if hasattr(args, "relative_position_embedding"):
|
||||
self.relative_position_embedding = args.relative_position_embedding
|
||||
self.num_buckets = args.num_buckets
|
||||
self.max_distance = args.max_distance
|
||||
else:
|
||||
self.relative_position_embedding = False
|
||||
self.num_buckets = 0
|
||||
self.max_distance = 0
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
TransformerSentenceEncoderLayer(
|
||||
embedding_dim=self.embedding_dim,
|
||||
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||
num_attention_heads=args.encoder_attention_heads,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=args.attention_dropout,
|
||||
activation_dropout=args.activation_dropout,
|
||||
activation_fn=args.activation_fn,
|
||||
layer_norm_first=args.layer_norm_first,
|
||||
has_relative_attention_bias=(self.relative_position_embedding and i == 0),
|
||||
num_buckets=self.num_buckets,
|
||||
max_distance=self.max_distance,
|
||||
gru_rel_pos=args.gru_rel_pos,
|
||||
)
|
||||
for i in range(args.encoder_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.layer_norm_first = args.layer_norm_first
|
||||
self.layer_norm = LayerNorm(self.embedding_dim)
|
||||
self.layerdrop = args.encoder_layerdrop
|
||||
|
||||
self.apply(init_bert_params)
|
||||
|
||||
def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
|
||||
x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
|
||||
|
||||
if self.layer_norm_first and layer is None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
return x, layer_results
|
||||
|
||||
def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
|
||||
if padding_mask is not None:
|
||||
x[padding_mask] = 0
|
||||
|
||||
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||
x_conv = x_conv.transpose(1, 2)
|
||||
x += x_conv
|
||||
|
||||
if not self.layer_norm_first:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
layer_results = []
|
||||
z = None
|
||||
if tgt_layer is not None:
|
||||
layer_results.append((x, z))
|
||||
r = None
|
||||
pos_bias = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
dropout_probability = np.random.random()
|
||||
if not self.training or (dropout_probability > self.layerdrop):
|
||||
x, z, pos_bias = layer(
|
||||
x,
|
||||
self_attn_padding_mask=padding_mask,
|
||||
need_weights=False,
|
||||
self_attn_mask=streaming_mask,
|
||||
pos_bias=pos_bias,
|
||||
)
|
||||
if tgt_layer is not None:
|
||||
layer_results.append((x, z))
|
||||
if i == tgt_layer:
|
||||
r = x
|
||||
break
|
||||
|
||||
if r is not None:
|
||||
x = r
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
return x, layer_results
|
||||
|
||||
|
||||
class TransformerSentenceEncoderLayer(nn.Module):
|
||||
"""
|
||||
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
||||
models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: float = 768,
|
||||
ffn_embedding_dim: float = 3072,
|
||||
num_attention_heads: float = 8,
|
||||
dropout: float = 0.1,
|
||||
attention_dropout: float = 0.1,
|
||||
activation_dropout: float = 0.1,
|
||||
activation_fn: str = "relu",
|
||||
layer_norm_first: bool = False,
|
||||
has_relative_attention_bias: bool = False,
|
||||
num_buckets: int = 0,
|
||||
max_distance: int = 0,
|
||||
rescale_init: bool = False,
|
||||
gru_rel_pos: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Initialize parameters
|
||||
self.embedding_dim = embedding_dim
|
||||
self.dropout = dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
|
||||
# Initialize blocks
|
||||
self.activation_name = activation_fn
|
||||
self.activation_fn = get_activation_fn(activation_fn)
|
||||
self.self_attn = MultiheadAttention(
|
||||
self.embedding_dim,
|
||||
num_attention_heads,
|
||||
dropout=attention_dropout,
|
||||
self_attention=True,
|
||||
has_relative_attention_bias=has_relative_attention_bias,
|
||||
num_buckets=num_buckets,
|
||||
max_distance=max_distance,
|
||||
rescale_init=rescale_init,
|
||||
gru_rel_pos=gru_rel_pos,
|
||||
)
|
||||
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(self.activation_dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.layer_norm_first = layer_norm_first
|
||||
|
||||
# layer norm associated with the self attention layer
|
||||
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
||||
|
||||
if self.activation_name == "glu":
|
||||
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
||||
else:
|
||||
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
||||
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
||||
|
||||
# layer norm associated with the position wise feed-forward NN
|
||||
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
self_attn_mask: torch.Tensor = None,
|
||||
self_attn_padding_mask: torch.Tensor = None,
|
||||
need_weights: bool = False,
|
||||
pos_bias=None,
|
||||
):
|
||||
"""
|
||||
LayerNorm is applied either before or after the self-attention/ffn
|
||||
modules similar to the original Transformer imlementation.
|
||||
"""
|
||||
residual = x
|
||||
|
||||
if self.layer_norm_first:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, attn, pos_bias = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
need_weights=False,
|
||||
attn_mask=self_attn_mask,
|
||||
position_bias=pos_bias,
|
||||
)
|
||||
x = self.dropout1(x)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.final_layer_norm(x)
|
||||
if self.activation_name == "glu":
|
||||
x = self.fc1(x)
|
||||
else:
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout3(x)
|
||||
x = residual + x
|
||||
else:
|
||||
x, attn, pos_bias = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=self_attn_mask,
|
||||
position_bias=pos_bias,
|
||||
)
|
||||
|
||||
x = self.dropout1(x)
|
||||
x = residual + x
|
||||
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
residual = x
|
||||
if self.activation_name == "glu":
|
||||
x = self.fc1(x)
|
||||
else:
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout3(x)
|
||||
x = residual + x
|
||||
x = self.final_layer_norm(x)
|
||||
|
||||
return x, attn, pos_bias
|
|
@ -18,6 +18,8 @@ class BaseVocoder(BaseTrainerModel):
|
|||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
MODEL_TYPE = "vocoder"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self._set_model_args(config)
|
||||
|
|
|
@ -36,8 +36,8 @@ Run a tts and a vocoder model from the released model list. Note that not every
|
|||
|
||||
```bash
|
||||
tts --text "Text for TTS" \
|
||||
--model_name "<type>/<language>/<dataset>/<model_name>" \
|
||||
--vocoder_name "<type>/<language>/<dataset>/<model_name>" \
|
||||
--model_name "tts_models/<language>/<dataset>/<model_name>" \
|
||||
--vocoder_name "vocoder_models/<language>/<dataset>/<model_name>" \
|
||||
--out_path folder/to/save/output.wav
|
||||
```
|
||||
|
||||
|
@ -64,8 +64,17 @@ tts --text "Text for TTS" \
|
|||
Run a multi-speaker TTS model from the released models list.
|
||||
|
||||
```bash
|
||||
tts --model_name "<type>/<language>/<dataset>/<model_name>" --list_speaker_idxs # list the possible speaker IDs.
|
||||
tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx "<speaker_id>"
|
||||
tts --model_name "tts_models/<language>/<dataset>/<model_name>" --list_speaker_idxs # list the possible speaker IDs.
|
||||
tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "tts_models/<language>/<dataset>/<model_name>" --speaker_idx "<speaker_id>"
|
||||
```
|
||||
|
||||
Run a released voice conversion model
|
||||
|
||||
```bash
|
||||
tts --model_name "voice_conversion/<language>/<dataset>/<model_name>"
|
||||
--source_wav "my/source/speaker/audio.wav"
|
||||
--target_wav "my/target/speaker/audio.wav"
|
||||
--out_path folder/to/save/output.wav
|
||||
```
|
||||
|
||||
**Note:** You can use ```./TTS/bin/synthesize.py``` if you prefer running ```tts``` from the TTS project folder.
|
||||
|
@ -135,4 +144,23 @@ tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_
|
|||
tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
|
||||
tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="output.wav")
|
||||
tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav")
|
||||
```
|
||||
|
||||
Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav`
|
||||
|
||||
```python
|
||||
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True)
|
||||
tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav")
|
||||
```
|
||||
|
||||
Example voice cloning by a single speaker TTS model combining with the voice conversion model. This way, you can
|
||||
clone voices by using any model in 🐸TTS.
|
||||
|
||||
```python
|
||||
tts = TTS("tts_models/de/thorsten/tacotron2-DDC")
|
||||
tts.tts_with_vc_to_file(
|
||||
"Wie sage ich auf Italienisch, dass ich dich liebe?",
|
||||
speaker_wav="target/speaker.wav",
|
||||
file_path="ouptut.wav"
|
||||
)
|
||||
```
|
|
@ -14,6 +14,7 @@ tqdm
|
|||
anyascii
|
||||
pyyaml
|
||||
fsspec>=2021.04.0
|
||||
aiohttp
|
||||
packaging
|
||||
# deps for examples
|
||||
flask
|
||||
|
|
|
@ -28,7 +28,7 @@ class TTSTest(unittest.TestCase):
|
|||
|
||||
def test_multi_speaker_multi_lingual_model(self):
|
||||
tts = TTS()
|
||||
tts.load_model_by_name(tts.models[0]) # YourTTS
|
||||
tts.load_tts_model_by_name(tts.models[0]) # YourTTS
|
||||
tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path=OUTPUT_PATH)
|
||||
|
||||
self.assertTrue(tts.is_multi_speaker)
|
||||
|
@ -38,5 +38,5 @@ class TTSTest(unittest.TestCase):
|
|||
|
||||
def test_voice_cloning(self): # pylint: disable=no-self-use
|
||||
tts = TTS()
|
||||
tts.load_model_by_name("tts_models/multilingual/multi-dataset/your_tts")
|
||||
tts.load_tts_model_by_name("tts_models/multilingual/multi-dataset/your_tts")
|
||||
tts.tts_to_file("Hello world!", speaker_wav=cloning_test_wav_path, language="en", file_path=OUTPUT_PATH)
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||
from TTS.vc.models.freevc import FreeVC
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
# pylint: disable=no-self-use
|
||||
|
||||
torch.manual_seed(1)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
c = FreeVCConfig()
|
||||
|
||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||
BATCH_SIZE = 3
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
r"""Count number of trainable parameters in a network"""
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
class TestFreeVC(unittest.TestCase):
|
||||
def _create_inputs(self, config, batch_size=2):
|
||||
input_dummy = torch.rand(batch_size, 30 * config.audio["hop_length"]).to(device)
|
||||
input_lengths = torch.randint(100, 30 * config.audio["hop_length"], (batch_size,)).long().to(device)
|
||||
input_lengths[-1] = 30 * config.audio["hop_length"]
|
||||
spec = torch.rand(batch_size, 30, config.audio["filter_length"] // 2 + 1).to(device)
|
||||
mel = torch.rand(batch_size, 30, config.audio["n_mel_channels"]).to(device)
|
||||
spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
|
||||
spec_lengths[-1] = spec.size(2)
|
||||
waveform = torch.rand(batch_size, spec.size(2) * config.audio["hop_length"]).to(device)
|
||||
return input_dummy, input_lengths, mel, spec, spec_lengths, waveform
|
||||
|
||||
@staticmethod
|
||||
def _create_inputs_inference():
|
||||
source_wav = torch.rand(16000)
|
||||
target_wav = torch.rand(16000)
|
||||
return source_wav, target_wav
|
||||
|
||||
@staticmethod
|
||||
def _check_parameter_changes(model, model_ref):
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref
|
||||
)
|
||||
count += 1
|
||||
|
||||
def test_methods(self):
|
||||
config = FreeVCConfig()
|
||||
model = FreeVC(config).to(device)
|
||||
model.load_pretrained_speaker_encoder()
|
||||
model.init_multispeaker(config)
|
||||
wavlm_feats = model.extract_wavlm_features(torch.rand(1, 16000))
|
||||
assert wavlm_feats.shape == (1, 1024, 49), wavlm_feats.shape
|
||||
|
||||
def test_load_audio(self):
|
||||
config = FreeVCConfig()
|
||||
model = FreeVC(config).to(device)
|
||||
wav = model.load_audio(WAV_FILE)
|
||||
wav2 = model.load_audio(wav)
|
||||
assert all(torch.isclose(wav, wav2))
|
||||
|
||||
def _test_forward(self, batch_size):
|
||||
# create model
|
||||
config = FreeVCConfig()
|
||||
model = FreeVC(config).to(device)
|
||||
model.train()
|
||||
print(" > Num parameters for FreeVC model:%s" % (count_parameters(model)))
|
||||
|
||||
_, _, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size)
|
||||
|
||||
wavlm_vec = model.extract_wavlm_features(waveform)
|
||||
wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long)
|
||||
|
||||
y = model.forward(wavlm_vec, spec, None, mel, spec_lengths, wavlm_vec_lengths)
|
||||
# TODO: assert with training implementation
|
||||
|
||||
def test_forward(self):
|
||||
self._test_forward(1)
|
||||
self._test_forward(3)
|
||||
|
||||
def _test_inference(self, batch_size):
|
||||
config = FreeVCConfig()
|
||||
model = FreeVC(config).to(device)
|
||||
model.eval()
|
||||
|
||||
_, _, mel, _, _, waveform = self._create_inputs(config, batch_size)
|
||||
|
||||
wavlm_vec = model.extract_wavlm_features(waveform)
|
||||
wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long)
|
||||
|
||||
output_wav = model.inference(wavlm_vec, None, mel, wavlm_vec_lengths)
|
||||
assert (
|
||||
output_wav.shape[-1] // config.audio.hop_length == wavlm_vec.shape[-1]
|
||||
), f"{output_wav.shape[-1] // config.audio.hop_length} != {wavlm_vec.shape}"
|
||||
|
||||
def test_inference(self):
|
||||
self._test_inference(1)
|
||||
self._test_inference(3)
|
||||
|
||||
def test_voice_conversion(self):
|
||||
config = FreeVCConfig()
|
||||
model = FreeVC(config).to(device)
|
||||
model.eval()
|
||||
|
||||
source_wav, target_wav = self._create_inputs_inference()
|
||||
output_wav = model.voice_conversion(source_wav, target_wav)
|
||||
assert (
|
||||
output_wav.shape[0] + config.audio.hop_length == source_wav.shape[0]
|
||||
), f"{output_wav.shape} != {source_wav.shape}"
|
||||
|
||||
def test_train_step(self):
|
||||
...
|
||||
|
||||
def test_train_eval_log(self):
|
||||
...
|
||||
|
||||
def test_test_run(self):
|
||||
...
|
||||
|
||||
def test_load_checkpoint(self):
|
||||
...
|
||||
|
||||
def test_get_criterion(self):
|
||||
...
|
||||
|
||||
def test_init_from_config(self):
|
||||
...
|
|
@ -51,6 +51,13 @@ def run_models(offset=0, step=1):
|
|||
# remove downloaded models
|
||||
shutil.rmtree(local_download_dir)
|
||||
shutil.rmtree(get_user_data_dir("tts"))
|
||||
elif "voice_conversion_models" in model_name:
|
||||
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
||||
reference_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0032.wav")
|
||||
run_cli(
|
||||
f"tts --model_name {model_name} "
|
||||
f'--out_path "{output_path}" --source_wav "{speaker_wav}" --target_wav "{reference_wav}" --progress_bar False'
|
||||
)
|
||||
else:
|
||||
# only download the model
|
||||
manager.download_model(model_name)
|
||||
|
|
Loading…
Reference in New Issue