enhance: Update latest sdk update to client pkg (#33105)

Related to #31293
See also milvus-io/milvus-sdk-go#704 milvus-io/milvus-sdk-go#711 
milvus-io/milvus-sdk-go#713 milvus-io/milvus-sdk-go#721
milvus-io/milvus-sdk-go#732 milvus-io/milvus-sdk-go#739 
milvus-io/milvus-sdk-go#748

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/32991/head^2
congqixia 2024-05-17 10:39:37 +08:00 committed by GitHub
parent f1c9986974
commit 1ef975d327
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 555 additions and 125 deletions

View File

@ -1,5 +1,6 @@
reviewers:
- congqixia
- ThreadDao
approvers:
- maintainers

View File

@ -18,14 +18,20 @@ package client
import (
"context"
"crypto/tls"
"fmt"
"math"
"os"
"strconv"
"sync"
"time"
"github.com/gogo/status"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
@ -39,6 +45,11 @@ type Client struct {
service milvuspb.MilvusServiceClient
config *ClientConfig
// mutable status
stateMut sync.RWMutex
currentDB string
identifier string
collCache *CollectionCache
}
@ -54,8 +65,10 @@ func New(ctx context.Context, config *ClientConfig) (*Client, error) {
// Parse remote address.
addr := c.config.getParsedAddress()
// parse authentication parameters
c.config.parseAuthentication()
// Parse grpc options
options := c.config.getDialOption()
options := c.dialOptions()
// Connect the grpc server.
if err := c.connect(ctx, addr, options...); err != nil {
@ -69,6 +82,40 @@ func New(ctx context.Context, config *ClientConfig) (*Client, error) {
return c, nil
}
func (c *Client) dialOptions() []grpc.DialOption {
var options []grpc.DialOption
// Construct dial option.
if c.config.EnableTLSAuth {
options = append(options, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})))
} else {
options = append(options, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
if c.config.DialOptions == nil {
// Add default connection options.
options = append(options, DefaultGrpcOpts...)
} else {
options = append(options, c.config.DialOptions...)
}
options = append(options,
grpc.WithChainUnaryInterceptor(grpc_retry.UnaryClientInterceptor(
grpc_retry.WithMax(6),
grpc_retry.WithBackoff(func(attempt uint) time.Duration {
return 60 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
}),
grpc_retry.WithCodes(codes.Unavailable, codes.ResourceExhausted)),
// c.getRetryOnRateLimitInterceptor(),
))
options = append(options, grpc.WithChainUnaryInterceptor(
c.MetadataUnaryInterceptor(),
))
return options
}
func (c *Client) Close(ctx context.Context) error {
if c.conn == nil {
return nil
@ -82,6 +129,18 @@ func (c *Client) Close(ctx context.Context) error {
return nil
}
func (c *Client) usingDatabase(dbName string) {
c.stateMut.Lock()
defer c.stateMut.Unlock()
c.currentDB = dbName
}
func (c *Client) setIdentifier(identifier string) {
c.stateMut.Lock()
defer c.stateMut.Unlock()
c.identifier = identifier
}
func (c *Client) connect(ctx context.Context, addr string, options ...grpc.DialOption) error {
if addr == "" {
return fmt.Errorf("address is empty")
@ -112,7 +171,7 @@ func (c *Client) connectInternal(ctx context.Context) error {
req := &milvuspb.ConnectRequest{
ClientInfo: &commonpb.ClientInfo{
SdkType: "Golang",
SdkType: "GoMilvusClient",
SdkVersion: common.SDKVersion,
LocalTime: time.Now().String(),
User: c.config.Username,
@ -131,8 +190,8 @@ func (c *Client) connectInternal(ctx context.Context) error {
disableJSON |
disableParitionKey |
disableDynamicSchema)
return nil
}
return nil
}
return err
}
@ -142,7 +201,7 @@ func (c *Client) connectInternal(ctx context.Context) error {
}
c.config.setServerInfo(resp.GetServerInfo().GetBuildTags())
c.config.setIdentifier(strconv.FormatInt(resp.GetIdentifier(), 10))
c.setIdentifier(strconv.FormatInt(resp.GetIdentifier(), 10))
return nil
}

View File

@ -1,7 +1,7 @@
package client
import (
"crypto/tls"
"context"
"fmt"
"math"
"net/url"
@ -10,12 +10,9 @@ import (
"time"
"github.com/cockroachdb/errors"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
"github.com/milvus-io/milvus/pkg/util/crypto"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
)
@ -59,16 +56,23 @@ type ClientConfig struct {
DialOptions []grpc.DialOption // Dial options for GRPC.
// RetryRateLimit *RetryRateLimitOption // option for retry on rate limit inteceptor
RetryRateLimit *RetryRateLimitOption // option for retry on rate limit inteceptor
DisableConn bool
metadataHeaders map[string]string
identifier string // Identifier for this connection
ServerVersion string // ServerVersion
parsedAddress *url.URL
flags uint64 // internal flags
}
type RetryRateLimitOption struct {
MaxRetry uint
MaxBackoff time.Duration
}
func (cfg *ClientConfig) parse() error {
// Prepend default fake tcp:// scheme for remote address.
address := cfg.Address
@ -118,54 +122,36 @@ func (c *ClientConfig) setServerInfo(serverInfo string) {
c.ServerVersion = serverInfo
}
// Get parsed grpc dial options, should be called after parse was called.
func (c *ClientConfig) getDialOption() []grpc.DialOption {
options := c.DialOptions
if c.DialOptions == nil {
// Add default connection options.
options = make([]grpc.DialOption, len(DefaultGrpcOpts))
copy(options, DefaultGrpcOpts)
// parseAuthentication prepares authentication headers for grpc inteceptors based on the provided username, password or API key.
func (c *ClientConfig) parseAuthentication() {
c.metadataHeaders = make(map[string]string)
if c.Username != "" || c.Password != "" {
value := crypto.Base64Encode(fmt.Sprintf("%s:%s", c.Username, c.Password))
c.metadataHeaders[authorizationHeader] = value
}
// Construct dial option.
if c.EnableTLSAuth {
options = append(options, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})))
} else {
options = append(options, grpc.WithTransportCredentials(insecure.NewCredentials()))
// API overwrites username & passwd
if c.APIKey != "" {
value := crypto.Base64Encode(c.APIKey)
c.metadataHeaders[authorizationHeader] = value
}
options = append(options,
grpc.WithChainUnaryInterceptor(grpc_retry.UnaryClientInterceptor(
grpc_retry.WithMax(6),
grpc_retry.WithBackoff(func(attempt uint) time.Duration {
return 60 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
}),
grpc_retry.WithCodes(codes.Unavailable, codes.ResourceExhausted)),
// c.getRetryOnRateLimitInterceptor(),
))
// options = append(options, grpc.WithChainUnaryInterceptor(
// createMetaDataUnaryInterceptor(c),
// ))
return options
}
// func (c *ClientConfig) getRetryOnRateLimitInterceptor() grpc.UnaryClientInterceptor {
// if c.RetryRateLimit == nil {
// c.RetryRateLimit = c.defaultRetryRateLimitOption()
// }
func (c *ClientConfig) getRetryOnRateLimitInterceptor() grpc.UnaryClientInterceptor {
if c.RetryRateLimit == nil {
c.RetryRateLimit = c.defaultRetryRateLimitOption()
}
// return RetryOnRateLimitInterceptor(c.RetryRateLimit.MaxRetry, c.RetryRateLimit.MaxBackoff, func(ctx context.Context, attempt uint) time.Duration {
// return 10 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
// })
// }
return RetryOnRateLimitInterceptor(c.RetryRateLimit.MaxRetry, c.RetryRateLimit.MaxBackoff, func(ctx context.Context, attempt uint) time.Duration {
return 10 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
})
}
// func (c *ClientConfig) defaultRetryRateLimitOption() *RetryRateLimitOption {
// return &RetryRateLimitOption{
// MaxRetry: 75,
// MaxBackoff: 3 * time.Second,
// }
// }
func (c *ClientConfig) defaultRetryRateLimitOption() *RetryRateLimitOption {
return &RetryRateLimitOption{
MaxRetry: 75,
MaxBackoff: 3 * time.Second,
}
}
// addFlags set internal flags
func (c *ClientConfig) addFlags(flags uint64) {

View File

@ -98,6 +98,7 @@ func (c *Client) DescribeCollection(ctx context.Context, option *describeCollect
VirtualChannels: resp.GetVirtualChannelNames(),
ConsistencyLevel: entity.ConsistencyLevel(resp.ConsistencyLevel),
ShardNum: resp.GetShardsNum(),
Properties: entity.KvPairsMap(resp.GetProperties()),
}
collection.Name = collection.Schema.CollectionName
return nil

View File

@ -140,6 +140,7 @@ func SimpleCreateCollectionOptions(name string, dim int64) *createCollectionOpti
autoID: true,
dim: dim,
enabledDynamicSchema: true,
consistencyLevel: entity.DefaultConsistencyLevel,
isFast: true,
metricType: entity.COSINE,
@ -149,9 +150,10 @@ func SimpleCreateCollectionOptions(name string, dim int64) *createCollectionOpti
// NewCreateCollectionOption returns a CreateCollectionOption with customized collection schema
func NewCreateCollectionOption(name string, collectionSchema *entity.Schema) *createCollectionOption {
return &createCollectionOption{
name: name,
shardNum: 1,
schema: collectionSchema,
name: name,
shardNum: 1,
schema: collectionSchema,
consistencyLevel: entity.DefaultConsistencyLevel,
metricType: entity.COSINE,
}

View File

@ -64,26 +64,38 @@ var errFieldDataTypeNotMatch = errors.New("FieldData type not matched")
// IDColumns converts schemapb.IDs to corresponding column
// currently Int64 / string may be in IDs
func IDColumns(idField *schemapb.IDs, begin, end int) (Column, error) {
func IDColumns(schema *entity.Schema, ids *schemapb.IDs, begin, end int) (Column, error) {
var idColumn Column
if idField == nil {
pkField := schema.PKField()
if pkField == nil {
return nil, errors.New("PK Field not found")
}
if ids == nil {
return nil, errors.New("nil Ids from response")
}
switch field := idField.GetIdField().(type) {
case *schemapb.IDs_IntId:
if end >= 0 {
idColumn = NewColumnInt64("", field.IntId.GetData()[begin:end])
} else {
idColumn = NewColumnInt64("", field.IntId.GetData()[begin:])
switch pkField.DataType {
case entity.FieldTypeInt64:
data := ids.GetIntId().GetData()
if data == nil {
return NewColumnInt64(pkField.Name, nil), nil
}
case *schemapb.IDs_StrId:
if end >= 0 {
idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:end])
idColumn = NewColumnInt64(pkField.Name, data[begin:end])
} else {
idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:])
idColumn = NewColumnInt64(pkField.Name, data[begin:])
}
case entity.FieldTypeVarChar, entity.FieldTypeString:
data := ids.GetStrId().GetData()
if data == nil {
return NewColumnVarChar(pkField.Name, nil), nil
}
if end >= 0 {
idColumn = NewColumnVarChar(pkField.Name, data[begin:end])
} else {
idColumn = NewColumnVarChar(pkField.Name, data[begin:])
}
default:
return nil, fmt.Errorf("unsupported id type %v", field)
return nil, fmt.Errorf("unsupported id type %v", pkField.DataType)
}
return idColumn, nil
}

View File

@ -24,18 +24,34 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
)
func TestIDColumns(t *testing.T) {
dataLen := rand.Intn(100) + 1
base := rand.Intn(5000) // id start point
intPKCol := entity.NewSchema().WithField(
entity.NewField().WithName("pk").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeInt64),
)
strPKCol := entity.NewSchema().WithField(
entity.NewField().WithName("pk").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeVarChar),
)
t.Run("nil id", func(t *testing.T) {
_, err := IDColumns(nil, 0, -1)
assert.NotNil(t, err)
col, err := IDColumns(intPKCol, nil, 0, -1)
assert.NoError(t, err)
assert.EqualValues(t, 0, col.Len())
col, err = IDColumns(strPKCol, nil, 0, -1)
assert.NoError(t, err)
assert.EqualValues(t, 0, col.Len())
idField := &schemapb.IDs{}
_, err = IDColumns(idField, 0, -1)
assert.NotNil(t, err)
col, err = IDColumns(intPKCol, idField, 0, -1)
assert.NoError(t, err)
assert.EqualValues(t, 0, col.Len())
col, err = IDColumns(strPKCol, idField, 0, -1)
assert.NoError(t, err)
assert.EqualValues(t, 0, col.Len())
})
t.Run("int ids", func(t *testing.T) {
@ -50,12 +66,12 @@ func TestIDColumns(t *testing.T) {
},
},
}
column, err := IDColumns(idField, 0, dataLen)
column, err := IDColumns(intPKCol, idField, 0, dataLen)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, dataLen, column.Len())
column, err = IDColumns(idField, 0, -1) // test -1 method
column, err = IDColumns(intPKCol, idField, 0, -1) // test -1 method
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, dataLen, column.Len())
@ -72,12 +88,12 @@ func TestIDColumns(t *testing.T) {
},
},
}
column, err := IDColumns(idField, 0, dataLen)
column, err := IDColumns(strPKCol, idField, 0, dataLen)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, dataLen, column.Len())
column, err = IDColumns(idField, 0, -1) // test -1 method
column, err = IDColumns(strPKCol, idField, 0, -1) // test -1 method
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, dataLen, column.Len())

View File

@ -25,6 +25,12 @@ import (
"github.com/milvus-io/milvus/pkg/util/merr"
)
func (c *Client) UsingDatabase(ctx context.Context, option UsingDatabaseOption) error {
dbName := option.DbName()
c.usingDatabase(dbName)
return c.connectInternal(ctx)
}
func (c *Client) ListDatabase(ctx context.Context, option ListDatabaseOption, callOptions ...grpc.CallOption) (databaseNames []string, err error) {
req := option.Request()

View File

@ -18,6 +18,24 @@ package client
import "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
type UsingDatabaseOption interface {
DbName() string
}
type usingDatabaseNameOpt struct {
dbName string
}
func (opt *usingDatabaseNameOpt) DbName() string {
return opt.dbName
}
func NewUsingDatabaseOption(dbName string) *usingDatabaseNameOpt {
return &usingDatabaseNameOpt{
dbName: dbName,
}
}
// ListDatabaseOption is a builder interface for ListDatabase request.
type ListDatabaseOption interface {
Request() *milvuspb.ListDatabasesRequest

View File

@ -32,6 +32,7 @@ type Collection struct {
Loaded bool
ConsistencyLevel ConsistencyLevel
ShardNum int32
Properties map[string]string
}
// Partition represent partition meta in Milvus

View File

@ -60,6 +60,8 @@ type Schema struct {
AutoID bool
Fields []*Field
EnableDynamicField bool
pkField *Field
}
// NewSchema creates an empty schema object.
@ -91,6 +93,9 @@ func (s *Schema) WithDynamicFieldEnabled(dynamicEnabled bool) *Schema {
// WithField adds a field into schema and returns schema itself.
func (s *Schema) WithField(f *Field) *Schema {
if f.PrimaryKey {
s.pkField = f
}
s.Fields = append(s.Fields, f)
return s
}
@ -116,10 +121,14 @@ func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema {
s.CollectionName = p.GetName()
s.Fields = make([]*Field, 0, len(p.GetFields()))
for _, fp := range p.GetFields() {
field := NewField().ReadProto(fp)
if fp.GetAutoID() {
s.AutoID = true
}
s.Fields = append(s.Fields, NewField().ReadProto(fp))
if field.PrimaryKey {
s.pkField = field
}
s.Fields = append(s.Fields, field)
}
s.EnableDynamicField = p.GetEnableDynamicField()
return s
@ -127,12 +136,15 @@ func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema {
// PKFieldName returns pk field name for this schemapb.
func (s *Schema) PKFieldName() string {
for _, field := range s.Fields {
if field.PrimaryKey {
return field.Name
}
if s.pkField == nil {
return ""
}
return ""
return s.pkField.Name
}
// PKField returns PK Field schema for this schema.
func (s *Schema) PKField() *Field {
return s.pkField
}
// Field represent field schema in milvus

View File

@ -17,6 +17,8 @@
package client
import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/index"
@ -31,15 +33,27 @@ type createIndexOption struct {
fieldName string
indexName string
indexDef index.Index
extraParams map[string]any
}
func (opt *createIndexOption) WithExtraParam(key string, value any) {
opt.extraParams[key] = value
}
func (opt *createIndexOption) Request() *milvuspb.CreateIndexRequest {
return &milvuspb.CreateIndexRequest{
params := opt.indexDef.Params()
for key, value := range opt.extraParams {
params[key] = fmt.Sprintf("%v", value)
}
req := &milvuspb.CreateIndexRequest{
CollectionName: opt.collectionName,
FieldName: opt.fieldName,
IndexName: opt.indexName,
ExtraParams: entity.MapKvPairs(opt.indexDef.Params()),
ExtraParams: entity.MapKvPairs(params),
}
return req
}
func (opt *createIndexOption) WithIndexName(indexName string) *createIndexOption {
@ -52,6 +66,7 @@ func NewCreateIndexOption(collectionName string, fieldName string, index index.I
collectionName: collectionName,
fieldName: fieldName,
indexDef: index,
extraParams: make(map[string]any),
}
}

159
client/interceptors.go Normal file
View File

@ -0,0 +1,159 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
)
const (
authorizationHeader = `authorization`
identifierHeader = `identifier`
databaseHeader = `dbname`
)
func (c *Client) MetadataUnaryInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx = c.metadata(ctx)
ctx = c.state(ctx)
return invoker(ctx, method, req, reply, cc, opts...)
}
}
func (c *Client) metadata(ctx context.Context) context.Context {
for k, v := range c.config.metadataHeaders {
ctx = metadata.AppendToOutgoingContext(ctx, k, v)
}
return ctx
}
func (c *Client) state(ctx context.Context) context.Context {
c.stateMut.RLock()
defer c.stateMut.RUnlock()
if c.currentDB != "" {
ctx = metadata.AppendToOutgoingContext(ctx, databaseHeader, c.currentDB)
}
if c.identifier != "" {
ctx = metadata.AppendToOutgoingContext(ctx, identifierHeader, c.identifier)
}
return ctx
}
// ref: https://github.com/grpc-ecosystem/go-grpc-middleware
type ctxKey int
const (
RetryOnRateLimit ctxKey = iota
)
// RetryOnRateLimitInterceptor returns a new retrying unary client interceptor.
func RetryOnRateLimitInterceptor(maxRetry uint, maxBackoff time.Duration, backoffFunc grpc_retry.BackoffFuncContext) grpc.UnaryClientInterceptor {
return func(parentCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if maxRetry == 0 {
return invoker(parentCtx, method, req, reply, cc, opts...)
}
var lastErr error
for attempt := uint(0); attempt < maxRetry; attempt++ {
_, err := waitRetryBackoff(parentCtx, attempt, maxBackoff, backoffFunc)
if err != nil {
return err
}
lastErr = invoker(parentCtx, method, req, reply, cc, opts...)
rspStatus := getResultStatus(reply)
if retryOnRateLimit(parentCtx) && rspStatus.GetErrorCode() == commonpb.ErrorCode_RateLimit {
continue
}
return lastErr
}
return lastErr
}
}
func retryOnRateLimit(ctx context.Context) bool {
retry, ok := ctx.Value(RetryOnRateLimit).(bool)
if !ok {
return true // default true
}
return retry
}
// getResultStatus returns status of response.
func getResultStatus(reply interface{}) *commonpb.Status {
switch r := reply.(type) {
case *commonpb.Status:
return r
case *milvuspb.MutationResult:
return r.GetStatus()
case *milvuspb.BoolResponse:
return r.GetStatus()
case *milvuspb.SearchResults:
return r.GetStatus()
case *milvuspb.QueryResults:
return r.GetStatus()
case *milvuspb.FlushResponse:
return r.GetStatus()
default:
return nil
}
}
func contextErrToGrpcErr(err error) error {
switch err {
case context.DeadlineExceeded:
return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled:
return status.Error(codes.Canceled, err.Error())
default:
return status.Error(codes.Unknown, err.Error())
}
}
func waitRetryBackoff(parentCtx context.Context, attempt uint, maxBackoff time.Duration, backoffFunc grpc_retry.BackoffFuncContext) (time.Duration, error) {
var waitTime time.Duration
if attempt > 0 {
waitTime = backoffFunc(parentCtx, attempt)
}
if waitTime > 0 {
if waitTime > maxBackoff {
waitTime = maxBackoff
}
timer := time.NewTimer(waitTime)
select {
case <-parentCtx.Done():
timer.Stop()
return waitTime, contextErrToGrpcErr(parentCtx.Err())
case <-timer.C:
}
}
return waitTime, nil
}

View File

@ -0,0 +1,66 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"math"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
)
var mockInvokerError error
var mockInvokerReply interface{}
var mockInvokeTimes = 0
var mockInvoker grpc.UnaryInvoker = func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
mockInvokeTimes++
return mockInvokerError
}
func resetMockInvokeTimes() {
mockInvokeTimes = 0
}
func TestRateLimitInterceptor(t *testing.T) {
maxRetry := uint(3)
maxBackoff := 3 * time.Second
inter := RetryOnRateLimitInterceptor(maxRetry, maxBackoff, func(ctx context.Context, attempt uint) time.Duration {
return 60 * time.Millisecond * time.Duration(math.Pow(2, float64(attempt)))
})
ctx := context.Background()
// with retry
mockInvokerReply = &commonpb.Status{ErrorCode: commonpb.ErrorCode_RateLimit}
resetMockInvokeTimes()
err := inter(ctx, "", nil, mockInvokerReply, nil, mockInvoker)
assert.NoError(t, err)
assert.Equal(t, maxRetry, uint(mockInvokeTimes))
// without retry
ctx1 := context.WithValue(ctx, RetryOnRateLimit, false)
resetMockInvokeTimes()
err = inter(ctx1, "", nil, mockInvokerReply, nil, mockInvoker)
assert.NoError(t, err)
assert.Equal(t, uint(1), uint(mockInvokeTimes))
}

View File

@ -33,7 +33,7 @@ type ResultSets struct{}
type ResultSet struct {
ResultCount int // the returning entry count
GroupByValue any
GroupByValue column.Column
IDs column.Column // auto generated id, can be mapped to the columns from `Insert` API
Fields DataSet // output field data
Scores []float32 // distance to the target vector
@ -67,35 +67,32 @@ func (c *Client) Search(ctx context.Context, option SearchOption, callOptions ..
}
func (c *Client) handleSearchResult(schema *entity.Schema, outputFields []string, nq int, resp *milvuspb.SearchResults) ([]ResultSet, error) {
var err error
sr := make([]ResultSet, 0, nq)
results := resp.GetResults()
offset := 0
fieldDataList := results.GetFieldsData()
gb := results.GetGroupByFieldValue()
var gbc column.Column
if gb != nil {
gbc, err = column.FieldDataColumn(gb, 0, -1)
if err != nil {
return nil, err
}
}
for i := 0; i < int(results.GetNumQueries()); i++ {
rc := int(results.GetTopks()[i]) // result entry count for current query
entry := ResultSet{
ResultCount: rc,
Scores: results.GetScores()[offset : offset+rc],
}
if gbc != nil {
entry.GroupByValue, _ = gbc.Get(i)
}
// parse result set if current nq is not empty
if rc > 0 {
entry.IDs, entry.Err = column.IDColumns(results.GetIds(), offset, offset+rc)
entry.IDs, entry.Err = column.IDColumns(schema, results.GetIds(), offset, offset+rc)
if entry.Err != nil {
offset += rc
continue
}
// parse group-by values
if gb != nil {
entry.GroupByValue, entry.Err = column.FieldDataColumn(gb, offset, offset+rc)
if entry.Err != nil {
offset += rc
continue
}
}
entry.Fields, entry.Err = c.parseSearchResult(schema, outputFields, fieldDataList, i, offset, offset+rc)
sr = append(sr, entry)
}

View File

@ -87,7 +87,7 @@ func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.
// search param
bs, _ := json.Marshal(annRequest.searchParam)
request.SearchParams = entity.MapKvPairs(map[string]string{
params := map[string]string{
spAnnsField: annRequest.annField,
spTopK: strconv.Itoa(opt.topK),
spOffset: strconv.Itoa(opt.offset),
@ -95,8 +95,11 @@ func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.
spMetricsType: string(annRequest.metricsType),
spRoundDecimal: "-1",
spIgnoreGrowing: strconv.FormatBool(opt.ignoreGrowing),
spGroupBy: annRequest.groupByField,
})
}
if annRequest.groupByField != "" {
params[spGroupBy] = annRequest.groupByField
}
request.SearchParams = entity.MapKvPairs(params)
// placeholder group
request.PlaceholderGroup = vector2PlaceholderGroupBytes(annRequest.vectors)

View File

@ -22,53 +22,90 @@ import (
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/column"
"github.com/milvus-io/milvus/pkg/util/merr"
)
func (c *Client) Insert(ctx context.Context, option InsertOption, callOptions ...grpc.CallOption) error {
type InsertResult struct {
InsertCount int64
IDs column.Column
}
func (c *Client) Insert(ctx context.Context, option InsertOption, callOptions ...grpc.CallOption) (InsertResult, error) {
result := InsertResult{}
collection, err := c.getCollection(ctx, option.CollectionName())
if err != nil {
return err
return result, err
}
req, err := option.InsertRequest(collection)
if err != nil {
return err
return result, err
}
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Insert(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
result.InsertCount = resp.GetInsertCnt()
result.IDs, err = column.IDColumns(collection.Schema, resp.GetIDs(), 0, -1)
if err != nil {
return err
}
return nil
})
return err
return result, err
}
func (c *Client) Delete(ctx context.Context, option DeleteOption, callOptions ...grpc.CallOption) error {
type DeleteResult struct {
DeleteCount int64
}
func (c *Client) Delete(ctx context.Context, option DeleteOption, callOptions ...grpc.CallOption) (DeleteResult, error) {
req := option.Request()
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
result := DeleteResult{}
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Delete(ctx, req, callOptions...)
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
result.DeleteCount = resp.GetDeleteCnt()
return nil
})
return result, err
}
func (c *Client) Upsert(ctx context.Context, option UpsertOption, callOptions ...grpc.CallOption) error {
type UpsertResult struct {
UpsertCount int64
IDs column.Column
}
func (c *Client) Upsert(ctx context.Context, option UpsertOption, callOptions ...grpc.CallOption) (UpsertResult, error) {
result := UpsertResult{}
collection, err := c.getCollection(ctx, option.CollectionName())
if err != nil {
return err
return result, err
}
req, err := option.UpsertRequest(collection)
if err != nil {
return err
return result, err
}
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Upsert(ctx, req, callOptions...)
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
result.UpsertCount = resp.GetUpsertCnt()
result.IDs, err = column.IDColumns(collection.Schema, resp.GetIDs(), 0, -1)
if err != nil {
return err
}
return nil
})
return result, err
}

View File

@ -23,6 +23,7 @@ import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/samber/lo"
@ -63,16 +64,25 @@ func (s *WriteSuite) TestInsert() {
s.Require().Len(ir.GetFieldsData(), 2)
s.EqualValues(3, ir.GetNumRows())
return &milvuspb.MutationResult{
Status: merr.Success(),
Status: merr.Success(),
InsertCnt: 3,
IDs: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3},
},
},
},
}, nil
}).Once()
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})).
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
s.NoError(err)
s.EqualValues(3, result.InsertCount)
})
s.Run("dynamic_schema", func() {
@ -86,17 +96,26 @@ func (s *WriteSuite) TestInsert() {
s.Require().Len(ir.GetFieldsData(), 3)
s.EqualValues(3, ir.GetNumRows())
return &milvuspb.MutationResult{
Status: merr.Success(),
Status: merr.Success(),
InsertCnt: 3,
IDs: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3},
},
},
},
}, nil
}).Once()
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})).
WithVarcharColumn("extra", []string{"a", "b", "c"}).
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
s.NoError(err)
s.EqualValues(3, result.InsertCount)
})
s.Run("bad_input", func() {
@ -141,7 +160,7 @@ func (s *WriteSuite) TestInsert() {
for _, tc := range cases {
s.Run(tc.tag, func() {
err := s.client.Insert(ctx, tc.input)
_, err := s.client.Insert(ctx, tc.input)
s.Error(err)
})
}
@ -153,7 +172,7 @@ func (s *WriteSuite) TestInsert() {
s.mock.EXPECT().Insert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
_, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})).
@ -177,16 +196,25 @@ func (s *WriteSuite) TestUpsert() {
s.Require().Len(ur.GetFieldsData(), 2)
s.EqualValues(3, ur.GetNumRows())
return &milvuspb.MutationResult{
Status: merr.Success(),
Status: merr.Success(),
UpsertCnt: 3,
IDs: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3},
},
},
},
}, nil
}).Once()
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})).
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
s.NoError(err)
s.EqualValues(3, result.UpsertCount)
})
s.Run("dynamic_schema", func() {
@ -200,17 +228,26 @@ func (s *WriteSuite) TestUpsert() {
s.Require().Len(ur.GetFieldsData(), 3)
s.EqualValues(3, ur.GetNumRows())
return &milvuspb.MutationResult{
Status: merr.Success(),
Status: merr.Success(),
UpsertCnt: 3,
IDs: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3},
},
},
},
}, nil
}).Once()
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})).
WithVarcharColumn("extra", []string{"a", "b", "c"}).
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
s.NoError(err)
s.EqualValues(3, result.UpsertCount)
})
s.Run("bad_input", func() {
@ -255,7 +292,7 @@ func (s *WriteSuite) TestUpsert() {
for _, tc := range cases {
s.Run(tc.tag, func() {
err := s.client.Upsert(ctx, tc.input)
_, err := s.client.Upsert(ctx, tc.input)
s.Error(err)
})
}
@ -267,7 +304,7 @@ func (s *WriteSuite) TestUpsert() {
s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
_, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})).
@ -315,11 +352,13 @@ func (s *WriteSuite) TestDelete() {
s.Equal(partName, dr.GetPartitionName())
s.Equal(tc.expectExpr, dr.GetExpr())
return &milvuspb.MutationResult{
Status: merr.Success(),
Status: merr.Success(),
DeleteCnt: 100,
}, nil
}).Once()
err := s.client.Delete(ctx, tc.input)
result, err := s.client.Delete(ctx, tc.input)
s.NoError(err)
s.EqualValues(100, result.DeleteCount)
})
}
})