diff --git a/internal/querynode/benchmark_test.go b/internal/querynode/benchmark_test.go index a27679ad60..dab4a2e368 100644 --- a/internal/querynode/benchmark_test.go +++ b/internal/querynode/benchmark_test.go @@ -47,23 +47,18 @@ func benchmarkQueryCollectionSearch(nq int64, b *testing.B) { assert.NoError(b, err) // search only one segment - err = queryShardObj.streaming.removeSegment(defaultSegmentID) - assert.NoError(b, err) - err = queryShardObj.historical.removeSegment(defaultSegmentID) - assert.NoError(b, err) - - assert.Equal(b, 0, queryShardObj.historical.getSegmentNum()) - assert.Equal(b, 0, queryShardObj.streaming.getSegmentNum()) + assert.Equal(b, 0, queryShardObj.metaReplica.getSegmentNum(segmentTypeSealed)) + assert.Equal(b, 0, queryShardObj.metaReplica.getSegmentNum(segmentTypeGrowing)) segment, err := genSimpleSealedSegment(nb) assert.NoError(b, err) - err = queryShardObj.historical.setSegment(segment) + err = queryShardObj.metaReplica.setSegment(segment) assert.NoError(b, err) // segment check - assert.Equal(b, 1, queryShardObj.historical.getSegmentNum()) - assert.Equal(b, 0, queryShardObj.streaming.getSegmentNum()) - seg, err := queryShardObj.historical.getSegmentByID(defaultSegmentID) + assert.Equal(b, 1, queryShardObj.metaReplica.getSegmentNum(segmentTypeSealed)) + assert.Equal(b, 0, queryShardObj.metaReplica.getSegmentNum(segmentTypeGrowing)) + seg, err := queryShardObj.metaReplica.getSegmentByID(defaultSegmentID, segmentTypeSealed) assert.NoError(b, err) assert.Equal(b, int64(nb), seg.getRowCount()) @@ -75,7 +70,7 @@ func benchmarkQueryCollectionSearch(nq int64, b *testing.B) { // warming up - collection, err := queryShardObj.historical.getCollectionByID(defaultCollectionID) + collection, err := queryShardObj.metaReplica.getCollectionByID(defaultCollectionID) assert.NoError(b, err) iReq, _ := genSearchRequest(nq, IndexFaissIDMap, collection.schema) @@ -89,7 +84,7 @@ func benchmarkQueryCollectionSearch(nq int64, b *testing.B) { searchReq, err := newSearchRequest(collection, queryReq, queryReq.Req.GetPlaceholderGroup()) assert.NoError(b, err) for j := 0; j < 10000; j++ { - _, _, _, err := searchHistorical(queryShardObj.historical, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs()) + _, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs()) assert.NoError(b, err) } @@ -113,7 +108,7 @@ func benchmarkQueryCollectionSearch(nq int64, b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { for j := int64(0); j < benchmarkMaxNQ/nq; j++ { - _, _, _, err := searchHistorical(queryShardObj.historical, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs()) + _, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, queryReq.GetSegmentIDs()) assert.NoError(b, err) } } @@ -128,25 +123,20 @@ func benchmarkQueryCollectionSearchIndex(nq int64, indexType string, b *testing. queryShardObj, err := genSimpleQueryShard(tx) assert.NoError(b, err) - err = queryShardObj.historical.removeSegment(defaultSegmentID) - assert.NoError(b, err) - err = queryShardObj.streaming.removeSegment(defaultSegmentID) - assert.NoError(b, err) - - assert.Equal(b, 0, queryShardObj.historical.getSegmentNum()) - assert.Equal(b, 0, queryShardObj.streaming.getSegmentNum()) + assert.Equal(b, 0, queryShardObj.metaReplica.getSegmentNum(segmentTypeSealed)) + assert.Equal(b, 0, queryShardObj.metaReplica.getSegmentNum(segmentTypeGrowing)) node, err := genSimpleQueryNode(tx) assert.NoError(b, err) - node.loader.historicalReplica = queryShardObj.historical + node.loader.metaReplica = queryShardObj.metaReplica err = loadIndexForSegment(tx, node, defaultSegmentID, nb, indexType, L2, schemapb.DataType_Int64) assert.NoError(b, err) // segment check - assert.Equal(b, 1, queryShardObj.historical.getSegmentNum()) - assert.Equal(b, 0, queryShardObj.streaming.getSegmentNum()) - seg, err := queryShardObj.historical.getSegmentByID(defaultSegmentID) + assert.Equal(b, 1, queryShardObj.metaReplica.getSegmentNum(segmentTypeSealed)) + assert.Equal(b, 0, queryShardObj.metaReplica.getSegmentNum(segmentTypeGrowing)) + seg, err := queryShardObj.metaReplica.getSegmentByID(defaultSegmentID, segmentTypeSealed) assert.NoError(b, err) assert.Equal(b, int64(nb), seg.getRowCount()) //TODO:: check string data in segcore @@ -156,14 +146,14 @@ func benchmarkQueryCollectionSearchIndex(nq int64, indexType string, b *testing. //assert.Equal(b, seg.getMemSize(), int64(expectSize)) // warming up - collection, err := queryShardObj.historical.getCollectionByID(defaultCollectionID) + collection, err := queryShardObj.metaReplica.getCollectionByID(defaultCollectionID) assert.NoError(b, err) //ollection *Collection, indexType string, nq int32 searchReq, _ := genSearchPlanAndRequests(collection, indexType, nq) for j := 0; j < 10000; j++ { - _, _, _, err := searchHistorical(queryShardObj.historical, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) + _, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) assert.NoError(b, err) } @@ -188,7 +178,7 @@ func benchmarkQueryCollectionSearchIndex(nq int64, indexType string, b *testing. b.ResetTimer() for i := 0; i < b.N; i++ { for j := 0; j < benchmarkMaxNQ/int(nq); j++ { - _, _, _, err := searchHistorical(queryShardObj.historical, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) + _, _, _, err := searchHistorical(queryShardObj.metaReplica, searchReq, defaultCollectionID, nil, []UniqueID{defaultSegmentID}) assert.NoError(b, err) } } diff --git a/internal/querynode/collection.go b/internal/querynode/collection.go index 6daa3455f7..d2b848da22 100644 --- a/internal/querynode/collection.go +++ b/internal/querynode/collection.go @@ -80,6 +80,13 @@ func (c *Collection) Schema() *schemapb.CollectionSchema { return c.schema } +// getPartitionIDs return partitionIDs of collection +func (c *Collection) getPartitionIDs() []UniqueID { + dst := make([]UniqueID, len(c.partitionIDs)) + copy(dst, c.partitionIDs) + return dst +} + // addPartitionID would add a partition id to partition id list of collection func (c *Collection) addPartitionID(partitionID UniqueID) { c.partitionIDs = append(c.partitionIDs, partitionID) diff --git a/internal/querynode/collection_test.go b/internal/querynode/collection_test.go index 557788d8bd..6781729a36 100644 --- a/internal/querynode/collection_test.go +++ b/internal/querynode/collection_test.go @@ -26,8 +26,7 @@ import ( func TestCollection_newCollection(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) assert.Equal(t, collection.ID(), collectionID) @@ -35,8 +34,7 @@ func TestCollection_newCollection(t *testing.T) { func TestCollection_deleteCollection(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) assert.Equal(t, collection.ID(), collectionID) @@ -45,8 +43,7 @@ func TestCollection_deleteCollection(t *testing.T) { func TestCollection_schema(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) collectionSchema := collection.Schema() @@ -57,8 +54,7 @@ func TestCollection_schema(t *testing.T) { func TestCollection_vChannel(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) collection.addVChannels([]Channel{defaultDMLChannel}) @@ -75,8 +71,7 @@ func TestCollection_vChannel(t *testing.T) { func TestCollection_vDeltaChannel(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) collection.addVDeltaChannels([]Channel{defaultDeltaChannel}) @@ -93,8 +88,7 @@ func TestCollection_vDeltaChannel(t *testing.T) { func TestCollection_pChannel(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) collection.addPChannels([]Channel{"TestCollection_addPChannel_channel-0"}) @@ -107,8 +101,7 @@ func TestCollection_pChannel(t *testing.T) { func TestCollection_pDeltaChannel(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) collection.addPDeltaChannels([]Channel{"TestCollection_addPDeltaChannel_channel-0"}) @@ -121,8 +114,7 @@ func TestCollection_pDeltaChannel(t *testing.T) { func TestCollection_releaseTime(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) t0 := Timestamp(1000) @@ -134,8 +126,7 @@ func TestCollection_releaseTime(t *testing.T) { func TestCollection_loadType(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) collection.setLoadType(loadTypeCollection) diff --git a/internal/querynode/data_sync_service.go b/internal/querynode/data_sync_service.go index e2c7bcd2fd..00f59c197d 100644 --- a/internal/querynode/data_sync_service.go +++ b/internal/querynode/data_sync_service.go @@ -36,32 +36,24 @@ type dataSyncService struct { dmlChannel2FlowGraph map[Channel]*queryNodeFlowGraph deltaChannel2FlowGraph map[Channel]*queryNodeFlowGraph - streamingReplica ReplicaInterface - historicalReplica ReplicaInterface - tSafeReplica TSafeReplicaInterface - msFactory msgstream.Factory + metaReplica ReplicaInterface + tSafeReplica TSafeReplicaInterface + msFactory msgstream.Factory } // checkReplica used to check replica info before init flow graph, it's a private method of dataSyncService func (dsService *dataSyncService) checkReplica(collectionID UniqueID) error { // check if the collection exists - hisColl, err := dsService.historicalReplica.getCollectionByID(collectionID) + coll, err := dsService.metaReplica.getCollectionByID(collectionID) if err != nil { return err } - strColl, err := dsService.streamingReplica.getCollectionByID(collectionID) - if err != nil { - return err - } - if hisColl.getLoadType() != strColl.getLoadType() { - return fmt.Errorf("inconsistent loadType of collection, collectionID = %d", collectionID) - } - for _, channel := range hisColl.getVChannels() { + for _, channel := range coll.getVChannels() { if _, err := dsService.tSafeReplica.getTSafe(channel); err != nil { return fmt.Errorf("getTSafe failed, err = %s", err) } } - for _, channel := range hisColl.getVDeltaChannels() { + for _, channel := range coll.getVDeltaChannels() { if _, err := dsService.tSafeReplica.getTSafe(channel); err != nil { return fmt.Errorf("getTSafe failed, err = %s", err) } @@ -89,7 +81,7 @@ func (dsService *dataSyncService) addFlowGraphsForDMLChannels(collectionID Uniqu } newFlowGraph, err := newQueryNodeFlowGraph(dsService.ctx, collectionID, - dsService.streamingReplica, + dsService.metaReplica, dsService.tSafeReplica, channel, dsService.msFactory) @@ -133,7 +125,7 @@ func (dsService *dataSyncService) addFlowGraphsForDeltaChannels(collectionID Uni } newFlowGraph, err := newQueryNodeDeltaFlowGraph(dsService.ctx, collectionID, - dsService.historicalReplica, + dsService.metaReplica, dsService.tSafeReplica, channel, dsService.msFactory) @@ -247,8 +239,7 @@ func (dsService *dataSyncService) removeFlowGraphsByDeltaChannels(channels []Cha // newDataSyncService returns a new dataSyncService func newDataSyncService(ctx context.Context, - streamingReplica ReplicaInterface, - historicalReplica ReplicaInterface, + metaReplica ReplicaInterface, tSafeReplica TSafeReplicaInterface, factory msgstream.Factory) *dataSyncService { @@ -256,8 +247,7 @@ func newDataSyncService(ctx context.Context, ctx: ctx, dmlChannel2FlowGraph: make(map[Channel]*queryNodeFlowGraph), deltaChannel2FlowGraph: make(map[Channel]*queryNodeFlowGraph), - streamingReplica: streamingReplica, - historicalReplica: historicalReplica, + metaReplica: metaReplica, tSafeReplica: tSafeReplica, msFactory: factory, } diff --git a/internal/querynode/data_sync_service_test.go b/internal/querynode/data_sync_service_test.go index 937650f47a..89a400f06b 100644 --- a/internal/querynode/data_sync_service_test.go +++ b/internal/querynode/data_sync_service_test.go @@ -21,25 +21,20 @@ import ( "testing" "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus/internal/proto/schemapb" ) func TestDataSyncService_DMLFlowGraphs(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - streamingReplica, err := genSimpleReplica() - assert.NoError(t, err) - - historicalReplica, err := genSimpleReplica() + replica, err := genSimpleReplica() assert.NoError(t, err) fac := genFactory() assert.NoError(t, err) tSafe := newTSafeReplica() - dataSyncService := newDataSyncService(ctx, streamingReplica, historicalReplica, tSafe, fac) + dataSyncService := newDataSyncService(ctx, replica, tSafe, fac) assert.NotNil(t, dataSyncService) t.Run("test DMLFlowGraphs", func(t *testing.T) { @@ -83,11 +78,11 @@ func TestDataSyncService_DMLFlowGraphs(t *testing.T) { }) t.Run("test addFlowGraphsForDMLChannels checkReplica Failed", func(t *testing.T) { - err = dataSyncService.historicalReplica.removeCollection(defaultCollectionID) + err = dataSyncService.metaReplica.removeCollection(defaultCollectionID) assert.NoError(t, err) _, err = dataSyncService.addFlowGraphsForDMLChannels(defaultCollectionID, []Channel{defaultDMLChannel}) assert.Error(t, err) - dataSyncService.historicalReplica.addCollection(defaultCollectionID, genTestCollectionSchema(schemapb.DataType_Int64)) + dataSyncService.metaReplica.addCollection(defaultCollectionID, genTestCollectionSchema()) }) } @@ -95,17 +90,14 @@ func TestDataSyncService_DeltaFlowGraphs(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - streamingReplica, err := genSimpleReplica() - assert.NoError(t, err) - - historicalReplica, err := genSimpleReplica() + replica, err := genSimpleReplica() assert.NoError(t, err) fac := genFactory() assert.NoError(t, err) tSafe := newTSafeReplica() - dataSyncService := newDataSyncService(ctx, streamingReplica, historicalReplica, tSafe, fac) + dataSyncService := newDataSyncService(ctx, replica, tSafe, fac) assert.NotNil(t, dataSyncService) t.Run("test DeltaFlowGraphs", func(t *testing.T) { @@ -149,11 +141,11 @@ func TestDataSyncService_DeltaFlowGraphs(t *testing.T) { }) t.Run("test addFlowGraphsForDeltaChannels checkReplica Failed", func(t *testing.T) { - err = dataSyncService.historicalReplica.removeCollection(defaultCollectionID) + err = dataSyncService.metaReplica.removeCollection(defaultCollectionID) assert.NoError(t, err) _, err = dataSyncService.addFlowGraphsForDeltaChannels(defaultCollectionID, []Channel{defaultDMLChannel}) assert.Error(t, err) - dataSyncService.historicalReplica.addCollection(defaultCollectionID, genTestCollectionSchema(schemapb.DataType_Int64)) + dataSyncService.metaReplica.addCollection(defaultCollectionID, genTestCollectionSchema()) }) } @@ -161,17 +153,14 @@ func TestDataSyncService_checkReplica(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - streamingReplica, err := genSimpleReplica() - assert.NoError(t, err) - - historicalReplica, err := genSimpleReplica() + replica, err := genSimpleReplica() assert.NoError(t, err) fac := genFactory() assert.NoError(t, err) tSafe := newTSafeReplica() - dataSyncService := newDataSyncService(ctx, streamingReplica, historicalReplica, tSafe, fac) + dataSyncService := newDataSyncService(ctx, replica, tSafe, fac) assert.NotNil(t, dataSyncService) defer dataSyncService.close() @@ -181,37 +170,16 @@ func TestDataSyncService_checkReplica(t *testing.T) { }) t.Run("test collection doesn't exist", func(t *testing.T) { - err = dataSyncService.streamingReplica.removeCollection(defaultCollectionID) + err = dataSyncService.metaReplica.removeCollection(defaultCollectionID) assert.NoError(t, err) err = dataSyncService.checkReplica(defaultCollectionID) assert.Error(t, err) - - err = dataSyncService.historicalReplica.removeCollection(defaultCollectionID) - assert.NoError(t, err) - err = dataSyncService.checkReplica(defaultCollectionID) - assert.Error(t, err) - - coll := dataSyncService.historicalReplica.addCollection(defaultCollectionID, genTestCollectionSchema(schemapb.DataType_Int64)) + coll := dataSyncService.metaReplica.addCollection(defaultCollectionID, genTestCollectionSchema()) assert.NotNil(t, coll) - coll = dataSyncService.streamingReplica.addCollection(defaultCollectionID, genTestCollectionSchema(schemapb.DataType_Int64)) - assert.NotNil(t, coll) - }) - - t.Run("test different loadType", func(t *testing.T) { - coll, err := dataSyncService.historicalReplica.getCollectionByID(defaultCollectionID) - assert.NoError(t, err) - coll.setLoadType(loadTypePartition) - - err = dataSyncService.checkReplica(defaultCollectionID) - assert.Error(t, err) - - coll, err = dataSyncService.streamingReplica.getCollectionByID(defaultCollectionID) - assert.NoError(t, err) - coll.setLoadType(loadTypePartition) }) t.Run("test cannot find tSafe", func(t *testing.T) { - coll, err := dataSyncService.historicalReplica.getCollectionByID(defaultCollectionID) + coll, err := dataSyncService.metaReplica.getCollectionByID(defaultCollectionID) assert.NoError(t, err) coll.addVDeltaChannels([]Channel{defaultDeltaChannel}) coll.addVChannels([]Channel{defaultDMLChannel}) diff --git a/internal/querynode/flow_graph_delete_node.go b/internal/querynode/flow_graph_delete_node.go index ebccbd841e..9589ab2740 100644 --- a/internal/querynode/flow_graph_delete_node.go +++ b/internal/querynode/flow_graph_delete_node.go @@ -40,7 +40,7 @@ var newVarCharPrimaryKey = storage.NewVarCharPrimaryKey // deleteNode is the one of nodes in delta flow graph type deleteNode struct { baseNode - replica ReplicaInterface // historical + metaReplica ReplicaInterface // historical } // Name returns the name of deleteNode @@ -92,12 +92,12 @@ func (dNode *deleteNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { zap.Int("numTS", len(delMsg.Timestamps)), zap.Any("timestampBegin", delMsg.BeginTs()), zap.Any("timestampEnd", delMsg.EndTs()), - zap.Any("segmentNum", dNode.replica.getSegmentNum()), + zap.Any("segmentNum", dNode.metaReplica.getSegmentNum(segmentTypeSealed)), zap.Any("traceID", traceID), ) - if dNode.replica.getSegmentNum() != 0 { - err := processDeleteMessages(dNode.replica, delMsg, delData) + if dNode.metaReplica.getSegmentNum(segmentTypeSealed) != 0 { + err := processDeleteMessages(dNode.metaReplica, segmentTypeSealed, delMsg, delData) if err != nil { // error occurs when missing meta info or unexpected pk type, should not happen err = fmt.Errorf("deleteNode processDeleteMessages failed, collectionID = %d, err = %s", delMsg.CollectionID, err) @@ -109,7 +109,7 @@ func (dNode *deleteNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { // 2. do preDelete for segmentID, pks := range delData.deleteIDs { - segment, err := dNode.replica.getSegmentByID(segmentID) + segment, err := dNode.metaReplica.getSegmentByID(segmentID, segmentTypeSealed) if err != nil { // should not happen, segment should be created before err = fmt.Errorf("deleteNode getSegmentByID failed, err = %s", err) @@ -150,7 +150,7 @@ func (dNode *deleteNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { // delete will do delete operation at segment which id is segmentID func (dNode *deleteNode) delete(deleteData *deleteData, segmentID UniqueID, wg *sync.WaitGroup) error { defer wg.Done() - targetSegment, err := dNode.replica.getSegmentByID(segmentID) + targetSegment, err := dNode.metaReplica.getSegmentByID(segmentID, segmentTypeSealed) if err != nil { return fmt.Errorf("getSegmentByID failed, err = %s", err) } @@ -173,7 +173,7 @@ func (dNode *deleteNode) delete(deleteData *deleteData, segmentID UniqueID, wg * } // newDeleteNode returns a new deleteNode -func newDeleteNode(historicalReplica ReplicaInterface) *deleteNode { +func newDeleteNode(metaReplica ReplicaInterface) *deleteNode { maxQueueLength := Params.QueryNodeCfg.FlowGraphMaxQueueLength maxParallelism := Params.QueryNodeCfg.FlowGraphMaxParallelism @@ -182,7 +182,7 @@ func newDeleteNode(historicalReplica ReplicaInterface) *deleteNode { baseNode.SetMaxParallelism(maxParallelism) return &deleteNode{ - baseNode: baseNode, - replica: historicalReplica, + baseNode: baseNode, + metaReplica: metaReplica, } } diff --git a/internal/querynode/flow_graph_delete_node_test.go b/internal/querynode/flow_graph_delete_node_test.go index 57187101b8..9c347a304b 100644 --- a/internal/querynode/flow_graph_delete_node_test.go +++ b/internal/querynode/flow_graph_delete_node_test.go @@ -122,7 +122,7 @@ func TestFlowGraphDeleteNode_operate(t *testing.T) { } msg := []flowgraph.Msg{&dMsg} deleteNode.Operate(msg) - s, err := historical.getSegmentByID(defaultSegmentID) + s, err := historical.getSegmentByID(defaultSegmentID, segmentTypeSealed) pks := make([]primaryKey, defaultMsgLength) for i := 0; i < defaultMsgLength; i++ { pks[i] = newInt64PrimaryKey(int64(i)) diff --git a/internal/querynode/flow_graph_filter_delete_node.go b/internal/querynode/flow_graph_filter_delete_node.go index 779b867f36..5f87eb29eb 100644 --- a/internal/querynode/flow_graph_filter_delete_node.go +++ b/internal/querynode/flow_graph_filter_delete_node.go @@ -34,7 +34,7 @@ import ( type filterDeleteNode struct { baseNode collectionID UniqueID - replica ReplicaInterface + metaReplica ReplicaInterface } // Name returns the name of filterDeleteNode @@ -124,13 +124,13 @@ func (fddNode *filterDeleteNode) filterInvalidDeleteMessage(msg *msgstream.Delet } // check if collection exists - col, err := fddNode.replica.getCollectionByID(msg.CollectionID) + col, err := fddNode.metaReplica.getCollectionByID(msg.CollectionID) if err != nil { // QueryNode should add collection before start flow graph return nil, fmt.Errorf("filter invalid delete message, collection does not exist, collectionID = %d", msg.CollectionID) } if col.getLoadType() == loadTypePartition { - if !fddNode.replica.hasPartition(msg.PartitionID) { + if !fddNode.metaReplica.hasPartition(msg.PartitionID) { // filter out msg which not belongs to the loaded partitions return nil, nil } @@ -139,7 +139,7 @@ func (fddNode *filterDeleteNode) filterInvalidDeleteMessage(msg *msgstream.Delet } // newFilteredDeleteNode returns a new filterDeleteNode -func newFilteredDeleteNode(replica ReplicaInterface, collectionID UniqueID) *filterDeleteNode { +func newFilteredDeleteNode(metaReplica ReplicaInterface, collectionID UniqueID) *filterDeleteNode { maxQueueLength := Params.QueryNodeCfg.FlowGraphMaxQueueLength maxParallelism := Params.QueryNodeCfg.FlowGraphMaxParallelism @@ -151,6 +151,6 @@ func newFilteredDeleteNode(replica ReplicaInterface, collectionID UniqueID) *fil return &filterDeleteNode{ baseNode: baseNode, collectionID: collectionID, - replica: replica, + metaReplica: metaReplica, } } diff --git a/internal/querynode/flow_graph_filter_delete_node_test.go b/internal/querynode/flow_graph_filter_delete_node_test.go index 4a0aecc681..ee01b2430b 100644 --- a/internal/querynode/flow_graph_filter_delete_node_test.go +++ b/internal/querynode/flow_graph_filter_delete_node_test.go @@ -96,10 +96,10 @@ func TestFlowGraphFilterDeleteNode_filterInvalidDeleteMessage(t *testing.T) { msg := genDeleteMsg(defaultCollectionID, schemapb.DataType_Int64, defaultDelLength) fg, err := getFilterDeleteNode() assert.NoError(t, err) - col, err := fg.replica.getCollectionByID(defaultCollectionID) + col, err := fg.metaReplica.getCollectionByID(defaultCollectionID) assert.NoError(t, err) col.setLoadType(loadTypePartition) - err = fg.replica.removePartition(defaultPartitionID) + err = fg.metaReplica.removePartition(defaultPartitionID) assert.NoError(t, err) res, err := fg.filterInvalidDeleteMessage(msg) @@ -146,7 +146,7 @@ func TestFlowGraphFilterDeleteNode_Operate(t *testing.T) { }) t.Run("invalid msgType", func(t *testing.T) { - iMsg, err := genSimpleInsertMsg(genTestCollectionSchema(schemapb.DataType_Int64), defaultDelLength) + iMsg, err := genSimpleInsertMsg(genTestCollectionSchema(), defaultDelLength) assert.NoError(t, err) msg := flowgraph.GenerateMsgStreamMsg([]msgstream.TsMsg{iMsg}, 0, 1000, nil, nil) diff --git a/internal/querynode/flow_graph_filter_dm_node.go b/internal/querynode/flow_graph_filter_dm_node.go index 4564f73619..c88565e766 100644 --- a/internal/querynode/flow_graph_filter_dm_node.go +++ b/internal/querynode/flow_graph_filter_dm_node.go @@ -34,7 +34,7 @@ import ( type filterDmNode struct { baseNode collectionID UniqueID - replica ReplicaInterface + metaReplica ReplicaInterface } // Name returns the name of filterDmNode @@ -140,13 +140,13 @@ func (fdmNode *filterDmNode) filterInvalidDeleteMessage(msg *msgstream.DeleteMsg } // check if collection exist - col, err := fdmNode.replica.getCollectionByID(msg.CollectionID) + col, err := fdmNode.metaReplica.getCollectionByID(msg.CollectionID) if err != nil { // QueryNode should add collection before start flow graph return nil, fmt.Errorf("filter invalid delete message, collection does not exist, collectionID = %d", msg.CollectionID) } if col.getLoadType() == loadTypePartition { - if !fdmNode.replica.hasPartition(msg.PartitionID) { + if !fdmNode.metaReplica.hasPartition(msg.PartitionID) { // filter out msg which not belongs to the loaded partitions return nil, nil } @@ -181,13 +181,13 @@ func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg } // check if collection exists - col, err := fdmNode.replica.getCollectionByID(msg.CollectionID) + col, err := fdmNode.metaReplica.getCollectionByID(msg.CollectionID) if err != nil { // QueryNode should add collection before start flow graph return nil, fmt.Errorf("filter invalid insert message, collection does not exist, collectionID = %d", msg.CollectionID) } if col.getLoadType() == loadTypePartition { - if !fdmNode.replica.hasPartition(msg.PartitionID) { + if !fdmNode.metaReplica.hasPartition(msg.PartitionID) { // filter out msg which not belongs to the loaded partitions return nil, nil } @@ -196,7 +196,7 @@ func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg // Check if the segment is in excluded segments, // messages after seekPosition may contain the redundant data from flushed slice of segment, // so we need to compare the endTimestamp of received messages and position's timestamp. - excludedSegments, err := fdmNode.replica.getExcludedSegments(fdmNode.collectionID) + excludedSegments, err := fdmNode.metaReplica.getExcludedSegments(fdmNode.collectionID) if err != nil { // QueryNode should addExcludedSegments for the current collection before start flow graph return nil, err @@ -221,7 +221,7 @@ func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg } // newFilteredDmNode returns a new filterDmNode -func newFilteredDmNode(replica ReplicaInterface, collectionID UniqueID) *filterDmNode { +func newFilteredDmNode(metaReplica ReplicaInterface, collectionID UniqueID) *filterDmNode { maxQueueLength := Params.QueryNodeCfg.FlowGraphMaxQueueLength maxParallelism := Params.QueryNodeCfg.FlowGraphMaxParallelism @@ -233,6 +233,6 @@ func newFilteredDmNode(replica ReplicaInterface, collectionID UniqueID) *filterD return &filterDmNode{ baseNode: baseNode, collectionID: collectionID, - replica: replica, + metaReplica: metaReplica, } } diff --git a/internal/querynode/flow_graph_filter_dm_node_test.go b/internal/querynode/flow_graph_filter_dm_node_test.go index b3ff39ae74..ea35271fdf 100644 --- a/internal/querynode/flow_graph_filter_dm_node_test.go +++ b/internal/querynode/flow_graph_filter_dm_node_test.go @@ -47,8 +47,7 @@ func TestFlowGraphFilterDmNode_filterDmNode(t *testing.T) { } func TestFlowGraphFilterDmNode_filterInvalidInsertMessage(t *testing.T) { - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() t.Run("valid test", func(t *testing.T) { msg, err := genSimpleInsertMsg(schema, defaultMsgLength) assert.NoError(t, err) @@ -79,7 +78,7 @@ func TestFlowGraphFilterDmNode_filterInvalidInsertMessage(t *testing.T) { fg, err := getFilterDMNode() assert.NoError(t, err) - col, err := fg.replica.getCollectionByID(defaultCollectionID) + col, err := fg.metaReplica.getCollectionByID(defaultCollectionID) assert.NoError(t, err) col.setLoadType(loadTypePartition) @@ -104,7 +103,7 @@ func TestFlowGraphFilterDmNode_filterInvalidInsertMessage(t *testing.T) { assert.NoError(t, err) fg, err := getFilterDMNode() assert.NoError(t, err) - fg.replica.removeExcludedSegments(defaultCollectionID) + fg.metaReplica.removeExcludedSegments(defaultCollectionID) res, err := fg.filterInvalidInsertMessage(msg) assert.Error(t, err) assert.Nil(t, res) @@ -115,7 +114,7 @@ func TestFlowGraphFilterDmNode_filterInvalidInsertMessage(t *testing.T) { assert.NoError(t, err) fg, err := getFilterDMNode() assert.NoError(t, err) - fg.replica.addExcludedSegments(defaultCollectionID, []*datapb.SegmentInfo{ + fg.metaReplica.addExcludedSegments(defaultCollectionID, []*datapb.SegmentInfo{ { ID: defaultSegmentID, CollectionID: defaultCollectionID, @@ -185,7 +184,7 @@ func TestFlowGraphFilterDmNode_filterInvalidDeleteMessage(t *testing.T) { fg, err := getFilterDMNode() assert.NoError(t, err) - col, err := fg.replica.getCollectionByID(defaultCollectionID) + col, err := fg.metaReplica.getCollectionByID(defaultCollectionID) assert.NoError(t, err) col.setLoadType(loadTypePartition) @@ -233,8 +232,7 @@ func TestFlowGraphFilterDmNode_filterInvalidDeleteMessage(t *testing.T) { } func TestFlowGraphFilterDmNode_Operate(t *testing.T) { - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() genFilterDMMsg := func() []flowgraph.Msg { iMsg, err := genSimpleInsertMsg(schema, defaultMsgLength) @@ -287,7 +285,7 @@ func TestFlowGraphFilterDmNode_Operate(t *testing.T) { }) t.Run("invalid msgType", func(t *testing.T) { - iMsg, err := genSimpleInsertMsg(genTestCollectionSchema(schemapb.DataType_Int64), defaultDelLength) + iMsg, err := genSimpleInsertMsg(genTestCollectionSchema(), defaultDelLength) assert.NoError(t, err) iMsg.Base.MsgType = commonpb.MsgType_Search msg := flowgraph.GenerateMsgStreamMsg([]msgstream.TsMsg{iMsg}, 0, 1000, nil, nil) diff --git a/internal/querynode/flow_graph_insert_node.go b/internal/querynode/flow_graph_insert_node.go index c1a26b5156..b9f3448620 100644 --- a/internal/querynode/flow_graph_insert_node.go +++ b/internal/querynode/flow_graph_insert_node.go @@ -43,7 +43,7 @@ import ( // insertNode is one of the nodes in query flow graph type insertNode struct { baseNode - streamingReplica ReplicaInterface + metaReplica ReplicaInterface // streaming } // insertData stores the valid insert data @@ -111,7 +111,7 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { }) for _, insertMsg := range iMsg.insertMessages { // if loadType is loadCollection, check if partition exists, if not, create partition - col, err := iNode.streamingReplica.getCollectionByID(insertMsg.CollectionID) + col, err := iNode.metaReplica.getCollectionByID(insertMsg.CollectionID) if err != nil { // should not happen, QueryNode should create collection before start flow graph err = fmt.Errorf("insertNode getCollectionByID failed, err = %s", err) @@ -119,7 +119,7 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { panic(err) } if col.getLoadType() == loadTypeCollection { - err = iNode.streamingReplica.addPartition(insertMsg.CollectionID, insertMsg.PartitionID) + err = iNode.metaReplica.addPartition(insertMsg.CollectionID, insertMsg.PartitionID) if err != nil { // error occurs only when collection cannot be found, should not happen err = fmt.Errorf("insertNode addPartition failed, err = %s", err) @@ -129,8 +129,13 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { } // check if segment exists, if not, create this segment - if !iNode.streamingReplica.hasSegment(insertMsg.SegmentID) { - err := iNode.streamingReplica.addSegment(insertMsg.SegmentID, insertMsg.PartitionID, insertMsg.CollectionID, insertMsg.ShardName, segmentTypeGrowing) + has, err := iNode.metaReplica.hasSegment(insertMsg.SegmentID, segmentTypeGrowing) + if err != nil { + log.Error(err.Error()) // never gonna happen + panic(err) + } + if !has { + err = iNode.metaReplica.addSegment(insertMsg.SegmentID, insertMsg.PartitionID, insertMsg.CollectionID, insertMsg.ShardName, segmentTypeGrowing) if err != nil { // error occurs when collection or partition cannot be found, collection and partition should be created before err = fmt.Errorf("insertNode addSegment failed, err = %s", err) @@ -154,7 +159,7 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { } else { typeutil.MergeFieldData(iData.insertRecords[insertMsg.SegmentID], insertRecord.FieldsData) } - pks, err := getPrimaryKeys(insertMsg, iNode.streamingReplica) + pks, err := getPrimaryKeys(insertMsg, iNode.metaReplica) if err != nil { // error occurs when cannot find collection or data is misaligned, should not happen err = fmt.Errorf("failed to get primary keys, err = %d", err) @@ -166,7 +171,7 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { // 2. do preInsert for segmentID := range iData.insertRecords { - var targetSegment, err = iNode.streamingReplica.getSegmentByID(segmentID) + var targetSegment, err = iNode.metaReplica.getSegmentByID(segmentID, segmentTypeGrowing) if err != nil { // should not happen, segment should be created before err = fmt.Errorf("insertNode getSegmentByID failed, err = %s", err) @@ -213,13 +218,13 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { } // 1. filter segment by bloom filter for _, delMsg := range iMsg.deleteMessages { - if iNode.streamingReplica.getSegmentNum() != 0 { + if iNode.metaReplica.getSegmentNum(segmentTypeGrowing) != 0 { log.Debug("delete in streaming replica", zap.Any("collectionID", delMsg.CollectionID), zap.Any("collectionName", delMsg.CollectionName), zap.Int64("numPKs", delMsg.NumRows), zap.Any("timestamp", delMsg.Timestamps)) - err := processDeleteMessages(iNode.streamingReplica, delMsg, delData) + err := processDeleteMessages(iNode.metaReplica, segmentTypeGrowing, delMsg, delData) if err != nil { // error occurs when missing meta info or unexpected pk type, should not happen err = fmt.Errorf("insertNode processDeleteMessages failed, collectionID = %d, err = %s", delMsg.CollectionID, err) @@ -231,7 +236,7 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { // 2. do preDelete for segmentID, pks := range delData.deleteIDs { - segment, err := iNode.streamingReplica.getSegmentByID(segmentID) + segment, err := iNode.metaReplica.getSegmentByID(segmentID, segmentTypeGrowing) if err != nil { // error occurs when segment cannot be found, should not happen err = fmt.Errorf("insertNode getSegmentByID failed, err = %s", err) @@ -269,7 +274,7 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { } // processDeleteMessages would execute delete operations for growing segments -func processDeleteMessages(replica ReplicaInterface, msg *msgstream.DeleteMsg, delData *deleteData) error { +func processDeleteMessages(replica ReplicaInterface, segType segmentType, msg *msgstream.DeleteMsg, delData *deleteData) error { var partitionIDs []UniqueID var err error if msg.PartitionID != -1 { @@ -282,7 +287,7 @@ func processDeleteMessages(replica ReplicaInterface, msg *msgstream.DeleteMsg, d } resultSegmentIDs := make([]UniqueID, 0) for _, partitionID := range partitionIDs { - segmentIDs, err := replica.getSegmentIDs(partitionID) + segmentIDs, err := replica.getSegmentIDs(partitionID, segType) if err != nil { return err } @@ -291,7 +296,7 @@ func processDeleteMessages(replica ReplicaInterface, msg *msgstream.DeleteMsg, d primaryKeys := storage.ParseIDs2PrimaryKeys(msg.PrimaryKeys) for _, segmentID := range resultSegmentIDs { - segment, err := replica.getSegmentByID(segmentID) + segment, err := replica.getSegmentByID(segmentID, segType) if err != nil { return err } @@ -341,7 +346,7 @@ func filterSegmentsByPKs(pks []primaryKey, timestamps []Timestamp, segment *Segm func (iNode *insertNode) insert(iData *insertData, segmentID UniqueID, wg *sync.WaitGroup) error { defer wg.Done() - var targetSegment, err = iNode.streamingReplica.getSegmentByID(segmentID) + var targetSegment, err = iNode.metaReplica.getSegmentByID(segmentID, segmentTypeGrowing) if err != nil { return fmt.Errorf("getSegmentByID failed, err = %s", err) } @@ -366,7 +371,7 @@ func (iNode *insertNode) insert(iData *insertData, segmentID UniqueID, wg *sync. // delete would execute delete operations for specific growing segment func (iNode *insertNode) delete(deleteData *deleteData, segmentID UniqueID, wg *sync.WaitGroup) error { defer wg.Done() - targetSegment, err := iNode.streamingReplica.getSegmentByID(segmentID) + targetSegment, err := iNode.metaReplica.getSegmentByID(segmentID, segmentTypeGrowing) if err != nil { return fmt.Errorf("getSegmentByID failed, err = %s", err) } @@ -390,14 +395,14 @@ func (iNode *insertNode) delete(deleteData *deleteData, segmentID UniqueID, wg * // TODO: remove this function to proper file // getPrimaryKeys would get primary keys by insert messages -func getPrimaryKeys(msg *msgstream.InsertMsg, streamingReplica ReplicaInterface) ([]primaryKey, error) { +func getPrimaryKeys(msg *msgstream.InsertMsg, metaReplica ReplicaInterface) ([]primaryKey, error) { if err := msg.CheckAligned(); err != nil { log.Warn("misaligned messages detected", zap.Error(err)) return nil, err } collectionID := msg.GetCollectionID() - collection, err := streamingReplica.getCollectionByID(collectionID) + collection, err := metaReplica.getCollectionByID(collectionID) if err != nil { log.Warn(err.Error()) return nil, err @@ -498,7 +503,7 @@ func getPKsFromColumnBasedInsertMsg(msg *msgstream.InsertMsg, schema *schemapb.C } // newInsertNode returns a new insertNode -func newInsertNode(streamingReplica ReplicaInterface) *insertNode { +func newInsertNode(metaReplica ReplicaInterface) *insertNode { maxQueueLength := Params.QueryNodeCfg.FlowGraphMaxQueueLength maxParallelism := Params.QueryNodeCfg.FlowGraphMaxParallelism @@ -507,7 +512,7 @@ func newInsertNode(streamingReplica ReplicaInterface) *insertNode { baseNode.SetMaxParallelism(maxParallelism) return &insertNode{ - baseNode: baseNode, - streamingReplica: streamingReplica, + baseNode: baseNode, + metaReplica: metaReplica, } } diff --git a/internal/querynode/flow_graph_insert_node_test.go b/internal/querynode/flow_graph_insert_node_test.go index 417814d118..686e2c79dd 100644 --- a/internal/querynode/flow_graph_insert_node_test.go +++ b/internal/querynode/flow_graph_insert_node_test.go @@ -96,8 +96,7 @@ func genFlowGraphDeleteData() (*deleteData, error) { } func TestFlowGraphInsertNode_insert(t *testing.T) { - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() t.Run("test insert", func(t *testing.T) { insertNode, err := getInsertNode() @@ -143,7 +142,7 @@ func TestFlowGraphInsertNode_insert(t *testing.T) { insertData, err := genFlowGraphInsertData(schema, defaultMsgLength) assert.NoError(t, err) - seg, err := insertNode.streamingReplica.getSegmentByID(defaultSegmentID) + seg, err := insertNode.metaReplica.getSegmentByID(defaultSegmentID, segmentTypeGrowing) assert.NoError(t, err) seg.setType(segmentTypeSealed) @@ -155,8 +154,7 @@ func TestFlowGraphInsertNode_insert(t *testing.T) { } func TestFlowGraphInsertNode_delete(t *testing.T) { - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() t.Run("test insert and delete", func(t *testing.T) { insertNode, err := getInsertNode() @@ -222,7 +220,7 @@ func TestFlowGraphInsertNode_processDeleteMessages(t *testing.T) { dData, err := genFlowGraphDeleteData() assert.NoError(t, err) - err = processDeleteMessages(streaming, dMsg, dData) + err = processDeleteMessages(streaming, segmentTypeGrowing, dMsg, dData) assert.NoError(t, err) }) @@ -234,14 +232,13 @@ func TestFlowGraphInsertNode_processDeleteMessages(t *testing.T) { dData, err := genFlowGraphDeleteData() assert.NoError(t, err) - err = processDeleteMessages(streaming, dMsg, dData) + err = processDeleteMessages(streaming, segmentTypeGrowing, dMsg, dData) assert.NoError(t, err) }) } func TestFlowGraphInsertNode_operate(t *testing.T) { - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() genMsgStreamInsertMsg := func() *msgstream.InsertMsg { iMsg, err := genSimpleInsertMsg(schema, defaultMsgLength) @@ -269,7 +266,7 @@ func TestFlowGraphInsertNode_operate(t *testing.T) { msg := []flowgraph.Msg{genInsertMsg()} insertNode.Operate(msg) - s, err := insertNode.streamingReplica.getSegmentByID(defaultSegmentID) + s, err := insertNode.metaReplica.getSegmentByID(defaultSegmentID, segmentTypeGrowing) assert.Nil(t, err) buf := make([]byte, 8) for i := 0; i < defaultMsgLength; i++ { @@ -345,7 +342,7 @@ func TestFlowGraphInsertNode_operate(t *testing.T) { msg := []flowgraph.Msg{genInsertMsg()} - err = insertNode.streamingReplica.removeCollection(defaultCollectionID) + err = insertNode.metaReplica.removeCollection(defaultCollectionID) assert.NoError(t, err) assert.Panics(t, func() { insertNode.Operate(msg) @@ -356,7 +353,7 @@ func TestFlowGraphInsertNode_operate(t *testing.T) { insertNode, err := getInsertNode() assert.NoError(t, err) - col, err := insertNode.streamingReplica.getCollectionByID(defaultCollectionID) + col, err := insertNode.metaReplica.getCollectionByID(defaultCollectionID) assert.NoError(t, err) for i, field := range col.schema.GetFields() { diff --git a/internal/querynode/flow_graph_message_test.go b/internal/querynode/flow_graph_message_test.go index 6863cfd16a..6de9bb9a54 100644 --- a/internal/querynode/flow_graph_message_test.go +++ b/internal/querynode/flow_graph_message_test.go @@ -22,12 +22,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/internal/mq/msgstream" - "github.com/milvus-io/milvus/internal/proto/schemapb" ) func TestFlowGraphMsg_insertMsg(t *testing.T) { - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() msg, err := genSimpleInsertMsg(schema, defaultMsgLength) assert.NoError(t, err) timestampMax := Timestamp(1000) diff --git a/internal/querynode/flow_graph_query_node.go b/internal/querynode/flow_graph_query_node.go index 9dfcb5ed6c..4b3caee824 100644 --- a/internal/querynode/flow_graph_query_node.go +++ b/internal/querynode/flow_graph_query_node.go @@ -45,7 +45,7 @@ type queryNodeFlowGraph struct { // newQueryNodeFlowGraph returns a new queryNodeFlowGraph func newQueryNodeFlowGraph(ctx context.Context, collectionID UniqueID, - streamingReplica ReplicaInterface, + metaReplica ReplicaInterface, tSafeReplica TSafeReplicaInterface, channel Channel, factory msgstream.Factory) (*queryNodeFlowGraph, error) { @@ -64,8 +64,8 @@ func newQueryNodeFlowGraph(ctx context.Context, if err != nil { return nil, err } - var filterDmNode node = newFilteredDmNode(streamingReplica, collectionID) - var insertNode node = newInsertNode(streamingReplica) + var filterDmNode node = newFilteredDmNode(metaReplica, collectionID) + var insertNode node = newInsertNode(metaReplica) var serviceTimeNode node = newServiceTimeNode(tSafeReplica, collectionID, channel) q.flowGraph.AddNode(dmStreamNode) @@ -115,7 +115,7 @@ func newQueryNodeFlowGraph(ctx context.Context, // newQueryNodeDeltaFlowGraph returns a new queryNodeFlowGraph func newQueryNodeDeltaFlowGraph(ctx context.Context, collectionID UniqueID, - historicalReplica ReplicaInterface, + metaReplica ReplicaInterface, tSafeReplica TSafeReplicaInterface, channel Channel, factory msgstream.Factory) (*queryNodeFlowGraph, error) { @@ -134,8 +134,8 @@ func newQueryNodeDeltaFlowGraph(ctx context.Context, if err != nil { return nil, err } - var filterDeleteNode node = newFilteredDeleteNode(historicalReplica, collectionID) - var deleteNode node = newDeleteNode(historicalReplica) + var filterDeleteNode node = newFilteredDeleteNode(metaReplica, collectionID) + var deleteNode node = newDeleteNode(metaReplica) var serviceTimeNode node = newServiceTimeNode(tSafeReplica, collectionID, channel) q.flowGraph.AddNode(dmStreamNode) diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index ea2ae880bf..fb34644b81 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -354,13 +354,13 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *queryPb.ReleaseS } // collection lock is not needed since we guarantee not query/search will be dispatch from leader for _, id := range in.SegmentIDs { - err := node.historical.removeSegment(id) + err := node.metaReplica.removeSegment(id, segmentTypeSealed) if err != nil { // not return, try to release all segments status.ErrorCode = commonpb.ErrorCode_UnexpectedError status.Reason = err.Error() } - err = node.streaming.removeSegment(id) + err = node.metaReplica.removeSegment(id, segmentTypeGrowing) if err != nil { // not return, try to release all segments status.ErrorCode = commonpb.ErrorCode_UnexpectedError @@ -392,33 +392,8 @@ func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *queryPb.GetSegmen segmentIDs[segmentID] = struct{}{} } - // get info from historical - historicalSegmentInfos, err := node.historical.getSegmentInfosByColID(in.CollectionID) - if err != nil { - log.Warn("GetSegmentInfo: get historical segmentInfo failed", zap.Int64("collectionID", in.CollectionID), zap.Error(err)) - res := &queryPb.GetSegmentInfoResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - } - return res, nil - } - segmentInfos = append(segmentInfos, filterSegmentInfo(historicalSegmentInfos, segmentIDs)...) - - // get info from streaming - streamingSegmentInfos, err := node.streaming.getSegmentInfosByColID(in.CollectionID) - if err != nil { - log.Warn("GetSegmentInfo: get streaming segmentInfo failed", zap.Int64("collectionID", in.CollectionID), zap.Error(err)) - res := &queryPb.GetSegmentInfoResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - } - return res, nil - } - segmentInfos = append(segmentInfos, filterSegmentInfo(streamingSegmentInfos, segmentIDs)...) + infos := node.metaReplica.getSegmentInfosByColID(in.CollectionID) + segmentInfos = append(segmentInfos, filterSegmentInfo(infos, segmentIDs)...) return &queryPb.GetSegmentInfoResponse{ Status: &commonpb.Status{ diff --git a/internal/querynode/impl_test.go b/internal/querynode/impl_test.go index b975747a7d..8131eaf690 100644 --- a/internal/querynode/impl_test.go +++ b/internal/querynode/impl_test.go @@ -32,7 +32,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/querypb" queryPb "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/sessionutil" @@ -90,9 +89,7 @@ func TestImpl_WatchDmChannels(t *testing.T) { node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) - + schema := genTestCollectionSchema() req := &queryPb.WatchDmChannelsRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_WatchDmChannels, @@ -120,8 +117,7 @@ func TestImpl_LoadSegments(t *testing.T) { node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() req := &queryPb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ @@ -226,32 +222,11 @@ func TestImpl_GetSegmentInfo(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode) }) - t.Run("test no collection in historical", func(t *testing.T) { + t.Run("test no collection in metaReplica", func(t *testing.T) { node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) - err = node.historical.removeCollection(defaultCollectionID) - assert.NoError(t, err) - - req := &queryPb.GetSegmentInfoRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_WatchQueryChannels, - MsgID: rand.Int63(), - }, - SegmentIDs: []UniqueID{defaultSegmentID}, - CollectionID: defaultCollectionID, - } - - rsp, err := node.GetSegmentInfo(ctx, req) - assert.Nil(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, rsp.Status.ErrorCode) - }) - - t.Run("test no collection in streaming", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - - err = node.streaming.removeCollection(defaultCollectionID) + err = node.metaReplica.removeCollection(defaultCollectionID) assert.NoError(t, err) req := &queryPb.GetSegmentInfoRequest{ @@ -281,7 +256,7 @@ func TestImpl_GetSegmentInfo(t *testing.T) { CollectionID: defaultCollectionID, } - seg, err := node.streaming.getSegmentByID(defaultSegmentID) + seg, err := node.metaReplica.getSegmentByID(defaultSegmentID, segmentTypeSealed) assert.NoError(t, err) seg.setType(segmentTypeSealed) @@ -304,7 +279,7 @@ func TestImpl_GetSegmentInfo(t *testing.T) { node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) - seg, err := node.historical.getSegmentByID(defaultSegmentID) + seg, err := node.metaReplica.getSegmentByID(defaultSegmentID, segmentTypeSealed) assert.NoError(t, err) seg.setIndexedFieldInfo(simpleFloatVecField.id, &IndexedFieldInfo{ @@ -333,82 +308,6 @@ func TestImpl_GetSegmentInfo(t *testing.T) { assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode) }) - - t.Run("test GetSegmentInfo without streaming partition", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - - req := &queryPb.GetSegmentInfoRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_WatchQueryChannels, - MsgID: rand.Int63(), - }, - SegmentIDs: []UniqueID{}, - CollectionID: defaultCollectionID, - } - - node.streaming.(*metaReplica).partitions = make(map[UniqueID]*Partition) - rsp, err := node.GetSegmentInfo(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode) - }) - - t.Run("test GetSegmentInfo without streaming segment", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - - req := &queryPb.GetSegmentInfoRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_WatchQueryChannels, - MsgID: rand.Int63(), - }, - SegmentIDs: []UniqueID{}, - CollectionID: defaultCollectionID, - } - - node.streaming.(*metaReplica).segments = make(map[UniqueID]*Segment) - rsp, err := node.GetSegmentInfo(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode) - }) - - t.Run("test GetSegmentInfo without historical partition", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - - req := &queryPb.GetSegmentInfoRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_WatchQueryChannels, - MsgID: rand.Int63(), - }, - SegmentIDs: []UniqueID{}, - CollectionID: defaultCollectionID, - } - - node.historical.(*metaReplica).partitions = make(map[UniqueID]*Partition) - rsp, err := node.GetSegmentInfo(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode) - }) - - t.Run("test GetSegmentInfo without historical segment", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - - req := &queryPb.GetSegmentInfoRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_WatchQueryChannels, - MsgID: rand.Int63(), - }, - SegmentIDs: []UniqueID{}, - CollectionID: defaultCollectionID, - } - - node.historical.(*metaReplica).segments = make(map[UniqueID]*Segment) - rsp, err := node.GetSegmentInfo(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, rsp.Status.ErrorCode) - }) } func TestImpl_isHealthy(t *testing.T) { @@ -530,10 +429,7 @@ func TestImpl_ReleaseSegments(t *testing.T) { SegmentIDs: []UniqueID{defaultSegmentID}, } - err = node.historical.removeSegment(defaultSegmentID) - assert.NoError(t, err) - - err = node.streaming.removeSegment(defaultSegmentID) + err = node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed) assert.NoError(t, err) status, err := node.ReleaseSegments(ctx, req) @@ -550,8 +446,7 @@ func TestImpl_Search(t *testing.T) { node, err := genSimpleQueryNode(ctx) require.NoError(t, err) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() req, err := genSearchRequest(defaultNQ, IndexFaissIDMap, schema) require.NoError(t, err) @@ -573,8 +468,7 @@ func TestImpl_Query(t *testing.T) { defer node.Stop() require.NoError(t, err) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() req, err := genRetrieveRequest(schema) require.NoError(t, err) diff --git a/internal/querynode/meta_replica.go b/internal/querynode/meta_replica.go index f98437709b..cc953efc65 100644 --- a/internal/querynode/meta_replica.go +++ b/internal/querynode/meta_replica.go @@ -37,7 +37,6 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/internal/common" - etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/proto/datapb" @@ -70,7 +69,7 @@ type ReplicaInterface interface { // getPKFieldIDsByCollectionID returns vector field ids of collection getPKFieldIDByCollectionID(collectionID UniqueID) (FieldID, error) // getSegmentInfosByColID return segments info by collectionID - getSegmentInfosByColID(collectionID UniqueID) ([]*querypb.SegmentInfo, error) + getSegmentInfosByColID(collectionID UniqueID) []*querypb.SegmentInfo // partition // addPartition adds a new partition to collection @@ -84,7 +83,7 @@ type ReplicaInterface interface { // getPartitionNum returns num of partitions getPartitionNum() int // getSegmentIDs returns segment ids - getSegmentIDs(partitionID UniqueID) ([]UniqueID, error) + getSegmentIDs(partitionID UniqueID, segType segmentType) ([]UniqueID, error) // getSegmentIDsByVChannel returns segment ids which virtual channel is vChannel getSegmentIDsByVChannel(partitionID UniqueID, vChannel Channel) ([]UniqueID, error) @@ -94,13 +93,13 @@ type ReplicaInterface interface { // setSegment adds a segment to collectionReplica setSegment(segment *Segment) error // removeSegment removes a segment from collectionReplica - removeSegment(segmentID UniqueID) error + removeSegment(segmentID UniqueID, segType segmentType) error // getSegmentByID returns the segment which id is segmentID - getSegmentByID(segmentID UniqueID) (*Segment, error) + getSegmentByID(segmentID UniqueID, segType segmentType) (*Segment, error) // hasSegment returns true if collectionReplica has the segment, false otherwise - hasSegment(segmentID UniqueID) bool + hasSegment(segmentID UniqueID, segType segmentType) (bool, error) // getSegmentNum returns num of segments in collectionReplica - getSegmentNum() int + getSegmentNum(segType segmentType) int // getSegmentStatistics returns the statistics of segments in collectionReplica getSegmentStatistics() []*internalpb.SegmentStats @@ -123,14 +122,13 @@ type ReplicaInterface interface { // collectionReplica is the data replication of memory data in query node. // It implements `ReplicaInterface` interface. type metaReplica struct { - mu sync.RWMutex // guards all - collections map[UniqueID]*Collection - partitions map[UniqueID]*Partition - segments map[UniqueID]*Segment + mu sync.RWMutex // guards all + collections map[UniqueID]*Collection + partitions map[UniqueID]*Partition + growingSegments map[UniqueID]*Segment + sealedSegments map[UniqueID]*Segment excludedSegments map[UniqueID][]*datapb.SegmentInfo // map[collectionID]segmentIDs - - etcdKV *etcdkv.EtcdKV } // getSegmentsMemSize get the memory size in bytes of all the Segments @@ -139,7 +137,10 @@ func (replica *metaReplica) getSegmentsMemSize() int64 { defer replica.mu.RUnlock() memSize := int64(0) - for _, segment := range replica.segments { + for _, segment := range replica.growingSegments { + memSize += segment.getMemSize() + } + for _, segment := range replica.sealedSegments { memSize += segment.getMemSize() } return memSize @@ -152,7 +153,8 @@ func (replica *metaReplica) printReplica() { log.Info("collections in collectionReplica", zap.Any("info", replica.collections)) log.Info("partitions in collectionReplica", zap.Any("info", replica.partitions)) - log.Info("segments in collectionReplica", zap.Any("info", replica.segments)) + log.Info("growingSegments in collectionReplica", zap.Any("info", replica.growingSegments)) + log.Info("sealedSegments in collectionReplica", zap.Any("info", replica.sealedSegments)) log.Info("excludedSegments in collectionReplica", zap.Any("info", replica.excludedSegments)) } @@ -262,9 +264,7 @@ func (replica *metaReplica) getPartitionIDs(collectionID UniqueID) ([]UniqueID, return nil, err } - parID := make([]UniqueID, len(collection.partitionIDs)) - copy(parID, collection.partitionIDs) - return parID, nil + return collection.getPartitionIDs(), nil } func (replica *metaReplica) getIndexedFieldIDByCollectionIDPrivate(collectionID UniqueID, segment *Segment) ([]FieldID, error) { @@ -338,33 +338,31 @@ func (replica *metaReplica) getFieldsByCollectionIDPrivate(collectionID UniqueID } // getSegmentInfosByColID return segments info by collectionID -func (replica *metaReplica) getSegmentInfosByColID(collectionID UniqueID) ([]*querypb.SegmentInfo, error) { +func (replica *metaReplica) getSegmentInfosByColID(collectionID UniqueID) []*querypb.SegmentInfo { replica.mu.RLock() defer replica.mu.RUnlock() segmentInfos := make([]*querypb.SegmentInfo, 0) - collection, ok := replica.collections[collectionID] + _, ok := replica.collections[collectionID] if !ok { // collection not exist, so result segmentInfos is empty - return segmentInfos, nil + return segmentInfos } - for _, partitionID := range collection.partitionIDs { - partition, ok := replica.partitions[partitionID] - if !ok { - return nil, fmt.Errorf("the meta of collection %d and partition %d are inconsistent in QueryNode", collectionID, partitionID) + for _, segment := range replica.growingSegments { + if segment.collectionID == collectionID { + segmentInfo := replica.getSegmentInfo(segment) + segmentInfos = append(segmentInfos, segmentInfo) } - for _, segmentID := range partition.segmentIDs { - segment, ok := replica.segments[segmentID] - if !ok { - return nil, fmt.Errorf("the meta of partition %d and segment %d are inconsistent in QueryNode", partitionID, segmentID) - } + } + for _, segment := range replica.sealedSegments { + if segment.collectionID == collectionID { segmentInfo := replica.getSegmentInfo(segment) segmentInfos = append(segmentInfos, segmentInfo) } } - return segmentInfos, nil + return segmentInfos } //----------------------------------------------------------------------------------------------------- partition @@ -418,9 +416,15 @@ func (replica *metaReplica) removePartitionPrivate(partitionID UniqueID, locked } // delete segments - for _, segmentID := range partition.segmentIDs { + ids, _ := partition.getSegmentIDs(segmentTypeGrowing) + for _, segmentID := range ids { // try to delete, ignore error - _ = replica.removeSegmentPrivate(segmentID) + _ = replica.removeSegmentPrivate(segmentID, segmentTypeGrowing) + } + ids, _ = partition.getSegmentIDs(segmentTypeSealed) + for _, segmentID := range ids { + // try to delete, ignore error + _ = replica.removeSegmentPrivate(segmentID, segmentTypeSealed) } collection.removePartitionID(partitionID) @@ -468,23 +472,24 @@ func (replica *metaReplica) getPartitionNum() int { } // getSegmentIDs returns segment ids -func (replica *metaReplica) getSegmentIDs(partitionID UniqueID) ([]UniqueID, error) { +func (replica *metaReplica) getSegmentIDs(partitionID UniqueID, segType segmentType) ([]UniqueID, error) { replica.mu.RLock() defer replica.mu.RUnlock() - return replica.getSegmentIDsPrivate(partitionID) + + return replica.getSegmentIDsPrivate(partitionID, segType) } // getSegmentIDsByVChannel returns segment ids which virtual channel is vChannel func (replica *metaReplica) getSegmentIDsByVChannel(partitionID UniqueID, vChannel Channel) ([]UniqueID, error) { replica.mu.RLock() defer replica.mu.RUnlock() - segmentIDs, err := replica.getSegmentIDsPrivate(partitionID) + segmentIDs, err := replica.getSegmentIDsPrivate(partitionID, segmentTypeGrowing) if err != nil { return nil, err } segmentIDsTmp := make([]UniqueID, 0) for _, segmentID := range segmentIDs { - segment, err := replica.getSegmentByIDPrivate(segmentID) + segment, err := replica.getSegmentByIDPrivate(segmentID, segmentTypeGrowing) if err != nil { return nil, err } @@ -497,15 +502,13 @@ func (replica *metaReplica) getSegmentIDsByVChannel(partitionID UniqueID, vChann } // getSegmentIDsPrivate is private function in collectionReplica, it returns segment ids -func (replica *metaReplica) getSegmentIDsPrivate(partitionID UniqueID) ([]UniqueID, error) { +func (replica *metaReplica) getSegmentIDsPrivate(partitionID UniqueID, segType segmentType) ([]UniqueID, error) { partition, err2 := replica.getPartitionByIDPrivate(partitionID) if err2 != nil { return nil, err2 } - segIDs := make([]UniqueID, len(partition.segmentIDs)) - copy(segIDs, partition.segmentIDs) - return segIDs, nil + return partition.getSegmentIDs(segType) } //----------------------------------------------------------------------------------------------------- segment @@ -513,6 +516,7 @@ func (replica *metaReplica) getSegmentIDsPrivate(partitionID UniqueID) ([]Unique func (replica *metaReplica) addSegment(segmentID UniqueID, partitionID UniqueID, collectionID UniqueID, vChannelID Channel, segType segmentType) error { replica.mu.Lock() defer replica.mu.Unlock() + collection, err := replica.getCollectionByIDPrivate(collectionID) if err != nil { return err @@ -531,11 +535,24 @@ func (replica *metaReplica) addSegmentPrivate(segmentID UniqueID, partitionID Un return err } - if replica.hasSegmentPrivate(segmentID) { + segType := segment.getType() + ok, err := replica.hasSegmentPrivate(segmentID, segType) + if err != nil { + return err + } + if ok { return nil } - partition.addSegmentID(segmentID) - replica.segments[segmentID] = segment + partition.addSegmentID(segmentID, segType) + + switch segType { + case segmentTypeGrowing: + replica.growingSegments[segmentID] = segment + case segmentTypeSealed: + replica.sealedSegments[segmentID] = segment + default: + return fmt.Errorf("unexpected segment type, segmentID = %d, segmentType = %s", segmentID, segType.String()) + } metrics.QueryNodeNumSegments.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Inc() return nil @@ -545,23 +562,30 @@ func (replica *metaReplica) addSegmentPrivate(segmentID UniqueID, partitionID Un func (replica *metaReplica) setSegment(segment *Segment) error { replica.mu.Lock() defer replica.mu.Unlock() + + if segment == nil { + return fmt.Errorf("nil segment when setSegment") + } + _, err := replica.getCollectionByIDPrivate(segment.collectionID) if err != nil { return err } + return replica.addSegmentPrivate(segment.segmentID, segment.partitionID, segment) } // removeSegment removes a segment from collectionReplica -func (replica *metaReplica) removeSegment(segmentID UniqueID) error { +func (replica *metaReplica) removeSegment(segmentID UniqueID, segType segmentType) error { replica.mu.Lock() defer replica.mu.Unlock() - return replica.removeSegmentPrivate(segmentID) + + return replica.removeSegmentPrivate(segmentID, segType) } // removeSegmentPrivate is private function in collectionReplica, to remove a segment from collectionReplica -func (replica *metaReplica) removeSegmentPrivate(segmentID UniqueID) error { - segment, err := replica.getSegmentByIDPrivate(segmentID) +func (replica *metaReplica) removeSegmentPrivate(segmentID UniqueID, segType segmentType) error { + segment, err := replica.getSegmentByIDPrivate(segmentID, segType) if err != nil { return err } @@ -570,76 +594,89 @@ func (replica *metaReplica) removeSegmentPrivate(segmentID UniqueID) error { if err2 != nil { return err } - - partition.removeSegmentID(segmentID) - delete(replica.segments, segmentID) + partition.removeSegmentID(segmentID, segType) + switch segType { + case segmentTypeGrowing: + delete(replica.growingSegments, segmentID) + case segmentTypeSealed: + delete(replica.sealedSegments, segmentID) + default: + err = fmt.Errorf("unexpected segment type, segmentID = %d, segmentType = %s", segmentID, segType.String()) + } deleteSegment(segment) metrics.QueryNodeNumSegments.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Dec() - return nil + return err } // getSegmentByID returns the segment which id is segmentID -func (replica *metaReplica) getSegmentByID(segmentID UniqueID) (*Segment, error) { +func (replica *metaReplica) getSegmentByID(segmentID UniqueID, segType segmentType) (*Segment, error) { replica.mu.RLock() defer replica.mu.RUnlock() - return replica.getSegmentByIDPrivate(segmentID) + return replica.getSegmentByIDPrivate(segmentID, segType) } // getSegmentByIDPrivate is private function in collectionReplica, it returns the segment which id is segmentID -func (replica *metaReplica) getSegmentByIDPrivate(segmentID UniqueID) (*Segment, error) { - segment, ok := replica.segments[segmentID] - if !ok { - return nil, fmt.Errorf("cannot find segment %d in QueryNode", segmentID) +func (replica *metaReplica) getSegmentByIDPrivate(segmentID UniqueID, segType segmentType) (*Segment, error) { + switch segType { + case segmentTypeGrowing: + segment, ok := replica.growingSegments[segmentID] + if !ok { + return nil, fmt.Errorf("cannot find growing segment %d in QueryNode", segmentID) + } + return segment, nil + case segmentTypeSealed: + segment, ok := replica.sealedSegments[segmentID] + if !ok { + return nil, fmt.Errorf("cannot find sealed segment %d in QueryNode", segmentID) + } + return segment, nil + default: + return nil, fmt.Errorf("unexpected segment type, segmentID = %d, segmentType = %s", segmentID, segType.String()) } - - return segment, nil } // hasSegment returns true if collectionReplica has the segment, false otherwise -func (replica *metaReplica) hasSegment(segmentID UniqueID) bool { +func (replica *metaReplica) hasSegment(segmentID UniqueID, segType segmentType) (bool, error) { replica.mu.RLock() defer replica.mu.RUnlock() - return replica.hasSegmentPrivate(segmentID) + return replica.hasSegmentPrivate(segmentID, segType) } // hasSegmentPrivate is private function in collectionReplica, to check if collectionReplica has the segment -func (replica *metaReplica) hasSegmentPrivate(segmentID UniqueID) bool { - _, ok := replica.segments[segmentID] - return ok +func (replica *metaReplica) hasSegmentPrivate(segmentID UniqueID, segType segmentType) (bool, error) { + switch segType { + case segmentTypeGrowing: + _, ok := replica.growingSegments[segmentID] + return ok, nil + case segmentTypeSealed: + _, ok := replica.sealedSegments[segmentID] + return ok, nil + default: + return false, fmt.Errorf("unexpected segment type, segmentID = %d, segmentType = %s", segmentID, segType.String()) + } } // getSegmentNum returns num of segments in collectionReplica -func (replica *metaReplica) getSegmentNum() int { +func (replica *metaReplica) getSegmentNum(segType segmentType) int { replica.mu.RLock() defer replica.mu.RUnlock() - return len(replica.segments) + + switch segType { + case segmentTypeGrowing: + return len(replica.growingSegments) + case segmentTypeSealed: + return len(replica.sealedSegments) + default: + log.Error("unexpected segment type", zap.String("segmentType", segType.String())) + return 0 + } } // getSegmentStatistics returns the statistics of segments in collectionReplica func (replica *metaReplica) getSegmentStatistics() []*internalpb.SegmentStats { - replica.mu.RLock() - defer replica.mu.RUnlock() - - var statisticData = make([]*internalpb.SegmentStats, 0) - - for segmentID, segment := range replica.segments { - currentMemSize := segment.getMemSize() - segment.lastMemSize = currentMemSize - segmentNumOfRows := segment.getRowCount() - - stat := internalpb.SegmentStats{ - SegmentID: segmentID, - MemorySize: currentMemSize, - NumRows: segmentNumOfRows, - RecentlyModified: segment.getRecentlyModified(), - } - - statisticData = append(statisticData, &stat) - segment.setRecentlyModified(false) - } - - return statisticData + // TODO: deprecated + return nil } // removeExcludedSegments will remove excludedSegments from collectionReplica @@ -685,23 +722,19 @@ func (replica *metaReplica) freeAll() { replica.collections = make(map[UniqueID]*Collection) replica.partitions = make(map[UniqueID]*Partition) - replica.segments = make(map[UniqueID]*Segment) + replica.growingSegments = make(map[UniqueID]*Segment) + replica.sealedSegments = make(map[UniqueID]*Segment) } // newCollectionReplica returns a new ReplicaInterface -func newCollectionReplica(etcdKv *etcdkv.EtcdKV) ReplicaInterface { - collections := make(map[UniqueID]*Collection) - partitions := make(map[UniqueID]*Partition) - segments := make(map[UniqueID]*Segment) - excludedSegments := make(map[UniqueID][]*datapb.SegmentInfo) - +func newCollectionReplica() ReplicaInterface { var replica ReplicaInterface = &metaReplica{ - collections: collections, - partitions: partitions, - segments: segments, + collections: make(map[UniqueID]*Collection), + partitions: make(map[UniqueID]*Partition), + growingSegments: make(map[UniqueID]*Segment), + sealedSegments: make(map[UniqueID]*Segment), - excludedSegments: excludedSegments, - etcdKV: etcdKv, + excludedSegments: make(map[UniqueID][]*datapb.SegmentInfo), } return replica diff --git a/internal/querynode/meta_replica_test.go b/internal/querynode/meta_replica_test.go index 681cb64506..849d4d045d 100644 --- a/internal/querynode/meta_replica_test.go +++ b/internal/querynode/meta_replica_test.go @@ -21,312 +21,252 @@ import ( "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/internal/proto/schemapb" ) -//----------------------------------------------------------------------------------------------------- collection -func TestMetaReplica_getCollectionNum(t *testing.T) { - node := newQueryNodeMock() - initTestMeta(t, node, 0, 0) - assert.Equal(t, node.historical.getCollectionNum(), 1) - err := node.Stop() - assert.NoError(t, err) -} - -func TestMetaReplica_addCollection(t *testing.T) { - node := newQueryNodeMock() - initTestMeta(t, node, 0, 0) - err := node.Stop() - assert.NoError(t, err) -} - -func TestMetaReplica_removeCollection(t *testing.T) { - node := newQueryNodeMock() - initTestMeta(t, node, 0, 0) - assert.Equal(t, node.historical.getCollectionNum(), 1) - - err := node.historical.removeCollection(0) - assert.NoError(t, err) - assert.Equal(t, node.historical.getCollectionNum(), 0) - err = node.Stop() - assert.NoError(t, err) -} - -func TestMetaReplica_getCollectionByID(t *testing.T) { - node := newQueryNodeMock() - collectionID := UniqueID(0) - initTestMeta(t, node, collectionID, 0) - targetCollection, err := node.historical.getCollectionByID(collectionID) - assert.NoError(t, err) - assert.NotNil(t, targetCollection) - assert.Equal(t, targetCollection.ID(), collectionID) - err = node.Stop() - assert.NoError(t, err) -} - -func TestMetaReplica_hasCollection(t *testing.T) { - node := newQueryNodeMock() - collectionID := UniqueID(0) - initTestMeta(t, node, collectionID, 0) - - hasCollection := node.historical.hasCollection(collectionID) - assert.Equal(t, hasCollection, true) - hasCollection = node.historical.hasCollection(UniqueID(1)) - assert.Equal(t, hasCollection, false) - - err := node.Stop() - assert.NoError(t, err) -} - -//----------------------------------------------------------------------------------------------------- partition -func TestMetaReplica_getPartitionNum(t *testing.T) { - node := newQueryNodeMock() - collectionID := UniqueID(0) - initTestMeta(t, node, collectionID, 0) - - partitionIDs := []UniqueID{1, 2, 3} - for _, id := range partitionIDs { - err := node.historical.addPartition(collectionID, id) +func TestMetaReplica_collection(t *testing.T) { + t.Run("test getCollectionNum", func(t *testing.T) { + replica, err := genSimpleReplica() assert.NoError(t, err) - partition, err := node.historical.getPartitionByID(id) + defer replica.freeAll() + assert.Equal(t, 1, replica.getCollectionNum()) + }) + + t.Run("test addCollection", func(t *testing.T) { + replica, err := genSimpleReplica() assert.NoError(t, err) - assert.Equal(t, partition.ID(), id) - } + defer replica.freeAll() + replica.addCollection(defaultCollectionID+1, genTestCollectionSchema()) + assert.Equal(t, 2, replica.getCollectionNum()) + }) - partitionNum := node.historical.getPartitionNum() - assert.Equal(t, partitionNum, len(partitionIDs)) - err := node.Stop() - assert.NoError(t, err) -} - -func TestMetaReplica_addPartition(t *testing.T) { - node := newQueryNodeMock() - collectionID := UniqueID(0) - initTestMeta(t, node, collectionID, 0) - - partitionIDs := []UniqueID{1, 2, 3} - for _, id := range partitionIDs { - err := node.historical.addPartition(collectionID, id) + t.Run("test removeCollection", func(t *testing.T) { + replica, err := genSimpleReplica() assert.NoError(t, err) - partition, err := node.historical.getPartitionByID(id) + defer replica.freeAll() + err = replica.removeCollection(defaultCollectionID) assert.NoError(t, err) - assert.Equal(t, partition.ID(), id) - } - err := node.Stop() - assert.NoError(t, err) -} + }) -func TestMetaReplica_removePartition(t *testing.T) { - node := newQueryNodeMock() - collectionID := UniqueID(0) - initTestMeta(t, node, collectionID, 0) - - partitionIDs := []UniqueID{1, 2, 3} - - for _, id := range partitionIDs { - err := node.historical.addPartition(collectionID, id) + t.Run("test getCollectionByID", func(t *testing.T) { + replica, err := genSimpleReplica() assert.NoError(t, err) - partition, err := node.historical.getPartitionByID(id) + defer replica.freeAll() + + targetCollection, err := replica.getCollectionByID(defaultCollectionID) assert.NoError(t, err) - assert.Equal(t, partition.ID(), id) - err = node.historical.removePartition(id) + assert.NotNil(t, targetCollection) + assert.Equal(t, defaultCollectionID, targetCollection.ID()) + }) + + t.Run("test hasCollection", func(t *testing.T) { + replica, err := genSimpleReplica() assert.NoError(t, err) - } - err := node.Stop() - assert.NoError(t, err) -} + defer replica.freeAll() -func TestMetaReplica_getPartitionByTag(t *testing.T) { - node := newQueryNodeMock() - collectionID := UniqueID(0) - initTestMeta(t, node, collectionID, 0) + hasCollection := replica.hasCollection(defaultCollectionID) + assert.Equal(t, true, hasCollection) + hasCollection = replica.hasCollection(defaultCollectionID + 1) + assert.Equal(t, false, hasCollection) + }) - collection, err := node.historical.getCollectionByID(collectionID) - assert.NoError(t, err) - - for _, id := range collection.partitionIDs { - err := node.historical.addPartition(collectionID, id) - assert.NoError(t, err) - partition, err := node.historical.getPartitionByID(id) - assert.NoError(t, err) - assert.Equal(t, partition.ID(), id) - assert.NotNil(t, partition) - } - err = node.Stop() - assert.NoError(t, err) -} - -func TestMetaReplica_hasPartition(t *testing.T) { - node := newQueryNodeMock() - collectionID := UniqueID(0) - initTestMeta(t, node, collectionID, 0) - - collection, err := node.historical.getCollectionByID(collectionID) - assert.NoError(t, err) - err = node.historical.addPartition(collectionID, collection.partitionIDs[0]) - assert.NoError(t, err) - hasPartition := node.historical.hasPartition(defaultPartitionID) - assert.Equal(t, hasPartition, true) - hasPartition = node.historical.hasPartition(defaultPartitionID + 1) - assert.Equal(t, hasPartition, false) - err = node.Stop() - assert.NoError(t, err) -} - -//----------------------------------------------------------------------------------------------------- segment -func TestMetaReplica_addSegment(t *testing.T) { - node := newQueryNodeMock() - collectionID := UniqueID(0) - initTestMeta(t, node, collectionID, 0) - - const segmentNum = 3 - for i := 0; i < segmentNum; i++ { - err := node.historical.addSegment(UniqueID(i), defaultPartitionID, collectionID, "", segmentTypeGrowing) - assert.NoError(t, err) - targetSeg, err := node.historical.getSegmentByID(UniqueID(i)) - assert.NoError(t, err) - assert.Equal(t, targetSeg.segmentID, UniqueID(i)) - } - - err := node.Stop() - assert.NoError(t, err) -} - -func TestMetaReplica_removeSegment(t *testing.T) { - node := newQueryNodeMock() - collectionID := UniqueID(0) - initTestMeta(t, node, collectionID, 0) - - const segmentNum = 3 - - for i := 0; i < segmentNum; i++ { - err := node.historical.addSegment(UniqueID(i), defaultPartitionID, collectionID, "", segmentTypeGrowing) - assert.NoError(t, err) - targetSeg, err := node.historical.getSegmentByID(UniqueID(i)) - assert.NoError(t, err) - assert.Equal(t, targetSeg.segmentID, UniqueID(i)) - err = node.historical.removeSegment(UniqueID(i)) - assert.NoError(t, err) - } - - err := node.Stop() - assert.NoError(t, err) -} - -func TestMetaReplica_getSegmentByID(t *testing.T) { - node := newQueryNodeMock() - collectionID := UniqueID(0) - initTestMeta(t, node, collectionID, 0) - - const segmentNum = 3 - - for i := 0; i < segmentNum; i++ { - err := node.historical.addSegment(UniqueID(i), defaultPartitionID, collectionID, "", segmentTypeGrowing) - assert.NoError(t, err) - targetSeg, err := node.historical.getSegmentByID(UniqueID(i)) - assert.NoError(t, err) - assert.Equal(t, targetSeg.segmentID, UniqueID(i)) - } - - err := node.Stop() - assert.NoError(t, err) -} - -func TestMetaReplica_getSegmentInfosByColID(t *testing.T) { - node := newQueryNodeMock() - collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) - collection := node.historical.addCollection(collectionID, schema) - node.historical.addPartition(collectionID, defaultPartitionID) - - // test get indexed segment info - vectorFieldIDDs, err := node.historical.getVecFieldIDsByCollectionID(collectionID) - assert.NoError(t, err) - assert.Equal(t, 2, len(vectorFieldIDDs)) - fieldID := vectorFieldIDDs[0] - - indexID := UniqueID(10000) - indexInfo := &IndexedFieldInfo{ - indexInfo: &querypb.FieldIndexInfo{ - IndexName: "test-index-name", - IndexID: indexID, - EnableIndex: true, - }, - } - - segment1, err := newSegment(collection, UniqueID(1), defaultPartitionID, collectionID, "", segmentTypeGrowing) - assert.NoError(t, err) - err = node.historical.setSegment(segment1) - assert.NoError(t, err) - - segment2, err := newSegment(collection, UniqueID(2), defaultPartitionID, collectionID, "", segmentTypeSealed) - assert.NoError(t, err) - segment2.setIndexedFieldInfo(fieldID, indexInfo) - err = node.historical.setSegment(segment2) - assert.NoError(t, err) - - targetSegs, err := node.historical.getSegmentInfosByColID(collectionID) - assert.NoError(t, err) - assert.Equal(t, 2, len(targetSegs)) - for _, segment := range targetSegs { - if segment.GetSegmentState() == segmentTypeGrowing { - assert.Equal(t, UniqueID(0), segment.IndexID) - } else { - assert.Equal(t, indexID, segment.IndexID) - } - } - - err = node.Stop() - assert.NoError(t, err) -} - -func TestMetaReplica_hasSegment(t *testing.T) { - node := newQueryNodeMock() - collectionID := UniqueID(0) - initTestMeta(t, node, collectionID, 0) - - const segmentNum = 3 - - for i := 0; i < segmentNum; i++ { - err := node.historical.addSegment(UniqueID(i), defaultPartitionID, collectionID, "", segmentTypeGrowing) - assert.NoError(t, err) - targetSeg, err := node.historical.getSegmentByID(UniqueID(i)) - assert.NoError(t, err) - assert.Equal(t, targetSeg.segmentID, UniqueID(i)) - hasSeg := node.historical.hasSegment(UniqueID(i)) - assert.Equal(t, hasSeg, true) - hasSeg = node.historical.hasSegment(UniqueID(i + 100)) - assert.Equal(t, hasSeg, false) - } - - err := node.Stop() - assert.NoError(t, err) -} - -func TestMetaReplica_freeAll(t *testing.T) { - node := newQueryNodeMock() - collectionID := UniqueID(0) - initTestMeta(t, node, collectionID, 0) - - err := node.Stop() - assert.NoError(t, err) -} - -func TestMetaReplica_statistic(t *testing.T) { t.Run("test getCollectionIDs", func(t *testing.T) { replica, err := genSimpleReplica() assert.NoError(t, err) + defer replica.freeAll() ids := replica.getCollectionIDs() assert.Len(t, ids, 1) assert.Equal(t, defaultCollectionID, ids[0]) }) +} - t.Run("test getCollectionIDs", func(t *testing.T) { +func TestMetaReplica_partition(t *testing.T) { + t.Run("test addPartition, getPartitionNum and getPartitionByID", func(t *testing.T) { replica, err := genSimpleReplica() assert.NoError(t, err) - num := replica.getSegmentNum() - assert.Equal(t, 0, num) + defer replica.freeAll() + + partitionIDs := []UniqueID{1, 2, 3} + for _, id := range partitionIDs { + err := replica.addPartition(defaultCollectionID, id) + assert.NoError(t, err) + partition, err := replica.getPartitionByID(id) + assert.NoError(t, err) + assert.Equal(t, id, partition.ID()) + } + + partitionNum := replica.getPartitionNum() + assert.Equal(t, len(partitionIDs), partitionNum) + }) + + t.Run("test removePartition", func(t *testing.T) { + replica, err := genSimpleReplica() + assert.NoError(t, err) + defer replica.freeAll() + + partitionIDs := []UniqueID{1, 2, 3} + + for _, id := range partitionIDs { + err := replica.addPartition(defaultCollectionID, id) + assert.NoError(t, err) + partition, err := replica.getPartitionByID(id) + assert.NoError(t, err) + assert.Equal(t, id, partition.ID()) + err = replica.removePartition(id) + assert.NoError(t, err) + _, err = replica.getPartitionByID(id) + assert.Error(t, err) + } + }) + + t.Run("test hasPartition", func(t *testing.T) { + replica, err := genSimpleReplica() + assert.NoError(t, err) + defer replica.freeAll() + + collection, err := replica.getCollectionByID(defaultCollectionID) + assert.NoError(t, err) + err = replica.addPartition(defaultCollectionID, collection.partitionIDs[0]) + assert.NoError(t, err) + hasPartition := replica.hasPartition(defaultPartitionID) + assert.Equal(t, true, hasPartition) + hasPartition = replica.hasPartition(defaultPartitionID + 1) + assert.Equal(t, false, hasPartition) }) } + +func TestMetaReplica_segment(t *testing.T) { + t.Run("test addSegment and getSegmentByID", func(t *testing.T) { + replica, err := genSimpleReplica() + assert.NoError(t, err) + defer replica.freeAll() + + const segmentNum = 3 + for i := 0; i < segmentNum; i++ { + err := replica.addSegment(UniqueID(i), defaultPartitionID, defaultCollectionID, "", segmentTypeGrowing) + assert.NoError(t, err) + targetSeg, err := replica.getSegmentByID(UniqueID(i), segmentTypeGrowing) + assert.NoError(t, err) + assert.Equal(t, UniqueID(i), targetSeg.segmentID) + } + }) + + t.Run("test removeSegment", func(t *testing.T) { + replica, err := genSimpleReplica() + assert.NoError(t, err) + defer replica.freeAll() + + const segmentNum = 3 + for i := 0; i < segmentNum; i++ { + err := replica.addSegment(UniqueID(i), defaultPartitionID, defaultCollectionID, "", segmentTypeGrowing) + assert.NoError(t, err) + targetSeg, err := replica.getSegmentByID(UniqueID(i), segmentTypeGrowing) + assert.NoError(t, err) + assert.Equal(t, UniqueID(i), targetSeg.segmentID) + err = replica.removeSegment(UniqueID(i), segmentTypeGrowing) + assert.NoError(t, err) + } + }) + + t.Run("test hasSegment", func(t *testing.T) { + replica, err := genSimpleReplica() + assert.NoError(t, err) + defer replica.freeAll() + + const segmentNum = 3 + for i := 0; i < segmentNum; i++ { + err := replica.addSegment(UniqueID(i), defaultPartitionID, defaultCollectionID, "", segmentTypeGrowing) + assert.NoError(t, err) + targetSeg, err := replica.getSegmentByID(UniqueID(i), segmentTypeGrowing) + assert.NoError(t, err) + assert.Equal(t, UniqueID(i), targetSeg.segmentID) + hasSeg, err := replica.hasSegment(UniqueID(i), segmentTypeGrowing) + assert.NoError(t, err) + assert.Equal(t, true, hasSeg) + hasSeg, err = replica.hasSegment(UniqueID(i+100), segmentTypeGrowing) + assert.NoError(t, err) + assert.Equal(t, false, hasSeg) + } + }) + + t.Run("test invalid segment type", func(t *testing.T) { + replica, err := genSimpleReplica() + assert.NoError(t, err) + defer replica.freeAll() + + invalidType := commonpb.SegmentState_NotExist + err = replica.addSegment(defaultSegmentID, defaultPartitionID, defaultCollectionID, "", invalidType) + assert.Error(t, err) + _, err = replica.getSegmentByID(defaultSegmentID, invalidType) + assert.Error(t, err) + _, err = replica.getSegmentIDs(defaultPartitionID, invalidType) + assert.Error(t, err) + err = replica.removeSegment(defaultSegmentID, invalidType) + assert.Error(t, err) + _, err = replica.hasSegment(defaultSegmentID, invalidType) + assert.Error(t, err) + num := replica.getSegmentNum(invalidType) + assert.Equal(t, 0, num) + }) + + t.Run("test getSegmentInfosByColID", func(t *testing.T) { + replica, err := genSimpleReplica() + assert.NoError(t, err) + defer replica.freeAll() + + schema := genTestCollectionSchema() + collection := replica.addCollection(defaultCollectionID, schema) + replica.addPartition(defaultCollectionID, defaultPartitionID) + + // test get indexed segment info + vectorFieldIDDs, err := replica.getVecFieldIDsByCollectionID(defaultCollectionID) + assert.NoError(t, err) + assert.Equal(t, 2, len(vectorFieldIDDs)) + fieldID := vectorFieldIDDs[0] + + indexID := UniqueID(10000) + indexInfo := &IndexedFieldInfo{ + indexInfo: &querypb.FieldIndexInfo{ + IndexName: "test-index-name", + IndexID: indexID, + EnableIndex: true, + }, + } + + segment1, err := newSegment(collection, UniqueID(1), defaultPartitionID, defaultCollectionID, "", segmentTypeGrowing) + assert.NoError(t, err) + err = replica.setSegment(segment1) + assert.NoError(t, err) + + segment2, err := newSegment(collection, UniqueID(2), defaultPartitionID, defaultCollectionID, "", segmentTypeSealed) + assert.NoError(t, err) + segment2.setIndexedFieldInfo(fieldID, indexInfo) + err = replica.setSegment(segment2) + assert.NoError(t, err) + + targetSegs := replica.getSegmentInfosByColID(defaultCollectionID) + assert.Equal(t, 2, len(targetSegs)) + for _, segment := range targetSegs { + if segment.GetSegmentState() == segmentTypeGrowing { + assert.Equal(t, UniqueID(0), segment.IndexID) + } else { + assert.Equal(t, indexID, segment.IndexID) + } + } + }) +} + +func TestMetaReplica_freeAll(t *testing.T) { + replica, err := genSimpleReplica() + assert.NoError(t, err) + replica.freeAll() + num := replica.getCollectionNum() + assert.Equal(t, 0, num) + num = replica.getPartitionNum() + assert.Equal(t, 0, num) + num = replica.getSegmentNum(segmentTypeGrowing) + assert.Equal(t, 0, num) + num = replica.getSegmentNum(segmentTypeSealed) + assert.Equal(t, 0, num) +} diff --git a/internal/querynode/mock_test.go b/internal/querynode/mock_test.go index 3d9d8b34db..265bfcc909 100644 --- a/internal/querynode/mock_test.go +++ b/internal/querynode/mock_test.go @@ -343,7 +343,7 @@ func loadIndexForSegment(ctx context.Context, node *QueryNode, segmentID UniqueI return err } - segment, err := node.loader.historicalReplica.getSegmentByID(segmentID) + segment, err := node.loader.metaReplica.getSegmentByID(segmentID, segmentTypeSealed) if err != nil { return err } @@ -499,7 +499,7 @@ func genIndexParams(indexType, metricType string) (map[string]string, map[string return typeParams, indexParams } -func genTestCollectionSchema(pkType schemapb.DataType) *schemapb.CollectionSchema { +func genTestCollectionSchema(pkTypes ...schemapb.DataType) *schemapb.CollectionSchema { fieldBool := genConstantFieldSchema(simpleBoolField) fieldInt8 := genConstantFieldSchema(simpleInt8Field) fieldInt16 := genConstantFieldSchema(simpleInt16Field) @@ -509,6 +509,12 @@ func genTestCollectionSchema(pkType schemapb.DataType) *schemapb.CollectionSchem floatVecFieldSchema := genVectorFieldSchema(simpleFloatVecField) binVecFieldSchema := genVectorFieldSchema(simpleBinVecField) var pkFieldSchema *schemapb.FieldSchema + var pkType schemapb.DataType + if len(pkTypes) == 0 { + pkType = schemapb.DataType_Int64 + } else { + pkType = pkTypes[0] + } switch pkType { case schemapb.DataType_Int64: pkFieldSchema = genPKFieldSchema(simpleInt64Field) @@ -1213,7 +1219,7 @@ func genSealedSegment(schema *schemapb.CollectionSchema, } func genSimpleSealedSegment(msgLength int) (*Segment, error) { - schema := genTestCollectionSchema(schemapb.DataType_Int64) + schema := genTestCollectionSchema() return genSealedSegment(schema, defaultCollectionID, defaultPartitionID, @@ -1223,27 +1229,23 @@ func genSimpleSealedSegment(msgLength int) (*Segment, error) { } func genSimpleReplica() (ReplicaInterface, error) { - kv, err := genEtcdKV() - if err != nil { - return nil, err - } - r := newCollectionReplica(kv) - schema := genTestCollectionSchema(schemapb.DataType_Int64) + r := newCollectionReplica() + schema := genTestCollectionSchema() r.addCollection(defaultCollectionID, schema) - err = r.addPartition(defaultCollectionID, defaultPartitionID) + err := r.addPartition(defaultCollectionID, defaultPartitionID) return r, err } -func genSimpleSegmentLoaderWithMqFactory(historicalReplica ReplicaInterface, streamingReplica ReplicaInterface, factory msgstream.Factory) (*segmentLoader, error) { +func genSimpleSegmentLoaderWithMqFactory(metaReplica ReplicaInterface, factory msgstream.Factory) (*segmentLoader, error) { kv, err := genEtcdKV() if err != nil { return nil, err } cm := storage.NewLocalChunkManager(storage.RootPath(defaultLocalStorage)) - return newSegmentLoader(historicalReplica, streamingReplica, kv, cm, factory), nil + return newSegmentLoader(metaReplica, kv, cm, factory), nil } -func genSimpleHistorical(ctx context.Context) (ReplicaInterface, error) { +func genSimpleReplicaWithSealSegment(ctx context.Context) (ReplicaInterface, error) { r, err := genSimpleReplica() if err != nil { return nil, err @@ -1266,7 +1268,7 @@ func genSimpleHistorical(ctx context.Context) (ReplicaInterface, error) { return r, nil } -func genSimpleStreaming(ctx context.Context) (ReplicaInterface, error) { +func genSimpleReplicaWithGrowingSegment() (ReplicaInterface, error) { r, err := genSimpleReplica() if err != nil { return nil, err @@ -1652,23 +1654,18 @@ func genSimpleQueryNodeWithMQFactory(ctx context.Context, fac dependency.Factory node.tSafeReplica = newTSafeReplica() - streaming, err := genSimpleStreaming(ctx) + replica, err := genSimpleReplicaWithSealSegment(ctx) if err != nil { return nil, err } node.tSafeReplica.addTSafe(defaultDMLChannel) - historical, err := genSimpleHistorical(ctx) - if err != nil { - return nil, err - } node.tSafeReplica.addTSafe(defaultDeltaChannel) - node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, streaming, historical, node.tSafeReplica, node.factory) + node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, replica, node.tSafeReplica, node.factory) - node.streaming = streaming - node.historical = historical + node.metaReplica = replica - loader, err := genSimpleSegmentLoaderWithMqFactory(historical, streaming, fac) + loader, err := genSimpleSegmentLoaderWithMqFactory(replica, fac) if err != nil { return nil, err } @@ -1681,7 +1678,7 @@ func genSimpleQueryNodeWithMQFactory(ctx context.Context, fac dependency.Factory node.ShardClusterService = newShardClusterService(node.etcdCli, node.session, node) node.queryShardService = newQueryShardService(node.queryNodeLoopCtx, - node.historical, node.streaming, node.tSafeReplica, + node.metaReplica, node.tSafeReplica, node.ShardClusterService, node.factory, node.scheduler) node.UpdateStateCode(internalpb.StateCode_Healthy) diff --git a/internal/querynode/partition.go b/internal/querynode/partition.go index 1e72d8108f..44c9b34713 100644 --- a/internal/querynode/partition.go +++ b/internal/querynode/partition.go @@ -29,6 +29,8 @@ package querynode */ import "C" import ( + "fmt" + "go.uber.org/zap" "github.com/milvus-io/milvus/internal/log" @@ -36,9 +38,10 @@ import ( // Partition is a logical division of Collection and can be considered as an attribute of Segment. type Partition struct { - collectionID UniqueID - partitionID UniqueID - segmentIDs []UniqueID + collectionID UniqueID + partitionID UniqueID + growingSegmentIDs []UniqueID + sealedSegmentIDs []UniqueID } // ID returns the identity of the partition. @@ -46,21 +49,58 @@ func (p *Partition) ID() UniqueID { return p.partitionID } +// getSegmentIDs returns segment ids by DataScope +func (p *Partition) getSegmentIDs(segType segmentType) ([]UniqueID, error) { + switch segType { + case segmentTypeGrowing: + dst := make([]UniqueID, len(p.growingSegmentIDs)) + copy(dst, p.growingSegmentIDs) + return dst, nil + case segmentTypeSealed: + dst := make([]UniqueID, len(p.sealedSegmentIDs)) + copy(dst, p.sealedSegmentIDs) + return dst, nil + default: + return nil, fmt.Errorf("unexpected segmentType %s", segType.String()) + } +} + // addSegmentID add segmentID to segmentIDs -func (p *Partition) addSegmentID(segmentID UniqueID) { - p.segmentIDs = append(p.segmentIDs, segmentID) - log.Info("add a segment to replica", zap.Int64("collectionID", p.collectionID), zap.Int64("partitionID", p.partitionID), zap.Int64("segmentID", segmentID)) +func (p *Partition) addSegmentID(segmentID UniqueID, segType segmentType) { + switch segType { + case segmentTypeGrowing: + p.growingSegmentIDs = append(p.growingSegmentIDs, segmentID) + case segmentTypeSealed: + p.sealedSegmentIDs = append(p.sealedSegmentIDs, segmentID) + default: + return + } + log.Info("add a segment to replica", + zap.Int64("collectionID", p.collectionID), + zap.Int64("partitionID", p.partitionID), + zap.Int64("segmentID", segmentID), + zap.String("segmentType", segType.String())) } // removeSegmentID removes segmentID from segmentIDs -func (p *Partition) removeSegmentID(segmentID UniqueID) { - tmpIDs := make([]UniqueID, 0) - for _, id := range p.segmentIDs { - if id != segmentID { - tmpIDs = append(tmpIDs, id) +func (p *Partition) removeSegmentID(segmentID UniqueID, segType segmentType) { + deleteFunc := func(segmentIDs []UniqueID) []UniqueID { + tmpIDs := make([]UniqueID, 0) + for _, id := range segmentIDs { + if id != segmentID { + tmpIDs = append(tmpIDs, id) + } } + return tmpIDs + } + switch segType { + case segmentTypeGrowing: + p.growingSegmentIDs = deleteFunc(p.growingSegmentIDs) + case segmentTypeSealed: + p.sealedSegmentIDs = deleteFunc(p.sealedSegmentIDs) + default: + return } - p.segmentIDs = tmpIDs log.Info("remove a segment from replica", zap.Int64("collectionID", p.collectionID), zap.Int64("partitionID", p.partitionID), zap.Int64("segmentID", segmentID)) } diff --git a/internal/querynode/plan_test.go b/internal/querynode/plan_test.go index 145de61ef5..0549afeb9b 100644 --- a/internal/querynode/plan_test.go +++ b/internal/querynode/plan_test.go @@ -27,13 +27,11 @@ import ( "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/planpb" - "github.com/milvus-io/milvus/internal/proto/schemapb" ) func TestPlan_Plan(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) @@ -54,7 +52,7 @@ func TestPlan_createSearchPlanByExpr(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - historical, err := genSimpleHistorical(ctx) + historical, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) col, err := historical.getCollectionByID(defaultCollectionID) @@ -84,8 +82,7 @@ func TestPlan_NilCollection(t *testing.T) { func TestPlan_PlaceholderGroup(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) dslString := "{\"bool\": { \n\"vector\": {\n \"floatVectorField\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\n \"topk\": 10 \n,\"round_decimal\": 6\n } \n } \n } \n }" diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index a56cacaf34..4626f0941a 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -90,8 +90,7 @@ type QueryNode struct { initOnce sync.Once // internal components - historical ReplicaInterface - streaming ReplicaInterface + metaReplica ReplicaInterface // tSafeReplica tSafeReplica TSafeReplicaInterface @@ -304,17 +303,15 @@ func (node *QueryNode) Init() error { node.etcdKV = etcdkv.NewEtcdKV(node.etcdCli, Params.EtcdCfg.MetaRootPath) log.Info("queryNode try to connect etcd success", zap.Any("MetaRootPath", Params.EtcdCfg.MetaRootPath)) - node.streaming = newCollectionReplica(node.etcdKV) - node.historical = newCollectionReplica(node.etcdKV) + node.metaReplica = newCollectionReplica() node.loader = newSegmentLoader( - node.historical, - node.streaming, + node.metaReplica, node.etcdKV, node.vectorStorage, node.factory) - node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.streaming, node.historical, node.tSafeReplica, node.factory) + node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.metaReplica, node.tSafeReplica, node.factory) node.InitSegcore() @@ -351,7 +348,7 @@ func (node *QueryNode) Start() error { // create shardClusterService for shardLeader functions. node.ShardClusterService = newShardClusterService(node.etcdCli, node.session, node) // create shard-level query service - node.queryShardService = newQueryShardService(node.queryNodeLoopCtx, node.historical, node.streaming, node.tSafeReplica, + node.queryShardService = newQueryShardService(node.queryNodeLoopCtx, node.metaReplica, node.tSafeReplica, node.ShardClusterService, node.factory, node.scheduler) Params.QueryNodeCfg.CreatedTime = time.Now() @@ -377,12 +374,8 @@ func (node *QueryNode) Stop() error { node.dataSyncService.close() } - // release streaming first for query/search holds query lock in streaming collection - if node.streaming != nil { - node.streaming.freeAll() - } - if node.historical != nil { - node.historical.freeAll() + if node.metaReplica != nil { + node.metaReplica.freeAll() } if node.queryShardService != nil { diff --git a/internal/querynode/query_node_test.go b/internal/querynode/query_node_test.go index fca9d226c7..9fcae140fd 100644 --- a/internal/querynode/query_node_test.go +++ b/internal/querynode/query_node_test.go @@ -38,7 +38,6 @@ import ( etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/sessionutil" @@ -57,112 +56,20 @@ func setup() { Params.EtcdCfg.MetaRootPath = "/etcd/test/root/querynode" } -//func genTestCollectionSchema(collectionID UniqueID, isBinary bool, dim int) *schemapb.CollectionSchema { -// var fieldVec schemapb.FieldSchema -// if isBinary { -// fieldVec = schemapb.FieldSchema{ -// FieldID: UniqueID(100), -// Name: "vec", -// IsPrimaryKey: false, -// DataType: schemapb.DataType_BinaryVector, -// TypeParams: []*commonpb.KeyValuePair{ -// { -// Key: "dim", -// Value: strconv.Itoa(dim * 8), -// }, -// }, -// IndexParams: []*commonpb.KeyValuePair{ -// { -// Key: "metric_type", -// Value: "JACCARD", -// }, -// }, -// } -// } else { -// fieldVec = schemapb.FieldSchema{ -// FieldID: UniqueID(100), -// Name: "vec", -// IsPrimaryKey: false, -// DataType: schemapb.DataType_FloatVector, -// TypeParams: []*commonpb.KeyValuePair{ -// { -// Key: "dim", -// Value: strconv.Itoa(dim), -// }, -// }, -// IndexParams: []*commonpb.KeyValuePair{ -// { -// Key: "metric_type", -// Value: "L2", -// }, -// }, -// } -// } -// -// fieldInt := schemapb.FieldSchema{ -// FieldID: UniqueID(101), -// Name: "age", -// IsPrimaryKey: false, -// DataType: schemapb.DataType_Int32, -// } -// -// schema := &schemapb.CollectionSchema{ -// AutoID: true, -// Fields: []*schemapb.FieldSchema{ -// &fieldVec, &fieldInt, -// }, -// } -// -// return schema -//} -// -//func genTestCollectionMeta(collectionID UniqueID, isBinary bool) *etcdpb.CollectionInfo { -// schema := genTestCollectionSchema(collectionID, isBinary, 16) -// -// collectionMeta := etcdpb.CollectionInfo{ -// ID: collectionID, -// Schema: schema, -// CreateTime: Timestamp(0), -// PartitionIDs: []UniqueID{defaultPartitionID}, -// } -// -// return &collectionMeta -//} -// -//func genTestCollectionMetaWithPK(collectionID UniqueID, isBinary bool) *etcdpb.CollectionInfo { -// schema := genTestCollectionSchema(collectionID, isBinary, 16) -// schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ -// FieldID: UniqueID(0), -// Name: "id", -// IsPrimaryKey: true, -// DataType: schemapb.DataType_Int64, -// }) -// -// collectionMeta := etcdpb.CollectionInfo{ -// ID: collectionID, -// Schema: schema, -// CreateTime: Timestamp(0), -// PartitionIDs: []UniqueID{defaultPartitionID}, -// } -// -// return &collectionMeta -//} - func initTestMeta(t *testing.T, node *QueryNode, collectionID UniqueID, segmentID UniqueID, optional ...bool) { - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() - node.historical.addCollection(defaultCollectionID, schema) + node.metaReplica.addCollection(defaultCollectionID, schema) - collection, err := node.historical.getCollectionByID(collectionID) + collection, err := node.metaReplica.getCollectionByID(collectionID) assert.NoError(t, err) assert.Equal(t, collection.ID(), collectionID) - assert.Equal(t, node.historical.getCollectionNum(), 1) + assert.Equal(t, node.metaReplica.getCollectionNum(), 1) - err = node.historical.addPartition(collection.ID(), defaultPartitionID) + err = node.metaReplica.addPartition(collection.ID(), defaultPartitionID) assert.NoError(t, err) - err = node.historical.addSegment(segmentID, defaultPartitionID, collectionID, "", segmentTypeSealed) + err = node.metaReplica.addSegment(segmentID, defaultPartitionID, collectionID, "", segmentTypeSealed) assert.NoError(t, err) } @@ -190,12 +97,10 @@ func newQueryNodeMock() *QueryNode { factory := newMessageStreamFactory() svr := NewQueryNode(ctx, factory) tsReplica := newTSafeReplica() - streamingReplica := newCollectionReplica(etcdKV) - historicalReplica := newCollectionReplica(etcdKV) - svr.historical = streamingReplica - svr.streaming = historicalReplica - svr.dataSyncService = newDataSyncService(ctx, svr.streaming, svr.historical, tsReplica, factory) - svr.statsService = newStatsService(ctx, svr.historical, factory) + replica := newCollectionReplica() + svr.metaReplica = replica + svr.dataSyncService = newDataSyncService(ctx, svr.metaReplica, tsReplica, factory) + svr.statsService = newStatsService(ctx, svr.metaReplica, factory) svr.vectorStorage, err = factory.NewVectorStorageChunkManager(ctx) if err != nil { panic(err) @@ -204,7 +109,7 @@ func newQueryNodeMock() *QueryNode { if err != nil { panic(err) } - svr.loader = newSegmentLoader(svr.historical, svr.streaming, etcdKV, svr.vectorStorage, factory) + svr.loader = newSegmentLoader(svr.metaReplica, etcdKV, svr.vectorStorage, factory) svr.etcdKV = etcdKV return svr @@ -333,7 +238,7 @@ func TestQueryNode_adjustByChangeInfo(t *testing.T) { node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx) assert.NoError(t, err) - err = node.historical.removeSegment(defaultSegmentID) + err = node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed) assert.NoError(t, err) segmentChangeInfos := genSimpleChangeInfo() @@ -405,7 +310,7 @@ func TestQueryNode_watchChangeInfo(t *testing.T) { node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx) assert.NoError(t, err) - err = node.historical.removeSegment(defaultSegmentID) + err = node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed) assert.NoError(t, err) segmentChangeInfos := genSimpleChangeInfo() diff --git a/internal/querynode/query_shard.go b/internal/querynode/query_shard.go index 9d7c840752..d99c739091 100644 --- a/internal/querynode/query_shard.go +++ b/internal/querynode/query_shard.go @@ -41,8 +41,7 @@ type queryShard struct { clusterService *ShardClusterService tSafeReplica TSafeReplicaInterface - historical ReplicaInterface - streaming ReplicaInterface + metaReplica ReplicaInterface vectorChunkManager *storage.VectorChunkManager localCacheEnabled bool @@ -55,15 +54,14 @@ func newQueryShard( channel Channel, replicaID int64, clusterService *ShardClusterService, - historical ReplicaInterface, - streaming ReplicaInterface, + metaReplica ReplicaInterface, tSafeReplica TSafeReplicaInterface, localChunkManager storage.ChunkManager, remoteChunkManager storage.ChunkManager, localCacheEnabled bool, ) (*queryShard, error) { - collection, err := streaming.getCollectionByID(collectionID) + collection, err := metaReplica.getCollectionByID(collectionID) if err != nil { return nil, err } @@ -91,8 +89,7 @@ func newQueryShard( channel: channel, replicaID: replicaID, clusterService: clusterService, - historical: historical, - streaming: streaming, + metaReplica: metaReplica, vectorChunkManager: vectorChunkManager, tSafeReplica: tSafeReplica, } diff --git a/internal/querynode/query_shard_service.go b/internal/querynode/query_shard_service.go index 4a937b1c0a..5f5ef1e9c3 100644 --- a/internal/querynode/query_shard_service.go +++ b/internal/querynode/query_shard_service.go @@ -37,8 +37,7 @@ type queryShardService struct { factory dependency.Factory - historical ReplicaInterface - streaming ReplicaInterface + metaReplica ReplicaInterface tSafeReplica TSafeReplicaInterface shardClusterService *ShardClusterService @@ -48,7 +47,7 @@ type queryShardService struct { scheduler *taskScheduler } -func newQueryShardService(ctx context.Context, historical ReplicaInterface, streaming ReplicaInterface, tSafeReplica TSafeReplicaInterface, clusterService *ShardClusterService, factory dependency.Factory, scheduler *taskScheduler) *queryShardService { +func newQueryShardService(ctx context.Context, metaReplica ReplicaInterface, tSafeReplica TSafeReplicaInterface, clusterService *ShardClusterService, factory dependency.Factory, scheduler *taskScheduler) *queryShardService { queryShardServiceCtx, queryShardServiceCancel := context.WithCancel(ctx) path := Params.LoadWithDefault("localStorage.Path", "/tmp/milvus/data") @@ -60,8 +59,7 @@ func newQueryShardService(ctx context.Context, historical ReplicaInterface, stre ctx: queryShardServiceCtx, cancel: queryShardServiceCancel, queryShards: make(map[Channel]*queryShard), - historical: historical, - streaming: streaming, + metaReplica: metaReplica, tSafeReplica: tSafeReplica, shardClusterService: clusterService, localChunkManager: localChunkManager, @@ -85,8 +83,7 @@ func (q *queryShardService) addQueryShard(collectionID UniqueID, channel Channel channel, replicaID, q.shardClusterService, - q.historical, - q.streaming, + q.metaReplica, q.tSafeReplica, q.localChunkManager, q.remoteChunkManager, diff --git a/internal/querynode/query_shard_service_test.go b/internal/querynode/query_shard_service_test.go index 3587b37164..fd9ab913ef 100644 --- a/internal/querynode/query_shard_service_test.go +++ b/internal/querynode/query_shard_service_test.go @@ -28,7 +28,7 @@ func TestQueryShardService(t *testing.T) { qn, err := genSimpleQueryNode(context.Background()) require.NoError(t, err) - qss := newQueryShardService(context.Background(), qn.historical, qn.streaming, qn.tSafeReplica, qn.ShardClusterService, qn.factory, qn.scheduler) + qss := newQueryShardService(context.Background(), qn.metaReplica, qn.tSafeReplica, qn.ShardClusterService, qn.factory, qn.scheduler) err = qss.addQueryShard(0, "vchan1", 0) assert.NoError(t, err) found1 := qss.hasQueryShard("vchan1") @@ -50,7 +50,7 @@ func TestQueryShardService_InvalidChunkManager(t *testing.T) { qn, err := genSimpleQueryNode(context.Background()) require.NoError(t, err) - qss := newQueryShardService(context.Background(), qn.historical, qn.streaming, qn.tSafeReplica, qn.ShardClusterService, qn.factory, qn.scheduler) + qss := newQueryShardService(context.Background(), qn.metaReplica, qn.tSafeReplica, qn.ShardClusterService, qn.factory, qn.scheduler) lcm := qss.localChunkManager qss.localChunkManager = nil diff --git a/internal/querynode/query_shard_test.go b/internal/querynode/query_shard_test.go index 937eb1bda2..02d5df73d4 100644 --- a/internal/querynode/query_shard_test.go +++ b/internal/querynode/query_shard_test.go @@ -31,17 +31,11 @@ import ( func genSimpleQueryShard(ctx context.Context) (*queryShard, error) { tSafe := newTSafeReplica() - historical, err := genSimpleHistorical(ctx) + replica, err := genSimpleReplica() if err != nil { return nil, err } tSafe.addTSafe(defaultDMLChannel) - - streaming, err := genSimpleStreaming(ctx) - if err != nil { - return nil, err - } - tSafe.addTSafe(defaultDeltaChannel) localCM, err := genLocalChunkManager() if err != nil { @@ -61,7 +55,7 @@ func genSimpleQueryShard(ctx context.Context) (*queryShard, error) { shardClusterService.clusters.Store(defaultDMLChannel, shardCluster) qs, err := newQueryShard(ctx, defaultCollectionID, defaultDMLChannel, defaultReplicaID, shardClusterService, - historical, streaming, tSafe, localCM, remoteCM, false) + replica, tSafe, localCM, remoteCM, false) if err != nil { return nil, err } @@ -80,10 +74,7 @@ func updateQueryShardTSafe(qs *queryShard, timestamp Timestamp) error { func TestNewQueryShard_IllegalCases(t *testing.T) { ctx := context.Background() tSafe := newTSafeReplica() - historical, err := genSimpleHistorical(ctx) - require.NoError(t, err) - - streaming, err := genSimpleStreaming(ctx) + replica, err := genSimpleReplica() require.NoError(t, err) localCM, err := genLocalChunkManager() @@ -100,15 +91,15 @@ func TestNewQueryShard_IllegalCases(t *testing.T) { shardClusterService.clusters.Store(defaultDMLChannel, shardCluster) _, err = newQueryShard(ctx, defaultCollectionID-1, defaultDMLChannel, defaultReplicaID, shardClusterService, - historical, streaming, tSafe, localCM, remoteCM, false) + replica, tSafe, localCM, remoteCM, false) assert.Error(t, err) _, err = newQueryShard(ctx, defaultCollectionID, defaultDMLChannel, defaultReplicaID, shardClusterService, - historical, streaming, tSafe, nil, remoteCM, false) + replica, tSafe, nil, remoteCM, false) assert.Error(t, err) _, err = newQueryShard(ctx, defaultCollectionID, defaultDMLChannel, defaultReplicaID, shardClusterService, - historical, streaming, tSafe, localCM, nil, false) + replica, tSafe, localCM, nil, false) assert.Error(t, err) } diff --git a/internal/querynode/reduce_test.go b/internal/querynode/reduce_test.go index 7a6ff4d8da..10dc5fbc5c 100644 --- a/internal/querynode/reduce_test.go +++ b/internal/querynode/reduce_test.go @@ -51,10 +51,10 @@ func TestReduce_AllFunc(t *testing.T) { node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) - collection, err := node.historical.getCollectionByID(defaultCollectionID) + collection, err := node.metaReplica.getCollectionByID(defaultCollectionID) assert.NoError(t, err) - segment, err := node.historical.getSegmentByID(defaultSegmentID) + segment, err := node.metaReplica.getSegmentByID(defaultSegmentID, segmentTypeSealed) assert.NoError(t, err) // TODO: replace below by genPlaceholderGroup(nq) diff --git a/internal/querynode/retrieve.go b/internal/querynode/retrieve.go index 60026b476f..a3032a9535 100644 --- a/internal/querynode/retrieve.go +++ b/internal/querynode/retrieve.go @@ -23,11 +23,11 @@ import ( // retrieveOnSegments performs retrieve on listed segments // all segment ids are validated before calling this function -func retrieveOnSegments(replica ReplicaInterface, collID UniqueID, plan *RetrievePlan, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, error) { +func retrieveOnSegments(replica ReplicaInterface, segType segmentType, collID UniqueID, plan *RetrievePlan, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, error) { var retrieveResults []*segcorepb.RetrieveResults for _, segID := range segIDs { - seg, err := replica.getSegmentByID(segID) + seg, err := replica.getSegmentByID(segID, segType) if err != nil { return nil, err } @@ -54,7 +54,7 @@ func retrieveHistorical(replica ReplicaInterface, plan *RetrievePlan, collID Uni return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err } - retrieveResults, err = retrieveOnSegments(replica, collID, plan, retrieveSegmentIDs, vcm) + retrieveResults, err = retrieveOnSegments(replica, segmentTypeSealed, collID, plan, retrieveSegmentIDs, vcm) return retrieveResults, retrievePartIDs, retrieveSegmentIDs, err } @@ -69,6 +69,6 @@ func retrieveStreaming(replica ReplicaInterface, plan *RetrievePlan, collID Uniq if err != nil { return retrieveResults, retrieveSegmentIDs, retrievePartIDs, err } - retrieveResults, err = retrieveOnSegments(replica, collID, plan, retrieveSegmentIDs, vcm) + retrieveResults, err = retrieveOnSegments(replica, segmentTypeGrowing, collID, plan, retrieveSegmentIDs, vcm) return retrieveResults, retrievePartIDs, retrieveSegmentIDs, err } diff --git a/internal/querynode/retrieve_test.go b/internal/querynode/retrieve_test.go index 4c7f00410c..be7f60f433 100644 --- a/internal/querynode/retrieve_test.go +++ b/internal/querynode/retrieve_test.go @@ -17,7 +17,6 @@ package querynode import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -26,10 +25,7 @@ import ( ) func TestStreaming_retrieve(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - streaming, err := genSimpleStreaming(ctx) + streaming, err := genSimpleReplicaWithGrowingSegment() assert.NoError(t, err) collection, err := streaming.getCollectionByID(defaultCollectionID) @@ -40,7 +36,7 @@ func TestStreaming_retrieve(t *testing.T) { insertMsg, err := genSimpleInsertMsg(collection.schema, defaultMsgLength) assert.NoError(t, err) - segment, err := streaming.getSegmentByID(defaultSegmentID) + segment, err := streaming.getSegmentByID(defaultSegmentID, segmentTypeGrowing) assert.NoError(t, err) offset, err := segment.segmentPreInsert(len(insertMsg.RowIDs)) diff --git a/internal/querynode/search.go b/internal/querynode/search.go index 22499f1f8a..a7be5913ff 100644 --- a/internal/querynode/search.go +++ b/internal/querynode/search.go @@ -26,7 +26,7 @@ import ( // searchOnSegments performs search on listed segments // all segment ids are validated before calling this function -func searchOnSegments(replica ReplicaInterface, searchReq *searchRequest, segIDs []UniqueID) ([]*SearchResult, error) { +func searchOnSegments(replica ReplicaInterface, segType segmentType, searchReq *searchRequest, segIDs []UniqueID) ([]*SearchResult, error) { // results variables searchResults := make([]*SearchResult, len(segIDs)) errs := make([]error, len(segIDs)) @@ -37,7 +37,7 @@ func searchOnSegments(replica ReplicaInterface, searchReq *searchRequest, segIDs wg.Add(1) go func(segID UniqueID, i int) { defer wg.Done() - seg, err := replica.getSegmentByID(segID) + seg, err := replica.getSegmentByID(segID, segType) if err != nil { return } @@ -75,7 +75,7 @@ func searchHistorical(replica ReplicaInterface, searchReq *searchRequest, collID if err != nil { return searchResults, searchSegmentIDs, searchPartIDs, err } - searchResults, err = searchOnSegments(replica, searchReq, searchSegmentIDs) + searchResults, err = searchOnSegments(replica, segmentTypeSealed, searchReq, searchSegmentIDs) return searchResults, searchPartIDs, searchSegmentIDs, err } @@ -91,6 +91,6 @@ func searchStreaming(replica ReplicaInterface, searchReq *searchRequest, collID if err != nil { return searchResults, searchSegmentIDs, searchPartIDs, err } - searchResults, err = searchOnSegments(replica, searchReq, searchSegmentIDs) + searchResults, err = searchOnSegments(replica, segmentTypeGrowing, searchReq, searchSegmentIDs) return searchResults, searchPartIDs, searchSegmentIDs, err } diff --git a/internal/querynode/search_test.go b/internal/querynode/search_test.go index 7523a52b63..087777dfea 100644 --- a/internal/querynode/search_test.go +++ b/internal/querynode/search_test.go @@ -28,7 +28,7 @@ func TestHistorical_Search(t *testing.T) { defer cancel() t.Run("test search", func(t *testing.T) { - his, err := genSimpleHistorical(ctx) + his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) collection, err := his.getCollectionByID(defaultCollectionID) @@ -41,7 +41,7 @@ func TestHistorical_Search(t *testing.T) { }) t.Run("test no collection - search partitions", func(t *testing.T) { - his, err := genSimpleHistorical(ctx) + his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) collection, err := his.getCollectionByID(defaultCollectionID) @@ -57,7 +57,7 @@ func TestHistorical_Search(t *testing.T) { }) t.Run("test no collection - search all collection", func(t *testing.T) { - his, err := genSimpleHistorical(ctx) + his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) collection, err := his.getCollectionByID(defaultCollectionID) @@ -73,7 +73,7 @@ func TestHistorical_Search(t *testing.T) { }) t.Run("test load partition and partition has been released", func(t *testing.T) { - his, err := genSimpleHistorical(ctx) + his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) collection, err := his.getCollectionByID(defaultCollectionID) @@ -93,7 +93,7 @@ func TestHistorical_Search(t *testing.T) { }) t.Run("test no partition in collection", func(t *testing.T) { - his, err := genSimpleHistorical(ctx) + his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) collection, err := his.getCollectionByID(defaultCollectionID) @@ -112,11 +112,8 @@ func TestHistorical_Search(t *testing.T) { } func TestStreaming_search(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - t.Run("test search", func(t *testing.T) { - streaming, err := genSimpleStreaming(ctx) + streaming, err := genSimpleReplicaWithGrowingSegment() assert.NoError(t, err) collection, err := streaming.getCollectionByID(defaultCollectionID) @@ -133,7 +130,7 @@ func TestStreaming_search(t *testing.T) { }) t.Run("test run empty partition", func(t *testing.T) { - streaming, err := genSimpleStreaming(ctx) + streaming, err := genSimpleReplicaWithGrowingSegment() assert.NoError(t, err) collection, err := streaming.getCollectionByID(defaultCollectionID) @@ -150,7 +147,7 @@ func TestStreaming_search(t *testing.T) { }) t.Run("test run empty partition and loadCollection", func(t *testing.T) { - streaming, err := genSimpleStreaming(ctx) + streaming, err := genSimpleReplicaWithGrowingSegment() assert.NoError(t, err) collection, err := streaming.getCollectionByID(defaultCollectionID) @@ -175,7 +172,7 @@ func TestStreaming_search(t *testing.T) { }) t.Run("test run empty partition and loadPartition", func(t *testing.T) { - streaming, err := genSimpleStreaming(ctx) + streaming, err := genSimpleReplicaWithGrowingSegment() assert.NoError(t, err) collection, err := streaming.getCollectionByID(defaultCollectionID) @@ -198,7 +195,7 @@ func TestStreaming_search(t *testing.T) { }) t.Run("test no partitions in collection", func(t *testing.T) { - streaming, err := genSimpleStreaming(ctx) + streaming, err := genSimpleReplicaWithGrowingSegment() assert.NoError(t, err) collection, err := streaming.getCollectionByID(defaultCollectionID) @@ -218,7 +215,7 @@ func TestStreaming_search(t *testing.T) { }) t.Run("test search failed", func(t *testing.T) { - streaming, err := genSimpleStreaming(ctx) + streaming, err := genSimpleReplicaWithGrowingSegment() assert.NoError(t, err) collection, err := streaming.getCollectionByID(defaultCollectionID) @@ -226,7 +223,7 @@ func TestStreaming_search(t *testing.T) { searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ) assert.NoError(t, err) - seg, err := streaming.getSegmentByID(defaultSegmentID) + seg, err := streaming.getSegmentByID(defaultSegmentID, segmentTypeGrowing) assert.NoError(t, err) seg.segmentPtr = nil diff --git a/internal/querynode/segment_loader.go b/internal/querynode/segment_loader.go index 0455b9ad3c..c8204fb64d 100644 --- a/internal/querynode/segment_loader.go +++ b/internal/querynode/segment_loader.go @@ -49,8 +49,7 @@ import ( // segmentLoader is only responsible for loading the field data from binlog type segmentLoader struct { - historicalReplica ReplicaInterface - streamingReplica ReplicaInterface + metaReplica ReplicaInterface dataCoord types.DataCoord @@ -64,22 +63,9 @@ type segmentLoader struct { } func (loader *segmentLoader) getFieldType(segment *Segment, fieldID FieldID) (schemapb.DataType, error) { - var coll *Collection - var err error - - switch segment.getType() { - case segmentTypeGrowing: - coll, err = loader.streamingReplica.getCollectionByID(segment.collectionID) - if err != nil { - return schemapb.DataType_None, err - } - case segmentTypeSealed: - coll, err = loader.historicalReplica.getCollectionByID(segment.collectionID) - if err != nil { - return schemapb.DataType_None, err - } - default: - return schemapb.DataType_None, fmt.Errorf("invalid segment type: %s", segment.getType().String()) + coll, err := loader.metaReplica.getCollectionByID(segment.collectionID) + if err != nil { + return schemapb.DataType_None, err } return coll.getFieldType(fieldID) @@ -95,20 +81,6 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme return nil } - var metaReplica ReplicaInterface - switch segmentType { - case segmentTypeGrowing: - metaReplica = loader.streamingReplica - case segmentTypeSealed: - metaReplica = loader.historicalReplica - default: - err := fmt.Errorf("illegal segment type when load segment, collectionID = %d", req.CollectionID) - log.Error("load segment failed, illegal segment type", - zap.Int64("loadSegmentRequest msgID", req.Base.MsgID), - zap.Error(err)) - return err - } - log.Info("segmentLoader start loading...", zap.Any("collectionID", req.CollectionID), zap.Any("numOfSegments", len(req.Infos)), @@ -150,7 +122,7 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme partitionID := info.PartitionID collectionID := info.CollectionID - collection, err := loader.historicalReplica.getCollectionByID(collectionID) + collection, err := loader.metaReplica.getCollectionByID(collectionID) if err != nil { segmentGC() return err @@ -206,7 +178,7 @@ func (loader *segmentLoader) loadSegment(req *querypb.LoadSegmentsRequest, segme // set segment to meta replica for _, s := range newSegments { - err = metaReplica.setSegment(s) + err = loader.metaReplica.setSegment(s) if err != nil { log.Error("load segment failed, set segment to meta failed", zap.Int64("collectionID", s.collectionID), @@ -232,7 +204,7 @@ func (loader *segmentLoader) loadSegmentInternal(segment *Segment, zap.Int64("partitionID", partitionID), zap.Int64("segmentID", segmentID)) - pkFieldID, err := loader.historicalReplica.getPKFieldIDByCollectionID(collectionID) + pkFieldID, err := loader.metaReplica.getPKFieldIDByCollectionID(collectionID) if err != nil { return err } @@ -496,7 +468,7 @@ func (loader *segmentLoader) loadGrowingSegments(segment *Segment, Version: internalpb.InsertDataVersion_ColumnBased, }, } - pks, err := getPrimaryKeys(tmpInsertMsg, loader.streamingReplica) + pks, err := getPrimaryKeys(tmpInsertMsg, loader.metaReplica) if err != nil { return err } @@ -665,7 +637,14 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection zap.String("vChannelName", position.GetChannelName()), zap.Any("msg id", position.GetMsgID()), ) - processDeleteMessages(loader.historicalReplica, dmsg, delData) + err = processDeleteMessages(loader.metaReplica, segmentTypeSealed, dmsg, delData) + if err != nil { + // TODO: panic? + // error occurs when missing meta info or unexpected pk type, should not happen + err = fmt.Errorf("deleteNode processDeleteMessages failed, collectionID = %d, err = %s", dmsg.CollectionID, err) + log.Error(err.Error()) + return err + } } ret, err := lastMsgID.LessOrEqualThan(tsMsg.Position().MsgID) @@ -686,7 +665,7 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection log.Info("All data has been read, there is no more data", zap.Int64("Collection ID", collectionID), zap.String("channel", pChannelName), zap.Any("msg id", position.GetMsgID())) for segmentID, pks := range delData.deleteIDs { - segment, err := loader.historicalReplica.getSegmentByID(segmentID) + segment, err := loader.metaReplica.getSegmentByID(segmentID, segmentTypeSealed) if err != nil { log.Warn(err.Error()) continue @@ -698,7 +677,7 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection wg := sync.WaitGroup{} for segmentID := range delData.deleteOffset { wg.Add(1) - go deletePk(loader.historicalReplica, delData, segmentID, &wg) + go deletePk(loader.metaReplica, delData, segmentID, &wg) } wg.Wait() log.Info("from dml check point load done", zap.Any("msg id", position.GetMsgID())) @@ -708,7 +687,7 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection func deletePk(replica ReplicaInterface, deleteData *deleteData, segmentID UniqueID, wg *sync.WaitGroup) { defer wg.Done() log.Debug("QueryNode::iNode::delete", zap.Any("SegmentID", segmentID)) - targetSegment, err := replica.getSegmentByID(segmentID) + targetSegment, err := replica.getSegmentByID(segmentID, segmentTypeSealed) if err != nil { log.Error(err.Error()) return @@ -774,8 +753,7 @@ func (loader *segmentLoader) checkSegmentSize(collectionID UniqueID, segmentLoad } func newSegmentLoader( - historicalReplica ReplicaInterface, - streamingReplica ReplicaInterface, + metaReplica ReplicaInterface, etcdKV *etcdkv.EtcdKV, cm storage.ChunkManager, factory msgstream.Factory) *segmentLoader { @@ -797,8 +775,7 @@ func newSegmentLoader( } loader := &segmentLoader{ - historicalReplica: historicalReplica, - streamingReplica: streamingReplica, + metaReplica: metaReplica, cm: cm, etcdKV: etcdKV, diff --git a/internal/querynode/segment_loader_test.go b/internal/querynode/segment_loader_test.go index 5ea9845629..5ebae27bd2 100644 --- a/internal/querynode/segment_loader_test.go +++ b/internal/querynode/segment_loader_test.go @@ -40,8 +40,7 @@ func TestSegmentLoader_loadSegment(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() fieldBinlog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema) assert.NoError(t, err) @@ -49,7 +48,7 @@ func TestSegmentLoader_loadSegment(t *testing.T) { node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) - err = node.historical.removeSegment(defaultSegmentID) + err = node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed) assert.NoError(t, err) loader := node.loader @@ -80,7 +79,7 @@ func TestSegmentLoader_loadSegment(t *testing.T) { node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) - err = node.historical.removePartition(defaultPartitionID) + err = node.metaReplica.removePartition(defaultPartitionID) assert.NoError(t, err) loader := node.loader @@ -169,7 +168,7 @@ func TestSegmentLoader_loadSegmentFieldsData(t *testing.T) { schema.Fields = append(schema.Fields, genVectorFieldSchema(simpleBinVecField)) } - err = loader.historicalReplica.removeSegment(defaultSegmentID) + err = loader.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed) assert.NoError(t, err) col := newCollection(defaultCollectionID, schema) @@ -220,7 +219,7 @@ func TestSegmentLoader_invalid(t *testing.T) { loader := node.loader assert.NotNil(t, loader) - err = node.historical.removeCollection(defaultCollectionID) + err = node.metaReplica.removeCollection(defaultCollectionID) assert.NoError(t, err) req := &querypb.LoadSegmentsRequest{ @@ -248,7 +247,7 @@ func TestSegmentLoader_invalid(t *testing.T) { loader := node.loader assert.NotNil(t, loader) - err = node.historical.removeCollection(defaultCollectionID) + err = node.metaReplica.removeCollection(defaultCollectionID) assert.NoError(t, err) schema := &schemapb.CollectionSchema{ @@ -259,7 +258,7 @@ func TestSegmentLoader_invalid(t *testing.T) { genPKFieldSchema(simpleInt64Field), }, } - loader.historicalReplica.addCollection(defaultCollectionID, schema) + loader.metaReplica.addCollection(defaultCollectionID, schema) req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ @@ -329,7 +328,7 @@ func TestSegmentLoader_testLoadGrowing(t *testing.T) { loader := node.loader assert.NotNil(t, loader) - collection, err := node.historical.getCollectionByID(defaultCollectionID) + collection, err := node.metaReplica.getCollectionByID(defaultCollectionID) assert.NoError(t, err) segment, err := newSegment(collection, defaultSegmentID+1, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeGrowing) @@ -358,7 +357,7 @@ func TestSegmentLoader_testLoadGrowing(t *testing.T) { loader := node.loader assert.NotNil(t, loader) - collection, err := node.historical.getCollectionByID(defaultCollectionID) + collection, err := node.metaReplica.getCollectionByID(defaultCollectionID) assert.NoError(t, err) segment, err := newSegment(collection, defaultSegmentID+1, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeGrowing) @@ -386,8 +385,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() fieldBinlog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema) assert.NoError(t, err) @@ -422,7 +420,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) { err = loader.loadSegment(req1, segmentTypeSealed) assert.NoError(t, err) - segment1, err := loader.historicalReplica.getSegmentByID(segmentID1) + segment1, err := loader.metaReplica.getSegmentByID(segmentID1, segmentTypeSealed) assert.NoError(t, err) assert.Equal(t, segment1.getRowCount(), int64(100)) @@ -448,7 +446,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) { err = loader.loadSegment(req2, segmentTypeSealed) assert.NoError(t, err) - segment2, err := loader.historicalReplica.getSegmentByID(segmentID2) + segment2, err := loader.metaReplica.getSegmentByID(segmentID2, segmentTypeSealed) assert.NoError(t, err) // Note: getRowCount currently does not return accurate counts. The deleted rows are also counted. assert.Equal(t, segment2.getRowCount(), int64(100)) // accurate counts should be 98 @@ -482,7 +480,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) { err = loader.loadSegment(req1, segmentTypeGrowing) assert.NoError(t, err) - segment1, err := loader.streamingReplica.getSegmentByID(segmentID1) + segment1, err := loader.metaReplica.getSegmentByID(segmentID1, segmentTypeGrowing) assert.NoError(t, err) assert.Equal(t, segment1.getRowCount(), int64(100)) @@ -508,7 +506,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) { err = loader.loadSegment(req2, segmentTypeGrowing) assert.NoError(t, err) - segment2, err := loader.streamingReplica.getSegmentByID(segmentID2) + segment2, err := loader.metaReplica.getSegmentByID(segmentID2, segmentTypeGrowing) assert.NoError(t, err) // Note: getRowCount currently does not return accurate counts. The deleted rows are also counted. assert.Equal(t, segment2.getRowCount(), int64(100)) // accurate counts should be 98 @@ -519,8 +517,7 @@ func TestSegmentLoader_testLoadSealedSegmentWithIndex(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() // generate insert binlog fieldBinlog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema) @@ -568,7 +565,7 @@ func TestSegmentLoader_testLoadSealedSegmentWithIndex(t *testing.T) { err = loader.loadSegment(req, segmentTypeSealed) assert.NoError(t, err) - segment, err := node.historical.getSegmentByID(segmentID) + segment, err := node.metaReplica.getSegmentByID(segmentID, segmentTypeSealed) assert.NoError(t, err) vecFieldInfo, err := segment.getIndexedFieldInfo(simpleFloatVecField.id) assert.NoError(t, err) @@ -752,16 +749,17 @@ func newMockReplicaInterface() *mockReplicaInterface { } func TestSegmentLoader_getFieldType_err(t *testing.T) { - loader := &segmentLoader{} - // nor growing or sealed. - segment := &Segment{segmentType: 200} - _, err := loader.getFieldType(segment, 100) + replica, err := genSimpleReplica() + assert.NoError(t, err) + loader := &segmentLoader{metaReplica: replica} + segment := &Segment{collectionID: 200} + _, err = loader.getFieldType(segment, 100) assert.Error(t, err) } func TestSegmentLoader_getFieldType(t *testing.T) { replica := newMockReplicaInterface() - loader := &segmentLoader{streamingReplica: replica, historicalReplica: replica} + loader := &segmentLoader{metaReplica: replica} // failed to get collection. segment := &Segment{segmentType: segmentTypeSealed} diff --git a/internal/querynode/segment_test.go b/internal/querynode/segment_test.go index 1381d28d91..1e3d1de2f4 100644 --- a/internal/querynode/segment_test.go +++ b/internal/querynode/segment_test.go @@ -43,8 +43,7 @@ import ( //-------------------------------------------------------------------------------------- constructor and destructor func TestSegment_newSegment(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collectionMeta := genCollectionMeta(collectionID, schema) collection := newCollection(collectionMeta.ID, collectionMeta.Schema) @@ -68,8 +67,7 @@ func TestSegment_newSegment(t *testing.T) { func TestSegment_deleteSegment(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collectionMeta := genCollectionMeta(collectionID, schema) collection := newCollection(collectionMeta.ID, schema) @@ -94,8 +92,7 @@ func TestSegment_deleteSegment(t *testing.T) { //-------------------------------------------------------------------------------------- stats functions func TestSegment_getRowCount(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) assert.Equal(t, collection.ID(), collectionID) @@ -137,8 +134,7 @@ func TestSegment_getRowCount(t *testing.T) { func TestSegment_retrieve(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) assert.Equal(t, collection.ID(), collectionID) @@ -217,8 +213,7 @@ func TestSegment_retrieve(t *testing.T) { func TestSegment_getDeletedCount(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) assert.Equal(t, collection.ID(), collectionID) @@ -267,8 +262,7 @@ func TestSegment_getDeletedCount(t *testing.T) { func TestSegment_getMemSize(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) assert.Equal(t, collection.ID(), collectionID) @@ -304,8 +298,7 @@ func TestSegment_getMemSize(t *testing.T) { //-------------------------------------------------------------------------------------- dm & search functions func TestSegment_segmentInsert(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) assert.Equal(t, collection.ID(), collectionID) @@ -349,8 +342,7 @@ func TestSegment_segmentInsert(t *testing.T) { func TestSegment_segmentDelete(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) assert.Equal(t, collection.ID(), collectionID) @@ -391,10 +383,10 @@ func TestSegment_segmentSearch(t *testing.T) { node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) - collection, err := node.historical.getCollectionByID(defaultCollectionID) + collection, err := node.metaReplica.getCollectionByID(defaultCollectionID) assert.NoError(t, err) - segment, err := node.historical.getSegmentByID(defaultSegmentID) + segment, err := node.metaReplica.getSegmentByID(defaultSegmentID, segmentTypeSealed) assert.NoError(t, err) // TODO: replace below by genPlaceholderGroup(nq) @@ -446,8 +438,7 @@ func TestSegment_segmentSearch(t *testing.T) { //-------------------------------------------------------------------------------------- preDm functions func TestSegment_segmentPreInsert(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) assert.Equal(t, collection.ID(), collectionID) @@ -466,8 +457,7 @@ func TestSegment_segmentPreInsert(t *testing.T) { func TestSegment_segmentPreDelete(t *testing.T) { collectionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) assert.Equal(t, collection.ID(), collectionID) @@ -530,8 +520,7 @@ func TestSegment_segmentLoadDeletedRecord(t *testing.T) { } func TestSegment_segmentLoadFieldData(t *testing.T) { - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() _, err := genSealedSegment(schema, defaultCollectionID, defaultPartitionID, @@ -561,8 +550,7 @@ func TestSegment_ConcurrentOperation(t *testing.T) { collectionID := UniqueID(0) partitionID := UniqueID(0) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(collectionID, schema) assert.Equal(t, collection.ID(), collectionID) @@ -593,10 +581,10 @@ func TestSegment_indexInfo(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - h, err := genSimpleHistorical(ctx) + replica, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) - seg, err := h.getSegmentByID(defaultSegmentID) + seg, err := replica.getSegmentByID(defaultSegmentID, segmentTypeSealed) assert.NoError(t, err) fieldID := simpleFloatVecField.id @@ -634,8 +622,7 @@ func TestSegment_indexInfo(t *testing.T) { } func TestSegment_BasicMetrics(t *testing.T) { - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(defaultCollectionID, schema) segment, err := newSegment(collection, defaultSegmentID, @@ -682,8 +669,7 @@ func TestSegment_fillIndexedFieldsData(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() collection := newCollection(defaultCollectionID, schema) segment, err := newSegment(collection, defaultSegmentID, @@ -1009,15 +995,15 @@ func Test_fillFieldData(t *testing.T) { func TestUpdateBloomFilter(t *testing.T) { t.Run("test int64 pk", func(t *testing.T) { - historical, err := genSimpleReplica() + replica, err := genSimpleReplica() assert.NoError(t, err) - err = historical.addSegment(defaultSegmentID, + err = replica.addSegment(defaultSegmentID, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeSealed) assert.NoError(t, err) - seg, err := historical.getSegmentByID(defaultSegmentID) + seg, err := replica.getSegmentByID(defaultSegmentID, segmentTypeSealed) assert.Nil(t, err) pkValues := []int64{1, 2} pks := make([]primaryKey, len(pkValues)) @@ -1032,15 +1018,15 @@ func TestUpdateBloomFilter(t *testing.T) { } }) t.Run("test string pk", func(t *testing.T) { - historical, err := genSimpleReplica() + replica, err := genSimpleReplica() assert.NoError(t, err) - err = historical.addSegment(defaultSegmentID, + err = replica.addSegment(defaultSegmentID, defaultPartitionID, defaultCollectionID, defaultDMLChannel, segmentTypeSealed) assert.NoError(t, err) - seg, err := historical.getSegmentByID(defaultSegmentID) + seg, err := replica.getSegmentByID(defaultSegmentID, segmentTypeSealed) assert.Nil(t, err) pkValues := []string{"test1", "test2"} pks := make([]primaryKey, len(pkValues)) diff --git a/internal/querynode/stats_service_test.go b/internal/querynode/stats_service_test.go index 677a5a4c1c..352e2f8cf1 100644 --- a/internal/querynode/stats_service_test.go +++ b/internal/querynode/stats_service_test.go @@ -32,7 +32,7 @@ func TestStatsService_start(t *testing.T) { initTestMeta(t, node, 0, 0) factory := dependency.NewDefaultFactory(true) - node.statsService = newStatsService(node.queryNodeLoopCtx, node.historical, factory) + node.statsService = newStatsService(node.queryNodeLoopCtx, node.metaReplica, factory) node.statsService.start() node.Stop() } @@ -57,7 +57,7 @@ func TestSegmentManagement_sendSegmentStatistic(t *testing.T) { var statsMsgStream msgstream.MsgStream = statsStream - node.statsService = newStatsService(node.queryNodeLoopCtx, node.historical, factory) + node.statsService = newStatsService(node.queryNodeLoopCtx, node.metaReplica, factory) node.statsService.statsStream = statsMsgStream node.statsService.statsStream.Start() diff --git a/internal/querynode/task.go b/internal/querynode/task.go index a00cd7b1f8..0ef6f4a335 100644 --- a/internal/querynode/task.go +++ b/internal/querynode/task.go @@ -154,8 +154,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) { ) // init collection meta - sCol := w.node.streaming.addCollection(collectionID, w.req.Schema) - hCol := w.node.historical.addCollection(collectionID, w.req.Schema) + coll := w.node.metaReplica.addCollection(collectionID, w.req.Schema) //add shard cluster for _, vchannel := range vChannels { @@ -203,12 +202,16 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) { // update partition info from unFlushedSegments and loadMeta for _, info := range req.Infos { - w.node.streaming.addPartition(collectionID, info.PartitionID) - w.node.historical.addPartition(collectionID, info.PartitionID) + err = w.node.metaReplica.addPartition(collectionID, info.PartitionID) + if err != nil { + return err + } } for _, partitionID := range req.GetLoadMeta().GetPartitionIDs() { - w.node.historical.addPartition(collectionID, partitionID) - w.node.streaming.addPartition(collectionID, partitionID) + err = w.node.metaReplica.addPartition(collectionID, partitionID) + if err != nil { + return err + } } log.Info("loading growing segments in WatchDmChannels...", @@ -228,12 +231,12 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) { // remove growing segment if watch dmChannels failed defer func() { if err != nil { - collection, err2 := w.node.streaming.getCollectionByID(collectionID) + collection, err2 := w.node.metaReplica.getCollectionByID(collectionID) if err2 == nil { collection.Lock() defer collection.Unlock() for _, segmentID := range unFlushedSegmentIDs { - w.node.streaming.removeSegment(segmentID) + w.node.metaReplica.removeSegment(segmentID, segmentTypeGrowing) } } } @@ -260,7 +263,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) { for _, info := range w.req.Infos { unFlushedCheckPointInfos = append(unFlushedCheckPointInfos, info.UnflushedSegments...) } - w.node.streaming.addExcludedSegments(collectionID, unFlushedCheckPointInfos) + w.node.metaReplica.addExcludedSegments(collectionID, unFlushedCheckPointInfos) unflushedSegmentIDs := make([]UniqueID, 0) for i := 0; i < len(unFlushedCheckPointInfos); i++ { unflushedSegmentIDs = append(unflushedSegmentIDs, unFlushedCheckPointInfos[i].GetID()) @@ -284,7 +287,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) { } } } - w.node.streaming.addExcludedSegments(collectionID, flushedCheckPointInfos) + w.node.metaReplica.addExcludedSegments(collectionID, flushedCheckPointInfos) log.Info("watchDMChannel, add check points info for flushed segments done", zap.Int64("collectionID", collectionID), zap.Any("flushedCheckPointInfos", flushedCheckPointInfos), @@ -304,7 +307,7 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) { } } } - w.node.streaming.addExcludedSegments(collectionID, droppedCheckPointInfos) + w.node.metaReplica.addExcludedSegments(collectionID, droppedCheckPointInfos) log.Info("watchDMChannel, add check points info for dropped segments done", zap.Int64("collectionID", collectionID), zap.Any("droppedCheckPointInfos", droppedCheckPointInfos), @@ -356,13 +359,10 @@ func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) { log.Info("watchDMChannel, add flowGraph for dmChannels success", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels)) - sCol.addVChannels(vChannels) - sCol.addPChannels(pChannels) - sCol.setLoadType(lType) + coll.addVChannels(vChannels) + coll.addPChannels(pChannels) + coll.setLoadType(lType) - hCol.addVChannels(vChannels) - hCol.addPChannels(pChannels) - hCol.setLoadType(lType) log.Info("watchDMChannel, init replica done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels)) // create tSafe @@ -415,18 +415,10 @@ func (w *watchDeltaChannelsTask) Execute(ctx context.Context) error { zap.Any("collectionID", collectionID), ) - if hasCollectionInHistorical := w.node.historical.hasCollection(collectionID); !hasCollectionInHistorical { + if hasColl := w.node.metaReplica.hasCollection(collectionID); !hasColl { return fmt.Errorf("cannot find collection with collectionID, %d", collectionID) } - hCol, err := w.node.historical.getCollectionByID(collectionID) - if err != nil { - return err - } - - if hasCollectionInStreaming := w.node.streaming.hasCollection(collectionID); !hasCollectionInStreaming { - return fmt.Errorf("cannot find collection with collectionID, %d", collectionID) - } - sCol, err := w.node.streaming.getCollectionByID(collectionID) + coll, err := w.node.metaReplica.getCollectionByID(collectionID) if err != nil { return err } @@ -467,11 +459,8 @@ func (w *watchDeltaChannelsTask) Execute(ctx context.Context) error { log.Info("watchDeltaChannel, add flowGraph for deltaChannel success", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels)) //set collection replica - hCol.addVDeltaChannels(vDeltaChannels) - hCol.addPDeltaChannels(pDeltaChannels) - - sCol.addVDeltaChannels(vDeltaChannels) - sCol.addPDeltaChannels(pDeltaChannels) + coll.addVDeltaChannels(vDeltaChannels) + coll.addPDeltaChannels(pDeltaChannels) // create tSafe for _, channel := range vDeltaChannels { @@ -506,14 +495,9 @@ func (l *loadSegmentsTask) PreExecute(ctx context.Context) error { var err error // init meta collectionID := l.req.GetCollectionID() - l.node.historical.addCollection(collectionID, l.req.GetSchema()) - l.node.streaming.addCollection(collectionID, l.req.GetSchema()) + l.node.metaReplica.addCollection(collectionID, l.req.GetSchema()) for _, partitionID := range l.req.GetLoadMeta().GetPartitionIDs() { - err = l.node.historical.addPartition(collectionID, partitionID) - if err != nil { - return err - } - err = l.node.streaming.addPartition(collectionID, partitionID) + err = l.node.metaReplica.addPartition(collectionID, partitionID) if err != nil { return err } @@ -522,7 +506,11 @@ func (l *loadSegmentsTask) PreExecute(ctx context.Context) error { // filter segments that are already loaded in this querynode var filteredInfos []*queryPb.SegmentLoadInfo for _, info := range l.req.Infos { - if !l.node.historical.hasSegment(info.SegmentID) { + has, err := l.node.metaReplica.hasSegment(info.SegmentID, segmentTypeSealed) + if err != nil { + return err + } + if !has { filteredInfos = append(filteredInfos, info) } else { log.Debug("ignore segment that is already loaded", zap.Int64("segmentID", info.SegmentID)) @@ -562,28 +550,7 @@ func (r *releaseCollectionTask) Execute(ctx context.Context) error { zap.Any("collectionID", r.req.CollectionID), ) - err := r.releaseReplica(r.node.streaming, replicaStreaming) - if err != nil { - return fmt.Errorf("release collection failed, collectionID = %d, err = %s", r.req.CollectionID, err) - } - - // remove collection metas in streaming and historical - log.Info("release historical", zap.Any("collectionID", r.req.CollectionID)) - err = r.releaseReplica(r.node.historical, replicaHistorical) - if err != nil { - return fmt.Errorf("release collection failed, collectionID = %d, err = %s", r.req.CollectionID, err) - } - - debug.FreeOSMemory() - - r.node.queryShardService.releaseCollection(r.req.CollectionID) - - log.Info("ReleaseCollection done", zap.Int64("collectionID", r.req.CollectionID)) - return nil -} - -func (r *releaseCollectionTask) releaseReplica(replica ReplicaInterface, replicaType ReplicaType) error { - collection, err := replica.getCollectionByID(r.req.CollectionID) + collection, err := r.node.metaReplica.getCollectionByID(r.req.CollectionID) if err != nil { return err } @@ -592,31 +559,33 @@ func (r *releaseCollectionTask) releaseReplica(replica ReplicaInterface, replica collection.setReleaseTime(r.req.Base.Timestamp, true) // remove all flow graphs of the target collection - var channels []Channel - if replicaType == replicaStreaming { - channels = collection.getVChannels() - r.node.dataSyncService.removeFlowGraphsByDMLChannels(channels) - } else { - // remove all tSafes and flow graphs of the target collection - channels = collection.getVDeltaChannels() - r.node.dataSyncService.removeFlowGraphsByDeltaChannels(channels) - } + vChannels := collection.getVChannels() + vDeltaChannels := collection.getVDeltaChannels() + r.node.dataSyncService.removeFlowGraphsByDMLChannels(vChannels) + r.node.dataSyncService.removeFlowGraphsByDeltaChannels(vDeltaChannels) // remove all tSafes of the target collection - for _, channel := range channels { - log.Info("Releasing tSafe in releaseCollectionTask...", - zap.Any("collectionID", r.req.CollectionID), - zap.Any("vDeltaChannel", channel), - ) + for _, channel := range vChannels { r.node.tSafeReplica.removeTSafe(channel) } + for _, channel := range vDeltaChannels { + r.node.tSafeReplica.removeTSafe(channel) + } + log.Info("Release tSafe in releaseCollectionTask", + zap.Int64("collectionID", r.req.CollectionID), + zap.Strings("vChannels", vChannels), + zap.Strings("vDeltaChannels", vDeltaChannels), + ) - // remove excludedSegments record - replica.removeExcludedSegments(r.req.CollectionID) - err = replica.removeCollection(r.req.CollectionID) + r.node.metaReplica.removeExcludedSegments(r.req.CollectionID) + r.node.queryShardService.releaseCollection(r.req.CollectionID) + err = r.node.metaReplica.removeCollection(r.req.CollectionID) if err != nil { return err } + + debug.FreeOSMemory() + log.Info("ReleaseCollection done", zap.Int64("collectionID", r.req.CollectionID)) return nil } @@ -630,12 +599,7 @@ func (r *releasePartitionsTask) Execute(ctx context.Context) error { const gracefulReleaseTime = 1 time.Sleep(gracefulReleaseTime * time.Second) - // get collection from streaming and historical - _, err := r.node.historical.getCollectionByID(r.req.CollectionID) - if err != nil { - return fmt.Errorf("release partitions failed, collectionID = %d, err = %s", r.req.CollectionID, err) - } - _, err = r.node.streaming.getCollectionByID(r.req.CollectionID) + _, err := r.node.metaReplica.getCollectionByID(r.req.CollectionID) if err != nil { return fmt.Errorf("release partitions failed, collectionID = %d, err = %s", r.req.CollectionID, err) } @@ -643,17 +607,9 @@ func (r *releasePartitionsTask) Execute(ctx context.Context) error { for _, id := range r.req.PartitionIDs { // remove partition from streaming and historical - hasPartitionInHistorical := r.node.historical.hasPartition(id) - if hasPartitionInHistorical { - err := r.node.historical.removePartition(id) - if err != nil { - // not return, try to release all partitions - log.Warn(err.Error()) - } - } - hasPartitionInStreaming := r.node.streaming.hasPartition(id) - if hasPartitionInStreaming { - err := r.node.streaming.removePartition(id) + hasPartition := r.node.metaReplica.hasPartition(id) + if hasPartition { + err := r.node.metaReplica.removePartition(id) if err != nil { // not return, try to release all partitions log.Warn(err.Error()) diff --git a/internal/querynode/task_query.go b/internal/querynode/task_query.go index 28ba3517da..7c0444c0a4 100644 --- a/internal/querynode/task_query.go +++ b/internal/querynode/task_query.go @@ -50,6 +50,7 @@ func (q *queryTask) PreExecute(ctx context.Context) error { return nil } +// TODO: merge queryOnStreaming and queryOnHistorical? func (q *queryTask) queryOnStreaming() error { // check ctx timeout if !funcutil.CheckCtxValid(q.Ctx()) { @@ -57,7 +58,7 @@ func (q *queryTask) queryOnStreaming() error { } // check if collection has been released, check streaming since it's released first - _, err := q.QS.streaming.getCollectionByID(q.CollectionID) + _, err := q.QS.metaReplica.getCollectionByID(q.CollectionID) if err != nil { return err } @@ -76,7 +77,7 @@ func (q *queryTask) queryOnStreaming() error { } defer plan.delete() - sResults, _, _, sErr := retrieveStreaming(q.QS.streaming, plan, q.CollectionID, q.iReq.GetPartitionIDs(), q.QS.channel, q.QS.vectorChunkManager) + sResults, _, _, sErr := retrieveStreaming(q.QS.metaReplica, plan, q.CollectionID, q.iReq.GetPartitionIDs(), q.QS.channel, q.QS.vectorChunkManager) if sErr != nil { return sErr } @@ -103,7 +104,7 @@ func (q *queryTask) queryOnHistorical() error { } // check if collection has been released, check historical since it's released first - _, err := q.QS.streaming.getCollectionByID(q.CollectionID) + _, err := q.QS.metaReplica.getCollectionByID(q.CollectionID) if err != nil { return err } @@ -122,7 +123,7 @@ func (q *queryTask) queryOnHistorical() error { return err } defer plan.delete() - retrieveResults, _, _, err := retrieveHistorical(q.QS.historical, plan, q.CollectionID, nil, q.req.SegmentIDs, q.QS.vectorChunkManager) + retrieveResults, _, _, err := retrieveHistorical(q.QS.metaReplica, plan, q.CollectionID, nil, q.req.SegmentIDs, q.QS.vectorChunkManager) if err != nil { return err } diff --git a/internal/querynode/task_search.go b/internal/querynode/task_search.go index 0b0465012e..e06671b612 100644 --- a/internal/querynode/task_search.go +++ b/internal/querynode/task_search.go @@ -81,6 +81,7 @@ func (s *searchTask) init() error { return nil } +// TODO: merge searchOnStreaming and searchOnHistorical? func (s *searchTask) searchOnStreaming() error { // check ctx timeout if !funcutil.CheckCtxValid(s.Ctx()) { @@ -88,7 +89,7 @@ func (s *searchTask) searchOnStreaming() error { } // check if collection has been released, check streaming since it's released first - _, err := s.QS.streaming.getCollectionByID(s.CollectionID) + _, err := s.QS.metaReplica.getCollectionByID(s.CollectionID) if err != nil { return err } @@ -107,7 +108,7 @@ func (s *searchTask) searchOnStreaming() error { defer searchReq.delete() // TODO add context - partResults, _, _, sErr := searchStreaming(s.QS.streaming, searchReq, s.CollectionID, s.iReq.GetPartitionIDs(), s.req.GetDmlChannel()) + partResults, _, _, sErr := searchStreaming(s.QS.metaReplica, searchReq, s.CollectionID, s.iReq.GetPartitionIDs(), s.req.GetDmlChannel()) if sErr != nil { log.Debug("failed to search streaming data", zap.Int64("collectionID", s.CollectionID), zap.Error(sErr)) return sErr @@ -123,7 +124,7 @@ func (s *searchTask) searchOnHistorical() error { } // check if collection has been released, check streaming since it's released first - _, err := s.QS.streaming.getCollectionByID(s.CollectionID) + _, err := s.QS.metaReplica.getCollectionByID(s.CollectionID) if err != nil { return err } @@ -142,7 +143,7 @@ func (s *searchTask) searchOnHistorical() error { } defer searchReq.delete() - partResults, _, _, err := searchHistorical(s.QS.historical, searchReq, s.CollectionID, nil, segmentIDs) + partResults, _, _, err := searchHistorical(s.QS.metaReplica, searchReq, s.CollectionID, nil, segmentIDs) if err != nil { return err } diff --git a/internal/querynode/task_test.go b/internal/querynode/task_test.go index 7a8fabf1b3..b510a6a349 100644 --- a/internal/querynode/task_test.go +++ b/internal/querynode/task_test.go @@ -29,15 +29,13 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/util/typeutil" ) func TestTask_watchDmChannelsTask(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() genWatchDMChannelsRequest := func() *querypb.WatchDmChannelsRequest { req := &querypb.WatchDmChannelsRequest{ @@ -348,8 +346,7 @@ func TestTask_watchDeltaChannelsTask(t *testing.T) { func TestTask_loadSegmentsTask(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() genLoadEmptySegmentsRequest := func() *querypb.LoadSegmentsRequest { req := &querypb.LoadSegmentsRequest{ @@ -445,7 +442,7 @@ func TestTask_loadSegmentsTask(t *testing.T) { err = task.Execute(ctx) assert.NoError(t, err) // expected only one segment in replica - num := node.historical.getSegmentNum() + num := node.metaReplica.getSegmentNum(segmentTypeSealed) assert.Equal(t, 1, num) }) @@ -493,7 +490,7 @@ func TestTask_loadSegmentsTask(t *testing.T) { totalRAM := Params.QueryNodeCfg.CacheSize * 1024 * 1024 * 1024 - col, err := node.historical.getCollectionByID(defaultCollectionID) + col, err := node.metaReplica.getCollectionByID(defaultCollectionID) assert.NoError(t, err) sizePerRecord, err := typeutil.EstimateSizePerRecord(col.schema) @@ -577,9 +574,7 @@ func TestTask_releaseCollectionTask(t *testing.T) { node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) - err = node.streaming.removeCollection(defaultCollectionID) - assert.NoError(t, err) - err = node.historical.removeCollection(defaultCollectionID) + err = node.metaReplica.removeCollection(defaultCollectionID) assert.NoError(t, err) task := releaseCollectionTask{ @@ -598,7 +593,7 @@ func TestTask_releaseCollectionTask(t *testing.T) { err = node.queryService.addQueryCollection(defaultCollectionID) assert.NoError(t, err)*/ - col, err := node.historical.getCollectionByID(defaultCollectionID) + col, err := node.metaReplica.getCollectionByID(defaultCollectionID) assert.NoError(t, err) col.addVDeltaChannels([]Channel{defaultDeltaChannel}) @@ -673,10 +668,7 @@ func TestTask_releasePartitionTask(t *testing.T) { req: genReleasePartitionsRequest(), node: node, } - err = node.historical.removeCollection(defaultCollectionID) - assert.NoError(t, err) - - err = node.streaming.removeCollection(defaultCollectionID) + err = node.metaReplica.removeCollection(defaultCollectionID) assert.NoError(t, err) err = task.Execute(ctx) @@ -687,17 +679,14 @@ func TestTask_releasePartitionTask(t *testing.T) { node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) - hisCol, err := node.historical.getCollectionByID(defaultCollectionID) - assert.NoError(t, err) - strCol, err := node.streaming.getCollectionByID(defaultCollectionID) + col, err := node.metaReplica.getCollectionByID(defaultCollectionID) assert.NoError(t, err) - err = node.historical.removePartition(defaultPartitionID) + err = node.metaReplica.removePartition(defaultPartitionID) assert.NoError(t, err) - hisCol.addVDeltaChannels([]Channel{defaultDeltaChannel}) - hisCol.setLoadType(loadTypePartition) - strCol.setLoadType(loadTypePartition) + col.addVDeltaChannels([]Channel{defaultDeltaChannel}) + col.setLoadType(loadTypePartition) /* err = node.queryService.addQueryCollection(defaultCollectionID) diff --git a/internal/querynode/validate.go b/internal/querynode/validate.go index 16bc2394b4..b186ffe527 100644 --- a/internal/querynode/validate.go +++ b/internal/querynode/validate.go @@ -25,6 +25,7 @@ import ( "github.com/milvus-io/milvus/internal/log" ) +// TODO: merge validate? func validateOnHistoricalReplica(replica ReplicaInterface, collectionID UniqueID, partitionIDs []UniqueID, segmentIDs []UniqueID) ([]UniqueID, []UniqueID, error) { var err error var searchPartIDs []UniqueID @@ -63,7 +64,7 @@ func validateOnHistoricalReplica(replica ReplicaInterface, collectionID UniqueID var newSegmentIDs []UniqueID if len(segmentIDs) == 0 { for _, partID := range searchPartIDs { - segIDs, err2 := replica.getSegmentIDs(partID) + segIDs, err2 := replica.getSegmentIDs(partID, segmentTypeSealed) if err2 != nil { return searchPartIDs, newSegmentIDs, err } @@ -73,7 +74,7 @@ func validateOnHistoricalReplica(replica ReplicaInterface, collectionID UniqueID newSegmentIDs = segmentIDs for _, segmentID := range newSegmentIDs { var segment *Segment - if segment, err = replica.getSegmentByID(segmentID); err != nil { + if segment, err = replica.getSegmentByID(segmentID, segmentTypeSealed); err != nil { return searchPartIDs, newSegmentIDs, err } if !inList(searchPartIDs, segment.partitionID) { diff --git a/internal/querynode/validate_test.go b/internal/querynode/validate_test.go index 500c1237ad..2e5f4cfff6 100644 --- a/internal/querynode/validate_test.go +++ b/internal/querynode/validate_test.go @@ -20,7 +20,6 @@ import ( "context" "testing" - "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/stretchr/testify/assert" ) @@ -29,47 +28,46 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { defer cancel() t.Run("test normal validate", func(t *testing.T) { - his, err := genSimpleHistorical(ctx) + his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) _, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID}) assert.NoError(t, err) }) t.Run("test normal validate2", func(t *testing.T) { - his, err := genSimpleHistorical(ctx) + his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) _, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) assert.NoError(t, err) }) t.Run("test validate non-existent collection", func(t *testing.T) { - his, err := genSimpleHistorical(ctx) + his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) _, _, err = validateOnHistoricalReplica(his, defaultCollectionID+1, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID}) assert.Error(t, err) }) t.Run("test validate non-existent partition", func(t *testing.T) { - his, err := genSimpleHistorical(ctx) + his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) _, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{defaultPartitionID + 1}, []UniqueID{defaultSegmentID}) assert.Error(t, err) }) t.Run("test validate non-existent segment", func(t *testing.T) { - his, err := genSimpleHistorical(ctx) + his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) _, _, err = validateOnHistoricalReplica(his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID + 1}) assert.Error(t, err) }) t.Run("test validate segment not in given partition", func(t *testing.T) { - his, err := genSimpleHistorical(ctx) + his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) err = his.addPartition(defaultCollectionID, defaultPartitionID+1) assert.NoError(t, err) - pkType := schemapb.DataType_Int64 - schema := genTestCollectionSchema(pkType) + schema := genTestCollectionSchema() seg, err := genSealedSegment(schema, defaultCollectionID, defaultPartitionID+1, @@ -86,7 +84,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { }) t.Run("test validate after partition release", func(t *testing.T) { - his, err := genSimpleHistorical(ctx) + his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) err = his.removePartition(defaultPartitionID) assert.NoError(t, err) @@ -95,7 +93,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { }) t.Run("test validate after partition release2", func(t *testing.T) { - his, err := genSimpleHistorical(ctx) + his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) col, err := his.getCollectionByID(defaultCollectionID) assert.NoError(t, err) @@ -107,7 +105,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { }) t.Run("test validate after partition release3", func(t *testing.T) { - his, err := genSimpleHistorical(ctx) + his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) col, err := his.getCollectionByID(defaultCollectionID) assert.NoError(t, err)