mirror of https://github.com/milvus-io/milvus.git
Optimization of delete and insert (#20990)
Signed-off-by: lixinguo <xinguo.li@zilliz.com> Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>pull/21082/head
parent
eb7ef01b9a
commit
18cad3a1fb
|
@ -1983,7 +1983,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
|||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
// req: request,
|
||||
BaseInsertTask: BaseInsertTask{
|
||||
insertMsg: &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: request.HashKeys,
|
||||
},
|
||||
|
@ -2007,8 +2007,8 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
|||
chTicker: node.chTicker,
|
||||
}
|
||||
|
||||
if len(it.PartitionName) <= 0 {
|
||||
it.PartitionName = Params.CommonCfg.DefaultPartitionName.GetValue()
|
||||
if len(it.insertMsg.PartitionName) <= 0 {
|
||||
it.insertMsg.PartitionName = Params.CommonCfg.DefaultPartitionName.GetValue()
|
||||
}
|
||||
|
||||
constructFailedResponse := func(err error) *milvuspb.MutationResult {
|
||||
|
@ -2045,7 +2045,6 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
|||
|
||||
log.Debug("Detail of insert request in Proxy",
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.Int64("msgID", it.Base.MsgID),
|
||||
zap.Uint64("BeginTS", it.BeginTs()),
|
||||
zap.Uint64("EndTS", it.EndTs()),
|
||||
zap.String("db", request.DbName),
|
||||
|
@ -2082,6 +2081,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
|||
metrics.ProxyInsertVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(successCnt))
|
||||
metrics.ProxyMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
metrics.ProxyCollectionMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel, request.CollectionName).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
log.Debug("lxg debug", zap.Any("insertResult", it.result))
|
||||
return it.result, nil
|
||||
}
|
||||
|
||||
|
@ -2112,7 +2112,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
|
|||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
deleteExpr: request.Expr,
|
||||
BaseDeleteTask: BaseDeleteTask{
|
||||
deleteMsg: &BaseDeleteTask{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: request.HashKeys,
|
||||
},
|
||||
|
@ -2154,7 +2154,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
|
|||
|
||||
log.Debug("Detail of delete request in Proxy",
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.Uint64("timestamp", dt.Base.Timestamp),
|
||||
zap.Uint64("timestamp", dt.deleteMsg.Base.Timestamp),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.String("partition", request.PartitionName),
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func assignSegmentID(ctx context.Context, insertMsg *msgstream.InsertMsg, result *milvuspb.MutationResult, channelNames []string, idAllocator *allocator.IDAllocator, segIDAssigner *segIDAssigner) (*msgstream.MsgPack, error) {
|
||||
threshold := Params.PulsarCfg.MaxMessageSize.GetAsInt()
|
||||
log.Debug("assign segmentid", zap.Int("threshold", threshold))
|
||||
|
||||
msgPack := &msgstream.MsgPack{
|
||||
BeginTs: insertMsg.BeginTs(),
|
||||
EndTs: insertMsg.EndTs(),
|
||||
}
|
||||
|
||||
// generate hash value for every primary key
|
||||
if len(insertMsg.HashValues) != 0 {
|
||||
log.Warn("the hashvalues passed through client is not supported now, and will be overwritten")
|
||||
}
|
||||
insertMsg.HashValues = typeutil.HashPK2Channels(result.IDs, channelNames)
|
||||
// groupedHashKeys represents the dmChannel index
|
||||
channel2RowOffsets := make(map[string][]int) // channelName to count
|
||||
channelMaxTSMap := make(map[string]Timestamp) // channelName to max Timestamp
|
||||
|
||||
// assert len(it.hashValues) < maxInt
|
||||
for offset, channelID := range insertMsg.HashValues {
|
||||
channelName := channelNames[channelID]
|
||||
if _, ok := channel2RowOffsets[channelName]; !ok {
|
||||
channel2RowOffsets[channelName] = []int{}
|
||||
}
|
||||
channel2RowOffsets[channelName] = append(channel2RowOffsets[channelName], offset)
|
||||
|
||||
if _, ok := channelMaxTSMap[channelName]; !ok {
|
||||
channelMaxTSMap[channelName] = typeutil.ZeroTimestamp
|
||||
}
|
||||
ts := insertMsg.Timestamps[offset]
|
||||
if channelMaxTSMap[channelName] < ts {
|
||||
channelMaxTSMap[channelName] = ts
|
||||
}
|
||||
}
|
||||
|
||||
// pre-alloc msg id by batch
|
||||
var idBegin, idEnd int64
|
||||
var err error
|
||||
|
||||
// fetch next id, if not id available, fetch next batch
|
||||
// lazy fetch, get first batch after first getMsgID called
|
||||
getMsgID := func() (int64, error) {
|
||||
if idBegin == idEnd {
|
||||
err = retry.Do(ctx, func() error {
|
||||
idBegin, idEnd, err = idAllocator.Alloc(16)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Error("failed to allocate msg id", zap.Int64("base.MsgID", insertMsg.Base.MsgID), zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
result := idBegin
|
||||
idBegin++
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// create empty insert message
|
||||
createInsertMsg := func(segmentID UniqueID, channelName string, msgID int64) *msgstream.InsertMsg {
|
||||
insertReq := internalpb.InsertRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
|
||||
commonpbutil.WithMsgID(msgID),
|
||||
commonpbutil.WithTimeStamp(insertMsg.BeginTimestamp), // entity's timestamp was set to equal it.BeginTimestamp in preExecute()
|
||||
commonpbutil.WithSourceID(insertMsg.Base.SourceID),
|
||||
),
|
||||
CollectionID: insertMsg.CollectionID,
|
||||
PartitionID: insertMsg.PartitionID,
|
||||
CollectionName: insertMsg.CollectionName,
|
||||
PartitionName: insertMsg.PartitionName,
|
||||
SegmentID: segmentID,
|
||||
ShardName: channelName,
|
||||
Version: internalpb.InsertDataVersion_ColumnBased,
|
||||
}
|
||||
insertReq.FieldsData = make([]*schemapb.FieldData, len(insertMsg.GetFieldsData()))
|
||||
|
||||
msg := &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
Ctx: ctx,
|
||||
},
|
||||
InsertRequest: insertReq,
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// repack the row data corresponding to the offset to insertMsg
|
||||
getInsertMsgsBySegmentID := func(segmentID UniqueID, rowOffsets []int, channelName string, maxMessageSize int) ([]msgstream.TsMsg, error) {
|
||||
repackedMsgs := make([]msgstream.TsMsg, 0)
|
||||
requestSize := 0
|
||||
msgID, err := getMsgID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg := createInsertMsg(segmentID, channelName, msgID)
|
||||
for _, offset := range rowOffsets {
|
||||
curRowMessageSize, err := typeutil.EstimateEntitySize(insertMsg.GetFieldsData(), offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if insertMsg's size is greater than the threshold, split into multiple insertMsgs
|
||||
if requestSize+curRowMessageSize >= maxMessageSize {
|
||||
repackedMsgs = append(repackedMsgs, msg)
|
||||
msgID, err = getMsgID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg = createInsertMsg(segmentID, channelName, msgID)
|
||||
requestSize = 0
|
||||
}
|
||||
|
||||
typeutil.AppendFieldData(msg.FieldsData, insertMsg.GetFieldsData(), int64(offset))
|
||||
msg.HashValues = append(msg.HashValues, insertMsg.HashValues[offset])
|
||||
msg.Timestamps = append(msg.Timestamps, insertMsg.Timestamps[offset])
|
||||
msg.RowIDs = append(msg.RowIDs, insertMsg.RowIDs[offset])
|
||||
msg.NumRows++
|
||||
requestSize += curRowMessageSize
|
||||
}
|
||||
repackedMsgs = append(repackedMsgs, msg)
|
||||
|
||||
return repackedMsgs, nil
|
||||
}
|
||||
|
||||
// get allocated segmentID info for every dmChannel and repack insertMsgs for every segmentID
|
||||
for channelName, rowOffsets := range channel2RowOffsets {
|
||||
assignedSegmentInfos, err := segIDAssigner.GetSegmentID(insertMsg.CollectionID, insertMsg.PartitionID, channelName, uint32(len(rowOffsets)), channelMaxTSMap[channelName])
|
||||
if err != nil {
|
||||
log.Error("allocate segmentID for insert data failed", zap.Int64("collectionID", insertMsg.CollectionID), zap.String("channel name", channelName),
|
||||
zap.Int("allocate count", len(rowOffsets)),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startPos := 0
|
||||
for segmentID, count := range assignedSegmentInfos {
|
||||
subRowOffsets := rowOffsets[startPos : startPos+int(count)]
|
||||
insertMsgs, err := getInsertMsgsBySegmentID(segmentID, subRowOffsets, channelName, threshold)
|
||||
if err != nil {
|
||||
log.Error("repack insert data to insert msgs failed", zap.Int64("collectionID", insertMsg.CollectionID),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, insertMsgs...)
|
||||
startPos += int(count)
|
||||
}
|
||||
}
|
||||
|
||||
return msgPack, nil
|
||||
}
|
|
@ -27,7 +27,7 @@ type BaseDeleteTask = msgstream.DeleteMsg
|
|||
|
||||
type deleteTask struct {
|
||||
Condition
|
||||
BaseDeleteTask
|
||||
deleteMsg *BaseDeleteTask
|
||||
ctx context.Context
|
||||
deleteExpr string
|
||||
//req *milvuspb.DeleteRequest
|
||||
|
@ -46,15 +46,15 @@ func (dt *deleteTask) TraceCtx() context.Context {
|
|||
}
|
||||
|
||||
func (dt *deleteTask) ID() UniqueID {
|
||||
return dt.Base.MsgID
|
||||
return dt.deleteMsg.Base.MsgID
|
||||
}
|
||||
|
||||
func (dt *deleteTask) SetID(uid UniqueID) {
|
||||
dt.Base.MsgID = uid
|
||||
dt.deleteMsg.Base.MsgID = uid
|
||||
}
|
||||
|
||||
func (dt *deleteTask) Type() commonpb.MsgType {
|
||||
return dt.Base.MsgType
|
||||
return dt.deleteMsg.Base.MsgType
|
||||
}
|
||||
|
||||
func (dt *deleteTask) Name() string {
|
||||
|
@ -62,19 +62,19 @@ func (dt *deleteTask) Name() string {
|
|||
}
|
||||
|
||||
func (dt *deleteTask) BeginTs() Timestamp {
|
||||
return dt.Base.Timestamp
|
||||
return dt.deleteMsg.Base.Timestamp
|
||||
}
|
||||
|
||||
func (dt *deleteTask) EndTs() Timestamp {
|
||||
return dt.Base.Timestamp
|
||||
return dt.deleteMsg.Base.Timestamp
|
||||
}
|
||||
|
||||
func (dt *deleteTask) SetTs(ts Timestamp) {
|
||||
dt.Base.Timestamp = ts
|
||||
dt.deleteMsg.Base.Timestamp = ts
|
||||
}
|
||||
|
||||
func (dt *deleteTask) OnEnqueue() error {
|
||||
dt.DeleteRequest.Base = commonpbutil.NewMsgBase()
|
||||
dt.deleteMsg.Base = commonpbutil.NewMsgBase()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -99,7 +99,7 @@ func (dt *deleteTask) getPChanStats() (map[pChan]pChanStatistics, error) {
|
|||
}
|
||||
|
||||
func (dt *deleteTask) getChannels() ([]pChan, error) {
|
||||
collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.CollectionName)
|
||||
collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.deleteMsg.CollectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -154,8 +154,8 @@ func getPrimaryKeysFromExpr(schema *schemapb.CollectionSchema, expr string) (res
|
|||
}
|
||||
|
||||
func (dt *deleteTask) PreExecute(ctx context.Context) error {
|
||||
dt.Base.MsgType = commonpb.MsgType_Delete
|
||||
dt.Base.SourceID = paramtable.GetNodeID()
|
||||
dt.deleteMsg.Base.MsgType = commonpb.MsgType_Delete
|
||||
dt.deleteMsg.Base.SourceID = paramtable.GetNodeID()
|
||||
|
||||
dt.result = &milvuspb.MutationResult{
|
||||
Status: &commonpb.Status{
|
||||
|
@ -167,7 +167,7 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
|
|||
Timestamp: dt.BeginTs(),
|
||||
}
|
||||
|
||||
collName := dt.CollectionName
|
||||
collName := dt.deleteMsg.CollectionName
|
||||
if err := validateCollectionName(collName); err != nil {
|
||||
log.Info("Invalid collection name", zap.String("collectionName", collName), zap.Error(err))
|
||||
return err
|
||||
|
@ -177,12 +177,12 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
|
|||
log.Info("Failed to get collection id", zap.String("collectionName", collName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
dt.DeleteRequest.CollectionID = collID
|
||||
dt.deleteMsg.CollectionID = collID
|
||||
dt.collectionID = collID
|
||||
|
||||
// If partitionName is not empty, partitionID will be set.
|
||||
if len(dt.PartitionName) > 0 {
|
||||
partName := dt.PartitionName
|
||||
if len(dt.deleteMsg.PartitionName) > 0 {
|
||||
partName := dt.deleteMsg.PartitionName
|
||||
if err := validatePartitionTag(partName, true); err != nil {
|
||||
log.Info("Invalid partition name", zap.String("partitionName", partName), zap.Error(err))
|
||||
return err
|
||||
|
@ -192,9 +192,9 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
|
|||
log.Info("Failed to get partition id", zap.String("collectionName", collName), zap.String("partitionName", partName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
dt.DeleteRequest.PartitionID = partID
|
||||
dt.deleteMsg.PartitionID = partID
|
||||
} else {
|
||||
dt.DeleteRequest.PartitionID = common.InvalidPartitionID
|
||||
dt.deleteMsg.PartitionID = common.InvalidPartitionID
|
||||
}
|
||||
|
||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, collName)
|
||||
|
@ -211,17 +211,17 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
dt.DeleteRequest.NumRows = numRow
|
||||
dt.DeleteRequest.PrimaryKeys = primaryKeys
|
||||
log.Debug("get primary keys from expr", zap.Int64("len of primary keys", dt.DeleteRequest.NumRows))
|
||||
dt.deleteMsg.NumRows = numRow
|
||||
dt.deleteMsg.PrimaryKeys = primaryKeys
|
||||
log.Debug("get primary keys from expr", zap.Int64("len of primary keys", dt.deleteMsg.NumRows))
|
||||
|
||||
// set result
|
||||
dt.result.IDs = primaryKeys
|
||||
dt.result.DeleteCnt = dt.DeleteRequest.NumRows
|
||||
dt.result.DeleteCnt = dt.deleteMsg.NumRows
|
||||
|
||||
dt.Timestamps = make([]uint64, numRow)
|
||||
for index := range dt.Timestamps {
|
||||
dt.Timestamps[index] = dt.BeginTs()
|
||||
dt.deleteMsg.Timestamps = make([]uint64, numRow)
|
||||
for index := range dt.deleteMsg.Timestamps {
|
||||
dt.deleteMsg.Timestamps[index] = dt.BeginTs()
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -233,7 +233,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
|
|||
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID()))
|
||||
|
||||
collID := dt.DeleteRequest.CollectionID
|
||||
collID := dt.deleteMsg.CollectionID
|
||||
stream, err := dt.chMgr.getOrCreateDmlStream(collID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -247,10 +247,10 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
|
|||
dt.result.Status.Reason = err.Error()
|
||||
return err
|
||||
}
|
||||
dt.HashValues = typeutil.HashPK2Channels(dt.result.IDs, channelNames)
|
||||
dt.deleteMsg.HashValues = typeutil.HashPK2Channels(dt.result.IDs, channelNames)
|
||||
|
||||
log.Debug("send delete request to virtual channels",
|
||||
zap.String("collection", dt.GetCollectionName()),
|
||||
zap.String("collection", dt.deleteMsg.GetCollectionName()),
|
||||
zap.Int64("collection_id", collID),
|
||||
zap.Strings("virtual_channels", channelNames),
|
||||
zap.Int64("task_id", dt.ID()))
|
||||
|
@ -258,19 +258,19 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
|
|||
tr.Record("get vchannels")
|
||||
// repack delete msg by dmChannel
|
||||
result := make(map[uint32]msgstream.TsMsg)
|
||||
collectionName := dt.CollectionName
|
||||
collectionID := dt.CollectionID
|
||||
partitionID := dt.PartitionID
|
||||
partitionName := dt.PartitionName
|
||||
proxyID := dt.Base.SourceID
|
||||
for index, key := range dt.HashValues {
|
||||
ts := dt.Timestamps[index]
|
||||
collectionName := dt.deleteMsg.CollectionName
|
||||
collectionID := dt.deleteMsg.CollectionID
|
||||
partitionID := dt.deleteMsg.PartitionID
|
||||
partitionName := dt.deleteMsg.PartitionName
|
||||
proxyID := dt.deleteMsg.Base.SourceID
|
||||
for index, key := range dt.deleteMsg.HashValues {
|
||||
ts := dt.deleteMsg.Timestamps[index]
|
||||
_, ok := result[key]
|
||||
if !ok {
|
||||
sliceRequest := internalpb.DeleteRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Delete),
|
||||
commonpbutil.WithMsgID(dt.Base.MsgID),
|
||||
commonpbutil.WithMsgID(dt.deleteMsg.Base.MsgID),
|
||||
commonpbutil.WithTimeStamp(ts),
|
||||
commonpbutil.WithSourceID(proxyID),
|
||||
),
|
||||
|
@ -289,9 +289,9 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
|
|||
result[key] = deleteMsg
|
||||
}
|
||||
curMsg := result[key].(*msgstream.DeleteMsg)
|
||||
curMsg.HashValues = append(curMsg.HashValues, dt.HashValues[index])
|
||||
curMsg.Timestamps = append(curMsg.Timestamps, dt.Timestamps[index])
|
||||
typeutil.AppendIDs(curMsg.PrimaryKeys, dt.PrimaryKeys, index)
|
||||
curMsg.HashValues = append(curMsg.HashValues, dt.deleteMsg.HashValues[index])
|
||||
curMsg.Timestamps = append(curMsg.Timestamps, dt.deleteMsg.Timestamps[index])
|
||||
typeutil.AppendIDs(curMsg.PrimaryKeys, dt.deleteMsg.PrimaryKeys, index)
|
||||
curMsg.NumRows++
|
||||
}
|
||||
|
||||
|
|
|
@ -12,21 +12,17 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
"github.com/milvus-io/milvus/internal/util/timerecord"
|
||||
"github.com/milvus-io/milvus/internal/util/trace"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type insertTask struct {
|
||||
BaseInsertTask
|
||||
// req *milvuspb.InsertRequest
|
||||
Condition
|
||||
ctx context.Context
|
||||
insertMsg *BaseInsertTask
|
||||
ctx context.Context
|
||||
|
||||
result *milvuspb.MutationResult
|
||||
idAllocator *allocator.IDAllocator
|
||||
|
@ -44,11 +40,11 @@ func (it *insertTask) TraceCtx() context.Context {
|
|||
}
|
||||
|
||||
func (it *insertTask) ID() UniqueID {
|
||||
return it.Base.MsgID
|
||||
return it.insertMsg.Base.MsgID
|
||||
}
|
||||
|
||||
func (it *insertTask) SetID(uid UniqueID) {
|
||||
it.Base.MsgID = uid
|
||||
it.insertMsg.Base.MsgID = uid
|
||||
}
|
||||
|
||||
func (it *insertTask) Name() string {
|
||||
|
@ -56,20 +52,20 @@ func (it *insertTask) Name() string {
|
|||
}
|
||||
|
||||
func (it *insertTask) Type() commonpb.MsgType {
|
||||
return it.Base.MsgType
|
||||
return it.insertMsg.Base.MsgType
|
||||
}
|
||||
|
||||
func (it *insertTask) BeginTs() Timestamp {
|
||||
return it.BeginTimestamp
|
||||
return it.insertMsg.BeginTimestamp
|
||||
}
|
||||
|
||||
func (it *insertTask) SetTs(ts Timestamp) {
|
||||
it.BeginTimestamp = ts
|
||||
it.EndTimestamp = ts
|
||||
it.insertMsg.BeginTimestamp = ts
|
||||
it.insertMsg.EndTimestamp = ts
|
||||
}
|
||||
|
||||
func (it *insertTask) EndTs() Timestamp {
|
||||
return it.EndTimestamp
|
||||
return it.insertMsg.EndTimestamp
|
||||
}
|
||||
|
||||
func (it *insertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
|
||||
|
@ -93,7 +89,7 @@ func (it *insertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
|
|||
}
|
||||
|
||||
func (it *insertTask) getChannels() ([]pChan, error) {
|
||||
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.CollectionName)
|
||||
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.insertMsg.CollectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -104,71 +100,6 @@ func (it *insertTask) OnEnqueue() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (it *insertTask) checkLengthOfFieldsData() error {
|
||||
neededFieldsNum := 0
|
||||
for _, field := range it.schema.Fields {
|
||||
if !field.AutoID {
|
||||
neededFieldsNum++
|
||||
}
|
||||
}
|
||||
|
||||
if len(it.FieldsData) < neededFieldsNum {
|
||||
return errFieldsLessThanNeeded(len(it.FieldsData), neededFieldsNum)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *insertTask) checkPrimaryFieldData() error {
|
||||
rowNums := uint32(it.NRows())
|
||||
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
|
||||
if it.NRows() <= 0 {
|
||||
return errNumRowsLessThanOrEqualToZero(rowNums)
|
||||
}
|
||||
|
||||
if err := it.checkLengthOfFieldsData(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(it.schema)
|
||||
if err != nil {
|
||||
log.Error("get primary field schema failed", zap.String("collectionName", it.CollectionName), zap.Any("schema", it.schema), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// get primaryFieldData whether autoID is true or not
|
||||
var primaryFieldData *schemapb.FieldData
|
||||
if !primaryFieldSchema.AutoID {
|
||||
primaryFieldData, err = typeutil.GetPrimaryFieldData(it.GetFieldsData(), primaryFieldSchema)
|
||||
if err != nil {
|
||||
log.Error("get primary field data failed", zap.String("collectionName", it.CollectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// check primary key data not exist
|
||||
if typeutil.IsPrimaryFieldDataExist(it.GetFieldsData(), primaryFieldSchema) {
|
||||
return fmt.Errorf("can not assign primary field data when auto id enabled %v", primaryFieldSchema.Name)
|
||||
}
|
||||
// if autoID == true, currently only support autoID for int64 PrimaryField
|
||||
primaryFieldData, err = autoGenPrimaryFieldData(primaryFieldSchema, it.RowIDs)
|
||||
if err != nil {
|
||||
log.Error("generate primary field data failed when autoID == true", zap.String("collectionName", it.CollectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
// if autoID == true, set the primary field data
|
||||
it.FieldsData = append(it.FieldsData, primaryFieldData)
|
||||
}
|
||||
|
||||
// parse primaryFieldData to result.IDs, and as returned primary keys
|
||||
it.result.IDs, err = parsePrimaryFieldData2IDs(primaryFieldData)
|
||||
if err != nil {
|
||||
log.Error("parse primary field data to IDs failed", zap.String("collectionName", it.CollectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *insertTask) PreExecute(ctx context.Context) error {
|
||||
sp, ctx := trace.StartSpanFromContextWithOperationName(it.ctx, "Proxy-Insert-PreExecute")
|
||||
defer sp.Finish()
|
||||
|
@ -183,13 +114,13 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||
Timestamp: it.EndTs(),
|
||||
}
|
||||
|
||||
collectionName := it.CollectionName
|
||||
collectionName := it.insertMsg.CollectionName
|
||||
if err := validateCollectionName(collectionName); err != nil {
|
||||
log.Error("valid collection name failed", zap.String("collectionName", collectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
partitionTag := it.PartitionName
|
||||
partitionTag := it.insertMsg.PartitionName
|
||||
if err := validatePartitionTag(partitionTag, true); err != nil {
|
||||
log.Error("valid partition name failed", zap.String("partition name", partitionTag), zap.Error(err))
|
||||
return err
|
||||
|
@ -202,7 +133,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
it.schema = collSchema
|
||||
|
||||
rowNums := uint32(it.NRows())
|
||||
rowNums := uint32(it.insertMsg.NRows())
|
||||
// set insertTask.rowIDs
|
||||
var rowIDBegin UniqueID
|
||||
var rowIDEnd UniqueID
|
||||
|
@ -210,16 +141,16 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||
rowIDBegin, rowIDEnd, _ = it.idAllocator.Alloc(rowNums)
|
||||
metrics.ProxyApplyPrimaryKeyLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
it.RowIDs = make([]UniqueID, rowNums)
|
||||
it.insertMsg.RowIDs = make([]UniqueID, rowNums)
|
||||
for i := rowIDBegin; i < rowIDEnd; i++ {
|
||||
offset := i - rowIDBegin
|
||||
it.RowIDs[offset] = i
|
||||
it.insertMsg.RowIDs[offset] = i
|
||||
}
|
||||
// set insertTask.timeStamps
|
||||
rowNum := it.NRows()
|
||||
it.Timestamps = make([]uint64, rowNum)
|
||||
for index := range it.Timestamps {
|
||||
it.Timestamps[index] = it.BeginTimestamp
|
||||
rowNum := it.insertMsg.NRows()
|
||||
it.insertMsg.Timestamps = make([]uint64, rowNum)
|
||||
for index := range it.insertMsg.Timestamps {
|
||||
it.insertMsg.Timestamps[index] = it.insertMsg.BeginTimestamp
|
||||
}
|
||||
|
||||
// set result.SuccIndex
|
||||
|
@ -231,7 +162,8 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||
|
||||
// check primaryFieldData whether autoID is true or not
|
||||
// set rowIDs as primary data if autoID == true
|
||||
err = it.checkPrimaryFieldData()
|
||||
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
|
||||
it.result.IDs, err = checkPrimaryFieldData(it.schema, it.insertMsg)
|
||||
log := log.Ctx(ctx).With(zap.String("collectionName", collectionName))
|
||||
if err != nil {
|
||||
log.Error("check primary field data and hash primary key failed",
|
||||
|
@ -240,7 +172,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// set field ID to insert field data
|
||||
err = fillFieldIDBySchema(it.GetFieldsData(), collSchema)
|
||||
err = fillFieldIDBySchema(it.insertMsg.GetFieldsData(), collSchema)
|
||||
if err != nil {
|
||||
log.Error("set fieldID to fieldData failed",
|
||||
zap.Error(err))
|
||||
|
@ -248,7 +180,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// check that all field's number rows are equal
|
||||
if err = it.CheckAligned(); err != nil {
|
||||
if err = it.insertMsg.CheckAligned(); err != nil {
|
||||
log.Error("field data is not aligned",
|
||||
zap.Error(err))
|
||||
return err
|
||||
|
@ -259,160 +191,6 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (it *insertTask) assignSegmentID(channelNames []string) (*msgstream.MsgPack, error) {
|
||||
threshold := Params.PulsarCfg.MaxMessageSize.GetAsInt()
|
||||
log.Debug("assign segmentid", zap.Int("threshold", threshold))
|
||||
|
||||
result := &msgstream.MsgPack{
|
||||
BeginTs: it.BeginTs(),
|
||||
EndTs: it.EndTs(),
|
||||
}
|
||||
|
||||
// generate hash value for every primary key
|
||||
if len(it.HashValues) != 0 {
|
||||
log.Warn("the hashvalues passed through client is not supported now, and will be overwritten")
|
||||
}
|
||||
it.HashValues = typeutil.HashPK2Channels(it.result.IDs, channelNames)
|
||||
// groupedHashKeys represents the dmChannel index
|
||||
channel2RowOffsets := make(map[string][]int) // channelName to count
|
||||
channelMaxTSMap := make(map[string]Timestamp) // channelName to max Timestamp
|
||||
|
||||
// assert len(it.hashValues) < maxInt
|
||||
for offset, channelID := range it.HashValues {
|
||||
channelName := channelNames[channelID]
|
||||
if _, ok := channel2RowOffsets[channelName]; !ok {
|
||||
channel2RowOffsets[channelName] = []int{}
|
||||
}
|
||||
channel2RowOffsets[channelName] = append(channel2RowOffsets[channelName], offset)
|
||||
|
||||
if _, ok := channelMaxTSMap[channelName]; !ok {
|
||||
channelMaxTSMap[channelName] = typeutil.ZeroTimestamp
|
||||
}
|
||||
ts := it.Timestamps[offset]
|
||||
if channelMaxTSMap[channelName] < ts {
|
||||
channelMaxTSMap[channelName] = ts
|
||||
}
|
||||
}
|
||||
|
||||
// pre-alloc msg id by batch
|
||||
var idBegin, idEnd int64
|
||||
var err error
|
||||
|
||||
// fetch next id, if not id available, fetch next batch
|
||||
// lazy fetch, get first batch after first getMsgID called
|
||||
getMsgID := func() (int64, error) {
|
||||
if idBegin == idEnd {
|
||||
err = retry.Do(it.ctx, func() error {
|
||||
idBegin, idEnd, err = it.idAllocator.Alloc(16)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Error("failed to allocate msg id", zap.Int64("base.MsgID", it.Base.MsgID), zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
result := idBegin
|
||||
idBegin++
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// create empty insert message
|
||||
createInsertMsg := func(segmentID UniqueID, channelName string, msgID int64) *msgstream.InsertMsg {
|
||||
insertReq := internalpb.InsertRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
|
||||
commonpbutil.WithMsgID(msgID),
|
||||
commonpbutil.WithTimeStamp(it.BeginTimestamp), // entity's timestamp was set to equal it.BeginTimestamp in preExecute()
|
||||
commonpbutil.WithSourceID(it.Base.SourceID),
|
||||
),
|
||||
CollectionID: it.CollectionID,
|
||||
PartitionID: it.PartitionID,
|
||||
CollectionName: it.CollectionName,
|
||||
PartitionName: it.PartitionName,
|
||||
SegmentID: segmentID,
|
||||
ShardName: channelName,
|
||||
Version: internalpb.InsertDataVersion_ColumnBased,
|
||||
}
|
||||
insertReq.FieldsData = make([]*schemapb.FieldData, len(it.GetFieldsData()))
|
||||
|
||||
insertMsg := &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
Ctx: it.TraceCtx(),
|
||||
},
|
||||
InsertRequest: insertReq,
|
||||
}
|
||||
|
||||
return insertMsg
|
||||
}
|
||||
|
||||
// repack the row data corresponding to the offset to insertMsg
|
||||
getInsertMsgsBySegmentID := func(segmentID UniqueID, rowOffsets []int, channelName string, maxMessageSize int) ([]msgstream.TsMsg, error) {
|
||||
repackedMsgs := make([]msgstream.TsMsg, 0)
|
||||
requestSize := 0
|
||||
msgID, err := getMsgID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
insertMsg := createInsertMsg(segmentID, channelName, msgID)
|
||||
for _, offset := range rowOffsets {
|
||||
curRowMessageSize, err := typeutil.EstimateEntitySize(it.InsertRequest.GetFieldsData(), offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if insertMsg's size is greater than the threshold, split into multiple insertMsgs
|
||||
if requestSize+curRowMessageSize >= maxMessageSize {
|
||||
repackedMsgs = append(repackedMsgs, insertMsg)
|
||||
msgID, err = getMsgID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
insertMsg = createInsertMsg(segmentID, channelName, msgID)
|
||||
requestSize = 0
|
||||
}
|
||||
|
||||
typeutil.AppendFieldData(insertMsg.FieldsData, it.GetFieldsData(), int64(offset))
|
||||
insertMsg.HashValues = append(insertMsg.HashValues, it.HashValues[offset])
|
||||
insertMsg.Timestamps = append(insertMsg.Timestamps, it.Timestamps[offset])
|
||||
insertMsg.RowIDs = append(insertMsg.RowIDs, it.RowIDs[offset])
|
||||
insertMsg.NumRows++
|
||||
requestSize += curRowMessageSize
|
||||
}
|
||||
repackedMsgs = append(repackedMsgs, insertMsg)
|
||||
|
||||
return repackedMsgs, nil
|
||||
}
|
||||
|
||||
// get allocated segmentID info for every dmChannel and repack insertMsgs for every segmentID
|
||||
for channelName, rowOffsets := range channel2RowOffsets {
|
||||
assignedSegmentInfos, err := it.segIDAssigner.GetSegmentID(it.CollectionID, it.PartitionID, channelName, uint32(len(rowOffsets)), channelMaxTSMap[channelName])
|
||||
if err != nil {
|
||||
log.Error("allocate segmentID for insert data failed",
|
||||
zap.Int64("collectionID", it.CollectionID),
|
||||
zap.String("channel name", channelName),
|
||||
zap.Int("allocate count", len(rowOffsets)),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startPos := 0
|
||||
for segmentID, count := range assignedSegmentInfos {
|
||||
subRowOffsets := rowOffsets[startPos : startPos+int(count)]
|
||||
insertMsgs, err := getInsertMsgsBySegmentID(segmentID, subRowOffsets, channelName, threshold)
|
||||
if err != nil {
|
||||
log.Error("repack insert data to insert msgs failed",
|
||||
zap.Int64("collectionID", it.CollectionID),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
result.Msgs = append(result.Msgs, insertMsgs...)
|
||||
startPos += int(count)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (it *insertTask) Execute(ctx context.Context) error {
|
||||
sp, ctx := trace.StartSpanFromContextWithOperationName(it.ctx, "Proxy-Insert-Execute")
|
||||
defer sp.Finish()
|
||||
|
@ -420,15 +198,15 @@ func (it *insertTask) Execute(ctx context.Context) error {
|
|||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute insert %d", it.ID()))
|
||||
defer tr.Elapse("insert execute done")
|
||||
|
||||
collectionName := it.CollectionName
|
||||
collectionName := it.insertMsg.CollectionName
|
||||
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
it.CollectionID = collID
|
||||
it.insertMsg.CollectionID = collID
|
||||
var partitionID UniqueID
|
||||
if len(it.PartitionName) > 0 {
|
||||
partitionID, err = globalMetaCache.GetPartitionID(ctx, collectionName, it.PartitionName)
|
||||
if len(it.insertMsg.PartitionName) > 0 {
|
||||
partitionID, err = globalMetaCache.GetPartitionID(ctx, collectionName, it.insertMsg.PartitionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -438,7 +216,7 @@ func (it *insertTask) Execute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
}
|
||||
it.PartitionID = partitionID
|
||||
it.insertMsg.PartitionID = partitionID
|
||||
tr.Record("get collection id & partition id from cache")
|
||||
|
||||
stream, err := it.chMgr.getOrCreateDmlStream(collID)
|
||||
|
@ -458,15 +236,16 @@ func (it *insertTask) Execute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
log.Ctx(ctx).Debug("send insert request to virtual channels",
|
||||
zap.String("collection", it.GetCollectionName()),
|
||||
zap.String("partition", it.GetPartitionName()),
|
||||
zap.String("collection", it.insertMsg.GetCollectionName()),
|
||||
zap.String("partition", it.insertMsg.GetPartitionName()),
|
||||
zap.Int64("collection_id", collID),
|
||||
zap.Int64("partition_id", partitionID),
|
||||
zap.Strings("virtual_channels", channelNames),
|
||||
zap.Int64("task_id", it.ID()))
|
||||
|
||||
// assign segmentID for insert data and repack data by segmentID
|
||||
msgPack, err := it.assignSegmentID(channelNames)
|
||||
var msgPack *msgstream.MsgPack
|
||||
msgPack, err = assignSegmentID(it.TraceCtx(), it.insertMsg, it.result, channelNames, it.idAllocator, it.segIDAssigner)
|
||||
if err != nil {
|
||||
log.Error("assign segmentID and repack insert data failed",
|
||||
zap.Int64("collectionID", collID),
|
||||
|
@ -477,7 +256,7 @@ func (it *insertTask) Execute(ctx context.Context) error {
|
|||
}
|
||||
log.Debug("assign segmentID for insert data success",
|
||||
zap.Int64("collectionID", collID),
|
||||
zap.String("collectionName", it.CollectionName))
|
||||
zap.String("collectionName", it.insertMsg.CollectionName))
|
||||
tr.Record("assign segment id")
|
||||
err = stream.Produce(msgPack)
|
||||
if err != nil {
|
||||
|
|
|
@ -9,140 +9,12 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInsertTask_checkLengthOfFieldsData(t *testing.T) {
|
||||
var err error
|
||||
|
||||
// schema is empty, though won't happen in system
|
||||
case1 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkLengthOfFieldsData",
|
||||
Description: "TestInsertTask_checkLengthOfFieldsData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{},
|
||||
},
|
||||
BaseInsertTask: BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
DbName: "TestInsertTask_checkLengthOfFieldsData",
|
||||
CollectionName: "TestInsertTask_checkLengthOfFieldsData",
|
||||
PartitionName: "TestInsertTask_checkLengthOfFieldsData",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = case1.checkLengthOfFieldsData()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// schema has two fields, neither of them are autoID
|
||||
case2 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkLengthOfFieldsData",
|
||||
Description: "TestInsertTask_checkLengthOfFieldsData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// passed fields is empty
|
||||
// case2.BaseInsertTask = BaseInsertTask{
|
||||
// InsertRequest: internalpb.InsertRequest{
|
||||
// Base: &commonpb.MsgBase{
|
||||
// MsgType: commonpb.MsgType_Insert,
|
||||
// MsgID: 0,
|
||||
// SourceID: paramtable.GetNodeID(),
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
err = case2.checkLengthOfFieldsData()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// the num of passed fields is less than needed
|
||||
case2.FieldsData = []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
}
|
||||
err = case2.checkLengthOfFieldsData()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// satisfied
|
||||
case2.FieldsData = []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
}
|
||||
err = case2.checkLengthOfFieldsData()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// schema has two field, one of them are autoID
|
||||
case3 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkLengthOfFieldsData",
|
||||
Description: "TestInsertTask_checkLengthOfFieldsData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
AutoID: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// passed fields is empty
|
||||
// case3.req = &milvuspb.InsertRequest{}
|
||||
err = case3.checkLengthOfFieldsData()
|
||||
assert.NotEqual(t, nil, err)
|
||||
// satisfied
|
||||
case3.FieldsData = []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
}
|
||||
err = case3.checkLengthOfFieldsData()
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// schema has one field which is autoID
|
||||
case4 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkLengthOfFieldsData",
|
||||
Description: "TestInsertTask_checkLengthOfFieldsData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
AutoID: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// passed fields is empty
|
||||
// satisfied
|
||||
// case4.req = &milvuspb.InsertRequest{}
|
||||
err = case4.checkLengthOfFieldsData()
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
|
||||
func TestInsertTask_CheckAligned(t *testing.T) {
|
||||
var err error
|
||||
|
||||
// passed NumRows is less than 0
|
||||
case1 := insertTask{
|
||||
BaseInsertTask: BaseInsertTask{
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
|
@ -151,7 +23,7 @@ func TestInsertTask_CheckAligned(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
err = case1.CheckAligned()
|
||||
err = case1.insertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// checkLengthOfFieldsData was already checked by TestInsertTask_checkLengthOfFieldsData
|
||||
|
@ -170,7 +42,7 @@ func TestInsertTask_CheckAligned(t *testing.T) {
|
|||
numRows := 20
|
||||
dim := 128
|
||||
case2 := insertTask{
|
||||
BaseInsertTask: BaseInsertTask{
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
|
@ -200,8 +72,8 @@ func TestInsertTask_CheckAligned(t *testing.T) {
|
|||
}
|
||||
|
||||
// satisfied
|
||||
case2.NumRows = uint64(numRows)
|
||||
case2.FieldsData = []*schemapb.FieldData{
|
||||
case2.insertMsg.NumRows = uint64(numRows)
|
||||
case2.insertMsg.FieldsData = []*schemapb.FieldData{
|
||||
newScalarFieldData(boolFieldSchema, "Bool", numRows),
|
||||
newScalarFieldData(int8FieldSchema, "Int8", numRows),
|
||||
newScalarFieldData(int16FieldSchema, "Int16", numRows),
|
||||
|
@ -213,136 +85,136 @@ func TestInsertTask_CheckAligned(t *testing.T) {
|
|||
newBinaryVectorFieldData("BinaryVector", numRows, dim),
|
||||
newScalarFieldData(varCharFieldSchema, "VarChar", numRows),
|
||||
}
|
||||
err = case2.CheckAligned()
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less bool data
|
||||
case2.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows/2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows/2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more bool data
|
||||
case2.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows*2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows*2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less int8 data
|
||||
case2.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows/2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows/2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more int8 data
|
||||
case2.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows*2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows*2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less int16 data
|
||||
case2.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows/2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows/2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more int16 data
|
||||
case2.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows*2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows*2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less int32 data
|
||||
case2.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows/2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows/2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more int32 data
|
||||
case2.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows*2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows*2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less int64 data
|
||||
case2.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows/2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows/2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more int64 data
|
||||
case2.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows*2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows*2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less float data
|
||||
case2.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows/2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows/2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more float data
|
||||
case2.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows*2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows*2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.NoError(t, nil, err)
|
||||
|
||||
// less double data
|
||||
case2.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows/2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows/2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more double data
|
||||
case2.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows*2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows*2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.NoError(t, nil, err)
|
||||
|
||||
// less float vectors
|
||||
case2.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more float vectors
|
||||
case2.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less binary vectors
|
||||
case2.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more binary vectors
|
||||
case2.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// less double data
|
||||
case2.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows/2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows/2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// more double data
|
||||
case2.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows*2)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows*2)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.Error(t, err)
|
||||
// revert
|
||||
case2.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows)
|
||||
err = case2.CheckAligned()
|
||||
case2.insertMsg.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows)
|
||||
err = case2.insertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
|
|
@ -1390,7 +1390,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
|
|||
t.Run("insert", func(t *testing.T) {
|
||||
hash := generateHashKeys(nb)
|
||||
task := &insertTask{
|
||||
BaseInsertTask: BaseInsertTask{
|
||||
insertMsg: &BaseInsertTask{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: hash,
|
||||
},
|
||||
|
@ -1434,7 +1434,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
|
|||
}
|
||||
|
||||
for fieldName, dataType := range fieldName2Types {
|
||||
task.FieldsData = append(task.FieldsData, generateFieldData(dataType, fieldName, nb))
|
||||
task.insertMsg.FieldsData = append(task.insertMsg.FieldsData, generateFieldData(dataType, fieldName, nb))
|
||||
}
|
||||
|
||||
assert.NoError(t, task.OnEnqueue())
|
||||
|
@ -1446,7 +1446,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
|
|||
t.Run("delete", func(t *testing.T) {
|
||||
task := &deleteTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
BaseDeleteTask: msgstream.DeleteMsg{
|
||||
deleteMsg: &msgstream.DeleteMsg{
|
||||
BaseMsg: msgstream.BaseMsg{},
|
||||
DeleteRequest: internalpb.DeleteRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
|
@ -1486,7 +1486,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
|
|||
task.SetID(id)
|
||||
assert.Equal(t, id, task.ID())
|
||||
|
||||
task.Base.MsgType = commonpb.MsgType_Delete
|
||||
task.deleteMsg.Base.MsgType = commonpb.MsgType_Delete
|
||||
assert.Equal(t, commonpb.MsgType_Delete, task.Type())
|
||||
|
||||
ts := Timestamp(time.Now().UnixNano())
|
||||
|
@ -1500,7 +1500,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
|
|||
|
||||
task2 := &deleteTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
BaseDeleteTask: msgstream.DeleteMsg{
|
||||
deleteMsg: &msgstream.DeleteMsg{
|
||||
BaseMsg: msgstream.BaseMsg{},
|
||||
DeleteRequest: internalpb.DeleteRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
|
@ -1643,7 +1643,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
|||
t.Run("insert", func(t *testing.T) {
|
||||
hash := generateHashKeys(nb)
|
||||
task := &insertTask{
|
||||
BaseInsertTask: BaseInsertTask{
|
||||
insertMsg: &BaseInsertTask{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
HashValues: hash,
|
||||
},
|
||||
|
@ -1688,7 +1688,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
|||
|
||||
fieldID := common.StartOfUserFieldID
|
||||
for fieldName, dataType := range fieldName2Types {
|
||||
task.FieldsData = append(task.FieldsData, generateFieldData(dataType, fieldName, nb))
|
||||
task.insertMsg.FieldsData = append(task.insertMsg.FieldsData, generateFieldData(dataType, fieldName, nb))
|
||||
fieldID++
|
||||
}
|
||||
|
||||
|
@ -1701,7 +1701,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
|||
t.Run("delete", func(t *testing.T) {
|
||||
task := &deleteTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
BaseDeleteTask: msgstream.DeleteMsg{
|
||||
deleteMsg: &msgstream.DeleteMsg{
|
||||
BaseMsg: msgstream.BaseMsg{},
|
||||
DeleteRequest: internalpb.DeleteRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
|
@ -1741,7 +1741,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
|||
task.SetID(id)
|
||||
assert.Equal(t, id, task.ID())
|
||||
|
||||
task.Base.MsgType = commonpb.MsgType_Delete
|
||||
task.deleteMsg.Base.MsgType = commonpb.MsgType_Delete
|
||||
assert.Equal(t, commonpb.MsgType_Delete, task.Type())
|
||||
|
||||
ts := Timestamp(time.Now().UnixNano())
|
||||
|
@ -1755,7 +1755,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
|||
|
||||
task2 := &deleteTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
BaseDeleteTask: msgstream.DeleteMsg{
|
||||
deleteMsg: &msgstream.DeleteMsg{
|
||||
BaseMsg: msgstream.BaseMsg{},
|
||||
DeleteRequest: internalpb.DeleteRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
|
||||
|
@ -872,3 +873,69 @@ func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collID int64, p
|
|||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func checkLengthOfFieldsData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) error {
|
||||
neededFieldsNum := 0
|
||||
for _, field := range schema.Fields {
|
||||
if !field.AutoID {
|
||||
neededFieldsNum++
|
||||
}
|
||||
}
|
||||
|
||||
if len(insertMsg.FieldsData) < neededFieldsNum {
|
||||
return errFieldsLessThanNeeded(len(insertMsg.FieldsData), neededFieldsNum)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkPrimaryFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) (*schemapb.IDs, error) {
|
||||
rowNums := uint32(insertMsg.NRows())
|
||||
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
|
||||
if insertMsg.NRows() <= 0 {
|
||||
return nil, errNumRowsLessThanOrEqualToZero(rowNums)
|
||||
}
|
||||
|
||||
if err := checkLengthOfFieldsData(schema, insertMsg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema)
|
||||
if err != nil {
|
||||
log.Error("get primary field schema failed", zap.String("collectionName", insertMsg.CollectionName), zap.Any("schema", schema), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// get primaryFieldData whether autoID is true or not
|
||||
var primaryFieldData *schemapb.FieldData
|
||||
if !primaryFieldSchema.AutoID {
|
||||
primaryFieldData, err = typeutil.GetPrimaryFieldData(insertMsg.GetFieldsData(), primaryFieldSchema)
|
||||
if err != nil {
|
||||
log.Error("get primary field data failed", zap.String("collectionName", insertMsg.CollectionName), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// check primary key data not exist
|
||||
if typeutil.IsPrimaryFieldDataExist(insertMsg.GetFieldsData(), primaryFieldSchema) {
|
||||
return nil, fmt.Errorf("can not assign primary field data when auto id enabled %v", primaryFieldSchema.Name)
|
||||
}
|
||||
// if autoID == true, currently only support autoID for int64 PrimaryField
|
||||
primaryFieldData, err = autoGenPrimaryFieldData(primaryFieldSchema, insertMsg.GetRowIDs())
|
||||
if err != nil {
|
||||
log.Error("generate primary field data failed when autoID == true", zap.String("collectionName", insertMsg.CollectionName), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
// if autoID == true, set the primary field data
|
||||
// insertMsg.fieldsData need append primaryFieldData
|
||||
insertMsg.FieldsData = append(insertMsg.FieldsData, primaryFieldData)
|
||||
}
|
||||
|
||||
// parse primaryFieldData to result.IDs, and as returned primary keys
|
||||
ids, err := parsePrimaryFieldData2IDs(primaryFieldData)
|
||||
if err != nil {
|
||||
log.Error("parse primary field data to IDs failed", zap.String("collectionName", insertMsg.CollectionName), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
|
|
@ -924,3 +924,305 @@ func Test_isPartitionIsLoaded(t *testing.T) {
|
|||
assert.False(t, loaded)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_InsertTaskCheckLengthOfFieldsData(t *testing.T) {
|
||||
var err error
|
||||
|
||||
// schema is empty, though won't happen in system
|
||||
case1 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkLengthOfFieldsData",
|
||||
Description: "TestInsertTask_checkLengthOfFieldsData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{},
|
||||
},
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
DbName: "TestInsertTask_checkLengthOfFieldsData",
|
||||
CollectionName: "TestInsertTask_checkLengthOfFieldsData",
|
||||
PartitionName: "TestInsertTask_checkLengthOfFieldsData",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = checkLengthOfFieldsData(case1.schema, case1.insertMsg)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// schema has two fields, neither of them are autoID
|
||||
case2 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkLengthOfFieldsData",
|
||||
Description: "TestInsertTask_checkLengthOfFieldsData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
},
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// passed fields is empty
|
||||
// case2.BaseInsertTask = BaseInsertTask{
|
||||
// InsertRequest: internalpb.insertRequest{
|
||||
// Base: &commonpb.MsgBase{
|
||||
// MsgType: commonpb.MsgType_Insert,
|
||||
// MsgID: 0,
|
||||
// SourceID: paramtable.GetNodeID(),
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
err = checkLengthOfFieldsData(case2.schema, case2.insertMsg)
|
||||
assert.NotEqual(t, nil, err)
|
||||
// the num of passed fields is less than needed
|
||||
case2.insertMsg.FieldsData = []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
}
|
||||
err = checkLengthOfFieldsData(case2.schema, case2.insertMsg)
|
||||
assert.NotEqual(t, nil, err)
|
||||
// satisfied
|
||||
case2.insertMsg.FieldsData = []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
}
|
||||
err = checkLengthOfFieldsData(case2.schema, case2.insertMsg)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// schema has two field, one of them are autoID
|
||||
case3 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkLengthOfFieldsData",
|
||||
Description: "TestInsertTask_checkLengthOfFieldsData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
AutoID: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
},
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// passed fields is empty
|
||||
// case3.req = &milvuspb.InsertRequest{}
|
||||
err = checkLengthOfFieldsData(case3.schema, case3.insertMsg)
|
||||
assert.NotEqual(t, nil, err)
|
||||
// satisfied
|
||||
case3.insertMsg.FieldsData = []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
}
|
||||
err = checkLengthOfFieldsData(case3.schema, case3.insertMsg)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// schema has one field which is autoID
|
||||
case4 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkLengthOfFieldsData",
|
||||
Description: "TestInsertTask_checkLengthOfFieldsData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
AutoID: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
},
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// passed fields is empty
|
||||
// satisfied
|
||||
// case4.req = &milvuspb.InsertRequest{}
|
||||
err = checkLengthOfFieldsData(case4.schema, case4.insertMsg)
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
|
||||
func Test_InsertTaskCheckPrimaryFieldData(t *testing.T) {
|
||||
// schema is empty, though won't happen in system
|
||||
// num_rows(0) should be greater than 0
|
||||
case1 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkPrimaryFieldData",
|
||||
Description: "TestInsertTask_checkPrimaryFieldData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{},
|
||||
},
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
DbName: "TestInsertTask_checkPrimaryFieldData",
|
||||
CollectionName: "TestInsertTask_checkPrimaryFieldData",
|
||||
PartitionName: "TestInsertTask_checkPrimaryFieldData",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := checkPrimaryFieldData(case1.schema, case1.insertMsg)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
// the num of passed fields is less than needed
|
||||
case2 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkPrimaryFieldData",
|
||||
Description: "TestInsertTask_checkPrimaryFieldData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
},
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
RowData: []*commonpb.Blob{
|
||||
{},
|
||||
{},
|
||||
},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
Version: internalpb.InsertDataVersion_RowBased,
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err = checkPrimaryFieldData(case2.schema, case2.insertMsg)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
// autoID == false, no primary field schema
|
||||
// primary field is not found
|
||||
case3 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkPrimaryFieldData",
|
||||
Description: "TestInsertTask_checkPrimaryFieldData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "int64Field",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
Name: "floatField",
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
},
|
||||
},
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
RowData: []*commonpb.Blob{
|
||||
{},
|
||||
{},
|
||||
},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{},
|
||||
{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err = checkPrimaryFieldData(case3.schema, case3.insertMsg)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
// autoID == true, has primary field schema, but primary field data exist
|
||||
// can not assign primary field data when auto id enabled int64Field
|
||||
case4 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkPrimaryFieldData",
|
||||
Description: "TestInsertTask_checkPrimaryFieldData",
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "int64Field",
|
||||
FieldID: 1,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
Name: "floatField",
|
||||
FieldID: 2,
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
},
|
||||
},
|
||||
insertMsg: &BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
RowData: []*commonpb.Blob{
|
||||
{},
|
||||
{},
|
||||
},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "int64Field",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
case4.schema.Fields[0].IsPrimaryKey = true
|
||||
case4.schema.Fields[0].AutoID = true
|
||||
case4.insertMsg.FieldsData[0] = newScalarFieldData(case4.schema.Fields[0], case4.schema.Fields[0].Name, 10)
|
||||
_, err = checkPrimaryFieldData(case4.schema, case4.insertMsg)
|
||||
assert.NotEqual(t, nil, err)
|
||||
|
||||
// autoID == true, has primary field schema, but DataType don't match
|
||||
// the data type of the data and the schema do not match
|
||||
case4.schema.Fields[0].IsPrimaryKey = false
|
||||
case4.schema.Fields[1].IsPrimaryKey = true
|
||||
case4.schema.Fields[1].AutoID = true
|
||||
_, err = checkPrimaryFieldData(case4.schema, case4.insertMsg)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue