mirror of https://github.com/milvus-io/milvus.git
163 lines
7.3 KiB
Python
163 lines
7.3 KiB
Python
import os
|
|
import pdb
|
|
import time
|
|
import random
|
|
import sys
|
|
import logging
|
|
import h5py
|
|
import numpy
|
|
from influxdb import InfluxDBClient
|
|
|
|
INSERT_INTERVAL = 100000
|
|
# s
|
|
DELETE_INTERVAL_TIME = 5
|
|
INFLUXDB_HOST = "192.168.1.194"
|
|
INFLUXDB_PORT = 8086
|
|
INFLUXDB_USER = "admin"
|
|
INFLUXDB_PASSWD = "admin"
|
|
INFLUXDB_NAME = "test_result"
|
|
influxdb_client = InfluxDBClient(host=INFLUXDB_HOST, port=INFLUXDB_PORT, username=INFLUXDB_USER, password=INFLUXDB_PASSWD, database=INFLUXDB_NAME)
|
|
|
|
logger = logging.getLogger("milvus_acc.runner")
|
|
|
|
|
|
def parse_dataset_name(dataset_name):
|
|
data_type = dataset_name.split("-")[0]
|
|
dimension = int(dataset_name.split("-")[1])
|
|
metric = dataset_name.split("-")[-1]
|
|
# metric = dataset.attrs['distance']
|
|
# dimension = len(dataset["train"][0])
|
|
if metric == "euclidean":
|
|
metric_type = "l2"
|
|
elif metric == "angular":
|
|
metric_type = "ip"
|
|
return ("ann"+data_type, dimension, metric_type)
|
|
|
|
|
|
def get_dataset(hdf5_path, dataset_name):
|
|
file_path = os.path.join(hdf5_path, '%s.hdf5' % dataset_name)
|
|
if not os.path.exists(file_path):
|
|
raise Exception("%s not existed" % file_path)
|
|
dataset = h5py.File(file_path)
|
|
return dataset
|
|
|
|
|
|
def get_table_name(hdf5_path, dataset_name, index_file_size):
|
|
data_type, dimension, metric_type = parse_dataset_name(dataset_name)
|
|
dataset = get_dataset(hdf5_path, dataset_name)
|
|
table_size = len(dataset["train"])
|
|
table_size = str(table_size // 1000000)+"m"
|
|
table_name = data_type+'_'+table_size+'_'+str(index_file_size)+'_'+str(dimension)+'_'+metric_type
|
|
return table_name
|
|
|
|
|
|
def recall_calc(result_ids, true_ids, top_k, recall_k):
|
|
sum_intersect_num = 0
|
|
recall = 0.0
|
|
for index, result_item in enumerate(result_ids):
|
|
# if len(set(true_ids[index][:top_k])) != len(set(result_item)):
|
|
# logger.warning("Error happened: query result length is wrong")
|
|
# continue
|
|
tmp = set(true_ids[index][:recall_k]).intersection(set(result_item))
|
|
sum_intersect_num = sum_intersect_num + len(tmp)
|
|
recall = round(sum_intersect_num / (len(result_ids) * recall_k), 4)
|
|
return recall
|
|
|
|
|
|
def run(milvus, config, hdf5_path, force=True):
|
|
server_version = milvus.get_server_version()
|
|
logger.info(server_version)
|
|
|
|
for dataset_name, config_value in config.items():
|
|
dataset = get_dataset(hdf5_path, dataset_name)
|
|
index_file_sizes = config_value["index_file_sizes"]
|
|
index_types = config_value["index_types"]
|
|
nlists = config_value["nlists"]
|
|
search_param = config_value["search_param"]
|
|
top_ks = search_param["top_ks"]
|
|
nprobes = search_param["nprobes"]
|
|
nqs = search_param["nqs"]
|
|
|
|
for index_file_size in index_file_sizes:
|
|
table_name = get_table_name(hdf5_path, dataset_name, index_file_size)
|
|
if milvus.exists_table(table_name):
|
|
if force is True:
|
|
logger.info("Re-create table: %s" % table_name)
|
|
milvus.delete(table_name)
|
|
time.sleep(DELETE_INTERVAL_TIME)
|
|
else:
|
|
logger.warning("Table name: %s existed" % table_name)
|
|
continue
|
|
data_type, dimension, metric_type = parse_dataset_name(dataset_name)
|
|
milvus.create_table(table_name, dimension, index_file_size, metric_type)
|
|
logger.info(milvus.describe())
|
|
insert_vectors = numpy.array(dataset["train"])
|
|
# milvus.insert(insert_vectors)
|
|
|
|
loops = len(insert_vectors) // INSERT_INTERVAL + 1
|
|
for i in range(loops):
|
|
start = i*INSERT_INTERVAL
|
|
end = min((i+1)*INSERT_INTERVAL, len(insert_vectors))
|
|
tmp_vectors = insert_vectors[start:end]
|
|
if start < end:
|
|
milvus.insert(tmp_vectors, ids=[i for i in range(start, end)])
|
|
time.sleep(20)
|
|
row_count = milvus.count()
|
|
logger.info("Table: %s, row count: %s" % (table_name, row_count))
|
|
if milvus.count() != len(insert_vectors):
|
|
logger.error("Table row count is not equal to insert vectors")
|
|
return
|
|
for index_type in index_types:
|
|
for nlist in nlists:
|
|
milvus.create_index(index_type, nlist)
|
|
logger.info(milvus.describe_index())
|
|
logger.info("Start preload table: %s, index_type: %s, nlist: %s" % (table_name, index_type, nlist))
|
|
milvus.preload_table()
|
|
true_ids = numpy.array(dataset["neighbors"])
|
|
for nprobe in nprobes:
|
|
for nq in nqs:
|
|
query_vectors = numpy.array(dataset["test"][:nq])
|
|
for top_k in top_ks:
|
|
rec1 = 0.0
|
|
rec10 = 0.0
|
|
rec100 = 0.0
|
|
result_ids = milvus.query(query_vectors, top_k, nprobe)
|
|
logger.info("Query result: %s" % len(result_ids))
|
|
rec1 = recall_calc(result_ids, true_ids, top_k, 1)
|
|
if top_k == 10:
|
|
rec10 = recall_calc(result_ids, true_ids, top_k, 10)
|
|
if top_k == 100:
|
|
rec10 = recall_calc(result_ids, true_ids, top_k, 10)
|
|
rec100 = recall_calc(result_ids, true_ids, top_k, 100)
|
|
avg_radio = recall_calc(result_ids, true_ids, top_k, top_k)
|
|
logger.debug("Recall_1: %s" % rec1)
|
|
logger.debug("Recall_10: %s" % rec10)
|
|
logger.debug("Recall_100: %s" % rec100)
|
|
logger.debug("Accuracy: %s" % avg_radio)
|
|
acc_record = [{
|
|
"measurement": "accuracy",
|
|
"tags": {
|
|
"server_version": server_version,
|
|
"dataset": dataset_name,
|
|
"index_file_size": index_file_size,
|
|
"index_type": index_type,
|
|
"nlist": nlist,
|
|
"search_nprobe": nprobe,
|
|
"top_k": top_k,
|
|
"nq": len(query_vectors)
|
|
},
|
|
# "time": time.ctime(),
|
|
"time": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
|
|
"fields": {
|
|
"recall1": rec1,
|
|
"recall10": rec10,
|
|
"recall100": rec100,
|
|
"avg_radio": avg_radio
|
|
}
|
|
}]
|
|
logger.info(acc_record)
|
|
try:
|
|
res = influxdb_client.write_points(acc_record)
|
|
except Exception as e:
|
|
logger.error("Insert infuxdb failed: %s" % str(e))
|