fix(xtts): clearer error message when file given to checkpoint_dir

pull/4115/head^2
Enno Hermann 2024-12-02 16:54:11 +01:00
parent 98a372bca2
commit ce202532cf
1 changed files with 13 additions and 9 deletions

View File

@ -2,6 +2,7 @@ import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import librosa
import torch
@ -10,6 +11,7 @@ import torchaudio
from coqpit import Coqpit
from trainer.io import load_fsspec
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
@ -719,14 +721,14 @@ class Xtts(BaseTTS):
def load_checkpoint(
self,
config,
checkpoint_dir=None,
checkpoint_path=None,
vocab_path=None,
eval=True,
strict=True,
use_deepspeed=False,
speaker_file_path=None,
config: XttsConfig,
checkpoint_dir: Optional[str] = None,
checkpoint_path: Optional[str] = None,
vocab_path: Optional[str] = None,
eval: bool = True,
strict: bool = True,
use_deepspeed: bool = False,
speaker_file_path: Optional[str] = None,
):
"""
Loads a checkpoint from disk and initializes the model's state and tokenizer.
@ -742,7 +744,9 @@ class Xtts(BaseTTS):
Returns:
None
"""
if checkpoint_dir is not None and Path(checkpoint_dir).is_file():
msg = f"You passed a file to `checkpoint_dir=`. Use `checkpoint_path={checkpoint_dir}` instead."
raise ValueError(msg)
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
if vocab_path is None:
if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file():