update tests

pull/441/head
Eren Gölge 2021-04-23 17:55:02 +02:00
parent 7eb0c60d2e
commit a878d8fb42
1 changed files with 17 additions and 1 deletions

View File

@ -1,4 +1,5 @@
import os
import torch
import unittest
import numpy as np
@ -11,9 +12,9 @@ from TTS.utils.io import load_config
encoder_config_path = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
encoder_model_path = os.path.join(get_tests_input_path(), "dummy_speaker_encoder.pth.tar")
sample_wav_path = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs/LJ001-0001.wav")
sample_wav_path2 = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs/LJ001-0002.wav")
x_vectors_file_path = os.path.join(get_tests_input_path(), "../data/dummy_speakers.json")
class SpeakerManagerTest(unittest.TestCase):
"""Test SpeakerManager for loading embedding files and computing x_vectors from waveforms"""
@staticmethod
@ -32,6 +33,21 @@ class SpeakerManagerTest(unittest.TestCase):
x_vector = manager.compute_x_vector(mel.T)
assert x_vector.shape[1] == 256
# compute x_vector directly from an input file
x_vector = manager.compute_x_vector_from_clip(sample_wav_path)
x_vector2 = manager.compute_x_vector_from_clip(sample_wav_path)
x_vector = torch.FloatTensor(x_vector)
x_vector2 = torch.FloatTensor(x_vector2)
assert x_vector.shape[0] == 256
assert (x_vector - x_vector2).sum() == 0.0
# compute x_vector from a list of wav files.
x_vector3 = manager.compute_x_vector_from_clip([sample_wav_path, sample_wav_path2])
x_vector3 = torch.FloatTensor(x_vector3)
assert x_vector3.shape[0] == 256
assert (x_vector - x_vector3).sum() != 0.0
@staticmethod
def test_speakers_file_processing():
manager = SpeakerManager(x_vectors_file_path=x_vectors_file_path)