refactor: use save_fsspec() from Trainer

pull/3243/head
Enno Hermann 2023-11-16 23:46:26 +01:00
parent fdf0c8b10a
commit 39fe38bda4
3 changed files with 3 additions and 14 deletions

View File

@ -5,10 +5,10 @@ import random
import numpy as np
from scipy import signal
from trainer.io import save_fsspec
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
from TTS.utils.io import save_fsspec
class AugmentWAV(object):

View File

@ -1,7 +1,7 @@
import datetime
import os
from TTS.utils.io import save_fsspec
from trainer.io import save_fsspec
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):

View File

@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, Union
import fsspec
import torch
from coqpit import Coqpit
from trainer.io import save_fsspec
from TTS.utils.generic_utils import get_user_data_dir
@ -102,18 +103,6 @@ def load_checkpoint(
return model, state
def save_fsspec(state: Any, path: str, **kwargs):
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
Args:
state: State object to save
path: Any path or url supported by fsspec.
**kwargs: Keyword arguments forwarded to torch.save.
"""
with fsspec.open(path, "wb") as f:
torch.save(state, f, **kwargs)
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
if hasattr(model, "module"):
model_state = model.module.state_dict()