enhance: Support Row-based insert for milvusclient (#33270)

See also #31293

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/33289/head
congqixia 2024-05-22 19:15:40 +08:00 committed by GitHub
parent 39f56678a0
commit 33144a43d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 977 additions and 10 deletions

View File

@ -62,10 +62,6 @@ func (c *Client) CreateCollection(ctx context.Context, option CreateCollectionOp
return nil
}
type ListCollectionOption interface {
Request() *milvuspb.ShowCollectionsRequest
}
func (c *Client) ListCollections(ctx context.Context, option ListCollectionOption, callOptions ...grpc.CallOption) (collectionNames []string, err error) {
req := option.Request()
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
@ -82,7 +78,7 @@ func (c *Client) ListCollections(ctx context.Context, option ListCollectionOptio
return collectionNames, err
}
func (c *Client) DescribeCollection(ctx context.Context, option *describeCollectionOption, callOptions ...grpc.CallOption) (collection *entity.Collection, err error) {
func (c *Client) DescribeCollection(ctx context.Context, option DescribeCollectionOption, callOptions ...grpc.CallOption) (collection *entity.Collection, err error) {
req := option.Request()
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DescribeCollection(ctx, req, callOptions...)

View File

@ -159,6 +159,10 @@ func NewCreateCollectionOption(name string, collectionSchema *entity.Schema) *cr
}
}
type ListCollectionOption interface {
Request() *milvuspb.ShowCollectionsRequest
}
type listCollectionOption struct{}
func (opt *listCollectionOption) Request() *milvuspb.ShowCollectionsRequest {

View File

@ -19,6 +19,8 @@ package entity
import (
"strconv"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
@ -293,6 +295,18 @@ func (f *Field) WithDim(dim int64) *Field {
return f
}
func (f *Field) GetDim() (int64, error) {
dimStr, has := f.TypeParams[TypeParamDim]
if !has {
return -1, errors.New("field with no dim")
}
dim, err := strconv.ParseInt(dimStr, 10, 64)
if err != nil {
return -1, errors.Newf("field with bad format dim: %s", err.Error())
}
return dim, nil
}
func (f *Field) WithMaxLength(maxLen int64) *Field {
if f.TypeParams == nil {
f.TypeParams = make(map[string]string)

View File

@ -88,7 +88,7 @@ func deserializeSliceSparceEmbedding(bs []byte) (sliceSparseEmbedding, error) {
return sliceSparseEmbedding{}, errors.New("not valid sparse embedding bytes")
}
length = length / 8
length /= 8
result := sliceSparseEmbedding{
positions: make([]uint32, length),

332
client/row/data.go Normal file
View File

@ -0,0 +1,332 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package row
import (
"encoding/json"
"fmt"
"reflect"
"strconv"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/client/v2/column"
"github.com/milvus-io/milvus/client/v2/entity"
)
const (
// MilvusTag struct tag const for milvus row based struct
MilvusTag = `milvus`
// MilvusSkipTagValue struct tag const for skip this field.
MilvusSkipTagValue = `-`
// MilvusTagSep struct tag const for attribute separator
MilvusTagSep = `;`
// MilvusTagName struct tag const for field name
MilvusTagName = `NAME`
// VectorDimTag struct tag const for vector dimension
VectorDimTag = `DIM`
// VectorTypeTag struct tag const for binary vector type
VectorTypeTag = `VECTOR_TYPE`
// MilvusPrimaryKey struct tag const for primary key indicator
MilvusPrimaryKey = `PRIMARY_KEY`
// MilvusAutoID struct tag const for auto id indicator
MilvusAutoID = `AUTO_ID`
// DimMax dimension max value
DimMax = 65535
)
func AnyToColumns(rows []interface{}, schemas ...*entity.Schema) ([]column.Column, error) {
rowsLen := len(rows)
if rowsLen == 0 {
return []column.Column{}, errors.New("0 length column")
}
var sch *entity.Schema
var err error
// if schema not provided, try to parse from row
if len(schemas) == 0 {
sch, err = ParseSchema(rows[0])
if err != nil {
return []column.Column{}, err
}
} else {
// use first schema provided
sch = schemas[0]
}
isDynamic := sch.EnableDynamicField
var dynamicCol *column.ColumnJSONBytes
nameColumns := make(map[string]column.Column)
for _, field := range sch.Fields {
// skip auto id pk field
if field.PrimaryKey && field.AutoID {
continue
}
switch field.DataType {
case entity.FieldTypeBool:
data := make([]bool, 0, rowsLen)
col := column.NewColumnBool(field.Name, data)
nameColumns[field.Name] = col
case entity.FieldTypeInt8:
data := make([]int8, 0, rowsLen)
col := column.NewColumnInt8(field.Name, data)
nameColumns[field.Name] = col
case entity.FieldTypeInt16:
data := make([]int16, 0, rowsLen)
col := column.NewColumnInt16(field.Name, data)
nameColumns[field.Name] = col
case entity.FieldTypeInt32:
data := make([]int32, 0, rowsLen)
col := column.NewColumnInt32(field.Name, data)
nameColumns[field.Name] = col
case entity.FieldTypeInt64:
data := make([]int64, 0, rowsLen)
col := column.NewColumnInt64(field.Name, data)
nameColumns[field.Name] = col
case entity.FieldTypeFloat:
data := make([]float32, 0, rowsLen)
col := column.NewColumnFloat(field.Name, data)
nameColumns[field.Name] = col
case entity.FieldTypeDouble:
data := make([]float64, 0, rowsLen)
col := column.NewColumnDouble(field.Name, data)
nameColumns[field.Name] = col
case entity.FieldTypeString, entity.FieldTypeVarChar:
data := make([]string, 0, rowsLen)
col := column.NewColumnString(field.Name, data)
nameColumns[field.Name] = col
case entity.FieldTypeJSON:
data := make([][]byte, 0, rowsLen)
col := column.NewColumnJSONBytes(field.Name, data)
nameColumns[field.Name] = col
case entity.FieldTypeArray:
col := NewArrayColumn(field)
if col == nil {
return nil, errors.Newf("unsupported element type %s for Array", field.ElementType.String())
}
nameColumns[field.Name] = col
case entity.FieldTypeFloatVector:
data := make([][]float32, 0, rowsLen)
dimStr, has := field.TypeParams[entity.TypeParamDim]
if !has {
return []column.Column{}, errors.New("vector field with no dim")
}
dim, err := strconv.ParseInt(dimStr, 10, 64)
if err != nil {
return []column.Column{}, fmt.Errorf("vector field with bad format dim: %s", err.Error())
}
col := column.NewColumnFloatVector(field.Name, int(dim), data)
nameColumns[field.Name] = col
case entity.FieldTypeBinaryVector:
data := make([][]byte, 0, rowsLen)
dim, err := field.GetDim()
if err != nil {
return []column.Column{}, err
}
col := column.NewColumnBinaryVector(field.Name, int(dim), data)
nameColumns[field.Name] = col
case entity.FieldTypeFloat16Vector:
data := make([][]byte, 0, rowsLen)
dim, err := field.GetDim()
if err != nil {
return []column.Column{}, err
}
col := column.NewColumnFloat16Vector(field.Name, int(dim), data)
nameColumns[field.Name] = col
case entity.FieldTypeBFloat16Vector:
data := make([][]byte, 0, rowsLen)
dim, err := field.GetDim()
if err != nil {
return []column.Column{}, err
}
col := column.NewColumnBFloat16Vector(field.Name, int(dim), data)
nameColumns[field.Name] = col
case entity.FieldTypeSparseVector:
data := make([]entity.SparseEmbedding, 0, rowsLen)
col := column.NewColumnSparseVectors(field.Name, data)
nameColumns[field.Name] = col
}
}
if isDynamic {
dynamicCol = column.NewColumnJSONBytes("", make([][]byte, 0, rowsLen)).WithIsDynamic(true)
}
for _, row := range rows {
// collection schema name need not to be same, since receiver could has other names
v := reflect.ValueOf(row)
set, err := reflectValueCandi(v)
if err != nil {
return nil, err
}
for idx, field := range sch.Fields {
// skip dynamic field if visible
if isDynamic && field.IsDynamic {
continue
}
// skip auto id pk field
if field.PrimaryKey && field.AutoID {
// remove pk field from candidates set, avoid adding it into dynamic column
delete(set, field.Name)
continue
}
column, ok := nameColumns[field.Name]
if !ok {
return nil, fmt.Errorf("expected unhandled field %s", field.Name)
}
candi, ok := set[field.Name]
if !ok {
return nil, fmt.Errorf("row %d does not has field %s", idx, field.Name)
}
err := column.AppendValue(candi.v.Interface())
if err != nil {
return nil, err
}
delete(set, field.Name)
}
if isDynamic {
m := make(map[string]interface{})
for name, candi := range set {
m[name] = candi.v.Interface()
}
bs, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("failed to marshal dynamic field %w", err)
}
err = dynamicCol.AppendValue(bs)
if err != nil {
return nil, fmt.Errorf("failed to append value to dynamic field %w", err)
}
}
}
columns := make([]column.Column, 0, len(nameColumns))
for _, column := range nameColumns {
columns = append(columns, column)
}
if isDynamic {
columns = append(columns, dynamicCol)
}
return columns, nil
}
func NewArrayColumn(f *entity.Field) column.Column {
switch f.ElementType {
case entity.FieldTypeBool:
return column.NewColumnBoolArray(f.Name, nil)
case entity.FieldTypeInt8:
return column.NewColumnInt8Array(f.Name, nil)
case entity.FieldTypeInt16:
return column.NewColumnInt16Array(f.Name, nil)
case entity.FieldTypeInt32:
return column.NewColumnInt32Array(f.Name, nil)
case entity.FieldTypeInt64:
return column.NewColumnInt64Array(f.Name, nil)
case entity.FieldTypeFloat:
return column.NewColumnFloatArray(f.Name, nil)
case entity.FieldTypeDouble:
return column.NewColumnDoubleArray(f.Name, nil)
case entity.FieldTypeVarChar:
return column.NewColumnVarCharArray(f.Name, nil)
default:
return nil
}
}
type fieldCandi struct {
name string
v reflect.Value
options map[string]string
}
func reflectValueCandi(v reflect.Value) (map[string]fieldCandi, error) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
result := make(map[string]fieldCandi)
switch v.Kind() {
case reflect.Map: // map[string]any
iter := v.MapRange()
for iter.Next() {
key := iter.Key().String()
result[key] = fieldCandi{
name: key,
v: iter.Value(),
}
}
return result, nil
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
ft := v.Type().Field(i)
name := ft.Name
tag, ok := ft.Tag.Lookup(MilvusTag)
settings := make(map[string]string)
if ok {
if tag == MilvusSkipTagValue {
continue
}
settings = ParseTagSetting(tag, MilvusTagSep)
fn, has := settings[MilvusTagName]
if has {
// overwrite column to tag name
name = fn
}
}
_, ok = result[name]
// duplicated
if ok {
return nil, fmt.Errorf("column has duplicated name: %s when parsing field: %s", name, ft.Name)
}
v := v.Field(i)
if v.Kind() == reflect.Array {
v = v.Slice(0, v.Len())
}
result[name] = fieldCandi{
name: name,
v: v,
options: settings,
}
}
return result, nil
default:
return nil, fmt.Errorf("unsupport row type: %s", v.Kind().String())
}
}

174
client/row/data_test.go Normal file
View File

@ -0,0 +1,174 @@
package row
import (
"reflect"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/client/v2/entity"
)
type ValidStruct struct {
ID int64 `milvus:"primary_key"`
Attr1 int8
Attr2 int16
Attr3 int32
Attr4 float32
Attr5 float64
Attr6 string
Attr7 bool
Vector []float32 `milvus:"dim:16"`
Vector2 []byte `milvus:"dim:32"`
}
type ValidStruct2 struct {
ID int64 `milvus:"primary_key"`
Vector [16]float32
Vector2 [4]byte
Ignored bool `milvus:"-"`
}
type ValidStructWithNamedTag struct {
ID int64 `milvus:"primary_key;name:id"`
Vector [16]float32 `milvus:"name:vector"`
}
type RowsSuite struct {
suite.Suite
}
func (s *RowsSuite) TestRowsToColumns() {
s.Run("valid_cases", func() {
columns, err := AnyToColumns([]any{&ValidStruct{}})
s.Nil(err)
s.Equal(10, len(columns))
columns, err = AnyToColumns([]any{&ValidStruct2{}})
s.Nil(err)
s.Equal(3, len(columns))
})
s.Run("auto_id_pk", func() {
type AutoPK struct {
ID int64 `milvus:"primary_key;auto_id"`
Vector []float32 `milvus:"dim:32"`
}
columns, err := AnyToColumns([]any{&AutoPK{}})
s.Nil(err)
s.Require().Equal(1, len(columns))
s.Equal("Vector", columns[0].Name())
})
s.Run("fp16", func() {
type BF16Struct struct {
ID int64 `milvus:"primary_key;auto_id"`
Vector []byte `milvus:"dim:16;vector_type:bf16"`
}
columns, err := AnyToColumns([]any{&BF16Struct{}})
s.Nil(err)
s.Require().Equal(1, len(columns))
s.Equal("Vector", columns[0].Name())
s.Equal(entity.FieldTypeBFloat16Vector, columns[0].Type())
})
s.Run("fp16", func() {
type FP16Struct struct {
ID int64 `milvus:"primary_key;auto_id"`
Vector []byte `milvus:"dim:16;vector_type:fp16"`
}
columns, err := AnyToColumns([]any{&FP16Struct{}})
s.Nil(err)
s.Require().Equal(1, len(columns))
s.Equal("Vector", columns[0].Name())
s.Equal(entity.FieldTypeFloat16Vector, columns[0].Type())
})
s.Run("invalid_cases", func() {
// empty input
_, err := AnyToColumns([]any{})
s.NotNil(err)
// incompatible rows
_, err = AnyToColumns([]any{&ValidStruct{}, &ValidStruct2{}})
s.NotNil(err)
// schema & row not compatible
_, err = AnyToColumns([]any{&ValidStruct{}}, &entity.Schema{
Fields: []*entity.Field{
{
Name: "int64",
DataType: entity.FieldTypeInt64,
},
},
})
s.NotNil(err)
})
}
func (s *RowsSuite) TestDynamicSchema() {
s.Run("all_fallback_dynamic", func() {
columns, err := AnyToColumns([]any{&ValidStruct{}},
entity.NewSchema().WithDynamicFieldEnabled(true),
)
s.NoError(err)
s.Equal(1, len(columns))
})
s.Run("dynamic_not_found", func() {
_, err := AnyToColumns([]any{&ValidStruct{}},
entity.NewSchema().WithField(
entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true),
).WithDynamicFieldEnabled(true),
)
s.NoError(err)
})
}
func (s *RowsSuite) TestReflectValueCandi() {
cases := []struct {
tag string
v reflect.Value
expect map[string]fieldCandi
expectErr bool
}{
{
tag: "MapRow",
v: reflect.ValueOf(map[string]interface{}{
"A": "abd", "B": int64(8),
}),
expect: map[string]fieldCandi{
"A": {
name: "A",
v: reflect.ValueOf("abd"),
},
"B": {
name: "B",
v: reflect.ValueOf(int64(8)),
},
},
expectErr: false,
},
}
for _, c := range cases {
s.Run(c.tag, func() {
r, err := reflectValueCandi(c.v)
if c.expectErr {
s.Error(err)
return
}
s.NoError(err)
s.Equal(len(c.expect), len(r))
for k, v := range c.expect {
rv, has := r[k]
s.Require().True(has)
s.Equal(v.name, rv.name)
}
})
}
}
func TestRows(t *testing.T) {
suite.Run(t, new(RowsSuite))
}

185
client/row/schema.go Normal file
View File

@ -0,0 +1,185 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package row
import (
"fmt"
"go/ast"
"reflect"
"strconv"
"strings"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/client/v2/entity"
)
// ParseSchema parses schema from interface{}.
func ParseSchema(r interface{}) (*entity.Schema, error) {
sch := &entity.Schema{}
t := reflect.TypeOf(r)
if t.Kind() == reflect.Array || t.Kind() == reflect.Slice || t.Kind() == reflect.Ptr {
t = t.Elem()
}
// MapRow is not supported for schema definition
// TODO add PrimaryKey() interface later
if t.Kind() == reflect.Map {
return nil, fmt.Errorf("map row is not supported for schema definition")
}
if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("unsupported data type: %+v", r)
}
// Collection method not overwrited, try use Row type name
if sch.CollectionName == "" {
sch.CollectionName = t.Name()
if sch.CollectionName == "" {
return nil, errors.New("collection name not provided")
}
}
sch.Fields = make([]*entity.Field, 0, t.NumField())
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
// ignore anonymous field for now
if f.Anonymous || !ast.IsExported(f.Name) {
continue
}
field := &entity.Field{
Name: f.Name,
}
ft := f.Type
if f.Type.Kind() == reflect.Ptr {
ft = ft.Elem()
}
fv := reflect.New(ft)
tag := f.Tag.Get(MilvusTag)
if tag == MilvusSkipTagValue {
continue
}
tagSettings := ParseTagSetting(tag, MilvusTagSep)
if _, has := tagSettings[MilvusPrimaryKey]; has {
field.PrimaryKey = true
}
if _, has := tagSettings[MilvusAutoID]; has {
field.AutoID = true
}
if name, has := tagSettings[MilvusTagName]; has {
field.Name = name
}
switch reflect.Indirect(fv).Kind() {
case reflect.Bool:
field.DataType = entity.FieldTypeBool
case reflect.Int8:
field.DataType = entity.FieldTypeInt8
case reflect.Int16:
field.DataType = entity.FieldTypeInt16
case reflect.Int32:
field.DataType = entity.FieldTypeInt32
case reflect.Int64:
field.DataType = entity.FieldTypeInt64
case reflect.Float32:
field.DataType = entity.FieldTypeFloat
case reflect.Float64:
field.DataType = entity.FieldTypeDouble
case reflect.String:
field.DataType = entity.FieldTypeString
case reflect.Array:
arrayLen := ft.Len()
elemType := ft.Elem()
switch elemType.Kind() {
case reflect.Uint8:
field.WithDataType(entity.FieldTypeBinaryVector)
field.WithDim(int64(arrayLen) * 8)
case reflect.Float32:
field.WithDataType(entity.FieldTypeFloatVector)
field.WithDim(int64(arrayLen))
default:
return nil, fmt.Errorf("field %s is array of %v, which is not supported", f.Name, elemType)
}
case reflect.Slice:
dimStr, has := tagSettings[VectorDimTag]
if !has {
return nil, fmt.Errorf("field %s is slice but dim not provided", f.Name)
}
dim, err := strconv.ParseInt(dimStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("dim value %s is not valid", dimStr)
}
if dim < 1 || dim > DimMax {
return nil, fmt.Errorf("dim value %d is out of range", dim)
}
field.WithDim(dim)
elemType := ft.Elem()
switch elemType.Kind() {
case reflect.Uint8: // []byte, could be BinaryVector, fp16, bf 6
switch tagSettings[VectorTypeTag] {
case "fp16":
field.DataType = entity.FieldTypeFloat16Vector
case "bf16":
field.DataType = entity.FieldTypeBFloat16Vector
default:
field.DataType = entity.FieldTypeBinaryVector
}
case reflect.Float32:
field.DataType = entity.FieldTypeFloatVector
default:
return nil, fmt.Errorf("field %s is slice of %v, which is not supported", f.Name, elemType)
}
default:
return nil, fmt.Errorf("field %s is %v, which is not supported", field.Name, ft)
}
sch.Fields = append(sch.Fields, field)
}
return sch, nil
}
// ParseTagSetting parses struct tag into map settings
func ParseTagSetting(str string, sep string) map[string]string {
settings := map[string]string{}
names := strings.Split(str, sep)
for i := 0; i < len(names); i++ {
j := i
if len(names[j]) > 0 {
for {
if names[j][len(names[j])-1] == '\\' {
i++
names[j] = names[j][0:len(names[j])-1] + sep + names[i]
names[i] = ""
} else {
break
}
}
}
values := strings.Split(names[j], ":")
k := strings.TrimSpace(strings.ToUpper(values[0]))
if len(values) >= 2 {
settings[k] = strings.Join(values[1:], ":")
} else if k != "" {
settings[k] = k
}
}
return settings
}

213
client/row/schema_test.go Normal file
View File

@ -0,0 +1,213 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package row
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/client/v2/entity"
)
// ArrayRow test case type
type ArrayRow [16]float32
func (ar *ArrayRow) Collection() string { return "" }
func (ar *ArrayRow) Partition() string { return "" }
func (ar *ArrayRow) Description() string { return "" }
type Uint8Struct struct {
Attr uint8
}
type StringArrayStruct struct {
Vector [8]string
}
type StringSliceStruct struct {
Vector []string `milvus:"dim:8"`
}
type SliceNoDimStruct struct {
Vector []float32 `milvus:""`
}
type SliceBadDimStruct struct {
Vector []float32 `milvus:"dim:str"`
}
type SliceBadDimStruct2 struct {
Vector []float32 `milvus:"dim:0"`
}
func TestParseSchema(t *testing.T) {
t.Run("invalid cases", func(t *testing.T) {
// anonymous struct with default collection name ("") will cause error
anonymusStruct := struct{}{}
sch, err := ParseSchema(anonymusStruct)
assert.Nil(t, sch)
assert.NotNil(t, err)
// non struct
arrayRow := ArrayRow([16]float32{})
sch, err = ParseSchema(&arrayRow)
assert.Nil(t, sch)
assert.NotNil(t, err)
// uint8 not supported
sch, err = ParseSchema(&Uint8Struct{})
assert.Nil(t, sch)
assert.NotNil(t, err)
// string array not supported
sch, err = ParseSchema(&StringArrayStruct{})
assert.Nil(t, sch)
assert.NotNil(t, err)
// string slice not supported
sch, err = ParseSchema(&StringSliceStruct{})
assert.Nil(t, sch)
assert.NotNil(t, err)
// slice vector with no dim
sch, err = ParseSchema(&SliceNoDimStruct{})
assert.Nil(t, sch)
assert.NotNil(t, err)
// slice vector with bad format dim
sch, err = ParseSchema(&SliceBadDimStruct{})
assert.Nil(t, sch)
assert.NotNil(t, err)
// slice vector with bad format dim 2
sch, err = ParseSchema(&SliceBadDimStruct2{})
assert.Nil(t, sch)
assert.NotNil(t, err)
})
t.Run("valid cases", func(t *testing.T) {
getVectorField := func(schema *entity.Schema) *entity.Field {
for _, field := range schema.Fields {
if field.DataType == entity.FieldTypeFloatVector ||
field.DataType == entity.FieldTypeBinaryVector ||
field.DataType == entity.FieldTypeBFloat16Vector ||
field.DataType == entity.FieldTypeFloat16Vector {
return field
}
}
return nil
}
type ValidStruct struct {
ID int64 `milvus:"primary_key"`
Attr1 int8
Attr2 int16
Attr3 int32
Attr4 float32
Attr5 float64
Attr6 string
Vector []float32 `milvus:"dim:128"`
}
vs := &ValidStruct{}
sch, err := ParseSchema(vs)
assert.Nil(t, err)
assert.NotNil(t, sch)
assert.Equal(t, "ValidStruct", sch.CollectionName)
type ValidFp16Struct struct {
ID int64 `milvus:"primary_key"`
Attr1 int8
Attr2 int16
Attr3 int32
Attr4 float32
Attr5 float64
Attr6 string
Vector []byte `milvus:"dim:128;vector_type:fp16"`
}
fp16Vs := &ValidFp16Struct{}
sch, err = ParseSchema(fp16Vs)
assert.Nil(t, err)
assert.NotNil(t, sch)
assert.Equal(t, "ValidFp16Struct", sch.CollectionName)
vectorField := getVectorField(sch)
assert.Equal(t, entity.FieldTypeFloat16Vector, vectorField.DataType)
type ValidBf16Struct struct {
ID int64 `milvus:"primary_key"`
Attr1 int8
Attr2 int16
Attr3 int32
Attr4 float32
Attr5 float64
Attr6 string
Vector []byte `milvus:"dim:128;vector_type:bf16"`
}
bf16Vs := &ValidBf16Struct{}
sch, err = ParseSchema(bf16Vs)
assert.Nil(t, err)
assert.NotNil(t, sch)
assert.Equal(t, "ValidBf16Struct", sch.CollectionName)
vectorField = getVectorField(sch)
assert.Equal(t, entity.FieldTypeBFloat16Vector, vectorField.DataType)
type ValidByteStruct struct {
ID int64 `milvus:"primary_key"`
Vector []byte `milvus:"dim:128"`
}
vs2 := &ValidByteStruct{}
sch, err = ParseSchema(vs2)
assert.Nil(t, err)
assert.NotNil(t, sch)
type ValidArrayStruct struct {
ID int64 `milvus:"primary_key"`
Vector [64]float32
}
vs3 := &ValidArrayStruct{}
sch, err = ParseSchema(vs3)
assert.Nil(t, err)
assert.NotNil(t, sch)
type ValidArrayStructByte struct {
ID int64 `milvus:"primary_key;auto_id"`
Data *string `milvus:"extra:test\\;false"`
Vector [64]byte
}
vs4 := &ValidArrayStructByte{}
sch, err = ParseSchema(vs4)
assert.Nil(t, err)
assert.NotNil(t, sch)
vs5 := &ValidStructWithNamedTag{}
sch, err = ParseSchema(vs5)
assert.Nil(t, err)
assert.NotNil(t, sch)
i64f, vecf := false, false
for _, field := range sch.Fields {
if field.Name == "id" {
i64f = true
}
if field.Name == "vector" {
vecf = true
}
}
assert.True(t, i64f)
assert.True(t, vecf)
})
}

