diff --git a/bark/generation.py b/bark/generation.py index 755a58f..9ad67b0 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -480,13 +480,7 @@ def generate_text_semantic( v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") probs = F.softmax(relevant_logits / temp, dim=-1) - # multinomial bugged on mps: shuttle to cpu if necessary - inf_device = probs.device - if probs.device.type == "mps": - probs = probs.to("cpu") - item_next = torch.multinomial(probs, num_samples=1) - probs = probs.to(inf_device) - item_next = item_next.to(inf_device) + item_next = torch.multinomial(probs, num_samples=1).to(torch.int32) if allow_early_stop and ( item_next == SEMANTIC_VOCAB_SIZE or (min_eos_p is not None and probs[-1] >= min_eos_p) @@ -670,13 +664,7 @@ def generate_coarse( v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") probs = F.softmax(relevant_logits / temp, dim=-1) - # multinomial bugged on mps: shuttle to cpu if necessary - inf_device = probs.device - if probs.device.type == "mps": - probs = probs.to("cpu") - item_next = torch.multinomial(probs, num_samples=1) - probs = probs.to(inf_device) - item_next = item_next.to(inf_device) + item_next = torch.multinomial(probs, num_samples=1).to(torch.int32) item_next += logit_start_idx x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1) x_in = torch.cat((x_in, item_next[None]), dim=1)