mirror of https://github.com/coqui-ai/TTS.git
refactor: use save_fsspec() from Trainer
parent
fdf0c8b10a
commit
39fe38bda4
|
@ -5,10 +5,10 @@ import random
|
|||
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from trainer.io import save_fsspec
|
||||
|
||||
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
|
||||
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
||||
from TTS.utils.io import save_fsspec
|
||||
|
||||
|
||||
class AugmentWAV(object):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import datetime
|
||||
import os
|
||||
|
||||
from TTS.utils.io import save_fsspec
|
||||
from trainer.io import save_fsspec
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, Union
|
|||
import fsspec
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from trainer.io import save_fsspec
|
||||
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
|
||||
|
@ -102,18 +103,6 @@ def load_checkpoint(
|
|||
return model, state
|
||||
|
||||
|
||||
def save_fsspec(state: Any, path: str, **kwargs):
|
||||
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
|
||||
|
||||
Args:
|
||||
state: State object to save
|
||||
path: Any path or url supported by fsspec.
|
||||
**kwargs: Keyword arguments forwarded to torch.save.
|
||||
"""
|
||||
with fsspec.open(path, "wb") as f:
|
||||
torch.save(state, f, **kwargs)
|
||||
|
||||
|
||||
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
|
||||
if hasattr(model, "module"):
|
||||
model_state = model.module.state_dict()
|
||||
|
|
Loading…
Reference in New Issue