enhance: [GoSDK] Sync API names and add missing APIs (#38603)

Related to #31293

- Rename `UsingDatabase` to `UseDatabase`
- Uncomment default value methods
- Add missing RBAC APIs
- Add some resource group APIs

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/38610/head
congqixia 2024-12-20 11:52:46 +08:00 committed by GitHub
parent ca7ec23198
commit c39db11509
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 1844 additions and 57 deletions

View File

@ -16,6 +16,8 @@
package entity
import "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
// MetricType metric type
type MetricType string
@ -31,3 +33,12 @@ const (
SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE"
BM25 MetricType = "BM25"
)
// CompactionState enum type for compaction state
type CompactionState commonpb.CompactionState
// CompactionState Constants
const (
CompactionStateRunning CompactionState = CompactionState(commonpb.CompactionState_Executing)
CompactionStateCompleted CompactionState = CompactionState(commonpb.CompactionState_Completed)
)

View File

@ -193,6 +193,8 @@ type Field struct {
IsPartitionKey bool
IsClusteringKey bool
ElementType FieldType
DefaultValue *schemapb.ValueField
Nullable bool
}
// ProtoMessage generates corresponding FieldSchema
@ -261,7 +263,11 @@ func (f *Field) WithIsClusteringKey(isClusteringKey bool) *Field {
return f
}
/*
func (f *Field) WithNullable(nullable bool) *Field {
f.Nullable = nullable
return f
}
func (f *Field) WithDefaultValueBool(defaultValue bool) *Field {
f.DefaultValue = &schemapb.ValueField{
Data: &schemapb.ValueField_BoolData{
@ -314,7 +320,7 @@ func (f *Field) WithDefaultValueString(defaultValue string) *Field {
},
}
return f
}*/
}
func (f *Field) WithTypeParams(key string, value string) *Field {
if f.TypeParams == nil {

View File

@ -30,13 +30,13 @@ func TestFieldSchema(t *testing.T) {
NewField().WithName("array_field").WithDataType(FieldTypeArray).WithElementType(FieldTypeBool).WithMaxCapacity(128),
NewField().WithName("clustering_key").WithDataType(FieldTypeInt32).WithIsClusteringKey(true),
NewField().WithName("varchar_text").WithDataType(FieldTypeVarChar).WithMaxLength(65535).WithEnableAnalyzer(true).WithAnalyzerParams(map[string]any{}),
/*
NewField().WithName("default_value_bool").WithDataType(FieldTypeBool).WithDefaultValueBool(true),
NewField().WithName("default_value_int").WithDataType(FieldTypeInt32).WithDefaultValueInt(1),
NewField().WithName("default_value_long").WithDataType(FieldTypeInt64).WithDefaultValueLong(1),
NewField().WithName("default_value_float").WithDataType(FieldTypeFloat).WithDefaultValueFloat(1),
NewField().WithName("default_value_double").WithDataType(FieldTypeDouble).WithDefaultValueDouble(1),
NewField().WithName("default_value_string").WithDataType(FieldTypeString).WithDefaultValueString("a"),*/
NewField().WithName("default_value_bool").WithDataType(FieldTypeBool).WithDefaultValueBool(true),
NewField().WithName("default_value_int").WithDataType(FieldTypeInt32).WithDefaultValueInt(1),
NewField().WithName("default_value_long").WithDataType(FieldTypeInt64).WithDefaultValueLong(1),
NewField().WithName("default_value_float").WithDataType(FieldTypeFloat).WithDefaultValueFloat(1),
NewField().WithName("default_value_double").WithDataType(FieldTypeDouble).WithDefaultValueDouble(1),
NewField().WithName("default_value_string").WithDataType(FieldTypeString).WithDefaultValueString("a"),
}
for _, field := range fields {

View File

@ -0,0 +1,34 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package entity
import "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
type LoadStateCode commonpb.LoadState
const (
// LoadStateNone LoadStateCode = LoadStateCode(commonpb.LoadState)
LoadStateLoading LoadStateCode = LoadStateCode(commonpb.LoadState_LoadStateLoading)
LoadStateLoaded LoadStateCode = LoadStateCode(commonpb.LoadState_LoadStateLoaded)
LoadStateUnloading LoadStateCode = LoadStateCode(commonpb.LoadState_LoadStateNotExist)
LoadStateNotLoad LoadStateCode = LoadStateCode(commonpb.LoadState_LoadStateNotLoad)
)
type LoadState struct {
State LoadStateCode
Progress int64
}

35
client/entity/rbac.go Normal file
View File

@ -0,0 +1,35 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package entity
type User struct {
UserName string
Roles []string
}
type Role struct {
RoleName string
Privileges []GrantItem
}
type GrantItem struct {
Object string
ObjectName string
RoleName string
Grantor string
Privilege string
}

View File

@ -6,7 +6,7 @@ require (
github.com/blang/semver/v4 v4.0.0
github.com/cockroachdb/errors v1.9.1
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84
github.com/quasilyte/go-ruleguard/dsl v0.3.22
github.com/samber/lo v1.27.0

View File

@ -318,8 +318,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69 h1:Qt0Bv2Fum3EX3OlkuQYHJINBzeU4oEuHy2lXSfB/gZw=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b h1:iPPhnFx+s7FF53UeWj7A4EYhPRMFPL6mHqyQw7qRjeQ=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84 h1:EAFxmxUVp5yYFDCrX1MQoSxkTO+ycy8NXEqEDEB3cRM=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84/go.mod h1:RATa0GS4jhkPpsYOvQ/QvcNz8rd+TlRPDiSyXQnMMxs=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=

View File

@ -147,3 +147,23 @@ func (c *Client) AlterCollection(ctx context.Context, option AlterCollectionOpti
return merr.CheckRPCCall(resp, err)
})
}
type GetCollectionOption interface {
Request() *milvuspb.GetCollectionStatisticsRequest
}
func (c *Client) GetCollectionStats(ctx context.Context, opt GetCollectionOption) (map[string]string, error) {
var stats map[string]string
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.GetCollectionStatistics(ctx, opt.Request())
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
stats = entity.KvPairsMap(resp.GetStats())
return nil
})
if err != nil {
return nil, err
}
return stats, nil
}

View File

@ -310,3 +310,17 @@ func (opt *alterCollectionOption) Request() *milvuspb.AlterCollectionRequest {
func NewAlterCollectionOption(collection string) *alterCollectionOption {
return &alterCollectionOption{collectionName: collection, properties: make(map[string]string)}
}
type getCollectionStatsOption struct {
collectionName string
}
func (opt *getCollectionStatsOption) Request() *milvuspb.GetCollectionStatisticsRequest {
return &milvuspb.GetCollectionStatisticsRequest{
CollectionName: opt.collectionName,
}
}
func NewGetCollectionStatsOption(collectionName string) *getCollectionStatsOption {
return &getCollectionStatsOption{collectionName: collectionName}
}

View File

@ -21,6 +21,7 @@ import (
"fmt"
"testing"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
@ -315,6 +316,38 @@ func (s *CollectionSuite) TestAlterCollection() {
})
}
func (s *CollectionSuite) TestGetCollectionStats() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().GetCollectionStatistics(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, gcsr *milvuspb.GetCollectionStatisticsRequest) (*milvuspb.GetCollectionStatisticsResponse, error) {
s.Equal(collName, gcsr.GetCollectionName())
return &milvuspb.GetCollectionStatisticsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
Stats: []*commonpb.KeyValuePair{
{Key: "row_count", Value: "1000"},
},
}, nil
}).Once()
stats, err := s.client.GetCollectionStats(ctx, NewGetCollectionStatsOption(collName))
s.NoError(err)
s.Len(stats, 1)
s.Equal("1000", stats["row_count"])
})
s.Run("failure", func() {
collName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().GetCollectionStatistics(mock.Anything, mock.Anything).Return(nil, errors.New("mocked")).Once()
_, err := s.client.GetCollectionStats(ctx, NewGetCollectionStatsOption(collName))
s.Error(err)
})
}
func TestCollection(t *testing.T) {
suite.Run(t, new(CollectionSuite))
}

View File

