TTS/tests/generic_utils_text.py

36 lines
1.1 KiB
Python
Raw Normal View History

2018-02-27 15:32:09 +00:00
import unittest
import torch as T
2018-12-11 14:07:50 +00:00
from utils.generic_utils import save_checkpoint, save_best_model
from layers.tacotron import Prenet, CBHG, Decoder, Encoder
2018-02-27 15:32:09 +00:00
OUT_PATH = '/tmp/test.pth.tar'
2018-04-03 10:24:57 +00:00
2018-02-27 15:32:09 +00:00
class ModelSavingTests(unittest.TestCase):
def save_checkpoint_test(self):
# create a dummy model
model = Prenet(128, out_features=[256, 128])
model = T.nn.DataParallel(layer)
# save the model
2018-08-02 14:34:17 +00:00
save_checkpoint(model, None, 100, OUTPATH, 1, 1)
2018-02-27 15:32:09 +00:00
# load the model to CPU
2018-08-02 14:34:17 +00:00
model_dict = torch.load(
MODEL_PATH, map_location=lambda storage, loc: storage)
2018-02-27 15:32:09 +00:00
model.load_state_dict(model_dict['model'])
def save_best_model_test(self):
# create a dummy model
model = Prenet(256, out_features=[256, 256])
model = T.nn.DataParallel(layer)
# save the model
2018-08-02 14:34:17 +00:00
best_loss = save_best_model(model, None, 0, 100, OUT_PATH, 10, 1)
2018-02-27 15:32:09 +00:00
# load the model to CPU
2018-08-02 14:34:17 +00:00
model_dict = torch.load(
MODEL_PATH, map_location=lambda storage, loc: storage)
2018-02-27 15:32:09 +00:00
model.load_state_dict(model_dict['model'])