import sys import copy import time from loguru import logger import pymilvus from pymilvus import ( FieldSchema, CollectionSchema, DataType, Collection, list_collections, ) logger.remove() logger.add(sys.stderr, format= "{time:YYYY-MM-DD HH:mm:ss.SSS} | " "{level: <8} | " "{thread.name} |" "{name}:{function}:{line} - {message}", level="INFO") pymilvus_version = pymilvus.__version__ all_index_types = ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW"] default_index_params = [{}, {"nlist": 128}, {"nlist": 128}, {"nlist": 128, "m": 16, "nbits": 8}, {"M": 48, "efConstruction": 500}] index_params_map = dict(zip(all_index_types, default_index_params)) NUM_REPLICAS = 2 def filter_collections_by_prefix(prefix): col_list = list_collections() logger.info(f"all collections: {col_list}") res = [] for col in col_list: if col.startswith(prefix): if any(index_name in col for index_name in all_index_types): res.append(col) else: logger.warning(f"collection {col} has no supported index, skip") logger.info(f"filtered collections with prefix {prefix}: {res}") return res def gen_search_param(index_type, metric_type="L2"): search_params = [] if index_type in ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ"]: for nprobe in [10]: ivf_search_params = {"metric_type": metric_type, "params": {"nprobe": nprobe}} search_params.append(ivf_search_params) elif index_type in ["BIN_FLAT", "BIN_IVF_FLAT"]: for nprobe in [10]: bin_search_params = {"metric_type": "HAMMING", "params": {"nprobe": nprobe}} search_params.append(bin_search_params) elif index_type in ["HNSW"]: for ef in [64]: hnsw_search_param = {"metric_type": metric_type, "params": {"ef": ef}} search_params.append(hnsw_search_param) elif index_type == "ANNOY": for search_k in [1000]: annoy_search_param = {"metric_type": metric_type, "params": {"search_k": search_k}} search_params.append(annoy_search_param) else: logger.info("Invalid index_type.") raise Exception("Invalid index_type.") return search_params def get_collections(prefix, check=False): logger.info("\nList collections...") col_list = filter_collections_by_prefix(prefix) logger.info(f"collections_nums: {len(col_list)}") # list entities if collections for name in col_list: c = Collection(name=name) if pymilvus_version >= "2.2.0": c.flush() else: c.num_entities num_entities = c.num_entities logger.info(f"{name}: {num_entities}") if check: assert num_entities >= 3000 return col_list def create_collections_and_insert_data(prefix, flush=True, count=3000, collection_cnt=11): import random dim = 128 nb = count // 10 default_fields = [ FieldSchema(name="count", dtype=DataType.INT64, is_primary=True), FieldSchema(name="random_value", dtype=DataType.DOUBLE), FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=dim) ] default_schema = CollectionSchema(fields=default_fields, description="test collection") for index_name in all_index_types[:collection_cnt]: logger.info("\nCreate collection...") col_name = prefix + index_name collection = Collection(name=col_name, schema=default_schema) logger.info(f"collection name: {col_name}") logger.info(f"begin insert, count: {count} nb: {nb}") times = int(count // nb) total_time = 0.0 vectors = [[random.random() for _ in range(dim)] for _ in range(count)] for j in range(times): start_time = time.time() collection.insert( [ [i for i in range(nb * j, nb * j + nb)], [float(random.randrange(-20, -10)) for _ in range(nb)], vectors[nb*j:nb*j+nb] ] ) end_time = time.time() logger.info(f"[{j+1}/{times}] insert {nb} data, time: {end_time - start_time:.4f}") total_time += end_time - start_time if j <= times - 3: collection.flush() collection.num_entities if j == times - 3: collection.compact() logger.info(f"end insert, time: {total_time:.4f}") if flush: logger.info("Get collection entities") start_time = time.time() if pymilvus_version >= "2.2.0": collection.flush() else: collection.num_entities logger.info(f"collection entities: {collection.num_entities}") end_time = time.time() logger.info("Get collection entities time = %.4fs" % (end_time - start_time)) logger.info("\nList collections...") logger.info(get_collections(prefix)) def create_index_flat(): # create index default_flat_index = {"index_type": "FLAT", "params": {}, "metric_type": "L2"} all_col_list = list_collections() col_list = [] for col_name in all_col_list: if "FLAT" in col_name and "task" in col_name and "IVF" not in col_name: col_list.append(col_name) logger.info("\nCreate index for FLAT...") for col_name in col_list: c = Collection(name=col_name) logger.info(c) try: replicas = c.get_replicas() replica_number = len(replicas.groups) c.release() except Exception as e: replica_number = 0 logger.info(e) t0 = time.time() c.create_index(field_name="float_vector", index_params=default_flat_index) logger.info(f"create index time: {time.time() - t0:.4f}") if replica_number > 0: c.load(replica_number=replica_number) def create_index(prefix): # create index default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} col_list = get_collections(prefix) logger.info("\nCreate index...") for col_name in col_list: c = Collection(name=col_name) try: replicas = c.get_replicas() replica_number = len(replicas.groups) c.release() except Exception as e: replica_number = 0 logger.info(e) index_name = col_name.replace(prefix, "") logger.info(index_name) logger.info(c) index = copy.deepcopy(default_index) index["index_type"] = index_name index["params"] = index_params_map[index_name] if index_name in ["BIN_FLAT", "BIN_IVF_FLAT"]: index["metric_type"] = "HAMMING" index_info_list = [x.to_dict() for x in c.indexes] logger.info(index_info_list) is_indexed = False for index_info in index_info_list: if "metric_type" in index_info.keys() or "metric_type" in index_info["index_param"]: is_indexed = True logger.info(f"collection {col_name} has been indexed with {index_info}") if not is_indexed: t0 = time.time() c.create_index(field_name="float_vector", index_params=index) logger.info(f"create index time: {time.time() - t0:.4f}") if replica_number > 0: c.load(replica_number=replica_number) def release_collection(prefix): col_list = get_collections(prefix) logger.info("release collection") for col_name in col_list: c = Collection(name=col_name) c.release() def load_and_search(prefix, replicas=1): logger.info("search data starts") col_list = get_collections(prefix) for col_name in col_list: c = Collection(name=col_name) logger.info(f"collection name: {col_name}") logger.info("load collection") if replicas == 1: t0 = time.time() c.load() logger.info(f"load time: {time.time() - t0:.4f}") if replicas > 1: logger.info("release collection before load if replicas > 1") t0 = time.time() c.release() logger.info(f"release time: {time.time() - t0:.4f}") t0 = time.time() c.load(replica_number=replicas) logger.info(f"load time: {time.time() - t0:.4f}") logger.info(c.get_replicas()) topK = 5 vectors = [[1.0 for _ in range(128)] for _ in range(3000)] index_name = col_name.replace(prefix, "") search_params = gen_search_param(index_name)[0] logger.info(search_params) # search_params = {"metric_type": "L2", "params": {"nprobe": 10}} start_time = time.time() logger.info(f"\nSearch...") # define output_fields of search result v_search = vectors[:1] res = c.search( v_search, "float_vector", search_params, topK, "count > 500", output_fields=["count", "random_value"], timeout=120 ) end_time = time.time() # show result for hits in res: for hit in hits: logger.info(f"hit: {hit}") ids = hits.ids assert len(ids) == topK, f"get {len(ids)} results, but topK is {topK}" logger.info(ids) assert len(res) == len(v_search), f"get {len(res)} results, but search num is {len(v_search)}" logger.info("search latency: %.4fs" % (end_time - start_time)) t0 = time.time() expr = "count in [2,4,6,8]" if "SQ" in col_name or "PQ" in col_name: output_fields = ["count", "random_value"] else: output_fields = ["count", "random_value", "float_vector"] res = c.query(expr, output_fields, timeout=120) sorted_res = sorted(res, key=lambda k: k['count']) for r in sorted_res: logger.info(r) t1 = time.time() assert len(res) == 4 logger.info("query latency: %.4fs" % (t1 - t0)) # c.release() logger.info("###########") logger.info("search data ends")