@ -25,7 +25,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/merr"
)
func (c *Client) UsingDatabase(ctx context.Context, option UsingDatabaseOption) error {
func (c *Client) UseDatabase(ctx context.Context, option UseDatabaseOption) error {
dbName := option.DbName()
c.usingDatabase(dbName)
return c.connectInternal(ctx)

View File

@ -18,20 +18,20 @@ package milvusclient
import "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
type UsingDatabaseOption interface {
type UseDatabaseOption interface {
DbName() string
}
type usingDatabaseNameOpt struct {
type useDatabaseNameOpt struct {
dbName string
}
func (opt *usingDatabaseNameOpt) DbName() string {
func (opt *useDatabaseNameOpt) DbName() string {
return opt.dbName
}
func NewUsingDatabaseOption(dbName string) *usingDatabaseNameOpt {
return &usingDatabaseNameOpt{
func NewUseDatabaseOption(dbName string) *useDatabaseNameOpt {
return &useDatabaseNameOpt{
dbName: dbName,
}
}

View File

@ -88,6 +88,26 @@ func (s *DatabaseSuite) TestDropDatabase() {
})
}
func (s *DatabaseSuite) TestUseDatabase() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
dbName := fmt.Sprintf("dt_%s", s.randString(6))
s.mock.EXPECT().Connect(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cr *milvuspb.ConnectRequest) (*milvuspb.ConnectResponse, error) {
return &milvuspb.ConnectResponse{
Status: merr.Success(),
ServerInfo: &commonpb.ServerInfo{},
}, nil
}).Once()
err := s.client.UseDatabase(ctx, NewUseDatabaseOption(dbName))
s.NoError(err)
s.Equal(dbName, s.client.currentDB)
})
}
func TestDatabase(t *testing.T) {
suite.Run(t, new(DatabaseSuite))
}

View File

@ -23,6 +23,7 @@ import (
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/util/merr"
)
@ -31,6 +32,7 @@ type LoadTask struct {
collectionName string
partitionNames []string
interval time.Duration
refresh bool
}
func (t *LoadTask) Await(ctx context.Context) error {
@ -40,6 +42,7 @@ func (t *LoadTask) Await(ctx context.Context) error {
select {
case <-timer.C:
loaded := false
refreshed := false
err := t.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
CollectionName: t.collectionName,
@ -49,12 +52,13 @@ func (t *LoadTask) Await(ctx context.Context) error {
return err
}
loaded = resp.GetProgress() == 100
refreshed = resp.GetRefreshProgress() == 100
return nil
})
if err != nil {
return err
}
if loaded {
if (loaded && !t.refresh) || (refreshed && t.refresh) {
return nil
}
if !timer.Stop() {
@ -85,6 +89,7 @@ func (c *Client) LoadCollection(ctx context.Context, option LoadCollectionOption
client: c,
collectionName: req.GetCollectionName(),
interval: option.CheckInterval(),
refresh: option.IsRefresh(),
}
return nil
@ -108,6 +113,7 @@ func (c *Client) LoadPartitions(ctx context.Context, option LoadPartitionsOption
collectionName: req.GetCollectionName(),
partitionNames: req.GetPartitionNames(),
interval: option.CheckInterval(),
refresh: option.IsRefresh(),
}
return nil
@ -115,6 +121,35 @@ func (c *Client) LoadPartitions(ctx context.Context, option LoadPartitionsOption
return task, err
}
func (c *Client) GetLoadState(ctx context.Context, option GetLoadStateOption, callOptions ...grpc.CallOption) (entity.LoadState, error) {
req := option.Request()
var state entity.LoadState
var err error
if err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.GetLoadState(ctx, req, callOptions...)
state.State = entity.LoadStateCode(resp.GetState())
return merr.CheckRPCCall(resp, err)
}); err != nil {
return state, err
}
// get progress if state is loading
if state.State == entity.LoadStateLoading {
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.GetLoadingProgress(ctx, option.ProgressRequest(), callOptions...)
if err := merr.CheckRPCCall(resp, err); err != nil {
return err
}
state.Progress = resp.GetProgress()
return nil
})
}
return state, err
}
func (c *Client) ReleaseCollection(ctx context.Context, option ReleaseCollectionOption, callOptions ...grpc.CallOption) error {
req := option.Request()
@ -134,6 +169,26 @@ func (c *Client) ReleasePartitions(ctx context.Context, option ReleasePartitions
})
}
func (c *Client) RefreshLoad(ctx context.Context, option RefreshLoadOption, callOptions ...grpc.CallOption) (LoadTask, error) {
req := option.Request()
var task LoadTask
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.LoadCollection(ctx, req, callOptions...)
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
task = LoadTask{
client: c,
collectionName: req.GetCollectionName(),
interval: option.CheckInterval(),
refresh: true,
}
return nil
})
return task, err
}
type FlushTask struct {
client *Client
collectionName string
@ -206,3 +261,29 @@ func (c *Client) Flush(ctx context.Context, option FlushOption, callOptions ...g
})
return task, err
}
func (c *Client) Compact(ctx context.Context, option CompactOption, callOptions ...grpc.CallOption) (int64, error) {
req := option.Request()
var jobID int64
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.ManualCompaction(ctx, req, callOptions...)
jobID = resp.GetCompactionID()
return merr.CheckRPCCall(resp, err)
})
return jobID, err
}
func (c *Client) GetCompactionState(ctx context.Context, option GetCompactionStateOption, callOptions ...grpc.CallOption) (entity.CompactionState, error) {
req := option.Request()
var status entity.CompactionState
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.GetCompactionState(ctx, req, callOptions...)
status = entity.CompactionState(resp.GetState())
return merr.CheckRPCCall(resp, err)
})
return status, err
}

View File

@ -25,6 +25,7 @@ import (
type LoadCollectionOption interface {
Request() *milvuspb.LoadCollectionRequest
CheckInterval() time.Duration
IsRefresh() bool
}
type loadCollectionOption struct {
@ -33,6 +34,8 @@ type loadCollectionOption struct {
replicaNum int
loadFields []string
skipLoadDynamicField bool
isRefresh bool
resourceGroups []string
}
func (opt *loadCollectionOption) Request() *milvuspb.LoadCollectionRequest {
@ -41,6 +44,7 @@ func (opt *loadCollectionOption) Request() *milvuspb.LoadCollectionRequest {
ReplicaNumber: int32(opt.replicaNum),
LoadFields: opt.loadFields,
SkipLoadDynamicField: opt.skipLoadDynamicField,
ResourceGroups: opt.resourceGroups,
}
}
@ -48,11 +52,20 @@ func (opt *loadCollectionOption) CheckInterval() time.Duration {
return opt.interval
}
func (opt *loadCollectionOption) IsRefresh() bool {
return opt.isRefresh
}
func (opt *loadCollectionOption) WithReplica(num int) *loadCollectionOption {
opt.replicaNum = num
return opt
}
func (opt *loadCollectionOption) WithResourceGroup(resourceGroups ...string) *loadCollectionOption {
opt.resourceGroups = resourceGroups
return opt
}
func (opt *loadCollectionOption) WithLoadFields(loadFields ...string) *loadCollectionOption {
opt.loadFields = loadFields
return opt
@ -63,6 +76,11 @@ func (opt *loadCollectionOption) WithSkipLoadDynamicField(skipFlag bool) *loadCo
return opt
}
func (opt *loadCollectionOption) WithRefresh(isRefresh bool) *loadCollectionOption {
opt.isRefresh = isRefresh
return opt
}
func NewLoadCollectionOption(collectionName string) *loadCollectionOption {
return &loadCollectionOption{
collectionName: collectionName,
@ -74,6 +92,7 @@ func NewLoadCollectionOption(collectionName string) *loadCollectionOption {
type LoadPartitionsOption interface {
Request() *milvuspb.LoadPartitionsRequest
CheckInterval() time.Duration
IsRefresh() bool
}
var _ LoadPartitionsOption = (*loadPartitionsOption)(nil)
@ -83,8 +102,10 @@ type loadPartitionsOption struct {
partitionNames []string
interval time.Duration
replicaNum int
resourceGroups []string
loadFields []string
skipLoadDynamicField bool
isRefresh bool
}
func (opt *loadPartitionsOption) Request() *milvuspb.LoadPartitionsRequest {
@ -94,6 +115,7 @@ func (opt *loadPartitionsOption) Request() *milvuspb.LoadPartitionsRequest {
ReplicaNumber: int32(opt.replicaNum),
LoadFields: opt.loadFields,
SkipLoadDynamicField: opt.skipLoadDynamicField,
ResourceGroups: opt.resourceGroups,
}
}
@ -101,11 +123,20 @@ func (opt *loadPartitionsOption) CheckInterval() time.Duration {
return opt.interval
}
func (opt *loadPartitionsOption) IsRefresh() bool {
return opt.isRefresh
}
func (opt *loadPartitionsOption) WithReplica(num int) *loadPartitionsOption {
opt.replicaNum = num
return opt
}
func (opt *loadPartitionsOption) WithResourceGroup(resourceGroups ...string) *loadPartitionsOption {
opt.resourceGroups = resourceGroups
return opt
}
func (opt *loadPartitionsOption) WithLoadFields(loadFields ...string) *loadPartitionsOption {
opt.loadFields = loadFields
return opt
@ -116,6 +147,11 @@ func (opt *loadPartitionsOption) WithSkipLoadDynamicField(skipFlag bool) *loadPa
return opt
}
func (opt *loadPartitionsOption) WithRefresh(isRefresh bool) *loadPartitionsOption {
opt.isRefresh = isRefresh
return opt
}
func NewLoadPartitionsOption(collectionName string, partitionsNames ...string) *loadPartitionsOption {
return &loadPartitionsOption{
collectionName: collectionName,
@ -125,6 +161,65 @@ func NewLoadPartitionsOption(collectionName string, partitionsNames ...string) *
}
}
type GetLoadStateOption interface {
Request() *milvuspb.GetLoadStateRequest
ProgressRequest() *milvuspb.GetLoadingProgressRequest
}
type getLoadStateOption struct {
collectionName string
partitionNames []string
}
func (opt *getLoadStateOption) Request() *milvuspb.GetLoadStateRequest {
return &milvuspb.GetLoadStateRequest{
CollectionName: opt.collectionName,
PartitionNames: opt.partitionNames,
}
}
func (opt *getLoadStateOption) ProgressRequest() *milvuspb.GetLoadingProgressRequest {
return &milvuspb.GetLoadingProgressRequest{
CollectionName: opt.collectionName,
PartitionNames: opt.partitionNames,
}
}
func NewGetLoadStateOption(collectionName string, partitionNames ...string) *getLoadStateOption {
return &getLoadStateOption{
collectionName: collectionName,
partitionNames: partitionNames,
}
}
type RefreshLoadOption interface {
Request() *milvuspb.LoadCollectionRequest
CheckInterval() time.Duration
}
type refreshLoadOption struct {
collectionName string
checkInterval time.Duration
}
func (opt *refreshLoadOption) Request() *milvuspb.LoadCollectionRequest {
return &milvuspb.LoadCollectionRequest{
CollectionName: opt.collectionName,
Refresh: true,
}
}
func (opt *refreshLoadOption) CheckInterval() time.Duration {
return opt.checkInterval
}
func NewRefreshLoadOption(collectionName string) *refreshLoadOption {
return &refreshLoadOption{
collectionName: collectionName,
checkInterval: time.Millisecond * 200,
}
}
type ReleaseCollectionOption interface {
Request() *milvuspb.ReleaseCollectionRequest
}
@ -203,3 +298,43 @@ func NewFlushOption(collName string) *flushOption {
interval: time.Millisecond * 200,
}
}
type CompactOption interface {
Request() *milvuspb.ManualCompactionRequest
}
type compactOption struct {
collectionName string
}
func (opt *compactOption) Request() *milvuspb.ManualCompactionRequest {
return &milvuspb.ManualCompactionRequest{
CollectionName: opt.collectionName,
}
}
func NewCompactOption(collectionName string) *compactOption {
return &compactOption{
collectionName: collectionName,
}
}
type GetCompactionStateOption interface {
Request() *milvuspb.GetCompactionStateRequest
}
type getCompactionStateOption struct {
compactionID int64
}
func (opt *getCompactionStateOption) Request() *milvuspb.GetCompactionStateRequest {
return &milvuspb.GetCompactionStateRequest{
CompactionID: opt.compactionID,
}
}
func NewGetCompactionStateOption(compactionID int64) *getCompactionStateOption {
return &getCompactionStateOption{
compactionID: compactionID,
}
}

View File

@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/util/merr"
)
@ -44,6 +45,7 @@ func (s *MaintenanceSuite) TestLoadCollection() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
fieldNames := []string{"id", "part", "vector"}
replicaNum := rand.Intn(3) + 1
rgs := []string{"rg1", "rg2"}
done := atomic.NewBool(false)
s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, lcr *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) {
@ -51,6 +53,7 @@ func (s *MaintenanceSuite) TestLoadCollection() {
s.ElementsMatch(fieldNames, lcr.GetLoadFields())
s.True(lcr.SkipLoadDynamicField)
s.EqualValues(replicaNum, lcr.GetReplicaNumber())
s.ElementsMatch(rgs, lcr.GetResourceGroups())
return merr.Success(), nil
}).Once()
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
@ -70,6 +73,7 @@ func (s *MaintenanceSuite) TestLoadCollection() {
task, err := s.client.LoadCollection(ctx, NewLoadCollectionOption(collectionName).
WithReplica(replicaNum).
WithResourceGroup(rgs...).
WithLoadFields(fieldNames...).
WithSkipLoadDynamicField(true))
s.NoError(err)
@ -114,6 +118,7 @@ func (s *MaintenanceSuite) TestLoadPartitions() {
partitionName := fmt.Sprintf("part_%s", s.randString(6))
fieldNames := []string{"id", "part", "vector"}
replicaNum := rand.Intn(3) + 1
rgs := []string{"rg1", "rg2"}
done := atomic.NewBool(false)
s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, lpr *milvuspb.LoadPartitionsRequest) (*commonpb.Status, error) {
@ -122,6 +127,7 @@ func (s *MaintenanceSuite) TestLoadPartitions() {
s.ElementsMatch(fieldNames, lpr.GetLoadFields())
s.True(lpr.SkipLoadDynamicField)
s.EqualValues(replicaNum, lpr.GetReplicaNumber())
s.ElementsMatch(rgs, lpr.GetResourceGroups())
return merr.Success(), nil
}).Once()
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
@ -142,6 +148,7 @@ func (s *MaintenanceSuite) TestLoadPartitions() {
task, err := s.client.LoadPartitions(ctx, NewLoadPartitionsOption(collectionName, partitionName).
WithReplica(replicaNum).
WithResourceGroup(rgs...).
WithLoadFields(fieldNames...).
WithSkipLoadDynamicField(true))
s.NoError(err)
@ -293,6 +300,167 @@ func (s *MaintenanceSuite) TestFlush() {
})
}
func (s *MaintenanceSuite) TestRefreshLoad() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
done := atomic.NewBool(false)
s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, lcr *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) {
s.Equal(collectionName, lcr.GetCollectionName())
s.True(lcr.GetRefresh())
return merr.Success(), nil
}).Once()
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
s.Equal(collectionName, glpr.GetCollectionName())
progress := int64(50)
if done.Load() {
progress = 100
}
return &milvuspb.GetLoadingProgressResponse{
Status: merr.Success(),
RefreshProgress: progress,
}, nil
})
defer s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Unset()
task, err := s.client.RefreshLoad(ctx, NewRefreshLoadOption(collectionName))
s.NoError(err)
ch := make(chan struct{})
go func() {
defer close(ch)
err := task.Await(ctx)
s.NoError(err)
}()
select {
case <-ch:
s.FailNow("task done before index state set to finish")
case <-time.After(time.Second):
}
done.Store(true)
select {
case <-ch:
case <-time.After(time.Second):
s.FailNow("task not done after index set finished")
}
})
s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.RefreshLoad(ctx, NewRefreshLoadOption(collectionName))
s.Error(err)
})
}
func (s *MaintenanceSuite) TestCompact() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
compactID := rand.Int63()
s.mock.EXPECT().ManualCompaction(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cr *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error) {
s.Equal(collectionName, cr.GetCollectionName())
return &milvuspb.ManualCompactionResponse{
CompactionID: compactID,
}, nil
}).Once()
id, err := s.client.Compact(ctx, NewCompactOption(collectionName))
s.NoError(err)
s.Equal(compactID, id)
})
s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().ManualCompaction(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.Compact(ctx, NewCompactOption(collectionName))
s.Error(err)
})
}
func (s *MaintenanceSuite) TestGetCompactionState() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
compactID := rand.Int63()
s.mock.EXPECT().GetCompactionState(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, gcsr *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) {
s.Equal(compactID, gcsr.GetCompactionID())
return &milvuspb.GetCompactionStateResponse{
Status: merr.Success(),
State: commonpb.CompactionState_Completed,
}, nil
}).Once()
state, err := s.client.GetCompactionState(ctx, NewGetCompactionStateOption(compactID))
s.NoError(err)
s.Equal(entity.CompactionStateCompleted, state)
})
s.Run("failure", func() {
compactID := rand.Int63()
s.mock.EXPECT().GetCompactionState(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.GetCompactionState(ctx, NewGetCompactionStateOption(compactID))
s.Error(err)
})
}
func (s *MaintenanceSuite) TestGetLoadState() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
progress := rand.Int63n(100)
s.mock.EXPECT().GetLoadState(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glsr *milvuspb.GetLoadStateRequest) (*milvuspb.GetLoadStateResponse, error) {
s.Equal(collectionName, glsr.GetCollectionName())
return &milvuspb.GetLoadStateResponse{
Status: merr.Success(),
State: commonpb.LoadState_LoadStateLoading,
}, nil
}).Once()
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
s.Equal(collectionName, glpr.GetCollectionName())
return &milvuspb.GetLoadingProgressResponse{
Status: merr.Success(),
Progress: progress,
}, nil
}).Once()
state, err := s.client.GetLoadState(ctx, NewGetLoadStateOption(collectionName))
s.NoError(err)
s.Equal(entity.LoadStateLoading, state.State)
s.Equal(progress, state.Progress)
})
s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().GetLoadState(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.GetLoadState(ctx, NewGetLoadStateOption(collectionName))
s.Error(err)
})
}
func TestMaintenance(t *testing.T) {
suite.Run(t, new(MaintenanceSuite))
}

