mirror of https://github.com/coqui-ai/TTS.git
93 lines
2.8 KiB
Python
93 lines
2.8 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
from network import *
|
|
from data import inv_spectrogram, find_endpoint, save_wav, spectrogram
|
|
import numpy as np
|
|
import argparse
|
|
import os
|
|
import sys
|
|
import io
|
|
from text import text_to_sequence
|
|
|
|
use_cuda = torch.cuda.is_available()
|
|
|
|
|
|
def main(args):
|
|
|
|
# Make model
|
|
if use_cuda:
|
|
model = nn.DataParallel(Tacotron().cuda())
|
|
|
|
# Load checkpoint
|
|
try:
|
|
checkpoint = torch.load(os.path.join(
|
|
hp.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step))
|
|
model.load_state_dict(checkpoint['model'])
|
|
print("\n--------model restored at step %d--------\n" %
|
|
args.restore_step)
|
|
|
|
except:
|
|
raise FileNotFoundError("\n------------Model not exists------------\n")
|
|
|
|
# Evaluation
|
|
model = model.eval()
|
|
|
|
# Make result folder if not exists
|
|
if not os.path.exists(hp.output_path):
|
|
os.mkdir(hp.output_path)
|
|
|
|
# Sentences for generation
|
|
sentences = [
|
|
"I try my best to translate text to speech. But I know I need more work",
|
|
"The new Firefox, Fast for good.",
|
|
"Technology is continually providing us with new ways to create and publish stories.",
|
|
"For these stories to achieve their full impact, it requires tool.",
|
|
"I am allien and I am here to destron your world."
|
|
]
|
|
|
|
# Synthesis and save to wav files
|
|
for i, text in enumerate(sentences):
|
|
wav = generate(model, text)
|
|
path = os.path.join(hp.output_path, 'result_%d_%d.wav' %
|
|
(args.restore_step, i + 1))
|
|
with open(path, 'wb') as f:
|
|
f.write(wav)
|
|
|
|
f.close()
|
|
print("save wav file at step %d ..." % (i + 1))
|
|
|
|
|
|
def generate(model, text):
|
|
|
|
# Text to index sequence
|
|
cleaner_names = [x.strip() for x in hp.cleaners.split(',')]
|
|
seq = np.expand_dims(np.asarray(text_to_sequence(
|
|
text, cleaner_names), dtype=np.int32), axis=0)
|
|
|
|
# Provide [GO] Frame
|
|
mel_input = np.zeros([seq.shape[0], hp.num_mels, 1], dtype=np.float32)
|
|
|
|
# Variables
|
|
characters = Variable(torch.from_numpy(seq).type(
|
|
torch.cuda.LongTensor), volatile=True).cuda()
|
|
mel_input = Variable(torch.from_numpy(mel_input).type(
|
|
torch.cuda.FloatTensor), volatile=True).cuda()
|
|
|
|
# Spectrogram to wav
|
|
_, linear_output = model.forward(characters, mel_input)
|
|
wav = inv_spectrogram(linear_output[0].data.cpu().numpy())
|
|
wav = wav[:find_endpoint(wav)]
|
|
out = io.BytesIO()
|
|
save_wav(wav, out)
|
|
|
|
return out.getvalue()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--restore_step', type=int,
|
|
help='Global step to restore checkpoint', default=0)
|
|
parser.add_argument('--batch_size', type=int, help='Batch size', default=1)
|
|
args = parser.parse_args()
|
|
main(args)
|