mirror of https://github.com/coqui-ai/TTS.git
Test `TTS.tts.utils.helpers`
parent
8b7e094bde
commit
ed4b1d8514
|
@ -1,6 +1,3 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
@ -14,9 +11,9 @@ except ModuleNotFoundError:
|
|||
|
||||
|
||||
class StandardScaler:
|
||||
"""StandardScaler for mean-std normalization with the given mean and std values.
|
||||
"""
|
||||
def __init__(self, mean:np.ndarray=None, std:np.ndarray=None) -> None:
|
||||
"""StandardScaler for mean-std normalization with the given mean and std values."""
|
||||
|
||||
def __init__(self, mean: np.ndarray = None, std: np.ndarray = None) -> None:
|
||||
self.mean_ = mean
|
||||
self.std_ = std
|
||||
|
||||
|
@ -97,6 +94,7 @@ def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=
|
|||
ret = segment(x, segment_indices, segment_size)
|
||||
return ret, segment_indices
|
||||
|
||||
|
||||
def average_over_durations(values, durs):
|
||||
"""Average values over durations.
|
||||
|
||||
|
@ -212,4 +210,4 @@ def maximum_path_numpy(value, mask, max_neg_val=None):
|
|||
index = index + direction[index_range, index, j] - 1
|
||||
path = path * mask.astype(np.float32)
|
||||
path = torch.from_numpy(path).to(device=device, dtype=dtype)
|
||||
return path
|
||||
return path
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
import torch as T
|
||||
|
||||
from TTS.tts.utils.helpers import *
|
||||
|
||||
|
||||
def average_over_durations_test(): # pylint: disable=no-self-use
|
||||
pitch = T.rand(1, 1, 128)
|
||||
|
||||
durations = T.randint(1, 5, (1, 21))
|
||||
coeff = 128.0 / durations.sum()
|
||||
durations = T.floor(durations * coeff)
|
||||
diff = 128.0 - durations.sum()
|
||||
durations[0, -1] += diff
|
||||
durations = durations.long()
|
||||
|
||||
pitch_avg = average_over_durations(pitch, durations)
|
||||
|
||||
index = 0
|
||||
for idx, dur in enumerate(durations[0]):
|
||||
assert abs(pitch_avg[0, 0, idx] - pitch[0, 0, index : index + dur.item()].mean()) < 1e-5
|
||||
index += dur
|
||||
|
||||
|
||||
def seqeunce_mask_test():
|
||||
lengths = T.randint(10, 15, (8,))
|
||||
mask = sequence_mask(lengths)
|
||||
for i in range(8):
|
||||
l = lengths[i].item()
|
||||
assert mask[i, :l].sum() == l
|
||||
assert mask[i, l:].sum() == 0
|
||||
|
||||
|
||||
def segment_test():
|
||||
x = T.range(0, 11)
|
||||
x = x.repeat(8, 1).unsqueeze(1)
|
||||
segment_ids = T.randint(0, 7, (8,))
|
||||
|
||||
segments = segment(x, segment_ids, segment_size=4)
|
||||
for idx, start_indx in enumerate(segment_ids):
|
||||
assert x[idx, :, start_indx : start_indx + 4].sum() == segments[idx, :, :].sum()
|
||||
|
||||
|
||||
def generate_path_test():
|
||||
durations = T.randint(1, 4, (10, 21))
|
||||
x_length = T.randint(18, 22, (10,))
|
||||
x_mask = sequence_mask(x_length).unsqueeze(1).long()
|
||||
durations = durations * x_mask.squeeze(1)
|
||||
y_length = durations.sum(1)
|
||||
y_mask = sequence_mask(y_length).unsqueeze(1).long()
|
||||
attn_mask = (torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)).squeeze(1).long()
|
||||
print(attn_mask.shape)
|
||||
path = generate_path(durations, attn_mask)
|
||||
assert path.shape == (10, 21, durations.sum(1).max().item())
|
||||
for b in range(durations.shape[0]):
|
||||
current_idx = 0
|
||||
for t in range(durations.shape[1]):
|
||||
assert all(path[b, t, current_idx : current_idx + durations[b, t].item()] == 1.0)
|
||||
assert all(path[b, t, :current_idx] == 0.0)
|
||||
assert all(path[b, t, current_idx + durations[b, t].item() :] == 0.0)
|
||||
current_idx += durations[b, t].item()
|
Loading…
Reference in New Issue