View File

@ -22,6 +22,7 @@ import (
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/util/merr"
)
@ -76,3 +77,20 @@ func (c *Client) ListPartitions(ctx context.Context, opt ListPartitionsOption, c
})
return partitionNames, err
}
func (c *Client) GetPartitionStats(ctx context.Context, opt GetPartitionStatsOption, callOptions ...grpc.CallOption) (map[string]string, error) {
req := opt.Request()
var result map[string]string
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.GetPartitionStatistics(ctx, req, callOptions...)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
result = entity.KvPairsMap(resp.GetStats())
return nil
})
return result, err
}

View File

@ -117,3 +117,26 @@ func NewListPartitionOption(collectionName string) *listPartitionsOpt {
collectionName: collectionName,
}
}
type GetPartitionStatsOption interface {
Request() *milvuspb.GetPartitionStatisticsRequest
}
type getPartitionStatsOpt struct {
collectionName string
partitionName string
}
func (opt *getPartitionStatsOpt) Request() *milvuspb.GetPartitionStatisticsRequest {
return &milvuspb.GetPartitionStatisticsRequest{
CollectionName: opt.collectionName,
PartitionName: opt.partitionName,
}
}
func NewGetPartitionStatsOption(collectionName string, partitionName string) *getPartitionStatsOpt {
return &getPartitionStatsOpt{
collectionName: collectionName,
partitionName: partitionName,
}
}

View File

@ -162,6 +162,39 @@ func (s *PartitionSuite) TestDropPartition() {
})
}
func (s *PartitionSuite) TestGetPartitionStats() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
s.mock.EXPECT().GetPartitionStatistics(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, gpsr *milvuspb.GetPartitionStatisticsRequest) (*milvuspb.GetPartitionStatisticsResponse, error) {
s.Equal(collectionName, gpsr.GetCollectionName())
s.Equal(partitionName, gpsr.GetPartitionName())
return &milvuspb.GetPartitionStatisticsResponse{
Status: merr.Success(),
Stats: []*commonpb.KeyValuePair{
{Key: "rows", Value: "100"},
},
}, nil
}).Once()
stats, err := s.client.GetPartitionStats(ctx, NewGetPartitionStatsOption(collectionName, partitionName))
s.NoError(err)
s.Equal("100", stats["rows"])
})
s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
s.mock.EXPECT().GetPartitionStatistics(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.GetPartitionStats(ctx, NewGetPartitionStatsOption(collectionName, partitionName))
s.Error(err)
})
}
func TestPartition(t *testing.T) {
suite.Run(t, new(PartitionSuite))
}

View File

