Add `base_model` field to `forward_tts` configs

pull/800/head
Eren Gölge 2021-09-10 17:23:48 +00:00
parent 22822cd41c
commit 66732025e1
4 changed files with 20 additions and 6 deletions

View File

@ -1182,7 +1182,6 @@ def process_args(args, config=None):
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path:
args.best_path = best_model
# init config if not already defined
if config is None:
if args.config_path:

View File

@ -18,6 +18,10 @@ class FastPitchConfig(BaseTTSConfig):
model (str):
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
base_model (str):
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
model_args (Coqpit):
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
@ -94,9 +98,11 @@ class FastPitchConfig(BaseTTSConfig):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
"""
model: str = "forward_tts"
model: str = "fast_pitch"
base_model: str = "forward_tts"
# model specific params
model_args: ForwardTTSArgs = field(default_factory=ForwardTTSArgs)
model_args: ForwardTTSArgs = ForwardTTSArgs()
# multi-speaker settings
use_speaker_embedding: bool = False

View File

@ -16,7 +16,11 @@ class SpeedySpeechConfig(BaseTTSConfig):
Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
Model name used for selecting the right model at initialization. Defaults to `speedy_speech`.
base_model (str):
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
model_args (Coqpit):
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
@ -91,7 +95,8 @@ class SpeedySpeechConfig(BaseTTSConfig):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
"""
model: str = "forward_tts"
model: str = "speedy_speech"
base_model: str = "forward_tts"
# set model args as SpeedySpeech
model_args: ForwardTTSArgs = ForwardTTSArgs(

View File

@ -4,7 +4,11 @@ from TTS.utils.generic_utils import find_module
def setup_model(config):
print(" > Using model: {}".format(config.model))
MyModel = find_module("TTS.tts.models", config.model.lower())
# fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None:
MyModel = find_module("TTS.tts.models", config.base_model.lower())
else:
MyModel = find_module("TTS.tts.models", config.model.lower())
# define set of characters used by the model
if config.characters is not None:
# set characters from config