update server and synthesizer to handle ParallelWaveGAN

pull/10/head
root 2020-02-04 17:09:59 +01:00 committed by erogol
parent fbe5310be0
commit ed8a9fc82a
2 changed files with 38 additions and 17 deletions

View File

@ -14,10 +14,13 @@ def create_argparser():
parser.add_argument('--tts_checkpoint', type=str, help='path to TTS checkpoint file') parser.add_argument('--tts_checkpoint', type=str, help='path to TTS checkpoint file')
parser.add_argument('--tts_config', type=str, help='path to TTS config.json file') parser.add_argument('--tts_config', type=str, help='path to TTS config.json file')
parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model') parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model')
parser.add_argument('--wavernn_lib_path', type=str, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') parser.add_argument('--wavernn_lib_path', type=str, default=None, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
parser.add_argument('--wavernn_file', type=str, help='path to WaveRNN checkpoint file.') parser.add_argument('--wavernn_file', type=str, default=None, help='path to WaveRNN checkpoint file.')
parser.add_argument('--wavernn_config', type=str, help='path to WaveRNN config file.') parser.add_argument('--wavernn_config', type=str, default=None, help='path to WaveRNN config file.')
parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.') parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.')
parser.add_argument('--pwgan_lib_path', type=str, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
parser.add_argument('--pwgan_file', type=str, help='path to ParallelWaveGAN checkpoint file.')
parser.add_argument('--pwgan_config', type=str, help='path to ParallelWaveGAN config file.')
parser.add_argument('--port', type=int, default=5002, help='port to listen on.') parser.add_argument('--port', type=int, default=5002, help='port to listen on.')
parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.') parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.')
parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.') parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.')

View File

@ -1,17 +1,18 @@
import io import io
import os import os
import re
import sys
import numpy as np import numpy as np
import torch import torch
import sys import yaml
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import load_config, setup_model from TTS.utils.generic_utils import load_config, setup_model
from TTS.utils.text import phonemes, symbols
from TTS.utils.speakers import load_speaker_mapping from TTS.utils.speakers import load_speaker_mapping
from TTS.utils.synthesis import * from TTS.utils.synthesis import *
from TTS.utils.text import phonemes, symbols
import re
alphabets = r"([A-Za-z])" alphabets = r"([A-Za-z])"
prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]" prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]"
suffixes = r"(Inc|Ltd|Jr|Sr|Co)" suffixes = r"(Inc|Ltd|Jr|Sr|Co)"
@ -23,6 +24,7 @@ websites = r"[.](com|net|org|io|gov)"
class Synthesizer(object): class Synthesizer(object):
def __init__(self, config): def __init__(self, config):
self.wavernn = None self.wavernn = None
self.pwgan = None
self.config = config self.config = config
self.use_cuda = self.config.use_cuda self.use_cuda = self.config.use_cuda
if self.use_cuda: if self.use_cuda:
@ -30,9 +32,11 @@ class Synthesizer(object):
self.load_tts(self.config.tts_checkpoint, self.config.tts_config, self.load_tts(self.config.tts_checkpoint, self.config.tts_config,
self.config.use_cuda) self.config.use_cuda)
if self.config.wavernn_lib_path: if self.config.wavernn_lib_path:
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_path, self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file,
self.config.wavernn_file, self.config.wavernn_config, self.config.wavernn_config, self.config.use_cuda)
self.config.use_cuda) if self.config.pwgan_lib_path:
self.load_pwgan(self.config.pwgan_lib_path, self.config.pwgan_file,
self.config.pwgan_config, self.config.use_cuda)
def load_tts(self, tts_checkpoint, tts_config, use_cuda): def load_tts(self, tts_checkpoint, tts_config, use_cuda):
print(" > Loading TTS model ...") print(" > Loading TTS model ...")
@ -45,9 +49,9 @@ class Synthesizer(object):
self.input_size = len(phonemes) self.input_size = len(phonemes)
else: else:
self.input_size = len(symbols) self.input_size = len(symbols)
# load speakers # TODO: fix this for multi-speaker model - load speakers
if self.config.tts_speakers is not None: if self.config.tts_speakers is not None:
self.tts_speakers = load_speaker_mapping(os.path.join(model_path, self.config.tts_speakers)) self.tts_speakers = load_speaker_mapping(self.config.tts_speakers)
num_speakers = len(self.tts_speakers) num_speakers = len(self.tts_speakers)
else: else:
num_speakers = 0 num_speakers = 0
@ -63,16 +67,14 @@ class Synthesizer(object):
if 'r' in cp: if 'r' in cp:
self.tts_model.decoder.set_r(cp['r']) self.tts_model.decoder.set_r(cp['r'])
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda): def load_wavernn(self, lib_path, model_file, model_config, use_cuda):
# TODO: set a function in wavernn code base for model setup and call it here. # TODO: set a function in wavernn code base for model setup and call it here.
sys.path.append(lib_path) # set this if TTS is not installed globally sys.path.append(lib_path) # set this if TTS is not installed globally
from WaveRNN.models.wavernn import Model from WaveRNN.models.wavernn import Model
wavernn_config = os.path.join(model_path, model_config)
model_file = os.path.join(model_path, model_file)
print(" > Loading WaveRNN model ...") print(" > Loading WaveRNN model ...")
print(" | > model config: ", wavernn_config) print(" | > model config: ", model_config)
print(" | > model file: ", model_file) print(" | > model file: ", model_file)
self.wavernn_config = load_config(wavernn_config) self.wavernn_config = load_config(model_config)
self.wavernn = Model( self.wavernn = Model(
rnn_dims=512, rnn_dims=512,
fc_dims=512, fc_dims=512,
@ -91,11 +93,27 @@ class Synthesizer(object):
).cuda() ).cuda()
check = torch.load(model_file) check = torch.load(model_file)
self.wavernn.load_state_dict(check['model']) self.wavernn.load_state_dict(check['model'], map_location="cpu")
if use_cuda: if use_cuda:
self.wavernn.cuda() self.wavernn.cuda()
self.wavernn.eval() self.wavernn.eval()
def load_pwgan(self, lib_path, model_file, model_config, use_cuda):
sys.path.append(lib_path) # set this if TTS is not installed globally
from parallel_wavegan.models import ParallelWaveGANGenerator
from parallel_wavegan.utils.audio import AudioProcessor as AudioProcessorVocoder
print(" > Loading PWGAN model ...")
print(" | > model config: ", model_config)
print(" | > model file: ", model_file)
with open(model_config) as f:
self.pwgan_config = yaml.load(f, Loader=yaml.Loader)
self.pwgan = ParallelWaveGANGenerator(**self.pwgan_config["generator_params"])
self.pwgan.load_state_dict(torch.load(model_file, map_location="cpu")["model"]["generator"])
self.pwgan.remove_weight_norm()
if use_cuda:
self.pwgan.cuda()
self.pwgan.eval()
def save_wav(self, wav, path): def save_wav(self, wav, path):
# wav *= 32767 / max(1e-8, np.max(np.abs(wav))) # wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
wav = np.array(wav) wav = np.array(wav)