mirror of https://github.com/suno-ai/bark.git
first commit
parent
f4f32d4cd4
commit
fdf5705fc4
|
@ -54,18 +54,7 @@ SAMPLE_RATE = 24_000
|
|||
|
||||
SUPPORTED_LANGS = [
|
||||
("English", "en"),
|
||||
("German", "de"),
|
||||
("Spanish", "es"),
|
||||
("French", "fr"),
|
||||
("Hindi", "hi"),
|
||||
("Italian", "it"),
|
||||
("Japanese", "ja"),
|
||||
("Korean", "ko"),
|
||||
("Polish", "pl"),
|
||||
("Portuguese", "pt"),
|
||||
("Russian", "ru"),
|
||||
("Turkish", "tr"),
|
||||
("Chinese", "zh"),
|
||||
]
|
||||
|
||||
ALLOWED_PROMPTS = {"announcer"}
|
||||
|
@ -129,13 +118,25 @@ if not hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch.cu
|
|||
)
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
if hasattr(torch.nn.functional, 'flash_attention'):
|
||||
print("------------------------------------------------->Flash Attention is available in PyTorch.")
|
||||
flash_attention_available = True
|
||||
else:
|
||||
# print("------------------------------------------------->Flash Attention is NOT available in PyTorch.")
|
||||
flash_attention_available = False
|
||||
|
||||
|
||||
|
||||
|
||||
def _grab_best_device(use_gpu=True):
|
||||
if torch.cuda.device_count() > 0 and use_gpu:
|
||||
device = "cuda"
|
||||
elif torch.backends.mps.is_available() and use_gpu and GLOBAL_ENABLE_MPS:
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
device = "cuda"
|
||||
return device
|
||||
|
||||
|
||||
|
@ -251,7 +252,7 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
|||
|
||||
def _load_codec_model(device):
|
||||
model = EncodecModel.encodec_model_24khz()
|
||||
model.set_target_bandwidth(6.0)
|
||||
model.set_target_bandwidth(3.0)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
_clear_cuda_cache()
|
||||
|
@ -268,7 +269,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
|
|||
model_key = f"{model_type}"
|
||||
if OFFLOAD_CPU:
|
||||
models_devices[model_key] = device
|
||||
device = "cpu"
|
||||
device = "cuda"
|
||||
if model_key not in models or force_reload:
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
||||
clean_models(model_key=model_key)
|
||||
|
@ -287,11 +288,11 @@ def load_codec_model(use_gpu=True, force_reload=False):
|
|||
device = _grab_best_device(use_gpu=use_gpu)
|
||||
if device == "mps":
|
||||
# encodec doesn't support mps
|
||||
device = "cpu"
|
||||
device = "cuda"
|
||||
model_key = "codec"
|
||||
if OFFLOAD_CPU:
|
||||
models_devices[model_key] = device
|
||||
device = "cpu"
|
||||
device = "cuda"
|
||||
if model_key not in models or force_reload:
|
||||
clean_models(model_key=model_key)
|
||||
model = _load_codec_model(device)
|
||||
|
@ -311,7 +312,7 @@ def preload_models(
|
|||
force_reload=False,
|
||||
):
|
||||
"""Load all the necessary models for the pipeline."""
|
||||
if _grab_best_device() == "cpu" and (
|
||||
if _grab_best_device() == "cuda" and (
|
||||
text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu
|
||||
):
|
||||
logger.warning("No GPU being used. Careful, inference might be very slow!")
|
||||
|
@ -508,7 +509,7 @@ def generate_text_semantic(
|
|||
pbar.close()
|
||||
out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
|
||||
if OFFLOAD_CPU:
|
||||
model.to("cpu")
|
||||
model.to("cuda")
|
||||
assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)
|
||||
_clear_cuda_cache()
|
||||
return out
|
||||
|
@ -527,6 +528,10 @@ def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):
|
|||
COARSE_SEMANTIC_PAD_TOKEN = 12_048
|
||||
COARSE_INFER_TOKEN = 12_050
|
||||
|
||||
import torch.cuda
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
|
||||
def generate_coarse(
|
||||
x_semantic,
|
||||
|
@ -536,25 +541,45 @@ def generate_coarse(
|
|||
top_p=None,
|
||||
silent=False,
|
||||
max_coarse_history=630, # min 60 (faster), max 630 (more context)
|
||||
sliding_window_len=60,
|
||||
use_kv_caching=False,
|
||||
sliding_window_len=120,
|
||||
use_kv_caching=True,
|
||||
# kv_cache_dtype = torch.bfloat16,
|
||||
num_streams=4 # New parameter to control number of CUDA streams
|
||||
):
|
||||
"""Generate coarse audio codes from semantic tokens."""
|
||||
"""Generate coarse audio codes from semantic tokens with CUDA stream optimization.
|
||||
|
||||
Args:
|
||||
... (existing args remain the same) ...
|
||||
num_streams: Number of CUDA streams to use for parallel processing
|
||||
"""
|
||||
# Original input validation
|
||||
assert (
|
||||
isinstance(x_semantic, np.ndarray)
|
||||
and len(x_semantic.shape) == 1
|
||||
and len(x_semantic) > 0
|
||||
and x_semantic.min() >= 0
|
||||
and x_semantic.max() <= SEMANTIC_VOCAB_SIZE - 1
|
||||
and 60 <= max_coarse_history <= 630
|
||||
and max_coarse_history + sliding_window_len <= 1024 - 256
|
||||
)
|
||||
assert 60 <= max_coarse_history <= 630
|
||||
assert max_coarse_history + sliding_window_len <= 1024 - 256
|
||||
|
||||
# Initialize CUDA streams only if CUDA is available
|
||||
use_cuda = torch.cuda.is_available()
|
||||
if use_cuda:
|
||||
streams = [torch.cuda.Stream() for _ in range(num_streams)]
|
||||
else:
|
||||
streams = [None] * num_streams
|
||||
|
||||
semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
|
||||
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
|
||||
|
||||
# History prompt processing
|
||||
if history_prompt is not None:
|
||||
history_prompt = _load_history_prompt(history_prompt)
|
||||
x_semantic_history = history_prompt["semantic_prompt"]
|
||||
x_coarse_history = history_prompt["coarse_prompt"]
|
||||
|
||||
# Original history prompt validation
|
||||
assert (
|
||||
isinstance(x_semantic_history, np.ndarray)
|
||||
and len(x_semantic_history.shape) == 1
|
||||
|
@ -572,33 +597,42 @@ def generate_coarse(
|
|||
== round(semantic_to_coarse_ratio / N_COARSE_CODEBOOKS, 1)
|
||||
)
|
||||
)
|
||||
x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
|
||||
# trim histories correctly
|
||||
n_semantic_hist_provided = np.min(
|
||||
[
|
||||
max_semantic_history,
|
||||
len(x_semantic_history) - len(x_semantic_history) % 2,
|
||||
int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio)),
|
||||
]
|
||||
|
||||
# Process history using the first stream if CUDA is available
|
||||
if use_cuda:
|
||||
torch.cuda.synchronize()
|
||||
with torch.cuda.stream(streams[0]):
|
||||
x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
|
||||
else:
|
||||
x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
|
||||
|
||||
n_semantic_hist_provided = min(
|
||||
max_semantic_history,
|
||||
len(x_semantic_history) - len(x_semantic_history) % 2,
|
||||
int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio))
|
||||
)
|
||||
n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio))
|
||||
x_semantic_history = x_semantic_history[-n_semantic_hist_provided:].astype(np.int32)
|
||||
x_coarse_history = x_coarse_history[-n_coarse_hist_provided:].astype(np.int32)
|
||||
# TODO: bit of a hack for time alignment (sounds better)
|
||||
x_coarse_history = x_coarse_history[:-2]
|
||||
x_coarse_history = x_coarse_history[-n_coarse_hist_provided:-2].astype(np.int32)
|
||||
else:
|
||||
x_semantic_history = np.array([], dtype=np.int32)
|
||||
x_coarse_history = np.array([], dtype=np.int32)
|
||||
# load models if not yet exist
|
||||
|
||||
# Model loading and device setup
|
||||
global models
|
||||
global models_devices
|
||||
if "coarse" not in models:
|
||||
preload_models()
|
||||
model = models["coarse"]
|
||||
if OFFLOAD_CPU:
|
||||
model.to(models_devices["coarse"])
|
||||
if use_cuda:
|
||||
with torch.cuda.stream(streams[0]):
|
||||
model.to(models_devices["coarse"])
|
||||
else:
|
||||
model.to(models_devices["coarse"])
|
||||
device = next(model.parameters()).device
|
||||
# start loop
|
||||
|
||||
# Pre-calculations
|
||||
n_steps = int(
|
||||
round(
|
||||
np.floor(len(x_semantic) * semantic_to_coarse_ratio / N_COARSE_CODEBOOKS)
|
||||
|
@ -606,86 +640,132 @@ def generate_coarse(
|
|||
)
|
||||
)
|
||||
assert n_steps > 0 and n_steps % N_COARSE_CODEBOOKS == 0
|
||||
|
||||
# Prepare input tensors
|
||||
x_semantic = np.hstack([x_semantic_history, x_semantic]).astype(np.int32)
|
||||
x_coarse = x_coarse_history.astype(np.int32)
|
||||
base_semantic_idx = len(x_semantic_history)
|
||||
with _inference_mode():
|
||||
|
||||
# Move tensors to device
|
||||
if use_cuda:
|
||||
with torch.cuda.stream(streams[0]):
|
||||
x_semantic_in = torch.from_numpy(x_semantic)[None].to(device)
|
||||
x_coarse_in = torch.from_numpy(x_coarse)[None].to(device)
|
||||
infer_token = torch.tensor([COARSE_INFER_TOKEN])[None].to(device)
|
||||
torch.cuda.synchronize()
|
||||
else:
|
||||
x_semantic_in = torch.from_numpy(x_semantic)[None].to(device)
|
||||
x_coarse_in = torch.from_numpy(x_coarse)[None].to(device)
|
||||
infer_token = torch.tensor([COARSE_INFER_TOKEN])[None].to(device)
|
||||
|
||||
with _inference_mode():
|
||||
n_window_steps = int(np.ceil(n_steps / sliding_window_len))
|
||||
n_step = 0
|
||||
for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent):
|
||||
semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio))
|
||||
# pad from right side
|
||||
x_in = x_semantic_in[:, np.max([0, semantic_idx - max_semantic_history]) :]
|
||||
x_in = x_in[:, :256]
|
||||
x_in = F.pad(
|
||||
x_in,
|
||||
(0, 256 - x_in.shape[-1]),
|
||||
"constant",
|
||||
COARSE_SEMANTIC_PAD_TOKEN,
|
||||
)
|
||||
x_in = torch.hstack(
|
||||
[
|
||||
|
||||
for window_idx in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent):
|
||||
stream_idx = window_idx % num_streams if use_cuda else 0
|
||||
|
||||
# Use CUDA stream if available
|
||||
if use_cuda:
|
||||
torch.cuda.synchronize()
|
||||
stream_context = torch.cuda.stream(streams[stream_idx])
|
||||
else:
|
||||
stream_context = nullcontext()
|
||||
|
||||
with stream_context:
|
||||
semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio))
|
||||
|
||||
# Prepare input window
|
||||
x_in = x_semantic_in[:, max(0, semantic_idx - max_semantic_history):]
|
||||
x_in = x_in[:, :256]
|
||||
if x_in.shape[-1] < 256:
|
||||
x_in = F.pad(
|
||||
x_in,
|
||||
(0, 256 - x_in.shape[-1]),
|
||||
"constant",
|
||||
COARSE_SEMANTIC_PAD_TOKEN,
|
||||
)
|
||||
|
||||
x_in = torch.cat([
|
||||
x_in,
|
||||
torch.tensor([COARSE_INFER_TOKEN])[None].to(device),
|
||||
infer_token,
|
||||
x_coarse_in[:, -max_coarse_history:],
|
||||
]
|
||||
)
|
||||
kv_cache = None
|
||||
for _ in range(sliding_window_len):
|
||||
if n_step >= n_steps:
|
||||
continue
|
||||
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
|
||||
], dim=1)
|
||||
|
||||
if use_kv_caching and kv_cache is not None:
|
||||
x_input = x_in[:, [-1]]
|
||||
else:
|
||||
x_input = x_in
|
||||
# Process window
|
||||
kv_cache = None
|
||||
for _ in range(sliding_window_len):
|
||||
if n_step >= n_steps:
|
||||
continue
|
||||
|
||||
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
|
||||
x_input = x_in[:, [-1]] if (use_kv_caching and kv_cache is not None) else x_in
|
||||
|
||||
# Model inference
|
||||
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
|
||||
|
||||
logit_start_idx = SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
|
||||
logit_end_idx = SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * CODEBOOK_SIZE
|
||||
relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
|
||||
|
||||
if top_p is not None:
|
||||
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
|
||||
sorted_indices = np.argsort(relevant_logits)[::-1]
|
||||
sorted_logits = relevant_logits[sorted_indices]
|
||||
cumulative_probs = np.cumsum(softmax(sorted_logits))
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
|
||||
sorted_indices_to_remove[0] = False
|
||||
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
||||
relevant_logits = torch.from_numpy(relevant_logits).to(device)
|
||||
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
||||
|
||||
probs = F.softmax(relevant_logits / temp, dim=-1)
|
||||
item_next = torch.multinomial(probs, num_samples=1).to(torch.int32)
|
||||
item_next += logit_start_idx
|
||||
|
||||
x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
|
||||
x_in = torch.cat((x_in, item_next[None]), dim=1)
|
||||
|
||||
del logits, relevant_logits, probs, item_next
|
||||
n_step += 1
|
||||
|
||||
del x_in
|
||||
|
||||
# Synchronize at the end of each window if using CUDA
|
||||
if use_cuda:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
|
||||
logit_start_idx = (
|
||||
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
|
||||
)
|
||||
logit_end_idx = (
|
||||
SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * CODEBOOK_SIZE
|
||||
)
|
||||
relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
|
||||
if top_p is not None:
|
||||
# faster to convert to numpy
|
||||
original_device = relevant_logits.device
|
||||
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
|
||||
sorted_indices = np.argsort(relevant_logits)[::-1]
|
||||
sorted_logits = relevant_logits[sorted_indices]
|
||||
cumulative_probs = np.cumsum(softmax(sorted_logits))
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
|
||||
sorted_indices_to_remove[0] = False
|
||||
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
||||
relevant_logits = torch.from_numpy(relevant_logits)
|
||||
relevant_logits = relevant_logits.to(original_device)
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
||||
probs = F.softmax(relevant_logits / temp, dim=-1)
|
||||
item_next = torch.multinomial(probs, num_samples=1).to(torch.int32)
|
||||
item_next += logit_start_idx
|
||||
x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
|
||||
x_in = torch.cat((x_in, item_next[None]), dim=1)
|
||||
del logits, relevant_logits, probs, item_next
|
||||
n_step += 1
|
||||
del x_in
|
||||
del x_semantic_in
|
||||
|
||||
if OFFLOAD_CPU:
|
||||
model.to("cpu")
|
||||
gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :]
|
||||
if use_cuda:
|
||||
with torch.cuda.stream(streams[0]):
|
||||
model.to("cuda")
|
||||
torch.cuda.synchronize()
|
||||
else:
|
||||
model.to("cuda")
|
||||
|
||||
# Output processing
|
||||
if use_cuda:
|
||||
with torch.cuda.stream(streams[0]):
|
||||
gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history):]
|
||||
torch.cuda.synchronize()
|
||||
else:
|
||||
gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history):]
|
||||
|
||||
del x_coarse_in
|
||||
assert len(gen_coarse_arr) == n_steps
|
||||
|
||||
gen_coarse_audio_arr = gen_coarse_arr.reshape(-1, N_COARSE_CODEBOOKS).T - SEMANTIC_VOCAB_SIZE
|
||||
for n in range(1, N_COARSE_CODEBOOKS):
|
||||
gen_coarse_audio_arr[n, :] -= n * CODEBOOK_SIZE
|
||||
offsets = np.arange(1, N_COARSE_CODEBOOKS) * CODEBOOK_SIZE
|
||||
gen_coarse_audio_arr[1:] -= offsets[:, None]
|
||||
|
||||
_clear_cuda_cache()
|
||||
return gen_coarse_audio_arr
|
||||
return gen_coarse_audio_arr
|
||||
|
||||
|
||||
def generate_fine(
|
||||
|
@ -788,7 +868,7 @@ def generate_fine(
|
|||
gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T
|
||||
del in_arr
|
||||
if OFFLOAD_CPU:
|
||||
model.to("cpu")
|
||||
model.to("cuda")
|
||||
gen_fine_arr = gen_fine_arr[:, n_history:]
|
||||
if n_remove_from_end > 0:
|
||||
gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end]
|
||||
|
@ -816,5 +896,5 @@ def codec_decode(fine_tokens):
|
|||
audio_arr = out.detach().cpu().numpy().squeeze()
|
||||
del arr, emb, out
|
||||
if OFFLOAD_CPU:
|
||||
model.to("cpu")
|
||||
return audio_arr
|
||||
model.to("cuda")
|
||||
return audio_arr
|
Loading…
Reference in New Issue