View File

@ -28,6 +28,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/column"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/row"
)
type InsertOption interface {
@ -71,10 +72,8 @@ func (opt *columnBasedDataOption) processInsertColumns(colSchema *entity.Schema,
l := col.Len()
if rowSize == 0 {
rowSize = l
} else {
if rowSize != l {
return nil, 0, errors.New("column size not match")
}
} else if rowSize != l {
return nil, 0, errors.New("column size not match")
}
field, has := mNameField[col.Name()]
if !has {
@ -247,6 +246,56 @@ func NewColumnBasedInsertOption(collName string, columns ...column.Column) *colu
}
}
type rowBasedDataOption struct {
*columnBasedDataOption
rows []any
}
func NewRowBasedInsertOption(collName string, rows ...any) *rowBasedDataOption {
return &rowBasedDataOption{
columnBasedDataOption: &columnBasedDataOption{
collName: collName,
},
rows: rows,
}
}
func (opt *rowBasedDataOption) InsertRequest(coll *entity.Collection) (*milvuspb.InsertRequest, error) {
columns, err := row.AnyToColumns(opt.rows, coll.Schema)
if err != nil {
return nil, err
}
opt.columnBasedDataOption.columns = columns
fieldsData, rowNum, err := opt.processInsertColumns(coll.Schema, opt.columns...)
if err != nil {
return nil, err
}
return &milvuspb.InsertRequest{
CollectionName: opt.collName,
PartitionName: opt.partitionName,
FieldsData: fieldsData,
NumRows: uint32(rowNum),
}, nil
}
func (opt *rowBasedDataOption) UpsertRequest(coll *entity.Collection) (*milvuspb.UpsertRequest, error) {
columns, err := row.AnyToColumns(opt.rows, coll.Schema)
if err != nil {
return nil, err
}
opt.columnBasedDataOption.columns = columns
fieldsData, rowNum, err := opt.processInsertColumns(coll.Schema, opt.columns...)
if err != nil {
return nil, err
}
return &milvuspb.UpsertRequest{
CollectionName: opt.collName,
PartitionName: opt.partitionName,
FieldsData: fieldsData,
NumRows: uint32(rowNum),
}, nil
}
type DeleteOption interface {
Request() *milvuspb.DeleteRequest
}