mirror of https://github.com/coqui-ai/TTS.git
Add `base_model` field to `forward_tts` configs
parent
22822cd41c
commit
66732025e1
|
@ -1182,7 +1182,6 @@ def process_args(args, config=None):
|
|||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||
if not args.best_path:
|
||||
args.best_path = best_model
|
||||
|
||||
# init config if not already defined
|
||||
if config is None:
|
||||
if args.config_path:
|
||||
|
|
|
@ -18,6 +18,10 @@ class FastPitchConfig(BaseTTSConfig):
|
|||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
|
||||
|
||||
base_model (str):
|
||||
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
|
||||
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
|
||||
|
||||
model_args (Coqpit):
|
||||
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
|
||||
|
||||
|
@ -94,9 +98,11 @@ class FastPitchConfig(BaseTTSConfig):
|
|||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||
"""
|
||||
|
||||
model: str = "forward_tts"
|
||||
model: str = "fast_pitch"
|
||||
base_model: str = "forward_tts"
|
||||
|
||||
# model specific params
|
||||
model_args: ForwardTTSArgs = field(default_factory=ForwardTTSArgs)
|
||||
model_args: ForwardTTSArgs = ForwardTTSArgs()
|
||||
|
||||
# multi-speaker settings
|
||||
use_speaker_embedding: bool = False
|
||||
|
|
|
@ -16,7 +16,11 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
|||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
|
||||
Model name used for selecting the right model at initialization. Defaults to `speedy_speech`.
|
||||
|
||||
base_model (str):
|
||||
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
|
||||
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
|
||||
|
||||
model_args (Coqpit):
|
||||
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
|
||||
|
@ -91,7 +95,8 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
|||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||
"""
|
||||
|
||||
model: str = "forward_tts"
|
||||
model: str = "speedy_speech"
|
||||
base_model: str = "forward_tts"
|
||||
|
||||
# set model args as SpeedySpeech
|
||||
model_args: ForwardTTSArgs = ForwardTTSArgs(
|
||||
|
|
|
@ -4,7 +4,11 @@ from TTS.utils.generic_utils import find_module
|
|||
|
||||
def setup_model(config):
|
||||
print(" > Using model: {}".format(config.model))
|
||||
MyModel = find_module("TTS.tts.models", config.model.lower())
|
||||
# fetch the right model implementation.
|
||||
if "base_model" in config and config["base_model"] is not None:
|
||||
MyModel = find_module("TTS.tts.models", config.base_model.lower())
|
||||
else:
|
||||
MyModel = find_module("TTS.tts.models", config.model.lower())
|
||||
# define set of characters used by the model
|
||||
if config.characters is not None:
|
||||
# set characters from config
|
||||
|
|
Loading…
Reference in New Issue