milvus/tests/milvus_ann_acc/runner.py

163 lines
7.3 KiB
Python

import os
import pdb
import time
import random
import sys
import logging
import h5py
import numpy
from influxdb import InfluxDBClient
INSERT_INTERVAL = 100000
# s
DELETE_INTERVAL_TIME = 5
INFLUXDB_HOST = "192.168.1.194"
INFLUXDB_PORT = 8086
INFLUXDB_USER = "admin"
INFLUXDB_PASSWD = "admin"
INFLUXDB_NAME = "test_result"
influxdb_client = InfluxDBClient(host=INFLUXDB_HOST, port=INFLUXDB_PORT, username=INFLUXDB_USER, password=INFLUXDB_PASSWD, database=INFLUXDB_NAME)
logger = logging.getLogger("milvus_acc.runner")
def parse_dataset_name(dataset_name):
data_type = dataset_name.split("-")[0]
dimension = int(dataset_name.split("-")[1])
metric = dataset_name.split("-")[-1]
# metric = dataset.attrs['distance']
# dimension = len(dataset["train"][0])
if metric == "euclidean":
metric_type = "l2"
elif metric == "angular":
metric_type = "ip"
return ("ann"+data_type, dimension, metric_type)
def get_dataset(hdf5_path, dataset_name):
file_path = os.path.join(hdf5_path, '%s.hdf5' % dataset_name)
if not os.path.exists(file_path):
raise Exception("%s not existed" % file_path)
dataset = h5py.File(file_path)
return dataset
def get_table_name(hdf5_path, dataset_name, index_file_size):
data_type, dimension, metric_type = parse_dataset_name(dataset_name)
dataset = get_dataset(hdf5_path, dataset_name)
table_size = len(dataset["train"])
table_size = str(table_size // 1000000)+"m"
table_name = data_type+'_'+table_size+'_'+str(index_file_size)+'_'+str(dimension)+'_'+metric_type
return table_name
def recall_calc(result_ids, true_ids, top_k, recall_k):
sum_intersect_num = 0
recall = 0.0
for index, result_item in enumerate(result_ids):
# if len(set(true_ids[index][:top_k])) != len(set(result_item)):
# logger.warning("Error happened: query result length is wrong")
# continue
tmp = set(true_ids[index][:recall_k]).intersection(set(result_item))
sum_intersect_num = sum_intersect_num + len(tmp)
recall = round(sum_intersect_num / (len(result_ids) * recall_k), 4)
return recall
def run(milvus, config, hdf5_path, force=True):
server_version = milvus.get_server_version()
logger.info(server_version)
for dataset_name, config_value in config.items():
dataset = get_dataset(hdf5_path, dataset_name)
index_file_sizes = config_value["index_file_sizes"]
index_types = config_value["index_types"]
nlists = config_value["nlists"]
search_param = config_value["search_param"]
top_ks = search_param["top_ks"]
nprobes = search_param["nprobes"]
nqs = search_param["nqs"]
for index_file_size in index_file_sizes:
table_name = get_table_name(hdf5_path, dataset_name, index_file_size)
if milvus.exists_table(table_name):
if force is True:
logger.info("Re-create table: %s" % table_name)
milvus.delete(table_name)
time.sleep(DELETE_INTERVAL_TIME)
else:
logger.warning("Table name: %s existed" % table_name)
continue
data_type, dimension, metric_type = parse_dataset_name(dataset_name)
milvus.create_table(table_name, dimension, index_file_size, metric_type)
logger.info(milvus.describe())
insert_vectors = numpy.array(dataset["train"])
# milvus.insert(insert_vectors)
loops = len(insert_vectors) // INSERT_INTERVAL + 1
for i in range(loops):
start = i*INSERT_INTERVAL
end = min((i+1)*INSERT_INTERVAL, len(insert_vectors))
tmp_vectors = insert_vectors[start:end]
if start < end:
milvus.insert(tmp_vectors, ids=[i for i in range(start, end)])
time.sleep(20)
row_count = milvus.count()
logger.info("Table: %s, row count: %s" % (table_name, row_count))
if milvus.count() != len(insert_vectors):
logger.error("Table row count is not equal to insert vectors")
return
for index_type in index_types:
for nlist in nlists:
milvus.create_index(index_type, nlist)
logger.info(milvus.describe_index())
logger.info("Start preload table: %s, index_type: %s, nlist: %s" % (table_name, index_type, nlist))
milvus.preload_table()
true_ids = numpy.array(dataset["neighbors"])
for nprobe in nprobes:
for nq in nqs:
query_vectors = numpy.array(dataset["test"][:nq])
for top_k in top_ks:
rec1 = 0.0
rec10 = 0.0
rec100 = 0.0
result_ids = milvus.query(query_vectors, top_k, nprobe)
logger.info("Query result: %s" % len(result_ids))
rec1 = recall_calc(result_ids, true_ids, top_k, 1)
if top_k == 10:
rec10 = recall_calc(result_ids, true_ids, top_k, 10)
if top_k == 100:
rec10 = recall_calc(result_ids, true_ids, top_k, 10)
rec100 = recall_calc(result_ids, true_ids, top_k, 100)
avg_radio = recall_calc(result_ids, true_ids, top_k, top_k)
logger.debug("Recall_1: %s" % rec1)
logger.debug("Recall_10: %s" % rec10)
logger.debug("Recall_100: %s" % rec100)
logger.debug("Accuracy: %s" % avg_radio)
acc_record = [{
"measurement": "accuracy",
"tags": {
"server_version": server_version,
"dataset": dataset_name,
"index_file_size": index_file_size,
"index_type": index_type,
"nlist": nlist,
"search_nprobe": nprobe,
"top_k": top_k,
"nq": len(query_vectors)
},
# "time": time.ctime(),
"time": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
"fields": {
"recall1": rec1,
"recall10": rec10,
"recall100": rec100,
"avg_radio": avg_radio
}
}]
logger.info(acc_record)
try:
res = influxdb_client.write_points(acc_record)
except Exception as e:
logger.error("Insert infuxdb failed: %s" % str(e))