fix: [2.5][GoSDK] Set nullable according to fieldSchema for RowBased insert (#40928) (#40962)

Cherry-pick from master
pr: #40928

Related to #40737

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/40976/head
congqixia 2025-03-28 10:30:20 +08:00 committed by GitHub
parent f0346a149a
commit 51efe9a60c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 96 additions and 54 deletions

View File

@ -43,6 +43,8 @@ type Column interface {
AppendNull() error
IsNull(int) (bool, error)
Nullable() bool
SetNullable(bool)
ValidateNullable() error
}
var errFieldDataTypeNotMatch = errors.New("FieldData type not matched")

View File

@ -18,6 +18,7 @@ package column
import (
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
@ -32,6 +33,8 @@ type genericColumnBase[T any] struct {
fieldType entity.FieldType
values []T
// nullable related fields
// note that nullable must be set to true explicitly
nullable bool
validData []bool
}
@ -169,9 +172,8 @@ func (c *genericColumnBase[T]) AppendNull() error {
if !c.nullable {
return errors.New("append null to not nullable column")
}
// var v T
c.validData = append(c.validData, true)
// c.values = append(c.values, v)
c.validData = append(c.validData, false)
return nil
}
@ -189,6 +191,40 @@ func (c *genericColumnBase[T]) Nullable() bool {
return c.nullable
}
// SetNullable update the nullable flag and change the valid data array according to the flag value.
// NOTE: set nullable to false will erase all the validData previously set.
func (c *genericColumnBase[T]) SetNullable(nullable bool) {
c.nullable = nullable
// initialize validData only when
if c.nullable && c.validData == nil {
// set valid flag for all exisiting values
c.validData = lo.RepeatBy(len(c.values), func(_ int) bool { return true })
}
if !c.nullable {
c.validData = nil
}
}
// ValidateNullable performs the sanity check for nullable column.
// it checks the length of data and the valid number indicated by validData slice,
// which shall be the same by definition
func (c *genericColumnBase[T]) ValidateNullable() error {
// skip check if column not nullable
if !c.nullable {
return nil
}
// count valid entries
validCnt := lo.CountBy(c.validData, func(v bool) bool {
return v
})
if validCnt != len(c.values) {
return errors.Newf("values number(%d) does not match valid count(%d)", len(c.values), validCnt)
}
return nil
}
func (c *genericColumnBase[T]) withValidData(validData []bool) {
if len(validData) > 0 {
c.nullable = true

View File

@ -148,6 +148,30 @@ func (s *GenericBaseSuite) TestConversion() {
s.Error(err)
}
func (s *GenericBaseSuite) TestNullable() {
name := fmt.Sprintf("test_%d", rand.Intn(10))
var values []int64
gb := &genericColumnBase[int64]{
name: name,
fieldType: entity.FieldTypeInt64,
values: values,
}
s.False(gb.Nullable())
s.NoError(gb.ValidateNullable())
s.Error(gb.AppendNull())
s.EqualValues(0, gb.Len())
gb.SetNullable(true)
s.NoError(gb.ValidateNullable())
s.NoError(gb.AppendNull())
s.EqualValues(1, gb.Len())
gb.SetNullable(false)
s.NoError(gb.ValidateNullable())
s.EqualValues(0, gb.Len())
}
func TestGenericBase(t *testing.T) {
suite.Run(t, new(GenericBaseSuite))
}

View File

@ -16,11 +16,6 @@
package column
import (
"github.com/cockroachdb/errors"
"github.com/samber/lo"
)
var (
// scalars
NewNullableColumnBool NullableColumnCreateFunc[bool, *ColumnBool] = NewNullableColumnCreator(NewColumnBool).New
@ -57,18 +52,10 @@ type NullableColumnCreator[col interface {
}
func (c NullableColumnCreator[col, T]) New(name string, values []T, validData []bool) (col, error) {
var result col
validCnt := lo.CountBy(validData, func(v bool) bool {
return v
})
if validCnt != len(values) {
return result, errors.Newf("values number(%d) does not match valid count(%d)", len(values), validCnt)
}
result = c.base(name, values)
result := c.base(name, values)
result.withValidData(validData)
return result, nil
return result, result.ValidateNullable()
}
func NewNullableColumnCreator[col interface {

View File

@ -57,7 +57,7 @@ func (s *NullableScalarSuite) TestBasic() {
s.Equal(entity.FieldTypeBool, column.Type())
}
_, err = NewNullableColumnBool(name, data, nil)
_, err = NewNullableColumnBool(name, data, []bool{false, false})
s.Error(err)
})
@ -87,7 +87,7 @@ func (s *NullableScalarSuite) TestBasic() {
s.Equal(entity.FieldTypeInt8, column.Type())
}
_, err = NewNullableColumnInt8(name, data, nil)
_, err = NewNullableColumnInt8(name, data, []bool{false, false})
s.Error(err)
})
@ -117,7 +117,7 @@ func (s *NullableScalarSuite) TestBasic() {
s.Equal(entity.FieldTypeInt16, column.Type())
}
_, err = NewNullableColumnInt16(name, data, nil)
_, err = NewNullableColumnInt16(name, data, []bool{false, false})
s.Error(err)
})
@ -147,7 +147,7 @@ func (s *NullableScalarSuite) TestBasic() {
s.Equal(entity.FieldTypeInt32, column.Type())
}
_, err = NewNullableColumnInt32(name, data, nil)
_, err = NewNullableColumnInt32(name, data, []bool{false, false})
s.Error(err)
})
@ -177,7 +177,7 @@ func (s *NullableScalarSuite) TestBasic() {
s.Equal(entity.FieldTypeInt64, column.Type())
}
_, err = NewNullableColumnInt64(name, data, nil)
_, err = NewNullableColumnInt64(name, data, []bool{false, false})
s.Error(err)
})
@ -207,7 +207,7 @@ func (s *NullableScalarSuite) TestBasic() {
s.Equal(entity.FieldTypeFloat, column.Type())
}
_, err = NewNullableColumnFloat(name, data, nil)
_, err = NewNullableColumnFloat(name, data, []bool{false, false})
s.Error(err)
})
@ -237,7 +237,7 @@ func (s *NullableScalarSuite) TestBasic() {
s.Equal(entity.FieldTypeDouble, column.Type())
}
_, err = NewNullableColumnDouble(name, data, nil)
_, err = NewNullableColumnDouble(name, data, []bool{false, false})
s.Error(err)
})
}