@ -19,6 +19,7 @@ package milvusclient
import (
"context"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"google.golang.org/grpc"
@ -27,6 +28,153 @@ import (
"github.com/milvus-io/milvus/pkg/util/merr"
)
func (c *Client) ListUsers(ctx context.Context, opt ListUserOption, callOpts ...grpc.CallOption) ([]string, error) {
var users []string
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.ListCredUsers(ctx, opt.Request(), callOpts...)
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
users = resp.GetUsernames()
return nil
})
return users, err
}
func (c *Client) DescribeUser(ctx context.Context, opt DescribeUserOption, callOpts ...grpc.CallOption) (*entity.User, error) {
var user *entity.User
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.SelectUser(ctx, opt.Request(), callOpts...)
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
if len(resp.GetResults()) == 0 {
return errors.New("not user found")
}
result := resp.GetResults()[0]
user = &entity.User{
UserName: result.GetUser().GetName(),
Roles: lo.Map(result.GetRoles(), func(r *milvuspb.RoleEntity, _ int) string { return r.GetName() }),
}
return nil
})
return user, err
}
func (c *Client) CreateUser(ctx context.Context, opt CreateUserOption, callOpts ...grpc.CallOption) error {
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.CreateCredential(ctx, opt.Request(), callOpts...)
return merr.CheckRPCCall(resp, err)
})
}
func (c *Client) UpdatePassword(ctx context.Context, opt UpdatePasswordOption, callOpts ...grpc.CallOption) error {
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.UpdateCredential(ctx, opt.Request(), callOpts...)
return merr.CheckRPCCall(resp, err)
})
}
func (c *Client) DropUser(ctx context.Context, opt DropUserOption, callOpts ...grpc.CallOption) error {
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DeleteCredential(ctx, opt.Request(), callOpts...)
return merr.CheckRPCCall(resp, err)
})
}
func (c *Client) ListRoles(ctx context.Context, opt ListRoleOption, callOpts ...grpc.CallOption) ([]string, error) {
var roles []string
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.SelectRole(ctx, opt.Request(), callOpts...)
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
roles = lo.Map(resp.GetResults(), func(r *milvuspb.RoleResult, _ int) string {
return r.GetRole().GetName()
})
return nil
})
return roles, err
}
func (c *Client) CreateRole(ctx context.Context, opt CreateRoleOption, callOpts ...grpc.CallOption) error {
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.CreateRole(ctx, opt.Request(), callOpts...)
return merr.CheckRPCCall(resp, err)
})
}
func (c *Client) GrantRole(ctx context.Context, opt GrantRoleOption, callOpts ...grpc.CallOption) error {
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.OperateUserRole(ctx, opt.Request(), callOpts...)
return merr.CheckRPCCall(resp, err)
})
}
func (c *Client) RevokeRole(ctx context.Context, opt RevokeRoleOption, callOpts ...grpc.CallOption) error {
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.OperateUserRole(ctx, opt.Request(), callOpts...)
return merr.CheckRPCCall(resp, err)
})
}
func (c *Client) DropRole(ctx context.Context, opt DropRoleOption, callOpts ...grpc.CallOption) error {
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DropRole(ctx, opt.Request(), callOpts...)
return merr.CheckRPCCall(resp, err)
})
}
func (c *Client) DescribeRole(ctx context.Context, option DescribeRoleOption, callOptions ...grpc.CallOption) (*entity.Role, error) {
req := option.Request()
var role *entity.Role
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.SelectGrant(ctx, req, callOptions...)
if err := merr.CheckRPCCall(resp, err); err != nil {
return err
}
if len(resp.GetEntities()) == 0 {
return errors.New("role not found")
}
role = &entity.Role{
RoleName: req.GetEntity().GetRole().GetName(),
Privileges: lo.Map(resp.GetEntities(), func(g *milvuspb.GrantEntity, _ int) entity.GrantItem {
return entity.GrantItem{
Object: g.Object.GetName(),
ObjectName: g.GetObjectName(),
RoleName: g.GetRole().GetName(),
Grantor: g.GetGrantor().GetUser().GetName(),
Privilege: g.GetGrantor().GetPrivilege().GetName(),
}
}),
}
return nil
})
return role, err
}
func (c *Client) GrantPrivilege(ctx context.Context, option GrantPrivilegeOption, callOptions ...grpc.CallOption) error {
req := option.Request()
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.OperatePrivilege(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err)
})
}
func (c *Client) RevokePrivilege(ctx context.Context, option RevokePrivilegeOption, callOptions ...grpc.CallOption) error {
req := option.Request()
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.OperatePrivilege(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err)
})
}
func (c *Client) GrantV2(ctx context.Context, option GrantV2Option, callOptions ...grpc.CallOption) error {
req := option.Request()

View File

@ -20,6 +20,314 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
)
type ListUserOption interface {
Request() *milvuspb.ListCredUsersRequest
}
// listUserOption is the struct to build ListCredUsersRequest
// left empty for not attribute needed right now
type listUserOption struct{}
func (opt *listUserOption) Request() *milvuspb.ListCredUsersRequest {
return &milvuspb.ListCredUsersRequest{}
}
func NewListUserOption() *listUserOption {
return &listUserOption{}
}
type DescribeUserOption interface {
Request() *milvuspb.SelectUserRequest
}
type describeUserOption struct {
userName string
}
func (opt *describeUserOption) Request() *milvuspb.SelectUserRequest {
return &milvuspb.SelectUserRequest{
User: &milvuspb.UserEntity{
Name: opt.userName,
},
IncludeRoleInfo: true,
}
}
func NewDescribeUserOption(userName string) *describeUserOption {
return &describeUserOption{
userName: userName,
}
}
type CreateUserOption interface {
Request() *milvuspb.CreateCredentialRequest
}
type createUserOption struct {
userName string
password string
}
func (opt *createUserOption) Request() *milvuspb.CreateCredentialRequest {
return &milvuspb.CreateCredentialRequest{
Username: opt.userName,
Password: opt.password,
}
}
func NewCreateUserOption(userName, password string) *createUserOption {
return &createUserOption{
userName: userName,
password: password,
}
}
type UpdatePasswordOption interface {
Request() *milvuspb.UpdateCredentialRequest
}
type updatePasswordOption struct {
userName string
oldPassword string
newPassword string
}
func (opt *updatePasswordOption) Request() *milvuspb.UpdateCredentialRequest {
return &milvuspb.UpdateCredentialRequest{
Username: opt.userName,
OldPassword: opt.oldPassword,
NewPassword: opt.newPassword,
}
}
func NewUpdatePasswordOption(userName, oldPassword, newPassword string) *updatePasswordOption {
return &updatePasswordOption{
userName: userName,
oldPassword: oldPassword,
newPassword: newPassword,
}
}
type DropUserOption interface {
Request() *milvuspb.DeleteCredentialRequest
}
type dropUserOption struct {
userName string
}
func (opt *dropUserOption) Request() *milvuspb.DeleteCredentialRequest {
return &milvuspb.DeleteCredentialRequest{
Username: opt.userName,
}
}
func NewDropUserOption(userName string) *dropUserOption {
return &dropUserOption{
userName: userName,
}
}
type ListRoleOption interface {
Request() *milvuspb.SelectRoleRequest
}
type listRoleOption struct{}
func (opt *listRoleOption) Request() *milvuspb.SelectRoleRequest {
return &milvuspb.SelectRoleRequest{
IncludeUserInfo: false,
}
}
func NewListRoleOption() *listRoleOption {
return &listRoleOption{}
}
type CreateRoleOption interface {
Request() *milvuspb.CreateRoleRequest
}
type createRoleOption struct {
roleName string
}
func (opt *createRoleOption) Request() *milvuspb.CreateRoleRequest {
return &milvuspb.CreateRoleRequest{
Entity: &milvuspb.RoleEntity{Name: opt.roleName},
}
}
func NewCreateRoleOption(roleName string) *createRoleOption {
return &createRoleOption{
roleName: roleName,
}
}
type GrantRoleOption interface {
Request() *milvuspb.OperateUserRoleRequest
}
type grantRoleOption struct {
roleName string
userName string
}
func (opt *grantRoleOption) Request() *milvuspb.OperateUserRoleRequest {
return &milvuspb.OperateUserRoleRequest{
Username: opt.userName,
RoleName: opt.roleName,
Type: milvuspb.OperateUserRoleType_AddUserToRole,
}
}
func NewGrantRoleOption(userName, roleName string) *grantRoleOption {
return &grantRoleOption{
roleName: roleName,
userName: userName,
}
}
type RevokeRoleOption interface {
Request() *milvuspb.OperateUserRoleRequest
}
type revokeRoleOption struct {
roleName string
userName string
}
func (opt *revokeRoleOption) Request() *milvuspb.OperateUserRoleRequest {
return &milvuspb.OperateUserRoleRequest{
Username: opt.userName,
RoleName: opt.roleName,
Type: milvuspb.OperateUserRoleType_RemoveUserFromRole,
}
}
func NewRevokeRoleOption(userName, roleName string) *revokeRoleOption {
return &revokeRoleOption{
roleName: roleName,
userName: userName,
}
}
type DropRoleOption interface {
Request() *milvuspb.DropRoleRequest
}
type dropDropRoleOption struct {
roleName string
}
func (opt *dropDropRoleOption) Request() *milvuspb.DropRoleRequest {
return &milvuspb.DropRoleRequest{
RoleName: opt.roleName,
}
}
func NewDropRoleOption(roleName string) *dropDropRoleOption {
return &dropDropRoleOption{
roleName: roleName,
}
}
type DescribeRoleOption interface {
Request() *milvuspb.SelectGrantRequest
}
type describeRoleOption struct {
roleName string
}
func (opt *describeRoleOption) Request() *milvuspb.SelectGrantRequest {
return &milvuspb.SelectGrantRequest{
Entity: &milvuspb.GrantEntity{
Role: &milvuspb.RoleEntity{Name: opt.roleName},
},
}
}
func NewDescribeRoleOption(roleName string) *describeRoleOption {
return &describeRoleOption{
roleName: roleName,
}
}
type GrantPrivilegeOption interface {
Request() *milvuspb.OperatePrivilegeRequest
}
type grantPrivilegeOption struct {
roleName string
privilegeName string
objectName string
objectType string
}
func (opt *grantPrivilegeOption) Request() *milvuspb.OperatePrivilegeRequest {
return &milvuspb.OperatePrivilegeRequest{
Entity: &milvuspb.GrantEntity{
Role: &milvuspb.RoleEntity{Name: opt.roleName},
Grantor: &milvuspb.GrantorEntity{
Privilege: &milvuspb.PrivilegeEntity{Name: opt.privilegeName},
},
Object: &milvuspb.ObjectEntity{
Name: opt.objectType,
},
ObjectName: opt.objectName,
},
Type: milvuspb.OperatePrivilegeType_Grant,
}
}
func NewGrantPrivilegeOption(roleName, objectType, privilegeName, objectName string) *grantPrivilegeOption {
return &grantPrivilegeOption{
roleName: roleName,
privilegeName: privilegeName,
objectName: objectName,
objectType: objectType,
}
}
type RevokePrivilegeOption interface {
Request() *milvuspb.OperatePrivilegeRequest
}
type revokePrivilegeOption struct {
roleName string
privilegeName string
objectName string
objectType string
}
func (opt *revokePrivilegeOption) Request() *milvuspb.OperatePrivilegeRequest {
return &milvuspb.OperatePrivilegeRequest{
Entity: &milvuspb.GrantEntity{
Role: &milvuspb.RoleEntity{Name: opt.roleName},
Grantor: &milvuspb.GrantorEntity{
Privilege: &milvuspb.PrivilegeEntity{Name: opt.privilegeName},
},
Object: &milvuspb.ObjectEntity{
Name: opt.objectType,
},
ObjectName: opt.objectName,
},
Type: milvuspb.OperatePrivilegeType_Revoke,
}
}
func NewRevokePrivilegeOption(roleName, objectType, privilegeName, objectName string) *revokePrivilegeOption {
return &revokePrivilegeOption{
roleName: roleName,
privilegeName: privilegeName,
objectName: objectName,
objectType: objectType,
}
}
// GrantV2Option is the interface builds OperatePrivilegeV2Request
type GrantV2Option interface {
Request() *milvuspb.OperatePrivilegeV2Request

View File

@ -29,6 +29,376 @@ import (
"github.com/milvus-io/milvus/pkg/util/merr"
)
type UserSuite struct {
MockSuiteBase
}
func (s *UserSuite) TestListUsers() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
s.mock.EXPECT().ListCredUsers(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) {
return &milvuspb.ListCredUsersResponse{
Usernames: []string{"user1", "user2"},
}, nil
}).Once()
users, err := s.client.ListUsers(ctx, NewListUserOption())
s.NoError(err)
s.Equal([]string{"user1", "user2"}, users)
})
s.Run("failure", func() {
s.mock.EXPECT().ListCredUsers(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.ListUsers(ctx, NewListUserOption())
s.Error(err)
})
}
func (s *UserSuite) TestDescribeUser() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
userName := fmt.Sprintf("user_%s", s.randString(5))
s.Run("success", func() {
s.mock.EXPECT().SelectUser(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error) {
s.Equal(userName, r.GetUser().GetName())
return &milvuspb.SelectUserResponse{
Results: []*milvuspb.UserResult{
{
User: &milvuspb.UserEntity{Name: userName},
Roles: []*milvuspb.RoleEntity{
{Name: "role1"},
{Name: "role2"},
},
},
},
}, nil
}).Once()
user, err := s.client.DescribeUser(ctx, NewDescribeUserOption(userName))
s.NoError(err)
s.Equal(userName, user.UserName)
s.Equal([]string{"role1", "role2"}, user.Roles)
})
s.Run("failure", func() {
s.mock.EXPECT().SelectUser(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.DescribeUser(ctx, NewDescribeUserOption(userName))
s.Error(err)
})
}
func (s *UserSuite) TestCreateUser() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
userName := fmt.Sprintf("user_%s", s.randString(5))
password := s.randString(12)
s.mock.EXPECT().CreateCredential(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ccr *milvuspb.CreateCredentialRequest) (*commonpb.Status, error) {
s.Equal(userName, ccr.GetUsername())
s.Equal(password, ccr.GetPassword())
return merr.Success(), nil
}).Once()
err := s.client.CreateUser(ctx, NewCreateUserOption(userName, password))
s.NoError(err)
})
}
func (s *UserSuite) TestUpdatePassword() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
userName := fmt.Sprintf("user_%s", s.randString(5))
oldPassword := s.randString(12)
newPassword := s.randString(12)
s.mock.EXPECT().UpdateCredential(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ucr *milvuspb.UpdateCredentialRequest) (*commonpb.Status, error) {
s.Equal(userName, ucr.GetUsername())
s.Equal(oldPassword, ucr.GetOldPassword())
s.Equal(newPassword, ucr.GetNewPassword())
return merr.Success(), nil
}).Once()
err := s.client.UpdatePassword(ctx, NewUpdatePasswordOption(userName, oldPassword, newPassword))
s.NoError(err)
})
s.Run("failure", func() {
s.mock.EXPECT().UpdateCredential(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.UpdatePassword(ctx, NewUpdatePasswordOption("user", "old", "new"))
s.Error(err)
})
}
func (s *UserSuite) TestDropUser() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
userName := fmt.Sprintf("user_%s", s.randString(5))
s.mock.EXPECT().DeleteCredential(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dcr *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error) {
s.Equal(userName, dcr.GetUsername())
return merr.Success(), nil
}).Once()
err := s.client.DropUser(ctx, NewDropUserOption(userName))
s.NoError(err)
})
s.Run("failure", func() {
s.mock.EXPECT().DeleteCredential(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.DropUser(ctx, NewDropUserOption("user"))
s.Error(err)
})
}
func TestUserRBAC(t *testing.T) {
suite.Run(t, new(UserSuite))
}
type RoleSuite struct {
MockSuiteBase
}
func (s *RoleSuite) TestListRoles() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
s.mock.EXPECT().SelectRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) {
return &milvuspb.SelectRoleResponse{
Results: []*milvuspb.RoleResult{
{Role: &milvuspb.RoleEntity{Name: "role1"}},
{Role: &milvuspb.RoleEntity{Name: "role2"}},
},
}, nil
}).Once()
roles, err := s.client.ListRoles(ctx, NewListRoleOption())
s.NoError(err)
s.Equal([]string{"role1", "role2"}, roles)
})
s.Run("failure", func() {
s.mock.EXPECT().SelectRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.ListRoles(ctx, NewListRoleOption())
s.Error(err)
})
}
func (s *RoleSuite) TestCreateRole() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
roleName := fmt.Sprintf("role_%s", s.randString(5))
s.mock.EXPECT().CreateRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.CreateRoleRequest) (*commonpb.Status, error) {
s.Equal(roleName, r.GetEntity().GetName())
return merr.Success(), nil
}).Once()
err := s.client.CreateRole(ctx, NewCreateRoleOption(roleName))
s.NoError(err)
})
s.Run("failure", func() {
s.mock.EXPECT().CreateRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.CreateRole(ctx, NewCreateRoleOption("role"))
s.Error(err)
})
}
func (s *RoleSuite) TestGrantRole() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
userName := fmt.Sprintf("user_%s", s.randString(5))
roleName := fmt.Sprintf("role_%s", s.randString(5))
s.mock.EXPECT().OperateUserRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) {
s.Equal(userName, r.GetUsername())
s.Equal(roleName, r.GetRoleName())
return merr.Success(), nil
}).Once()
err := s.client.GrantRole(ctx, NewGrantRoleOption(userName, roleName))
s.NoError(err)
})
s.Run("failure", func() {
s.mock.EXPECT().OperateUserRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.GrantRole(ctx, NewGrantRoleOption("user", "role"))
s.Error(err)
})
}
func (s *RoleSuite) TestRevokeRole() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
userName := fmt.Sprintf("user_%s", s.randString(5))
roleName := fmt.Sprintf("role_%s", s.randString(5))
s.mock.EXPECT().OperateUserRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) {
s.Equal(userName, r.GetUsername())
s.Equal(roleName, r.GetRoleName())
return merr.Success(), nil
}).Once()
err := s.client.RevokeRole(ctx, NewRevokeRoleOption(userName, roleName))
s.NoError(err)
})
s.Run("failure", func() {
s.mock.EXPECT().OperateUserRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.RevokeRole(ctx, NewRevokeRoleOption("user", "role"))
s.Error(err)
})
}
func (s *RoleSuite) TestDropRole() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
roleName := fmt.Sprintf("role_%s", s.randString(5))
s.mock.EXPECT().DropRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.DropRoleRequest) (*commonpb.Status, error) {
s.Equal(roleName, r.GetRoleName())
return merr.Success(), nil
}).Once()
err := s.client.DropRole(ctx, NewDropRoleOption(roleName))
s.NoError(err)
})
s.Run("failure", func() {
s.mock.EXPECT().DropRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.DropRole(ctx, NewDropRoleOption("role"))
s.Error(err)
})
}
func (s *RoleSuite) TestDescribeRole() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
roleName := fmt.Sprintf("role_%s", s.randString(5))
s.mock.EXPECT().SelectGrant(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) {
s.Equal(roleName, r.GetEntity().GetRole().GetName())
return &milvuspb.SelectGrantResponse{
Entities: []*milvuspb.GrantEntity{
{
ObjectName: "*",
Object: &milvuspb.ObjectEntity{
Name: "collection",
},
Role: &milvuspb.RoleEntity{Name: roleName},
Grantor: &milvuspb.GrantorEntity{User: &milvuspb.UserEntity{Name: "admin"}, Privilege: &milvuspb.PrivilegeEntity{Name: "Insert"}},
},
{
ObjectName: "*",
Object: &milvuspb.ObjectEntity{
Name: "collection",
},
Role: &milvuspb.RoleEntity{Name: roleName},
Grantor: &milvuspb.GrantorEntity{User: &milvuspb.UserEntity{Name: "admin"}, Privilege: &milvuspb.PrivilegeEntity{Name: "Query"}},
},
},
}, nil
}).Once()
role, err := s.client.DescribeRole(ctx, NewDescribeRoleOption(roleName))
s.NoError(err)
s.Equal(roleName, role.RoleName)
})
s.Run("failure", func() {
s.mock.EXPECT().SelectGrant(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.DescribeRole(ctx, NewDescribeRoleOption("role"))
s.Error(err)
})
}
func (s *RoleSuite) TestGrantPrivilege() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
roleName := fmt.Sprintf("role_%s", s.randString(5))
privilegeName := "Insert"
collectionName := fmt.Sprintf("collection_%s", s.randString(6))
s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) {
s.Equal(roleName, r.GetEntity().GetRole().GetName())
s.Equal("collection", r.GetEntity().GetObject().GetName())
s.Equal(privilegeName, r.GetEntity().GetGrantor().GetPrivilege().GetName())
s.Equal(collectionName, r.GetEntity().GetObjectName())
s.Equal(milvuspb.OperatePrivilegeType_Grant, r.GetType())
return merr.Success(), nil
}).Once()
err := s.client.GrantPrivilege(ctx, NewGrantPrivilegeOption(roleName, "collection", privilegeName, collectionName))
s.NoError(err)
})
s.Run("failure", func() {
s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.GrantPrivilege(ctx, NewGrantPrivilegeOption("role", "collection", "privilege", "coll_1"))
s.Error(err)
})
}
func (s *RoleSuite) TestRevokePrivilege() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
roleName := fmt.Sprintf("role_%s", s.randString(5))
privilegeName := "Insert"
collectionName := fmt.Sprintf("collection_%s", s.randString(6))
s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) {
s.Equal(roleName, r.GetEntity().GetRole().GetName())
s.Equal("collection", r.GetEntity().GetObject().GetName())
s.Equal(privilegeName, r.GetEntity().GetGrantor().GetPrivilege().GetName())
s.Equal(collectionName, r.GetEntity().GetObjectName())
s.Equal(milvuspb.OperatePrivilegeType_Revoke, r.GetType())
return merr.Success(), nil
}).Once()
err := s.client.RevokePrivilege(ctx, NewRevokePrivilegeOption(roleName, "collection", privilegeName, collectionName))
s.NoError(err)
})
s.Run("failure", func() {
s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.RevokePrivilege(ctx, NewRevokePrivilegeOption("role", "collection", "privilege", "coll_1"))
s.Error(err)
})
}
func TestRoleRBAC(t *testing.T) {
suite.Run(t, new(RoleSuite))
}
type PrivilgeGroupSuite struct {
MockSuiteBase
}

