diff --git a/server/server.py b/server/server.py index 3be66f9e..6af119bf 100644 --- a/server/server.py +++ b/server/server.py @@ -14,10 +14,13 @@ def create_argparser(): parser.add_argument('--tts_checkpoint', type=str, help='path to TTS checkpoint file') parser.add_argument('--tts_config', type=str, help='path to TTS config.json file') parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model') - parser.add_argument('--wavernn_lib_path', type=str, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') - parser.add_argument('--wavernn_file', type=str, help='path to WaveRNN checkpoint file.') - parser.add_argument('--wavernn_config', type=str, help='path to WaveRNN config file.') + parser.add_argument('--wavernn_lib_path', type=str, default=None, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') + parser.add_argument('--wavernn_file', type=str, default=None, help='path to WaveRNN checkpoint file.') + parser.add_argument('--wavernn_config', type=str, default=None, help='path to WaveRNN config file.') parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.') + parser.add_argument('--pwgan_lib_path', type=str, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') + parser.add_argument('--pwgan_file', type=str, help='path to ParallelWaveGAN checkpoint file.') + parser.add_argument('--pwgan_config', type=str, help='path to ParallelWaveGAN config file.') parser.add_argument('--port', type=int, default=5002, help='port to listen on.') parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.') parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.') diff --git a/server/synthesizer.py b/server/synthesizer.py index d8852a3e..b703c62e 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -1,17 +1,18 @@ import io import os +import re +import sys import numpy as np import torch -import sys +import yaml from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import load_config, setup_model -from TTS.utils.text import phonemes, symbols from TTS.utils.speakers import load_speaker_mapping from TTS.utils.synthesis import * +from TTS.utils.text import phonemes, symbols -import re alphabets = r"([A-Za-z])" prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]" suffixes = r"(Inc|Ltd|Jr|Sr|Co)" @@ -23,6 +24,7 @@ websites = r"[.](com|net|org|io|gov)" class Synthesizer(object): def __init__(self, config): self.wavernn = None + self.pwgan = None self.config = config self.use_cuda = self.config.use_cuda if self.use_cuda: @@ -30,9 +32,11 @@ class Synthesizer(object): self.load_tts(self.config.tts_checkpoint, self.config.tts_config, self.config.use_cuda) if self.config.wavernn_lib_path: - self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_path, - self.config.wavernn_file, self.config.wavernn_config, - self.config.use_cuda) + self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file, + self.config.wavernn_config, self.config.use_cuda) + if self.config.pwgan_lib_path: + self.load_pwgan(self.config.pwgan_lib_path, self.config.pwgan_file, + self.config.pwgan_config, self.config.use_cuda) def load_tts(self, tts_checkpoint, tts_config, use_cuda): print(" > Loading TTS model ...") @@ -45,9 +49,9 @@ class Synthesizer(object): self.input_size = len(phonemes) else: self.input_size = len(symbols) - # load speakers + # TODO: fix this for multi-speaker model - load speakers if self.config.tts_speakers is not None: - self.tts_speakers = load_speaker_mapping(os.path.join(model_path, self.config.tts_speakers)) + self.tts_speakers = load_speaker_mapping(self.config.tts_speakers) num_speakers = len(self.tts_speakers) else: num_speakers = 0 @@ -63,16 +67,14 @@ class Synthesizer(object): if 'r' in cp: self.tts_model.decoder.set_r(cp['r']) - def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda): + def load_wavernn(self, lib_path, model_file, model_config, use_cuda): # TODO: set a function in wavernn code base for model setup and call it here. sys.path.append(lib_path) # set this if TTS is not installed globally from WaveRNN.models.wavernn import Model - wavernn_config = os.path.join(model_path, model_config) - model_file = os.path.join(model_path, model_file) print(" > Loading WaveRNN model ...") - print(" | > model config: ", wavernn_config) + print(" | > model config: ", model_config) print(" | > model file: ", model_file) - self.wavernn_config = load_config(wavernn_config) + self.wavernn_config = load_config(model_config) self.wavernn = Model( rnn_dims=512, fc_dims=512, @@ -91,11 +93,27 @@ class Synthesizer(object): ).cuda() check = torch.load(model_file) - self.wavernn.load_state_dict(check['model']) + self.wavernn.load_state_dict(check['model'], map_location="cpu") if use_cuda: self.wavernn.cuda() self.wavernn.eval() + def load_pwgan(self, lib_path, model_file, model_config, use_cuda): + sys.path.append(lib_path) # set this if TTS is not installed globally + from parallel_wavegan.models import ParallelWaveGANGenerator + from parallel_wavegan.utils.audio import AudioProcessor as AudioProcessorVocoder + print(" > Loading PWGAN model ...") + print(" | > model config: ", model_config) + print(" | > model file: ", model_file) + with open(model_config) as f: + self.pwgan_config = yaml.load(f, Loader=yaml.Loader) + self.pwgan = ParallelWaveGANGenerator(**self.pwgan_config["generator_params"]) + self.pwgan.load_state_dict(torch.load(model_file, map_location="cpu")["model"]["generator"]) + self.pwgan.remove_weight_norm() + if use_cuda: + self.pwgan.cuda() + self.pwgan.eval() + def save_wav(self, wav, path): # wav *= 32767 / max(1e-8, np.max(np.abs(wav))) wav = np.array(wav)