mirror of https://github.com/suno-ai/bark.git
reduce mps code handling logic (#327)
parent
bfb7ebf530
commit
f6f2db527b
|
@ -480,13 +480,7 @@ def generate_text_semantic(
|
|||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
||||
probs = F.softmax(relevant_logits / temp, dim=-1)
|
||||
# multinomial bugged on mps: shuttle to cpu if necessary
|
||||
inf_device = probs.device
|
||||
if probs.device.type == "mps":
|
||||
probs = probs.to("cpu")
|
||||
item_next = torch.multinomial(probs, num_samples=1)
|
||||
probs = probs.to(inf_device)
|
||||
item_next = item_next.to(inf_device)
|
||||
item_next = torch.multinomial(probs, num_samples=1).to(torch.int32)
|
||||
if allow_early_stop and (
|
||||
item_next == SEMANTIC_VOCAB_SIZE
|
||||
or (min_eos_p is not None and probs[-1] >= min_eos_p)
|
||||
|
@ -670,13 +664,7 @@ def generate_coarse(
|
|||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
||||
probs = F.softmax(relevant_logits / temp, dim=-1)
|
||||
# multinomial bugged on mps: shuttle to cpu if necessary
|
||||
inf_device = probs.device
|
||||
if probs.device.type == "mps":
|
||||
probs = probs.to("cpu")
|
||||
item_next = torch.multinomial(probs, num_samples=1)
|
||||
probs = probs.to(inf_device)
|
||||
item_next = item_next.to(inf_device)
|
||||
item_next = torch.multinomial(probs, num_samples=1).to(torch.int32)
|
||||
item_next += logit_start_idx
|
||||
x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
|
||||
x_in = torch.cat((x_in, item_next[None]), dim=1)
|
||||
|
|
Loading…
Reference in New Issue