mirror of https://github.com/milvus-io/milvus.git
enhance: Update latest sdk update to client pkg (#33105)
Related to #31293 See also milvus-io/milvus-sdk-go#704 milvus-io/milvus-sdk-go#711 milvus-io/milvus-sdk-go#713 milvus-io/milvus-sdk-go#721 milvus-io/milvus-sdk-go#732 milvus-io/milvus-sdk-go#739 milvus-io/milvus-sdk-go#748 --------- Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>pull/32991/head^2
parent
f1c9986974
commit
1ef975d327
|
@ -1,5 +1,6 @@
|
|||
reviewers:
|
||||
- congqixia
|
||||
- ThreadDao
|
||||
|
||||
approvers:
|
||||
- maintainers
|
||||
|
|
|
@ -18,14 +18,20 @@ package client
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gogo/status"
|
||||
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
|
@ -39,6 +45,11 @@ type Client struct {
|
|||
service milvuspb.MilvusServiceClient
|
||||
config *ClientConfig
|
||||
|
||||
// mutable status
|
||||
stateMut sync.RWMutex
|
||||
currentDB string
|
||||
identifier string
|
||||
|
||||
collCache *CollectionCache
|
||||
}
|
||||
|
||||
|
@ -54,8 +65,10 @@ func New(ctx context.Context, config *ClientConfig) (*Client, error) {
|
|||
// Parse remote address.
|
||||
addr := c.config.getParsedAddress()
|
||||
|
||||
// parse authentication parameters
|
||||
c.config.parseAuthentication()
|
||||
// Parse grpc options
|
||||
options := c.config.getDialOption()
|
||||
options := c.dialOptions()
|
||||
|
||||
// Connect the grpc server.
|
||||
if err := c.connect(ctx, addr, options...); err != nil {
|
||||
|
@ -69,6 +82,40 @@ func New(ctx context.Context, config *ClientConfig) (*Client, error) {
|
|||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Client) dialOptions() []grpc.DialOption {
|
||||
var options []grpc.DialOption
|
||||
// Construct dial option.
|
||||
if c.config.EnableTLSAuth {
|
||||
options = append(options, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})))
|
||||
} else {
|
||||
options = append(options, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
}
|
||||
|
||||
if c.config.DialOptions == nil {
|
||||
// Add default connection options.
|
||||
options = append(options, DefaultGrpcOpts...)
|
||||
} else {
|
||||
options = append(options, c.config.DialOptions...)
|
||||
}
|
||||
|
||||
options = append(options,
|
||||
grpc.WithChainUnaryInterceptor(grpc_retry.UnaryClientInterceptor(
|
||||
grpc_retry.WithMax(6),
|
||||
grpc_retry.WithBackoff(func(attempt uint) time.Duration {
|
||||
return 60 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
|
||||
}),
|
||||
grpc_retry.WithCodes(codes.Unavailable, codes.ResourceExhausted)),
|
||||
|
||||
// c.getRetryOnRateLimitInterceptor(),
|
||||
))
|
||||
|
||||
options = append(options, grpc.WithChainUnaryInterceptor(
|
||||
c.MetadataUnaryInterceptor(),
|
||||
))
|
||||
|
||||
return options
|
||||
}
|
||||
|
||||
func (c *Client) Close(ctx context.Context) error {
|
||||
if c.conn == nil {
|
||||
return nil
|
||||
|
@ -82,6 +129,18 @@ func (c *Client) Close(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) usingDatabase(dbName string) {
|
||||
c.stateMut.Lock()
|
||||
defer c.stateMut.Unlock()
|
||||
c.currentDB = dbName
|
||||
}
|
||||
|
||||
func (c *Client) setIdentifier(identifier string) {
|
||||
c.stateMut.Lock()
|
||||
defer c.stateMut.Unlock()
|
||||
c.identifier = identifier
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context, addr string, options ...grpc.DialOption) error {
|
||||
if addr == "" {
|
||||
return fmt.Errorf("address is empty")
|
||||
|
@ -112,7 +171,7 @@ func (c *Client) connectInternal(ctx context.Context) error {
|
|||
|
||||
req := &milvuspb.ConnectRequest{
|
||||
ClientInfo: &commonpb.ClientInfo{
|
||||
SdkType: "Golang",
|
||||
SdkType: "GoMilvusClient",
|
||||
SdkVersion: common.SDKVersion,
|
||||
LocalTime: time.Now().String(),
|
||||
User: c.config.Username,
|
||||
|
@ -131,8 +190,8 @@ func (c *Client) connectInternal(ctx context.Context) error {
|
|||
disableJSON |
|
||||
disableParitionKey |
|
||||
disableDynamicSchema)
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -142,7 +201,7 @@ func (c *Client) connectInternal(ctx context.Context) error {
|
|||
}
|
||||
|
||||
c.config.setServerInfo(resp.GetServerInfo().GetBuildTags())
|
||||
c.config.setIdentifier(strconv.FormatInt(resp.GetIdentifier(), 10))
|
||||
c.setIdentifier(strconv.FormatInt(resp.GetIdentifier(), 10))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/url"
|
||||
|
@ -10,12 +10,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
|
||||
"github.com/milvus-io/milvus/pkg/util/crypto"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/backoff"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
)
|
||||
|
||||
|
@ -59,16 +56,23 @@ type ClientConfig struct {
|
|||
|
||||
DialOptions []grpc.DialOption // Dial options for GRPC.
|
||||
|
||||
// RetryRateLimit *RetryRateLimitOption // option for retry on rate limit inteceptor
|
||||
RetryRateLimit *RetryRateLimitOption // option for retry on rate limit inteceptor
|
||||
|
||||
DisableConn bool
|
||||
|
||||
metadataHeaders map[string]string
|
||||
|
||||
identifier string // Identifier for this connection
|
||||
ServerVersion string // ServerVersion
|
||||
parsedAddress *url.URL
|
||||
flags uint64 // internal flags
|
||||
}
|
||||
|
||||
type RetryRateLimitOption struct {
|
||||
MaxRetry uint
|
||||
MaxBackoff time.Duration
|
||||
}
|
||||
|
||||
func (cfg *ClientConfig) parse() error {
|
||||
// Prepend default fake tcp:// scheme for remote address.
|
||||
address := cfg.Address
|
||||
|
@ -118,54 +122,36 @@ func (c *ClientConfig) setServerInfo(serverInfo string) {
|
|||
c.ServerVersion = serverInfo
|
||||
}
|
||||
|
||||
// Get parsed grpc dial options, should be called after parse was called.
|
||||
func (c *ClientConfig) getDialOption() []grpc.DialOption {
|
||||
options := c.DialOptions
|
||||
if c.DialOptions == nil {
|
||||
// Add default connection options.
|
||||
options = make([]grpc.DialOption, len(DefaultGrpcOpts))
|
||||
copy(options, DefaultGrpcOpts)
|
||||
// parseAuthentication prepares authentication headers for grpc inteceptors based on the provided username, password or API key.
|
||||
func (c *ClientConfig) parseAuthentication() {
|
||||
c.metadataHeaders = make(map[string]string)
|
||||
if c.Username != "" || c.Password != "" {
|
||||
value := crypto.Base64Encode(fmt.Sprintf("%s:%s", c.Username, c.Password))
|
||||
c.metadataHeaders[authorizationHeader] = value
|
||||
}
|
||||
|
||||
// Construct dial option.
|
||||
if c.EnableTLSAuth {
|
||||
options = append(options, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})))
|
||||
} else {
|
||||
options = append(options, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
// API overwrites username & passwd
|
||||
if c.APIKey != "" {
|
||||
value := crypto.Base64Encode(c.APIKey)
|
||||
c.metadataHeaders[authorizationHeader] = value
|
||||
}
|
||||
|
||||
options = append(options,
|
||||
grpc.WithChainUnaryInterceptor(grpc_retry.UnaryClientInterceptor(
|
||||
grpc_retry.WithMax(6),
|
||||
grpc_retry.WithBackoff(func(attempt uint) time.Duration {
|
||||
return 60 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
|
||||
}),
|
||||
grpc_retry.WithCodes(codes.Unavailable, codes.ResourceExhausted)),
|
||||
// c.getRetryOnRateLimitInterceptor(),
|
||||
))
|
||||
|
||||
// options = append(options, grpc.WithChainUnaryInterceptor(
|
||||
// createMetaDataUnaryInterceptor(c),
|
||||
// ))
|
||||
return options
|
||||
}
|
||||
|
||||
// func (c *ClientConfig) getRetryOnRateLimitInterceptor() grpc.UnaryClientInterceptor {
|
||||
// if c.RetryRateLimit == nil {
|
||||
// c.RetryRateLimit = c.defaultRetryRateLimitOption()
|
||||
// }
|
||||
func (c *ClientConfig) getRetryOnRateLimitInterceptor() grpc.UnaryClientInterceptor {
|
||||
if c.RetryRateLimit == nil {
|
||||
c.RetryRateLimit = c.defaultRetryRateLimitOption()
|
||||
}
|
||||
|
||||
// return RetryOnRateLimitInterceptor(c.RetryRateLimit.MaxRetry, c.RetryRateLimit.MaxBackoff, func(ctx context.Context, attempt uint) time.Duration {
|
||||
// return 10 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
|
||||
// })
|
||||
// }
|
||||
return RetryOnRateLimitInterceptor(c.RetryRateLimit.MaxRetry, c.RetryRateLimit.MaxBackoff, func(ctx context.Context, attempt uint) time.Duration {
|
||||
return 10 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
|
||||
})
|
||||
}
|
||||
|
||||
// func (c *ClientConfig) defaultRetryRateLimitOption() *RetryRateLimitOption {
|
||||
// return &RetryRateLimitOption{
|
||||
// MaxRetry: 75,
|
||||
// MaxBackoff: 3 * time.Second,
|
||||
// }
|
||||
// }
|
||||
func (c *ClientConfig) defaultRetryRateLimitOption() *RetryRateLimitOption {
|
||||
return &RetryRateLimitOption{
|
||||
MaxRetry: 75,
|
||||
MaxBackoff: 3 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// addFlags set internal flags
|
||||
func (c *ClientConfig) addFlags(flags uint64) {
|
||||
|
|
|
@ -98,6 +98,7 @@ func (c *Client) DescribeCollection(ctx context.Context, option *describeCollect
|
|||
VirtualChannels: resp.GetVirtualChannelNames(),
|
||||
ConsistencyLevel: entity.ConsistencyLevel(resp.ConsistencyLevel),
|
||||
ShardNum: resp.GetShardsNum(),
|
||||
Properties: entity.KvPairsMap(resp.GetProperties()),
|
||||
}
|
||||
collection.Name = collection.Schema.CollectionName
|
||||
return nil
|
||||
|
|
|
@ -140,6 +140,7 @@ func SimpleCreateCollectionOptions(name string, dim int64) *createCollectionOpti
|
|||
autoID: true,
|
||||
dim: dim,
|
||||
enabledDynamicSchema: true,
|
||||
consistencyLevel: entity.DefaultConsistencyLevel,
|
||||
|
||||
isFast: true,
|
||||
metricType: entity.COSINE,
|
||||
|
@ -149,9 +150,10 @@ func SimpleCreateCollectionOptions(name string, dim int64) *createCollectionOpti
|
|||
// NewCreateCollectionOption returns a CreateCollectionOption with customized collection schema
|
||||
func NewCreateCollectionOption(name string, collectionSchema *entity.Schema) *createCollectionOption {
|
||||
return &createCollectionOption{
|
||||
name: name,
|
||||
shardNum: 1,
|
||||
schema: collectionSchema,
|
||||
name: name,
|
||||
shardNum: 1,
|
||||
schema: collectionSchema,
|
||||
consistencyLevel: entity.DefaultConsistencyLevel,
|
||||
|
||||
metricType: entity.COSINE,
|
||||
}
|
||||
|
|
|
@ -64,26 +64,38 @@ var errFieldDataTypeNotMatch = errors.New("FieldData type not matched")
|
|||
|
||||
// IDColumns converts schemapb.IDs to corresponding column
|
||||
// currently Int64 / string may be in IDs
|
||||
func IDColumns(idField *schemapb.IDs, begin, end int) (Column, error) {
|
||||
func IDColumns(schema *entity.Schema, ids *schemapb.IDs, begin, end int) (Column, error) {
|
||||
var idColumn Column
|
||||
if idField == nil {
|
||||
pkField := schema.PKField()
|
||||
if pkField == nil {
|
||||
return nil, errors.New("PK Field not found")
|
||||
}
|
||||
if ids == nil {
|
||||
return nil, errors.New("nil Ids from response")
|
||||
}
|
||||
switch field := idField.GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
if end >= 0 {
|
||||
idColumn = NewColumnInt64("", field.IntId.GetData()[begin:end])
|
||||
} else {
|
||||
idColumn = NewColumnInt64("", field.IntId.GetData()[begin:])
|
||||
switch pkField.DataType {
|
||||
case entity.FieldTypeInt64:
|
||||
data := ids.GetIntId().GetData()
|
||||
if data == nil {
|
||||
return NewColumnInt64(pkField.Name, nil), nil
|
||||
}
|
||||
case *schemapb.IDs_StrId:
|
||||
if end >= 0 {
|
||||
idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:end])
|
||||
idColumn = NewColumnInt64(pkField.Name, data[begin:end])
|
||||
} else {
|
||||
idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:])
|
||||
idColumn = NewColumnInt64(pkField.Name, data[begin:])
|
||||
}
|
||||
case entity.FieldTypeVarChar, entity.FieldTypeString:
|
||||
data := ids.GetStrId().GetData()
|
||||
if data == nil {
|
||||
return NewColumnVarChar(pkField.Name, nil), nil
|
||||
}
|
||||
if end >= 0 {
|
||||
idColumn = NewColumnVarChar(pkField.Name, data[begin:end])
|
||||
} else {
|
||||
idColumn = NewColumnVarChar(pkField.Name, data[begin:])
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported id type %v", field)
|
||||
return nil, fmt.Errorf("unsupported id type %v", pkField.DataType)
|
||||
}
|
||||
return idColumn, nil
|
||||
}
|
||||
|
|
|
@ -24,18 +24,34 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
func TestIDColumns(t *testing.T) {
|
||||
dataLen := rand.Intn(100) + 1
|
||||
base := rand.Intn(5000) // id start point
|
||||
|
||||
intPKCol := entity.NewSchema().WithField(
|
||||
entity.NewField().WithName("pk").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeInt64),
|
||||
)
|
||||
strPKCol := entity.NewSchema().WithField(
|
||||
entity.NewField().WithName("pk").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeVarChar),
|
||||
)
|
||||
|
||||
t.Run("nil id", func(t *testing.T) {
|
||||
_, err := IDColumns(nil, 0, -1)
|
||||
assert.NotNil(t, err)
|
||||
col, err := IDColumns(intPKCol, nil, 0, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 0, col.Len())
|
||||
col, err = IDColumns(strPKCol, nil, 0, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 0, col.Len())
|
||||
idField := &schemapb.IDs{}
|
||||
_, err = IDColumns(idField, 0, -1)
|
||||
assert.NotNil(t, err)
|
||||
col, err = IDColumns(intPKCol, idField, 0, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 0, col.Len())
|
||||
col, err = IDColumns(strPKCol, idField, 0, -1)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 0, col.Len())
|
||||
})
|
||||
|
||||
t.Run("int ids", func(t *testing.T) {
|
||||
|
@ -50,12 +66,12 @@ func TestIDColumns(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
column, err := IDColumns(idField, 0, dataLen)
|
||||
column, err := IDColumns(intPKCol, idField, 0, dataLen)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
assert.Equal(t, dataLen, column.Len())
|
||||
|
||||
column, err = IDColumns(idField, 0, -1) // test -1 method
|
||||
column, err = IDColumns(intPKCol, idField, 0, -1) // test -1 method
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
assert.Equal(t, dataLen, column.Len())
|
||||
|
@ -72,12 +88,12 @@ func TestIDColumns(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
column, err := IDColumns(idField, 0, dataLen)
|
||||
column, err := IDColumns(strPKCol, idField, 0, dataLen)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
assert.Equal(t, dataLen, column.Len())
|
||||
|
||||
column, err = IDColumns(idField, 0, -1) // test -1 method
|
||||
column, err = IDColumns(strPKCol, idField, 0, -1) // test -1 method
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
assert.Equal(t, dataLen, column.Len())
|
||||
|
|
|
@ -25,6 +25,12 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
func (c *Client) UsingDatabase(ctx context.Context, option UsingDatabaseOption) error {
|
||||
dbName := option.DbName()
|
||||
c.usingDatabase(dbName)
|
||||
return c.connectInternal(ctx)
|
||||
}
|
||||
|
||||
func (c *Client) ListDatabase(ctx context.Context, option ListDatabaseOption, callOptions ...grpc.CallOption) (databaseNames []string, err error) {
|
||||
req := option.Request()
|
||||
|
||||
|
|
|
@ -18,6 +18,24 @@ package client
|
|||
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
|
||||
type UsingDatabaseOption interface {
|
||||
DbName() string
|
||||
}
|
||||
|
||||
type usingDatabaseNameOpt struct {
|
||||
dbName string
|
||||
}
|
||||
|
||||
func (opt *usingDatabaseNameOpt) DbName() string {
|
||||
return opt.dbName
|
||||
}
|
||||
|
||||
func NewUsingDatabaseOption(dbName string) *usingDatabaseNameOpt {
|
||||
return &usingDatabaseNameOpt{
|
||||
dbName: dbName,
|
||||
}
|
||||
}
|
||||
|
||||
// ListDatabaseOption is a builder interface for ListDatabase request.
|
||||
type ListDatabaseOption interface {
|
||||
Request() *milvuspb.ListDatabasesRequest
|
||||
|
|
|
@ -32,6 +32,7 @@ type Collection struct {
|
|||
Loaded bool
|
||||
ConsistencyLevel ConsistencyLevel
|
||||
ShardNum int32
|
||||
Properties map[string]string
|
||||
}
|
||||
|
||||
// Partition represent partition meta in Milvus
|
||||
|
|
|
@ -60,6 +60,8 @@ type Schema struct {
|
|||
AutoID bool
|
||||
Fields []*Field
|
||||
EnableDynamicField bool
|
||||
|
||||
pkField *Field
|
||||
}
|
||||
|
||||
// NewSchema creates an empty schema object.
|
||||
|
@ -91,6 +93,9 @@ func (s *Schema) WithDynamicFieldEnabled(dynamicEnabled bool) *Schema {
|
|||
|
||||
// WithField adds a field into schema and returns schema itself.
|
||||
func (s *Schema) WithField(f *Field) *Schema {
|
||||
if f.PrimaryKey {
|
||||
s.pkField = f
|
||||
}
|
||||
s.Fields = append(s.Fields, f)
|
||||
return s
|
||||
}
|
||||
|
@ -116,10 +121,14 @@ func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema {
|
|||
s.CollectionName = p.GetName()
|
||||
s.Fields = make([]*Field, 0, len(p.GetFields()))
|
||||
for _, fp := range p.GetFields() {
|
||||
field := NewField().ReadProto(fp)
|
||||
if fp.GetAutoID() {
|
||||
s.AutoID = true
|
||||
}
|
||||
s.Fields = append(s.Fields, NewField().ReadProto(fp))
|
||||
if field.PrimaryKey {
|
||||
s.pkField = field
|
||||
}
|
||||
s.Fields = append(s.Fields, field)
|
||||
}
|
||||
s.EnableDynamicField = p.GetEnableDynamicField()
|
||||
return s
|
||||
|
@ -127,12 +136,15 @@ func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema {
|
|||
|
||||
// PKFieldName returns pk field name for this schemapb.
|
||||
func (s *Schema) PKFieldName() string {
|
||||
for _, field := range s.Fields {
|
||||
if field.PrimaryKey {
|
||||
return field.Name
|
||||
}
|
||||
if s.pkField == nil {
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
return s.pkField.Name
|
||||
}
|
||||
|
||||
// PKField returns PK Field schema for this schema.
|
||||
func (s *Schema) PKField() *Field {
|
||||
return s.pkField
|
||||
}
|
||||
|
||||
// Field represent field schema in milvus
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/client/v2/index"
|
||||
|
@ -31,15 +33,27 @@ type createIndexOption struct {
|
|||
fieldName string
|
||||
indexName string
|
||||
indexDef index.Index
|
||||
|
||||
extraParams map[string]any
|
||||
}
|
||||
|
||||
func (opt *createIndexOption) WithExtraParam(key string, value any) {
|
||||
opt.extraParams[key] = value
|
||||
}
|
||||
|
||||
func (opt *createIndexOption) Request() *milvuspb.CreateIndexRequest {
|
||||
return &milvuspb.CreateIndexRequest{
|
||||
params := opt.indexDef.Params()
|
||||
for key, value := range opt.extraParams {
|
||||
params[key] = fmt.Sprintf("%v", value)
|
||||
}
|
||||
req := &milvuspb.CreateIndexRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
FieldName: opt.fieldName,
|
||||
IndexName: opt.indexName,
|
||||
ExtraParams: entity.MapKvPairs(opt.indexDef.Params()),
|
||||
ExtraParams: entity.MapKvPairs(params),
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func (opt *createIndexOption) WithIndexName(indexName string) *createIndexOption {
|
||||
|
@ -52,6 +66,7 @@ func NewCreateIndexOption(collectionName string, fieldName string, index index.I
|
|||
collectionName: collectionName,
|
||||
fieldName: fieldName,
|
||||
indexDef: index,
|
||||
extraParams: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,159 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
)
|
||||
|
||||
const (
|
||||
authorizationHeader = `authorization`
|
||||
|
||||
identifierHeader = `identifier`
|
||||
|
||||
databaseHeader = `dbname`
|
||||
)
|
||||
|
||||
func (c *Client) MetadataUnaryInterceptor() grpc.UnaryClientInterceptor {
|
||||
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
ctx = c.metadata(ctx)
|
||||
ctx = c.state(ctx)
|
||||
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) metadata(ctx context.Context) context.Context {
|
||||
for k, v := range c.config.metadataHeaders {
|
||||
ctx = metadata.AppendToOutgoingContext(ctx, k, v)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (c *Client) state(ctx context.Context) context.Context {
|
||||
c.stateMut.RLock()
|
||||
defer c.stateMut.RUnlock()
|
||||
|
||||
if c.currentDB != "" {
|
||||
ctx = metadata.AppendToOutgoingContext(ctx, databaseHeader, c.currentDB)
|
||||
}
|
||||
if c.identifier != "" {
|
||||
ctx = metadata.AppendToOutgoingContext(ctx, identifierHeader, c.identifier)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
// ref: https://github.com/grpc-ecosystem/go-grpc-middleware
|
||||
|
||||
type ctxKey int
|
||||
|
||||
const (
|
||||
RetryOnRateLimit ctxKey = iota
|
||||
)
|
||||
|
||||
// RetryOnRateLimitInterceptor returns a new retrying unary client interceptor.
|
||||
func RetryOnRateLimitInterceptor(maxRetry uint, maxBackoff time.Duration, backoffFunc grpc_retry.BackoffFuncContext) grpc.UnaryClientInterceptor {
|
||||
return func(parentCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
if maxRetry == 0 {
|
||||
return invoker(parentCtx, method, req, reply, cc, opts...)
|
||||
}
|
||||
var lastErr error
|
||||
for attempt := uint(0); attempt < maxRetry; attempt++ {
|
||||
_, err := waitRetryBackoff(parentCtx, attempt, maxBackoff, backoffFunc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lastErr = invoker(parentCtx, method, req, reply, cc, opts...)
|
||||
rspStatus := getResultStatus(reply)
|
||||
if retryOnRateLimit(parentCtx) && rspStatus.GetErrorCode() == commonpb.ErrorCode_RateLimit {
|
||||
continue
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
}
|
||||
|
||||
func retryOnRateLimit(ctx context.Context) bool {
|
||||
retry, ok := ctx.Value(RetryOnRateLimit).(bool)
|
||||
if !ok {
|
||||
return true // default true
|
||||
}
|
||||
return retry
|
||||
}
|
||||
|
||||
// getResultStatus returns status of response.
|
||||
func getResultStatus(reply interface{}) *commonpb.Status {
|
||||
switch r := reply.(type) {
|
||||
case *commonpb.Status:
|
||||
return r
|
||||
case *milvuspb.MutationResult:
|
||||
return r.GetStatus()
|
||||
case *milvuspb.BoolResponse:
|
||||
return r.GetStatus()
|
||||
case *milvuspb.SearchResults:
|
||||
return r.GetStatus()
|
||||
case *milvuspb.QueryResults:
|
||||
return r.GetStatus()
|
||||
case *milvuspb.FlushResponse:
|
||||
return r.GetStatus()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func contextErrToGrpcErr(err error) error {
|
||||
switch err {
|
||||
case context.DeadlineExceeded:
|
||||
return status.Error(codes.DeadlineExceeded, err.Error())
|
||||
case context.Canceled:
|
||||
return status.Error(codes.Canceled, err.Error())
|
||||
default:
|
||||
return status.Error(codes.Unknown, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func waitRetryBackoff(parentCtx context.Context, attempt uint, maxBackoff time.Duration, backoffFunc grpc_retry.BackoffFuncContext) (time.Duration, error) {
|
||||
var waitTime time.Duration
|
||||
if attempt > 0 {
|
||||
waitTime = backoffFunc(parentCtx, attempt)
|
||||
}
|
||||
if waitTime > 0 {
|
||||
if waitTime > maxBackoff {
|
||||
waitTime = maxBackoff
|
||||
}
|
||||
timer := time.NewTimer(waitTime)
|
||||
select {
|
||||
case <-parentCtx.Done():
|
||||
timer.Stop()
|
||||
return waitTime, contextErrToGrpcErr(parentCtx.Err())
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
return waitTime, nil
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
)
|
||||
|
||||
var mockInvokerError error
|
||||
var mockInvokerReply interface{}
|
||||
var mockInvokeTimes = 0
|
||||
|
||||
var mockInvoker grpc.UnaryInvoker = func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
|
||||
mockInvokeTimes++
|
||||
return mockInvokerError
|
||||
}
|
||||
|
||||
func resetMockInvokeTimes() {
|
||||
mockInvokeTimes = 0
|
||||
}
|
||||
|
||||
func TestRateLimitInterceptor(t *testing.T) {
|
||||
maxRetry := uint(3)
|
||||
maxBackoff := 3 * time.Second
|
||||
inter := RetryOnRateLimitInterceptor(maxRetry, maxBackoff, func(ctx context.Context, attempt uint) time.Duration {
|
||||
return 60 * time.Millisecond * time.Duration(math.Pow(2, float64(attempt)))
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// with retry
|
||||
mockInvokerReply = &commonpb.Status{ErrorCode: commonpb.ErrorCode_RateLimit}
|
||||
resetMockInvokeTimes()
|
||||
err := inter(ctx, "", nil, mockInvokerReply, nil, mockInvoker)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, maxRetry, uint(mockInvokeTimes))
|
||||
|
||||
// without retry
|
||||
ctx1 := context.WithValue(ctx, RetryOnRateLimit, false)
|
||||
resetMockInvokeTimes()
|
||||
err = inter(ctx1, "", nil, mockInvokerReply, nil, mockInvoker)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint(1), uint(mockInvokeTimes))
|
||||
}
|
|
@ -33,7 +33,7 @@ type ResultSets struct{}
|
|||
|
||||
type ResultSet struct {
|
||||
ResultCount int // the returning entry count
|
||||
GroupByValue any
|
||||
GroupByValue column.Column
|
||||
IDs column.Column // auto generated id, can be mapped to the columns from `Insert` API
|
||||
Fields DataSet // output field data
|
||||
Scores []float32 // distance to the target vector
|
||||
|
@ -67,35 +67,32 @@ func (c *Client) Search(ctx context.Context, option SearchOption, callOptions ..
|
|||
}
|
||||
|
||||
func (c *Client) handleSearchResult(schema *entity.Schema, outputFields []string, nq int, resp *milvuspb.SearchResults) ([]ResultSet, error) {
|
||||
var err error
|
||||
sr := make([]ResultSet, 0, nq)
|
||||
results := resp.GetResults()
|
||||
offset := 0
|
||||
fieldDataList := results.GetFieldsData()
|
||||
gb := results.GetGroupByFieldValue()
|
||||
var gbc column.Column
|
||||
if gb != nil {
|
||||
gbc, err = column.FieldDataColumn(gb, 0, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
for i := 0; i < int(results.GetNumQueries()); i++ {
|
||||
rc := int(results.GetTopks()[i]) // result entry count for current query
|
||||
entry := ResultSet{
|
||||
ResultCount: rc,
|
||||
Scores: results.GetScores()[offset : offset+rc],
|
||||
}
|
||||
if gbc != nil {
|
||||
entry.GroupByValue, _ = gbc.Get(i)
|
||||
}
|
||||
// parse result set if current nq is not empty
|
||||
if rc > 0 {
|
||||
entry.IDs, entry.Err = column.IDColumns(results.GetIds(), offset, offset+rc)
|
||||
entry.IDs, entry.Err = column.IDColumns(schema, results.GetIds(), offset, offset+rc)
|
||||
if entry.Err != nil {
|
||||
offset += rc
|
||||
continue
|
||||
}
|
||||
// parse group-by values
|
||||
if gb != nil {
|
||||
entry.GroupByValue, entry.Err = column.FieldDataColumn(gb, offset, offset+rc)
|
||||
if entry.Err != nil {
|
||||
offset += rc
|
||||
continue
|
||||
}
|
||||
}
|
||||
entry.Fields, entry.Err = c.parseSearchResult(schema, outputFields, fieldDataList, i, offset, offset+rc)
|
||||
sr = append(sr, entry)
|
||||
}
|
||||
|
|
|
@ -87,7 +87,7 @@ func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.
|
|||
|
||||
// search param
|
||||
bs, _ := json.Marshal(annRequest.searchParam)
|
||||
request.SearchParams = entity.MapKvPairs(map[string]string{
|
||||
params := map[string]string{
|
||||
spAnnsField: annRequest.annField,
|
||||
spTopK: strconv.Itoa(opt.topK),
|
||||
spOffset: strconv.Itoa(opt.offset),
|
||||
|
@ -95,8 +95,11 @@ func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.
|
|||
spMetricsType: string(annRequest.metricsType),
|
||||
spRoundDecimal: "-1",
|
||||
spIgnoreGrowing: strconv.FormatBool(opt.ignoreGrowing),
|
||||
spGroupBy: annRequest.groupByField,
|
||||
})
|
||||
}
|
||||
if annRequest.groupByField != "" {
|
||||
params[spGroupBy] = annRequest.groupByField
|
||||
}
|
||||
request.SearchParams = entity.MapKvPairs(params)
|
||||
|
||||
// placeholder group
|
||||
request.PlaceholderGroup = vector2PlaceholderGroupBytes(annRequest.vectors)
|
||||
|
|
|
@ -22,53 +22,90 @@ import (
|
|||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/client/v2/column"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
func (c *Client) Insert(ctx context.Context, option InsertOption, callOptions ...grpc.CallOption) error {
|
||||
type InsertResult struct {
|
||||
InsertCount int64
|
||||
IDs column.Column
|
||||
}
|
||||
|
||||
func (c *Client) Insert(ctx context.Context, option InsertOption, callOptions ...grpc.CallOption) (InsertResult, error) {
|
||||
result := InsertResult{}
|
||||
collection, err := c.getCollection(ctx, option.CollectionName())
|
||||
if err != nil {
|
||||
return err
|
||||
return result, err
|
||||
}
|
||||
req, err := option.InsertRequest(collection)
|
||||
if err != nil {
|
||||
return err
|
||||
return result, err
|
||||
}
|
||||
|
||||
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.Insert(ctx, req, callOptions...)
|
||||
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
err = merr.CheckRPCCall(resp, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result.InsertCount = resp.GetInsertCnt()
|
||||
result.IDs, err = column.IDColumns(collection.Schema, resp.GetIDs(), 0, -1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (c *Client) Delete(ctx context.Context, option DeleteOption, callOptions ...grpc.CallOption) error {
|
||||
type DeleteResult struct {
|
||||
DeleteCount int64
|
||||
}
|
||||
|
||||
func (c *Client) Delete(ctx context.Context, option DeleteOption, callOptions ...grpc.CallOption) (DeleteResult, error) {
|
||||
req := option.Request()
|
||||
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
result := DeleteResult{}
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.Delete(ctx, req, callOptions...)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
result.DeleteCount = resp.GetDeleteCnt()
|
||||
return nil
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (c *Client) Upsert(ctx context.Context, option UpsertOption, callOptions ...grpc.CallOption) error {
|
||||
type UpsertResult struct {
|
||||
UpsertCount int64
|
||||
IDs column.Column
|
||||
}
|
||||
|
||||
func (c *Client) Upsert(ctx context.Context, option UpsertOption, callOptions ...grpc.CallOption) (UpsertResult, error) {
|
||||
result := UpsertResult{}
|
||||
collection, err := c.getCollection(ctx, option.CollectionName())
|
||||
if err != nil {
|
||||
return err
|
||||
return result, err
|
||||
}
|
||||
req, err := option.UpsertRequest(collection)
|
||||
if err != nil {
|
||||
return err
|
||||
return result, err
|
||||
}
|
||||
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.Upsert(ctx, req, callOptions...)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
result.UpsertCount = resp.GetUpsertCnt()
|
||||
result.IDs, err = column.IDColumns(collection.Schema, resp.GetIDs(), 0, -1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/samber/lo"
|
||||
|
@ -63,16 +64,25 @@ func (s *WriteSuite) TestInsert() {
|
|||
s.Require().Len(ir.GetFieldsData(), 2)
|
||||
s.EqualValues(3, ir.GetNumRows())
|
||||
return &milvuspb.MutationResult{
|
||||
Status: merr.Success(),
|
||||
Status: merr.Success(),
|
||||
InsertCnt: 3,
|
||||
IDs: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
||||
result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})).
|
||||
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
||||
s.NoError(err)
|
||||
s.EqualValues(3, result.InsertCount)
|
||||
})
|
||||
|
||||
s.Run("dynamic_schema", func() {
|
||||
|
@ -86,17 +96,26 @@ func (s *WriteSuite) TestInsert() {
|
|||
s.Require().Len(ir.GetFieldsData(), 3)
|
||||
s.EqualValues(3, ir.GetNumRows())
|
||||
return &milvuspb.MutationResult{
|
||||
Status: merr.Success(),
|
||||
Status: merr.Success(),
|
||||
InsertCnt: 3,
|
||||
IDs: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
||||
result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})).
|
||||
WithVarcharColumn("extra", []string{"a", "b", "c"}).
|
||||
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
||||
s.NoError(err)
|
||||
s.EqualValues(3, result.InsertCount)
|
||||
})
|
||||
|
||||
s.Run("bad_input", func() {
|
||||
|
@ -141,7 +160,7 @@ func (s *WriteSuite) TestInsert() {
|
|||
|
||||
for _, tc := range cases {
|
||||
s.Run(tc.tag, func() {
|
||||
err := s.client.Insert(ctx, tc.input)
|
||||
_, err := s.client.Insert(ctx, tc.input)
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
@ -153,7 +172,7 @@ func (s *WriteSuite) TestInsert() {
|
|||
|
||||
s.mock.EXPECT().Insert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
||||
_, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})).
|
||||
|
@ -177,16 +196,25 @@ func (s *WriteSuite) TestUpsert() {
|
|||
s.Require().Len(ur.GetFieldsData(), 2)
|
||||
s.EqualValues(3, ur.GetNumRows())
|
||||
return &milvuspb.MutationResult{
|
||||
Status: merr.Success(),
|
||||
Status: merr.Success(),
|
||||
UpsertCnt: 3,
|
||||
IDs: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
||||
result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})).
|
||||
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
||||
s.NoError(err)
|
||||
s.EqualValues(3, result.UpsertCount)
|
||||
})
|
||||
|
||||
s.Run("dynamic_schema", func() {
|
||||
|
@ -200,17 +228,26 @@ func (s *WriteSuite) TestUpsert() {
|
|||
s.Require().Len(ur.GetFieldsData(), 3)
|
||||
s.EqualValues(3, ur.GetNumRows())
|
||||
return &milvuspb.MutationResult{
|
||||
Status: merr.Success(),
|
||||
Status: merr.Success(),
|
||||
UpsertCnt: 3,
|
||||
IDs: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
||||
result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})).
|
||||
WithVarcharColumn("extra", []string{"a", "b", "c"}).
|
||||
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
||||
s.NoError(err)
|
||||
s.EqualValues(3, result.UpsertCount)
|
||||
})
|
||||
|
||||
s.Run("bad_input", func() {
|
||||
|
@ -255,7 +292,7 @@ func (s *WriteSuite) TestUpsert() {
|
|||
|
||||
for _, tc := range cases {
|
||||
s.Run(tc.tag, func() {
|
||||
err := s.client.Upsert(ctx, tc.input)
|
||||
_, err := s.client.Upsert(ctx, tc.input)
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
@ -267,7 +304,7 @@ func (s *WriteSuite) TestUpsert() {
|
|||
|
||||
s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
||||
_, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})).
|
||||
|
@ -315,11 +352,13 @@ func (s *WriteSuite) TestDelete() {
|
|||
s.Equal(partName, dr.GetPartitionName())
|
||||
s.Equal(tc.expectExpr, dr.GetExpr())
|
||||
return &milvuspb.MutationResult{
|
||||
Status: merr.Success(),
|
||||
Status: merr.Success(),
|
||||
DeleteCnt: 100,
|
||||
}, nil
|
||||
}).Once()
|
||||
err := s.client.Delete(ctx, tc.input)
|
||||
result, err := s.client.Delete(ctx, tc.input)
|
||||
s.NoError(err)
|
||||
s.EqualValues(100, result.DeleteCount)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue