pull/610/merge
Yash Rawal 2024-11-03 21:55:37 +05:30 committed by GitHub
commit cf2c075020
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 183 additions and 103 deletions

View File

@ -54,18 +54,7 @@ SAMPLE_RATE = 24_000
SUPPORTED_LANGS = [
("English", "en"),
("German", "de"),
("Spanish", "es"),
("French", "fr"),
("Hindi", "hi"),
("Italian", "it"),
("Japanese", "ja"),
("Korean", "ko"),
("Polish", "pl"),
("Portuguese", "pt"),
("Russian", "ru"),
("Turkish", "tr"),
("Chinese", "zh"),
]
ALLOWED_PROMPTS = {"announcer"}
@ -129,13 +118,25 @@ if not hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch.cu
)
import torch
if hasattr(torch.nn.functional, 'flash_attention'):
print("------------------------------------------------->Flash Attention is available in PyTorch.")
flash_attention_available = True
else:
# print("------------------------------------------------->Flash Attention is NOT available in PyTorch.")
flash_attention_available = False
def _grab_best_device(use_gpu=True):
if torch.cuda.device_count() > 0 and use_gpu:
device = "cuda"
elif torch.backends.mps.is_available() and use_gpu and GLOBAL_ENABLE_MPS:
device = "mps"
else:
device = "cpu"
device = "cuda"
return device
@ -251,7 +252,7 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
def _load_codec_model(device):
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
model.set_target_bandwidth(3.0)
model.eval()
model.to(device)
_clear_cuda_cache()
@ -268,7 +269,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
model_key = f"{model_type}"
if OFFLOAD_CPU:
models_devices[model_key] = device
device = "cpu"
device = "cuda"
if model_key not in models or force_reload:
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
clean_models(model_key=model_key)
@ -287,11 +288,11 @@ def load_codec_model(use_gpu=True, force_reload=False):
device = _grab_best_device(use_gpu=use_gpu)
if device == "mps":
# encodec doesn't support mps
device = "cpu"
device = "cuda"
model_key = "codec"
if OFFLOAD_CPU:
models_devices[model_key] = device
device = "cpu"
device = "cuda"
if model_key not in models or force_reload:
clean_models(model_key=model_key)
model = _load_codec_model(device)
@ -311,7 +312,7 @@ def preload_models(
force_reload=False,
):
"""Load all the necessary models for the pipeline."""
if _grab_best_device() == "cpu" and (
if _grab_best_device() == "cuda" and (
text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu
):
logger.warning("No GPU being used. Careful, inference might be very slow!")
@ -508,7 +509,7 @@ def generate_text_semantic(
pbar.close()
out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
if OFFLOAD_CPU:
model.to("cpu")
model.to("cuda")
assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)
_clear_cuda_cache()
return out
@ -527,6 +528,10 @@ def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):
COARSE_SEMANTIC_PAD_TOKEN = 12_048
COARSE_INFER_TOKEN = 12_050
import torch.cuda
import numpy as np
import torch.nn.functional as F
import tqdm
def generate_coarse(
x_semantic,
@ -536,25 +541,45 @@ def generate_coarse(
top_p=None,
silent=False,
max_coarse_history=630, # min 60 (faster), max 630 (more context)
sliding_window_len=60,
use_kv_caching=False,
sliding_window_len=120,
use_kv_caching=True,
# kv_cache_dtype = torch.bfloat16,
num_streams=4 # New parameter to control number of CUDA streams
):
"""Generate coarse audio codes from semantic tokens."""
"""Generate coarse audio codes from semantic tokens with CUDA stream optimization.
Args:
... (existing args remain the same) ...
num_streams: Number of CUDA streams to use for parallel processing
"""
# Original input validation
assert (
isinstance(x_semantic, np.ndarray)
and len(x_semantic.shape) == 1
and len(x_semantic) > 0
and x_semantic.min() >= 0
and x_semantic.max() <= SEMANTIC_VOCAB_SIZE - 1
and 60 <= max_coarse_history <= 630
and max_coarse_history + sliding_window_len <= 1024 - 256
)
assert 60 <= max_coarse_history <= 630
assert max_coarse_history + sliding_window_len <= 1024 - 256
# Initialize CUDA streams only if CUDA is available
use_cuda = torch.cuda.is_available()
if use_cuda:
streams = [torch.cuda.Stream() for _ in range(num_streams)]
else:
streams = [None] * num_streams
semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
# History prompt processing
if history_prompt is not None:
history_prompt = _load_history_prompt(history_prompt)
x_semantic_history = history_prompt["semantic_prompt"]
x_coarse_history = history_prompt["coarse_prompt"]
# Original history prompt validation
assert (
isinstance(x_semantic_history, np.ndarray)
and len(x_semantic_history.shape) == 1
@ -572,33 +597,42 @@ def generate_coarse(
== round(semantic_to_coarse_ratio / N_COARSE_CODEBOOKS, 1)
)
)
x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
# trim histories correctly
n_semantic_hist_provided = np.min(
[
max_semantic_history,
len(x_semantic_history) - len(x_semantic_history) % 2,
int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio)),
]
# Process history using the first stream if CUDA is available
if use_cuda:
torch.cuda.synchronize()
with torch.cuda.stream(streams[0]):
x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
else:
x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
n_semantic_hist_provided = min(
max_semantic_history,
len(x_semantic_history) - len(x_semantic_history) % 2,
int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio))
)
n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio))
x_semantic_history = x_semantic_history[-n_semantic_hist_provided:].astype(np.int32)
x_coarse_history = x_coarse_history[-n_coarse_hist_provided:].astype(np.int32)
# TODO: bit of a hack for time alignment (sounds better)
x_coarse_history = x_coarse_history[:-2]
x_coarse_history = x_coarse_history[-n_coarse_hist_provided:-2].astype(np.int32)
else:
x_semantic_history = np.array([], dtype=np.int32)
x_coarse_history = np.array([], dtype=np.int32)
# load models if not yet exist
# Model loading and device setup
global models
global models_devices
if "coarse" not in models:
preload_models()
model = models["coarse"]
if OFFLOAD_CPU:
model.to(models_devices["coarse"])
if use_cuda:
with torch.cuda.stream(streams[0]):
model.to(models_devices["coarse"])
else:
model.to(models_devices["coarse"])
device = next(model.parameters()).device
# start loop
# Pre-calculations
n_steps = int(
round(
np.floor(len(x_semantic) * semantic_to_coarse_ratio / N_COARSE_CODEBOOKS)
@ -606,86 +640,132 @@ def generate_coarse(
)
)
assert n_steps > 0 and n_steps % N_COARSE_CODEBOOKS == 0
# Prepare input tensors
x_semantic = np.hstack([x_semantic_history, x_semantic]).astype(np.int32)
x_coarse = x_coarse_history.astype(np.int32)
base_semantic_idx = len(x_semantic_history)
with _inference_mode():
# Move tensors to device
if use_cuda:
with torch.cuda.stream(streams[0]):
x_semantic_in = torch.from_numpy(x_semantic)[None].to(device)
x_coarse_in = torch.from_numpy(x_coarse)[None].to(device)
infer_token = torch.tensor([COARSE_INFER_TOKEN])[None].to(device)
torch.cuda.synchronize()
else:
x_semantic_in = torch.from_numpy(x_semantic)[None].to(device)
x_coarse_in = torch.from_numpy(x_coarse)[None].to(device)
infer_token = torch.tensor([COARSE_INFER_TOKEN])[None].to(device)
with _inference_mode():
n_window_steps = int(np.ceil(n_steps / sliding_window_len))
n_step = 0
for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent):
semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio))
# pad from right side
x_in = x_semantic_in[:, np.max([0, semantic_idx - max_semantic_history]) :]
x_in = x_in[:, :256]
x_in = F.pad(
x_in,
(0, 256 - x_in.shape[-1]),
"constant",
COARSE_SEMANTIC_PAD_TOKEN,
)
x_in = torch.hstack(
[
for window_idx in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent):
stream_idx = window_idx % num_streams if use_cuda else 0
# Use CUDA stream if available
if use_cuda:
torch.cuda.synchronize()
stream_context = torch.cuda.stream(streams[stream_idx])
else:
stream_context = nullcontext()
with stream_context:
semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio))
# Prepare input window
x_in = x_semantic_in[:, max(0, semantic_idx - max_semantic_history):]
x_in = x_in[:, :256]
if x_in.shape[-1] < 256:
x_in = F.pad(
x_in,
(0, 256 - x_in.shape[-1]),
"constant",
COARSE_SEMANTIC_PAD_TOKEN,
)
x_in = torch.cat([
x_in,
torch.tensor([COARSE_INFER_TOKEN])[None].to(device),
infer_token,
x_coarse_in[:, -max_coarse_history:],
]
)
kv_cache = None
for _ in range(sliding_window_len):
if n_step >= n_steps:
continue
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
], dim=1)
if use_kv_caching and kv_cache is not None:
x_input = x_in[:, [-1]]
else:
x_input = x_in
# Process window
kv_cache = None
for _ in range(sliding_window_len):
if n_step >= n_steps:
continue
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
x_input = x_in[:, [-1]] if (use_kv_caching and kv_cache is not None) else x_in
# Model inference
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
logit_start_idx = SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
logit_end_idx = SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * CODEBOOK_SIZE
relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
if top_p is not None:
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
sorted_indices = np.argsort(relevant_logits)[::-1]
sorted_logits = relevant_logits[sorted_indices]
cumulative_probs = np.cumsum(softmax(sorted_logits))
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
sorted_indices_to_remove[0] = False
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
relevant_logits = torch.from_numpy(relevant_logits).to(device)
if top_k is not None:
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
probs = F.softmax(relevant_logits / temp, dim=-1)
item_next = torch.multinomial(probs, num_samples=1).to(torch.int32)
item_next += logit_start_idx
x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
x_in = torch.cat((x_in, item_next[None]), dim=1)
del logits, relevant_logits, probs, item_next
n_step += 1
del x_in
# Synchronize at the end of each window if using CUDA
if use_cuda:
torch.cuda.synchronize()
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
logit_start_idx = (
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
)
logit_end_idx = (
SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * CODEBOOK_SIZE
)
relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
if top_p is not None:
# faster to convert to numpy
original_device = relevant_logits.device
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
sorted_indices = np.argsort(relevant_logits)[::-1]
sorted_logits = relevant_logits[sorted_indices]
cumulative_probs = np.cumsum(softmax(sorted_logits))
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
sorted_indices_to_remove[0] = False
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
relevant_logits = torch.from_numpy(relevant_logits)
relevant_logits = relevant_logits.to(original_device)
if top_k is not None:
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
probs = F.softmax(relevant_logits / temp, dim=-1)
item_next = torch.multinomial(probs, num_samples=1).to(torch.int32)
item_next += logit_start_idx
x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
x_in = torch.cat((x_in, item_next[None]), dim=1)
del logits, relevant_logits, probs, item_next
n_step += 1
del x_in
del x_semantic_in
if OFFLOAD_CPU:
model.to("cpu")
gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :]
if use_cuda:
with torch.cuda.stream(streams[0]):
model.to("cuda")
torch.cuda.synchronize()
else:
model.to("cuda")
# Output processing
if use_cuda:
with torch.cuda.stream(streams[0]):
gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history):]
torch.cuda.synchronize()
else:
gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history):]
del x_coarse_in
assert len(gen_coarse_arr) == n_steps
gen_coarse_audio_arr = gen_coarse_arr.reshape(-1, N_COARSE_CODEBOOKS).T - SEMANTIC_VOCAB_SIZE
for n in range(1, N_COARSE_CODEBOOKS):
gen_coarse_audio_arr[n, :] -= n * CODEBOOK_SIZE
offsets = np.arange(1, N_COARSE_CODEBOOKS) * CODEBOOK_SIZE
gen_coarse_audio_arr[1:] -= offsets[:, None]
_clear_cuda_cache()
return gen_coarse_audio_arr
return gen_coarse_audio_arr
def generate_fine(
@ -788,7 +868,7 @@ def generate_fine(
gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T
del in_arr
if OFFLOAD_CPU:
model.to("cpu")
model.to("cuda")
gen_fine_arr = gen_fine_arr[:, n_history:]
if n_remove_from_end > 0:
gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end]
@ -816,5 +896,5 @@ def codec_decode(fine_tokens):
audio_arr = out.detach().cpu().numpy().squeeze()
del arr, emb, out
if OFFLOAD_CPU:
model.to("cpu")
return audio_arr
model.to("cuda")
return audio_arr