From 03a5f7e6c05f69b67602b8bd06a2937505da51ec Mon Sep 17 00:00:00 2001 From: zhuwenxing Date: Fri, 21 Jun 2024 14:20:08 +0800 Subject: [PATCH] test: update the lib of bf16 (#34043) Signed-off-by: zhuwenxing --- tests/restful_client_v2/requirements.txt | 3 +-- tests/restful_client_v2/utils/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/restful_client_v2/requirements.txt b/tests/restful_client_v2/requirements.txt index 5f7aa72423..624e0f269d 100644 --- a/tests/restful_client_v2/requirements.txt +++ b/tests/restful_client_v2/requirements.txt @@ -12,5 +12,4 @@ pytest-xdist==2.5.0 minio==7.1.14 tenacity==8.1.0 # for bf16 datatype -jax==0.4.13 -jaxlib==0.4.13 +ml-dtypes==0.2.0 diff --git a/tests/restful_client_v2/utils/utils.py b/tests/restful_client_v2/utils/utils.py index 112e26e787..cbd7640edf 100644 --- a/tests/restful_client_v2/utils/utils.py +++ b/tests/restful_client_v2/utils/utils.py @@ -4,7 +4,7 @@ import random import string from faker import Faker import numpy as np -import jax.numpy as jnp +from ml_dtypes import bfloat16 from sklearn import preprocessing import base64 import requests @@ -191,7 +191,7 @@ def gen_bf16_vectors(num, dim): for _ in range(num): raw_vector = [random.random() for _ in range(dim)] raw_vectors.append(raw_vector) - bf16_vector = np.array(jnp.array(raw_vector, dtype=jnp.bfloat16)).view(np.uint8).tolist() + bf16_vector = np.array(raw_vector, dtype=bfloat16).view(np.uint8).tolist() bf16_vectors.append(bytes(bf16_vector)) return raw_vectors, bf16_vectors