diff --git a/.gitignore b/.gitignore index 48e4ceb..ba0430d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1 @@ -__pycache__/ -.venv \ No newline at end of file +__pycache__/ \ No newline at end of file diff --git a/bark/generation.py b/bark/generation.py index 5753125..4e860d8 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -428,7 +428,7 @@ def generate_text_semantic( else: x_input = x - logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_key_values=kv_cache) + logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache) relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE] if allow_early_stop: relevant_logits = torch.hstack( @@ -611,7 +611,7 @@ def generate_coarse( else: x_input = x_in - logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_key_values=kv_cache) + logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache) logit_start_idx = ( SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE ) diff --git a/bark/model.py b/bark/model.py index 463557c..bb99932 100644 --- a/bark/model.py +++ b/bark/model.py @@ -43,7 +43,7 @@ class CausalSelfAttention(nn.Module): self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size)) - def forward(self, x, layer_past=None, use_cache=False): + def forward(self, x, past_kv=None, use_cache=False): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim @@ -52,9 +52,9 @@ class CausalSelfAttention(nn.Module): q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - if layer_past is not None: - past_key = layer_past[0] - past_value = layer_past[1] + if past_kv is not None: + past_key = past_kv[0] + past_value = past_kv[1] k = torch.cat((past_key, k), dim=-2) v = torch.cat((past_value, v), dim=-2) @@ -68,9 +68,11 @@ class CausalSelfAttention(nn.Module): # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels - if layer_past is not None: - # in theory the attention is still causal but because we're computing it incrementally, - # the last query can attend on all previous keys/values, which which is equivalent to non-causal + if past_kv is not None: + # When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains + # the query for the last token. scaled_dot_product_attention interprets this as the first token in the + # sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so + # to work around this we set is_causal=False. is_causal = False else: is_causal = True @@ -115,8 +117,8 @@ class Block(nn.Module): self.mlp = MLP(config) self.layer_idx = layer_idx - def forward(self, x, layer_past=None, use_cache=False): - attn_output, prev_kvs = self.attn(self.ln_1(x), layer_past=layer_past, use_cache=use_cache) + def forward(self, x, past_kv=None, use_cache=False): + attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache) x = x + attn_output x = x + self.mlp(self.ln_2(x)) return (x, prev_kvs) @@ -163,10 +165,10 @@ class GPT(nn.Module): n_params -= self.transformer.wpe.weight.numel() return n_params - def forward(self, idx, merge_context=False, past_key_values=None, position_ids=None, use_cache=False): + def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False): device = idx.device b, t = idx.size() - if past_key_values is not None: + if past_kv is not None: assert t == 1 tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) else: @@ -185,11 +187,11 @@ class GPT(nn.Module): else: tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - if past_key_values is None: + if past_kv is None: past_length = 0 - past_key_values = tuple([None] * len(self.transformer.h)) + past_kv = tuple([None] * len(self.transformer.h)) else: - past_length = past_key_values[0][0].size(-2) + past_length = past_kv[0][0].size(-2) if position_ids is None: position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device) @@ -201,17 +203,17 @@ class GPT(nn.Module): x = self.transformer.drop(tok_emb + pos_emb) - presents = () if use_cache else None + new_kv = () if use_cache else None - for i, (block, layer_past) in enumerate(zip(self.transformer.h, past_key_values)): - x, kv = block(x, layer_past=layer_past, use_cache=use_cache) + for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)): + x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache) if use_cache: - presents = presents + (kv,) + new_kv = new_kv + (kv,) x = self.transformer.ln_f(x) # inference-time mini-optimization: only forward the lm_head on the very last position logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim - return (logits, presents) + return (logits, new_kv)