diff --git a/tests/python_client/common/bulk_insert_data.py b/tests/python_client/common/bulk_insert_data.py index a064efde97..ce80d9cbf0 100644 --- a/tests/python_client/common/bulk_insert_data.py +++ b/tests/python_client/common/bulk_insert_data.py @@ -4,7 +4,7 @@ import os import time import numpy as np -import jax.numpy as jnp +from ml_dtypes import bfloat16 import pandas as pd import random from faker import Faker @@ -128,9 +128,9 @@ def gen_bf16_vectors(num, dim, for_json=False): raw_vector = [random.random() for _ in range(dim)] raw_vectors.append(raw_vector) if for_json: - bf16_vector = np.array(jnp.array(raw_vector, dtype=jnp.bfloat16)).tolist() + bf16_vector = np.array(raw_vector, dtype=bfloat16).tolist() else: - bf16_vector = np.array(jnp.array(raw_vector, dtype=jnp.bfloat16)).view(np.uint8).tolist() + bf16_vector = np.array(raw_vector, dtype=bfloat16).view(np.uint8).tolist() bf16_vectors.append(bf16_vector) return raw_vectors, bf16_vectors diff --git a/tests/python_client/common/common_func.py b/tests/python_client/common/common_func.py index 4a4c15d322..71d3328ced 100644 --- a/tests/python_client/common/common_func.py +++ b/tests/python_client/common/common_func.py @@ -8,7 +8,7 @@ import uuid from functools import singledispatch import numpy as np import pandas as pd -import jax.numpy as jnp +from ml_dtypes import bfloat16 from sklearn import preprocessing from npy_append_array import NpyAppendArray from faker import Faker @@ -20,7 +20,6 @@ from common import common_type as ct from utils.util_log import test_log as log from customize.milvus_operator import MilvusOperator import pickle -import tensorflow as tf fake = Faker() """" Methods of processing data """ @@ -1070,14 +1069,12 @@ def gen_data_by_collection_field(field, nb=None, start=None): dim = field.params['dim'] if nb is None: raw_vector = [random.random() for _ in range(dim)] - bf16_vector = jnp.array(raw_vector, dtype=jnp.bfloat16) - bf16_vector = np.array(bf16_vector).view(np.uint8).tolist() + bf16_vector = np.array(raw_vector, dtype=bfloat16).view(np.uint8).tolist() return bytes(bf16_vector) bf16_vectors = [] for i in range(nb): raw_vector = [random.random() for _ in range(dim)] - bf16_vector = jnp.array(raw_vector, dtype=jnp.bfloat16) - bf16_vector = np.array(bf16_vector).view(np.uint8).tolist() + bf16_vector = np.array(raw_vector, dtype=bfloat16).view(np.uint8).tolist() bf16_vectors.append(bytes(bf16_vector)) return bf16_vectors if data_type == DataType.FLOAT16_VECTOR: @@ -2077,7 +2074,7 @@ def gen_bf16_vectors(num, dim): for _ in range(num): raw_vector = [random.random() for _ in range(dim)] raw_vectors.append(raw_vector) - bf16_vector = tf.cast(raw_vector, dtype=tf.bfloat16).numpy() + bf16_vector = np.array(raw_vector, dtype=bfloat16) bf16_vectors.append(bf16_vector) return raw_vectors, bf16_vectors diff --git a/tests/python_client/requirements.txt b/tests/python_client/requirements.txt index 177e44cd36..99ad4f62c9 100644 --- a/tests/python_client/requirements.txt +++ b/tests/python_client/requirements.txt @@ -56,7 +56,5 @@ pyarrow==14.0.1 fastparquet==2023.7.0 # for bf16 datatype -jax==0.4.13 -jaxlib==0.4.13 -tensorflow==2.13.1 +ml-dtypes==0.2.0