Merge pull request #184 from idiap/xtts-error

fix(xtts): clearer error message when file given to checkpoint_dir
pull/4115/head^2
Enno Hermann 2024-12-06 06:46:48 +01:00 committed by GitHub
commit e8d99aaf2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 25 additions and 20 deletions

View File

@ -5,7 +5,8 @@ from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
from TTS.utils.manage import ModelManager

View File

@ -18,7 +18,7 @@ from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.tts.models.xtts import Xtts, XttsArgs
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__)
@ -34,11 +34,6 @@ class GPTTrainerConfig(XttsConfig):
test_sentences: List[dict] = field(default_factory=lambda: [])
@dataclass
class XttsAudioConfig(XttsAudioConfig):
dvae_sample_rate: int = 22050
@dataclass
class GPTArgs(XttsArgs):
min_conditioning_length: int = 66150

View File

@ -2,6 +2,7 @@ import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import librosa
import torch
@ -101,10 +102,12 @@ class XttsAudioConfig(Coqpit):
Args:
sample_rate (int): The sample rate in which the GPT operates.
output_sample_rate (int): The sample rate of the output audio waveform.
dvae_sample_rate (int): The sample rate of the DVAE
"""
sample_rate: int = 22050
output_sample_rate: int = 24000
dvae_sample_rate: int = 22050
@dataclass
@ -719,14 +722,14 @@ class Xtts(BaseTTS):
def load_checkpoint(
self,
config,
checkpoint_dir=None,
checkpoint_path=None,
vocab_path=None,
eval=True,
strict=True,
use_deepspeed=False,
speaker_file_path=None,
config: "XttsConfig",
checkpoint_dir: Optional[str] = None,
checkpoint_path: Optional[str] = None,
vocab_path: Optional[str] = None,
eval: bool = True,
strict: bool = True,
use_deepspeed: bool = False,
speaker_file_path: Optional[str] = None,
):
"""
Loads a checkpoint from disk and initializes the model's state and tokenizer.
@ -742,7 +745,9 @@ class Xtts(BaseTTS):
Returns:
None
"""
if checkpoint_dir is not None and Path(checkpoint_dir).is_file():
msg = f"You passed a file to `checkpoint_dir=`. Use `checkpoint_path={checkpoint_dir}` instead."
raise ValueError(msg)
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
if vocab_path is None:
if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file():

View File

@ -4,7 +4,8 @@ from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
from TTS.utils.manage import ModelManager
# Logging parameters

View File

@ -4,7 +4,8 @@ from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
from TTS.utils.manage import ModelManager
# Logging parameters

View File

@ -8,7 +8,8 @@ from tests import get_tests_output_path
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
config_dataset = BaseDatasetConfig(
formatter="ljspeech",

View File

@ -8,7 +8,8 @@ from tests import get_tests_output_path
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig
from TTS.tts.models.xtts import XttsAudioConfig
config_dataset = BaseDatasetConfig(
formatter="ljspeech",