mirror of https://github.com/coqui-ai/TTS.git
fix(xtts): clearer error message when file given to checkpoint_dir
parent
98a372bca2
commit
ce202532cf
|
@ -2,6 +2,7 @@ import logging
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
|
@ -10,6 +11,7 @@ import torchaudio
|
|||
from coqpit import Coqpit
|
||||
from trainer.io import load_fsspec
|
||||
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||
|
@ -719,14 +721,14 @@ class Xtts(BaseTTS):
|
|||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
config,
|
||||
checkpoint_dir=None,
|
||||
checkpoint_path=None,
|
||||
vocab_path=None,
|
||||
eval=True,
|
||||
strict=True,
|
||||
use_deepspeed=False,
|
||||
speaker_file_path=None,
|
||||
config: XttsConfig,
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
vocab_path: Optional[str] = None,
|
||||
eval: bool = True,
|
||||
strict: bool = True,
|
||||
use_deepspeed: bool = False,
|
||||
speaker_file_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
||||
|
@ -742,7 +744,9 @@ class Xtts(BaseTTS):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
if checkpoint_dir is not None and Path(checkpoint_dir).is_file():
|
||||
msg = f"You passed a file to `checkpoint_dir=`. Use `checkpoint_path={checkpoint_dir}` instead."
|
||||
raise ValueError(msg)
|
||||
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
||||
if vocab_path is None:
|
||||
if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file():
|
||||
|
|
Loading…
Reference in New Issue