View File

@ -88,49 +88,41 @@ func AnyToColumns(rows []interface{}, schemas ...*entity.Schema) ([]column.Colum
if field.PrimaryKey && field.AutoID {
continue
}
var col column.Column
switch field.DataType {
case entity.FieldTypeBool:
data := make([]bool, 0, rowsLen)
col := column.NewColumnBool(field.Name, data)
nameColumns[field.Name] = col
col = column.NewColumnBool(field.Name, data)
case entity.FieldTypeInt8:
data := make([]int8, 0, rowsLen)
col := column.NewColumnInt8(field.Name, data)
nameColumns[field.Name] = col
col = column.NewColumnInt8(field.Name, data)
case entity.FieldTypeInt16:
data := make([]int16, 0, rowsLen)
col := column.NewColumnInt16(field.Name, data)
nameColumns[field.Name] = col
col = column.NewColumnInt16(field.Name, data)
case entity.FieldTypeInt32:
data := make([]int32, 0, rowsLen)
col := column.NewColumnInt32(field.Name, data)
nameColumns[field.Name] = col
col = column.NewColumnInt32(field.Name, data)
case entity.FieldTypeInt64:
data := make([]int64, 0, rowsLen)
col := column.NewColumnInt64(field.Name, data)
nameColumns[field.Name] = col
col = column.NewColumnInt64(field.Name, data)
case entity.FieldTypeFloat:
data := make([]float32, 0, rowsLen)
col := column.NewColumnFloat(field.Name, data)
nameColumns[field.Name] = col
col = column.NewColumnFloat(field.Name, data)
case entity.FieldTypeDouble:
data := make([]float64, 0, rowsLen)
col := column.NewColumnDouble(field.Name, data)
nameColumns[field.Name] = col
col = column.NewColumnDouble(field.Name, data)
case entity.FieldTypeString, entity.FieldTypeVarChar:
data := make([]string, 0, rowsLen)
col := column.NewColumnVarChar(field.Name, data)
nameColumns[field.Name] = col
col = column.NewColumnVarChar(field.Name, data)
case entity.FieldTypeJSON:
data := make([][]byte, 0, rowsLen)
col := column.NewColumnJSONBytes(field.Name, data)
nameColumns[field.Name] = col
col = column.NewColumnJSONBytes(field.Name, data)
case entity.FieldTypeArray:
col := NewArrayColumn(field)
if col == nil {
return nil, errors.Newf("unsupported element type %s for Array", field.ElementType.String())
}
nameColumns[field.Name] = col
case entity.FieldTypeFloatVector:
data := make([][]float32, 0, rowsLen)
dimStr, has := field.TypeParams[entity.TypeParamDim]
@ -141,37 +133,38 @@ func AnyToColumns(rows []interface{}, schemas ...*entity.Schema) ([]column.Colum
if err != nil {
return []column.Column{}, fmt.Errorf("vector field with bad format dim: %s", err.Error())
}
col := column.NewColumnFloatVector(field.Name, int(dim), data)
nameColumns[field.Name] = col
col = column.NewColumnFloatVector(field.Name, int(dim), data)
case entity.FieldTypeBinaryVector:
data := make([][]byte, 0, rowsLen)
dim, err := field.GetDim()
if err != nil {
return []column.Column{}, err
}
col := column.NewColumnBinaryVector(field.Name, int(dim), data)
nameColumns[field.Name] = col
col = column.NewColumnBinaryVector(field.Name, int(dim), data)
case entity.FieldTypeFloat16Vector:
data := make([][]byte, 0, rowsLen)
dim, err := field.GetDim()
if err != nil {
return []column.Column{}, err
}
col := column.NewColumnFloat16Vector(field.Name, int(dim), data)
nameColumns[field.Name] = col
col = column.NewColumnFloat16Vector(field.Name, int(dim), data)
case entity.FieldTypeBFloat16Vector:
data := make([][]byte, 0, rowsLen)
dim, err := field.GetDim()
if err != nil {
return []column.Column{}, err
}
col := column.NewColumnBFloat16Vector(field.Name, int(dim), data)
nameColumns[field.Name] = col
col = column.NewColumnBFloat16Vector(field.Name, int(dim), data)
case entity.FieldTypeSparseVector:
data := make([]entity.SparseEmbedding, 0, rowsLen)
col := column.NewColumnSparseVectors(field.Name, data)
nameColumns[field.Name] = col
col = column.NewColumnSparseVectors(field.Name, data)
}
if field.Nullable {
col.SetNullable(true)
}
nameColumns[field.Name] = col
}
if isDynamic {