bug fix at server

pull/10/head
erogol 2020-02-21 14:57:10 +01:00
parent aef12f0c21
commit bc6764a5c7
2 changed files with 20 additions and 2 deletions

View File

@ -94,8 +94,8 @@ class Synthesizer(object):
sample_rate=self.ap.sample_rate, sample_rate=self.ap.sample_rate,
).cuda() ).cuda()
check = torch.load(model_file) check = torch.load(model_file, map_location="cpu")
self.wavernn.load_state_dict(check['model'], map_location="cpu") self.wavernn.load_state_dict(check['model'])
if use_cuda: if use_cuda:
self.wavernn.cuda() self.wavernn.cuda()
self.wavernn.eval() self.wavernn.eval()

View File

@ -69,6 +69,24 @@ def id_to_torch(speaker_id):
return speaker_id return speaker_id
# TODO: perform GL with pytorch for batching
def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
'''Apply griffin-lim to each sample iterating throught the first dimension.
Args:
inputs (Tensor or np.Array): Features to be converted by GL. First dimension is the batch size.
input_lens (Tensor or np.Array): 1D array of sample lengths.
CONFIG (Dict): TTS config.
ap (AudioProcessor): TTS audio processor.
'''
wavs = []
for idx, spec in enumerate(inputs):
wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding
wav = inv_spectrogram(spec, ap, CONFIG)
# assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}"
wavs.append(wav[:wav_len])
return wavs
def synthesis(model, def synthesis(model,
text, text,
CONFIG, CONFIG,