diff --git a/tests/test_speaker_manager.py b/tests/test_speaker_manager.py index 40914224..3e272f42 100644 --- a/tests/test_speaker_manager.py +++ b/tests/test_speaker_manager.py @@ -1,4 +1,5 @@ import os +import torch import unittest import numpy as np @@ -11,9 +12,9 @@ from TTS.utils.io import load_config encoder_config_path = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") encoder_model_path = os.path.join(get_tests_input_path(), "dummy_speaker_encoder.pth.tar") sample_wav_path = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs/LJ001-0001.wav") +sample_wav_path2 = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs/LJ001-0002.wav") x_vectors_file_path = os.path.join(get_tests_input_path(), "../data/dummy_speakers.json") - class SpeakerManagerTest(unittest.TestCase): """Test SpeakerManager for loading embedding files and computing x_vectors from waveforms""" @staticmethod @@ -32,6 +33,21 @@ class SpeakerManagerTest(unittest.TestCase): x_vector = manager.compute_x_vector(mel.T) assert x_vector.shape[1] == 256 + # compute x_vector directly from an input file + x_vector = manager.compute_x_vector_from_clip(sample_wav_path) + x_vector2 = manager.compute_x_vector_from_clip(sample_wav_path) + x_vector = torch.FloatTensor(x_vector) + x_vector2 = torch.FloatTensor(x_vector2) + assert x_vector.shape[0] == 256 + assert (x_vector - x_vector2).sum() == 0.0 + + # compute x_vector from a list of wav files. + x_vector3 = manager.compute_x_vector_from_clip([sample_wav_path, sample_wav_path2]) + x_vector3 = torch.FloatTensor(x_vector3) + assert x_vector3.shape[0] == 256 + assert (x_vector - x_vector3).sum() != 0.0 + + @staticmethod def test_speakers_file_processing(): manager = SpeakerManager(x_vectors_file_path=x_vectors_file_path)