View File

@ -192,6 +192,10 @@ func (c *Client) Query(ctx context.Context, option QueryOption, callOptions ...g
return resultSet, err
}
func (c *Client) Get(ctx context.Context, option QueryOption, callOptions ...grpc.CallOption) (ResultSet, error) {
return c.Query(ctx, option, callOptions...)
}
func (c *Client) HybridSearch(ctx context.Context, option HybridSearchOption, callOptions ...grpc.CallOption) ([]ResultSet, error) {
req, err := option.HybridRequest()
if err != nil {

View File

@ -52,8 +52,7 @@ func (s *SearchOptionSuite) TestBasic() {
topK := rand.Intn(100) + 1
opt := NewSearchOption(collName, topK, []entity.Vector{entity.FloatVector([]float32{0.1, 0.2})})
opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000")
opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000").WithGroupByField("group_field").WithGroupSize(10).WithStrictGroupSize(true)
req, err := opt.Request()
s.Require().NoError(err)
@ -64,6 +63,15 @@ func (s *SearchOptionSuite) TestBasic() {
annField, ok := searchParams[spAnnsField]
s.Require().True(ok)
s.Equal("test_field", annField)
groupField, ok := searchParams[spGroupBy]
s.Require().True(ok)
s.Equal("group_field", groupField)
groupSize, ok := searchParams[spGroupSize]
s.Require().True(ok)
s.Equal("10", groupSize)
spStrictGroupSize, ok := searchParams[spStrictGroupSize]
s.Require().True(ok)
s.Equal("true", spStrictGroupSize)
opt = NewSearchOption(collName, topK, []entity.Vector{nonSupportData{}})
_, err = opt.Request()

View File

@ -21,6 +21,7 @@ import (
"fmt"
"reflect"
"strconv"
"strings"
"github.com/cockroachdb/errors"
"google.golang.org/protobuf/proto"
@ -28,20 +29,23 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/column"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/index"
)
const (
spAnnsField = `anns_field`
spTopK = `topk`
spOffset = `offset`
spLimit = `limit`
spParams = `params`
spMetricsType = `metric_type`
spRoundDecimal = `round_decimal`
spIgnoreGrowing = `ignore_growing`
spGroupBy = `group_by_field`
spAnnsField = `anns_field`
spTopK = `topk`
spOffset = `offset`
spLimit = `limit`
spParams = `params`
spMetricsType = `metric_type`
spRoundDecimal = `round_decimal`
spIgnoreGrowing = `ignore_growing`
spGroupBy = `group_by_field`
spGroupSize = `group_size`
spStrictGroupSize = `strict_group_size`
)
type SearchOption interface {
@ -62,16 +66,18 @@ type searchOption struct {
type annRequest struct {
vectors []entity.Vector
annField string
metricsType entity.MetricType
searchParam map[string]string
groupByField string
annParam index.AnnParam
ignoreGrowing bool
expr string
topK int
offset int
templateParams map[string]any
annField string
metricsType entity.MetricType
searchParam map[string]string
groupByField string
groupSize int
strictGroupSize bool
annParam index.AnnParam
ignoreGrowing bool
expr string
topK int
offset int
templateParams map[string]any
}
func NewAnnRequest(annField string, limit int, vectors ...entity.Vector) *annRequest {
@ -108,6 +114,12 @@ func (r *annRequest) searchRequest() (*milvuspb.SearchRequest, error) {
if r.groupByField != "" {
params[spGroupBy] = r.groupByField
}
if r.groupSize != 0 {
params[spGroupSize] = strconv.Itoa(r.groupSize)
}
if r.strictGroupSize {
params[spStrictGroupSize] = "true"
}
// ann param
if r.annParam != nil {
bs, _ := json.Marshal(r.annParam.Params())
@ -223,6 +235,16 @@ func (r *annRequest) WithGroupByField(groupByField string) *annRequest {
return r
}
func (r *annRequest) WithGroupSize(groupSize int) *annRequest {
r.groupSize = groupSize
return r
}
func (r *annRequest) WithStrictGroupSize(strictGroupSize bool) *annRequest {
r.strictGroupSize = strictGroupSize
return r
}
func (r *annRequest) WithSearchParam(key, value string) *annRequest {
r.searchParam[key] = value
return r
@ -309,6 +331,16 @@ func (opt *searchOption) WithGroupByField(groupByField string) *searchOption {
return opt
}
func (opt *searchOption) WithGroupSize(groupSize int) *searchOption {
opt.annRequest.WithGroupSize(groupSize)
return opt
}
func (opt *searchOption) WithStrictGroupSize(strictGroupSize bool) *searchOption {
opt.annRequest.WithStrictGroupSize(strictGroupSize)
return opt
}
func (opt *searchOption) WithIgnoreGrowing(ignoreGrowing bool) *searchOption {
opt.annRequest.WithIgnoreGrowing(ignoreGrowing)
return opt
@ -550,6 +582,27 @@ func (opt *queryOption) WithPartitions(partitionNames ...string) *queryOption {
return opt
}
func (opt *queryOption) WithIDs(ids column.Column) *queryOption {
opt.expr = pks2Expr(ids)
return opt
}
func pks2Expr(ids column.Column) string {
var expr string
pkName := ids.Name()
switch ids.Type() {
case entity.FieldTypeInt64:
expr = fmt.Sprintf("%s in %s", pkName, strings.Join(strings.Fields(fmt.Sprint(ids.FieldData().GetScalars().GetLongData().GetData())), ","))
case entity.FieldTypeVarChar:
data := ids.FieldData().GetScalars().GetData().(*schemapb.ScalarField_StringData).StringData.GetData()
for i := range data {
data[i] = fmt.Sprintf("\"%s\"", data[i])
}
expr = fmt.Sprintf("%s in [%s]", pkName, strings.Join(data, ","))
}
return expr
}
func NewQueryOption(collectionName string) *queryOption {
return &queryOption{
collectionName: collectionName,

View File

@ -0,0 +1,65 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package milvusclient
import (
"context"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/util/merr"
)
func (c *Client) ListResourceGroups(ctx context.Context, opt ListResourceGroupsOption, callOptions ...grpc.CallOption) ([]string, error) {
req := opt.Request()
var rgs []string
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.ListResourceGroups(ctx, req, callOptions...)
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
rgs = resp.GetResourceGroups()
return nil
})
return rgs, err
}
func (c *Client) CreateResourceGroup(ctx context.Context, opt CreateResourceGroupOption, callOptions ...grpc.CallOption) error {
req := opt.Request()
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.CreateResourceGroup(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err)
})
return err
}
func (c *Client) DropResourceGroup(ctx context.Context, opt DropResourceGroupOption, callOptions ...grpc.CallOption) error {
req := opt.Request()
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DropResourceGroup(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err)
})
return err
}

View File

@ -0,0 +1,92 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package milvusclient
import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/rgpb"
)
type ListResourceGroupsOption interface {
Request() *milvuspb.ListResourceGroupsRequest
}
type listResourceGroupsOption struct{}
func (opt *listResourceGroupsOption) Request() *milvuspb.ListResourceGroupsRequest {
return &milvuspb.ListResourceGroupsRequest{}
}
func NewListResourceGroupsOption() *listResourceGroupsOption {
return &listResourceGroupsOption{}
}
type CreateResourceGroupOption interface {
Request() *milvuspb.CreateResourceGroupRequest
}
type createResourceGroupOption struct {
name string
nodeRequest int
nodeLimit int
}
func (opt *createResourceGroupOption) WithNodeRequest(nodeRequest int) *createResourceGroupOption {
opt.nodeRequest = nodeRequest
return opt
}
func (opt *createResourceGroupOption) WithNodeLimit(nodeLimit int) *createResourceGroupOption {
opt.nodeLimit = nodeLimit
return opt
}
func (opt *createResourceGroupOption) Request() *milvuspb.CreateResourceGroupRequest {
return &milvuspb.CreateResourceGroupRequest{
ResourceGroup: opt.name,
Config: &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{
NodeNum: int32(opt.nodeRequest),
},
Limits: &rgpb.ResourceGroupLimit{
NodeNum: int32(opt.nodeLimit),
},
},
}
}
func NewCreateResourceGroupOption(name string) *createResourceGroupOption {
return &createResourceGroupOption{name: name}
}
type DropResourceGroupOption interface {
Request() *milvuspb.DropResourceGroupRequest
}
type dropResourceGroupOption struct {
name string
}
func (opt *dropResourceGroupOption) Request() *milvuspb.DropResourceGroupRequest {
return &milvuspb.DropResourceGroupRequest{
ResourceGroup: opt.name,
}
}
func NewDropResourceGroupOption(name string) *dropResourceGroupOption {
return &dropResourceGroupOption{name: name}
}

View File

@ -0,0 +1,108 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package milvusclient
import (
"context"
"fmt"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
)
type ResourceGroupSuite struct {
MockSuiteBase
}
func (s *ResourceGroupSuite) TestListResourceGroups() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
s.mock.EXPECT().ListResourceGroups(mock.Anything, mock.Anything).Return(&milvuspb.ListResourceGroupsResponse{
ResourceGroups: []string{"rg1", "rg2"},
}, nil).Once()
rgs, err := s.client.ListResourceGroups(ctx, NewListResourceGroupsOption())
s.NoError(err)
s.Equal([]string{"rg1", "rg2"}, rgs)
})
s.Run("failure", func() {
s.mock.EXPECT().ListResourceGroups(mock.Anything, mock.Anything).Return(nil, errors.New("mocked")).Once()
_, err := s.client.ListResourceGroups(ctx, NewListResourceGroupsOption())
s.Error(err)
})
}
func (s *ResourceGroupSuite) TestCreateResourceGroup() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
rgName := fmt.Sprintf("rg_%s", s.randString(6))
s.mock.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, crgr *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) {
s.Equal(rgName, crgr.GetResourceGroup())
s.Equal(int32(5), crgr.GetConfig().GetRequests().GetNodeNum())
s.Equal(int32(10), crgr.GetConfig().GetLimits().GetNodeNum())
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
}).Once()
opt := NewCreateResourceGroupOption(rgName).WithNodeLimit(10).WithNodeRequest(5)
err := s.client.CreateResourceGroup(ctx, opt)
s.NoError(err)
})
s.Run("failure", func() {
rgName := fmt.Sprintf("rg_%s", s.randString(6))
s.mock.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything).Return(nil, errors.New("mocked")).Once()
opt := NewCreateResourceGroupOption(rgName).WithNodeLimit(10).WithNodeRequest(5)
err := s.client.CreateResourceGroup(ctx, opt)
s.Error(err)
})
}
func (s *ResourceGroupSuite) TestDropResourceGroup() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
rgName := fmt.Sprintf("rg_%s", s.randString(6))
s.mock.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, drgr *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) {
s.Equal(rgName, drgr.GetResourceGroup())
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
}).Once()
opt := NewDropResourceGroupOption(rgName)
err := s.client.DropResourceGroup(ctx, opt)
s.NoError(err)
})
s.Run("failure", func() {
rgName := fmt.Sprintf("rg_%s", s.randString(6))
s.mock.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(nil, errors.New("mocked")).Once()
opt := NewDropResourceGroupOption(rgName)
err := s.client.DropResourceGroup(ctx, opt)
s.Error(err)
})
}
func TestResourceGroup(t *testing.T) {
suite.Run(t, new(ResourceGroupSuite))
}

