mirror of https://github.com/milvus-io/milvus.git
enhance: [GoSDK] support Doc-in-doc-out APIs (#37590)
Related to #35853 This PR contains following changes: - Add function and related proto and helper functions - Remove the insert column missing check and leave it to server - Add text as search input data - Add some unit tests for logic above --------- Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>pull/37595/head
parent
a45a288a25
commit
24c6a4bb29
|
@ -29,4 +29,5 @@ const (
|
|||
TANIMOTO MetricType = "TANIMOTO"
|
||||
SUBSTRUCTURE MetricType = "SUBSTRUCTURE"
|
||||
SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE"
|
||||
BM25 MetricType = "BM25"
|
||||
)
|
||||
|
|
|
@ -0,0 +1,401 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
// FieldType field data type alias type
|
||||
// used in go:generate trick, DO NOT modify names & string
|
||||
type FieldType int32
|
||||
|
||||
// Name returns field type name
|
||||
func (t FieldType) Name() string {
|
||||
switch t {
|
||||
case FieldTypeBool:
|
||||
return "Bool"
|
||||
case FieldTypeInt8:
|
||||
return "Int8"
|
||||
case FieldTypeInt16:
|
||||
return "Int16"
|
||||
case FieldTypeInt32:
|
||||
return "Int32"
|
||||
case FieldTypeInt64:
|
||||
return "Int64"
|
||||
case FieldTypeFloat:
|
||||
return "Float"
|
||||
case FieldTypeDouble:
|
||||
return "Double"
|
||||
case FieldTypeString:
|
||||
return "String"
|
||||
case FieldTypeVarChar:
|
||||
return "VarChar"
|
||||
case FieldTypeArray:
|
||||
return "Array"
|
||||
case FieldTypeJSON:
|
||||
return "JSON"
|
||||
case FieldTypeBinaryVector:
|
||||
return "BinaryVector"
|
||||
case FieldTypeFloatVector:
|
||||
return "FloatVector"
|
||||
case FieldTypeFloat16Vector:
|
||||
return "Float16Vector"
|
||||
case FieldTypeBFloat16Vector:
|
||||
return "BFloat16Vector"
|
||||
default:
|
||||
return "undefined"
|
||||
}
|
||||
}
|
||||
|
||||
// String returns field type
|
||||
func (t FieldType) String() string {
|
||||
switch t {
|
||||
case FieldTypeBool:
|
||||
return "bool"
|
||||
case FieldTypeInt8:
|
||||
return "int8"
|
||||
case FieldTypeInt16:
|
||||
return "int16"
|
||||
case FieldTypeInt32:
|
||||
return "int32"
|
||||
case FieldTypeInt64:
|
||||
return "int64"
|
||||
case FieldTypeFloat:
|
||||
return "float32"
|
||||
case FieldTypeDouble:
|
||||
return "float64"
|
||||
case FieldTypeString:
|
||||
return "string"
|
||||
case FieldTypeVarChar:
|
||||
return "string"
|
||||
case FieldTypeArray:
|
||||
return "Array"
|
||||
case FieldTypeJSON:
|
||||
return "JSON"
|
||||
case FieldTypeBinaryVector:
|
||||
return "[]byte"
|
||||
case FieldTypeFloatVector:
|
||||
return "[]float32"
|
||||
case FieldTypeFloat16Vector:
|
||||
return "[]byte"
|
||||
case FieldTypeBFloat16Vector:
|
||||
return "[]byte"
|
||||
default:
|
||||
return "undefined"
|
||||
}
|
||||
}
|
||||
|
||||
// PbFieldType represents FieldType corresponding schema pb type
|
||||
func (t FieldType) PbFieldType() (string, string) {
|
||||
switch t {
|
||||
case FieldTypeBool:
|
||||
return "Bool", "bool"
|
||||
case FieldTypeInt8:
|
||||
fallthrough
|
||||
case FieldTypeInt16:
|
||||
fallthrough
|
||||
case FieldTypeInt32:
|
||||
return "Int", "int32"
|
||||
case FieldTypeInt64:
|
||||
return "Long", "int64"
|
||||
case FieldTypeFloat:
|
||||
return "Float", "float32"
|
||||
case FieldTypeDouble:
|
||||
return "Double", "float64"
|
||||
case FieldTypeString:
|
||||
return "String", "string"
|
||||
case FieldTypeVarChar:
|
||||
return "VarChar", "string"
|
||||
case FieldTypeJSON:
|
||||
return "JSON", "JSON"
|
||||
case FieldTypeBinaryVector:
|
||||
return "[]byte", ""
|
||||
case FieldTypeFloatVector:
|
||||
return "[]float32", ""
|
||||
case FieldTypeFloat16Vector:
|
||||
return "[]byte", ""
|
||||
case FieldTypeBFloat16Vector:
|
||||
return "[]byte", ""
|
||||
default:
|
||||
return "undefined", ""
|
||||
}
|
||||
}
|
||||
|
||||
// Match schema definition
|
||||
const (
|
||||
// FieldTypeNone zero value place holder
|
||||
FieldTypeNone FieldType = 0 // zero value place holder
|
||||
// FieldTypeBool field type boolean
|
||||
FieldTypeBool FieldType = 1
|
||||
// FieldTypeInt8 field type int8
|
||||
FieldTypeInt8 FieldType = 2
|
||||
// FieldTypeInt16 field type int16
|
||||
FieldTypeInt16 FieldType = 3
|
||||
// FieldTypeInt32 field type int32
|
||||
FieldTypeInt32 FieldType = 4
|
||||
// FieldTypeInt64 field type int64
|
||||
FieldTypeInt64 FieldType = 5
|
||||
// FieldTypeFloat field type float
|
||||
FieldTypeFloat FieldType = 10
|
||||
// FieldTypeDouble field type double
|
||||
FieldTypeDouble FieldType = 11
|
||||
// FieldTypeString field type string
|
||||
FieldTypeString FieldType = 20
|
||||
// FieldTypeVarChar field type varchar
|
||||
FieldTypeVarChar FieldType = 21 // variable-length strings with a specified maximum length
|
||||
// FieldTypeArray field type Array
|
||||
FieldTypeArray FieldType = 22
|
||||
// FieldTypeJSON field type JSON
|
||||
FieldTypeJSON FieldType = 23
|
||||
// FieldTypeBinaryVector field type binary vector
|
||||
FieldTypeBinaryVector FieldType = 100
|
||||
// FieldTypeFloatVector field type float vector
|
||||
FieldTypeFloatVector FieldType = 101
|
||||
// FieldTypeBinaryVector field type float16 vector
|
||||
FieldTypeFloat16Vector FieldType = 102
|
||||
// FieldTypeBinaryVector field type bf16 vector
|
||||
FieldTypeBFloat16Vector FieldType = 103
|
||||
// FieldTypeBinaryVector field type sparse vector
|
||||
FieldTypeSparseVector FieldType = 104
|
||||
)
|
||||
|
||||
// Field represent field schema in milvus
|
||||
type Field struct {
|
||||
ID int64 // field id, generated when collection is created, input value is ignored
|
||||
Name string // field name
|
||||
PrimaryKey bool // is primary key
|
||||
AutoID bool // is auto id
|
||||
Description string
|
||||
DataType FieldType
|
||||
TypeParams map[string]string
|
||||
IndexParams map[string]string
|
||||
IsDynamic bool
|
||||
IsPartitionKey bool
|
||||
IsClusteringKey bool
|
||||
ElementType FieldType
|
||||
}
|
||||
|
||||
// ProtoMessage generates corresponding FieldSchema
|
||||
func (f *Field) ProtoMessage() *schemapb.FieldSchema {
|
||||
return &schemapb.FieldSchema{
|
||||
FieldID: f.ID,
|
||||
Name: f.Name,
|
||||
Description: f.Description,
|
||||
IsPrimaryKey: f.PrimaryKey,
|
||||
AutoID: f.AutoID,
|
||||
DataType: schemapb.DataType(f.DataType),
|
||||
TypeParams: MapKvPairs(f.TypeParams),
|
||||
IndexParams: MapKvPairs(f.IndexParams),
|
||||
IsDynamic: f.IsDynamic,
|
||||
IsPartitionKey: f.IsPartitionKey,
|
||||
IsClusteringKey: f.IsClusteringKey,
|
||||
ElementType: schemapb.DataType(f.ElementType),
|
||||
}
|
||||
}
|
||||
|
||||
// NewField creates a new Field with map initialized.
|
||||
func NewField() *Field {
|
||||
return &Field{
|
||||
TypeParams: make(map[string]string),
|
||||
IndexParams: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Field) WithName(name string) *Field {
|
||||
f.Name = name
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDescription(desc string) *Field {
|
||||
f.Description = desc
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDataType(dataType FieldType) *Field {
|
||||
f.DataType = dataType
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithIsPrimaryKey(isPrimaryKey bool) *Field {
|
||||
f.PrimaryKey = isPrimaryKey
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithIsAutoID(isAutoID bool) *Field {
|
||||
f.AutoID = isAutoID
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithIsDynamic(isDynamic bool) *Field {
|
||||
f.IsDynamic = isDynamic
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithIsPartitionKey(isPartitionKey bool) *Field {
|
||||
f.IsPartitionKey = isPartitionKey
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithIsClusteringKey(isClusteringKey bool) *Field {
|
||||
f.IsClusteringKey = isClusteringKey
|
||||
return f
|
||||
}
|
||||
|
||||
/*
|
||||
func (f *Field) WithDefaultValueBool(defaultValue bool) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_BoolData{
|
||||
BoolData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueInt(defaultValue int32) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_IntData{
|
||||
IntData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueLong(defaultValue int64) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_LongData{
|
||||
LongData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueFloat(defaultValue float32) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_FloatData{
|
||||
FloatData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueDouble(defaultValue float64) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_DoubleData{
|
||||
DoubleData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueString(defaultValue string) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_StringData{
|
||||
StringData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}*/
|
||||
|
||||
func (f *Field) WithTypeParams(key string, value string) *Field {
|
||||
if f.TypeParams == nil {
|
||||
f.TypeParams = make(map[string]string)
|
||||
}
|
||||
f.TypeParams[key] = value
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDim(dim int64) *Field {
|
||||
if f.TypeParams == nil {
|
||||
f.TypeParams = make(map[string]string)
|
||||
}
|
||||
f.TypeParams[TypeParamDim] = strconv.FormatInt(dim, 10)
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) GetDim() (int64, error) {
|
||||
dimStr, has := f.TypeParams[TypeParamDim]
|
||||
if !has {
|
||||
return -1, errors.New("field with no dim")
|
||||
}
|
||||
dim, err := strconv.ParseInt(dimStr, 10, 64)
|
||||
if err != nil {
|
||||
return -1, errors.Newf("field with bad format dim: %s", err.Error())
|
||||
}
|
||||
return dim, nil
|
||||
}
|
||||
|
||||
func (f *Field) WithMaxLength(maxLen int64) *Field {
|
||||
if f.TypeParams == nil {
|
||||
f.TypeParams = make(map[string]string)
|
||||
}
|
||||
f.TypeParams[TypeParamMaxLength] = strconv.FormatInt(maxLen, 10)
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithElementType(eleType FieldType) *Field {
|
||||
f.ElementType = eleType
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithMaxCapacity(maxCap int64) *Field {
|
||||
if f.TypeParams == nil {
|
||||
f.TypeParams = make(map[string]string)
|
||||
}
|
||||
f.TypeParams[TypeParamMaxCapacity] = strconv.FormatInt(maxCap, 10)
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithEnableAnalyzer(enable bool) *Field {
|
||||
if f.TypeParams == nil {
|
||||
f.TypeParams = make(map[string]string)
|
||||
}
|
||||
f.TypeParams["enable_analyzer"] = strconv.FormatBool(enable)
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithAnalyzerParams(params map[string]any) *Field {
|
||||
if f.TypeParams == nil {
|
||||
f.TypeParams = make(map[string]string)
|
||||
}
|
||||
bs, _ := json.Marshal(params)
|
||||
f.TypeParams["analyzer_params"] = string(bs)
|
||||
return f
|
||||
}
|
||||
|
||||
// ReadProto parses FieldSchema
|
||||
func (f *Field) ReadProto(p *schemapb.FieldSchema) *Field {
|
||||
f.ID = p.GetFieldID()
|
||||
f.Name = p.GetName()
|
||||
f.PrimaryKey = p.GetIsPrimaryKey()
|
||||
f.AutoID = p.GetAutoID()
|
||||
f.Description = p.GetDescription()
|
||||
f.DataType = FieldType(p.GetDataType())
|
||||
f.TypeParams = KvPairsMap(p.GetTypeParams())
|
||||
f.IndexParams = KvPairsMap(p.GetIndexParams())
|
||||
f.IsDynamic = p.GetIsDynamic()
|
||||
f.IsPartitionKey = p.GetIsPartitionKey()
|
||||
f.IsClusteringKey = p.GetIsClusteringKey()
|
||||
f.ElementType = FieldType(p.GetElementType())
|
||||
|
||||
return f
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFieldSchema(t *testing.T) {
|
||||
fields := []*Field{
|
||||
NewField().WithName("int_field").WithDataType(FieldTypeInt64).WithIsAutoID(true).WithIsPrimaryKey(true).WithDescription("int_field desc"),
|
||||
NewField().WithName("string_field").WithDataType(FieldTypeString).WithIsAutoID(false).WithIsPrimaryKey(true).WithIsDynamic(false).WithTypeParams("max_len", "32").WithDescription("string_field desc"),
|
||||
NewField().WithName("partition_key").WithDataType(FieldTypeInt32).WithIsPartitionKey(true),
|
||||
NewField().WithName("array_field").WithDataType(FieldTypeArray).WithElementType(FieldTypeBool).WithMaxCapacity(128),
|
||||
NewField().WithName("clustering_key").WithDataType(FieldTypeInt32).WithIsClusteringKey(true),
|
||||
NewField().WithName("varchar_text").WithDataType(FieldTypeVarChar).WithMaxLength(65535).WithEnableAnalyzer(true).WithAnalyzerParams(map[string]any{}),
|
||||
/*
|
||||
NewField().WithName("default_value_bool").WithDataType(FieldTypeBool).WithDefaultValueBool(true),
|
||||
NewField().WithName("default_value_int").WithDataType(FieldTypeInt32).WithDefaultValueInt(1),
|
||||
NewField().WithName("default_value_long").WithDataType(FieldTypeInt64).WithDefaultValueLong(1),
|
||||
NewField().WithName("default_value_float").WithDataType(FieldTypeFloat).WithDefaultValueFloat(1),
|
||||
NewField().WithName("default_value_double").WithDataType(FieldTypeDouble).WithDefaultValueDouble(1),
|
||||
NewField().WithName("default_value_string").WithDataType(FieldTypeString).WithDefaultValueString("a"),*/
|
||||
}
|
||||
|
||||
for _, field := range fields {
|
||||
fieldSchema := field.ProtoMessage()
|
||||
assert.Equal(t, field.ID, fieldSchema.GetFieldID())
|
||||
assert.Equal(t, field.Name, fieldSchema.GetName())
|
||||
assert.EqualValues(t, field.DataType, fieldSchema.GetDataType())
|
||||
assert.Equal(t, field.AutoID, fieldSchema.GetAutoID())
|
||||
assert.Equal(t, field.PrimaryKey, fieldSchema.GetIsPrimaryKey())
|
||||
assert.Equal(t, field.IsPartitionKey, fieldSchema.GetIsPartitionKey())
|
||||
assert.Equal(t, field.IsClusteringKey, fieldSchema.GetIsClusteringKey())
|
||||
assert.Equal(t, field.IsDynamic, fieldSchema.GetIsDynamic())
|
||||
assert.Equal(t, field.Description, fieldSchema.GetDescription())
|
||||
assert.Equal(t, field.TypeParams, KvPairsMap(fieldSchema.GetTypeParams()))
|
||||
assert.EqualValues(t, field.ElementType, fieldSchema.GetElementType())
|
||||
// marshal & unmarshal, still equals
|
||||
nf := &Field{}
|
||||
nf = nf.ReadProto(fieldSchema)
|
||||
assert.Equal(t, field.ID, nf.ID)
|
||||
assert.Equal(t, field.Name, nf.Name)
|
||||
assert.EqualValues(t, field.DataType, nf.DataType)
|
||||
assert.Equal(t, field.AutoID, nf.AutoID)
|
||||
assert.Equal(t, field.PrimaryKey, nf.PrimaryKey)
|
||||
assert.Equal(t, field.Description, nf.Description)
|
||||
assert.Equal(t, field.IsDynamic, nf.IsDynamic)
|
||||
assert.Equal(t, field.IsPartitionKey, nf.IsPartitionKey)
|
||||
assert.Equal(t, field.IsClusteringKey, nf.IsClusteringKey)
|
||||
assert.EqualValues(t, field.TypeParams, nf.TypeParams)
|
||||
assert.EqualValues(t, field.ElementType, nf.ElementType)
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
(&Field{}).WithTypeParams("a", "b")
|
||||
})
|
||||
}
|
|
@ -1,171 +0,0 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
// FieldType field data type alias type
|
||||
// used in go:generate trick, DO NOT modify names & string
|
||||
type FieldType int32
|
||||
|
||||
// Name returns field type name
|
||||
func (t FieldType) Name() string {
|
||||
switch t {
|
||||
case FieldTypeBool:
|
||||
return "Bool"
|
||||
case FieldTypeInt8:
|
||||
return "Int8"
|
||||
case FieldTypeInt16:
|
||||
return "Int16"
|
||||
case FieldTypeInt32:
|
||||
return "Int32"
|
||||
case FieldTypeInt64:
|
||||
return "Int64"
|
||||
case FieldTypeFloat:
|
||||
return "Float"
|
||||
case FieldTypeDouble:
|
||||
return "Double"
|
||||
case FieldTypeString:
|
||||
return "String"
|
||||
case FieldTypeVarChar:
|
||||
return "VarChar"
|
||||
case FieldTypeArray:
|
||||
return "Array"
|
||||
case FieldTypeJSON:
|
||||
return "JSON"
|
||||
case FieldTypeBinaryVector:
|
||||
return "BinaryVector"
|
||||
case FieldTypeFloatVector:
|
||||
return "FloatVector"
|
||||
case FieldTypeFloat16Vector:
|
||||
return "Float16Vector"
|
||||
case FieldTypeBFloat16Vector:
|
||||
return "BFloat16Vector"
|
||||
default:
|
||||
return "undefined"
|
||||
}
|
||||
}
|
||||
|
||||
// String returns field type
|
||||
func (t FieldType) String() string {
|
||||
switch t {
|
||||
case FieldTypeBool:
|
||||
return "bool"
|
||||
case FieldTypeInt8:
|
||||
return "int8"
|
||||
case FieldTypeInt16:
|
||||
return "int16"
|
||||
case FieldTypeInt32:
|
||||
return "int32"
|
||||
case FieldTypeInt64:
|
||||
return "int64"
|
||||
case FieldTypeFloat:
|
||||
return "float32"
|
||||
case FieldTypeDouble:
|
||||
return "float64"
|
||||
case FieldTypeString:
|
||||
return "string"
|
||||
case FieldTypeVarChar:
|
||||
return "string"
|
||||
case FieldTypeArray:
|
||||
return "Array"
|
||||
case FieldTypeJSON:
|
||||
return "JSON"
|
||||
case FieldTypeBinaryVector:
|
||||
return "[]byte"
|
||||
case FieldTypeFloatVector:
|
||||
return "[]float32"
|
||||
case FieldTypeFloat16Vector:
|
||||
return "[]byte"
|
||||
case FieldTypeBFloat16Vector:
|
||||
return "[]byte"
|
||||
default:
|
||||
return "undefined"
|
||||
}
|
||||
}
|
||||
|
||||
// PbFieldType represents FieldType corresponding schema pb type
|
||||
func (t FieldType) PbFieldType() (string, string) {
|
||||
switch t {
|
||||
case FieldTypeBool:
|
||||
return "Bool", "bool"
|
||||
case FieldTypeInt8:
|
||||
fallthrough
|
||||
case FieldTypeInt16:
|
||||
fallthrough
|
||||
case FieldTypeInt32:
|
||||
return "Int", "int32"
|
||||
case FieldTypeInt64:
|
||||
return "Long", "int64"
|
||||
case FieldTypeFloat:
|
||||
return "Float", "float32"
|
||||
case FieldTypeDouble:
|
||||
return "Double", "float64"
|
||||
case FieldTypeString:
|
||||
return "String", "string"
|
||||
case FieldTypeVarChar:
|
||||
return "VarChar", "string"
|
||||
case FieldTypeJSON:
|
||||
return "JSON", "JSON"
|
||||
case FieldTypeBinaryVector:
|
||||
return "[]byte", ""
|
||||
case FieldTypeFloatVector:
|
||||
return "[]float32", ""
|
||||
case FieldTypeFloat16Vector:
|
||||
return "[]byte", ""
|
||||
case FieldTypeBFloat16Vector:
|
||||
return "[]byte", ""
|
||||
default:
|
||||
return "undefined", ""
|
||||
}
|
||||
}
|
||||
|
||||
// Match schema definition
|
||||
const (
|
||||
// FieldTypeNone zero value place holder
|
||||
FieldTypeNone FieldType = 0 // zero value place holder
|
||||
// FieldTypeBool field type boolean
|
||||
FieldTypeBool FieldType = 1
|
||||
// FieldTypeInt8 field type int8
|
||||
FieldTypeInt8 FieldType = 2
|
||||
// FieldTypeInt16 field type int16
|
||||
FieldTypeInt16 FieldType = 3
|
||||
// FieldTypeInt32 field type int32
|
||||
FieldTypeInt32 FieldType = 4
|
||||
// FieldTypeInt64 field type int64
|
||||
FieldTypeInt64 FieldType = 5
|
||||
// FieldTypeFloat field type float
|
||||
FieldTypeFloat FieldType = 10
|
||||
// FieldTypeDouble field type double
|
||||
FieldTypeDouble FieldType = 11
|
||||
// FieldTypeString field type string
|
||||
FieldTypeString FieldType = 20
|
||||
// FieldTypeVarChar field type varchar
|
||||
FieldTypeVarChar FieldType = 21 // variable-length strings with a specified maximum length
|
||||
// FieldTypeArray field type Array
|
||||
FieldTypeArray FieldType = 22
|
||||
// FieldTypeJSON field type JSON
|
||||
FieldTypeJSON FieldType = 23
|
||||
// FieldTypeBinaryVector field type binary vector
|
||||
FieldTypeBinaryVector FieldType = 100
|
||||
// FieldTypeFloatVector field type float vector
|
||||
FieldTypeFloatVector FieldType = 101
|
||||
// FieldTypeBinaryVector field type float16 vector
|
||||
FieldTypeFloat16Vector FieldType = 102
|
||||
// FieldTypeBinaryVector field type bf16 vector
|
||||
FieldTypeBFloat16Vector FieldType = 103
|
||||
// FieldTypeBinaryVector field type sparse vector
|
||||
FieldTypeSparseVector FieldType = 104
|
||||
)
|
|
@ -0,0 +1,109 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
type FunctionType = schemapb.FunctionType
|
||||
|
||||
// provide package alias
|
||||
const (
|
||||
FunctionTypeUnknown = schemapb.FunctionType_Unknown
|
||||
FunctionTypeBM25 = schemapb.FunctionType_BM25
|
||||
FunctionTypeTextEmbedding = schemapb.FunctionType_TextEmbedding
|
||||
)
|
||||
|
||||
type Function struct {
|
||||
Name string
|
||||
Description string
|
||||
Type FunctionType
|
||||
|
||||
InputFieldNames []string
|
||||
OutputFieldNames []string
|
||||
Params map[string]string
|
||||
|
||||
// ids shall be private
|
||||
id int64
|
||||
inputFieldIDs []int64
|
||||
outputFieldIDs []int64
|
||||
}
|
||||
|
||||
func NewFunction() *Function {
|
||||
return &Function{
|
||||
Params: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Function) WithName(name string) *Function {
|
||||
f.Name = name
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Function) WithInputFields(inputFields ...string) *Function {
|
||||
f.InputFieldNames = inputFields
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Function) WithOutputFields(outputFields ...string) *Function {
|
||||
f.OutputFieldNames = outputFields
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Function) WithType(funcType FunctionType) *Function {
|
||||
f.Type = funcType
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Function) WithParam(key string, value any) *Function {
|
||||
f.Params[key] = fmt.Sprintf("%v", value)
|
||||
return f
|
||||
}
|
||||
|
||||
// ProtoMessage returns corresponding schemapb.FunctionSchema
|
||||
func (f *Function) ProtoMessage() *schemapb.FunctionSchema {
|
||||
r := &schemapb.FunctionSchema{
|
||||
Name: f.Name,
|
||||
Description: f.Description,
|
||||
Type: f.Type,
|
||||
InputFieldNames: f.InputFieldNames,
|
||||
OutputFieldNames: f.OutputFieldNames,
|
||||
Params: MapKvPairs(f.Params),
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// ReadProto parses proto Collection Schema
|
||||
func (f *Function) ReadProto(p *schemapb.FunctionSchema) *Function {
|
||||
f.Name = p.GetName()
|
||||
f.Description = p.GetDescription()
|
||||
f.Type = p.GetType()
|
||||
|
||||
f.InputFieldNames = p.GetInputFieldNames()
|
||||
f.OutputFieldNames = p.GetOutputFieldNames()
|
||||
f.Params = KvPairsMap(p.GetParams())
|
||||
|
||||
f.id = p.GetId()
|
||||
f.inputFieldIDs = p.GetInputFieldIds()
|
||||
f.outputFieldIDs = p.GetOutputFieldIds()
|
||||
|
||||
return f
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFunctionSchema(t *testing.T) {
|
||||
functions := []*Function{
|
||||
NewFunction().WithName("text_bm25_emb").WithType(FunctionTypeBM25).WithInputFields("a", "b").WithOutputFields("c").WithParam("key", "value"),
|
||||
NewFunction().WithName("other_emb").WithType(FunctionTypeTextEmbedding).WithInputFields("c").WithOutputFields("b", "a"),
|
||||
}
|
||||
|
||||
for _, function := range functions {
|
||||
funcSchema := function.ProtoMessage()
|
||||
assert.Equal(t, function.Name, funcSchema.Name)
|
||||
assert.Equal(t, function.Type, funcSchema.Type)
|
||||
assert.Equal(t, function.InputFieldNames, funcSchema.InputFieldNames)
|
||||
assert.Equal(t, function.OutputFieldNames, funcSchema.OutputFieldNames)
|
||||
assert.Equal(t, function.Params, KvPairsMap(funcSchema.GetParams()))
|
||||
|
||||
nf := NewFunction()
|
||||
nf.ReadProto(funcSchema)
|
||||
|
||||
assert.Equal(t, function.Name, nf.Name)
|
||||
assert.Equal(t, function.Type, nf.Type)
|
||||
assert.Equal(t, function.InputFieldNames, nf.InputFieldNames)
|
||||
assert.Equal(t, function.OutputFieldNames, nf.OutputFieldNames)
|
||||
assert.Equal(t, function.Params, nf.Params)
|
||||
}
|
||||
}
|
|
@ -17,9 +17,7 @@
|
|||
package entity
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
|
@ -62,6 +60,7 @@ type Schema struct {
|
|||
AutoID bool
|
||||
Fields []*Field
|
||||
EnableDynamicField bool
|
||||
Functions []*Function
|
||||
|
||||
pkField *Field
|
||||
}
|
||||
|
@ -102,6 +101,11 @@ func (s *Schema) WithField(f *Field) *Schema {
|
|||
return s
|
||||
}
|
||||
|
||||
func (s *Schema) WithFunction(f *Function) *Schema {
|
||||
s.Functions = append(s.Functions, f)
|
||||
return s
|
||||
}
|
||||
|
||||
// ProtoMessage returns corresponding server.CollectionSchema
|
||||
func (s *Schema) ProtoMessage() *schemapb.CollectionSchema {
|
||||
r := &schemapb.CollectionSchema{
|
||||
|
@ -110,10 +114,14 @@ func (s *Schema) ProtoMessage() *schemapb.CollectionSchema {
|
|||
AutoID: s.AutoID,
|
||||
EnableDynamicField: s.EnableDynamicField,
|
||||
}
|
||||
r.Fields = make([]*schemapb.FieldSchema, 0, len(s.Fields))
|
||||
for _, field := range s.Fields {
|
||||
r.Fields = append(r.Fields, field.ProtoMessage())
|
||||
}
|
||||
r.Fields = lo.Map(s.Fields, func(field *Field, _ int) *schemapb.FieldSchema {
|
||||
return field.ProtoMessage()
|
||||
})
|
||||
|
||||
r.Functions = lo.Map(s.Functions, func(function *Function, _ int) *schemapb.FunctionSchema {
|
||||
return function.ProtoMessage()
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
|
@ -121,6 +129,8 @@ func (s *Schema) ProtoMessage() *schemapb.CollectionSchema {
|
|||
func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema {
|
||||
s.Description = p.GetDescription()
|
||||
s.CollectionName = p.GetName()
|
||||
s.EnableDynamicField = p.GetEnableDynamicField()
|
||||
// fields
|
||||
s.Fields = make([]*Field, 0, len(p.GetFields()))
|
||||
for _, fp := range p.GetFields() {
|
||||
field := NewField().ReadProto(fp)
|
||||
|
@ -132,7 +142,11 @@ func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema {
|
|||
}
|
||||
s.Fields = append(s.Fields, field)
|
||||
}
|
||||
s.EnableDynamicField = p.GetEnableDynamicField()
|
||||
// functions
|
||||
s.Functions = lo.Map(p.GetFunctions(), func(fn *schemapb.FunctionSchema, _ int) *Function {
|
||||
return NewFunction().ReadProto(fn)
|
||||
})
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -149,210 +163,6 @@ func (s *Schema) PKField() *Field {
|
|||
return s.pkField
|
||||
}
|
||||
|
||||
// Field represent field schema in milvus
|
||||
type Field struct {
|
||||
ID int64 // field id, generated when collection is created, input value is ignored
|
||||
Name string // field name
|
||||
PrimaryKey bool // is primary key
|
||||
AutoID bool // is auto id
|
||||
Description string
|
||||
DataType FieldType
|
||||
TypeParams map[string]string
|
||||
IndexParams map[string]string
|
||||
IsDynamic bool
|
||||
IsPartitionKey bool
|
||||
IsClusteringKey bool
|
||||
ElementType FieldType
|
||||
}
|
||||
|
||||
// ProtoMessage generates corresponding FieldSchema
|
||||
func (f *Field) ProtoMessage() *schemapb.FieldSchema {
|
||||
return &schemapb.FieldSchema{
|
||||
FieldID: f.ID,
|
||||
Name: f.Name,
|
||||
Description: f.Description,
|
||||
IsPrimaryKey: f.PrimaryKey,
|
||||
AutoID: f.AutoID,
|
||||
DataType: schemapb.DataType(f.DataType),
|
||||
TypeParams: MapKvPairs(f.TypeParams),
|
||||
IndexParams: MapKvPairs(f.IndexParams),
|
||||
IsDynamic: f.IsDynamic,
|
||||
IsPartitionKey: f.IsPartitionKey,
|
||||
IsClusteringKey: f.IsClusteringKey,
|
||||
ElementType: schemapb.DataType(f.ElementType),
|
||||
}
|
||||
}
|
||||
|
||||
// NewField creates a new Field with map initialized.
|
||||
func NewField() *Field {
|
||||
return &Field{
|
||||
TypeParams: make(map[string]string),
|
||||
IndexParams: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Field) WithName(name string) *Field {
|
||||
f.Name = name
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDescription(desc string) *Field {
|
||||
f.Description = desc
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDataType(dataType FieldType) *Field {
|
||||
f.DataType = dataType
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithIsPrimaryKey(isPrimaryKey bool) *Field {
|
||||
f.PrimaryKey = isPrimaryKey
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithIsAutoID(isAutoID bool) *Field {
|
||||
f.AutoID = isAutoID
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithIsDynamic(isDynamic bool) *Field {
|
||||
f.IsDynamic = isDynamic
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithIsPartitionKey(isPartitionKey bool) *Field {
|
||||
f.IsPartitionKey = isPartitionKey
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithIsClusteringKey(isClusteringKey bool) *Field {
|
||||
f.IsClusteringKey = isClusteringKey
|
||||
return f
|
||||
}
|
||||
|
||||
/*
|
||||
func (f *Field) WithDefaultValueBool(defaultValue bool) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_BoolData{
|
||||
BoolData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueInt(defaultValue int32) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_IntData{
|
||||
IntData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueLong(defaultValue int64) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_LongData{
|
||||
LongData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueFloat(defaultValue float32) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_FloatData{
|
||||
FloatData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueDouble(defaultValue float64) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_DoubleData{
|
||||
DoubleData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueString(defaultValue string) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_StringData{
|
||||
StringData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}*/
|
||||
|
||||
func (f *Field) WithTypeParams(key string, value string) *Field {
|
||||
if f.TypeParams == nil {
|
||||
f.TypeParams = make(map[string]string)
|
||||
}
|
||||
f.TypeParams[key] = value
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDim(dim int64) *Field {
|
||||
if f.TypeParams == nil {
|
||||
f.TypeParams = make(map[string]string)
|
||||
}
|
||||
f.TypeParams[TypeParamDim] = strconv.FormatInt(dim, 10)
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) GetDim() (int64, error) {
|
||||
dimStr, has := f.TypeParams[TypeParamDim]
|
||||
if !has {
|
||||
return -1, errors.New("field with no dim")
|
||||
}
|
||||
dim, err := strconv.ParseInt(dimStr, 10, 64)
|
||||
if err != nil {
|
||||
return -1, errors.Newf("field with bad format dim: %s", err.Error())
|
||||
}
|
||||
return dim, nil
|
||||
}
|
||||
|
||||
func (f *Field) WithMaxLength(maxLen int64) *Field {
|
||||
if f.TypeParams == nil {
|
||||
f.TypeParams = make(map[string]string)
|
||||
}
|
||||
f.TypeParams[TypeParamMaxLength] = strconv.FormatInt(maxLen, 10)
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithElementType(eleType FieldType) *Field {
|
||||
f.ElementType = eleType
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithMaxCapacity(maxCap int64) *Field {
|
||||
if f.TypeParams == nil {
|
||||
f.TypeParams = make(map[string]string)
|
||||
}
|
||||
f.TypeParams[TypeParamMaxCapacity] = strconv.FormatInt(maxCap, 10)
|
||||
return f
|
||||
}
|
||||
|
||||
// ReadProto parses FieldSchema
|
||||
func (f *Field) ReadProto(p *schemapb.FieldSchema) *Field {
|
||||
f.ID = p.GetFieldID()
|
||||
f.Name = p.GetName()
|
||||
f.PrimaryKey = p.GetIsPrimaryKey()
|
||||
f.AutoID = p.GetAutoID()
|
||||
f.Description = p.GetDescription()
|
||||
f.DataType = FieldType(p.GetDataType())
|
||||
f.TypeParams = KvPairsMap(p.GetTypeParams())
|
||||
f.IndexParams = KvPairsMap(p.GetIndexParams())
|
||||
f.IsDynamic = p.GetIsDynamic()
|
||||
f.IsPartitionKey = p.GetIsPartitionKey()
|
||||
f.IsClusteringKey = p.GetIsClusteringKey()
|
||||
f.ElementType = FieldType(p.GetElementType())
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
// MapKvPairs converts map into commonpb.KeyValuePair slice
|
||||
func MapKvPairs(m map[string]string) []*commonpb.KeyValuePair {
|
||||
pairs := make([]*commonpb.KeyValuePair, 0, len(m))
|
||||
|
|
|
@ -37,56 +37,6 @@ func TestCL_CommonCL(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestFieldSchema(t *testing.T) {
|
||||
fields := []*Field{
|
||||
NewField().WithName("int_field").WithDataType(FieldTypeInt64).WithIsAutoID(true).WithIsPrimaryKey(true).WithDescription("int_field desc"),
|
||||
NewField().WithName("string_field").WithDataType(FieldTypeString).WithIsAutoID(false).WithIsPrimaryKey(true).WithIsDynamic(false).WithTypeParams("max_len", "32").WithDescription("string_field desc"),
|
||||
NewField().WithName("partition_key").WithDataType(FieldTypeInt32).WithIsPartitionKey(true),
|
||||
NewField().WithName("array_field").WithDataType(FieldTypeArray).WithElementType(FieldTypeBool).WithMaxCapacity(128),
|
||||
NewField().WithName("clustering_key").WithDataType(FieldTypeInt32).WithIsClusteringKey(true),
|
||||
/*
|
||||
NewField().WithName("default_value_bool").WithDataType(FieldTypeBool).WithDefaultValueBool(true),
|
||||
NewField().WithName("default_value_int").WithDataType(FieldTypeInt32).WithDefaultValueInt(1),
|
||||
NewField().WithName("default_value_long").WithDataType(FieldTypeInt64).WithDefaultValueLong(1),
|
||||
NewField().WithName("default_value_float").WithDataType(FieldTypeFloat).WithDefaultValueFloat(1),
|
||||
NewField().WithName("default_value_double").WithDataType(FieldTypeDouble).WithDefaultValueDouble(1),
|
||||
NewField().WithName("default_value_string").WithDataType(FieldTypeString).WithDefaultValueString("a"),*/
|
||||
}
|
||||
|
||||
for _, field := range fields {
|
||||
fieldSchema := field.ProtoMessage()
|
||||
assert.Equal(t, field.ID, fieldSchema.GetFieldID())
|
||||
assert.Equal(t, field.Name, fieldSchema.GetName())
|
||||
assert.EqualValues(t, field.DataType, fieldSchema.GetDataType())
|
||||
assert.Equal(t, field.AutoID, fieldSchema.GetAutoID())
|
||||
assert.Equal(t, field.PrimaryKey, fieldSchema.GetIsPrimaryKey())
|
||||
assert.Equal(t, field.IsPartitionKey, fieldSchema.GetIsPartitionKey())
|
||||
assert.Equal(t, field.IsClusteringKey, fieldSchema.GetIsClusteringKey())
|
||||
assert.Equal(t, field.IsDynamic, fieldSchema.GetIsDynamic())
|
||||
assert.Equal(t, field.Description, fieldSchema.GetDescription())
|
||||
assert.Equal(t, field.TypeParams, KvPairsMap(fieldSchema.GetTypeParams()))
|
||||
assert.EqualValues(t, field.ElementType, fieldSchema.GetElementType())
|
||||
// marshal & unmarshal, still equals
|
||||
nf := &Field{}
|
||||
nf = nf.ReadProto(fieldSchema)
|
||||
assert.Equal(t, field.ID, nf.ID)
|
||||
assert.Equal(t, field.Name, nf.Name)
|
||||
assert.EqualValues(t, field.DataType, nf.DataType)
|
||||
assert.Equal(t, field.AutoID, nf.AutoID)
|
||||
assert.Equal(t, field.PrimaryKey, nf.PrimaryKey)
|
||||
assert.Equal(t, field.Description, nf.Description)
|
||||
assert.Equal(t, field.IsDynamic, nf.IsDynamic)
|
||||
assert.Equal(t, field.IsPartitionKey, nf.IsPartitionKey)
|
||||
assert.Equal(t, field.IsClusteringKey, nf.IsClusteringKey)
|
||||
assert.EqualValues(t, field.TypeParams, nf.TypeParams)
|
||||
assert.EqualValues(t, field.ElementType, nf.ElementType)
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
(&Field{}).WithTypeParams("a", "b")
|
||||
})
|
||||
}
|
||||
|
||||
type SchemaSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
@ -101,7 +51,8 @@ func (s *SchemaSuite) TestBasic() {
|
|||
"test_collection",
|
||||
NewSchema().WithName("test_collection_1").WithDescription("test_collection_1 desc").WithAutoID(false).
|
||||
WithField(NewField().WithName("ID").WithDataType(FieldTypeInt64).WithIsPrimaryKey(true)).
|
||||
WithField(NewField().WithName("vector").WithDataType(FieldTypeFloatVector).WithDim(128)),
|
||||
WithField(NewField().WithName("vector").WithDataType(FieldTypeFloatVector).WithDim(128)).
|
||||
WithFunction(NewFunction()),
|
||||
"ID",
|
||||
},
|
||||
{
|
||||
|
@ -122,6 +73,7 @@ func (s *SchemaSuite) TestBasic() {
|
|||
s.Equal(sch.Description, p.GetDescription())
|
||||
s.Equal(sch.EnableDynamicField, p.GetEnableDynamicField())
|
||||
s.Equal(len(sch.Fields), len(p.GetFields()))
|
||||
s.Equal(len(sch.Functions), len(p.GetFunctions()))
|
||||
|
||||
nsch := &Schema{}
|
||||
nsch = nsch.ReadProto(p)
|
||||
|
@ -130,6 +82,7 @@ func (s *SchemaSuite) TestBasic() {
|
|||
s.Equal(sch.Description, nsch.Description)
|
||||
s.Equal(sch.EnableDynamicField, nsch.EnableDynamicField)
|
||||
s.Equal(len(sch.Fields), len(nsch.Fields))
|
||||
s.Equal(len(sch.Functions), len(nsch.Functions))
|
||||
s.Equal(c.pkName, sch.PKFieldName())
|
||||
s.Equal(c.pkName, nsch.PKFieldName())
|
||||
})
|
||||
|
|
|
@ -104,3 +104,19 @@ func (bv BinaryVector) Serialize() []byte {
|
|||
func (bv BinaryVector) FieldType() FieldType {
|
||||
return FieldTypeBinaryVector
|
||||
}
|
||||
|
||||
type Text string
|
||||
|
||||
// Dim returns vector dimension.
|
||||
func (t Text) Dim() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// entity.FieldType returns coresponding field type.
|
||||
func (t Text) FieldType() FieldType {
|
||||
return FieldTypeVarChar
|
||||
}
|
||||
|
||||
func (t Text) Serialize() []byte {
|
||||
return []byte(t)
|
||||
}
|
||||
|
|
|
@ -30,7 +30,10 @@ import (
|
|||
)
|
||||
|
||||
func (c *Client) Search(ctx context.Context, option SearchOption, callOptions ...grpc.CallOption) ([]ResultSet, error) {
|
||||
req := option.Request()
|
||||
req, err := option.Request()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
collection, err := c.getCollection(ctx, req.GetCollectionName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package milvusclient
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
type SearchOptionSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
type nonSupportData struct{}
|
||||
|
||||
func (d nonSupportData) Serialize() []byte {
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
func (d nonSupportData) Dim() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (d nonSupportData) FieldType() entity.FieldType {
|
||||
return entity.FieldType(0)
|
||||
}
|
||||
|
||||
func (s *SearchOptionSuite) TestBasic() {
|
||||
collName := "search_opt_basic"
|
||||
|
||||
topK := rand.Intn(100) + 1
|
||||
opt := NewSearchOption(collName, topK, []entity.Vector{entity.FloatVector([]float32{0.1, 0.2})})
|
||||
|
||||
opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000")
|
||||
|
||||
req, err := opt.Request()
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Equal(collName, req.GetCollectionName())
|
||||
s.Equal("ID > 1000", req.GetDsl())
|
||||
s.ElementsMatch([]string{"ID", "Value"}, req.GetOutputFields())
|
||||
searchParams := entity.KvPairsMap(req.GetSearchParams())
|
||||
annField, ok := searchParams[spAnnsField]
|
||||
s.Require().True(ok)
|
||||
s.Equal("test_field", annField)
|
||||
|
||||
opt = NewSearchOption(collName, topK, []entity.Vector{nonSupportData{}})
|
||||
_, err = opt.Request()
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
func (s *SearchOptionSuite) TestPlaceHolder() {
|
||||
type testCase struct {
|
||||
tag string
|
||||
input []entity.Vector
|
||||
expectError bool
|
||||
expectType commonpb.PlaceholderType
|
||||
}
|
||||
|
||||
sparse, err := entity.NewSliceSparseEmbedding([]uint32{0, 10, 12}, []float32{0.1, 0.2, 0.3})
|
||||
s.Require().NoError(err)
|
||||
|
||||
cases := []*testCase{
|
||||
{
|
||||
tag: "empty_input",
|
||||
input: nil,
|
||||
expectType: commonpb.PlaceholderType_None,
|
||||
},
|
||||
{
|
||||
tag: "float_vector",
|
||||
input: []entity.Vector{entity.FloatVector([]float32{0.1, 0.2, 0.3})},
|
||||
expectType: commonpb.PlaceholderType_FloatVector,
|
||||
},
|
||||
{
|
||||
tag: "sparse_vector",
|
||||
input: []entity.Vector{sparse},
|
||||
expectType: commonpb.PlaceholderType_SparseFloatVector,
|
||||
},
|
||||
{
|
||||
tag: "fp16_vector",
|
||||
input: []entity.Vector{entity.Float16Vector([]byte{})},
|
||||
expectType: commonpb.PlaceholderType_Float16Vector,
|
||||
},
|
||||
{
|
||||
tag: "bf16_vector",
|
||||
input: []entity.Vector{entity.BFloat16Vector([]byte{})},
|
||||
expectType: commonpb.PlaceholderType_BFloat16Vector,
|
||||
},
|
||||
{
|
||||
tag: "binary_vector",
|
||||
input: []entity.Vector{entity.BinaryVector([]byte{})},
|
||||
expectType: commonpb.PlaceholderType_BinaryVector,
|
||||
},
|
||||
{
|
||||
tag: "text",
|
||||
input: []entity.Vector{entity.Text("abc")},
|
||||
expectType: commonpb.PlaceholderType_VarChar,
|
||||
},
|
||||
{
|
||||
tag: "non_supported",
|
||||
input: []entity.Vector{nonSupportData{}},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
s.Run(tc.tag, func() {
|
||||
phv, err := vector2Placeholder(tc.input)
|
||||
if tc.expectError {
|
||||
s.Error(err)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
s.Equal(tc.expectType, phv.GetType())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchOption(t *testing.T) {
|
||||
suite.Run(t, new(SearchOptionSuite))
|
||||
}
|
|
@ -20,6 +20,7 @@ import (
|
|||
"encoding/json"
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
|
@ -40,7 +41,7 @@ const (
|
|||
)
|
||||
|
||||
type SearchOption interface {
|
||||
Request() *milvuspb.SearchRequest
|
||||
Request() (*milvuspb.SearchRequest, error)
|
||||
}
|
||||
|
||||
var _ SearchOption = (*searchOption)(nil)
|
||||
|
@ -70,12 +71,12 @@ type annRequest struct {
|
|||
groupByField string
|
||||
}
|
||||
|
||||
func (opt *searchOption) Request() *milvuspb.SearchRequest {
|
||||
func (opt *searchOption) Request() (*milvuspb.SearchRequest, error) {
|
||||
// TODO check whether search is hybrid after logic merged
|
||||
return opt.prepareSearchRequest(opt.request)
|
||||
}
|
||||
|
||||
func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.SearchRequest {
|
||||
func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) (*milvuspb.SearchRequest, error) {
|
||||
request := &milvuspb.SearchRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
PartitionNames: opt.partitionNames,
|
||||
|
@ -104,11 +105,15 @@ func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.
|
|||
}
|
||||
request.SearchParams = entity.MapKvPairs(params)
|
||||
|
||||
var err error
|
||||
// placeholder group
|
||||
request.PlaceholderGroup = vector2PlaceholderGroupBytes(annRequest.vectors)
|
||||
request.PlaceholderGroup, err = vector2PlaceholderGroupBytes(annRequest.vectors)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return request
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (opt *searchOption) WithFilter(expr string) *searchOption {
|
||||
|
@ -159,25 +164,29 @@ func NewSearchOption(collectionName string, limit int, vectors []entity.Vector)
|
|||
}
|
||||
}
|
||||
|
||||
func vector2PlaceholderGroupBytes(vectors []entity.Vector) []byte {
|
||||
func vector2PlaceholderGroupBytes(vectors []entity.Vector) ([]byte, error) {
|
||||
phv, err := vector2Placeholder(vectors)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
phg := &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{
|
||||
vector2Placeholder(vectors),
|
||||
phv,
|
||||
},
|
||||
}
|
||||
|
||||
bs, _ := proto.Marshal(phg)
|
||||
return bs
|
||||
bs, err := proto.Marshal(phg)
|
||||
return bs, err
|
||||
}
|
||||
|
||||
func vector2Placeholder(vectors []entity.Vector) *commonpb.PlaceholderValue {
|
||||
func vector2Placeholder(vectors []entity.Vector) (*commonpb.PlaceholderValue, error) {
|
||||
var placeHolderType commonpb.PlaceholderType
|
||||
ph := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Values: make([][]byte, 0, len(vectors)),
|
||||
}
|
||||
if len(vectors) == 0 {
|
||||
return ph
|
||||
return ph, nil
|
||||
}
|
||||
switch vectors[0].(type) {
|
||||
case entity.FloatVector:
|
||||
|
@ -190,12 +199,16 @@ func vector2Placeholder(vectors []entity.Vector) *commonpb.PlaceholderValue {
|
|||
placeHolderType = commonpb.PlaceholderType_Float16Vector
|
||||
case entity.SparseEmbedding:
|
||||
placeHolderType = commonpb.PlaceholderType_SparseFloatVector
|
||||
case entity.Text:
|
||||
placeHolderType = commonpb.PlaceholderType_VarChar
|
||||
default:
|
||||
return nil, errors.Newf("unsupported search data type: %T", vectors[0])
|
||||
}
|
||||
ph.Type = placeHolderType
|
||||
for _, vector := range vectors {
|
||||
ph.Values = append(ph.Values, vector.Serialize())
|
||||
}
|
||||
return ph
|
||||
return ph, nil
|
||||
}
|
||||
|
||||
type QueryOption interface {
|
||||
|
|
|
@ -118,11 +118,14 @@ func (s *ReadSuite) TestSearch() {
|
|||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
s.setupCache(collectionName, s.schemaDyn)
|
||||
|
||||
_, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{nonSupportData{}}))
|
||||
s.Error(err)
|
||||
|
||||
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
||||
return nil, merr.WrapErrServiceInternal("mocked")
|
||||
}).Once()
|
||||
|
||||
_, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{
|
||||
_, err = s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{
|
||||
entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
||||
return rand.Float32()
|
||||
})),
|
||||
|
|
|
@ -109,14 +109,15 @@ func (opt *columnBasedDataOption) processInsertColumns(colSchema *entity.Schema,
|
|||
}
|
||||
}
|
||||
|
||||
// check all fixed field pass value
|
||||
for _, field := range colSchema.Fields {
|
||||
_, has := mNameColumn[field.Name]
|
||||
if !has &&
|
||||
!field.AutoID && !field.IsDynamic {
|
||||
return nil, 0, fmt.Errorf("field %s not passed", field.Name)
|
||||
}
|
||||
}
|
||||
// missing field shall be checked in server side
|
||||
// // check all fixed field pass value
|
||||
// for _, field := range colSchema.Fields {
|
||||
// _, has := mNameColumn[field.Name]
|
||||
// if !has &&
|
||||
// !field.AutoID && !field.IsDynamic {
|
||||
// return nil, 0, fmt.Errorf("field %s not passed", field.Name)
|
||||
// }
|
||||
// }
|
||||
|
||||
fieldsData := make([]*schemapb.FieldData, 0, len(mNameColumn)+1)
|
||||
for _, fixedColumn := range mNameColumn {
|
||||
|
|
|
@ -129,10 +129,6 @@ func (s *WriteSuite) TestInsert() {
|
|||
}
|
||||
|
||||
cases := []badCase{
|
||||
{
|
||||
tag: "missing_column",
|
||||
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}),
|
||||
},
|
||||
{
|
||||
tag: "row_count_not_match",
|
||||
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}).
|
||||
|
@ -261,10 +257,6 @@ func (s *WriteSuite) TestUpsert() {
|
|||
}
|
||||
|
||||
cases := []badCase{
|
||||
{
|
||||
tag: "missing_column",
|
||||
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}),
|
||||
},
|
||||
{
|
||||
tag: "row_count_not_match",
|
||||
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}).
|
||||
|
|
|
@ -140,6 +140,12 @@ const (
|
|||
ConsistencyLevel = "consistency_level"
|
||||
)
|
||||
|
||||
// Doc-in-doc-out
|
||||
const (
|
||||
EnableAnalyzerKey = `enable_analyzer`
|
||||
AnalyzerParamKey = `analyzer_params`
|
||||
)
|
||||
|
||||
// Collection properties key
|
||||
|
||||
const (
|
||||
|
|
Loading…
Reference in New Issue