Fixes MPS device errors from Tensor.type() when using generate_text_semantic and generate_coarse (#254)

* fix for pytorch's Tensor.type() lacking support for MPS
pull/84/merge
Raf Gemmail 2023-05-16 11:16:36 +12:00 committed by GitHub
parent 1ad007171e
commit 81d3a507fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 6 deletions

View File

@ -465,8 +465,7 @@ def generate_text_semantic(
)
if top_p is not None:
# faster to convert to numpy
logits_device = relevant_logits.device
logits_dtype = relevant_logits.type()
original_device = relevant_logits.device
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
sorted_indices = np.argsort(relevant_logits)[::-1]
sorted_logits = relevant_logits[sorted_indices]
@ -476,7 +475,7 @@ def generate_text_semantic(
sorted_indices_to_remove[0] = False
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
relevant_logits = torch.from_numpy(relevant_logits)
relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
relevant_logits = relevant_logits.to(original_device)
if top_k is not None:
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
@ -656,8 +655,7 @@ def generate_coarse(
relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
if top_p is not None:
# faster to convert to numpy
logits_device = relevant_logits.device
logits_dtype = relevant_logits.type()
original_device = relevant_logits.device
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
sorted_indices = np.argsort(relevant_logits)[::-1]
sorted_logits = relevant_logits[sorted_indices]
@ -667,7 +665,7 @@ def generate_coarse(
sorted_indices_to_remove[0] = False
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
relevant_logits = torch.from_numpy(relevant_logits)
relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
relevant_logits = relevant_logits.to(original_device)
if top_k is not None:
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
relevant_logits[relevant_logits < v[-1]] = -float("Inf")