[test]Update HNSW index param (#22686)

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
pull/22695/head
zhuwenxing 2023-03-10 15:45:56 +08:00 committed by GitHub
parent 7fd59cbcc2
commit 6e51905efc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 16 deletions

View File

@ -15,7 +15,7 @@ pymilvus_version = pymilvus.__version__
all_index_types = ["IVF_FLAT", "IVF_SQ8", "HNSW"]
default_index_params = [{"nlist": 128}, {"nlist": 128}, {"M": 48, "efConstruction": 100}]
default_index_params = [{"nlist": 128}, {"nlist": 128}, {"M": 48, "efConstruction": 200}]
index_params_map = dict(zip(all_index_types, default_index_params))
@ -40,7 +40,7 @@ def gen_search_param(index_type, metric_type="L2"):
bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}}
search_params.append(bin_search_params)
elif index_type in ["HNSW"]:
for ef in [50]:
for ef in [150]:
hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}}
search_params.append(hnsw_search_param)
elif index_type == "ANNOY":
@ -161,7 +161,7 @@ def milvus_recall_test(host='127.0.0.1', index_type="HNSW"):
assert len(item) == len(true_ids[index])
tmp = set(true_ids[index]).intersection(set(item))
sum_radio = sum_radio + len(tmp) / len(item)
recall = round(sum_radio / len(result_ids), 3)
recall = round(sum_radio / len(result_ids), 6)
logger.info(f"recall={recall}")
if index_type in ["IVF_PQ", "ANNOY"]:
assert recall >= 0.6, f"recall={recall} < 0.6"

View File

@ -30,7 +30,7 @@ def gen_search_param(index_type, metric_type="L2"):
bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}}
search_params.append(bin_search_params)
elif index_type in ["HNSW"]:
for ef in [50]:
for ef in [150]:
hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}}
search_params.append(hnsw_search_param)
elif index_type == "ANNOY":
@ -80,24 +80,13 @@ def search_test(host="127.0.0.1", index_type="HNSW"):
assert len(item) == len(true_ids[index]), f"get {len(item)} but expect {len(true_ids[index])}"
tmp = set(true_ids[index]).intersection(set(item))
sum_radio = sum_radio + len(tmp) / len(item)
recall = round(sum_radio / len(result_ids), 3)
recall = round(sum_radio / len(result_ids), 6)
logger.info(f"recall={recall}")
if index_type in ["IVF_PQ", "ANNOY"]:
assert recall >= 0.6, f"recall={recall} < 0.6"
else:
assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95, greater than or equal to 1.0"
# calculate recall
true_ids = neighbors[:nq,:topK]
sum_radio = 0.0
for index, item in enumerate(result_ids):
# tmp = set(item).intersection(set(flat_id_list[index]))
assert len(item) == len(true_ids[index]), f"get {len(item)} but expect {len(true_ids[index])}"
tmp = set(true_ids[index]).intersection(set(item))
sum_radio = sum_radio + len(tmp) / len(item)
recall = round(sum_radio / len(result_ids), 3)
assert recall >= 0.95, f"recall is {recall}, less than 0.95"
logger.info(f"recall={recall}")
if __name__ == "__main__":