View File

@ -89,8 +89,8 @@ func (mc *MilvusClient) Close(ctx context.Context) error {
// -- database --
// UsingDatabase list all database in milvus cluster.
func (mc *MilvusClient) UsingDatabase(ctx context.Context, option client.UsingDatabaseOption) error {
err := mc.mClient.UsingDatabase(ctx, option)
func (mc *MilvusClient) UsingDatabase(ctx context.Context, option client.UseDatabaseOption) error {
err := mc.mClient.UseDatabase(ctx, option)
return err
}

View File

@ -52,7 +52,7 @@ require (
github.com/kr/text v0.2.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69 // indirect
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/opencontainers/runtime-spec v1.0.2 // indirect

View File

@ -318,8 +318,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69 h1:Qt0Bv2Fum3EX3OlkuQYHJINBzeU4oEuHy2lXSfB/gZw=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b h1:iPPhnFx+s7FF53UeWj7A4EYhPRMFPL6mHqyQw7qRjeQ=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84 h1:EAFxmxUVp5yYFDCrX1MQoSxkTO+ycy8NXEqEDEB3cRM=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84/go.mod h1:RATa0GS4jhkPpsYOvQ/QvcNz8rd+TlRPDiSyXQnMMxs=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=

View File

@ -26,7 +26,7 @@ func teardownTest(t *testing.T) func(t *testing.T) {
dbs, _ := mc.ListDatabases(ctx, client.NewListDatabaseOption())
for _, db := range dbs {
if db != common.DefaultDb {
_ = mc.UsingDatabase(ctx, client.NewUsingDatabaseOption(db))
_ = mc.UsingDatabase(ctx, client.NewUseDatabaseOption(db))
collections, _ := mc.ListCollections(ctx, client.NewListCollectionOption())
for _, coll := range collections {
_ = mc.DropCollection(ctx, client.NewDropCollectionOption(coll))
@ -57,7 +57,7 @@ func TestDatabase(t *testing.T) {
// new client with db1 -> using db
clientDB1 := createMilvusClient(ctx, t, &client.ClientConfig{Address: *addr, DBName: dbName1})
t.Log("https://github.com/milvus-io/milvus/issues/34137")
err = clientDB1.UsingDatabase(ctx, client.NewUsingDatabaseOption(dbName1))
err = clientDB1.UsingDatabase(ctx, client.NewUseDatabaseOption(dbName1))
common.CheckErr(t, err, true)
// create collections -> verify collections contains
@ -77,14 +77,14 @@ func TestDatabase(t *testing.T) {
require.Containsf(t, dbs, dbName2, fmt.Sprintf("%s db not in dbs: %v", dbName2, dbs))
// using db2 -> create collection -> drop collection
err = clientDefault.UsingDatabase(ctx, client.NewUsingDatabaseOption(dbName2))
err = clientDefault.UsingDatabase(ctx, client.NewUseDatabaseOption(dbName2))
common.CheckErr(t, err, true)
_, db2Col1 := hp.CollPrepare.CreateCollection(ctx, t, clientDefault, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption())
err = clientDefault.DropCollection(ctx, client.NewDropCollectionOption(db2Col1.CollectionName))
common.CheckErr(t, err, true)
// using empty db -> drop db2
clientDefault.UsingDatabase(ctx, client.NewUsingDatabaseOption(""))
clientDefault.UsingDatabase(ctx, client.NewUseDatabaseOption(""))
err = clientDefault.DropDatabase(ctx, client.NewDropDatabaseOption(dbName2))
common.CheckErr(t, err, true)
@ -98,7 +98,7 @@ func TestDatabase(t *testing.T) {
common.CheckErr(t, err, false, "must drop all collections before drop database")
// drop all db1's collections -> drop db1
clientDB1.UsingDatabase(ctx, client.NewUsingDatabaseOption(dbName1))
clientDB1.UsingDatabase(ctx, client.NewUseDatabaseOption(dbName1))
err = clientDB1.DropCollection(ctx, client.NewDropCollectionOption(db1Col1.CollectionName))
common.CheckErr(t, err, true)
@ -160,7 +160,7 @@ func TestDropDb(t *testing.T) {
common.CheckErr(t, err, true)
// using db and drop the db
err = mc.UsingDatabase(ctx, client.NewUsingDatabaseOption(dbName))
err = mc.UsingDatabase(ctx, client.NewUseDatabaseOption(dbName))
common.CheckErr(t, err, true)
err = mc.DropDatabase(ctx, client.NewDropDatabaseOption(dbName))
common.CheckErr(t, err, true)
@ -170,7 +170,7 @@ func TestDropDb(t *testing.T) {
common.CheckErr(t, err, false, fmt.Sprintf("database not found[database=%s]", dbName))
// using default db and verify collections
err = mc.UsingDatabase(ctx, client.NewUsingDatabaseOption(common.DefaultDb))
err = mc.UsingDatabase(ctx, client.NewUseDatabaseOption(common.DefaultDb))
common.CheckErr(t, err, true)
collections, _ = mc.ListCollections(ctx, listCollOpt)
require.Contains(t, collections, defCol.CollectionName)
@ -205,17 +205,17 @@ func TestUsingDb(t *testing.T) {
// using not existed db
dbName := common.GenRandomString("db", 4)
err := mc.UsingDatabase(ctx, client.NewUsingDatabaseOption(dbName))
err := mc.UsingDatabase(ctx, client.NewUseDatabaseOption(dbName))
common.CheckErr(t, err, false, fmt.Sprintf("database not found[database=%s]", dbName))
// using empty db
err = mc.UsingDatabase(ctx, client.NewUsingDatabaseOption(""))
err = mc.UsingDatabase(ctx, client.NewUseDatabaseOption(""))
common.CheckErr(t, err, true)
collections, _ = mc.ListCollections(ctx, listCollOpt)
require.Contains(t, collections, col.CollectionName)
// using current db
err = mc.UsingDatabase(ctx, client.NewUsingDatabaseOption(common.DefaultDb))
err = mc.UsingDatabase(ctx, client.NewUseDatabaseOption(common.DefaultDb))
common.CheckErr(t, err, true)
collections, _ = mc.ListCollections(ctx, listCollOpt)
require.Contains(t, collections, col.CollectionName)
@ -262,7 +262,7 @@ func TestClientWithDb(t *testing.T) {
require.Containsf(t, dbCollections, dbCol1.CollectionName, fmt.Sprintf("The collection %s not in: %v", dbCol1.CollectionName, dbCollections))
// using default db and collection not in
_ = mcDb.UsingDatabase(ctx, client.NewUsingDatabaseOption(common.DefaultDb))
_ = mcDb.UsingDatabase(ctx, client.NewUseDatabaseOption(common.DefaultDb))
defCollections, _ = mcDb.ListCollections(ctx, listCollOpt)
require.NotContains(t, defCollections, dbCol1.CollectionName)

View File

@ -36,7 +36,7 @@ func teardown() {
dbs, _ := mc.ListDatabases(ctx, clientv2.NewListDatabaseOption())
for _, db := range dbs {
if db != common.DefaultDb {
_ = mc.UsingDatabase(ctx, clientv2.NewUsingDatabaseOption(db))
_ = mc.UsingDatabase(ctx, clientv2.NewUseDatabaseOption(db))
collections, _ := mc.ListCollections(ctx, clientv2.NewListCollectionOption())
for _, coll := range collections {
_ = mc.DropCollection(ctx, clientv2.NewDropCollectionOption(coll))