From 8a02243931238330e9657cd2f998648afc16c46a Mon Sep 17 00:00:00 2001 From: baojia tong Date: Wed, 24 May 2023 17:11:40 -0400 Subject: [PATCH] fix issues with mps --- bark/generation.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index c0f5e94..755a58f 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -781,16 +781,10 @@ def generate_fine( else: relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp probs = F.softmax(relevant_logits, dim=-1) - # multinomial bugged on mps: shuttle to cpu if necessary - inf_device = probs.device - if probs.device.type == "mps": - probs = probs.to("cpu") - codebook_preds = torch.hstack( - [ - torch.multinomial(probs[nnn], num_samples=1).to(inf_device) - for nnn in range(rel_start_fill_idx, 1024) - ] - ) + codebook_preds = torch.multinomial( + probs[rel_start_fill_idx:1024], num_samples=1 + ).reshape(-1) + codebook_preds = codebook_preds.to(torch.int32) in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds del logits, codebook_preds # transfer over info into model_in and convert to numpy