mirror of https://github.com/suno-ai/bark.git
Merge pull request #27 from zygi/main
Add key/value caching for autoregressive generationpull/62/head
commit
3247106492
|
@ -1,2 +1 @@
|
|||
__pycache__/
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ def text_to_semantic(
|
|||
history_prompt: Optional[str] = None,
|
||||
temp: float = 0.7,
|
||||
silent: bool = False,
|
||||
use_kv_caching = False,
|
||||
):
|
||||
"""Generate semantic array from text.
|
||||
|
||||
|
@ -27,6 +28,7 @@ def text_to_semantic(
|
|||
history_prompt=history_prompt,
|
||||
temp=temp,
|
||||
silent=silent,
|
||||
use_kv_caching=use_kv_caching
|
||||
)
|
||||
return x_semantic
|
||||
|
||||
|
@ -37,6 +39,7 @@ def semantic_to_waveform(
|
|||
temp: float = 0.7,
|
||||
silent: bool = False,
|
||||
output_full: bool = False,
|
||||
use_kv_caching = False
|
||||
):
|
||||
"""Generate audio array from semantic input.
|
||||
|
||||
|
@ -55,6 +58,7 @@ def semantic_to_waveform(
|
|||
history_prompt=history_prompt,
|
||||
temp=temp,
|
||||
silent=silent,
|
||||
use_kv_caching=use_kv_caching
|
||||
)
|
||||
fine_tokens = generate_fine(
|
||||
coarse_tokens,
|
||||
|
@ -88,6 +92,7 @@ def generate_audio(
|
|||
waveform_temp: float = 0.7,
|
||||
silent: bool = False,
|
||||
output_full: bool = False,
|
||||
use_kv_caching = False
|
||||
):
|
||||
"""Generate audio array from input text.
|
||||
|
||||
|
@ -103,7 +108,7 @@ def generate_audio(
|
|||
numpy audio array at sample frequency 24khz
|
||||
"""
|
||||
semantic_tokens = text_to_semantic(
|
||||
text, history_prompt=history_prompt, temp=text_temp, silent=silent,
|
||||
text, history_prompt=history_prompt, temp=text_temp, silent=silent, use_kv_caching=use_kv_caching
|
||||
)
|
||||
out = semantic_to_waveform(
|
||||
semantic_tokens,
|
||||
|
@ -111,6 +116,7 @@ def generate_audio(
|
|||
temp=waveform_temp,
|
||||
silent=silent,
|
||||
output_full=output_full,
|
||||
use_kv_caching=use_kv_caching
|
||||
)
|
||||
if output_full:
|
||||
full_generation, audio_arr = out
|
||||
|
|
|
@ -359,6 +359,7 @@ def generate_text_semantic(
|
|||
max_gen_duration_s=None,
|
||||
allow_early_stop=True,
|
||||
model=None,
|
||||
use_kv_caching=False
|
||||
):
|
||||
"""Generate semantic tokens from text."""
|
||||
assert isinstance(text, str)
|
||||
|
@ -420,8 +421,14 @@ def generate_text_semantic(
|
|||
pbar = tqdm.tqdm(disable=silent, total=100)
|
||||
pbar_state = 0
|
||||
tot_generated_duration_s = 0
|
||||
kv_cache = None
|
||||
for n in range(n_tot_steps):
|
||||
logits = model(x, merge_context=True)
|
||||
if use_kv_caching and kv_cache is not None:
|
||||
x_input = x[:, [-1]]
|
||||
else:
|
||||
x_input = x
|
||||
|
||||
logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache)
|
||||
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
|
||||
if allow_early_stop:
|
||||
relevant_logits = torch.hstack(
|
||||
|
@ -498,6 +505,7 @@ def generate_coarse(
|
|||
max_coarse_history=630, # min 60 (faster), max 630 (more context)
|
||||
sliding_window_len=60,
|
||||
model=None,
|
||||
use_kv_caching=False
|
||||
):
|
||||
"""Generate coarse audio codes from semantic tokens."""
|
||||
assert (
|
||||
|
@ -592,11 +600,18 @@ def generate_coarse(
|
|||
x_coarse_in[:, -max_coarse_history:],
|
||||
]
|
||||
)
|
||||
kv_cache = None
|
||||
for _ in range(sliding_window_len):
|
||||
if n_step >= n_steps:
|
||||
continue
|
||||
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
|
||||
logits = model(x_in)
|
||||
|
||||
if use_kv_caching and kv_cache is not None:
|
||||
x_input = x_in[:, [-1]]
|
||||
else:
|
||||
x_input = x_in
|
||||
|
||||
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
|
||||
logit_start_idx = (
|
||||
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
|
||||
)
|
||||
|
|
|
@ -43,7 +43,7 @@ class CausalSelfAttention(nn.Module):
|
|||
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
||||
.view(1, 1, config.block_size, config.block_size))
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, past_kv=None, use_cache=False):
|
||||
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
|
@ -52,14 +52,36 @@ class CausalSelfAttention(nn.Module):
|
|||
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
|
||||
if past_kv is not None:
|
||||
past_key = past_kv[0]
|
||||
past_value = past_kv[1]
|
||||
k = torch.cat((past_key, k), dim=-2)
|
||||
v = torch.cat((past_value, v), dim=-2)
|
||||
|
||||
FULL_T = k.shape[-2]
|
||||
|
||||
if use_cache is True:
|
||||
present = (k, v)
|
||||
else:
|
||||
present = None
|
||||
|
||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
if self.flash:
|
||||
# efficient attention using Flash Attention CUDA kernels
|
||||
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
|
||||
if past_kv is not None:
|
||||
# When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains
|
||||
# the query for the last token. scaled_dot_product_attention interprets this as the first token in the
|
||||
# sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so
|
||||
# to work around this we set is_causal=False.
|
||||
is_causal = False
|
||||
else:
|
||||
is_causal = True
|
||||
|
||||
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal)
|
||||
else:
|
||||
# manual implementation of attention
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
||||
att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf'))
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_dropout(att)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
|
@ -67,7 +89,7 @@ class CausalSelfAttention(nn.Module):
|
|||
|
||||
# output projection
|
||||
y = self.resid_dropout(self.c_proj(y))
|
||||
return y
|
||||
return (y, present)
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
|
@ -95,10 +117,11 @@ class Block(nn.Module):
|
|||
self.mlp = MLP(config)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.ln_1(x))
|
||||
def forward(self, x, past_kv=None, use_cache=False):
|
||||
attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache)
|
||||
x = x + attn_output
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
return (x, prev_kvs)
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
|
@ -142,33 +165,55 @@ class GPT(nn.Module):
|
|||
n_params -= self.transformer.wpe.weight.numel()
|
||||
return n_params
|
||||
|
||||
def forward(self, idx, merge_context=False):
|
||||
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
|
||||
device = idx.device
|
||||
b, t = idx.size()
|
||||
if merge_context:
|
||||
assert(idx.shape[1] >= 256+256+1)
|
||||
t = idx.shape[1] - 256
|
||||
else:
|
||||
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
||||
|
||||
# forward the GPT model itself
|
||||
if merge_context:
|
||||
tok_emb = torch.cat([
|
||||
self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
|
||||
self.transformer.wte(idx[:,256+256:])
|
||||
], dim=1)
|
||||
else:
|
||||
if past_kv is not None:
|
||||
assert t == 1
|
||||
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||
else:
|
||||
if merge_context:
|
||||
assert(idx.shape[1] >= 256+256+1)
|
||||
t = idx.shape[1] - 256
|
||||
else:
|
||||
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
||||
|
||||
# forward the GPT model itself
|
||||
if merge_context:
|
||||
tok_emb = torch.cat([
|
||||
self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
|
||||
self.transformer.wte(idx[:,256+256:])
|
||||
], dim=1)
|
||||
else:
|
||||
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||
|
||||
if past_kv is None:
|
||||
past_length = 0
|
||||
past_kv = tuple([None] * len(self.transformer.h))
|
||||
else:
|
||||
past_length = past_kv[0][0].size(-2)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0) # shape (1, t)
|
||||
assert position_ids.shape == (1, t)
|
||||
|
||||
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
|
||||
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
|
||||
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
|
||||
new_kv = () if use_cache else None
|
||||
|
||||
for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
|
||||
x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache)
|
||||
|
||||
if use_cache:
|
||||
new_kv = new_kv + (kv,)
|
||||
|
||||
x = self.transformer.ln_f(x)
|
||||
|
||||
# inference-time mini-optimization: only forward the lm_head on the very last position
|
||||
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||
|
||||
return logits
|
||||
return (logits, new_kv)
|
||||
|
|
Loading…
Reference in New Issue