mirror of https://github.com/milvus-io/milvus.git
enhance: Unify data type check APIs for go (#31945)
cherry-pick from master pr: #31887 Issue: #22837 Signed-off-by: Cai Yudong <yudong.cai@zilliz.com>pull/31999/head
parent
958f933810
commit
4b430097dd
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/requestutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
var RestRequestInterceptorErr = errors.New("interceptor error placeholder")
|
||||
|
@ -383,7 +384,7 @@ func (h *HandlersV1) getCollectionDetails(c *gin.Context) {
|
|||
}
|
||||
vectorField := ""
|
||||
for _, field := range coll.Schema.Fields {
|
||||
if IsVectorField(field) {
|
||||
if typeutil.IsVectorType(field.DataType) {
|
||||
vectorField = field.Name
|
||||
break
|
||||
}
|
||||
|
|
|
@ -344,8 +344,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a
|
|||
}
|
||||
vectorField := ""
|
||||
for _, field := range coll.Schema.Fields {
|
||||
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector ||
|
||||
field.DataType == schemapb.DataType_Float16Vector || field.DataType == schemapb.DataType_BFloat16Vector {
|
||||
if typeutil.IsVectorType(field.DataType) {
|
||||
vectorField = field.Name
|
||||
break
|
||||
}
|
||||
|
@ -760,7 +759,7 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
|
|||
var vectorField *schemapb.FieldSchema
|
||||
if len(fieldName) == 0 {
|
||||
for _, field := range collSchema.Fields {
|
||||
if IsVectorField(field) {
|
||||
if typeutil.IsVectorType(field.DataType) {
|
||||
if len(fieldName) == 0 {
|
||||
fieldName = field.Name
|
||||
vectorField = field
|
||||
|
@ -771,7 +770,7 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
|
|||
}
|
||||
} else {
|
||||
for _, field := range collSchema.Fields {
|
||||
if field.Name == fieldName && IsVectorField(field) {
|
||||
if field.Name == fieldName && typeutil.IsVectorType(field.DataType) {
|
||||
vectorField = field
|
||||
break
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/parameterutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
func ParseUsernamePassword(c *gin.Context) (string, string, bool) {
|
||||
|
@ -124,14 +125,6 @@ func checkGetPrimaryKey(coll *schemapb.CollectionSchema, idResult gjson.Result)
|
|||
|
||||
// --------------------- collection details --------------------- //
|
||||
|
||||
func IsVectorField(field *schemapb.FieldSchema) bool {
|
||||
switch field.DataType {
|
||||
case schemapb.DataType_BinaryVector, schemapb.DataType_FloatVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func printFields(fields []*schemapb.FieldSchema) []gin.H {
|
||||
return printFieldDetails(fields, true)
|
||||
}
|
||||
|
@ -150,7 +143,7 @@ func printFieldDetails(fields []*schemapb.FieldSchema, oldVersion bool) []gin.H
|
|||
HTTPReturnFieldAutoID: field.AutoID,
|
||||
HTTPReturnDescription: field.Description,
|
||||
}
|
||||
if IsVectorField(field) {
|
||||
if typeutil.IsVectorType(field.DataType) {
|
||||
fieldDetail[HTTPReturnFieldType] = field.DataType.String()
|
||||
if oldVersion {
|
||||
dim, _ := getDim(field)
|
||||
|
|
|
@ -17,30 +17,24 @@
|
|||
package indexnode
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func estimateFieldDataSize(dim int64, numRows int64, dataType schemapb.DataType) (uint64, error) {
|
||||
if dataType == schemapb.DataType_FloatVector {
|
||||
var value float32
|
||||
/* #nosec G103 */
|
||||
return uint64(dim) * uint64(numRows) * uint64(unsafe.Sizeof(value)), nil
|
||||
}
|
||||
if dataType == schemapb.DataType_BinaryVector {
|
||||
switch dataType {
|
||||
case schemapb.DataType_BinaryVector:
|
||||
return uint64(dim) / 8 * uint64(numRows), nil
|
||||
}
|
||||
if dataType == schemapb.DataType_Float16Vector {
|
||||
case schemapb.DataType_FloatVector:
|
||||
return uint64(dim) * uint64(numRows) * 4, nil
|
||||
case schemapb.DataType_Float16Vector:
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
return uint64(dim) * uint64(numRows) * 2, nil
|
||||
}
|
||||
if dataType == schemapb.DataType_BFloat16Vector {
|
||||
return uint64(dim) * uint64(numRows) * 2, nil
|
||||
}
|
||||
if dataType == schemapb.DataType_SparseFloatVector {
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
return 0, errors.New("could not estimate field data size of SparseFloatVector")
|
||||
default:
|
||||
return 0, nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
|
|
@ -129,16 +129,20 @@ func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorField
|
|||
if !typeutil.IsVectorType(dataType) {
|
||||
return nil, fmt.Errorf("field (%s) to search is not of vector data type", vectorFieldName)
|
||||
}
|
||||
if dataType == schemapb.DataType_FloatVector {
|
||||
vectorType = planpb.VectorType_FloatVector
|
||||
} else if dataType == schemapb.DataType_BinaryVector {
|
||||
switch dataType {
|
||||
case schemapb.DataType_BinaryVector:
|
||||
vectorType = planpb.VectorType_BinaryVector
|
||||
} else if dataType == schemapb.DataType_Float16Vector {
|
||||
case schemapb.DataType_FloatVector:
|
||||
vectorType = planpb.VectorType_FloatVector
|
||||
case schemapb.DataType_Float16Vector:
|
||||
vectorType = planpb.VectorType_Float16Vector
|
||||
} else if dataType == schemapb.DataType_BFloat16Vector {
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
vectorType = planpb.VectorType_BFloat16Vector
|
||||
} else if dataType == schemapb.DataType_SparseFloatVector {
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
vectorType = planpb.VectorType_SparseFloatVector
|
||||
default:
|
||||
log.Error("Invalid dataType", zap.Any("dataType", dataType))
|
||||
return nil, err
|
||||
}
|
||||
planNode := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
|
|
|
@ -323,7 +323,7 @@ func (t *createCollectionTask) PreExecute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
// validate dense vector field type parameters
|
||||
if isVectorType(field.DataType) {
|
||||
if typeutil.IsVectorType(field.DataType) {
|
||||
err = validateDimension(field)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -1567,7 +1567,7 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) {
|
|||
|
||||
unindexedVecFields := make([]string, 0)
|
||||
for _, field := range collSchema.GetFields() {
|
||||
if isVectorType(field.GetDataType()) {
|
||||
if typeutil.IsVectorType(field.GetDataType()) {
|
||||
if _, ok := fieldIndexIDs[field.GetFieldID()]; !ok {
|
||||
unindexedVecFields = append(unindexedVecFields, field.GetName())
|
||||
}
|
||||
|
@ -1810,7 +1810,7 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error {
|
|||
for _, index := range indexResponse.IndexInfos {
|
||||
fieldIndexIDs[index.FieldID] = index.IndexID
|
||||
for _, field := range collSchema.Fields {
|
||||
if index.FieldID == field.FieldID && isVectorType(field.DataType) {
|
||||
if index.FieldID == field.FieldID && typeutil.IsVectorType(field.DataType) {
|
||||
hasVecIndex = true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -308,7 +308,7 @@ func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.Fiel
|
|||
}
|
||||
|
||||
func fillDimension(field *schemapb.FieldSchema, indexParams map[string]string) error {
|
||||
if !isVectorType(field.GetDataType()) {
|
||||
if !typeutil.IsVectorType(field.GetDataType()) {
|
||||
return nil
|
||||
}
|
||||
params := make([]*commonpb.KeyValuePair, 0, len(field.GetTypeParams())+len(field.GetIndexParams()))
|
||||
|
@ -338,7 +338,7 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro
|
|||
return fmt.Errorf("invalid index type: %s", indexType)
|
||||
}
|
||||
|
||||
if !isSparseVectorType(field.DataType) {
|
||||
if !typeutil.IsSparseFloatVectorType(field.DataType) {
|
||||
if err := fillDimension(field, indexParams); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -78,7 +78,7 @@ func translateToOutputFieldIDs(outputFields []string, schema *schemapb.Collectio
|
|||
outputFieldIDs := make([]UniqueID, 0, len(outputFields)+1)
|
||||
if len(outputFields) == 0 {
|
||||
for _, field := range schema.Fields {
|
||||
if field.FieldID >= common.StartOfUserFieldID && !isVectorType(field.DataType) {
|
||||
if field.FieldID >= common.StartOfUserFieldID && !typeutil.IsVectorType(field.DataType) {
|
||||
outputFieldIDs = append(outputFieldIDs, field.FieldID)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -192,7 +192,7 @@ func constructCollectionSchemaByDataType(collectionName string, fieldName2DataTy
|
|||
DataType: dataType,
|
||||
}
|
||||
idx++
|
||||
if isVectorType(dataType) {
|
||||
if typeutil.IsVectorType(dataType) {
|
||||
fieldSchema.TypeParams = []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
|
@ -2507,7 +2507,7 @@ func Test_loadCollectionTask_Execute(t *testing.T) {
|
|||
t.Run("not all vector fields with index", func(t *testing.T) {
|
||||
vecFields := make([]*schemapb.FieldSchema, 0)
|
||||
for _, field := range newTestSchema().GetFields() {
|
||||
if isVectorType(field.GetDataType()) {
|
||||
if typeutil.IsVectorType(field.GetDataType()) {
|
||||
vecFields = append(vecFields, field)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -93,18 +93,6 @@ func isNumber(c uint8) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func isVectorType(dataType schemapb.DataType) bool {
|
||||
return dataType == schemapb.DataType_FloatVector ||
|
||||
dataType == schemapb.DataType_BinaryVector ||
|
||||
dataType == schemapb.DataType_Float16Vector ||
|
||||
dataType == schemapb.DataType_BFloat16Vector ||
|
||||
dataType == schemapb.DataType_SparseFloatVector
|
||||
}
|
||||
|
||||
func isSparseVectorType(dataType schemapb.DataType) bool {
|
||||
return dataType == schemapb.DataType_SparseFloatVector
|
||||
}
|
||||
|
||||
func validateMaxQueryResultWindow(offset int64, limit int64) error {
|
||||
if offset < 0 {
|
||||
return fmt.Errorf("%s [%d] is invalid, should be gte than 0", OffsetKey, offset)
|
||||
|
@ -301,7 +289,7 @@ func validateDimension(field *schemapb.FieldSchema) error {
|
|||
break
|
||||
}
|
||||
}
|
||||
if isSparseVectorType(field.DataType) {
|
||||
if typeutil.IsSparseFloatVectorType(field.DataType) {
|
||||
if exist {
|
||||
return fmt.Errorf("dim should not be specified for sparse vector field %s(%d)", field.Name, field.FieldID)
|
||||
}
|
||||
|
@ -315,7 +303,7 @@ func validateDimension(field *schemapb.FieldSchema) error {
|
|||
return fmt.Errorf("invalid dimension: %d. should be in range 2 ~ %d", dim, Params.ProxyCfg.MaxDimension.GetAsInt())
|
||||
}
|
||||
|
||||
if field.DataType != schemapb.DataType_BinaryVector {
|
||||
if typeutil.IsFloatVectorType(field.DataType) {
|
||||
if dim > Params.ProxyCfg.MaxDimension.GetAsInt64() {
|
||||
return fmt.Errorf("invalid dimension: %d. float vector dimension should be in range 2 ~ %d", dim, Params.ProxyCfg.MaxDimension.GetAsInt())
|
||||
}
|
||||
|
@ -379,7 +367,7 @@ func validateMaxCapacityPerRow(collectionName string, field *schemapb.FieldSchem
|
|||
}
|
||||
|
||||
func validateVectorFieldMetricType(field *schemapb.FieldSchema) error {
|
||||
if !isVectorType(field.DataType) {
|
||||
if !typeutil.IsVectorType(field.DataType) {
|
||||
return nil
|
||||
}
|
||||
for _, params := range field.IndexParams {
|
||||
|
@ -520,7 +508,7 @@ func validateMetricType(dataType schemapb.DataType, metricTypeStrRaw string) err
|
|||
metricTypeStr := strings.ToUpper(metricTypeStrRaw)
|
||||
switch metricTypeStr {
|
||||
case metric.L2, metric.IP, metric.COSINE:
|
||||
if dataType == schemapb.DataType_FloatVector || dataType == schemapb.DataType_Float16Vector || dataType == schemapb.DataType_BFloat16Vector || dataType == schemapb.DataType_SparseFloatVector {
|
||||
if typeutil.IsFloatVectorType(dataType) {
|
||||
return nil
|
||||
}
|
||||
case metric.JACCARD, metric.HAMMING, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE:
|
||||
|
@ -581,7 +569,7 @@ func validateSchema(coll *schemapb.CollectionSchema) error {
|
|||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
if !isSparseVectorType(field.DataType) {
|
||||
if !typeutil.IsSparseFloatVectorType(field.DataType) {
|
||||
dimStr, ok := typeKv[common.DimKey]
|
||||
if !ok {
|
||||
return fmt.Errorf("dim not found in type_params for vector field %s(%d)", field.Name, field.FieldID)
|
||||
|
@ -626,7 +614,7 @@ func validateMultipleVectorFields(schema *schemapb.CollectionSchema) error {
|
|||
for i := range schema.Fields {
|
||||
name := schema.Fields[i].Name
|
||||
dType := schema.Fields[i].DataType
|
||||
isVec := dType == schemapb.DataType_BinaryVector || dType == schemapb.DataType_FloatVector || dType == schemapb.DataType_Float16Vector || dType == schemapb.DataType_BFloat16Vector || dType == schemapb.DataType_SparseFloatVector
|
||||
isVec := typeutil.IsVectorType(dType)
|
||||
if isVec && vecExist && !enableMultipleVectorFields {
|
||||
return fmt.Errorf(
|
||||
"multiple vector fields is not supported, fields name: %s, %s",
|
||||
|
|
|
@ -157,7 +157,7 @@ func (writer *InsertBinlogWriter) NextInsertEventWriter(dim ...int) (*insertEven
|
|||
|
||||
var event *insertEventWriter
|
||||
var err error
|
||||
if typeutil.IsVectorType(writer.PayloadDataType) && !typeutil.IsSparseVectorType(writer.PayloadDataType) {
|
||||
if typeutil.IsVectorType(writer.PayloadDataType) && !typeutil.IsSparseFloatVectorType(writer.PayloadDataType) {
|
||||
if len(dim) != 1 {
|
||||
return nil, fmt.Errorf("incorrect input numbers")
|
||||
}
|
||||
|
|
|
@ -215,7 +215,7 @@ func newDescriptorEvent() *descriptorEvent {
|
|||
func newInsertEventWriter(dataType schemapb.DataType, dim ...int) (*insertEventWriter, error) {
|
||||
var payloadWriter PayloadWriterInterface
|
||||
var err error
|
||||
if typeutil.IsVectorType(dataType) && !typeutil.IsSparseVectorType(dataType) {
|
||||
if typeutil.IsVectorType(dataType) && !typeutil.IsSparseFloatVectorType(dataType) {
|
||||
if len(dim) != 1 {
|
||||
return nil, fmt.Errorf("incorrect input numbers")
|
||||
}
|
||||
|
|
|
@ -433,7 +433,7 @@ func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) {
|
|||
}
|
||||
|
||||
func (r *PayloadReader) GetSparseFloatVectorFromPayload() (*SparseFloatVectorFieldData, int, error) {
|
||||
if !typeutil.IsSparseVectorType(r.colType) {
|
||||
if !typeutil.IsSparseFloatVectorType(r.colType) {
|
||||
return nil, -1, fmt.Errorf("failed to get sparse float vector from datatype %v", r.colType.String())
|
||||
}
|
||||
values := make([]parquet.ByteArray, r.numRows)
|
||||
|
|
|
@ -51,7 +51,7 @@ type NativePayloadWriter struct {
|
|||
func NewPayloadWriter(colType schemapb.DataType, dim ...int) (PayloadWriterInterface, error) {
|
||||
var arrowType arrow.DataType
|
||||
// writer for sparse float vector doesn't require dim
|
||||
if typeutil.IsVectorType(colType) && !typeutil.IsSparseVectorType(colType) {
|
||||
if typeutil.IsVectorType(colType) && !typeutil.IsSparseFloatVectorType(colType) {
|
||||
if len(dim) != 1 {
|
||||
return nil, fmt.Errorf("incorrect input numbers")
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type floatVectorBaseChecker struct {
|
||||
|
@ -28,8 +29,8 @@ func (c floatVectorBaseChecker) CheckTrain(params map[string]string) error {
|
|||
}
|
||||
|
||||
func (c floatVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error {
|
||||
if dType != schemapb.DataType_FloatVector && dType != schemapb.DataType_Float16Vector && dType != schemapb.DataType_BFloat16Vector {
|
||||
return fmt.Errorf("float or float16 or bfloat16 vector are only supported")
|
||||
if !typeutil.IsDenseFloatVectorType(dType) {
|
||||
return fmt.Errorf("data type should be FloatVector, Float16Vector or BFloat16Vector")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type hnswChecker struct {
|
||||
|
@ -32,7 +33,7 @@ func (c hnswChecker) CheckTrain(params map[string]string) error {
|
|||
|
||||
func (c hnswChecker) CheckValidDataType(dType schemapb.DataType) error {
|
||||
// TODO(SPARSE) we'll add sparse vector support in HNSW later in cardinal
|
||||
if dType != schemapb.DataType_FloatVector && dType != schemapb.DataType_Float16Vector && dType != schemapb.DataType_BFloat16Vector {
|
||||
if !typeutil.IsDenseFloatVectorType(dType) {
|
||||
return fmt.Errorf("HNSW only support float vector data type")
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
// sparse vector don't check for dim, but baseChecker does, thus not including baseChecker
|
||||
|
@ -32,7 +33,7 @@ func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error
|
|||
}
|
||||
|
||||
func (c sparseFloatVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error {
|
||||
if dType != schemapb.DataType_SparseFloatVector {
|
||||
if !typeutil.IsSparseFloatVectorType(dType) {
|
||||
return fmt.Errorf("only sparse float vector is supported for the specified index tpye")
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -13,7 +13,7 @@ func GetDim(field *schemapb.FieldSchema) (int64, error) {
|
|||
if !IsVectorType(field.GetDataType()) {
|
||||
return 0, fmt.Errorf("%s is not of vector type", field.GetDataType())
|
||||
}
|
||||
if IsSparseVectorType(field.GetDataType()) {
|
||||
if IsSparseFloatVectorType(field.GetDataType()) {
|
||||
return 0, fmt.Errorf("typeutil.GetDim should not invoke on sparse vector type")
|
||||
}
|
||||
h := NewKvPairs(append(field.GetIndexParams(), field.GetTypeParams()...))
|
||||
|
|
|
@ -371,20 +371,32 @@ func (helper *SchemaHelper) GetVectorDimFromID(fieldID int64) (int, error) {
|
|||
return 0, fmt.Errorf("fieldID(%d) not has dim", fieldID)
|
||||
}
|
||||
|
||||
// IsVectorType returns true if input is a vector type, otherwise false
|
||||
func IsVectorType(dataType schemapb.DataType) bool {
|
||||
func IsBinaryVectorType(dataType schemapb.DataType) bool {
|
||||
return dataType == schemapb.DataType_BinaryVector
|
||||
}
|
||||
|
||||
func IsDenseFloatVectorType(dataType schemapb.DataType) bool {
|
||||
switch dataType {
|
||||
case schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector, schemapb.DataType_SparseFloatVector:
|
||||
case schemapb.DataType_FloatVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func IsSparseVectorType(dataType schemapb.DataType) bool {
|
||||
func IsSparseFloatVectorType(dataType schemapb.DataType) bool {
|
||||
return dataType == schemapb.DataType_SparseFloatVector
|
||||
}
|
||||
|
||||
func IsFloatVectorType(dataType schemapb.DataType) bool {
|
||||
return IsDenseFloatVectorType(dataType) || IsSparseFloatVectorType(dataType)
|
||||
}
|
||||
|
||||
// IsVectorType returns true if input is a vector type, otherwise false
|
||||
func IsVectorType(dataType schemapb.DataType) bool {
|
||||
return IsBinaryVectorType(dataType) || IsFloatVectorType(dataType)
|
||||
}
|
||||
|
||||
// IsIntegerType returns true if input is an integer type, otherwise false
|
||||
func IsIntegerType(dataType schemapb.DataType) bool {
|
||||
switch dataType {
|
||||
|
|
|
@ -254,19 +254,19 @@ func TestSchema(t *testing.T) {
|
|||
assert.False(t, IsFloatingType(schemapb.DataType_BFloat16Vector))
|
||||
assert.False(t, IsFloatingType(schemapb.DataType_SparseFloatVector))
|
||||
|
||||
assert.False(t, IsSparseVectorType(schemapb.DataType_Bool))
|
||||
assert.False(t, IsSparseVectorType(schemapb.DataType_Int8))
|
||||
assert.False(t, IsSparseVectorType(schemapb.DataType_Int16))
|
||||
assert.False(t, IsSparseVectorType(schemapb.DataType_Int32))
|
||||
assert.False(t, IsSparseVectorType(schemapb.DataType_Int64))
|
||||
assert.False(t, IsSparseVectorType(schemapb.DataType_Float))
|
||||
assert.False(t, IsSparseVectorType(schemapb.DataType_Double))
|
||||
assert.False(t, IsSparseVectorType(schemapb.DataType_String))
|
||||
assert.False(t, IsSparseVectorType(schemapb.DataType_BinaryVector))
|
||||
assert.False(t, IsSparseVectorType(schemapb.DataType_FloatVector))
|
||||
assert.False(t, IsSparseVectorType(schemapb.DataType_Float16Vector))
|
||||
assert.False(t, IsSparseVectorType(schemapb.DataType_BFloat16Vector))
|
||||
assert.True(t, IsSparseVectorType(schemapb.DataType_SparseFloatVector))
|
||||
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Bool))
|
||||
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Int8))
|
||||
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Int16))
|
||||
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Int32))
|
||||
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Int64))
|
||||
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Float))
|
||||
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Double))
|
||||
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_String))
|
||||
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_BinaryVector))
|
||||
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_FloatVector))
|
||||
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_Float16Vector))
|
||||
assert.False(t, IsSparseFloatVectorType(schemapb.DataType_BFloat16Vector))
|
||||
assert.True(t, IsSparseFloatVectorType(schemapb.DataType_SparseFloatVector))
|
||||
})
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue