import time import random import pdb import threading import logging import json from multiprocessing import Pool, Process import pytest from utils.util_pymilvus import get_milvus, restart_server, gen_entities, gen_unique_str, default_nb from common.constants import default_fields, default_entities from common.common_type import CaseLabel uid = "wal" TIMEOUT = 120 insert_interval_time = 1.5 big_nb = 100000 field_name = "float_vector" big_entities = gen_entities(big_nb) default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} class TestRestartBase: """ ****************************************************************** The following cases are used to test `create_partition` function ****************************************************************** """ @pytest.fixture(scope="module", autouse=True) def skip_check(self, args): logging.getLogger().info(args) if "service_name" not in args or not args["service_name"]: reason = "Skip if service name not provided" logging.getLogger().info(reason) pytest.skip(reason) if args["service_name"].find("shards") != -1: reason = "Skip restart cases in shards mode" logging.getLogger().info(reason) pytest.skip(reason) @pytest.mark.tags(CaseLabel.L2) def _test_insert_flush(self, connect, collection, args): """ target: return the same row count after server restart method: call function: create collection, then insert/flush, restart server and assert row count expected: row count keep the same """ ids = connect.bulk_insert(collection, default_entities) connect.flush([collection]) ids = connect.bulk_insert(collection, default_entities) connect.flush([collection]) res_count = connect.count_entities(collection) logging.getLogger().info(res_count) assert res_count == 2 * default_nb # restart server logging.getLogger().info("Start restart server") assert restart_server(args["service_name"]) # assert row count again new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"]) res_count = new_connect.count_entities(collection) logging.getLogger().info(res_count) assert res_count == 2 * default_nb @pytest.mark.tags(CaseLabel.L2) def _test_insert_during_flushing(self, connect, collection, args): """ target: flushing will recover method: call function: create collection, then insert/flushing, restart server and assert row count expected: row count equals 0 """ # disable_autoflush() ids = connect.bulk_insert(collection, big_entities) connect.flush([collection], _async=True) res_count = connect.count_entities(collection) logging.getLogger().info(res_count) if res_count < big_nb: # restart server assert restart_server(args["service_name"]) # assert row count again new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"]) res_count_2 = new_connect.count_entities(collection) logging.getLogger().info(res_count_2) timeout = 300 start_time = time.time() while new_connect.count_entities(collection) != big_nb and (time.time() - start_time < timeout): time.sleep(10) logging.getLogger().info(new_connect.count_entities(collection)) res_count_3 = new_connect.count_entities(collection) logging.getLogger().info(res_count_3) assert res_count_3 == big_nb @pytest.mark.tags(CaseLabel.L2) def _test_delete_during_flushing(self, connect, collection, args): """ target: flushing will recover method: call function: create collection, then delete/flushing, restart server and assert row count expected: row count equals (nb - delete_length) """ # disable_autoflush() ids = connect.bulk_insert(collection, big_entities) connect.flush([collection]) delete_length = 1000 delete_ids = ids[big_nb//4:big_nb//4+delete_length] delete_res = connect.delete_entity_by_id(collection, delete_ids) connect.flush([collection], _async=True) res_count = connect.count_entities(collection) logging.getLogger().info(res_count) # restart server assert restart_server(args["service_name"]) # assert row count again new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"]) res_count_2 = new_connect.count_entities(collection) logging.getLogger().info(res_count_2) timeout = 100 start_time = time.time() while new_connect.count_entities(collection) != big_nb - delete_length and (time.time() - start_time < timeout): time.sleep(10) logging.getLogger().info(new_connect.count_entities(collection)) if new_connect.count_entities(collection) == big_nb - delete_length: time.sleep(10) res_count_3 = new_connect.count_entities(collection) logging.getLogger().info(res_count_3) assert res_count_3 == big_nb - delete_length @pytest.mark.tags(CaseLabel.L2) def _test_during_indexed(self, connect, collection, args): """ target: flushing will recover method: call function: create collection, then indexed, restart server and assert row count expected: row count equals nb """ # disable_autoflush() ids = connect.bulk_insert(collection, big_entities) connect.flush([collection]) connect.create_index(collection, field_name, default_index) res_count = connect.count_entities(collection) logging.getLogger().info(res_count) stats = connect.get_collection_stats(collection) # logging.getLogger().info(stats) # pdb.set_trace() # restart server assert restart_server(args["service_name"]) # assert row count again new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"]) assert new_connect.count_entities(collection) == big_nb stats = connect.get_collection_stats(collection) for file in stats["partitions"][0]["segments"][0]["files"]: if file["field"] == field_name and file["name"] != "_raw": assert file["data_size"] > 0 if file["index_type"] != default_index["index_type"]: assert False else: assert True @pytest.mark.tags(CaseLabel.L2) def _test_during_indexing(self, connect, collection, args): """ target: flushing will recover method: call function: create collection, then indexing, restart server and assert row count expected: row count equals nb, server contitue to build index after restart """ # disable_autoflush() loop = 5 for i in range(loop): ids = connect.bulk_insert(collection, big_entities) connect.flush([collection]) connect.create_index(collection, field_name, default_index, _async=True) res_count = connect.count_entities(collection) logging.getLogger().info(res_count) stats = connect.get_collection_stats(collection) # logging.getLogger().info(stats) # restart server assert restart_server(args["service_name"]) # assert row count again new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"]) res_count_2 = new_connect.count_entities(collection) logging.getLogger().info(res_count_2) assert res_count_2 == loop * big_nb status = new_connect._cmd("status") assert json.loads(status)["indexing"] == True # timeout = 100 # start_time = time.time() # while time.time() - start_time < timeout: # time.sleep(5) # assert new_connect.count_entities(collection) == loop * big_nb # stats = connect.get_collection_stats(collection) # assert stats["row_count"] == loop * big_nb # for file in stats["partitions"][0]["segments"][0]["files"]: # # logging.getLogger().info(file) # if file["field"] == field_name and file["name"] != "_raw": # assert file["data_size"] > 0 # if file["index_type"] != default_index["index_type"]: # continue # for file in stats["partitions"][0]["segments"][0]["files"]: # if file["field"] == field_name and file["name"] != "_raw": # assert file["data_size"] > 0 # if file["index_type"] != default_index["index_type"]: # assert False # else: # assert True @pytest.mark.tags(CaseLabel.L2) def _test_delete_flush_during_compacting(self, connect, collection, args): """ target: verify server work after restart during compaction method: call function: create collection, then delete/flush/compacting, restart server and assert row count call `compact` again, compact pass expected: row count equals (nb - delete_length) """ # disable_autoflush() ids = connect.bulk_insert(collection, big_entities) connect.flush([collection]) delete_length = 1000 loop = 10 for i in range(loop): delete_ids = ids[i*delete_length:(i+1)*delete_length] delete_res = connect.delete_entity_by_id(collection, delete_ids) connect.flush([collection]) connect.compact(collection, _async=True) res_count = connect.count_entities(collection) logging.getLogger().info(res_count) assert res_count == big_nb - delete_length*loop info = connect.get_collection_stats(collection) size_old = info["partitions"][0]["segments"][0]["data_size"] logging.getLogger().info(size_old) # restart server assert restart_server(args["service_name"]) # assert row count again new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"]) res_count_2 = new_connect.count_entities(collection) logging.getLogger().info(res_count_2) assert res_count_2 == big_nb - delete_length*loop info = connect.get_collection_stats(collection) size_before = info["partitions"][0]["segments"][0]["data_size"] status = connect.compact(collection) assert status.OK() info = connect.get_collection_stats(collection) size_after = info["partitions"][0]["segments"][0]["data_size"] assert size_before > size_after @pytest.mark.tags(CaseLabel.L2) def _test_insert_during_flushing_multi_collections(self, connect, args): """ target: flushing will recover method: call function: create collections, then insert/flushing, restart server and assert row count expected: row count equals 0 """ # disable_autoflush() collection_num = 2 collection_list = [] for i in range(collection_num): collection_name = gen_unique_str(uid) collection_list.append(collection_name) connect.create_collection(collection_name, default_fields) ids = connect.bulk_insert(collection_name, big_entities) connect.flush(collection_list, _async=True) res_count = connect.count_entities(collection_list[-1]) logging.getLogger().info(res_count) if res_count < big_nb: # restart server assert restart_server(args["service_name"]) # assert row count again new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"]) res_count_2 = new_connect.count_entities(collection_list[-1]) logging.getLogger().info(res_count_2) timeout = 300 start_time = time.time() while time.time() - start_time < timeout: count_list = [] break_flag = True for index, name in enumerate(collection_list): tmp_count = new_connect.count_entities(name) count_list.append(tmp_count) logging.getLogger().info(count_list) if tmp_count != big_nb: break_flag = False break if break_flag == True: break time.sleep(10) for name in collection_list: assert new_connect.count_entities(name) == big_nb @pytest.mark.tags(CaseLabel.L2) def _test_insert_during_flushing_multi_partitions(self, connect, collection, args): """ target: flushing will recover method: call function: create collection/partition, then insert/flushing, restart server and assert row count expected: row count equals 0 """ # disable_autoflush() partitions_num = 2 partitions = [] for i in range(partitions_num): tag_tmp = gen_unique_str() partitions.append(tag_tmp) connect.create_partition(collection, tag_tmp) ids = connect.bulk_insert(collection, big_entities, partition_name=tag_tmp) connect.flush([collection], _async=True) res_count = connect.count_entities(collection) logging.getLogger().info(res_count) if res_count < big_nb: # restart server assert restart_server(args["service_name"]) # assert row count again new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"]) res_count_2 = new_connect.count_entities(collection) logging.getLogger().info(res_count_2) timeout = 300 start_time = time.time() while new_connect.count_entities(collection) != big_nb * 2 and (time.time() - start_time < timeout): time.sleep(10) logging.getLogger().info(new_connect.count_entities(collection)) res_count_3 = new_connect.count_entities(collection) logging.getLogger().info(res_count_3) assert res_count_3 == big_nb * 2