milvus/tests-deprecating/python_test/test_mix.py

200 lines
8.3 KiB
Python

import pdb
import copy
import pytest
import threading
import datetime
import logging
from time import sleep
from multiprocessing import Process
import sklearn.preprocessing
from utils import *
index_file_size = 10
vectors = gen_vectors(10000, default_dim)
vectors = sklearn.preprocessing.normalize(vectors, axis=1, norm='l2')
vectors = vectors.tolist()
top_k = 1
nprobe = 1
epsilon = 0.001
nlist = 128
# index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384}
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 16384}, "metric_type": "L2"}
class TestMixBase:
# TODO
@pytest.mark.tags(CaseLabel.L2)
def _test_mix_base(self, connect, collection):
nb = 200000
nq = 5
entities = gen_entities(nb=nb)
ids = connect.insert(collection, entities)
assert len(ids) == nb
connect.flush([collection])
connect.create_index(collection, default_float_vec_field_name, default_index)
index = connect.describe_index(collection, "")
create_target_index(default_index, default_float_vec_field_name)
assert index == default_index
query, vecs = gen_query_vectors(default_float_vec_field_name, entities, default_top_k, nq)
connect.load_collection(collection)
res = connect.search(collection, query)
assert len(res) == nq
assert len(res[0]) == default_top_k
assert res[0]._distances[0] <= epsilon
assert check_id_result(res[0], ids[0])
# disable
@pytest.mark.tags(CaseLabel.L2)
def _test_search_during_createIndex(self, args):
loops = 10000
collection = gen_unique_str()
query_vecs = [vectors[0], vectors[1]]
uri = "tcp://%s:%s" % (args["ip"], args["port"])
id_0 = 0;
id_1 = 0
milvus_instance = get_milvus(args["handler"])
# milvus_instance.connect(uri=uri)
milvus_instance.create_collection({'collection_name': collection,
'dimension': default_dim,
'index_file_size': index_file_size,
'metric_type': "L2"})
for i in range(10):
status, ids = milvus_instance.bulk_insert(collection, vectors)
# logging.getLogger().info(ids)
if i == 0:
id_0 = ids[0];
id_1 = ids[1]
# def create_index(milvus_instance):
# logging.getLogger().info("In create index")
# status = milvus_instance.create_index(collection, index_params)
# logging.getLogger().info(status)
# status, result = milvus_instance.get_index_info(collection)
# logging.getLogger().info(result)
def insert(milvus_instance):
logging.getLogger().info("In add vectors")
status, ids = milvus_instance.bulk_insert(collection, vectors)
logging.getLogger().info(status)
def search(milvus_instance):
logging.getLogger().info("In search vectors")
for i in range(loops):
status, result = milvus_instance.search(collection, top_k, nprobe, query_vecs)
logging.getLogger().info(status)
assert result[0][0].id == id_0
assert result[1][0].id == id_1
milvus_instance = get_milvus(args["handler"])
# milvus_instance.connect(uri=uri)
p_search = Process(target=search, args=(milvus_instance,))
p_search.start()
milvus_instance = get_milvus(args["handler"])
# milvus_instance.connect(uri=uri)
p_create = Process(target=insert, args=(milvus_instance,))
p_create.start()
p_create.join()
@pytest.mark.tags(CaseLabel.L2)
def _test_mix_multi_collections(self, connect):
'''
target: test functions with multiple collections of different metric_types and index_types
method: create 60 collections which 30 are L2 and the other are IP, add vectors into them
and test describe index and search
expected: status ok
'''
nq = 10000
collection_list = []
idx = []
index_param = {'nlist': nlist}
# create collection and add vectors
for i in range(30):
collection_name = gen_unique_str('test_mix_multi_collections')
collection_list.append(collection_name)
param = {'collection_name': collection_name,
'dimension': default_dim,
'index_file_size': index_file_size,
'metric_type': MetricType.L2}
connect.create_collection(param)
status, ids = connect.bulk_insert(collection_name=collection_name, records=vectors)
idx.append(ids[0])
idx.append(ids[10])
idx.append(ids[20])
assert status.OK()
for i in range(30):
collection_name = gen_unique_str('test_mix_multi_collections')
collection_list.append(collection_name)
param = {'collection_name': collection_name,
'dimension': default_dim,
'index_file_size': index_file_size,
'metric_type': MetricType.IP}
connect.create_collection(param)
status, ids = connect.bulk_insert(collection_name=collection_name, records=vectors)
assert status.OK()
status = connect.flush([collection_name])
assert status.OK()
idx.append(ids[0])
idx.append(ids[10])
idx.append(ids[20])
assert status.OK()
for i in range(10):
status = connect.create_index(collection_list[i], IndexType.FLAT, index_param)
assert status.OK()
status = connect.create_index(collection_list[30 + i], IndexType.FLAT, index_param)
assert status.OK()
status = connect.create_index(collection_list[10 + i], IndexType.IVFLAT, index_param)
assert status.OK()
status = connect.create_index(collection_list[40 + i], IndexType.IVFLAT, index_param)
assert status.OK()
status = connect.create_index(collection_list[20 + i], IndexType.IVF_SQ8, index_param)
assert status.OK()
status = connect.create_index(collection_list[50 + i], IndexType.IVF_SQ8, index_param)
assert status.OK()
# describe index
for i in range(10):
status, result = connect.get_index_info(collection_list[i])
assert result._index_type == IndexType.FLAT
status, result = connect.get_index_info(collection_list[10 + i])
assert result._index_type == IndexType.IVFLAT
status, result = connect.get_index_info(collection_list[20 + i])
assert result._index_type == IndexType.IVF_SQ8
status, result = connect.get_index_info(collection_list[30 + i])
assert result._index_type == IndexType.FLAT
status, result = connect.get_index_info(collection_list[40 + i])
assert result._index_type == IndexType.IVFLAT
status, result = connect.get_index_info(collection_list[50 + i])
assert result._index_type == IndexType.IVF_SQ8
# search
query_vecs = [vectors[0], vectors[10], vectors[20]]
for i in range(60):
collection = collection_list[i]
status, result = connect.search(collection, top_k, query_records=query_vecs, params={"nprobe": 1})
assert status.OK()
assert len(result) == len(query_vecs)
logging.getLogger().info(i)
for j in range(len(query_vecs)):
assert len(result[j]) == top_k
for j in range(len(query_vecs)):
if not check_result(result[j], idx[3 * i + j]):
logging.getLogger().info(result[j]._id_list)
logging.getLogger().info(idx[3 * i + j])
assert check_result(result[j], idx[3 * i + j])
def check_result(result, id):
if len(result) >= 5:
return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
else:
return id in (i.id for i in result)
def check_id_result(result, id):
limit_in = 5
ids = [entity.id for entity in result]
if len(result) >= limit_in:
return id in ids[:limit_in]
else:
return id in ids