reduce mps code handling logic (#327)

pull/349/head
Tony(Baojia) Tong 2023-05-25 10:52:56 -04:00 committed by GitHub
parent bfb7ebf530
commit f6f2db527b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 14 deletions

View File

@ -480,13 +480,7 @@ def generate_text_semantic(
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
probs = F.softmax(relevant_logits / temp, dim=-1)
# multinomial bugged on mps: shuttle to cpu if necessary
inf_device = probs.device
if probs.device.type == "mps":
probs = probs.to("cpu")
item_next = torch.multinomial(probs, num_samples=1)
probs = probs.to(inf_device)
item_next = item_next.to(inf_device)
item_next = torch.multinomial(probs, num_samples=1).to(torch.int32)
if allow_early_stop and (
item_next == SEMANTIC_VOCAB_SIZE
or (min_eos_p is not None and probs[-1] >= min_eos_p)
@ -670,13 +664,7 @@ def generate_coarse(
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
probs = F.softmax(relevant_logits / temp, dim=-1)
# multinomial bugged on mps: shuttle to cpu if necessary
inf_device = probs.device
if probs.device.type == "mps":
probs = probs.to("cpu")
item_next = torch.multinomial(probs, num_samples=1)
probs = probs.to(inf_device)
item_next = item_next.to(inf_device)
item_next = torch.multinomial(probs, num_samples=1).to(torch.int32)
item_next += logit_start_idx
x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
x_in = torch.cat((x_in, item_next[None]), dim=1)