mirror of https://github.com/coqui-ai/TTS.git
Move MAS to `TTS.tts.utils.helpers`
parent
2dfc5bdd11
commit
bfc6ceac29
|
@ -1,106 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
|
||||||
|
|
||||||
try:
|
|
||||||
# TODO: fix pypi cython installation problem.
|
|
||||||
from TTS.tts.layers.glow_tts.monotonic_align.core import maximum_path_c
|
|
||||||
|
|
||||||
CYTHON = True
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
CYTHON = False
|
|
||||||
|
|
||||||
|
|
||||||
def convert_pad_shape(pad_shape):
|
|
||||||
l = pad_shape[::-1]
|
|
||||||
pad_shape = [item for sublist in l for item in sublist]
|
|
||||||
return pad_shape
|
|
||||||
|
|
||||||
|
|
||||||
def generate_path(duration, mask):
|
|
||||||
"""
|
|
||||||
Shapes:
|
|
||||||
- duration: :math:`[B, T_en]`
|
|
||||||
- mask: :math:'[B, T_en, T_de]`
|
|
||||||
- path: :math:`[B, T_en, T_de]`
|
|
||||||
"""
|
|
||||||
device = duration.device
|
|
||||||
b, t_x, t_y = mask.shape
|
|
||||||
cum_duration = torch.cumsum(duration, 1)
|
|
||||||
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
|
|
||||||
|
|
||||||
cum_duration_flat = cum_duration.view(b * t_x)
|
|
||||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
|
||||||
path = path.view(b, t_x, t_y)
|
|
||||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
|
||||||
path = path * mask
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
def maximum_path(value, mask):
|
|
||||||
if CYTHON:
|
|
||||||
return maximum_path_cython(value, mask)
|
|
||||||
return maximum_path_numpy(value, mask)
|
|
||||||
|
|
||||||
|
|
||||||
def maximum_path_cython(value, mask):
|
|
||||||
"""Cython optimised version.
|
|
||||||
Shapes:
|
|
||||||
- value: :math:`[B, T_en, T_de]`
|
|
||||||
- mask: :math:`[B, T_en, T_de]`
|
|
||||||
"""
|
|
||||||
value = value * mask
|
|
||||||
device = value.device
|
|
||||||
dtype = value.dtype
|
|
||||||
value = value.data.cpu().numpy().astype(np.float32)
|
|
||||||
path = np.zeros_like(value).astype(np.int32)
|
|
||||||
mask = mask.data.cpu().numpy()
|
|
||||||
|
|
||||||
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
|
|
||||||
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
|
|
||||||
maximum_path_c(path, value, t_x_max, t_y_max)
|
|
||||||
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def maximum_path_numpy(value, mask, max_neg_val=None):
|
|
||||||
"""
|
|
||||||
Monotonic alignment search algorithm
|
|
||||||
Numpy-friendly version. It's about 4 times faster than torch version.
|
|
||||||
value: [b, t_x, t_y]
|
|
||||||
mask: [b, t_x, t_y]
|
|
||||||
"""
|
|
||||||
if max_neg_val is None:
|
|
||||||
max_neg_val = -np.inf # Patch for Sphinx complaint
|
|
||||||
value = value * mask
|
|
||||||
|
|
||||||
device = value.device
|
|
||||||
dtype = value.dtype
|
|
||||||
value = value.cpu().detach().numpy()
|
|
||||||
mask = mask.cpu().detach().numpy().astype(np.bool)
|
|
||||||
|
|
||||||
b, t_x, t_y = value.shape
|
|
||||||
direction = np.zeros(value.shape, dtype=np.int64)
|
|
||||||
v = np.zeros((b, t_x), dtype=np.float32)
|
|
||||||
x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1)
|
|
||||||
for j in range(t_y):
|
|
||||||
v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1]
|
|
||||||
v1 = v
|
|
||||||
max_mask = v1 >= v0
|
|
||||||
v_max = np.where(max_mask, v1, v0)
|
|
||||||
direction[:, :, j] = max_mask
|
|
||||||
|
|
||||||
index_mask = x_range <= j
|
|
||||||
v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
|
|
||||||
direction = np.where(mask, direction, 1)
|
|
||||||
|
|
||||||
path = np.zeros(value.shape, dtype=np.float32)
|
|
||||||
index = mask[:, :, 0].sum(1).astype(np.int64) - 1
|
|
||||||
index_range = np.arange(b)
|
|
||||||
for j in reversed(range(t_y)):
|
|
||||||
path[index_range, index, j] = 1
|
|
||||||
index = index + direction[index_range, index, j] - 1
|
|
||||||
path = path * mask.astype(np.float32)
|
|
||||||
path = torch.from_numpy(path).to(device=device, dtype=dtype)
|
|
||||||
return path
|
|
|
@ -10,7 +10,7 @@ from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||||
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
from TTS.tts.utils.helpers import generate_path, maximum_path
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
|
|
@ -118,9 +118,11 @@ class BaseTacotron(BaseTTS):
|
||||||
if "r" in state:
|
if "r" in state:
|
||||||
self.decoder.set_r(state["r"])
|
self.decoder.set_r(state["r"])
|
||||||
else:
|
else:
|
||||||
|
# set the reduction rate from the config values embedded in the checkpoint
|
||||||
self.decoder.set_r(state["config"]["r"])
|
self.decoder.set_r(state["config"]["r"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
print(f" > Model's reduction rate `r` is set to: {self.decoder.r}")
|
||||||
assert not self.training
|
assert not self.training
|
||||||
|
|
||||||
def get_criterion(self) -> nn.Module:
|
def get_criterion(self) -> nn.Module:
|
||||||
|
|
|
@ -11,7 +11,7 @@ from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||||
from TTS.tts.layers.generic.aligner import AlignmentNetwork
|
from TTS.tts.layers.generic.aligner import AlignmentNetwork
|
||||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
from TTS.tts.utils.helpers import generate_path, maximum_path
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
||||||
|
|
|
@ -7,7 +7,7 @@ from torch.nn import functional as F
|
||||||
from TTS.tts.configs import GlowTTSConfig
|
from TTS.tts.configs import GlowTTSConfig
|
||||||
from TTS.tts.layers.glow_tts.decoder import Decoder
|
from TTS.tts.layers.glow_tts.decoder import Decoder
|
||||||
from TTS.tts.layers.glow_tts.encoder import Encoder
|
from TTS.tts.layers.glow_tts.encoder import Encoder
|
||||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
from TTS.tts.utils.helpers import generate_path, maximum_path
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
from TTS.tts.utils.speakers import get_speaker_manager
|
from TTS.tts.utils.speakers import get_speaker_manager
|
||||||
|
|
|
@ -8,7 +8,7 @@ from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||||
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
|
from TTS.tts.utils.helpers import generate_path
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
|
|
|
@ -9,7 +9,7 @@ from torch import nn
|
||||||
from torch.cuda.amp.autocast_mode import autocast
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
|
|
||||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
from TTS.tts.utils.helpers import generate_path, maximum_path
|
||||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||||
|
|
|
@ -1,6 +1,17 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
try:
|
||||||
|
from TTS.tts.utils.monotonic_align.core import maximum_path_c
|
||||||
|
|
||||||
|
CYTHON = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
CYTHON = False
|
||||||
|
|
||||||
|
|
||||||
class StandardScaler:
|
class StandardScaler:
|
||||||
"""StandardScaler for mean-std normalization with the given mean and std values.
|
"""StandardScaler for mean-std normalization with the given mean and std values.
|
||||||
|
@ -109,3 +120,96 @@ def average_over_durations(values, durs):
|
||||||
|
|
||||||
avg = torch.where(values_nelems == 0.0, values_nelems, values_sums / values_nelems)
|
avg = torch.where(values_nelems == 0.0, values_nelems, values_sums / values_nelems)
|
||||||
return avg
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pad_shape(pad_shape):
|
||||||
|
l = pad_shape[::-1]
|
||||||
|
pad_shape = [item for sublist in l for item in sublist]
|
||||||
|
return pad_shape
|
||||||
|
|
||||||
|
|
||||||
|
def generate_path(duration, mask):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- duration: :math:`[B, T_en]`
|
||||||
|
- mask: :math:'[B, T_en, T_de]`
|
||||||
|
- path: :math:`[B, T_en, T_de]`
|
||||||
|
"""
|
||||||
|
device = duration.device
|
||||||
|
b, t_x, t_y = mask.shape
|
||||||
|
cum_duration = torch.cumsum(duration, 1)
|
||||||
|
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
|
||||||
|
|
||||||
|
cum_duration_flat = cum_duration.view(b * t_x)
|
||||||
|
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||||
|
path = path.view(b, t_x, t_y)
|
||||||
|
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||||
|
path = path * mask
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def maximum_path(value, mask):
|
||||||
|
if CYTHON:
|
||||||
|
return maximum_path_cython(value, mask)
|
||||||
|
return maximum_path_numpy(value, mask)
|
||||||
|
|
||||||
|
|
||||||
|
def maximum_path_cython(value, mask):
|
||||||
|
"""Cython optimised version.
|
||||||
|
Shapes:
|
||||||
|
- value: :math:`[B, T_en, T_de]`
|
||||||
|
- mask: :math:`[B, T_en, T_de]`
|
||||||
|
"""
|
||||||
|
value = value * mask
|
||||||
|
device = value.device
|
||||||
|
dtype = value.dtype
|
||||||
|
value = value.data.cpu().numpy().astype(np.float32)
|
||||||
|
path = np.zeros_like(value).astype(np.int32)
|
||||||
|
mask = mask.data.cpu().numpy()
|
||||||
|
|
||||||
|
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
|
||||||
|
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
|
||||||
|
maximum_path_c(path, value, t_x_max, t_y_max)
|
||||||
|
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def maximum_path_numpy(value, mask, max_neg_val=None):
|
||||||
|
"""
|
||||||
|
Monotonic alignment search algorithm
|
||||||
|
Numpy-friendly version. It's about 4 times faster than torch version.
|
||||||
|
value: [b, t_x, t_y]
|
||||||
|
mask: [b, t_x, t_y]
|
||||||
|
"""
|
||||||
|
if max_neg_val is None:
|
||||||
|
max_neg_val = -np.inf # Patch for Sphinx complaint
|
||||||
|
value = value * mask
|
||||||
|
|
||||||
|
device = value.device
|
||||||
|
dtype = value.dtype
|
||||||
|
value = value.cpu().detach().numpy()
|
||||||
|
mask = mask.cpu().detach().numpy().astype(np.bool)
|
||||||
|
|
||||||
|
b, t_x, t_y = value.shape
|
||||||
|
direction = np.zeros(value.shape, dtype=np.int64)
|
||||||
|
v = np.zeros((b, t_x), dtype=np.float32)
|
||||||
|
x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1)
|
||||||
|
for j in range(t_y):
|
||||||
|
v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1]
|
||||||
|
v1 = v
|
||||||
|
max_mask = v1 >= v0
|
||||||
|
v_max = np.where(max_mask, v1, v0)
|
||||||
|
direction[:, :, j] = max_mask
|
||||||
|
|
||||||
|
index_mask = x_range <= j
|
||||||
|
v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
|
||||||
|
direction = np.where(mask, direction, 1)
|
||||||
|
|
||||||
|
path = np.zeros(value.shape, dtype=np.float32)
|
||||||
|
index = mask[:, :, 0].sum(1).astype(np.int64) - 1
|
||||||
|
index_range = np.arange(b)
|
||||||
|
for j in reversed(range(t_y)):
|
||||||
|
path[index_range, index, j] = 1
|
||||||
|
index = index + direction[index_range, index, j] - 1
|
||||||
|
path = path * mask.astype(np.float32)
|
||||||
|
path = torch.from_numpy(path).to(device=device, dtype=dtype)
|
||||||
|
return path
|
File diff suppressed because it is too large
Load Diff
4
setup.py
4
setup.py
|
@ -54,8 +54,8 @@ with open("README.md", "r", encoding="utf-8") as readme_file:
|
||||||
|
|
||||||
exts = [
|
exts = [
|
||||||
Extension(
|
Extension(
|
||||||
name="TTS.tts.layers.glow_tts.monotonic_align.core",
|
name="TTS.tts.utils.monotonic_align.core",
|
||||||
sources=["TTS/tts/layers/glow_tts/monotonic_align/core.pyx"],
|
sources=["TTS/tts/utils/monotonic_align/core.pyx"],
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
setup(
|
setup(
|
||||||
|
|
Loading…
Reference in New Issue