mirror of https://github.com/milvus-io/milvus.git
test: use jax.numpy to generate brain float16 datatype (#30825)
JAX is a more lightweight lib that can also generate bfloat16 datatype Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>pull/30852/head
parent
7a49cf2104
commit
7f78e9d40d
|
@ -19,7 +19,6 @@ from base.schema_wrapper import ApiCollectionSchemaWrapper, ApiFieldSchemaWrappe
|
|||
from common import common_type as ct
|
||||
from utils.util_log import test_log as log
|
||||
from customize.milvus_operator import MilvusOperator
|
||||
import tensorflow as tf
|
||||
fake = Faker()
|
||||
"""" Methods of processing data """
|
||||
|
||||
|
@ -1802,8 +1801,7 @@ def gen_bf16_vectors(num, dim):
|
|||
for _ in range(num):
|
||||
raw_vector = [random.random() for _ in range(dim)]
|
||||
raw_vectors.append(raw_vector)
|
||||
# bf16_vector = np.array(raw_vector, dtype=tf.bfloat16).view(np.uint8).tolist()
|
||||
bf16_vector = tf.cast(raw_vector, dtype=tf.bfloat16).numpy().view(np.uint8).tolist()
|
||||
bf16_vector = np.array(jnp.array(raw_vector, dtype=jnp.bfloat16)).view(np.uint8).tolist()
|
||||
bf16_vectors.append(bytes(bf16_vector))
|
||||
|
||||
return raw_vectors, bf16_vectors
|
||||
|
@ -1841,4 +1839,4 @@ def gen_vectors_based_on_vector_type(num, dim, vector_data_type):
|
|||
elif vector_data_type == "BFLOAT16_VECTOR":
|
||||
vectors = gen_bf16_vectors(num, dim)[1]
|
||||
|
||||
return vectors
|
||||
return vectors
|
||||
|
|
|
@ -54,8 +54,6 @@ prettytable==3.8.0
|
|||
pyarrow==14.0.1
|
||||
fastparquet==2023.7.0
|
||||
|
||||
# for generating bfloat16 data
|
||||
tensorflow==2.13.1
|
||||
# for bf16 datatype
|
||||
jax==0.4.13
|
||||
jaxlib==0.4.13
|
||||
|
|
Loading…
Reference in New Issue