fix issues with mps

pull/325/head
baojia tong 2023-05-24 17:11:40 -04:00
parent 19636b21b8
commit 8a02243931
1 changed files with 4 additions and 10 deletions

View File

@ -781,16 +781,10 @@ def generate_fine(
else:
relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp
probs = F.softmax(relevant_logits, dim=-1)
# multinomial bugged on mps: shuttle to cpu if necessary
inf_device = probs.device
if probs.device.type == "mps":
probs = probs.to("cpu")
codebook_preds = torch.hstack(
[
torch.multinomial(probs[nnn], num_samples=1).to(inf_device)
for nnn in range(rel_start_fill_idx, 1024)
]
)
codebook_preds = torch.multinomial(
probs[rel_start_fill_idx:1024], num_samples=1
).reshape(-1)
codebook_preds = codebook_preds.to(torch.int32)
in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds
del logits, codebook_preds
# transfer over info into model_in and convert to numpy