mirror of https://github.com/coqui-ai/TTS.git
merge glow-tts after rebranding
parent
95de34e8ef
commit
89d15bf118
|
@ -113,7 +113,14 @@ class MyDataset(Dataset):
|
|||
return phonemes
|
||||
|
||||
def load_data(self, idx):
|
||||
text, wav_file, speaker_name = self.items[idx]
|
||||
item = self.items[idx]
|
||||
|
||||
if len(item) == 4:
|
||||
text, wav_file, speaker_name, attn_file = item
|
||||
else:
|
||||
text, wav_file, speaker_name = item
|
||||
attn = None
|
||||
|
||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||
|
||||
if self.use_phonemes:
|
||||
|
@ -125,9 +132,13 @@ class MyDataset(Dataset):
|
|||
assert text.size > 0, self.items[idx][1]
|
||||
assert wav.size > 0, self.items[idx][1]
|
||||
|
||||
if "attn_file" in locals():
|
||||
attn = np.load(attn_file)
|
||||
|
||||
sample = {
|
||||
'text': text,
|
||||
'wav': wav,
|
||||
'attn': attn,
|
||||
'item_idx': self.items[idx][1],
|
||||
'speaker_name': speaker_name,
|
||||
'wav_file_name': os.path.basename(wav_file)
|
||||
|
@ -245,8 +256,21 @@ class MyDataset(Dataset):
|
|||
linear = torch.FloatTensor(linear).contiguous()
|
||||
else:
|
||||
linear = None
|
||||
|
||||
# collate attention alignments
|
||||
if batch[0]['attn'] is not None:
|
||||
attns = [batch[idx]['attn'].T for idx in ids_sorted_decreasing]
|
||||
for idx, attn in enumerate(attns):
|
||||
pad2 = mel.shape[1] - attn.shape[1]
|
||||
pad1 = text.shape[1] - attn.shape[0]
|
||||
attn = np.pad(attn, [[0, pad1], [0, pad2]])
|
||||
attns[idx] = attn
|
||||
attns = prepare_tensor(attns, self.outputs_per_step)
|
||||
attns = torch.FloatTensor(attns).unsqueeze(1)
|
||||
else:
|
||||
attns = None
|
||||
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
|
||||
stop_targets, item_idxs, speaker_embedding
|
||||
stop_targets, item_idxs, speaker_embedding, attns
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}".format(type(batch[0]))))
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -150,7 +151,7 @@ class GuidedAttentionLoss(torch.nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _make_ga_mask(ilen, olen, sigma):
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device), torch.arange(ilen, device=ilen.device))
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen))
|
||||
grid_x, grid_y = grid_x.float(), grid_y.float()
|
||||
return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen) ** 2 / (2 * (sigma ** 2)))
|
||||
|
||||
|
@ -243,3 +244,21 @@ class TacotronLoss(torch.nn.Module):
|
|||
|
||||
return_dict['loss'] = loss
|
||||
return return_dict
|
||||
|
||||
|
||||
class GlowTTSLoss(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(GlowTTSLoss, self).__init__()
|
||||
self.constant_factor = 0.5 * math.log(2 * math.pi)
|
||||
|
||||
def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, o_attn_dur, x_lengths):
|
||||
return_dict = {}
|
||||
# flow loss
|
||||
pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means)**2)
|
||||
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths // 2) * 2 * 80)
|
||||
# duration loss
|
||||
loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths)
|
||||
return_dict['loss'] = log_mle + loss_dur
|
||||
return_dict['log_mle'] = log_mle
|
||||
return_dict['loss_dur'] = loss_dur
|
||||
return return_dict
|
|
@ -1,3 +1,4 @@
|
|||
import re
|
||||
import torch
|
||||
import importlib
|
||||
import numpy as np
|
||||
|
@ -44,6 +45,11 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
return seq_range_expand < seq_length_expand
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower())
|
||||
|
@ -99,6 +105,32 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r,
|
||||
speaker_embedding_dim=speaker_embedding_dim)
|
||||
elif c.model.lower() == "glow_tts":
|
||||
model = MyModel(num_chars=num_chars,
|
||||
hidden_channels=192,
|
||||
filter_channels=768,
|
||||
filter_channels_dp=256,
|
||||
out_channels=80,
|
||||
kernel_size=3,
|
||||
num_heads=2,
|
||||
num_layers_enc=6,
|
||||
dropout_p=0.1,
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=1,
|
||||
num_block_layers=4,
|
||||
dropout_p_dec=0.05,
|
||||
num_speakers=num_speakers,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_sqz=2,
|
||||
sigmoid_scale=False,
|
||||
rel_attn_window_size=4,
|
||||
input_length=None,
|
||||
mean_only=True,
|
||||
hidden_channels_enc=192,
|
||||
hidden_channels_dec=192,
|
||||
use_encoder_prenet=True)
|
||||
return model
|
||||
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -5,6 +5,7 @@ import os
|
|||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import numpy
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
import setuptools.command.develop
|
||||
|
@ -118,6 +119,7 @@ setup(
|
|||
'tts-server = TTS.server.server:main'
|
||||
]
|
||||
},
|
||||
include_dirs=[numpy.get_include()],
|
||||
ext_modules=cythonize(find_pyx(), language_level=3),
|
||||
packages=find_packages(include=['TTS*']),
|
||||
project_urls={
|
||||
|
|
Loading…
Reference in New Issue