fix: handle difference in xtts/tortoise attention (#199)

pull/4115/head^2
Enno Hermann 2024-12-09 16:13:13 +01:00 committed by GitHub
parent b545ab8b80
commit c0d9ed3d18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 28 additions and 111 deletions

View File

@ -70,11 +70,10 @@ class QKVAttentionLegacy(nn.Module):
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(
bs * self.n_heads, weight.shape[-2], weight.shape[-1]
)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
weight = weight * mask
mask = mask.repeat(self.n_heads, 1, 1)
weight[mask.logical_not()] = -torch.inf
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@ -93,7 +92,9 @@ class AttentionBlock(nn.Module):
channels,
num_heads=1,
num_head_channels=-1,
*,
relative_pos_embeddings=False,
tortoise_norm=False,
):
super().__init__()
self.channels = channels
@ -108,6 +109,7 @@ class AttentionBlock(nn.Module):
self.qkv = nn.Conv1d(channels, channels * 3, 1)
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.tortoise_norm = tortoise_norm
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
if relative_pos_embeddings:
@ -124,10 +126,13 @@ class AttentionBlock(nn.Module):
def forward(self, x, mask=None):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
x_norm = self.norm(x)
qkv = self.qkv(x_norm)
h = self.attention(qkv, mask, self.relative_pos_embeddings)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
if self.tortoise_norm:
return (x + h).reshape(b, c, *spatial)
return (x_norm + h).reshape(b, c, *spatial)
class Upsample(nn.Module):

View File

@ -176,12 +176,14 @@ class ConditioningEncoder(nn.Module):
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
*,
tortoise_norm=False,
):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
attn.append(AttentionBlock(embedding_dim, num_attn_heads, tortoise_norm=tortoise_norm))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim

View File

@ -97,7 +97,7 @@ class AudioMiniEncoder(nn.Module):
self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
attn = []
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
attn.append(AttentionBlock(embedding_dim, num_attn_heads, tortoise_norm=True))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim

View File

@ -130,7 +130,7 @@ class DiffusionLayer(TimestepBlock):
dims=1,
use_scale_shift_norm=True,
)
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True)
def forward(self, x, time_emb):
y = self.resblk(x, time_emb)
@ -177,17 +177,17 @@ class DiffusionTts(nn.Module):
# transformer network.
self.code_embedding = nn.Embedding(in_tokens, model_channels)
self.code_converter = nn.Sequential(
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
)
self.code_norm = normalization(model_channels)
self.latent_conditioner = nn.Sequential(
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
)
self.contextual_embedder = nn.Sequential(
nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2),
@ -196,26 +196,31 @@ class DiffusionTts(nn.Module):
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
tortoise_norm=True,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
tortoise_norm=True,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
tortoise_norm=True,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
tortoise_norm=True,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
tortoise_norm=True,
),
)
self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1))

View File

@ -1,95 +0,0 @@
# ported from: Originally ported from: https://github.com/neonbjb/tortoise-tts
import math
import torch
from torch import nn
from torch.nn import functional as F
from TTS.tts.layers.tortoise.arch_utils import normalization, zero_module
def conv_nd(dims, *args, **kwargs):
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class QKVAttention(nn.Module):
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, mask=None, qk_bias=0):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = weight + qk_bias
if mask is not None:
mask = mask.repeat(self.n_heads, 1, 1)
weight[mask.logical_not()] = -torch.inf
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
class AttentionBlock(nn.Module):
"""An attention block that allows spatial positions to attend to each other."""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
out_channels=None,
do_activation=False,
):
super().__init__()
self.channels = channels
out_channels = channels if out_channels is None else out_channels
self.do_activation = do_activation
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, out_channels * 3, 1)
self.attention = QKVAttention(self.num_heads)
self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1)
self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1))
def forward(self, x, mask=None, qk_bias=0):
b, c, *spatial = x.shape
if mask is not None:
if len(mask.shape) == 2:
mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1)
if mask.shape[1] != x.shape[-1]:
mask = mask[:, : x.shape[-1], : x.shape[-1]]
x = x.reshape(b, c, -1)
x = self.norm(x)
if self.do_activation:
x = F.silu(x, inplace=True)
qkv = self.qkv(x)
h = self.attention(qkv, mask=mask, qk_bias=qk_bias)
h = self.proj_out(h)
xp = self.x_proj(x)
return (xp + h).reshape(b, xp.shape[1], *spatial)