enhance: Unify data type check APIs for go (#31945)

cherry-pick from master
pr: #31887 
Issue: #22837

Signed-off-by: Cai Yudong <yudong.cai@zilliz.com>
pull/31999/head
Cai Yudong 2024-04-08 14:11:17 +08:00 committed by GitHub
parent 958f933810
commit 4b430097dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 81 additions and 87 deletions

View File

@ -24,6 +24,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/requestutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
var RestRequestInterceptorErr = errors.New("interceptor error placeholder")
@ -383,7 +384,7 @@ func (h *HandlersV1) getCollectionDetails(c *gin.Context) {
}
vectorField := ""
for _, field := range coll.Schema.Fields {
if IsVectorField(field) {
if typeutil.IsVectorType(field.DataType) {
vectorField = field.Name
break
}

View File

@ -344,8 +344,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a
}
vectorField := ""
for _, field := range coll.Schema.Fields {
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector ||
field.DataType == schemapb.DataType_Float16Vector || field.DataType == schemapb.DataType_BFloat16Vector {
if typeutil.IsVectorType(field.DataType) {
vectorField = field.Name
break
}
@ -760,7 +759,7 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
var vectorField *schemapb.FieldSchema
if len(fieldName) == 0 {
for _, field := range collSchema.Fields {
if IsVectorField(field) {
if typeutil.IsVectorType(field.DataType) {
if len(fieldName) == 0 {
fieldName = field.Name
vectorField = field
@ -771,7 +770,7 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
}
} else {
for _, field := range collSchema.Fields {
if field.Name == fieldName && IsVectorField(field) {
if field.Name == fieldName && typeutil.IsVectorType(field.DataType) {
vectorField = field
break
}

View File

@ -25,6 +25,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/parameterutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func ParseUsernamePassword(c *gin.Context) (string, string, bool) {
@ -124,14 +125,6 @@ func checkGetPrimaryKey(coll *schemapb.CollectionSchema, idResult gjson.Result)
// --------------------- collection details --------------------- //
func IsVectorField(field *schemapb.FieldSchema) bool {
switch field.DataType {
case schemapb.DataType_BinaryVector, schemapb.DataType_FloatVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector:
return true
}
return false
}
func printFields(fields []*schemapb.FieldSchema) []gin.H {
return printFieldDetails(fields, true)
}
@ -150,7 +143,7 @@ func printFieldDetails(fields []*schemapb.FieldSchema, oldVersion bool) []gin.H
HTTPReturnFieldAutoID: field.AutoID,
HTTPReturnDescription: field.Description,
}
if IsVectorField(field) {
if typeutil.IsVectorType(field.DataType) {
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
if oldVersion {
dim, _ := getDim(field)

View File

@ -17,30 +17,24 @@
package indexnode
import (
"unsafe"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func estimateFieldDataSize(dim int64, numRows int64, dataType schemapb.DataType) (uint64, error) {
if dataType == schemapb.DataType_FloatVector {
var value float32
/* #nosec G103 */
return uint64(dim) * uint64(numRows) * uint64(unsafe.Sizeof(value)), nil
}
if dataType == schemapb.DataType_BinaryVector {
switch dataType {
case schemapb.DataType_BinaryVector:
return uint64(dim) / 8 * uint64(numRows), nil
}
if dataType == schemapb.DataType_Float16Vector {
case schemapb.DataType_FloatVector:
return uint64(dim) * uint64(numRows) * 4, nil
case schemapb.DataType_Float16Vector:
case schemapb.DataType_BFloat16Vector:
return uint64(dim) * uint64(numRows) * 2, nil
}
if dataType == schemapb.DataType_BFloat16Vector {
return uint64(dim) * uint64(numRows) * 2, nil
}
if dataType == schemapb.DataType_SparseFloatVector {
case schemapb.DataType_SparseFloatVector:
return 0, errors.New("could not estimate field data size of SparseFloatVector")
default:
return 0, nil
}
return 0, nil
}

View File

@ -129,16 +129,20 @@ func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorField
if !typeutil.IsVectorType(dataType) {
return nil, fmt.Errorf("field (%s) to search is not of vector data type", vectorFieldName)
}
if dataType == schemapb.DataType_FloatVector {
vectorType = planpb.VectorType_FloatVector
} else if dataType == schemapb.DataType_BinaryVector {
switch dataType {
case schemapb.DataType_BinaryVector:
vectorType = planpb.VectorType_BinaryVector
} else if dataType == schemapb.DataType_Float16Vector {
case schemapb.DataType_FloatVector:
vectorType = planpb.VectorType_FloatVector
case schemapb.DataType_Float16Vector:
vectorType = planpb.VectorType_Float16Vector
} else if dataType == schemapb.DataType_BFloat16Vector {
case schemapb.DataType_BFloat16Vector:
vectorType = planpb.VectorType_BFloat16Vector
} else if dataType == schemapb.DataType_SparseFloatVector {
case schemapb.DataType_SparseFloatVector:
vectorType = planpb.VectorType_SparseFloatVector
default:
log.Error("Invalid dataType", zap.Any("dataType", dataType))
return nil, err
}
planNode := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{

View File

@ -323,7 +323,7 @@ func (t *createCollectionTask) PreExecute(ctx context.Context) error {
return err
}
// validate dense vector field type parameters
if isVectorType(field.DataType) {
if typeutil.IsVectorType(field.DataType) {
err = validateDimension(field)
if err != nil {
return err
@ -1567,7 +1567,7 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) {
unindexedVecFields := make([]string, 0)
for _, field := range collSchema.GetFields() {
if isVectorType(field.GetDataType()) {
if typeutil.IsVectorType(field.GetDataType()) {
if _, ok := fieldIndexIDs[field.GetFieldID()]; !ok {
unindexedVecFields = append(unindexedVecFields, field.GetName())
}
@ -1810,7 +1810,7 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error {
for _, index := range indexResponse.IndexInfos {
fieldIndexIDs[index.FieldID] = index.IndexID
for _, field := range collSchema.Fields {
if index.FieldID == field.FieldID && isVectorType(field.DataType) {
if index.FieldID == field.FieldID && typeutil.IsVectorType(field.DataType) {
hasVecIndex = true
}
}

View File

@ -308,7 +308,7 @@ func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.Fiel
}
func fillDimension(field *schemapb.FieldSchema, indexParams map[string]string) error {
if !isVectorType(field.GetDataType()) {
if !typeutil.IsVectorType(field.GetDataType()) {
return nil
}
params := make([]*commonpb.KeyValuePair, 0, len(field.GetTypeParams())+len(field.GetIndexParams()))
@ -338,7 +338,7 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro
return fmt.Errorf("invalid index type: %s", indexType)
}
if !isSparseVectorType(field.DataType) {
if !typeutil.IsSparseFloatVectorType(field.DataType) {
if err := fillDimension(field, indexParams); err != nil {
return err
}

View File

@ -78,7 +78,7 @@ func translateToOutputFieldIDs(outputFields []string, schema *schemapb.Collectio
outputFieldIDs := make([]UniqueID, 0, len(outputFields)+1)
if len(outputFields) == 0 {
for _, field := range schema.Fields {
if field.FieldID >= common.StartOfUserFieldID && !isVectorType(field.DataType) {
if field.FieldID >= common.StartOfUserFieldID && !typeutil.IsVectorType(field.DataType) {
outputFieldIDs = append(outputFieldIDs, field.FieldID)
}
}

View File

@ -192,7 +192,7 @@ func constructCollectionSchemaByDataType(collectionName string, fieldName2DataTy
DataType: dataType,
}
idx++
if isVectorType(dataType) {
if typeutil.IsVectorType(dataType) {
fieldSchema.TypeParams = []*commonpb.KeyValuePair{
{
Key: common.DimKey,
@ -2507,7 +2507,7 @@ func Test_loadCollectionTask_Execute(t *testing.T) {
t.Run("not all vector fields with index", func(t *testing.T) {
vecFields := make([]*schemapb.FieldSchema, 0)
for _, field := range newTestSchema().GetFields() {
if isVectorType(field.GetDataType()) {
if typeutil.IsVectorType(field.GetDataType()) {
vecFields = append(vecFields, field)
}
}

View File

@ -93,18 +93,6 @@ func isNumber(c uint8) bool {
return true
}
func isVectorType(dataType schemapb.DataType) bool {
return dataType == schemapb.DataType_FloatVector ||
dataType == schemapb.DataType_BinaryVector ||
dataType == schemapb.DataType_Float16Vector ||
dataType == schemapb.DataType_BFloat16Vector ||
dataType == schemapb.DataType_SparseFloatVector
}
func isSparseVectorType(dataType schemapb.DataType) bool {
return dataType == schemapb.DataType_SparseFloatVector
}
func validateMaxQueryResultWindow(offset int64, limit int64) error {
if offset < 0 {
return fmt.Errorf("%s [%d] is invalid, should be gte than 0", OffsetKey, offset)
@ -301,7 +289,7 @@ func validateDimension(field *schemapb.FieldSchema) error {
break
}
}
if isSparseVectorType(field.DataType) {
if typeutil.IsSparseFloatVectorType(field.DataType) {
if exist {
return fmt.Errorf("dim should not be specified for sparse vector field %s(%d)", field.Name, field.FieldID)
}
@ -315,7 +303,7 @@ func validateDimension(field *schemapb.FieldSchema) error {
return fmt.Errorf("invalid dimension: %d. should be in range 2 ~ %d", dim, Params.ProxyCfg.MaxDimension.GetAsInt())
}
if field.DataType != schemapb.DataType_BinaryVector {
if typeutil.IsFloatVectorType(field.DataType) {
if dim > Params.ProxyCfg.MaxDimension.GetAsInt64() {
return fmt.Errorf("invalid dimension: %d. float vector dimension should be in range 2 ~ %d", dim, Params.ProxyCfg.MaxDimension.GetAsInt())
}
@ -379,7 +367,7 @@ func validateMaxCapacityPerRow(collectionName string, field *schemapb.FieldSchem
}
func validateVectorFieldMetricType(field *schemapb.FieldSchema) error {
if !isVectorType(field.DataType) {
if !typeutil.IsVectorType(field.DataType) {
return nil
}
for _, params := range field.IndexParams {
@ -520,7 +508,7 @@ func validateMetricType(dataType schemapb.DataType, metricTypeStrRaw string) err
metricTypeStr := strings.ToUpper(metricTypeStrRaw)
switch metricTypeStr {
case metric.L2, metric.IP, metric.COSINE:
if dataType == schemapb.DataType_FloatVector || dataType == schemapb.DataType_Float16Vector || dataType == schemapb.DataType_BFloat16Vector || dataType == schemapb.DataType_SparseFloatVector {
if typeutil.IsFloatVectorType(dataType) {
return nil
}
case metric.JACCARD, metric.HAMMING, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE:
@ -581,7 +569,7 @@ func validateSchema(coll *schemapb.CollectionSchema) error {
if err2 != nil {
return err2
}
if !isSparseVectorType(field.DataType) {
if !typeutil.IsSparseFloatVectorType(field.DataType) {
dimStr, ok := typeKv[common.DimKey]
if !ok {
return fmt.Errorf("dim not found in type_params for vector field %s(%d)", field.Name, field.FieldID)
@ -626,7 +614,7 @@ func validateMultipleVectorFields(schema *schemapb.CollectionSchema) error {
for i := range schema.Fields {
name := schema.Fields[i].Name
dType := schema.Fields[i].DataType
isVec := dType == schemapb.DataType_BinaryVector || dType == schemapb.DataType_FloatVector || dType == schemapb.DataType_Float16Vector || dType == schemapb.DataType_BFloat16Vector || dType == schemapb.DataType_SparseFloatVector
isVec := typeutil.IsVectorType(dType)
if isVec && vecExist && !enableMultipleVectorFields {
return fmt.Errorf(
"multiple vector fields is not supported, fields name: %s, %s",

View File

@ -157,7 +157,7 @@ func (writer *InsertBinlogWriter) NextInsertEventWriter(dim ...int) (*insertEven
var event *insertEventWriter
var err error
if typeutil.IsVectorType(writer.PayloadDataType) && !typeutil.IsSparseVectorType(writer.PayloadDataType) {
if typeutil.IsVectorType(writer.PayloadDataType) && !typeutil.IsSparseFloatVectorType(writer.PayloadDataType) {
if len(dim) != 1 {
return nil, fmt.Errorf("incorrect input numbers")
}

View File

@ -215,7 +215,7 @@ func newDescriptorEvent() *descriptorEvent {
func newInsertEventWriter(dataType schemapb.DataType, dim ...int) (*insertEventWriter, error) {
var payloadWriter PayloadWriterInterface
var err error
if typeutil.IsVectorType(dataType) && !typeutil.IsSparseVectorType(dataType) {
if typeutil.IsVectorType(dataType) && !typeutil.IsSparseFloatVectorType(dataType) {
if len(dim) != 1 {
return nil, fmt.Errorf("incorrect input numbers")
}

View File

@ -433,7 +433,7 @@ func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) {
}
func (r *PayloadReader) GetSparseFloatVectorFromPayload() (*SparseFloatVectorFieldData, int, error) {
if !typeutil.IsSparseVectorType(r.colType) {
if !typeutil.IsSparseFloatVectorType(r.colType) {
return nil, -1, fmt.Errorf("failed to get sparse float vector from datatype %v", r.colType.String())
}
values := make([]parquet.ByteArray, r.numRows)

View File

@ -51,7 +51,7 @@ type NativePayloadWriter struct {
func NewPayloadWriter(colType schemapb.DataType, dim ...int) (PayloadWriterInterface, error) {
var arrowType arrow.DataType
// writer for sparse float vector doesn't require dim
if typeutil.IsVectorType(colType) && !typeutil.IsSparseVectorType(colType) {
if typeutil.IsVectorType(colType) && !typeutil.IsSparseFloatVectorType(colType) {
if len(dim) != 1 {
return nil, fmt.Errorf("incorrect input numbers")
}

View File

@ -5,6 +5,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type floatVectorBaseChecker struct {
@ -28,8 +29,8 @@ func (c floatVectorBaseChecker) CheckTrain(params map[string]string) error {
}
func (c floatVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error {
if dType != schemapb.DataType_FloatVector && dType != schemapb.DataType_Float16Vector && dType != schemapb.DataType_BFloat16Vector {
return fmt.Errorf("float or float16 or bfloat16 vector are only supported")
if !typeutil.IsDenseFloatVectorType(dType) {
return fmt.Errorf("data type should be FloatVector, Float16Vector or BFloat16Vector")
}
return nil
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type hnswChecker struct {
@ -32,7 +33,7 @@ func (c hnswChecker) CheckTrain(params map[string]string) error {
func (c hnswChecker) CheckValidDataType(dType schemapb.DataType) error {
// TODO(SPARSE) we'll add sparse vector support in HNSW later in cardinal
if dType != schemapb.DataType_FloatVector && dType != schemapb.DataType_Float16Vector && dType != schemapb.DataType_BFloat16Vector {
if !typeutil.IsDenseFloatVectorType(dType) {
return fmt.Errorf("HNSW only support float vector data type")
}
return nil

View File

@ -6,6 +6,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
// sparse vector don't check for dim, but baseChecker does, thus not including baseChecker
@ -32,7 +33,7 @@ func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error
}
func (c sparseFloatVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error {
if dType != schemapb.DataType_SparseFloatVector {
if !typeutil.IsSparseFloatVectorType(dType) {
return fmt.Errorf("only sparse float vector is supported for the specified index tpye")
}
return nil

View File

@ -13,7 +13,7 @@ func GetDim(field *schemapb.FieldSchema) (int64, error) {
if !IsVectorType(field.GetDataType()) {
return 0, fmt.Errorf("%s is not of vector type", field.GetDataType())
}
if IsSparseVectorType(field.GetDataType()) {
if IsSparseFloatVectorType(field.GetDataType()) {
return 0, fmt.Errorf("typeutil.GetDim should not invoke on sparse vector type")
}
h := NewKvPairs(append(field.GetIndexParams(), field.GetTypeParams()...))

View File

@ -371,20 +371,32 @@ func (helper *SchemaHelper) GetVectorDimFromID(fieldID int64) (int, error) {
return 0, fmt.Errorf("fieldID(%d) not has dim", fieldID)
}
// IsVectorType returns true if input is a vector type, otherwise false
func IsVectorType(dataType schemapb.DataType) bool {
func IsBinaryVectorType(dataType schemapb.DataType) bool {
return dataType == schemapb.DataType_BinaryVector
}
func IsDenseFloatVectorType(dataType schemapb.DataType) bool {
switch dataType {
case schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector, schemapb.DataType_SparseFloatVector:
case schemapb.DataType_FloatVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector:
return true
default:
return false
}
}
func IsSparseVectorType(dataType schemapb.DataType) bool {
func IsSparseFloatVectorType(dataType schemapb.DataType) bool {
return dataType == schemapb.DataType_SparseFloatVector
}
func IsFloatVectorType(dataType schemapb.DataType) bool {
return IsDenseFloatVectorType(dataType) || IsSparseFloatVectorType(dataType)
}
// IsVectorType returns true if input is a vector type, otherwise false
func IsVectorType(dataType schemapb.DataType) bool {
return IsBinaryVectorType(dataType) || IsFloatVectorType(dataType)
}
// IsIntegerType returns true if input is an integer type, otherwise false
func IsIntegerType(dataType schemapb.DataType) bool {
switch dataType {

View File

@ -254,19 +254,19 @@ func TestSchema(t *testing.T) {
assert.False(t, IsFloatingType(schemapb.DataType_BFloat16Vector))
assert.False(t, IsFloatingType(schemapb.DataType_SparseFloatVector))
assert.False(t, IsSparseVectorType(schemapb.DataType_Bool))
assert.False(t, IsSparseVectorType(schemapb.DataType_Int8))
assert.False(t, IsSparseVectorType(schemapb.DataType_Int16))
assert.False(t, IsSparseVectorType(schemapb.DataType_Int32))
assert.False(t, IsSparseVectorType(schemapb.DataType_Int64))
assert.False(t, IsSparseVectorType(schemapb.DataType_Float))
assert.False(t, IsSparseVectorType(schemapb.DataType_Double))
assert.False(t, IsSparseVectorType(schemapb.DataType_String))
assert.False(t, IsSparseVectorType(schemapb.DataType_BinaryVector))
assert.False(t, IsSparseVectorType(schemapb.DataType_FloatVector))
assert.False(t, IsSparseVectorType(schemapb.DataType_Float16Vector))
assert.False(t, IsSparseVectorType(schemapb.DataType_BFloat16Vector))
assert.True(t, IsSparseVectorType(schemapb.DataType_SparseFloatVector))
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Bool))
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Int8))
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Int16))
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Int32))
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Int64))
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Float))
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Double))
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_String))
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_BinaryVector))
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_FloatVector))
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Float16Vector))
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_BFloat16Vector))
assert.True(t, IsSparseFloatVectorType(schemapb.DataType_SparseFloatVector))
})
}