enhance: make `ColumnBasedInsertMsgToInsertData` check field missing (#29758)

fix: #29757

In previous code, `ColumnBasedInsertMsgToInsertData` adds empty field if
the insertMsg parameter does not have the column schema defined. This
may lead to unexpected behavior of caller functions.

This PR:
- Add column missing check
- Add column length check
- Generate BlobInfo for ColumnBasedInsertMsgToInsertData result

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/29681/head
congqixia 2024-01-09 11:50:48 +08:00 committed by GitHub
parent 60e88fb833
commit f18a7191f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 96 deletions

View File

@ -63,6 +63,8 @@ func (s *WriteBufferSuite) SetupTest() {
func (s *WriteBufferSuite) TestDefaultOption() {
s.Run("default BFPkOracle", func() {
paramtable.Get().Save(paramtable.Get().DataCoordCfg.EnableLevelZeroSegment.Key, "false")
defer paramtable.Get().Reset(paramtable.Get().DataCoordCfg.EnableLevelZeroSegment.Key)
wb, err := NewWriteBuffer(s.channelName, s.metacache, s.storageCache, s.syncMgr)
s.NoError(err)
_, ok := wb.(*bfWriteBuffer)

View File

@ -28,6 +28,7 @@ import (
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -451,6 +452,15 @@ func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemap
return idata, nil
}
// ColumnBasedInsertMsgToInsertData converts an InsertMsg msg into InsertData based
// on provided CollectionSchema collSchema.
//
// This function checks whether all fields are provided in the collSchema.Fields.
// If any field is missing in the msg, an error will be returned.
//
// This funcion also checks the length of each column. All columns shall have the same length.
// Also, the InsertData.Infos shall have BlobInfo with this length returned.
// When the length is not aligned, an error will be returned.
func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemapb.CollectionSchema) (idata *InsertData, err error) {
srcFields := make(map[FieldID]*schemapb.FieldData)
for _, field := range msg.FieldsData {
@ -459,11 +469,14 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche
idata = &InsertData{
Data: make(map[FieldID]FieldData),
// TODO: handle Infos.
Infos: nil,
}
length := 0
for _, field := range collSchema.Fields {
srcField, ok := srcFields[field.GetFieldID()]
if !ok && field.GetFieldID() >= common.StartOfUserFieldID {
return nil, merr.WrapErrFieldNotFound(field.GetFieldID(), fmt.Sprintf("field %s not found when converting insert msg to insert data", field.GetName()))
}
var fieldData FieldData
switch field.DataType {
case schemapb.DataType_FloatVector:
dim, err := GetDimFromParams(field.TypeParams)
@ -472,15 +485,11 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche
return nil, err
}
srcData := srcFields[field.FieldID].GetVectors().GetFloatVector().GetData()
fieldData := &FloatVectorFieldData{
Data: make([]float32, 0, len(srcData)),
srcData := srcField.GetVectors().GetFloatVector().GetData()
fieldData = &FloatVectorFieldData{
Data: lo.Map(srcData, func(v float32, _ int) float32 { return v }),
Dim: dim,
}
fieldData.Data = append(fieldData.Data, srcData...)
idata.Data[field.FieldID] = fieldData
case schemapb.DataType_BinaryVector:
dim, err := GetDimFromParams(field.TypeParams)
@ -489,15 +498,12 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche
return nil, err
}
srcData := srcFields[field.FieldID].GetVectors().GetBinaryVector()
srcData := srcField.GetVectors().GetBinaryVector()
fieldData := &BinaryVectorFieldData{
Data: make([]byte, 0, len(srcData)),
fieldData = &BinaryVectorFieldData{
Data: lo.Map(srcData, func(v byte, _ int) byte { return v }),
Dim: dim,
}
fieldData.Data = append(fieldData.Data, srcData...)
idata.Data[field.FieldID] = fieldData
case schemapb.DataType_Float16Vector:
dim, err := GetDimFromParams(field.TypeParams)
@ -506,136 +512,111 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche
return nil, err
}
srcData := srcFields[field.FieldID].GetVectors().GetFloat16Vector()
srcData := srcField.GetVectors().GetFloat16Vector()
fieldData := &Float16VectorFieldData{
Data: make([]byte, 0, len(srcData)),
fieldData = &Float16VectorFieldData{
Data: lo.Map(srcData, func(v byte, _ int) byte { return v }),
Dim: dim,
}
fieldData.Data = append(fieldData.Data, srcData...)
idata.Data[field.FieldID] = fieldData
case schemapb.DataType_Bool:
srcData := srcFields[field.FieldID].GetScalars().GetBoolData().GetData()
srcData := srcField.GetScalars().GetBoolData().GetData()
fieldData := &BoolFieldData{
Data: make([]bool, 0, len(srcData)),
fieldData = &BoolFieldData{
Data: lo.Map(srcData, func(v bool, _ int) bool { return v }),
}
fieldData.Data = append(fieldData.Data, srcData...)
idata.Data[field.FieldID] = fieldData
case schemapb.DataType_Int8:
srcData := srcFields[field.FieldID].GetScalars().GetIntData().GetData()
srcData := srcField.GetScalars().GetIntData().GetData()
fieldData := &Int8FieldData{
Data: make([]int8, 0, len(srcData)),
fieldData = &Int8FieldData{
Data: lo.Map(srcData, func(v int32, _ int) int8 { return int8(v) }),
}
int8SrcData := make([]int8, len(srcData))
for i := 0; i < len(srcData); i++ {
int8SrcData[i] = int8(srcData[i])
}
fieldData.Data = append(fieldData.Data, int8SrcData...)
idata.Data[field.FieldID] = fieldData
case schemapb.DataType_Int16:
srcData := srcFields[field.FieldID].GetScalars().GetIntData().GetData()
srcData := srcField.GetScalars().GetIntData().GetData()
fieldData := &Int16FieldData{
Data: make([]int16, 0, len(srcData)),
fieldData = &Int16FieldData{
Data: lo.Map(srcData, func(v int32, _ int) int16 { return int16(v) }),
}
int16SrcData := make([]int16, len(srcData))
for i := 0; i < len(srcData); i++ {
int16SrcData[i] = int16(srcData[i])
}
fieldData.Data = append(fieldData.Data, int16SrcData...)
idata.Data[field.FieldID] = fieldData
case schemapb.DataType_Int32:
srcData := srcFields[field.FieldID].GetScalars().GetIntData().GetData()
srcData := srcField.GetScalars().GetIntData().GetData()
fieldData := &Int32FieldData{
Data: make([]int32, 0, len(srcData)),
fieldData = &Int32FieldData{
Data: lo.Map(srcData, func(v int32, _ int) int32 { return v }),
}
fieldData.Data = append(fieldData.Data, srcData...)
idata.Data[field.FieldID] = fieldData
case schemapb.DataType_Int64:
fieldData := &Int64FieldData{
Data: make([]int64, 0),
}
switch field.FieldID {
case 0: // rowIDs
fieldData.Data = make([]int64, 0, len(msg.RowIDs))
fieldData.Data = append(fieldData.Data, msg.RowIDs...)
case 1: // Timestamps
fieldData.Data = make([]int64, 0, len(msg.Timestamps))
for _, ts := range msg.Timestamps {
fieldData.Data = append(fieldData.Data, int64(ts))
case common.RowIDField: // rowIDs
fieldData = &Int64FieldData{
Data: lo.Map(msg.GetRowIDs(), func(v int64, _ int) int64 { return v }),
}
case common.TimeStampField: // Timestamps
fieldData = &Int64FieldData{
Data: lo.Map(msg.GetTimestamps(), func(v uint64, _ int) int64 { return int64(v) }),
}
default:
srcData := srcFields[field.FieldID].GetScalars().GetLongData().GetData()
fieldData.Data = make([]int64, 0, len(srcData))
fieldData.Data = append(fieldData.Data, srcData...)
srcData := srcField.GetScalars().GetLongData().GetData()
fieldData = &Int64FieldData{
Data: lo.Map(srcData, func(v int64, _ int) int64 { return v }),
}
}
idata.Data[field.FieldID] = fieldData
case schemapb.DataType_Float:
srcData := srcFields[field.FieldID].GetScalars().GetFloatData().GetData()
srcData := srcField.GetScalars().GetFloatData().GetData()
fieldData := &FloatFieldData{
Data: make([]float32, 0, len(srcData)),
fieldData = &FloatFieldData{
Data: lo.Map(srcData, func(v float32, _ int) float32 { return v }),
}
fieldData.Data = append(fieldData.Data, srcData...)
idata.Data[field.FieldID] = fieldData
case schemapb.DataType_Double:
srcData := srcFields[field.FieldID].GetScalars().GetDoubleData().GetData()
srcData := srcField.GetScalars().GetDoubleData().GetData()
fieldData := &DoubleFieldData{
Data: make([]float64, 0, len(srcData)),
fieldData = &DoubleFieldData{
Data: lo.Map(srcData, func(v float64, _ int) float64 { return v }),
}
fieldData.Data = append(fieldData.Data, srcData...)
idata.Data[field.FieldID] = fieldData
case schemapb.DataType_String, schemapb.DataType_VarChar:
srcData := srcFields[field.FieldID].GetScalars().GetStringData().GetData()
srcData := srcField.GetScalars().GetStringData().GetData()
fieldData := &StringFieldData{
Data: make([]string, 0, len(srcData)),
fieldData = &StringFieldData{
Data: lo.Map(srcData, func(v string, _ int) string { return v }),
}
fieldData.Data = append(fieldData.Data, srcData...)
idata.Data[field.FieldID] = fieldData
case schemapb.DataType_Array:
srcData := srcFields[field.FieldID].GetScalars().GetArrayData().GetData()
srcData := srcField.GetScalars().GetArrayData().GetData()
fieldData := &ArrayFieldData{
fieldData = &ArrayFieldData{
ElementType: field.GetElementType(),
Data: make([]*schemapb.ScalarField, 0, len(srcData)),
Data: lo.Map(srcData, func(v *schemapb.ScalarField, _ int) *schemapb.ScalarField { return v }),
}
fieldData.Data = append(fieldData.Data, srcData...)
idata.Data[field.FieldID] = fieldData
case schemapb.DataType_JSON:
srcData := srcFields[field.FieldID].GetScalars().GetJsonData().GetData()
srcData := srcField.GetScalars().GetJsonData().GetData()
fieldData := &JSONFieldData{
Data: make([][]byte, 0, len(srcData)),
fieldData = &JSONFieldData{
Data: lo.Map(srcData, func(v []byte, _ int) []byte { return v }),
}
fieldData.Data = append(fieldData.Data, srcData...)
idata.Data[field.FieldID] = fieldData
default:
return nil, merr.WrapErrServiceInternal("data type not handled", field.GetDataType().String())
}
if length == 0 {
length = fieldData.RowNum()
}
if fieldData.RowNum() != length {
return nil, merr.WrapErrServiceInternal("row num not match", fmt.Sprintf("field %s row num not match %d, other column %d", field.GetName(), fieldData.RowNum(), length))
}
idata.Data[field.FieldID] = fieldData
}
idata.Infos = []BlobInfo{
{Length: length},
}
return idata, nil
}

View File

@ -26,6 +26,7 @@ import (
"testing"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -639,6 +640,8 @@ func genColumnBasedInsertMsg(schema *schemapb.CollectionSchema, numRows, fVecDim
FieldsData: nil,
NumRows: uint64(numRows),
Version: msgpb.InsertDataVersion_ColumnBased,
RowIDs: lo.RepeatBy(numRows, func(idx int) int64 { return int64(idx) }),
Timestamps: lo.RepeatBy(numRows, func(idx int) uint64 { return uint64(idx) }),
},
}
pks = make([]int64, 0)