mirror of https://github.com/suno-ai/bark.git
fix issues with mps
parent
19636b21b8
commit
8a02243931
|
@ -781,16 +781,10 @@ def generate_fine(
|
|||
else:
|
||||
relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp
|
||||
probs = F.softmax(relevant_logits, dim=-1)
|
||||
# multinomial bugged on mps: shuttle to cpu if necessary
|
||||
inf_device = probs.device
|
||||
if probs.device.type == "mps":
|
||||
probs = probs.to("cpu")
|
||||
codebook_preds = torch.hstack(
|
||||
[
|
||||
torch.multinomial(probs[nnn], num_samples=1).to(inf_device)
|
||||
for nnn in range(rel_start_fill_idx, 1024)
|
||||
]
|
||||
)
|
||||
codebook_preds = torch.multinomial(
|
||||
probs[rel_start_fill_idx:1024], num_samples=1
|
||||
).reshape(-1)
|
||||
codebook_preds = codebook_preds.to(torch.int32)
|
||||
in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds
|
||||
del logits, codebook_preds
|
||||
# transfer over info into model_in and convert to numpy
|
||||
|
|
Loading…
Reference in New Issue