mirror of https://github.com/milvus-io/milvus.git
test: use ml-dtypes lib to produce bf16 datatype (#33354)
Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>pull/33376/head
parent
970bf18a49
commit
ed883b39d7
|
@ -4,7 +4,7 @@ import os
|
|||
import time
|
||||
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from ml_dtypes import bfloat16
|
||||
import pandas as pd
|
||||
import random
|
||||
from faker import Faker
|
||||
|
@ -128,9 +128,9 @@ def gen_bf16_vectors(num, dim, for_json=False):
|
|||
raw_vector = [random.random() for _ in range(dim)]
|
||||
raw_vectors.append(raw_vector)
|
||||
if for_json:
|
||||
bf16_vector = np.array(jnp.array(raw_vector, dtype=jnp.bfloat16)).tolist()
|
||||
bf16_vector = np.array(raw_vector, dtype=bfloat16).tolist()
|
||||
else:
|
||||
bf16_vector = np.array(jnp.array(raw_vector, dtype=jnp.bfloat16)).view(np.uint8).tolist()
|
||||
bf16_vector = np.array(raw_vector, dtype=bfloat16).view(np.uint8).tolist()
|
||||
bf16_vectors.append(bf16_vector)
|
||||
|
||||
return raw_vectors, bf16_vectors
|
||||
|
|
|
@ -8,7 +8,7 @@ import uuid
|
|||
from functools import singledispatch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import jax.numpy as jnp
|
||||
from ml_dtypes import bfloat16
|
||||
from sklearn import preprocessing
|
||||
from npy_append_array import NpyAppendArray
|
||||
from faker import Faker
|
||||
|
@ -20,7 +20,6 @@ from common import common_type as ct
|
|||
from utils.util_log import test_log as log
|
||||
from customize.milvus_operator import MilvusOperator
|
||||
import pickle
|
||||
import tensorflow as tf
|
||||
fake = Faker()
|
||||
"""" Methods of processing data """
|
||||
|
||||
|
@ -1070,14 +1069,12 @@ def gen_data_by_collection_field(field, nb=None, start=None):
|
|||
dim = field.params['dim']
|
||||
if nb is None:
|
||||
raw_vector = [random.random() for _ in range(dim)]
|
||||
bf16_vector = jnp.array(raw_vector, dtype=jnp.bfloat16)
|
||||
bf16_vector = np.array(bf16_vector).view(np.uint8).tolist()
|
||||
bf16_vector = np.array(raw_vector, dtype=bfloat16).view(np.uint8).tolist()
|
||||
return bytes(bf16_vector)
|
||||
bf16_vectors = []
|
||||
for i in range(nb):
|
||||
raw_vector = [random.random() for _ in range(dim)]
|
||||
bf16_vector = jnp.array(raw_vector, dtype=jnp.bfloat16)
|
||||
bf16_vector = np.array(bf16_vector).view(np.uint8).tolist()
|
||||
bf16_vector = np.array(raw_vector, dtype=bfloat16).view(np.uint8).tolist()
|
||||
bf16_vectors.append(bytes(bf16_vector))
|
||||
return bf16_vectors
|
||||
if data_type == DataType.FLOAT16_VECTOR:
|
||||
|
@ -2077,7 +2074,7 @@ def gen_bf16_vectors(num, dim):
|
|||
for _ in range(num):
|
||||
raw_vector = [random.random() for _ in range(dim)]
|
||||
raw_vectors.append(raw_vector)
|
||||
bf16_vector = tf.cast(raw_vector, dtype=tf.bfloat16).numpy()
|
||||
bf16_vector = np.array(raw_vector, dtype=bfloat16)
|
||||
bf16_vectors.append(bf16_vector)
|
||||
|
||||
return raw_vectors, bf16_vectors
|
||||
|
|
|
@ -56,7 +56,5 @@ pyarrow==14.0.1
|
|||
fastparquet==2023.7.0
|
||||
|
||||
# for bf16 datatype
|
||||
jax==0.4.13
|
||||
jaxlib==0.4.13
|
||||
tensorflow==2.13.1
|
||||
ml-dtypes==0.2.0
|
||||
|
||||
|
|
Loading…
Reference in New Issue