mirror of https://github.com/coqui-ai/TTS.git
add compute embedding for the new speaker encoder
parent
3fcc748b2e
commit
3433c2f348
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.speaker_encoder.model import SpeakerEncoder
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_model
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.utils.speakers import save_speaker_mapping
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
@ -77,7 +77,7 @@ for output_file in output_files:
|
|||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
|
||||
# define Encoder model
|
||||
model = SpeakerEncoder(**c.model)
|
||||
model = setup_model(c)
|
||||
model.load_state_dict(torch.load(args.model_path)["model"])
|
||||
model.eval()
|
||||
if args.use_cuda:
|
||||
|
|
|
@ -124,7 +124,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
nn.init.xavier_normal_(out)
|
||||
return out
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, training=True):
|
||||
x = x.transpose(1, 2)
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
|
@ -140,7 +140,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = x.reshape(x.size()[0],-1,x.size()[-1])
|
||||
x = x.reshape(x.size()[0], -1, x.size()[-1])
|
||||
|
||||
w = self.attention(x)
|
||||
|
||||
|
@ -154,4 +154,33 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
x = x.view(x.size()[0], -1)
|
||||
x = self.fc(x)
|
||||
|
||||
if not training:
|
||||
x = torch.nn.functional.normalize(x, p=2, dim=1)
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embedding(self, x, num_frames=250, overlap=0.5):
|
||||
"""
|
||||
Generate embeddings for a batch of utterances
|
||||
x: 1xTxD
|
||||
"""
|
||||
num_overlap = int(num_frames * overlap)
|
||||
max_len = x.shape[1]
|
||||
embed = None
|
||||
cur_iter = 0
|
||||
for offset in range(0, max_len, num_frames - num_overlap):
|
||||
cur_iter += 1
|
||||
end_offset = min(x.shape[1], offset + num_frames)
|
||||
|
||||
# ignore slices with two or less frames, because it's can break instance normalization
|
||||
if end_offset-offset <= 1:
|
||||
continue
|
||||
|
||||
frames = x[:, offset:end_offset]
|
||||
|
||||
if embed is None:
|
||||
embed = self.forward(frames, training=False)
|
||||
else:
|
||||
embed += self.forward(frames, training=False)
|
||||
|
||||
return embed / cur_iter
|
||||
|
|
Loading…
Reference in New Issue