diff --git a/tests/python_client/deploy/scripts/first_recall_test.py b/tests/python_client/deploy/scripts/first_recall_test.py index 1ce6862cda..43bfde8597 100644 --- a/tests/python_client/deploy/scripts/first_recall_test.py +++ b/tests/python_client/deploy/scripts/first_recall_test.py @@ -15,7 +15,7 @@ pymilvus_version = pymilvus.__version__ all_index_types = ["IVF_FLAT", "IVF_SQ8", "HNSW"] -default_index_params = [{"nlist": 128}, {"nlist": 128}, {"M": 48, "efConstruction": 100}] +default_index_params = [{"nlist": 128}, {"nlist": 128}, {"M": 48, "efConstruction": 200}] index_params_map = dict(zip(all_index_types, default_index_params)) @@ -40,7 +40,7 @@ def gen_search_param(index_type, metric_type="L2"): bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}} search_params.append(bin_search_params) elif index_type in ["HNSW"]: - for ef in [50]: + for ef in [150]: hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}} search_params.append(hnsw_search_param) elif index_type == "ANNOY": @@ -161,7 +161,7 @@ def milvus_recall_test(host='127.0.0.1', index_type="HNSW"): assert len(item) == len(true_ids[index]) tmp = set(true_ids[index]).intersection(set(item)) sum_radio = sum_radio + len(tmp) / len(item) - recall = round(sum_radio / len(result_ids), 3) + recall = round(sum_radio / len(result_ids), 6) logger.info(f"recall={recall}") if index_type in ["IVF_PQ", "ANNOY"]: assert recall >= 0.6, f"recall={recall} < 0.6" diff --git a/tests/python_client/deploy/scripts/second_recall_test.py b/tests/python_client/deploy/scripts/second_recall_test.py index 4b15892657..59caa3f62d 100644 --- a/tests/python_client/deploy/scripts/second_recall_test.py +++ b/tests/python_client/deploy/scripts/second_recall_test.py @@ -30,7 +30,7 @@ def gen_search_param(index_type, metric_type="L2"): bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}} search_params.append(bin_search_params) elif index_type in ["HNSW"]: - for ef in [50]: + for ef in [150]: hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}} search_params.append(hnsw_search_param) elif index_type == "ANNOY": @@ -80,24 +80,13 @@ def search_test(host="127.0.0.1", index_type="HNSW"): assert len(item) == len(true_ids[index]), f"get {len(item)} but expect {len(true_ids[index])}" tmp = set(true_ids[index]).intersection(set(item)) sum_radio = sum_radio + len(tmp) / len(item) - recall = round(sum_radio / len(result_ids), 3) + recall = round(sum_radio / len(result_ids), 6) logger.info(f"recall={recall}") if index_type in ["IVF_PQ", "ANNOY"]: assert recall >= 0.6, f"recall={recall} < 0.6" else: assert 0.95 <= recall < 1.0, f"recall is {recall}, less than 0.95, greater than or equal to 1.0" - # calculate recall - true_ids = neighbors[:nq,:topK] - sum_radio = 0.0 - for index, item in enumerate(result_ids): - # tmp = set(item).intersection(set(flat_id_list[index])) - assert len(item) == len(true_ids[index]), f"get {len(item)} but expect {len(true_ids[index])}" - tmp = set(true_ids[index]).intersection(set(item)) - sum_radio = sum_radio + len(tmp) / len(item) - recall = round(sum_radio / len(result_ids), 3) - assert recall >= 0.95, f"recall is {recall}, less than 0.95" - logger.info(f"recall={recall}") if __name__ == "__main__":