import unittest import torch as T from TTS.utils.generic_utils import save_checkpoint, save_best_model from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder OUT_PATH = '/tmp/test.pth.tar' class ModelSavingTests(unittest.TestCase): def save_checkpoint_test(self): # create a dummy model model = Prenet(128, out_features=[256, 128]) model = T.nn.DataParallel(layer) # save the model save_checkpoint(model, None, 100, OUTPATH, 1, 1) # load the model to CPU model_dict = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage) model.load_state_dict(model_dict['model']) def save_best_model_test(self): # create a dummy model model = Prenet(256, out_features=[256, 256]) model = T.nn.DataParallel(layer) # save the model best_loss = save_best_model(model, None, 0, 100, OUT_PATH, 10, 1) # load the model to CPU model_dict = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage) model.load_state_dict(model_dict['model'])