From f4abb19515a1ec14e8f7c7be11066b6511ccd783 Mon Sep 17 00:00:00 2001 From: Edresson Date: Thu, 26 Aug 2021 16:01:07 -0300 Subject: [PATCH] Fix bug after merge --- TTS/tts/models/vits.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index bc4bf235..600a9551 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -5,6 +5,7 @@ from itertools import chain from typing import Dict, List, Tuple import torch +import math from coqpit import Coqpit from torch import nn from torch.cuda.amp.autocast_mode import autocast @@ -574,11 +575,11 @@ class Vits(BaseTTS): attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) with torch.no_grad(): o_scale = torch.exp(-2 * logs_p) - # logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)]) logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) - # logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] - logp = logp2 + logp3 + logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp2 + logp3 + logp1 + logp4 attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # expand prior