mirror of https://github.com/coqui-ai/TTS.git
update server and synthesizer to handle ParallelWaveGAN
parent
fbe5310be0
commit
ed8a9fc82a
|
@ -14,10 +14,13 @@ def create_argparser():
|
||||||
parser.add_argument('--tts_checkpoint', type=str, help='path to TTS checkpoint file')
|
parser.add_argument('--tts_checkpoint', type=str, help='path to TTS checkpoint file')
|
||||||
parser.add_argument('--tts_config', type=str, help='path to TTS config.json file')
|
parser.add_argument('--tts_config', type=str, help='path to TTS config.json file')
|
||||||
parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model')
|
parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model')
|
||||||
parser.add_argument('--wavernn_lib_path', type=str, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
|
parser.add_argument('--wavernn_lib_path', type=str, default=None, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
|
||||||
parser.add_argument('--wavernn_file', type=str, help='path to WaveRNN checkpoint file.')
|
parser.add_argument('--wavernn_file', type=str, default=None, help='path to WaveRNN checkpoint file.')
|
||||||
parser.add_argument('--wavernn_config', type=str, help='path to WaveRNN config file.')
|
parser.add_argument('--wavernn_config', type=str, default=None, help='path to WaveRNN config file.')
|
||||||
parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.')
|
parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.')
|
||||||
|
parser.add_argument('--pwgan_lib_path', type=str, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
|
||||||
|
parser.add_argument('--pwgan_file', type=str, help='path to ParallelWaveGAN checkpoint file.')
|
||||||
|
parser.add_argument('--pwgan_config', type=str, help='path to ParallelWaveGAN config file.')
|
||||||
parser.add_argument('--port', type=int, default=5002, help='port to listen on.')
|
parser.add_argument('--port', type=int, default=5002, help='port to listen on.')
|
||||||
parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.')
|
parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.')
|
||||||
parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.')
|
parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.')
|
||||||
|
|
|
@ -1,17 +1,18 @@
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import yaml
|
||||||
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import load_config, setup_model
|
from TTS.utils.generic_utils import load_config, setup_model
|
||||||
from TTS.utils.text import phonemes, symbols
|
|
||||||
from TTS.utils.speakers import load_speaker_mapping
|
from TTS.utils.speakers import load_speaker_mapping
|
||||||
from TTS.utils.synthesis import *
|
from TTS.utils.synthesis import *
|
||||||
|
from TTS.utils.text import phonemes, symbols
|
||||||
|
|
||||||
import re
|
|
||||||
alphabets = r"([A-Za-z])"
|
alphabets = r"([A-Za-z])"
|
||||||
prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]"
|
prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]"
|
||||||
suffixes = r"(Inc|Ltd|Jr|Sr|Co)"
|
suffixes = r"(Inc|Ltd|Jr|Sr|Co)"
|
||||||
|
@ -23,6 +24,7 @@ websites = r"[.](com|net|org|io|gov)"
|
||||||
class Synthesizer(object):
|
class Synthesizer(object):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self.wavernn = None
|
self.wavernn = None
|
||||||
|
self.pwgan = None
|
||||||
self.config = config
|
self.config = config
|
||||||
self.use_cuda = self.config.use_cuda
|
self.use_cuda = self.config.use_cuda
|
||||||
if self.use_cuda:
|
if self.use_cuda:
|
||||||
|
@ -30,9 +32,11 @@ class Synthesizer(object):
|
||||||
self.load_tts(self.config.tts_checkpoint, self.config.tts_config,
|
self.load_tts(self.config.tts_checkpoint, self.config.tts_config,
|
||||||
self.config.use_cuda)
|
self.config.use_cuda)
|
||||||
if self.config.wavernn_lib_path:
|
if self.config.wavernn_lib_path:
|
||||||
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_path,
|
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file,
|
||||||
self.config.wavernn_file, self.config.wavernn_config,
|
self.config.wavernn_config, self.config.use_cuda)
|
||||||
self.config.use_cuda)
|
if self.config.pwgan_lib_path:
|
||||||
|
self.load_pwgan(self.config.pwgan_lib_path, self.config.pwgan_file,
|
||||||
|
self.config.pwgan_config, self.config.use_cuda)
|
||||||
|
|
||||||
def load_tts(self, tts_checkpoint, tts_config, use_cuda):
|
def load_tts(self, tts_checkpoint, tts_config, use_cuda):
|
||||||
print(" > Loading TTS model ...")
|
print(" > Loading TTS model ...")
|
||||||
|
@ -45,9 +49,9 @@ class Synthesizer(object):
|
||||||
self.input_size = len(phonemes)
|
self.input_size = len(phonemes)
|
||||||
else:
|
else:
|
||||||
self.input_size = len(symbols)
|
self.input_size = len(symbols)
|
||||||
# load speakers
|
# TODO: fix this for multi-speaker model - load speakers
|
||||||
if self.config.tts_speakers is not None:
|
if self.config.tts_speakers is not None:
|
||||||
self.tts_speakers = load_speaker_mapping(os.path.join(model_path, self.config.tts_speakers))
|
self.tts_speakers = load_speaker_mapping(self.config.tts_speakers)
|
||||||
num_speakers = len(self.tts_speakers)
|
num_speakers = len(self.tts_speakers)
|
||||||
else:
|
else:
|
||||||
num_speakers = 0
|
num_speakers = 0
|
||||||
|
@ -63,16 +67,14 @@ class Synthesizer(object):
|
||||||
if 'r' in cp:
|
if 'r' in cp:
|
||||||
self.tts_model.decoder.set_r(cp['r'])
|
self.tts_model.decoder.set_r(cp['r'])
|
||||||
|
|
||||||
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):
|
def load_wavernn(self, lib_path, model_file, model_config, use_cuda):
|
||||||
# TODO: set a function in wavernn code base for model setup and call it here.
|
# TODO: set a function in wavernn code base for model setup and call it here.
|
||||||
sys.path.append(lib_path) # set this if TTS is not installed globally
|
sys.path.append(lib_path) # set this if TTS is not installed globally
|
||||||
from WaveRNN.models.wavernn import Model
|
from WaveRNN.models.wavernn import Model
|
||||||
wavernn_config = os.path.join(model_path, model_config)
|
|
||||||
model_file = os.path.join(model_path, model_file)
|
|
||||||
print(" > Loading WaveRNN model ...")
|
print(" > Loading WaveRNN model ...")
|
||||||
print(" | > model config: ", wavernn_config)
|
print(" | > model config: ", model_config)
|
||||||
print(" | > model file: ", model_file)
|
print(" | > model file: ", model_file)
|
||||||
self.wavernn_config = load_config(wavernn_config)
|
self.wavernn_config = load_config(model_config)
|
||||||
self.wavernn = Model(
|
self.wavernn = Model(
|
||||||
rnn_dims=512,
|
rnn_dims=512,
|
||||||
fc_dims=512,
|
fc_dims=512,
|
||||||
|
@ -91,11 +93,27 @@ class Synthesizer(object):
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
check = torch.load(model_file)
|
check = torch.load(model_file)
|
||||||
self.wavernn.load_state_dict(check['model'])
|
self.wavernn.load_state_dict(check['model'], map_location="cpu")
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.wavernn.cuda()
|
self.wavernn.cuda()
|
||||||
self.wavernn.eval()
|
self.wavernn.eval()
|
||||||
|
|
||||||
|
def load_pwgan(self, lib_path, model_file, model_config, use_cuda):
|
||||||
|
sys.path.append(lib_path) # set this if TTS is not installed globally
|
||||||
|
from parallel_wavegan.models import ParallelWaveGANGenerator
|
||||||
|
from parallel_wavegan.utils.audio import AudioProcessor as AudioProcessorVocoder
|
||||||
|
print(" > Loading PWGAN model ...")
|
||||||
|
print(" | > model config: ", model_config)
|
||||||
|
print(" | > model file: ", model_file)
|
||||||
|
with open(model_config) as f:
|
||||||
|
self.pwgan_config = yaml.load(f, Loader=yaml.Loader)
|
||||||
|
self.pwgan = ParallelWaveGANGenerator(**self.pwgan_config["generator_params"])
|
||||||
|
self.pwgan.load_state_dict(torch.load(model_file, map_location="cpu")["model"]["generator"])
|
||||||
|
self.pwgan.remove_weight_norm()
|
||||||
|
if use_cuda:
|
||||||
|
self.pwgan.cuda()
|
||||||
|
self.pwgan.eval()
|
||||||
|
|
||||||
def save_wav(self, wav, path):
|
def save_wav(self, wav, path):
|
||||||
# wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
|
# wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
|
||||||
wav = np.array(wav)
|
wav = np.array(wav)
|
||||||
|
|
Loading…
Reference in New Issue