Fix bug after merge

pull/1032/head
Edresson 2021-08-26 16:01:07 -03:00 committed by Eren Gölge
parent d7042ecfd8
commit f4abb19515
1 changed files with 4 additions and 3 deletions

View File

@ -5,6 +5,7 @@ from itertools import chain
from typing import Dict, List, Tuple
import torch
import math
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
@ -574,11 +575,11 @@ class Vits(BaseTTS):
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
with torch.no_grad():
o_scale = torch.exp(-2 * logs_p)
# logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
# logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3
logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3 + logp1 + logp4
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
# expand prior