diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index c38e0e7e..410086de 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -6,7 +6,7 @@ import numpy as np import torch from tqdm import tqdm -from TTS.speaker_encoder.model import SpeakerEncoder +from TTS.speaker_encoder.utils.generic_utils import setup_model from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.utils.speakers import save_speaker_mapping from TTS.utils.audio import AudioProcessor @@ -77,7 +77,7 @@ for output_file in output_files: os.makedirs(os.path.dirname(output_file), exist_ok=True) # define Encoder model -model = SpeakerEncoder(**c.model) +model = setup_model(c) model.load_state_dict(torch.load(args.model_path)["model"]) model.eval() if args.use_cuda: diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index b35a9c89..9b79b7a7 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -124,7 +124,7 @@ class ResNetSpeakerEncoder(nn.Module): nn.init.xavier_normal_(out) return out - def forward(self, x): + def forward(self, x, training=True): x = x.transpose(1, 2) with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): @@ -140,7 +140,7 @@ class ResNetSpeakerEncoder(nn.Module): x = self.layer3(x) x = self.layer4(x) - x = x.reshape(x.size()[0],-1,x.size()[-1]) + x = x.reshape(x.size()[0], -1, x.size()[-1]) w = self.attention(x) @@ -154,4 +154,33 @@ class ResNetSpeakerEncoder(nn.Module): x = x.view(x.size()[0], -1) x = self.fc(x) + if not training: + x = torch.nn.functional.normalize(x, p=2, dim=1) return x + + @torch.no_grad() + def compute_embedding(self, x, num_frames=250, overlap=0.5): + """ + Generate embeddings for a batch of utterances + x: 1xTxD + """ + num_overlap = int(num_frames * overlap) + max_len = x.shape[1] + embed = None + cur_iter = 0 + for offset in range(0, max_len, num_frames - num_overlap): + cur_iter += 1 + end_offset = min(x.shape[1], offset + num_frames) + + # ignore slices with two or less frames, because it's can break instance normalization + if end_offset-offset <= 1: + continue + + frames = x[:, offset:end_offset] + + if embed is None: + embed = self.forward(frames, training=False) + else: + embed += self.forward(frames, training=False) + + return embed / cur_iter