mirror of https://github.com/milvus-io/milvus.git
enhance: Support Row-based insert for milvusclient (#33270)
See also #31293 Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>pull/33289/head
parent
39f56678a0
commit
33144a43d4
|
@ -62,10 +62,6 @@ func (c *Client) CreateCollection(ctx context.Context, option CreateCollectionOp
|
|||
return nil
|
||||
}
|
||||
|
||||
type ListCollectionOption interface {
|
||||
Request() *milvuspb.ShowCollectionsRequest
|
||||
}
|
||||
|
||||
func (c *Client) ListCollections(ctx context.Context, option ListCollectionOption, callOptions ...grpc.CallOption) (collectionNames []string, err error) {
|
||||
req := option.Request()
|
||||
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
|
@ -82,7 +78,7 @@ func (c *Client) ListCollections(ctx context.Context, option ListCollectionOptio
|
|||
return collectionNames, err
|
||||
}
|
||||
|
||||
func (c *Client) DescribeCollection(ctx context.Context, option *describeCollectionOption, callOptions ...grpc.CallOption) (collection *entity.Collection, err error) {
|
||||
func (c *Client) DescribeCollection(ctx context.Context, option DescribeCollectionOption, callOptions ...grpc.CallOption) (collection *entity.Collection, err error) {
|
||||
req := option.Request()
|
||||
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.DescribeCollection(ctx, req, callOptions...)
|
||||
|
|
|
@ -159,6 +159,10 @@ func NewCreateCollectionOption(name string, collectionSchema *entity.Schema) *cr
|
|||
}
|
||||
}
|
||||
|
||||
type ListCollectionOption interface {
|
||||
Request() *milvuspb.ShowCollectionsRequest
|
||||
}
|
||||
|
||||
type listCollectionOption struct{}
|
||||
|
||||
func (opt *listCollectionOption) Request() *milvuspb.ShowCollectionsRequest {
|
||||
|
|
|
@ -19,6 +19,8 @@ package entity
|
|||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
@ -293,6 +295,18 @@ func (f *Field) WithDim(dim int64) *Field {
|
|||
return f
|
||||
}
|
||||
|
||||
func (f *Field) GetDim() (int64, error) {
|
||||
dimStr, has := f.TypeParams[TypeParamDim]
|
||||
if !has {
|
||||
return -1, errors.New("field with no dim")
|
||||
}
|
||||
dim, err := strconv.ParseInt(dimStr, 10, 64)
|
||||
if err != nil {
|
||||
return -1, errors.Newf("field with bad format dim: %s", err.Error())
|
||||
}
|
||||
return dim, nil
|
||||
}
|
||||
|
||||
func (f *Field) WithMaxLength(maxLen int64) *Field {
|
||||
if f.TypeParams == nil {
|
||||
f.TypeParams = make(map[string]string)
|
||||
|
|
|
@ -88,7 +88,7 @@ func deserializeSliceSparceEmbedding(bs []byte) (sliceSparseEmbedding, error) {
|
|||
return sliceSparseEmbedding{}, errors.New("not valid sparse embedding bytes")
|
||||
}
|
||||
|
||||
length = length / 8
|
||||
length /= 8
|
||||
|
||||
result := sliceSparseEmbedding{
|
||||
positions: make([]uint32, length),
|
||||
|
|
|
@ -0,0 +1,332 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package row
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus/client/v2/column"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
const (
|
||||
// MilvusTag struct tag const for milvus row based struct
|
||||
MilvusTag = `milvus`
|
||||
|
||||
// MilvusSkipTagValue struct tag const for skip this field.
|
||||
MilvusSkipTagValue = `-`
|
||||
|
||||
// MilvusTagSep struct tag const for attribute separator
|
||||
MilvusTagSep = `;`
|
||||
|
||||
// MilvusTagName struct tag const for field name
|
||||
MilvusTagName = `NAME`
|
||||
|
||||
// VectorDimTag struct tag const for vector dimension
|
||||
VectorDimTag = `DIM`
|
||||
|
||||
// VectorTypeTag struct tag const for binary vector type
|
||||
VectorTypeTag = `VECTOR_TYPE`
|
||||
|
||||
// MilvusPrimaryKey struct tag const for primary key indicator
|
||||
MilvusPrimaryKey = `PRIMARY_KEY`
|
||||
|
||||
// MilvusAutoID struct tag const for auto id indicator
|
||||
MilvusAutoID = `AUTO_ID`
|
||||
|
||||
// DimMax dimension max value
|
||||
DimMax = 65535
|
||||
)
|
||||
|
||||
func AnyToColumns(rows []interface{}, schemas ...*entity.Schema) ([]column.Column, error) {
|
||||
rowsLen := len(rows)
|
||||
if rowsLen == 0 {
|
||||
return []column.Column{}, errors.New("0 length column")
|
||||
}
|
||||
|
||||
var sch *entity.Schema
|
||||
var err error
|
||||
// if schema not provided, try to parse from row
|
||||
if len(schemas) == 0 {
|
||||
sch, err = ParseSchema(rows[0])
|
||||
if err != nil {
|
||||
return []column.Column{}, err
|
||||
}
|
||||
} else {
|
||||
// use first schema provided
|
||||
sch = schemas[0]
|
||||
}
|
||||
|
||||
isDynamic := sch.EnableDynamicField
|
||||
var dynamicCol *column.ColumnJSONBytes
|
||||
|
||||
nameColumns := make(map[string]column.Column)
|
||||
for _, field := range sch.Fields {
|
||||
// skip auto id pk field
|
||||
if field.PrimaryKey && field.AutoID {
|
||||
continue
|
||||
}
|
||||
switch field.DataType {
|
||||
case entity.FieldTypeBool:
|
||||
data := make([]bool, 0, rowsLen)
|
||||
col := column.NewColumnBool(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
case entity.FieldTypeInt8:
|
||||
data := make([]int8, 0, rowsLen)
|
||||
col := column.NewColumnInt8(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
case entity.FieldTypeInt16:
|
||||
data := make([]int16, 0, rowsLen)
|
||||
col := column.NewColumnInt16(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
case entity.FieldTypeInt32:
|
||||
data := make([]int32, 0, rowsLen)
|
||||
col := column.NewColumnInt32(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
case entity.FieldTypeInt64:
|
||||
data := make([]int64, 0, rowsLen)
|
||||
col := column.NewColumnInt64(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
case entity.FieldTypeFloat:
|
||||
data := make([]float32, 0, rowsLen)
|
||||
col := column.NewColumnFloat(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
case entity.FieldTypeDouble:
|
||||
data := make([]float64, 0, rowsLen)
|
||||
col := column.NewColumnDouble(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
case entity.FieldTypeString, entity.FieldTypeVarChar:
|
||||
data := make([]string, 0, rowsLen)
|
||||
col := column.NewColumnString(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
case entity.FieldTypeJSON:
|
||||
data := make([][]byte, 0, rowsLen)
|
||||
col := column.NewColumnJSONBytes(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
case entity.FieldTypeArray:
|
||||
col := NewArrayColumn(field)
|
||||
if col == nil {
|
||||
return nil, errors.Newf("unsupported element type %s for Array", field.ElementType.String())
|
||||
}
|
||||
nameColumns[field.Name] = col
|
||||
case entity.FieldTypeFloatVector:
|
||||
data := make([][]float32, 0, rowsLen)
|
||||
dimStr, has := field.TypeParams[entity.TypeParamDim]
|
||||
if !has {
|
||||
return []column.Column{}, errors.New("vector field with no dim")
|
||||
}
|
||||
dim, err := strconv.ParseInt(dimStr, 10, 64)
|
||||
if err != nil {
|
||||
return []column.Column{}, fmt.Errorf("vector field with bad format dim: %s", err.Error())
|
||||
}
|
||||
col := column.NewColumnFloatVector(field.Name, int(dim), data)
|
||||
nameColumns[field.Name] = col
|
||||
case entity.FieldTypeBinaryVector:
|
||||
data := make([][]byte, 0, rowsLen)
|
||||
dim, err := field.GetDim()
|
||||
if err != nil {
|
||||
return []column.Column{}, err
|
||||
}
|
||||
col := column.NewColumnBinaryVector(field.Name, int(dim), data)
|
||||
nameColumns[field.Name] = col
|
||||
case entity.FieldTypeFloat16Vector:
|
||||
data := make([][]byte, 0, rowsLen)
|
||||
dim, err := field.GetDim()
|
||||
if err != nil {
|
||||
return []column.Column{}, err
|
||||
}
|
||||
col := column.NewColumnFloat16Vector(field.Name, int(dim), data)
|
||||
nameColumns[field.Name] = col
|
||||
case entity.FieldTypeBFloat16Vector:
|
||||
data := make([][]byte, 0, rowsLen)
|
||||
dim, err := field.GetDim()
|
||||
if err != nil {
|
||||
return []column.Column{}, err
|
||||
}
|
||||
col := column.NewColumnBFloat16Vector(field.Name, int(dim), data)
|
||||
nameColumns[field.Name] = col
|
||||
case entity.FieldTypeSparseVector:
|
||||
data := make([]entity.SparseEmbedding, 0, rowsLen)
|
||||
col := column.NewColumnSparseVectors(field.Name, data)
|
||||
nameColumns[field.Name] = col
|
||||
}
|
||||
}
|
||||
|
||||
if isDynamic {
|
||||
dynamicCol = column.NewColumnJSONBytes("", make([][]byte, 0, rowsLen)).WithIsDynamic(true)
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
// collection schema name need not to be same, since receiver could has other names
|
||||
v := reflect.ValueOf(row)
|
||||
set, err := reflectValueCandi(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for idx, field := range sch.Fields {
|
||||
// skip dynamic field if visible
|
||||
if isDynamic && field.IsDynamic {
|
||||
continue
|
||||
}
|
||||
// skip auto id pk field
|
||||
if field.PrimaryKey && field.AutoID {
|
||||
// remove pk field from candidates set, avoid adding it into dynamic column
|
||||
delete(set, field.Name)
|
||||
continue
|
||||
}
|
||||
column, ok := nameColumns[field.Name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected unhandled field %s", field.Name)
|
||||
}
|
||||
|
||||
candi, ok := set[field.Name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("row %d does not has field %s", idx, field.Name)
|
||||
}
|
||||
err := column.AppendValue(candi.v.Interface())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
delete(set, field.Name)
|
||||
}
|
||||
|
||||
if isDynamic {
|
||||
m := make(map[string]interface{})
|
||||
for name, candi := range set {
|
||||
m[name] = candi.v.Interface()
|
||||
}
|
||||
bs, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal dynamic field %w", err)
|
||||
}
|
||||
err = dynamicCol.AppendValue(bs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to append value to dynamic field %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
columns := make([]column.Column, 0, len(nameColumns))
|
||||
for _, column := range nameColumns {
|
||||
columns = append(columns, column)
|
||||
}
|
||||
if isDynamic {
|
||||
columns = append(columns, dynamicCol)
|
||||
}
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func NewArrayColumn(f *entity.Field) column.Column {
|
||||
switch f.ElementType {
|
||||
case entity.FieldTypeBool:
|
||||
return column.NewColumnBoolArray(f.Name, nil)
|
||||
|
||||
case entity.FieldTypeInt8:
|
||||
return column.NewColumnInt8Array(f.Name, nil)
|
||||
|
||||
case entity.FieldTypeInt16:
|
||||
return column.NewColumnInt16Array(f.Name, nil)
|
||||
|
||||
case entity.FieldTypeInt32:
|
||||
return column.NewColumnInt32Array(f.Name, nil)
|
||||
|
||||
case entity.FieldTypeInt64:
|
||||
return column.NewColumnInt64Array(f.Name, nil)
|
||||
|
||||
case entity.FieldTypeFloat:
|
||||
return column.NewColumnFloatArray(f.Name, nil)
|
||||
|
||||
case entity.FieldTypeDouble:
|
||||
return column.NewColumnDoubleArray(f.Name, nil)
|
||||
|
||||
case entity.FieldTypeVarChar:
|
||||
return column.NewColumnVarCharArray(f.Name, nil)
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type fieldCandi struct {
|
||||
name string
|
||||
v reflect.Value
|
||||
options map[string]string
|
||||
}
|
||||
|
||||
func reflectValueCandi(v reflect.Value) (map[string]fieldCandi, error) {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
result := make(map[string]fieldCandi)
|
||||
switch v.Kind() {
|
||||
case reflect.Map: // map[string]any
|
||||
iter := v.MapRange()
|
||||
for iter.Next() {
|
||||
key := iter.Key().String()
|
||||
result[key] = fieldCandi{
|
||||
name: key,
|
||||
v: iter.Value(),
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
case reflect.Struct:
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
ft := v.Type().Field(i)
|
||||
name := ft.Name
|
||||
tag, ok := ft.Tag.Lookup(MilvusTag)
|
||||
|
||||
settings := make(map[string]string)
|
||||
if ok {
|
||||
if tag == MilvusSkipTagValue {
|
||||
continue
|
||||
}
|
||||
settings = ParseTagSetting(tag, MilvusTagSep)
|
||||
fn, has := settings[MilvusTagName]
|
||||
if has {
|
||||
// overwrite column to tag name
|
||||
name = fn
|
||||
}
|
||||
}
|
||||
_, ok = result[name]
|
||||
// duplicated
|
||||
if ok {
|
||||
return nil, fmt.Errorf("column has duplicated name: %s when parsing field: %s", name, ft.Name)
|
||||
}
|
||||
|
||||
v := v.Field(i)
|
||||
if v.Kind() == reflect.Array {
|
||||
v = v.Slice(0, v.Len())
|
||||
}
|
||||
|
||||
result[name] = fieldCandi{
|
||||
name: name,
|
||||
v: v,
|
||||
options: settings,
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupport row type: %s", v.Kind().String())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,174 @@
|
|||
package row
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
type ValidStruct struct {
|
||||
ID int64 `milvus:"primary_key"`
|
||||
Attr1 int8
|
||||
Attr2 int16
|
||||
Attr3 int32
|
||||
Attr4 float32
|
||||
Attr5 float64
|
||||
Attr6 string
|
||||
Attr7 bool
|
||||
Vector []float32 `milvus:"dim:16"`
|
||||
Vector2 []byte `milvus:"dim:32"`
|
||||
}
|
||||
|
||||
type ValidStruct2 struct {
|
||||
ID int64 `milvus:"primary_key"`
|
||||
Vector [16]float32
|
||||
Vector2 [4]byte
|
||||
Ignored bool `milvus:"-"`
|
||||
}
|
||||
|
||||
type ValidStructWithNamedTag struct {
|
||||
ID int64 `milvus:"primary_key;name:id"`
|
||||
Vector [16]float32 `milvus:"name:vector"`
|
||||
}
|
||||
|
||||
type RowsSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *RowsSuite) TestRowsToColumns() {
|
||||
s.Run("valid_cases", func() {
|
||||
columns, err := AnyToColumns([]any{&ValidStruct{}})
|
||||
s.Nil(err)
|
||||
s.Equal(10, len(columns))
|
||||
|
||||
columns, err = AnyToColumns([]any{&ValidStruct2{}})
|
||||
s.Nil(err)
|
||||
s.Equal(3, len(columns))
|
||||
})
|
||||
|
||||
s.Run("auto_id_pk", func() {
|
||||
type AutoPK struct {
|
||||
ID int64 `milvus:"primary_key;auto_id"`
|
||||
Vector []float32 `milvus:"dim:32"`
|
||||
}
|
||||
columns, err := AnyToColumns([]any{&AutoPK{}})
|
||||
s.Nil(err)
|
||||
s.Require().Equal(1, len(columns))
|
||||
s.Equal("Vector", columns[0].Name())
|
||||
})
|
||||
|
||||
s.Run("fp16", func() {
|
||||
type BF16Struct struct {
|
||||
ID int64 `milvus:"primary_key;auto_id"`
|
||||
Vector []byte `milvus:"dim:16;vector_type:bf16"`
|
||||
}
|
||||
columns, err := AnyToColumns([]any{&BF16Struct{}})
|
||||
s.Nil(err)
|
||||
s.Require().Equal(1, len(columns))
|
||||
s.Equal("Vector", columns[0].Name())
|
||||
s.Equal(entity.FieldTypeBFloat16Vector, columns[0].Type())
|
||||
})
|
||||
|
||||
s.Run("fp16", func() {
|
||||
type FP16Struct struct {
|
||||
ID int64 `milvus:"primary_key;auto_id"`
|
||||
Vector []byte `milvus:"dim:16;vector_type:fp16"`
|
||||
}
|
||||
columns, err := AnyToColumns([]any{&FP16Struct{}})
|
||||
s.Nil(err)
|
||||
s.Require().Equal(1, len(columns))
|
||||
s.Equal("Vector", columns[0].Name())
|
||||
s.Equal(entity.FieldTypeFloat16Vector, columns[0].Type())
|
||||
})
|
||||
|
||||
s.Run("invalid_cases", func() {
|
||||
// empty input
|
||||
_, err := AnyToColumns([]any{})
|
||||
s.NotNil(err)
|
||||
|
||||
// incompatible rows
|
||||
_, err = AnyToColumns([]any{&ValidStruct{}, &ValidStruct2{}})
|
||||
s.NotNil(err)
|
||||
|
||||
// schema & row not compatible
|
||||
_, err = AnyToColumns([]any{&ValidStruct{}}, &entity.Schema{
|
||||
Fields: []*entity.Field{
|
||||
{
|
||||
Name: "int64",
|
||||
DataType: entity.FieldTypeInt64,
|
||||
},
|
||||
},
|
||||
})
|
||||
s.NotNil(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RowsSuite) TestDynamicSchema() {
|
||||
s.Run("all_fallback_dynamic", func() {
|
||||
columns, err := AnyToColumns([]any{&ValidStruct{}},
|
||||
entity.NewSchema().WithDynamicFieldEnabled(true),
|
||||
)
|
||||
s.NoError(err)
|
||||
s.Equal(1, len(columns))
|
||||
})
|
||||
|
||||
s.Run("dynamic_not_found", func() {
|
||||
_, err := AnyToColumns([]any{&ValidStruct{}},
|
||||
entity.NewSchema().WithField(
|
||||
entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true),
|
||||
).WithDynamicFieldEnabled(true),
|
||||
)
|
||||
s.NoError(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RowsSuite) TestReflectValueCandi() {
|
||||
cases := []struct {
|
||||
tag string
|
||||
v reflect.Value
|
||||
expect map[string]fieldCandi
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
tag: "MapRow",
|
||||
v: reflect.ValueOf(map[string]interface{}{
|
||||
"A": "abd", "B": int64(8),
|
||||
}),
|
||||
expect: map[string]fieldCandi{
|
||||
"A": {
|
||||
name: "A",
|
||||
v: reflect.ValueOf("abd"),
|
||||
},
|
||||
"B": {
|
||||
name: "B",
|
||||
v: reflect.ValueOf(int64(8)),
|
||||
},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
s.Run(c.tag, func() {
|
||||
r, err := reflectValueCandi(c.v)
|
||||
if c.expectErr {
|
||||
s.Error(err)
|
||||
return
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Equal(len(c.expect), len(r))
|
||||
for k, v := range c.expect {
|
||||
rv, has := r[k]
|
||||
s.Require().True(has)
|
||||
s.Equal(v.name, rv.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRows(t *testing.T) {
|
||||
suite.Run(t, new(RowsSuite))
|
||||
}
|
|
@ -0,0 +1,185 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package row
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
// ParseSchema parses schema from interface{}.
|
||||
func ParseSchema(r interface{}) (*entity.Schema, error) {
|
||||
sch := &entity.Schema{}
|
||||
t := reflect.TypeOf(r)
|
||||
if t.Kind() == reflect.Array || t.Kind() == reflect.Slice || t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
// MapRow is not supported for schema definition
|
||||
// TODO add PrimaryKey() interface later
|
||||
if t.Kind() == reflect.Map {
|
||||
return nil, fmt.Errorf("map row is not supported for schema definition")
|
||||
}
|
||||
|
||||
if t.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("unsupported data type: %+v", r)
|
||||
}
|
||||
|
||||
// Collection method not overwrited, try use Row type name
|
||||
if sch.CollectionName == "" {
|
||||
sch.CollectionName = t.Name()
|
||||
if sch.CollectionName == "" {
|
||||
return nil, errors.New("collection name not provided")
|
||||
}
|
||||
}
|
||||
sch.Fields = make([]*entity.Field, 0, t.NumField())
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
// ignore anonymous field for now
|
||||
if f.Anonymous || !ast.IsExported(f.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
field := &entity.Field{
|
||||
Name: f.Name,
|
||||
}
|
||||
ft := f.Type
|
||||
if f.Type.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
fv := reflect.New(ft)
|
||||
tag := f.Tag.Get(MilvusTag)
|
||||
if tag == MilvusSkipTagValue {
|
||||
continue
|
||||
}
|
||||
tagSettings := ParseTagSetting(tag, MilvusTagSep)
|
||||
if _, has := tagSettings[MilvusPrimaryKey]; has {
|
||||
field.PrimaryKey = true
|
||||
}
|
||||
if _, has := tagSettings[MilvusAutoID]; has {
|
||||
field.AutoID = true
|
||||
}
|
||||
if name, has := tagSettings[MilvusTagName]; has {
|
||||
field.Name = name
|
||||
}
|
||||
switch reflect.Indirect(fv).Kind() {
|
||||
case reflect.Bool:
|
||||
field.DataType = entity.FieldTypeBool
|
||||
case reflect.Int8:
|
||||
field.DataType = entity.FieldTypeInt8
|
||||
case reflect.Int16:
|
||||
field.DataType = entity.FieldTypeInt16
|
||||
case reflect.Int32:
|
||||
field.DataType = entity.FieldTypeInt32
|
||||
case reflect.Int64:
|
||||
field.DataType = entity.FieldTypeInt64
|
||||
case reflect.Float32:
|
||||
field.DataType = entity.FieldTypeFloat
|
||||
case reflect.Float64:
|
||||
field.DataType = entity.FieldTypeDouble
|
||||
case reflect.String:
|
||||
field.DataType = entity.FieldTypeString
|
||||
case reflect.Array:
|
||||
arrayLen := ft.Len()
|
||||
elemType := ft.Elem()
|
||||
switch elemType.Kind() {
|
||||
case reflect.Uint8:
|
||||
field.WithDataType(entity.FieldTypeBinaryVector)
|
||||
field.WithDim(int64(arrayLen) * 8)
|
||||
case reflect.Float32:
|
||||
field.WithDataType(entity.FieldTypeFloatVector)
|
||||
field.WithDim(int64(arrayLen))
|
||||
default:
|
||||
return nil, fmt.Errorf("field %s is array of %v, which is not supported", f.Name, elemType)
|
||||
}
|
||||
case reflect.Slice:
|
||||
dimStr, has := tagSettings[VectorDimTag]
|
||||
if !has {
|
||||
return nil, fmt.Errorf("field %s is slice but dim not provided", f.Name)
|
||||
}
|
||||
dim, err := strconv.ParseInt(dimStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dim value %s is not valid", dimStr)
|
||||
}
|
||||
if dim < 1 || dim > DimMax {
|
||||
return nil, fmt.Errorf("dim value %d is out of range", dim)
|
||||
}
|
||||
field.WithDim(dim)
|
||||
|
||||
elemType := ft.Elem()
|
||||
switch elemType.Kind() {
|
||||
case reflect.Uint8: // []byte, could be BinaryVector, fp16, bf 6
|
||||
switch tagSettings[VectorTypeTag] {
|
||||
case "fp16":
|
||||
field.DataType = entity.FieldTypeFloat16Vector
|
||||
case "bf16":
|
||||
field.DataType = entity.FieldTypeBFloat16Vector
|
||||
default:
|
||||
field.DataType = entity.FieldTypeBinaryVector
|
||||
}
|
||||
case reflect.Float32:
|
||||
field.DataType = entity.FieldTypeFloatVector
|
||||
default:
|
||||
return nil, fmt.Errorf("field %s is slice of %v, which is not supported", f.Name, elemType)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("field %s is %v, which is not supported", field.Name, ft)
|
||||
}
|
||||
sch.Fields = append(sch.Fields, field)
|
||||
}
|
||||
|
||||
return sch, nil
|
||||
}
|
||||
|
||||
// ParseTagSetting parses struct tag into map settings
|
||||
func ParseTagSetting(str string, sep string) map[string]string {
|
||||
settings := map[string]string{}
|
||||
names := strings.Split(str, sep)
|
||||
|
||||
for i := 0; i < len(names); i++ {
|
||||
j := i
|
||||
if len(names[j]) > 0 {
|
||||
for {
|
||||
if names[j][len(names[j])-1] == '\\' {
|
||||
i++
|
||||
names[j] = names[j][0:len(names[j])-1] + sep + names[i]
|
||||
names[i] = ""
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
values := strings.Split(names[j], ":")
|
||||
k := strings.TrimSpace(strings.ToUpper(values[0]))
|
||||
|
||||
if len(values) >= 2 {
|
||||
settings[k] = strings.Join(values[1:], ":")
|
||||
} else if k != "" {
|
||||
settings[k] = k
|
||||
}
|
||||
}
|
||||
|
||||
return settings
|
||||
}
|
|
@ -0,0 +1,213 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package row
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
// ArrayRow test case type
|
||||
type ArrayRow [16]float32
|
||||
|
||||
func (ar *ArrayRow) Collection() string { return "" }
|
||||
func (ar *ArrayRow) Partition() string { return "" }
|
||||
func (ar *ArrayRow) Description() string { return "" }
|
||||
|
||||
type Uint8Struct struct {
|
||||
Attr uint8
|
||||
}
|
||||
|
||||
type StringArrayStruct struct {
|
||||
Vector [8]string
|
||||
}
|
||||
|
||||
type StringSliceStruct struct {
|
||||
Vector []string `milvus:"dim:8"`
|
||||
}
|
||||
|
||||
type SliceNoDimStruct struct {
|
||||
Vector []float32 `milvus:""`
|
||||
}
|
||||
|
||||
type SliceBadDimStruct struct {
|
||||
Vector []float32 `milvus:"dim:str"`
|
||||
}
|
||||
|
||||
type SliceBadDimStruct2 struct {
|
||||
Vector []float32 `milvus:"dim:0"`
|
||||
}
|
||||
|
||||
func TestParseSchema(t *testing.T) {
|
||||
t.Run("invalid cases", func(t *testing.T) {
|
||||
// anonymous struct with default collection name ("") will cause error
|
||||
anonymusStruct := struct{}{}
|
||||
sch, err := ParseSchema(anonymusStruct)
|
||||
assert.Nil(t, sch)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// non struct
|
||||
arrayRow := ArrayRow([16]float32{})
|
||||
sch, err = ParseSchema(&arrayRow)
|
||||
assert.Nil(t, sch)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// uint8 not supported
|
||||
sch, err = ParseSchema(&Uint8Struct{})
|
||||
assert.Nil(t, sch)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// string array not supported
|
||||
sch, err = ParseSchema(&StringArrayStruct{})
|
||||
assert.Nil(t, sch)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// string slice not supported
|
||||
sch, err = ParseSchema(&StringSliceStruct{})
|
||||
assert.Nil(t, sch)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// slice vector with no dim
|
||||
sch, err = ParseSchema(&SliceNoDimStruct{})
|
||||
assert.Nil(t, sch)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// slice vector with bad format dim
|
||||
sch, err = ParseSchema(&SliceBadDimStruct{})
|
||||
assert.Nil(t, sch)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// slice vector with bad format dim 2
|
||||
sch, err = ParseSchema(&SliceBadDimStruct2{})
|
||||
assert.Nil(t, sch)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("valid cases", func(t *testing.T) {
|
||||
getVectorField := func(schema *entity.Schema) *entity.Field {
|
||||
for _, field := range schema.Fields {
|
||||
if field.DataType == entity.FieldTypeFloatVector ||
|
||||
field.DataType == entity.FieldTypeBinaryVector ||
|
||||
field.DataType == entity.FieldTypeBFloat16Vector ||
|
||||
field.DataType == entity.FieldTypeFloat16Vector {
|
||||
return field
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ValidStruct struct {
|
||||
ID int64 `milvus:"primary_key"`
|
||||
Attr1 int8
|
||||
Attr2 int16
|
||||
Attr3 int32
|
||||
Attr4 float32
|
||||
Attr5 float64
|
||||
Attr6 string
|
||||
Vector []float32 `milvus:"dim:128"`
|
||||
}
|
||||
vs := &ValidStruct{}
|
||||
sch, err := ParseSchema(vs)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, sch)
|
||||
assert.Equal(t, "ValidStruct", sch.CollectionName)
|
||||
|
||||
type ValidFp16Struct struct {
|
||||
ID int64 `milvus:"primary_key"`
|
||||
Attr1 int8
|
||||
Attr2 int16
|
||||
Attr3 int32
|
||||
Attr4 float32
|
||||
Attr5 float64
|
||||
Attr6 string
|
||||
Vector []byte `milvus:"dim:128;vector_type:fp16"`
|
||||
}
|
||||
fp16Vs := &ValidFp16Struct{}
|
||||
sch, err = ParseSchema(fp16Vs)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, sch)
|
||||
assert.Equal(t, "ValidFp16Struct", sch.CollectionName)
|
||||
vectorField := getVectorField(sch)
|
||||
assert.Equal(t, entity.FieldTypeFloat16Vector, vectorField.DataType)
|
||||
|
||||
type ValidBf16Struct struct {
|
||||
ID int64 `milvus:"primary_key"`
|
||||
Attr1 int8
|
||||
Attr2 int16
|
||||
Attr3 int32
|
||||
Attr4 float32
|
||||
Attr5 float64
|
||||
Attr6 string
|
||||
Vector []byte `milvus:"dim:128;vector_type:bf16"`
|
||||
}
|
||||
bf16Vs := &ValidBf16Struct{}
|
||||
sch, err = ParseSchema(bf16Vs)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, sch)
|
||||
assert.Equal(t, "ValidBf16Struct", sch.CollectionName)
|
||||
vectorField = getVectorField(sch)
|
||||
assert.Equal(t, entity.FieldTypeBFloat16Vector, vectorField.DataType)
|
||||
|
||||
type ValidByteStruct struct {
|
||||
ID int64 `milvus:"primary_key"`
|
||||
Vector []byte `milvus:"dim:128"`
|
||||
}
|
||||
vs2 := &ValidByteStruct{}
|
||||
sch, err = ParseSchema(vs2)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, sch)
|
||||
|
||||
type ValidArrayStruct struct {
|
||||
ID int64 `milvus:"primary_key"`
|
||||
Vector [64]float32
|
||||
}
|
||||
vs3 := &ValidArrayStruct{}
|
||||
sch, err = ParseSchema(vs3)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, sch)
|
||||
|
||||
type ValidArrayStructByte struct {
|
||||
ID int64 `milvus:"primary_key;auto_id"`
|
||||
Data *string `milvus:"extra:test\\;false"`
|
||||
Vector [64]byte
|
||||
}
|
||||
vs4 := &ValidArrayStructByte{}
|
||||
sch, err = ParseSchema(vs4)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, sch)
|
||||
|
||||
vs5 := &ValidStructWithNamedTag{}
|
||||
sch, err = ParseSchema(vs5)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, sch)
|
||||
i64f, vecf := false, false
|
||||
for _, field := range sch.Fields {
|
||||
if field.Name == "id" {
|
||||
i64f = true
|
||||
}
|
||||
if field.Name == "vector" {
|
||||
vecf = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, i64f)
|
||||
assert.True(t, vecf)
|
||||
})
|
||||
}
|
|
@ -28,6 +28,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/column"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/client/v2/row"
|
||||
)
|
||||
|
||||
type InsertOption interface {
|
||||
|
@ -71,10 +72,8 @@ func (opt *columnBasedDataOption) processInsertColumns(colSchema *entity.Schema,
|
|||
l := col.Len()
|
||||
if rowSize == 0 {
|
||||
rowSize = l
|
||||
} else {
|
||||
if rowSize != l {
|
||||
return nil, 0, errors.New("column size not match")
|
||||
}
|
||||
} else if rowSize != l {
|
||||
return nil, 0, errors.New("column size not match")
|
||||
}
|
||||
field, has := mNameField[col.Name()]
|
||||
if !has {
|
||||
|
@ -247,6 +246,56 @@ func NewColumnBasedInsertOption(collName string, columns ...column.Column) *colu
|
|||
}
|
||||
}
|
||||
|
||||
type rowBasedDataOption struct {
|
||||
*columnBasedDataOption
|
||||
rows []any
|
||||
}
|
||||
|
||||
func NewRowBasedInsertOption(collName string, rows ...any) *rowBasedDataOption {
|
||||
return &rowBasedDataOption{
|
||||
columnBasedDataOption: &columnBasedDataOption{
|
||||
collName: collName,
|
||||
},
|
||||
rows: rows,
|
||||
}
|
||||
}
|
||||
|
||||
func (opt *rowBasedDataOption) InsertRequest(coll *entity.Collection) (*milvuspb.InsertRequest, error) {
|
||||
columns, err := row.AnyToColumns(opt.rows, coll.Schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opt.columnBasedDataOption.columns = columns
|
||||
fieldsData, rowNum, err := opt.processInsertColumns(coll.Schema, opt.columns...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &milvuspb.InsertRequest{
|
||||
CollectionName: opt.collName,
|
||||
PartitionName: opt.partitionName,
|
||||
FieldsData: fieldsData,
|
||||
NumRows: uint32(rowNum),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (opt *rowBasedDataOption) UpsertRequest(coll *entity.Collection) (*milvuspb.UpsertRequest, error) {
|
||||
columns, err := row.AnyToColumns(opt.rows, coll.Schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opt.columnBasedDataOption.columns = columns
|
||||
fieldsData, rowNum, err := opt.processInsertColumns(coll.Schema, opt.columns...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &milvuspb.UpsertRequest{
|
||||
CollectionName: opt.collName,
|
||||
PartitionName: opt.partitionName,
|
||||
FieldsData: fieldsData,
|
||||
NumRows: uint32(rowNum),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type DeleteOption interface {
|
||||
Request() *milvuspb.DeleteRequest
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue