update server and synthesizer to handle ParallelWaveGAN

pull/10/head
root 2020-02-04 17:09:59 +01:00
parent 9809fda144
commit 5c78816f51
2 changed files with 38 additions and 17 deletions

View File

@ -14,10 +14,13 @@ def create_argparser():
parser.add_argument('--tts_checkpoint', type=str, help='path to TTS checkpoint file')
parser.add_argument('--tts_config', type=str, help='path to TTS config.json file')
parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model')
parser.add_argument('--wavernn_lib_path', type=str, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
parser.add_argument('--wavernn_file', type=str, help='path to WaveRNN checkpoint file.')
parser.add_argument('--wavernn_config', type=str, help='path to WaveRNN config file.')
parser.add_argument('--wavernn_lib_path', type=str, default=None, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
parser.add_argument('--wavernn_file', type=str, default=None, help='path to WaveRNN checkpoint file.')
parser.add_argument('--wavernn_config', type=str, default=None, help='path to WaveRNN config file.')
parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.')
parser.add_argument('--pwgan_lib_path', type=str, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
parser.add_argument('--pwgan_file', type=str, help='path to ParallelWaveGAN checkpoint file.')
parser.add_argument('--pwgan_config', type=str, help='path to ParallelWaveGAN config file.')
parser.add_argument('--port', type=int, default=5002, help='port to listen on.')
parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.')
parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.')

View File

@ -1,17 +1,18 @@
import io
import os
import re
import sys
import numpy as np
import torch
import sys
import yaml
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import load_config, setup_model
from TTS.utils.text import phonemes, symbols
from TTS.utils.speakers import load_speaker_mapping
from TTS.utils.synthesis import *
from TTS.utils.text import phonemes, symbols
import re
alphabets = r"([A-Za-z])"
prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]"
suffixes = r"(Inc|Ltd|Jr|Sr|Co)"
@ -23,6 +24,7 @@ websites = r"[.](com|net|org|io|gov)"
class Synthesizer(object):
def __init__(self, config):
self.wavernn = None
self.pwgan = None
self.config = config
self.use_cuda = self.config.use_cuda
if self.use_cuda:
@ -30,9 +32,11 @@ class Synthesizer(object):
self.load_tts(self.config.tts_checkpoint, self.config.tts_config,
self.config.use_cuda)
if self.config.wavernn_lib_path:
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_path,
self.config.wavernn_file, self.config.wavernn_config,
self.config.use_cuda)
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file,
self.config.wavernn_config, self.config.use_cuda)
if self.config.pwgan_lib_path:
self.load_pwgan(self.config.pwgan_lib_path, self.config.pwgan_file,
self.config.pwgan_config, self.config.use_cuda)
def load_tts(self, tts_checkpoint, tts_config, use_cuda):
print(" > Loading TTS model ...")
@ -45,9 +49,9 @@ class Synthesizer(object):
self.input_size = len(phonemes)
else:
self.input_size = len(symbols)
# load speakers
# TODO: fix this for multi-speaker model - load speakers
if self.config.tts_speakers is not None:
self.tts_speakers = load_speaker_mapping(os.path.join(model_path, self.config.tts_speakers))
self.tts_speakers = load_speaker_mapping(self.config.tts_speakers)
num_speakers = len(self.tts_speakers)
else:
num_speakers = 0
@ -63,16 +67,14 @@ class Synthesizer(object):
if 'r' in cp:
self.tts_model.decoder.set_r(cp['r'])
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):
def load_wavernn(self, lib_path, model_file, model_config, use_cuda):
# TODO: set a function in wavernn code base for model setup and call it here.
sys.path.append(lib_path) # set this if TTS is not installed globally
from WaveRNN.models.wavernn import Model
wavernn_config = os.path.join(model_path, model_config)
model_file = os.path.join(model_path, model_file)
print(" > Loading WaveRNN model ...")
print(" | > model config: ", wavernn_config)
print(" | > model config: ", model_config)
print(" | > model file: ", model_file)
self.wavernn_config = load_config(wavernn_config)
self.wavernn_config = load_config(model_config)
self.wavernn = Model(
rnn_dims=512,
fc_dims=512,
@ -91,11 +93,27 @@ class Synthesizer(object):
).cuda()
check = torch.load(model_file)
self.wavernn.load_state_dict(check['model'])
self.wavernn.load_state_dict(check['model'], map_location="cpu")
if use_cuda:
self.wavernn.cuda()
self.wavernn.eval()
def load_pwgan(self, lib_path, model_file, model_config, use_cuda):
sys.path.append(lib_path) # set this if TTS is not installed globally
from parallel_wavegan.models import ParallelWaveGANGenerator
from parallel_wavegan.utils.audio import AudioProcessor as AudioProcessorVocoder
print(" > Loading PWGAN model ...")
print(" | > model config: ", model_config)
print(" | > model file: ", model_file)
with open(model_config) as f:
self.pwgan_config = yaml.load(f, Loader=yaml.Loader)
self.pwgan = ParallelWaveGANGenerator(**self.pwgan_config["generator_params"])
self.pwgan.load_state_dict(torch.load(model_file, map_location="cpu")["model"]["generator"])
self.pwgan.remove_weight_norm()
if use_cuda:
self.pwgan.cuda()
self.pwgan.eval()
def save_wav(self, wav, path):
# wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
wav = np.array(wav)