mirror of https://github.com/milvus-io/milvus.git
Cherry-pick from master pr: #40928 Related to #40737 --------- Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>pull/40976/head
parent
f0346a149a
commit
51efe9a60c
|
@ -43,6 +43,8 @@ type Column interface {
|
|||
AppendNull() error
|
||||
IsNull(int) (bool, error)
|
||||
Nullable() bool
|
||||
SetNullable(bool)
|
||||
ValidateNullable() error
|
||||
}
|
||||
|
||||
var errFieldDataTypeNotMatch = errors.New("FieldData type not matched")
|
||||
|
|
|
@ -18,6 +18,7 @@ package column
|
|||
|
||||
import (
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
|
@ -32,6 +33,8 @@ type genericColumnBase[T any] struct {
|
|||
fieldType entity.FieldType
|
||||
values []T
|
||||
|
||||
// nullable related fields
|
||||
// note that nullable must be set to true explicitly
|
||||
nullable bool
|
||||
validData []bool
|
||||
}
|
||||
|
@ -169,9 +172,8 @@ func (c *genericColumnBase[T]) AppendNull() error {
|
|||
if !c.nullable {
|
||||
return errors.New("append null to not nullable column")
|
||||
}
|
||||
// var v T
|
||||
c.validData = append(c.validData, true)
|
||||
// c.values = append(c.values, v)
|
||||
|
||||
c.validData = append(c.validData, false)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -189,6 +191,40 @@ func (c *genericColumnBase[T]) Nullable() bool {
|
|||
return c.nullable
|
||||
}
|
||||
|
||||
// SetNullable update the nullable flag and change the valid data array according to the flag value.
|
||||
// NOTE: set nullable to false will erase all the validData previously set.
|
||||
func (c *genericColumnBase[T]) SetNullable(nullable bool) {
|
||||
c.nullable = nullable
|
||||
// initialize validData only when
|
||||
if c.nullable && c.validData == nil {
|
||||
// set valid flag for all exisiting values
|
||||
c.validData = lo.RepeatBy(len(c.values), func(_ int) bool { return true })
|
||||
}
|
||||
|
||||
if !c.nullable {
|
||||
c.validData = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateNullable performs the sanity check for nullable column.
|
||||
// it checks the length of data and the valid number indicated by validData slice,
|
||||
// which shall be the same by definition
|
||||
func (c *genericColumnBase[T]) ValidateNullable() error {
|
||||
// skip check if column not nullable
|
||||
if !c.nullable {
|
||||
return nil
|
||||
}
|
||||
|
||||
// count valid entries
|
||||
validCnt := lo.CountBy(c.validData, func(v bool) bool {
|
||||
return v
|
||||
})
|
||||
if validCnt != len(c.values) {
|
||||
return errors.Newf("values number(%d) does not match valid count(%d)", len(c.values), validCnt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *genericColumnBase[T]) withValidData(validData []bool) {
|
||||
if len(validData) > 0 {
|
||||
c.nullable = true
|
||||
|
|
|
@ -148,6 +148,30 @@ func (s *GenericBaseSuite) TestConversion() {
|
|||
s.Error(err)
|
||||
}
|
||||
|
||||
func (s *GenericBaseSuite) TestNullable() {
|
||||
name := fmt.Sprintf("test_%d", rand.Intn(10))
|
||||
var values []int64
|
||||
gb := &genericColumnBase[int64]{
|
||||
name: name,
|
||||
fieldType: entity.FieldTypeInt64,
|
||||
values: values,
|
||||
}
|
||||
|
||||
s.False(gb.Nullable())
|
||||
s.NoError(gb.ValidateNullable())
|
||||
s.Error(gb.AppendNull())
|
||||
s.EqualValues(0, gb.Len())
|
||||
|
||||
gb.SetNullable(true)
|
||||
s.NoError(gb.ValidateNullable())
|
||||
s.NoError(gb.AppendNull())
|
||||
s.EqualValues(1, gb.Len())
|
||||
|
||||
gb.SetNullable(false)
|
||||
s.NoError(gb.ValidateNullable())
|
||||
s.EqualValues(0, gb.Len())
|
||||
}
|
||||
|
||||
func TestGenericBase(t *testing.T) {
|
||||
suite.Run(t, new(GenericBaseSuite))
|
||||
}
|
||||
|
|
|
@ -16,11 +16,6 @@
|
|||
|
||||
package column
|
||||
|
||||
import (
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
var (
|
||||
// scalars
|
||||
NewNullableColumnBool NullableColumnCreateFunc[bool, *ColumnBool] = NewNullableColumnCreator(NewColumnBool).New
|
||||
|
@ -57,18 +52,10 @@ type NullableColumnCreator[col interface {
|
|||
}
|
||||
|
||||
func (c NullableColumnCreator[col, T]) New(name string, values []T, validData []bool) (col, error) {
|
||||
var result col
|
||||
validCnt := lo.CountBy(validData, func(v bool) bool {
|
||||
return v
|
||||
})
|
||||
if validCnt != len(values) {
|
||||
return result, errors.Newf("values number(%d) does not match valid count(%d)", len(values), validCnt)
|
||||
}
|
||||
|
||||
result = c.base(name, values)
|
||||
result := c.base(name, values)
|
||||
result.withValidData(validData)
|
||||
|
||||
return result, nil
|
||||
return result, result.ValidateNullable()
|
||||
}
|
||||
|
||||
func NewNullableColumnCreator[col interface {
|
||||
|
|
|
@ -57,7 +57,7 @@ func (s *NullableScalarSuite) TestBasic() {
|
|||
s.Equal(entity.FieldTypeBool, column.Type())
|
||||
}
|
||||
|
||||
_, err = NewNullableColumnBool(name, data, nil)
|
||||
_, err = NewNullableColumnBool(name, data, []bool{false, false})
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
|
@ -87,7 +87,7 @@ func (s *NullableScalarSuite) TestBasic() {
|
|||
s.Equal(entity.FieldTypeInt8, column.Type())
|
||||
}
|
||||
|
||||
_, err = NewNullableColumnInt8(name, data, nil)
|
||||
_, err = NewNullableColumnInt8(name, data, []bool{false, false})
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
|
@ -117,7 +117,7 @@ func (s *NullableScalarSuite) TestBasic() {
|
|||
s.Equal(entity.FieldTypeInt16, column.Type())
|
||||
}
|
||||
|
||||
_, err = NewNullableColumnInt16(name, data, nil)
|
||||
_, err = NewNullableColumnInt16(name, data, []bool{false, false})
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
|
@ -147,7 +147,7 @@ func (s *NullableScalarSuite) TestBasic() {
|
|||
s.Equal(entity.FieldTypeInt32, column.Type())
|
||||
}
|
||||
|
||||
_, err = NewNullableColumnInt32(name, data, nil)
|
||||
_, err = NewNullableColumnInt32(name, data, []bool{false, false})
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
|
@ -177,7 +177,7 @@ func (s *NullableScalarSuite) TestBasic() {
|
|||
s.Equal(entity.FieldTypeInt64, column.Type())
|
||||
}
|
||||
|
||||
_, err = NewNullableColumnInt64(name, data, nil)
|
||||
_, err = NewNullableColumnInt64(name, data, []bool{false, false})
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
|
@ -207,7 +207,7 @@ func (s *NullableScalarSuite) TestBasic() {
|
|||
s.Equal(entity.FieldTypeFloat, column.Type())
|
||||
}
|
||||
|
||||
_, err = NewNullableColumnFloat(name, data, nil)
|
||||
_, err = NewNullableColumnFloat(name, data, []bool{false, false})
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
|
@ -237,7 +237,7 @@ func (s *NullableScalarSuite) TestBasic() {
|
|||
s.Equal(entity.FieldTypeDouble, column.Type())
|
||||
}
|
||||
|
||||
_, err = NewNullableColumnDouble(name, data, nil)
|
||||
_, err = NewNullableColumnDouble(name, data, []bool{false, false})
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -88,49 +88,41 @@ func AnyToColumns(rows []interface{}, schemas ...*entity.Schema) ([]column.Colum
|
|||
if field.PrimaryKey && field.AutoID {
|
||||
continue
|
||||
}
|
||||
|
||||
var col column.Column
|
||||
switch field.DataType {
|
||||
case entity.FieldTypeBool:
|
||||
data := make([]bool, 0, rowsLen)
|
||||
col := column.NewColumnBool(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
col = column.NewColumnBool(field.Name, data)
|
||||
case entity.FieldTypeInt8:
|
||||
data := make([]int8, 0, rowsLen)
|
||||
col := column.NewColumnInt8(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
col = column.NewColumnInt8(field.Name, data)
|
||||
case entity.FieldTypeInt16:
|
||||
data := make([]int16, 0, rowsLen)
|
||||
col := column.NewColumnInt16(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
col = column.NewColumnInt16(field.Name, data)
|
||||
case entity.FieldTypeInt32:
|
||||
data := make([]int32, 0, rowsLen)
|
||||
col := column.NewColumnInt32(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
col = column.NewColumnInt32(field.Name, data)
|
||||
case entity.FieldTypeInt64:
|
||||
data := make([]int64, 0, rowsLen)
|
||||
col := column.NewColumnInt64(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
col = column.NewColumnInt64(field.Name, data)
|
||||
case entity.FieldTypeFloat:
|
||||
data := make([]float32, 0, rowsLen)
|
||||
col := column.NewColumnFloat(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
col = column.NewColumnFloat(field.Name, data)
|
||||
case entity.FieldTypeDouble:
|
||||
data := make([]float64, 0, rowsLen)
|
||||
col := column.NewColumnDouble(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
col = column.NewColumnDouble(field.Name, data)
|
||||
case entity.FieldTypeString, entity.FieldTypeVarChar:
|
||||
data := make([]string, 0, rowsLen)
|
||||
col := column.NewColumnVarChar(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
col = column.NewColumnVarChar(field.Name, data)
|
||||
case entity.FieldTypeJSON:
|
||||
data := make([][]byte, 0, rowsLen)
|
||||
col := column.NewColumnJSONBytes(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
col = column.NewColumnJSONBytes(field.Name, data)
|
||||
case entity.FieldTypeArray:
|
||||
col := NewArrayColumn(field)
|
||||
if col == nil {
|
||||
return nil, errors.Newf("unsupported element type %s for Array", field.ElementType.String())
|
||||
}
|
||||
nameColumns[field.Name] = col
|
||||
case entity.FieldTypeFloatVector:
|
||||
data := make([][]float32, 0, rowsLen)
|
||||
dimStr, has := field.TypeParams[entity.TypeParamDim]
|
||||
|
@ -141,37 +133,38 @@ func AnyToColumns(rows []interface{}, schemas ...*entity.Schema) ([]column.Colum
|
|||
if err != nil {
|
||||
return []column.Column{}, fmt.Errorf("vector field with bad format dim: %s", err.Error())
|
||||
}
|
||||
col := column.NewColumnFloatVector(field.Name, int(dim), data)
|
||||
nameColumns[field.Name] = col
|
||||
col = column.NewColumnFloatVector(field.Name, int(dim), data)
|
||||
case entity.FieldTypeBinaryVector:
|
||||
data := make([][]byte, 0, rowsLen)
|
||||
dim, err := field.GetDim()
|
||||
if err != nil {
|
||||
return []column.Column{}, err
|
||||
}
|
||||
col := column.NewColumnBinaryVector(field.Name, int(dim), data)
|
||||
nameColumns[field.Name] = col
|
||||
col = column.NewColumnBinaryVector(field.Name, int(dim), data)
|
||||
case entity.FieldTypeFloat16Vector:
|
||||
data := make([][]byte, 0, rowsLen)
|
||||
dim, err := field.GetDim()
|
||||
if err != nil {
|
||||
return []column.Column{}, err
|
||||
}
|
||||
col := column.NewColumnFloat16Vector(field.Name, int(dim), data)
|
||||
nameColumns[field.Name] = col
|
||||
col = column.NewColumnFloat16Vector(field.Name, int(dim), data)
|
||||
case entity.FieldTypeBFloat16Vector:
|
||||
data := make([][]byte, 0, rowsLen)
|
||||
dim, err := field.GetDim()
|
||||
if err != nil {
|
||||
return []column.Column{}, err
|
||||
}
|
||||
col := column.NewColumnBFloat16Vector(field.Name, int(dim), data)
|
||||
nameColumns[field.Name] = col
|
||||
col = column.NewColumnBFloat16Vector(field.Name, int(dim), data)
|
||||
case entity.FieldTypeSparseVector:
|
||||
data := make([]entity.SparseEmbedding, 0, rowsLen)
|
||||
col := column.NewColumnSparseVectors(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
col = column.NewColumnSparseVectors(field.Name, data)
|
||||
}
|
||||
|
||||
if field.Nullable {
|
||||
col.SetNullable(true)
|
||||
}
|
||||
|
||||
nameColumns[field.Name] = col
|
||||
}
|
||||
|
||||
if isDynamic {
|
||||
|
|
Loading…
Reference in New Issue