mirror of https://github.com/coqui-ai/TTS.git
fix: handle difference in xtts/tortoise attention (#199)
parent
b545ab8b80
commit
c0d9ed3d18
|
@ -70,11 +70,10 @@ class QKVAttentionLegacy(nn.Module):
|
|||
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(
|
||||
bs * self.n_heads, weight.shape[-2], weight.shape[-1]
|
||||
)
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
if mask is not None:
|
||||
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
||||
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
|
||||
weight = weight * mask
|
||||
mask = mask.repeat(self.n_heads, 1, 1)
|
||||
weight[mask.logical_not()] = -torch.inf
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||
|
||||
return a.reshape(bs, -1, length)
|
||||
|
@ -93,7 +92,9 @@ class AttentionBlock(nn.Module):
|
|||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
*,
|
||||
relative_pos_embeddings=False,
|
||||
tortoise_norm=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
@ -108,6 +109,7 @@ class AttentionBlock(nn.Module):
|
|||
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
||||
# split heads before split qkv
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
self.tortoise_norm = tortoise_norm
|
||||
|
||||
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
||||
if relative_pos_embeddings:
|
||||
|
@ -124,10 +126,13 @@ class AttentionBlock(nn.Module):
|
|||
def forward(self, x, mask=None):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
x_norm = self.norm(x)
|
||||
qkv = self.qkv(x_norm)
|
||||
h = self.attention(qkv, mask, self.relative_pos_embeddings)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
if self.tortoise_norm:
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
return (x_norm + h).reshape(b, c, *spatial)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
|
|
@ -176,12 +176,14 @@ class ConditioningEncoder(nn.Module):
|
|||
embedding_dim,
|
||||
attn_blocks=6,
|
||||
num_attn_heads=4,
|
||||
*,
|
||||
tortoise_norm=False,
|
||||
):
|
||||
super().__init__()
|
||||
attn = []
|
||||
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads, tortoise_norm=tortoise_norm))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
|
||||
|
|
|
@ -97,7 +97,7 @@ class AudioMiniEncoder(nn.Module):
|
|||
self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
|
||||
attn = []
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads, tortoise_norm=True))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
|
||||
|
|
|
@ -130,7 +130,7 @@ class DiffusionLayer(TimestepBlock):
|
|||
dims=1,
|
||||
use_scale_shift_norm=True,
|
||||
)
|
||||
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
|
||||
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True)
|
||||
|
||||
def forward(self, x, time_emb):
|
||||
y = self.resblk(x, time_emb)
|
||||
|
@ -177,17 +177,17 @@ class DiffusionTts(nn.Module):
|
|||
# transformer network.
|
||||
self.code_embedding = nn.Embedding(in_tokens, model_channels)
|
||||
self.code_converter = nn.Sequential(
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
|
||||
)
|
||||
self.code_norm = normalization(model_channels)
|
||||
self.latent_conditioner = nn.Sequential(
|
||||
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
|
||||
)
|
||||
self.contextual_embedder = nn.Sequential(
|
||||
nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2),
|
||||
|
@ -196,26 +196,31 @@ class DiffusionTts(nn.Module):
|
|||
model_channels * 2,
|
||||
num_heads,
|
||||
relative_pos_embeddings=True,
|
||||
tortoise_norm=True,
|
||||
),
|
||||
AttentionBlock(
|
||||
model_channels * 2,
|
||||
num_heads,
|
||||
relative_pos_embeddings=True,
|
||||
tortoise_norm=True,
|
||||
),
|
||||
AttentionBlock(
|
||||
model_channels * 2,
|
||||
num_heads,
|
||||
relative_pos_embeddings=True,
|
||||
tortoise_norm=True,
|
||||
),
|
||||
AttentionBlock(
|
||||
model_channels * 2,
|
||||
num_heads,
|
||||
relative_pos_embeddings=True,
|
||||
tortoise_norm=True,
|
||||
),
|
||||
AttentionBlock(
|
||||
model_channels * 2,
|
||||
num_heads,
|
||||
relative_pos_embeddings=True,
|
||||
tortoise_norm=True,
|
||||
),
|
||||
)
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1))
|
||||
|
|
|
@ -1,95 +0,0 @@
|
|||
# ported from: Originally ported from: https://github.com/neonbjb/tortoise-tts
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.tortoise.arch_utils import normalization, zero_module
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv, mask=None, qk_bias=0):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = weight + qk_bias
|
||||
if mask is not None:
|
||||
mask = mask.repeat(self.n_heads, 1, 1)
|
||||
weight[mask.logical_not()] = -torch.inf
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""An attention block that allows spatial positions to attend to each other."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
out_channels=None,
|
||||
do_activation=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
out_channels = channels if out_channels is None else out_channels
|
||||
self.do_activation = do_activation
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = conv_nd(1, channels, out_channels * 3, 1)
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1)
|
||||
self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1))
|
||||
|
||||
def forward(self, x, mask=None, qk_bias=0):
|
||||
b, c, *spatial = x.shape
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1)
|
||||
if mask.shape[1] != x.shape[-1]:
|
||||
mask = mask[:, : x.shape[-1], : x.shape[-1]]
|
||||
|
||||
x = x.reshape(b, c, -1)
|
||||
x = self.norm(x)
|
||||
if self.do_activation:
|
||||
x = F.silu(x, inplace=True)
|
||||
qkv = self.qkv(x)
|
||||
h = self.attention(qkv, mask=mask, qk_bias=qk_bias)
|
||||
h = self.proj_out(h)
|
||||
xp = self.x_proj(x)
|
||||
return (xp + h).reshape(b, xp.shape[1], *spatial)
|
Loading…
Reference in New Issue