diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index d8f0e9f2ab..2599374f3c 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -1983,7 +1983,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) ctx: ctx, Condition: NewTaskCondition(ctx), // req: request, - BaseInsertTask: BaseInsertTask{ + insertMsg: &msgstream.InsertMsg{ BaseMsg: msgstream.BaseMsg{ HashValues: request.HashKeys, }, @@ -2007,8 +2007,8 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) chTicker: node.chTicker, } - if len(it.PartitionName) <= 0 { - it.PartitionName = Params.CommonCfg.DefaultPartitionName.GetValue() + if len(it.insertMsg.PartitionName) <= 0 { + it.insertMsg.PartitionName = Params.CommonCfg.DefaultPartitionName.GetValue() } constructFailedResponse := func(err error) *milvuspb.MutationResult { @@ -2045,7 +2045,6 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) log.Debug("Detail of insert request in Proxy", zap.String("role", typeutil.ProxyRole), - zap.Int64("msgID", it.Base.MsgID), zap.Uint64("BeginTS", it.BeginTs()), zap.Uint64("EndTS", it.EndTs()), zap.String("db", request.DbName), @@ -2082,6 +2081,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) metrics.ProxyInsertVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(successCnt)) metrics.ProxyMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.ProxyCollectionMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel, request.CollectionName).Observe(float64(tr.ElapseSpan().Milliseconds())) + log.Debug("lxg debug", zap.Any("insertResult", it.result)) return it.result, nil } @@ -2112,7 +2112,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest) ctx: ctx, Condition: NewTaskCondition(ctx), deleteExpr: request.Expr, - BaseDeleteTask: BaseDeleteTask{ + deleteMsg: &BaseDeleteTask{ BaseMsg: msgstream.BaseMsg{ HashValues: request.HashKeys, }, @@ -2154,7 +2154,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest) log.Debug("Detail of delete request in Proxy", zap.String("role", typeutil.ProxyRole), - zap.Uint64("timestamp", dt.Base.Timestamp), + zap.Uint64("timestamp", dt.deleteMsg.Base.Timestamp), zap.String("db", request.DbName), zap.String("collection", request.CollectionName), zap.String("partition", request.PartitionName), diff --git a/internal/proxy/msg_pack.go b/internal/proxy/msg_pack.go new file mode 100644 index 0000000000..b365fe68b5 --- /dev/null +++ b/internal/proxy/msg_pack.go @@ -0,0 +1,184 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + + "github.com/milvus-io/milvus-proto/go-api/commonpb" + "github.com/milvus-io/milvus-proto/go-api/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/schemapb" + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/mq/msgstream" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/util/commonpbutil" + "github.com/milvus-io/milvus/internal/util/retry" + "github.com/milvus-io/milvus/internal/util/typeutil" + "go.uber.org/zap" +) + +func assignSegmentID(ctx context.Context, insertMsg *msgstream.InsertMsg, result *milvuspb.MutationResult, channelNames []string, idAllocator *allocator.IDAllocator, segIDAssigner *segIDAssigner) (*msgstream.MsgPack, error) { + threshold := Params.PulsarCfg.MaxMessageSize.GetAsInt() + log.Debug("assign segmentid", zap.Int("threshold", threshold)) + + msgPack := &msgstream.MsgPack{ + BeginTs: insertMsg.BeginTs(), + EndTs: insertMsg.EndTs(), + } + + // generate hash value for every primary key + if len(insertMsg.HashValues) != 0 { + log.Warn("the hashvalues passed through client is not supported now, and will be overwritten") + } + insertMsg.HashValues = typeutil.HashPK2Channels(result.IDs, channelNames) + // groupedHashKeys represents the dmChannel index + channel2RowOffsets := make(map[string][]int) // channelName to count + channelMaxTSMap := make(map[string]Timestamp) // channelName to max Timestamp + + // assert len(it.hashValues) < maxInt + for offset, channelID := range insertMsg.HashValues { + channelName := channelNames[channelID] + if _, ok := channel2RowOffsets[channelName]; !ok { + channel2RowOffsets[channelName] = []int{} + } + channel2RowOffsets[channelName] = append(channel2RowOffsets[channelName], offset) + + if _, ok := channelMaxTSMap[channelName]; !ok { + channelMaxTSMap[channelName] = typeutil.ZeroTimestamp + } + ts := insertMsg.Timestamps[offset] + if channelMaxTSMap[channelName] < ts { + channelMaxTSMap[channelName] = ts + } + } + + // pre-alloc msg id by batch + var idBegin, idEnd int64 + var err error + + // fetch next id, if not id available, fetch next batch + // lazy fetch, get first batch after first getMsgID called + getMsgID := func() (int64, error) { + if idBegin == idEnd { + err = retry.Do(ctx, func() error { + idBegin, idEnd, err = idAllocator.Alloc(16) + return err + }) + if err != nil { + log.Error("failed to allocate msg id", zap.Int64("base.MsgID", insertMsg.Base.MsgID), zap.Error(err)) + return 0, err + } + } + result := idBegin + idBegin++ + return result, nil + } + + // create empty insert message + createInsertMsg := func(segmentID UniqueID, channelName string, msgID int64) *msgstream.InsertMsg { + insertReq := internalpb.InsertRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_Insert), + commonpbutil.WithMsgID(msgID), + commonpbutil.WithTimeStamp(insertMsg.BeginTimestamp), // entity's timestamp was set to equal it.BeginTimestamp in preExecute() + commonpbutil.WithSourceID(insertMsg.Base.SourceID), + ), + CollectionID: insertMsg.CollectionID, + PartitionID: insertMsg.PartitionID, + CollectionName: insertMsg.CollectionName, + PartitionName: insertMsg.PartitionName, + SegmentID: segmentID, + ShardName: channelName, + Version: internalpb.InsertDataVersion_ColumnBased, + } + insertReq.FieldsData = make([]*schemapb.FieldData, len(insertMsg.GetFieldsData())) + + msg := &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + Ctx: ctx, + }, + InsertRequest: insertReq, + } + + return msg + } + + // repack the row data corresponding to the offset to insertMsg + getInsertMsgsBySegmentID := func(segmentID UniqueID, rowOffsets []int, channelName string, maxMessageSize int) ([]msgstream.TsMsg, error) { + repackedMsgs := make([]msgstream.TsMsg, 0) + requestSize := 0 + msgID, err := getMsgID() + if err != nil { + return nil, err + } + msg := createInsertMsg(segmentID, channelName, msgID) + for _, offset := range rowOffsets { + curRowMessageSize, err := typeutil.EstimateEntitySize(insertMsg.GetFieldsData(), offset) + if err != nil { + return nil, err + } + + // if insertMsg's size is greater than the threshold, split into multiple insertMsgs + if requestSize+curRowMessageSize >= maxMessageSize { + repackedMsgs = append(repackedMsgs, msg) + msgID, err = getMsgID() + if err != nil { + return nil, err + } + msg = createInsertMsg(segmentID, channelName, msgID) + requestSize = 0 + } + + typeutil.AppendFieldData(msg.FieldsData, insertMsg.GetFieldsData(), int64(offset)) + msg.HashValues = append(msg.HashValues, insertMsg.HashValues[offset]) + msg.Timestamps = append(msg.Timestamps, insertMsg.Timestamps[offset]) + msg.RowIDs = append(msg.RowIDs, insertMsg.RowIDs[offset]) + msg.NumRows++ + requestSize += curRowMessageSize + } + repackedMsgs = append(repackedMsgs, msg) + + return repackedMsgs, nil + } + + // get allocated segmentID info for every dmChannel and repack insertMsgs for every segmentID + for channelName, rowOffsets := range channel2RowOffsets { + assignedSegmentInfos, err := segIDAssigner.GetSegmentID(insertMsg.CollectionID, insertMsg.PartitionID, channelName, uint32(len(rowOffsets)), channelMaxTSMap[channelName]) + if err != nil { + log.Error("allocate segmentID for insert data failed", zap.Int64("collectionID", insertMsg.CollectionID), zap.String("channel name", channelName), + zap.Int("allocate count", len(rowOffsets)), + zap.Error(err)) + return nil, err + } + + startPos := 0 + for segmentID, count := range assignedSegmentInfos { + subRowOffsets := rowOffsets[startPos : startPos+int(count)] + insertMsgs, err := getInsertMsgsBySegmentID(segmentID, subRowOffsets, channelName, threshold) + if err != nil { + log.Error("repack insert data to insert msgs failed", zap.Int64("collectionID", insertMsg.CollectionID), + zap.Error(err)) + return nil, err + } + msgPack.Msgs = append(msgPack.Msgs, insertMsgs...) + startPos += int(count) + } + } + + return msgPack, nil +} diff --git a/internal/proxy/task_delete.go b/internal/proxy/task_delete.go index 7b46776f01..a1923ece78 100644 --- a/internal/proxy/task_delete.go +++ b/internal/proxy/task_delete.go @@ -27,7 +27,7 @@ type BaseDeleteTask = msgstream.DeleteMsg type deleteTask struct { Condition - BaseDeleteTask + deleteMsg *BaseDeleteTask ctx context.Context deleteExpr string //req *milvuspb.DeleteRequest @@ -46,15 +46,15 @@ func (dt *deleteTask) TraceCtx() context.Context { } func (dt *deleteTask) ID() UniqueID { - return dt.Base.MsgID + return dt.deleteMsg.Base.MsgID } func (dt *deleteTask) SetID(uid UniqueID) { - dt.Base.MsgID = uid + dt.deleteMsg.Base.MsgID = uid } func (dt *deleteTask) Type() commonpb.MsgType { - return dt.Base.MsgType + return dt.deleteMsg.Base.MsgType } func (dt *deleteTask) Name() string { @@ -62,19 +62,19 @@ func (dt *deleteTask) Name() string { } func (dt *deleteTask) BeginTs() Timestamp { - return dt.Base.Timestamp + return dt.deleteMsg.Base.Timestamp } func (dt *deleteTask) EndTs() Timestamp { - return dt.Base.Timestamp + return dt.deleteMsg.Base.Timestamp } func (dt *deleteTask) SetTs(ts Timestamp) { - dt.Base.Timestamp = ts + dt.deleteMsg.Base.Timestamp = ts } func (dt *deleteTask) OnEnqueue() error { - dt.DeleteRequest.Base = commonpbutil.NewMsgBase() + dt.deleteMsg.Base = commonpbutil.NewMsgBase() return nil } @@ -99,7 +99,7 @@ func (dt *deleteTask) getPChanStats() (map[pChan]pChanStatistics, error) { } func (dt *deleteTask) getChannels() ([]pChan, error) { - collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.CollectionName) + collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.deleteMsg.CollectionName) if err != nil { return nil, err } @@ -154,8 +154,8 @@ func getPrimaryKeysFromExpr(schema *schemapb.CollectionSchema, expr string) (res } func (dt *deleteTask) PreExecute(ctx context.Context) error { - dt.Base.MsgType = commonpb.MsgType_Delete - dt.Base.SourceID = paramtable.GetNodeID() + dt.deleteMsg.Base.MsgType = commonpb.MsgType_Delete + dt.deleteMsg.Base.SourceID = paramtable.GetNodeID() dt.result = &milvuspb.MutationResult{ Status: &commonpb.Status{ @@ -167,7 +167,7 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error { Timestamp: dt.BeginTs(), } - collName := dt.CollectionName + collName := dt.deleteMsg.CollectionName if err := validateCollectionName(collName); err != nil { log.Info("Invalid collection name", zap.String("collectionName", collName), zap.Error(err)) return err @@ -177,12 +177,12 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error { log.Info("Failed to get collection id", zap.String("collectionName", collName), zap.Error(err)) return err } - dt.DeleteRequest.CollectionID = collID + dt.deleteMsg.CollectionID = collID dt.collectionID = collID // If partitionName is not empty, partitionID will be set. - if len(dt.PartitionName) > 0 { - partName := dt.PartitionName + if len(dt.deleteMsg.PartitionName) > 0 { + partName := dt.deleteMsg.PartitionName if err := validatePartitionTag(partName, true); err != nil { log.Info("Invalid partition name", zap.String("partitionName", partName), zap.Error(err)) return err @@ -192,9 +192,9 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error { log.Info("Failed to get partition id", zap.String("collectionName", collName), zap.String("partitionName", partName), zap.Error(err)) return err } - dt.DeleteRequest.PartitionID = partID + dt.deleteMsg.PartitionID = partID } else { - dt.DeleteRequest.PartitionID = common.InvalidPartitionID + dt.deleteMsg.PartitionID = common.InvalidPartitionID } schema, err := globalMetaCache.GetCollectionSchema(ctx, collName) @@ -211,17 +211,17 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error { return err } - dt.DeleteRequest.NumRows = numRow - dt.DeleteRequest.PrimaryKeys = primaryKeys - log.Debug("get primary keys from expr", zap.Int64("len of primary keys", dt.DeleteRequest.NumRows)) + dt.deleteMsg.NumRows = numRow + dt.deleteMsg.PrimaryKeys = primaryKeys + log.Debug("get primary keys from expr", zap.Int64("len of primary keys", dt.deleteMsg.NumRows)) // set result dt.result.IDs = primaryKeys - dt.result.DeleteCnt = dt.DeleteRequest.NumRows + dt.result.DeleteCnt = dt.deleteMsg.NumRows - dt.Timestamps = make([]uint64, numRow) - for index := range dt.Timestamps { - dt.Timestamps[index] = dt.BeginTs() + dt.deleteMsg.Timestamps = make([]uint64, numRow) + for index := range dt.deleteMsg.Timestamps { + dt.deleteMsg.Timestamps[index] = dt.BeginTs() } return nil @@ -233,7 +233,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID())) - collID := dt.DeleteRequest.CollectionID + collID := dt.deleteMsg.CollectionID stream, err := dt.chMgr.getOrCreateDmlStream(collID) if err != nil { return err @@ -247,10 +247,10 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { dt.result.Status.Reason = err.Error() return err } - dt.HashValues = typeutil.HashPK2Channels(dt.result.IDs, channelNames) + dt.deleteMsg.HashValues = typeutil.HashPK2Channels(dt.result.IDs, channelNames) log.Debug("send delete request to virtual channels", - zap.String("collection", dt.GetCollectionName()), + zap.String("collection", dt.deleteMsg.GetCollectionName()), zap.Int64("collection_id", collID), zap.Strings("virtual_channels", channelNames), zap.Int64("task_id", dt.ID())) @@ -258,19 +258,19 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { tr.Record("get vchannels") // repack delete msg by dmChannel result := make(map[uint32]msgstream.TsMsg) - collectionName := dt.CollectionName - collectionID := dt.CollectionID - partitionID := dt.PartitionID - partitionName := dt.PartitionName - proxyID := dt.Base.SourceID - for index, key := range dt.HashValues { - ts := dt.Timestamps[index] + collectionName := dt.deleteMsg.CollectionName + collectionID := dt.deleteMsg.CollectionID + partitionID := dt.deleteMsg.PartitionID + partitionName := dt.deleteMsg.PartitionName + proxyID := dt.deleteMsg.Base.SourceID + for index, key := range dt.deleteMsg.HashValues { + ts := dt.deleteMsg.Timestamps[index] _, ok := result[key] if !ok { sliceRequest := internalpb.DeleteRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_Delete), - commonpbutil.WithMsgID(dt.Base.MsgID), + commonpbutil.WithMsgID(dt.deleteMsg.Base.MsgID), commonpbutil.WithTimeStamp(ts), commonpbutil.WithSourceID(proxyID), ), @@ -289,9 +289,9 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { result[key] = deleteMsg } curMsg := result[key].(*msgstream.DeleteMsg) - curMsg.HashValues = append(curMsg.HashValues, dt.HashValues[index]) - curMsg.Timestamps = append(curMsg.Timestamps, dt.Timestamps[index]) - typeutil.AppendIDs(curMsg.PrimaryKeys, dt.PrimaryKeys, index) + curMsg.HashValues = append(curMsg.HashValues, dt.deleteMsg.HashValues[index]) + curMsg.Timestamps = append(curMsg.Timestamps, dt.deleteMsg.Timestamps[index]) + typeutil.AppendIDs(curMsg.PrimaryKeys, dt.deleteMsg.PrimaryKeys, index) curMsg.NumRows++ } diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index b7aabe055e..1a6095b589 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -12,21 +12,17 @@ import ( "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/mq/msgstream" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/util/commonpbutil" "github.com/milvus-io/milvus/internal/util/paramtable" - "github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/timerecord" "github.com/milvus-io/milvus/internal/util/trace" - "github.com/milvus-io/milvus/internal/util/typeutil" "go.uber.org/zap" ) type insertTask struct { - BaseInsertTask // req *milvuspb.InsertRequest Condition - ctx context.Context + insertMsg *BaseInsertTask + ctx context.Context result *milvuspb.MutationResult idAllocator *allocator.IDAllocator @@ -44,11 +40,11 @@ func (it *insertTask) TraceCtx() context.Context { } func (it *insertTask) ID() UniqueID { - return it.Base.MsgID + return it.insertMsg.Base.MsgID } func (it *insertTask) SetID(uid UniqueID) { - it.Base.MsgID = uid + it.insertMsg.Base.MsgID = uid } func (it *insertTask) Name() string { @@ -56,20 +52,20 @@ func (it *insertTask) Name() string { } func (it *insertTask) Type() commonpb.MsgType { - return it.Base.MsgType + return it.insertMsg.Base.MsgType } func (it *insertTask) BeginTs() Timestamp { - return it.BeginTimestamp + return it.insertMsg.BeginTimestamp } func (it *insertTask) SetTs(ts Timestamp) { - it.BeginTimestamp = ts - it.EndTimestamp = ts + it.insertMsg.BeginTimestamp = ts + it.insertMsg.EndTimestamp = ts } func (it *insertTask) EndTs() Timestamp { - return it.EndTimestamp + return it.insertMsg.EndTimestamp } func (it *insertTask) getPChanStats() (map[pChan]pChanStatistics, error) { @@ -93,7 +89,7 @@ func (it *insertTask) getPChanStats() (map[pChan]pChanStatistics, error) { } func (it *insertTask) getChannels() ([]pChan, error) { - collID, err := globalMetaCache.GetCollectionID(it.ctx, it.CollectionName) + collID, err := globalMetaCache.GetCollectionID(it.ctx, it.insertMsg.CollectionName) if err != nil { return nil, err } @@ -104,71 +100,6 @@ func (it *insertTask) OnEnqueue() error { return nil } -func (it *insertTask) checkLengthOfFieldsData() error { - neededFieldsNum := 0 - for _, field := range it.schema.Fields { - if !field.AutoID { - neededFieldsNum++ - } - } - - if len(it.FieldsData) < neededFieldsNum { - return errFieldsLessThanNeeded(len(it.FieldsData), neededFieldsNum) - } - - return nil -} - -func (it *insertTask) checkPrimaryFieldData() error { - rowNums := uint32(it.NRows()) - // TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields - if it.NRows() <= 0 { - return errNumRowsLessThanOrEqualToZero(rowNums) - } - - if err := it.checkLengthOfFieldsData(); err != nil { - return err - } - - primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(it.schema) - if err != nil { - log.Error("get primary field schema failed", zap.String("collectionName", it.CollectionName), zap.Any("schema", it.schema), zap.Error(err)) - return err - } - - // get primaryFieldData whether autoID is true or not - var primaryFieldData *schemapb.FieldData - if !primaryFieldSchema.AutoID { - primaryFieldData, err = typeutil.GetPrimaryFieldData(it.GetFieldsData(), primaryFieldSchema) - if err != nil { - log.Error("get primary field data failed", zap.String("collectionName", it.CollectionName), zap.Error(err)) - return err - } - } else { - // check primary key data not exist - if typeutil.IsPrimaryFieldDataExist(it.GetFieldsData(), primaryFieldSchema) { - return fmt.Errorf("can not assign primary field data when auto id enabled %v", primaryFieldSchema.Name) - } - // if autoID == true, currently only support autoID for int64 PrimaryField - primaryFieldData, err = autoGenPrimaryFieldData(primaryFieldSchema, it.RowIDs) - if err != nil { - log.Error("generate primary field data failed when autoID == true", zap.String("collectionName", it.CollectionName), zap.Error(err)) - return err - } - // if autoID == true, set the primary field data - it.FieldsData = append(it.FieldsData, primaryFieldData) - } - - // parse primaryFieldData to result.IDs, and as returned primary keys - it.result.IDs, err = parsePrimaryFieldData2IDs(primaryFieldData) - if err != nil { - log.Error("parse primary field data to IDs failed", zap.String("collectionName", it.CollectionName), zap.Error(err)) - return err - } - - return nil -} - func (it *insertTask) PreExecute(ctx context.Context) error { sp, ctx := trace.StartSpanFromContextWithOperationName(it.ctx, "Proxy-Insert-PreExecute") defer sp.Finish() @@ -183,13 +114,13 @@ func (it *insertTask) PreExecute(ctx context.Context) error { Timestamp: it.EndTs(), } - collectionName := it.CollectionName + collectionName := it.insertMsg.CollectionName if err := validateCollectionName(collectionName); err != nil { log.Error("valid collection name failed", zap.String("collectionName", collectionName), zap.Error(err)) return err } - partitionTag := it.PartitionName + partitionTag := it.insertMsg.PartitionName if err := validatePartitionTag(partitionTag, true); err != nil { log.Error("valid partition name failed", zap.String("partition name", partitionTag), zap.Error(err)) return err @@ -202,7 +133,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error { } it.schema = collSchema - rowNums := uint32(it.NRows()) + rowNums := uint32(it.insertMsg.NRows()) // set insertTask.rowIDs var rowIDBegin UniqueID var rowIDEnd UniqueID @@ -210,16 +141,16 @@ func (it *insertTask) PreExecute(ctx context.Context) error { rowIDBegin, rowIDEnd, _ = it.idAllocator.Alloc(rowNums) metrics.ProxyApplyPrimaryKeyLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds())) - it.RowIDs = make([]UniqueID, rowNums) + it.insertMsg.RowIDs = make([]UniqueID, rowNums) for i := rowIDBegin; i < rowIDEnd; i++ { offset := i - rowIDBegin - it.RowIDs[offset] = i + it.insertMsg.RowIDs[offset] = i } // set insertTask.timeStamps - rowNum := it.NRows() - it.Timestamps = make([]uint64, rowNum) - for index := range it.Timestamps { - it.Timestamps[index] = it.BeginTimestamp + rowNum := it.insertMsg.NRows() + it.insertMsg.Timestamps = make([]uint64, rowNum) + for index := range it.insertMsg.Timestamps { + it.insertMsg.Timestamps[index] = it.insertMsg.BeginTimestamp } // set result.SuccIndex @@ -231,7 +162,8 @@ func (it *insertTask) PreExecute(ctx context.Context) error { // check primaryFieldData whether autoID is true or not // set rowIDs as primary data if autoID == true - err = it.checkPrimaryFieldData() + // TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields + it.result.IDs, err = checkPrimaryFieldData(it.schema, it.insertMsg) log := log.Ctx(ctx).With(zap.String("collectionName", collectionName)) if err != nil { log.Error("check primary field data and hash primary key failed", @@ -240,7 +172,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error { } // set field ID to insert field data - err = fillFieldIDBySchema(it.GetFieldsData(), collSchema) + err = fillFieldIDBySchema(it.insertMsg.GetFieldsData(), collSchema) if err != nil { log.Error("set fieldID to fieldData failed", zap.Error(err)) @@ -248,7 +180,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error { } // check that all field's number rows are equal - if err = it.CheckAligned(); err != nil { + if err = it.insertMsg.CheckAligned(); err != nil { log.Error("field data is not aligned", zap.Error(err)) return err @@ -259,160 +191,6 @@ func (it *insertTask) PreExecute(ctx context.Context) error { return nil } -func (it *insertTask) assignSegmentID(channelNames []string) (*msgstream.MsgPack, error) { - threshold := Params.PulsarCfg.MaxMessageSize.GetAsInt() - log.Debug("assign segmentid", zap.Int("threshold", threshold)) - - result := &msgstream.MsgPack{ - BeginTs: it.BeginTs(), - EndTs: it.EndTs(), - } - - // generate hash value for every primary key - if len(it.HashValues) != 0 { - log.Warn("the hashvalues passed through client is not supported now, and will be overwritten") - } - it.HashValues = typeutil.HashPK2Channels(it.result.IDs, channelNames) - // groupedHashKeys represents the dmChannel index - channel2RowOffsets := make(map[string][]int) // channelName to count - channelMaxTSMap := make(map[string]Timestamp) // channelName to max Timestamp - - // assert len(it.hashValues) < maxInt - for offset, channelID := range it.HashValues { - channelName := channelNames[channelID] - if _, ok := channel2RowOffsets[channelName]; !ok { - channel2RowOffsets[channelName] = []int{} - } - channel2RowOffsets[channelName] = append(channel2RowOffsets[channelName], offset) - - if _, ok := channelMaxTSMap[channelName]; !ok { - channelMaxTSMap[channelName] = typeutil.ZeroTimestamp - } - ts := it.Timestamps[offset] - if channelMaxTSMap[channelName] < ts { - channelMaxTSMap[channelName] = ts - } - } - - // pre-alloc msg id by batch - var idBegin, idEnd int64 - var err error - - // fetch next id, if not id available, fetch next batch - // lazy fetch, get first batch after first getMsgID called - getMsgID := func() (int64, error) { - if idBegin == idEnd { - err = retry.Do(it.ctx, func() error { - idBegin, idEnd, err = it.idAllocator.Alloc(16) - return err - }) - if err != nil { - log.Error("failed to allocate msg id", zap.Int64("base.MsgID", it.Base.MsgID), zap.Error(err)) - return 0, err - } - } - result := idBegin - idBegin++ - return result, nil - } - - // create empty insert message - createInsertMsg := func(segmentID UniqueID, channelName string, msgID int64) *msgstream.InsertMsg { - insertReq := internalpb.InsertRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_Insert), - commonpbutil.WithMsgID(msgID), - commonpbutil.WithTimeStamp(it.BeginTimestamp), // entity's timestamp was set to equal it.BeginTimestamp in preExecute() - commonpbutil.WithSourceID(it.Base.SourceID), - ), - CollectionID: it.CollectionID, - PartitionID: it.PartitionID, - CollectionName: it.CollectionName, - PartitionName: it.PartitionName, - SegmentID: segmentID, - ShardName: channelName, - Version: internalpb.InsertDataVersion_ColumnBased, - } - insertReq.FieldsData = make([]*schemapb.FieldData, len(it.GetFieldsData())) - - insertMsg := &msgstream.InsertMsg{ - BaseMsg: msgstream.BaseMsg{ - Ctx: it.TraceCtx(), - }, - InsertRequest: insertReq, - } - - return insertMsg - } - - // repack the row data corresponding to the offset to insertMsg - getInsertMsgsBySegmentID := func(segmentID UniqueID, rowOffsets []int, channelName string, maxMessageSize int) ([]msgstream.TsMsg, error) { - repackedMsgs := make([]msgstream.TsMsg, 0) - requestSize := 0 - msgID, err := getMsgID() - if err != nil { - return nil, err - } - insertMsg := createInsertMsg(segmentID, channelName, msgID) - for _, offset := range rowOffsets { - curRowMessageSize, err := typeutil.EstimateEntitySize(it.InsertRequest.GetFieldsData(), offset) - if err != nil { - return nil, err - } - - // if insertMsg's size is greater than the threshold, split into multiple insertMsgs - if requestSize+curRowMessageSize >= maxMessageSize { - repackedMsgs = append(repackedMsgs, insertMsg) - msgID, err = getMsgID() - if err != nil { - return nil, err - } - insertMsg = createInsertMsg(segmentID, channelName, msgID) - requestSize = 0 - } - - typeutil.AppendFieldData(insertMsg.FieldsData, it.GetFieldsData(), int64(offset)) - insertMsg.HashValues = append(insertMsg.HashValues, it.HashValues[offset]) - insertMsg.Timestamps = append(insertMsg.Timestamps, it.Timestamps[offset]) - insertMsg.RowIDs = append(insertMsg.RowIDs, it.RowIDs[offset]) - insertMsg.NumRows++ - requestSize += curRowMessageSize - } - repackedMsgs = append(repackedMsgs, insertMsg) - - return repackedMsgs, nil - } - - // get allocated segmentID info for every dmChannel and repack insertMsgs for every segmentID - for channelName, rowOffsets := range channel2RowOffsets { - assignedSegmentInfos, err := it.segIDAssigner.GetSegmentID(it.CollectionID, it.PartitionID, channelName, uint32(len(rowOffsets)), channelMaxTSMap[channelName]) - if err != nil { - log.Error("allocate segmentID for insert data failed", - zap.Int64("collectionID", it.CollectionID), - zap.String("channel name", channelName), - zap.Int("allocate count", len(rowOffsets)), - zap.Error(err)) - return nil, err - } - - startPos := 0 - for segmentID, count := range assignedSegmentInfos { - subRowOffsets := rowOffsets[startPos : startPos+int(count)] - insertMsgs, err := getInsertMsgsBySegmentID(segmentID, subRowOffsets, channelName, threshold) - if err != nil { - log.Error("repack insert data to insert msgs failed", - zap.Int64("collectionID", it.CollectionID), - zap.Error(err)) - return nil, err - } - result.Msgs = append(result.Msgs, insertMsgs...) - startPos += int(count) - } - } - - return result, nil -} - func (it *insertTask) Execute(ctx context.Context) error { sp, ctx := trace.StartSpanFromContextWithOperationName(it.ctx, "Proxy-Insert-Execute") defer sp.Finish() @@ -420,15 +198,15 @@ func (it *insertTask) Execute(ctx context.Context) error { tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute insert %d", it.ID())) defer tr.Elapse("insert execute done") - collectionName := it.CollectionName + collectionName := it.insertMsg.CollectionName collID, err := globalMetaCache.GetCollectionID(ctx, collectionName) if err != nil { return err } - it.CollectionID = collID + it.insertMsg.CollectionID = collID var partitionID UniqueID - if len(it.PartitionName) > 0 { - partitionID, err = globalMetaCache.GetPartitionID(ctx, collectionName, it.PartitionName) + if len(it.insertMsg.PartitionName) > 0 { + partitionID, err = globalMetaCache.GetPartitionID(ctx, collectionName, it.insertMsg.PartitionName) if err != nil { return err } @@ -438,7 +216,7 @@ func (it *insertTask) Execute(ctx context.Context) error { return err } } - it.PartitionID = partitionID + it.insertMsg.PartitionID = partitionID tr.Record("get collection id & partition id from cache") stream, err := it.chMgr.getOrCreateDmlStream(collID) @@ -458,15 +236,16 @@ func (it *insertTask) Execute(ctx context.Context) error { } log.Ctx(ctx).Debug("send insert request to virtual channels", - zap.String("collection", it.GetCollectionName()), - zap.String("partition", it.GetPartitionName()), + zap.String("collection", it.insertMsg.GetCollectionName()), + zap.String("partition", it.insertMsg.GetPartitionName()), zap.Int64("collection_id", collID), zap.Int64("partition_id", partitionID), zap.Strings("virtual_channels", channelNames), zap.Int64("task_id", it.ID())) // assign segmentID for insert data and repack data by segmentID - msgPack, err := it.assignSegmentID(channelNames) + var msgPack *msgstream.MsgPack + msgPack, err = assignSegmentID(it.TraceCtx(), it.insertMsg, it.result, channelNames, it.idAllocator, it.segIDAssigner) if err != nil { log.Error("assign segmentID and repack insert data failed", zap.Int64("collectionID", collID), @@ -477,7 +256,7 @@ func (it *insertTask) Execute(ctx context.Context) error { } log.Debug("assign segmentID for insert data success", zap.Int64("collectionID", collID), - zap.String("collectionName", it.CollectionName)) + zap.String("collectionName", it.insertMsg.CollectionName)) tr.Record("assign segment id") err = stream.Produce(msgPack) if err != nil { diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index 73beeefab5..09f6250a4d 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -9,140 +9,12 @@ import ( "github.com/stretchr/testify/assert" ) -func TestInsertTask_checkLengthOfFieldsData(t *testing.T) { - var err error - - // schema is empty, though won't happen in system - case1 := insertTask{ - schema: &schemapb.CollectionSchema{ - Name: "TestInsertTask_checkLengthOfFieldsData", - Description: "TestInsertTask_checkLengthOfFieldsData", - AutoID: false, - Fields: []*schemapb.FieldSchema{}, - }, - BaseInsertTask: BaseInsertTask{ - InsertRequest: internalpb.InsertRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Insert, - }, - DbName: "TestInsertTask_checkLengthOfFieldsData", - CollectionName: "TestInsertTask_checkLengthOfFieldsData", - PartitionName: "TestInsertTask_checkLengthOfFieldsData", - }, - }, - } - - err = case1.checkLengthOfFieldsData() - assert.Equal(t, nil, err) - - // schema has two fields, neither of them are autoID - case2 := insertTask{ - schema: &schemapb.CollectionSchema{ - Name: "TestInsertTask_checkLengthOfFieldsData", - Description: "TestInsertTask_checkLengthOfFieldsData", - AutoID: false, - Fields: []*schemapb.FieldSchema{ - { - AutoID: false, - DataType: schemapb.DataType_Int64, - }, - { - AutoID: false, - DataType: schemapb.DataType_Int64, - }, - }, - }, - } - // passed fields is empty - // case2.BaseInsertTask = BaseInsertTask{ - // InsertRequest: internalpb.InsertRequest{ - // Base: &commonpb.MsgBase{ - // MsgType: commonpb.MsgType_Insert, - // MsgID: 0, - // SourceID: paramtable.GetNodeID(), - // }, - // }, - // } - err = case2.checkLengthOfFieldsData() - assert.NotEqual(t, nil, err) - // the num of passed fields is less than needed - case2.FieldsData = []*schemapb.FieldData{ - { - Type: schemapb.DataType_Int64, - }, - } - err = case2.checkLengthOfFieldsData() - assert.NotEqual(t, nil, err) - // satisfied - case2.FieldsData = []*schemapb.FieldData{ - { - Type: schemapb.DataType_Int64, - }, - { - Type: schemapb.DataType_Int64, - }, - } - err = case2.checkLengthOfFieldsData() - assert.Equal(t, nil, err) - - // schema has two field, one of them are autoID - case3 := insertTask{ - schema: &schemapb.CollectionSchema{ - Name: "TestInsertTask_checkLengthOfFieldsData", - Description: "TestInsertTask_checkLengthOfFieldsData", - AutoID: false, - Fields: []*schemapb.FieldSchema{ - { - AutoID: true, - DataType: schemapb.DataType_Int64, - }, - { - AutoID: false, - DataType: schemapb.DataType_Int64, - }, - }, - }, - } - // passed fields is empty - // case3.req = &milvuspb.InsertRequest{} - err = case3.checkLengthOfFieldsData() - assert.NotEqual(t, nil, err) - // satisfied - case3.FieldsData = []*schemapb.FieldData{ - { - Type: schemapb.DataType_Int64, - }, - } - err = case3.checkLengthOfFieldsData() - assert.Equal(t, nil, err) - - // schema has one field which is autoID - case4 := insertTask{ - schema: &schemapb.CollectionSchema{ - Name: "TestInsertTask_checkLengthOfFieldsData", - Description: "TestInsertTask_checkLengthOfFieldsData", - AutoID: false, - Fields: []*schemapb.FieldSchema{ - { - AutoID: true, - DataType: schemapb.DataType_Int64, - }, - }, - }, - } - // passed fields is empty - // satisfied - // case4.req = &milvuspb.InsertRequest{} - err = case4.checkLengthOfFieldsData() - assert.Equal(t, nil, err) -} - func TestInsertTask_CheckAligned(t *testing.T) { var err error // passed NumRows is less than 0 case1 := insertTask{ - BaseInsertTask: BaseInsertTask{ + insertMsg: &BaseInsertTask{ InsertRequest: internalpb.InsertRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Insert, @@ -151,7 +23,7 @@ func TestInsertTask_CheckAligned(t *testing.T) { }, }, } - err = case1.CheckAligned() + err = case1.insertMsg.CheckAligned() assert.NoError(t, err) // checkLengthOfFieldsData was already checked by TestInsertTask_checkLengthOfFieldsData @@ -170,7 +42,7 @@ func TestInsertTask_CheckAligned(t *testing.T) { numRows := 20 dim := 128 case2 := insertTask{ - BaseInsertTask: BaseInsertTask{ + insertMsg: &BaseInsertTask{ InsertRequest: internalpb.InsertRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Insert, @@ -200,8 +72,8 @@ func TestInsertTask_CheckAligned(t *testing.T) { } // satisfied - case2.NumRows = uint64(numRows) - case2.FieldsData = []*schemapb.FieldData{ + case2.insertMsg.NumRows = uint64(numRows) + case2.insertMsg.FieldsData = []*schemapb.FieldData{ newScalarFieldData(boolFieldSchema, "Bool", numRows), newScalarFieldData(int8FieldSchema, "Int8", numRows), newScalarFieldData(int16FieldSchema, "Int16", numRows), @@ -213,136 +85,136 @@ func TestInsertTask_CheckAligned(t *testing.T) { newBinaryVectorFieldData("BinaryVector", numRows, dim), newScalarFieldData(varCharFieldSchema, "VarChar", numRows), } - err = case2.CheckAligned() + err = case2.insertMsg.CheckAligned() assert.NoError(t, err) // less bool data - case2.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows/2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows/2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // more bool data - case2.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows*2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows*2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // revert - case2.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows) + err = case2.insertMsg.CheckAligned() assert.NoError(t, err) // less int8 data - case2.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows/2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows/2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // more int8 data - case2.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows*2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows*2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // revert - case2.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows) + err = case2.insertMsg.CheckAligned() assert.NoError(t, err) // less int16 data - case2.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows/2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows/2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // more int16 data - case2.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows*2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows*2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // revert - case2.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows) + err = case2.insertMsg.CheckAligned() assert.NoError(t, err) // less int32 data - case2.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows/2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows/2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // more int32 data - case2.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows*2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows*2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // revert - case2.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows) + err = case2.insertMsg.CheckAligned() assert.NoError(t, err) // less int64 data - case2.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows/2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows/2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // more int64 data - case2.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows*2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows*2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // revert - case2.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows) + err = case2.insertMsg.CheckAligned() assert.NoError(t, err) // less float data - case2.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows/2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows/2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // more float data - case2.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows*2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows*2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // revert - case2.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows) + err = case2.insertMsg.CheckAligned() assert.NoError(t, nil, err) // less double data - case2.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows/2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows/2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // more double data - case2.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows*2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows*2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // revert - case2.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows) + err = case2.insertMsg.CheckAligned() assert.NoError(t, nil, err) // less float vectors - case2.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // more float vectors - case2.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // revert - case2.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim) + err = case2.insertMsg.CheckAligned() assert.NoError(t, err) // less binary vectors - case2.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // more binary vectors - case2.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // revert - case2.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim) + err = case2.insertMsg.CheckAligned() assert.NoError(t, err) // less double data - case2.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows/2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows/2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // more double data - case2.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows*2) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows*2) + err = case2.insertMsg.CheckAligned() assert.Error(t, err) // revert - case2.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows) - err = case2.CheckAligned() + case2.insertMsg.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows) + err = case2.insertMsg.CheckAligned() assert.NoError(t, err) } diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 9c6f901848..4b3ccaeef5 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -1390,7 +1390,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) { t.Run("insert", func(t *testing.T) { hash := generateHashKeys(nb) task := &insertTask{ - BaseInsertTask: BaseInsertTask{ + insertMsg: &BaseInsertTask{ BaseMsg: msgstream.BaseMsg{ HashValues: hash, }, @@ -1434,7 +1434,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) { } for fieldName, dataType := range fieldName2Types { - task.FieldsData = append(task.FieldsData, generateFieldData(dataType, fieldName, nb)) + task.insertMsg.FieldsData = append(task.insertMsg.FieldsData, generateFieldData(dataType, fieldName, nb)) } assert.NoError(t, task.OnEnqueue()) @@ -1446,7 +1446,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) { t.Run("delete", func(t *testing.T) { task := &deleteTask{ Condition: NewTaskCondition(ctx), - BaseDeleteTask: msgstream.DeleteMsg{ + deleteMsg: &msgstream.DeleteMsg{ BaseMsg: msgstream.BaseMsg{}, DeleteRequest: internalpb.DeleteRequest{ Base: &commonpb.MsgBase{ @@ -1486,7 +1486,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) { task.SetID(id) assert.Equal(t, id, task.ID()) - task.Base.MsgType = commonpb.MsgType_Delete + task.deleteMsg.Base.MsgType = commonpb.MsgType_Delete assert.Equal(t, commonpb.MsgType_Delete, task.Type()) ts := Timestamp(time.Now().UnixNano()) @@ -1500,7 +1500,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) { task2 := &deleteTask{ Condition: NewTaskCondition(ctx), - BaseDeleteTask: msgstream.DeleteMsg{ + deleteMsg: &msgstream.DeleteMsg{ BaseMsg: msgstream.BaseMsg{}, DeleteRequest: internalpb.DeleteRequest{ Base: &commonpb.MsgBase{ @@ -1643,7 +1643,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { t.Run("insert", func(t *testing.T) { hash := generateHashKeys(nb) task := &insertTask{ - BaseInsertTask: BaseInsertTask{ + insertMsg: &BaseInsertTask{ BaseMsg: msgstream.BaseMsg{ HashValues: hash, }, @@ -1688,7 +1688,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { fieldID := common.StartOfUserFieldID for fieldName, dataType := range fieldName2Types { - task.FieldsData = append(task.FieldsData, generateFieldData(dataType, fieldName, nb)) + task.insertMsg.FieldsData = append(task.insertMsg.FieldsData, generateFieldData(dataType, fieldName, nb)) fieldID++ } @@ -1701,7 +1701,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { t.Run("delete", func(t *testing.T) { task := &deleteTask{ Condition: NewTaskCondition(ctx), - BaseDeleteTask: msgstream.DeleteMsg{ + deleteMsg: &msgstream.DeleteMsg{ BaseMsg: msgstream.BaseMsg{}, DeleteRequest: internalpb.DeleteRequest{ Base: &commonpb.MsgBase{ @@ -1741,7 +1741,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { task.SetID(id) assert.Equal(t, id, task.ID()) - task.Base.MsgType = commonpb.MsgType_Delete + task.deleteMsg.Base.MsgType = commonpb.MsgType_Delete assert.Equal(t, commonpb.MsgType_Delete, task.Type()) ts := Timestamp(time.Now().UnixNano()) @@ -1755,7 +1755,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { task2 := &deleteTask{ Condition: NewTaskCondition(ctx), - BaseDeleteTask: msgstream.DeleteMsg{ + deleteMsg: &msgstream.DeleteMsg{ BaseMsg: msgstream.BaseMsg{}, DeleteRequest: internalpb.DeleteRequest{ Base: &commonpb.MsgBase{ diff --git a/internal/proxy/util.go b/internal/proxy/util.go index d630d7f7b8..4cd7763a0d 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -24,6 +24,7 @@ import ( "strings" "time" + "github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" @@ -872,3 +873,69 @@ func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collID int64, p } return false, nil } + +func checkLengthOfFieldsData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) error { + neededFieldsNum := 0 + for _, field := range schema.Fields { + if !field.AutoID { + neededFieldsNum++ + } + } + + if len(insertMsg.FieldsData) < neededFieldsNum { + return errFieldsLessThanNeeded(len(insertMsg.FieldsData), neededFieldsNum) + } + + return nil +} + +func checkPrimaryFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) (*schemapb.IDs, error) { + rowNums := uint32(insertMsg.NRows()) + // TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields + if insertMsg.NRows() <= 0 { + return nil, errNumRowsLessThanOrEqualToZero(rowNums) + } + + if err := checkLengthOfFieldsData(schema, insertMsg); err != nil { + return nil, err + } + + primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema) + if err != nil { + log.Error("get primary field schema failed", zap.String("collectionName", insertMsg.CollectionName), zap.Any("schema", schema), zap.Error(err)) + return nil, err + } + + // get primaryFieldData whether autoID is true or not + var primaryFieldData *schemapb.FieldData + if !primaryFieldSchema.AutoID { + primaryFieldData, err = typeutil.GetPrimaryFieldData(insertMsg.GetFieldsData(), primaryFieldSchema) + if err != nil { + log.Error("get primary field data failed", zap.String("collectionName", insertMsg.CollectionName), zap.Error(err)) + return nil, err + } + } else { + // check primary key data not exist + if typeutil.IsPrimaryFieldDataExist(insertMsg.GetFieldsData(), primaryFieldSchema) { + return nil, fmt.Errorf("can not assign primary field data when auto id enabled %v", primaryFieldSchema.Name) + } + // if autoID == true, currently only support autoID for int64 PrimaryField + primaryFieldData, err = autoGenPrimaryFieldData(primaryFieldSchema, insertMsg.GetRowIDs()) + if err != nil { + log.Error("generate primary field data failed when autoID == true", zap.String("collectionName", insertMsg.CollectionName), zap.Error(err)) + return nil, err + } + // if autoID == true, set the primary field data + // insertMsg.fieldsData need append primaryFieldData + insertMsg.FieldsData = append(insertMsg.FieldsData, primaryFieldData) + } + + // parse primaryFieldData to result.IDs, and as returned primary keys + ids, err := parsePrimaryFieldData2IDs(primaryFieldData) + if err != nil { + log.Error("parse primary field data to IDs failed", zap.String("collectionName", insertMsg.CollectionName), zap.Error(err)) + return nil, err + } + + return ids, nil +} diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index 781e1c1fe0..97fcd32fef 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -924,3 +924,305 @@ func Test_isPartitionIsLoaded(t *testing.T) { assert.False(t, loaded) }) } + +func Test_InsertTaskCheckLengthOfFieldsData(t *testing.T) { + var err error + + // schema is empty, though won't happen in system + case1 := insertTask{ + schema: &schemapb.CollectionSchema{ + Name: "TestInsertTask_checkLengthOfFieldsData", + Description: "TestInsertTask_checkLengthOfFieldsData", + AutoID: false, + Fields: []*schemapb.FieldSchema{}, + }, + insertMsg: &BaseInsertTask{ + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + }, + DbName: "TestInsertTask_checkLengthOfFieldsData", + CollectionName: "TestInsertTask_checkLengthOfFieldsData", + PartitionName: "TestInsertTask_checkLengthOfFieldsData", + }, + }, + } + + err = checkLengthOfFieldsData(case1.schema, case1.insertMsg) + assert.Equal(t, nil, err) + + // schema has two fields, neither of them are autoID + case2 := insertTask{ + schema: &schemapb.CollectionSchema{ + Name: "TestInsertTask_checkLengthOfFieldsData", + Description: "TestInsertTask_checkLengthOfFieldsData", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + AutoID: false, + DataType: schemapb.DataType_Int64, + }, + { + AutoID: false, + DataType: schemapb.DataType_Int64, + }, + }, + }, + insertMsg: &BaseInsertTask{ + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + }, + }, + }, + } + // passed fields is empty + // case2.BaseInsertTask = BaseInsertTask{ + // InsertRequest: internalpb.insertRequest{ + // Base: &commonpb.MsgBase{ + // MsgType: commonpb.MsgType_Insert, + // MsgID: 0, + // SourceID: paramtable.GetNodeID(), + // }, + // }, + // } + err = checkLengthOfFieldsData(case2.schema, case2.insertMsg) + assert.NotEqual(t, nil, err) + // the num of passed fields is less than needed + case2.insertMsg.FieldsData = []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + }, + } + err = checkLengthOfFieldsData(case2.schema, case2.insertMsg) + assert.NotEqual(t, nil, err) + // satisfied + case2.insertMsg.FieldsData = []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + }, + { + Type: schemapb.DataType_Int64, + }, + } + err = checkLengthOfFieldsData(case2.schema, case2.insertMsg) + assert.Equal(t, nil, err) + + // schema has two field, one of them are autoID + case3 := insertTask{ + schema: &schemapb.CollectionSchema{ + Name: "TestInsertTask_checkLengthOfFieldsData", + Description: "TestInsertTask_checkLengthOfFieldsData", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + AutoID: true, + DataType: schemapb.DataType_Int64, + }, + { + AutoID: false, + DataType: schemapb.DataType_Int64, + }, + }, + }, + insertMsg: &BaseInsertTask{ + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + }, + }, + }, + } + // passed fields is empty + // case3.req = &milvuspb.InsertRequest{} + err = checkLengthOfFieldsData(case3.schema, case3.insertMsg) + assert.NotEqual(t, nil, err) + // satisfied + case3.insertMsg.FieldsData = []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + }, + } + err = checkLengthOfFieldsData(case3.schema, case3.insertMsg) + assert.Equal(t, nil, err) + + // schema has one field which is autoID + case4 := insertTask{ + schema: &schemapb.CollectionSchema{ + Name: "TestInsertTask_checkLengthOfFieldsData", + Description: "TestInsertTask_checkLengthOfFieldsData", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + AutoID: true, + DataType: schemapb.DataType_Int64, + }, + }, + }, + insertMsg: &BaseInsertTask{ + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + }, + }, + }, + } + // passed fields is empty + // satisfied + // case4.req = &milvuspb.InsertRequest{} + err = checkLengthOfFieldsData(case4.schema, case4.insertMsg) + assert.Equal(t, nil, err) +} + +func Test_InsertTaskCheckPrimaryFieldData(t *testing.T) { + // schema is empty, though won't happen in system + // num_rows(0) should be greater than 0 + case1 := insertTask{ + schema: &schemapb.CollectionSchema{ + Name: "TestInsertTask_checkPrimaryFieldData", + Description: "TestInsertTask_checkPrimaryFieldData", + AutoID: false, + Fields: []*schemapb.FieldSchema{}, + }, + insertMsg: &BaseInsertTask{ + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + }, + DbName: "TestInsertTask_checkPrimaryFieldData", + CollectionName: "TestInsertTask_checkPrimaryFieldData", + PartitionName: "TestInsertTask_checkPrimaryFieldData", + }, + }, + } + + _, err := checkPrimaryFieldData(case1.schema, case1.insertMsg) + assert.NotEqual(t, nil, err) + + // the num of passed fields is less than needed + case2 := insertTask{ + schema: &schemapb.CollectionSchema{ + Name: "TestInsertTask_checkPrimaryFieldData", + Description: "TestInsertTask_checkPrimaryFieldData", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + AutoID: false, + DataType: schemapb.DataType_Int64, + }, + { + AutoID: false, + DataType: schemapb.DataType_Int64, + }, + }, + }, + insertMsg: &BaseInsertTask{ + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + }, + RowData: []*commonpb.Blob{ + {}, + {}, + }, + FieldsData: []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + }, + }, + Version: internalpb.InsertDataVersion_RowBased, + }, + }, + } + _, err = checkPrimaryFieldData(case2.schema, case2.insertMsg) + assert.NotEqual(t, nil, err) + + // autoID == false, no primary field schema + // primary field is not found + case3 := insertTask{ + schema: &schemapb.CollectionSchema{ + Name: "TestInsertTask_checkPrimaryFieldData", + Description: "TestInsertTask_checkPrimaryFieldData", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + Name: "int64Field", + DataType: schemapb.DataType_Int64, + }, + { + Name: "floatField", + DataType: schemapb.DataType_Float, + }, + }, + }, + insertMsg: &BaseInsertTask{ + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + }, + RowData: []*commonpb.Blob{ + {}, + {}, + }, + FieldsData: []*schemapb.FieldData{ + {}, + {}, + }, + }, + }, + } + _, err = checkPrimaryFieldData(case3.schema, case3.insertMsg) + assert.NotEqual(t, nil, err) + + // autoID == true, has primary field schema, but primary field data exist + // can not assign primary field data when auto id enabled int64Field + case4 := insertTask{ + schema: &schemapb.CollectionSchema{ + Name: "TestInsertTask_checkPrimaryFieldData", + Description: "TestInsertTask_checkPrimaryFieldData", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + Name: "int64Field", + FieldID: 1, + DataType: schemapb.DataType_Int64, + }, + { + Name: "floatField", + FieldID: 2, + DataType: schemapb.DataType_Float, + }, + }, + }, + insertMsg: &BaseInsertTask{ + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + }, + RowData: []*commonpb.Blob{ + {}, + {}, + }, + FieldsData: []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + FieldName: "int64Field", + }, + }, + }, + }, + } + case4.schema.Fields[0].IsPrimaryKey = true + case4.schema.Fields[0].AutoID = true + case4.insertMsg.FieldsData[0] = newScalarFieldData(case4.schema.Fields[0], case4.schema.Fields[0].Name, 10) + _, err = checkPrimaryFieldData(case4.schema, case4.insertMsg) + assert.NotEqual(t, nil, err) + + // autoID == true, has primary field schema, but DataType don't match + // the data type of the data and the schema do not match + case4.schema.Fields[0].IsPrimaryKey = false + case4.schema.Fields[1].IsPrimaryKey = true + case4.schema.Fields[1].AutoID = true + _, err = checkPrimaryFieldData(case4.schema, case4.insertMsg) + assert.NotEqual(t, nil, err) +}