mirror of https://github.com/coqui-ai/TTS.git
update server and synthesizer to handle ParallelWaveGAN
parent
fbe5310be0
commit
ed8a9fc82a
|
@ -14,10 +14,13 @@ def create_argparser():
|
|||
parser.add_argument('--tts_checkpoint', type=str, help='path to TTS checkpoint file')
|
||||
parser.add_argument('--tts_config', type=str, help='path to TTS config.json file')
|
||||
parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model')
|
||||
parser.add_argument('--wavernn_lib_path', type=str, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
|
||||
parser.add_argument('--wavernn_file', type=str, help='path to WaveRNN checkpoint file.')
|
||||
parser.add_argument('--wavernn_config', type=str, help='path to WaveRNN config file.')
|
||||
parser.add_argument('--wavernn_lib_path', type=str, default=None, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
|
||||
parser.add_argument('--wavernn_file', type=str, default=None, help='path to WaveRNN checkpoint file.')
|
||||
parser.add_argument('--wavernn_config', type=str, default=None, help='path to WaveRNN config file.')
|
||||
parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.')
|
||||
parser.add_argument('--pwgan_lib_path', type=str, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
|
||||
parser.add_argument('--pwgan_file', type=str, help='path to ParallelWaveGAN checkpoint file.')
|
||||
parser.add_argument('--pwgan_config', type=str, help='path to ParallelWaveGAN config file.')
|
||||
parser.add_argument('--port', type=int, default=5002, help='port to listen on.')
|
||||
parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.')
|
||||
parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.')
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import load_config, setup_model
|
||||
from TTS.utils.text import phonemes, symbols
|
||||
from TTS.utils.speakers import load_speaker_mapping
|
||||
from TTS.utils.synthesis import *
|
||||
from TTS.utils.text import phonemes, symbols
|
||||
|
||||
import re
|
||||
alphabets = r"([A-Za-z])"
|
||||
prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]"
|
||||
suffixes = r"(Inc|Ltd|Jr|Sr|Co)"
|
||||
|
@ -23,6 +24,7 @@ websites = r"[.](com|net|org|io|gov)"
|
|||
class Synthesizer(object):
|
||||
def __init__(self, config):
|
||||
self.wavernn = None
|
||||
self.pwgan = None
|
||||
self.config = config
|
||||
self.use_cuda = self.config.use_cuda
|
||||
if self.use_cuda:
|
||||
|
@ -30,9 +32,11 @@ class Synthesizer(object):
|
|||
self.load_tts(self.config.tts_checkpoint, self.config.tts_config,
|
||||
self.config.use_cuda)
|
||||
if self.config.wavernn_lib_path:
|
||||
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_path,
|
||||
self.config.wavernn_file, self.config.wavernn_config,
|
||||
self.config.use_cuda)
|
||||
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file,
|
||||
self.config.wavernn_config, self.config.use_cuda)
|
||||
if self.config.pwgan_lib_path:
|
||||
self.load_pwgan(self.config.pwgan_lib_path, self.config.pwgan_file,
|
||||
self.config.pwgan_config, self.config.use_cuda)
|
||||
|
||||
def load_tts(self, tts_checkpoint, tts_config, use_cuda):
|
||||
print(" > Loading TTS model ...")
|
||||
|
@ -45,9 +49,9 @@ class Synthesizer(object):
|
|||
self.input_size = len(phonemes)
|
||||
else:
|
||||
self.input_size = len(symbols)
|
||||
# load speakers
|
||||
# TODO: fix this for multi-speaker model - load speakers
|
||||
if self.config.tts_speakers is not None:
|
||||
self.tts_speakers = load_speaker_mapping(os.path.join(model_path, self.config.tts_speakers))
|
||||
self.tts_speakers = load_speaker_mapping(self.config.tts_speakers)
|
||||
num_speakers = len(self.tts_speakers)
|
||||
else:
|
||||
num_speakers = 0
|
||||
|
@ -63,16 +67,14 @@ class Synthesizer(object):
|
|||
if 'r' in cp:
|
||||
self.tts_model.decoder.set_r(cp['r'])
|
||||
|
||||
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):
|
||||
def load_wavernn(self, lib_path, model_file, model_config, use_cuda):
|
||||
# TODO: set a function in wavernn code base for model setup and call it here.
|
||||
sys.path.append(lib_path) # set this if TTS is not installed globally
|
||||
from WaveRNN.models.wavernn import Model
|
||||
wavernn_config = os.path.join(model_path, model_config)
|
||||
model_file = os.path.join(model_path, model_file)
|
||||
print(" > Loading WaveRNN model ...")
|
||||
print(" | > model config: ", wavernn_config)
|
||||
print(" | > model config: ", model_config)
|
||||
print(" | > model file: ", model_file)
|
||||
self.wavernn_config = load_config(wavernn_config)
|
||||
self.wavernn_config = load_config(model_config)
|
||||
self.wavernn = Model(
|
||||
rnn_dims=512,
|
||||
fc_dims=512,
|
||||
|
@ -91,11 +93,27 @@ class Synthesizer(object):
|
|||
).cuda()
|
||||
|
||||
check = torch.load(model_file)
|
||||
self.wavernn.load_state_dict(check['model'])
|
||||
self.wavernn.load_state_dict(check['model'], map_location="cpu")
|
||||
if use_cuda:
|
||||
self.wavernn.cuda()
|
||||
self.wavernn.eval()
|
||||
|
||||
def load_pwgan(self, lib_path, model_file, model_config, use_cuda):
|
||||
sys.path.append(lib_path) # set this if TTS is not installed globally
|
||||
from parallel_wavegan.models import ParallelWaveGANGenerator
|
||||
from parallel_wavegan.utils.audio import AudioProcessor as AudioProcessorVocoder
|
||||
print(" > Loading PWGAN model ...")
|
||||
print(" | > model config: ", model_config)
|
||||
print(" | > model file: ", model_file)
|
||||
with open(model_config) as f:
|
||||
self.pwgan_config = yaml.load(f, Loader=yaml.Loader)
|
||||
self.pwgan = ParallelWaveGANGenerator(**self.pwgan_config["generator_params"])
|
||||
self.pwgan.load_state_dict(torch.load(model_file, map_location="cpu")["model"]["generator"])
|
||||
self.pwgan.remove_weight_norm()
|
||||
if use_cuda:
|
||||
self.pwgan.cuda()
|
||||
self.pwgan.eval()
|
||||
|
||||
def save_wav(self, wav, path):
|
||||
# wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
|
||||
wav = np.array(wav)
|
||||
|
|
Loading…
Reference in New Issue