TTS/synthesis.py

93 lines
2.8 KiB
Python

#-*- coding: utf-8 -*-
from network import *
from data import inv_spectrogram, find_endpoint, save_wav, spectrogram
import numpy as np
import argparse
import os
import sys
import io
from text import text_to_sequence
use_cuda = torch.cuda.is_available()
def main(args):
# Make model
if use_cuda:
model = nn.DataParallel(Tacotron().cuda())
# Load checkpoint
try:
checkpoint = torch.load(os.path.join(
hp.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step))
model.load_state_dict(checkpoint['model'])
print("\n--------model restored at step %d--------\n" %
args.restore_step)
except:
raise FileNotFoundError("\n------------Model not exists------------\n")
# Evaluation
model = model.eval()
# Make result folder if not exists
if not os.path.exists(hp.output_path):
os.mkdir(hp.output_path)
# Sentences for generation
sentences = [
"I try my best to translate text to speech. But I know I need more work",
"The new Firefox, Fast for good.",
"Technology is continually providing us with new ways to create and publish stories.",
"For these stories to achieve their full impact, it requires tool.",
"I am allien and I am here to destron your world."
]
# Synthesis and save to wav files
for i, text in enumerate(sentences):
wav = generate(model, text)
path = os.path.join(hp.output_path, 'result_%d_%d.wav' %
(args.restore_step, i + 1))
with open(path, 'wb') as f:
f.write(wav)
f.close()
print("save wav file at step %d ..." % (i + 1))
def generate(model, text):
# Text to index sequence
cleaner_names = [x.strip() for x in hp.cleaners.split(',')]
seq = np.expand_dims(np.asarray(text_to_sequence(
text, cleaner_names), dtype=np.int32), axis=0)
# Provide [GO] Frame
mel_input = np.zeros([seq.shape[0], hp.num_mels, 1], dtype=np.float32)
# Variables
characters = Variable(torch.from_numpy(seq).type(
torch.cuda.LongTensor), volatile=True).cuda()
mel_input = Variable(torch.from_numpy(mel_input).type(
torch.cuda.FloatTensor), volatile=True).cuda()
# Spectrogram to wav
_, linear_output = model.forward(characters, mel_input)
wav = inv_spectrogram(linear_output[0].data.cpu().numpy())
wav = wav[:find_endpoint(wav)]
out = io.BytesIO()
save_wav(wav, out)
return out.getvalue()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--restore_step', type=int,
help='Global step to restore checkpoint', default=0)
parser.add_argument('--batch_size', type=int, help='Batch size', default=1)
args = parser.parse_args()
main(args)