diff --git a/TTS/trainer.py b/TTS/trainer.py index 2f3e9bb1..d90c5473 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -1182,7 +1182,6 @@ def process_args(args, config=None): args.restore_path, best_model = get_last_checkpoint(args.continue_path) if not args.best_path: args.best_path = best_model - # init config if not already defined if config is None: if args.config_path: diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py index 3840d1d6..668ea227 100644 --- a/TTS/tts/configs/fast_pitch_config.py +++ b/TTS/tts/configs/fast_pitch_config.py @@ -18,6 +18,10 @@ class FastPitchConfig(BaseTTSConfig): model (str): Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + model_args (Coqpit): Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. @@ -94,9 +98,11 @@ class FastPitchConfig(BaseTTSConfig): Maximum input sequence length to be used at training. Larger values result in more VRAM usage. """ - model: str = "forward_tts" + model: str = "fast_pitch" + base_model: str = "forward_tts" + # model specific params - model_args: ForwardTTSArgs = field(default_factory=ForwardTTSArgs) + model_args: ForwardTTSArgs = ForwardTTSArgs() # multi-speaker settings use_speaker_embedding: bool = False diff --git a/TTS/tts/configs/speedy_speech_config.py b/TTS/tts/configs/speedy_speech_config.py index bdfc2a6b..1ec8f729 100644 --- a/TTS/tts/configs/speedy_speech_config.py +++ b/TTS/tts/configs/speedy_speech_config.py @@ -16,7 +16,11 @@ class SpeedySpeechConfig(BaseTTSConfig): Args: model (str): - Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + Model name used for selecting the right model at initialization. Defaults to `speedy_speech`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. model_args (Coqpit): Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. @@ -91,7 +95,8 @@ class SpeedySpeechConfig(BaseTTSConfig): Maximum input sequence length to be used at training. Larger values result in more VRAM usage. """ - model: str = "forward_tts" + model: str = "speedy_speech" + base_model: str = "forward_tts" # set model args as SpeedySpeech model_args: ForwardTTSArgs = ForwardTTSArgs( diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index 8c1bd430..1236fa76 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -4,7 +4,11 @@ from TTS.utils.generic_utils import find_module def setup_model(config): print(" > Using model: {}".format(config.model)) - MyModel = find_module("TTS.tts.models", config.model.lower()) + # fetch the right model implementation. + if "base_model" in config and config["base_model"] is not None: + MyModel = find_module("TTS.tts.models", config.base_model.lower()) + else: + MyModel = find_module("TTS.tts.models", config.model.lower()) # define set of characters used by the model if config.characters is not None: # set characters from config