milvus/tests/benchmark/milvus_benchmark/runners/accuracy.py

276 lines
13 KiB
Python

import json
import time
import copy
import logging
import numpy as np
from milvus_benchmark import parser
from milvus_benchmark.runners import utils
from milvus_benchmark.runners.base import BaseRunner
logger = logging.getLogger("milvus_benchmark.runners.accuracy")
INSERT_INTERVAL = 50000
class AccuracyRunner(BaseRunner):
"""run accuracy"""
name = "accuracy"
def __init__(self, env, metric):
super(AccuracyRunner, self).__init__(env, metric)
def extract_cases(self, collection):
collection_name = collection["collection_name"] if "collection_name" in collection else None
(data_type, collection_size, dimension, metric_type) = parser.collection_parser(collection_name)
vector_type = utils.get_vector_type(data_type)
index_field_name = utils.get_default_field_name(vector_type)
base_query_vectors = utils.get_vectors_from_binary(utils.MAX_NQ, dimension, data_type)
collection_info = {
"dimension": dimension,
"metric_type": metric_type,
"dataset_name": collection_name,
"collection_size": collection_size
}
index_info = self.milvus.describe_index(index_field_name, collection_name)
filters = collection["filters"] if "filters" in collection else []
filter_query = []
top_ks = collection["top_ks"]
nqs = collection["nqs"]
search_params = collection["search_params"]
search_params = utils.generate_combinations(search_params)
cases = list()
case_metrics = list()
self.init_metric(self.name, collection_info, index_info, search_info=None)
for search_param in search_params:
if not filters:
filters.append(None)
for filter in filters:
filter_param = []
if isinstance(filter, dict) and "range" in filter:
filter_query.append(eval(filter["range"]))
filter_param.append(filter["range"])
if isinstance(filter, dict) and "term" in filter:
filter_query.append(eval(filter["term"]))
filter_param.append(filter["term"])
for nq in nqs:
query_vectors = base_query_vectors[0:nq]
for top_k in top_ks:
search_info = {
"topk": top_k,
"query": query_vectors,
"metric_type": utils.metric_type_trans(metric_type),
"params": search_param}
# TODO: only update search_info
case_metric = copy.deepcopy(self.metric)
# set metric type as case
case_metric.set_case_metric_type()
case_metric.search = {
"nq": nq,
"topk": top_k,
"search_param": search_param,
"filter": filter_param
}
vector_query = {"vector": {index_field_name: search_info}}
case = {
"collection_name": collection_name,
"index_field_name": index_field_name,
"dimension": dimension,
"data_type": data_type,
"metric_type": metric_type,
"vector_type": vector_type,
"collection_size": collection_size,
"filter_query": filter_query,
"vector_query": vector_query
}
cases.append(case)
case_metrics.append(case_metric)
return cases, case_metrics
def prepare(self, **case_param):
collection_name = case_param["collection_name"]
self.milvus.set_collection(collection_name)
if not self.milvus.exists_collection():
logger.info("collection not exist")
self.milvus.load_collection(timeout=600)
def run_case(self, case_metric, **case_param):
collection_size = case_param["collection_size"]
nq = case_metric.search["nq"]
top_k = case_metric.search["topk"]
query_res = self.milvus.query(case_param["vector_query"], filter_query=case_param["filter_query"])
true_ids = utils.get_ground_truth_ids(collection_size)
logger.debug({"true_ids": [len(true_ids[0]), len(true_ids[0])]})
result_ids = self.milvus.get_ids(query_res)
logger.debug({"result_ids": len(result_ids[0])})
acc_value = utils.get_recall_value(true_ids[:nq, :top_k].tolist(), result_ids)
tmp_result = {"acc": acc_value}
return tmp_result
class AccAccuracyRunner(AccuracyRunner):
"""run ann accuracy"""
"""
1. entities from hdf5
2. one collection test different index
"""
name = "ann_accuracy"
def __init__(self, env, metric):
super(AccAccuracyRunner, self).__init__(env, metric)
def extract_cases(self, collection):
collection_name = collection["collection_name"] if "collection_name" in collection else None
(data_type, dimension, metric_type) = parser.parse_ann_collection_name(collection_name)
# hdf5_source_file: The path of the source data file saved on the NAS
hdf5_source_file = collection["source_file"]
index_types = collection["index_types"]
index_params = collection["index_params"]
top_ks = collection["top_ks"]
nqs = collection["nqs"]
search_params = collection["search_params"]
vector_type = utils.get_vector_type(data_type)
index_field_name = utils.get_default_field_name(vector_type)
dataset = utils.get_dataset(hdf5_source_file)
collection_info = {
"dimension": dimension,
"metric_type": metric_type,
"dataset_name": collection_name
}
filters = collection["filters"] if "filters" in collection else []
filter_query = []
# Convert list data into a set of dictionary data
search_params = utils.generate_combinations(search_params)
index_params = utils.generate_combinations(index_params)
cases = list()
case_metrics = list()
self.init_metric(self.name, collection_info, {}, search_info=None)
# true_ids: The data set used to verify the results returned by query
true_ids = np.array(dataset["neighbors"])
for index_type in index_types:
for index_param in index_params:
index_info = {
"index_type": index_type,
"index_param": index_param
}
for search_param in search_params:
if not filters:
filters.append(None)
for filter in filters:
filter_param = []
if isinstance(filter, dict) and "range" in filter:
filter_query.append(eval(filter["range"]))
filter_param.append(filter["range"])
if isinstance(filter, dict) and "term" in filter:
filter_query.append(eval(filter["term"]))
filter_param.append(filter["term"])
for nq in nqs:
query_vectors = utils.normalize(metric_type, np.array(dataset["test"][:nq]))
for top_k in top_ks:
search_info = {
"topk": top_k,
"query": query_vectors,
"metric_type": utils.metric_type_trans(metric_type),
"params": search_param}
# TODO: only update search_info
case_metric = copy.deepcopy(self.metric)
# set metric type as case
case_metric.set_case_metric_type()
case_metric.index = index_info
case_metric.search = {
"nq": nq,
"topk": top_k,
"search_param": search_param,
"filter": filter_param
}
vector_query = {"vector": {index_field_name: search_info}}
case = {
"collection_name": collection_name,
"dataset": dataset,
"index_field_name": index_field_name,
"dimension": dimension,
"data_type": data_type,
"metric_type": metric_type,
"vector_type": vector_type,
"index_type": index_type,
"index_param": index_param,
"filter_query": filter_query,
"vector_query": vector_query,
"true_ids": true_ids
}
# Obtain the parameters of the use case to be tested
cases.append(case)
case_metrics.append(case_metric)
return cases, case_metrics
def prepare(self, **case_param):
""" According to the test case parameters, initialize the test """
collection_name = case_param["collection_name"]
metric_type = case_param["metric_type"]
dimension = case_param["dimension"]
vector_type = case_param["vector_type"]
index_type = case_param["index_type"]
index_param = case_param["index_param"]
index_field_name = case_param["index_field_name"]
self.milvus.set_collection(collection_name)
if self.milvus.exists_collection(collection_name):
logger.info("Re-create collection: %s" % collection_name)
self.milvus.drop()
dataset = case_param["dataset"]
self.milvus.create_collection(dimension, data_type=vector_type)
# Get the data set train for inserting into the collection
insert_vectors = utils.normalize(metric_type, np.array(dataset["train"]))
if len(insert_vectors) != dataset["train"].shape[0]:
raise Exception("Row count of insert vectors: %d is not equal to dataset size: %d" % (
len(insert_vectors), dataset["train"].shape[0]))
logger.debug("The row count of entities to be inserted: %d" % len(insert_vectors))
# Insert batch once
# milvus_instance.insert(insert_vectors)
info = self.milvus.get_info(collection_name)
loops = len(insert_vectors) // INSERT_INTERVAL + 1
for i in range(loops):
start = i * INSERT_INTERVAL
end = min((i + 1) * INSERT_INTERVAL, len(insert_vectors))
if start < end:
# Insert up to INSERT_INTERVAL=50000 at a time
tmp_vectors = insert_vectors[start:end]
ids = [i for i in range(start, end)]
if not isinstance(tmp_vectors, list):
entities = utils.generate_entities(info, tmp_vectors.tolist(), ids)
res_ids = self.milvus.insert(entities)
else:
entities = utils.generate_entities(tmp_vectors, ids)
res_ids = self.milvus.insert(entities)
assert res_ids == ids
logger.debug("End insert, start flush")
self.milvus.flush()
logger.debug("End flush")
res_count = self.milvus.count()
logger.info("Table: %s, row count: %d" % (collection_name, res_count))
if res_count != len(insert_vectors):
raise Exception("Table row count is not equal to insert vectors")
if self.milvus.describe_index(index_field_name):
self.milvus.drop_index(index_field_name)
logger.info("Re-create index: %s" % collection_name)
self.milvus.create_index(index_field_name, index_type, metric_type, index_param=index_param)
logger.info(self.milvus.describe_index(index_field_name))
logger.info("Start load collection: %s" % collection_name)
# self.milvus.release_collection()
self.milvus.load_collection(timeout=600)
logger.info("End load collection: %s" % collection_name)
def run_case(self, case_metric, **case_param):
true_ids = case_param["true_ids"]
nq = case_metric.search["nq"]
top_k = case_metric.search["topk"]
query_res = self.milvus.query(case_param["vector_query"], filter_query=case_param["filter_query"])
result_ids = self.milvus.get_ids(query_res)
# Calculate the accuracy of the result of query
acc_value = utils.get_recall_value(true_ids[:nq, :top_k].tolist(), result_ids)
tmp_result = {"acc": acc_value}
# Return accuracy results for reporting
return tmp_result