pull/366/merge
Michael Wei 2024-04-12 14:45:45 -04:00 committed by GitHub
commit cf5cff9dc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 7 additions and 7 deletions

View File

@ -445,16 +445,16 @@ def generate_text_semantic(
with _inference_mode():
x = x.to(device)
n_tot_steps = 768
# preallocate tensor
x_initial = x.shape[1]
x = torch.hstack([x , torch.empty([1, n_tot_steps], dtype=torch.int32, device=device)])
# custom tqdm updates since we don't know when eos will occur
pbar = tqdm.tqdm(disable=silent, total=n_tot_steps)
pbar_state = 0
tot_generated_duration_s = 0
kv_cache = None
for n in range(n_tot_steps):
if use_kv_caching and kv_cache is not None:
x_input = x[:, [-1]]
else:
x_input = x
x_input = x[:, [x_initial + n - 1]] if use_kv_caching and kv_cache is not None else x[:,:x_initial + n]
logits, kv_cache = model(
x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
)
@ -485,10 +485,11 @@ def generate_text_semantic(
item_next == SEMANTIC_VOCAB_SIZE
or (min_eos_p is not None and probs[-1] >= min_eos_p)
):
n -= 1 # backtrack 1
# eos found, so break
pbar.update(n - pbar_state)
break
x = torch.cat((x, item_next[None]), dim=1)
x[0][x_initial + n] = item_next
tot_generated_duration_s += 1 / SEMANTIC_RATE_HZ
if max_gen_duration_s is not None and tot_generated_duration_s > max_gen_duration_s:
pbar.update(n - pbar_state)
@ -496,7 +497,6 @@ def generate_text_semantic(
if n == n_tot_steps - 1:
pbar.update(n - pbar_state)
break
del logits, relevant_logits, probs, item_next
if n > pbar_state:
if n > pbar.total:
@ -506,7 +506,7 @@ def generate_text_semantic(
pbar.total = n
pbar.refresh()
pbar.close()
out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
out = x.detach().cpu().numpy().squeeze()[x_initial : x_initial + n + 1]
if OFFLOAD_CPU:
model.to("cpu")
assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)