mirror of https://github.com/milvus-io/milvus.git
test: update the lib of bf16 (#34043)
Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>pull/34057/head
parent
e653ad27e2
commit
03a5f7e6c0
|
@ -12,5 +12,4 @@ pytest-xdist==2.5.0
|
||||||
minio==7.1.14
|
minio==7.1.14
|
||||||
tenacity==8.1.0
|
tenacity==8.1.0
|
||||||
# for bf16 datatype
|
# for bf16 datatype
|
||||||
jax==0.4.13
|
ml-dtypes==0.2.0
|
||||||
jaxlib==0.4.13
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import random
|
||||||
import string
|
import string
|
||||||
from faker import Faker
|
from faker import Faker
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import jax.numpy as jnp
|
from ml_dtypes import bfloat16
|
||||||
from sklearn import preprocessing
|
from sklearn import preprocessing
|
||||||
import base64
|
import base64
|
||||||
import requests
|
import requests
|
||||||
|
@ -191,7 +191,7 @@ def gen_bf16_vectors(num, dim):
|
||||||
for _ in range(num):
|
for _ in range(num):
|
||||||
raw_vector = [random.random() for _ in range(dim)]
|
raw_vector = [random.random() for _ in range(dim)]
|
||||||
raw_vectors.append(raw_vector)
|
raw_vectors.append(raw_vector)
|
||||||
bf16_vector = np.array(jnp.array(raw_vector, dtype=jnp.bfloat16)).view(np.uint8).tolist()
|
bf16_vector = np.array(raw_vector, dtype=bfloat16).view(np.uint8).tolist()
|
||||||
bf16_vectors.append(bytes(bf16_vector))
|
bf16_vectors.append(bytes(bf16_vector))
|
||||||
|
|
||||||
return raw_vectors, bf16_vectors
|
return raw_vectors, bf16_vectors
|
||||||
|
|
Loading…
Reference in New Issue