mirror of https://github.com/coqui-ai/TTS.git
Fix bug after merge
parent
d7042ecfd8
commit
f4abb19515
|
@ -5,6 +5,7 @@ from itertools import chain
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import math
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp.autocast_mode import autocast
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
|
@ -574,11 +575,11 @@ class Vits(BaseTTS):
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
o_scale = torch.exp(-2 * logs_p)
|
o_scale = torch.exp(-2 * logs_p)
|
||||||
# logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
||||||
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
|
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
|
||||||
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
||||||
# logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||||
logp = logp2 + logp3
|
logp = logp2 + logp3 + logp1 + logp4
|
||||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||||
|
|
||||||
# expand prior
|
# expand prior
|
||||||
|
|
Loading…
Reference in New Issue