mirror of https://github.com/milvus-io/milvus.git
enhance: [GoSDK] Sync API names and add missing APIs (#38603)
Related to #31293 - Rename `UsingDatabase` to `UseDatabase` - Uncomment default value methods - Add missing RBAC APIs - Add some resource group APIs --------- Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>pull/38610/head
parent
ca7ec23198
commit
c39db11509
|
@ -16,6 +16,8 @@
|
|||
|
||||
package entity
|
||||
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
|
||||
// MetricType metric type
|
||||
type MetricType string
|
||||
|
||||
|
@ -31,3 +33,12 @@ const (
|
|||
SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE"
|
||||
BM25 MetricType = "BM25"
|
||||
)
|
||||
|
||||
// CompactionState enum type for compaction state
|
||||
type CompactionState commonpb.CompactionState
|
||||
|
||||
// CompactionState Constants
|
||||
const (
|
||||
CompactionStateRunning CompactionState = CompactionState(commonpb.CompactionState_Executing)
|
||||
CompactionStateCompleted CompactionState = CompactionState(commonpb.CompactionState_Completed)
|
||||
)
|
||||
|
|
|
@ -193,6 +193,8 @@ type Field struct {
|
|||
IsPartitionKey bool
|
||||
IsClusteringKey bool
|
||||
ElementType FieldType
|
||||
DefaultValue *schemapb.ValueField
|
||||
Nullable bool
|
||||
}
|
||||
|
||||
// ProtoMessage generates corresponding FieldSchema
|
||||
|
@ -261,7 +263,11 @@ func (f *Field) WithIsClusteringKey(isClusteringKey bool) *Field {
|
|||
return f
|
||||
}
|
||||
|
||||
/*
|
||||
func (f *Field) WithNullable(nullable bool) *Field {
|
||||
f.Nullable = nullable
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueBool(defaultValue bool) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_BoolData{
|
||||
|
@ -314,7 +320,7 @@ func (f *Field) WithDefaultValueString(defaultValue string) *Field {
|
|||
},
|
||||
}
|
||||
return f
|
||||
}*/
|
||||
}
|
||||
|
||||
func (f *Field) WithTypeParams(key string, value string) *Field {
|
||||
if f.TypeParams == nil {
|
||||
|
|
|
@ -30,13 +30,13 @@ func TestFieldSchema(t *testing.T) {
|
|||
NewField().WithName("array_field").WithDataType(FieldTypeArray).WithElementType(FieldTypeBool).WithMaxCapacity(128),
|
||||
NewField().WithName("clustering_key").WithDataType(FieldTypeInt32).WithIsClusteringKey(true),
|
||||
NewField().WithName("varchar_text").WithDataType(FieldTypeVarChar).WithMaxLength(65535).WithEnableAnalyzer(true).WithAnalyzerParams(map[string]any{}),
|
||||
/*
|
||||
NewField().WithName("default_value_bool").WithDataType(FieldTypeBool).WithDefaultValueBool(true),
|
||||
NewField().WithName("default_value_int").WithDataType(FieldTypeInt32).WithDefaultValueInt(1),
|
||||
NewField().WithName("default_value_long").WithDataType(FieldTypeInt64).WithDefaultValueLong(1),
|
||||
NewField().WithName("default_value_float").WithDataType(FieldTypeFloat).WithDefaultValueFloat(1),
|
||||
NewField().WithName("default_value_double").WithDataType(FieldTypeDouble).WithDefaultValueDouble(1),
|
||||
NewField().WithName("default_value_string").WithDataType(FieldTypeString).WithDefaultValueString("a"),*/
|
||||
|
||||
NewField().WithName("default_value_bool").WithDataType(FieldTypeBool).WithDefaultValueBool(true),
|
||||
NewField().WithName("default_value_int").WithDataType(FieldTypeInt32).WithDefaultValueInt(1),
|
||||
NewField().WithName("default_value_long").WithDataType(FieldTypeInt64).WithDefaultValueLong(1),
|
||||
NewField().WithName("default_value_float").WithDataType(FieldTypeFloat).WithDefaultValueFloat(1),
|
||||
NewField().WithName("default_value_double").WithDataType(FieldTypeDouble).WithDefaultValueDouble(1),
|
||||
NewField().WithName("default_value_string").WithDataType(FieldTypeString).WithDefaultValueString("a"),
|
||||
}
|
||||
|
||||
for _, field := range fields {
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
|
||||
type LoadStateCode commonpb.LoadState
|
||||
|
||||
const (
|
||||
// LoadStateNone LoadStateCode = LoadStateCode(commonpb.LoadState)
|
||||
LoadStateLoading LoadStateCode = LoadStateCode(commonpb.LoadState_LoadStateLoading)
|
||||
LoadStateLoaded LoadStateCode = LoadStateCode(commonpb.LoadState_LoadStateLoaded)
|
||||
LoadStateUnloading LoadStateCode = LoadStateCode(commonpb.LoadState_LoadStateNotExist)
|
||||
LoadStateNotLoad LoadStateCode = LoadStateCode(commonpb.LoadState_LoadStateNotLoad)
|
||||
)
|
||||
|
||||
type LoadState struct {
|
||||
State LoadStateCode
|
||||
Progress int64
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
type User struct {
|
||||
UserName string
|
||||
Roles []string
|
||||
}
|
||||
|
||||
type Role struct {
|
||||
RoleName string
|
||||
Privileges []GrantItem
|
||||
}
|
||||
|
||||
type GrantItem struct {
|
||||
Object string
|
||||
ObjectName string
|
||||
RoleName string
|
||||
Grantor string
|
||||
Privilege string
|
||||
}
|
|
@ -6,7 +6,7 @@ require (
|
|||
github.com/blang/semver/v4 v4.0.0
|
||||
github.com/cockroachdb/errors v1.9.1
|
||||
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b
|
||||
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84
|
||||
github.com/quasilyte/go-ruleguard/dsl v0.3.22
|
||||
github.com/samber/lo v1.27.0
|
||||
|
|
|
@ -318,8 +318,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr
|
|||
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
|
||||
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
|
||||
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69 h1:Qt0Bv2Fum3EX3OlkuQYHJINBzeU4oEuHy2lXSfB/gZw=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b h1:iPPhnFx+s7FF53UeWj7A4EYhPRMFPL6mHqyQw7qRjeQ=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
|
||||
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84 h1:EAFxmxUVp5yYFDCrX1MQoSxkTO+ycy8NXEqEDEB3cRM=
|
||||
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84/go.mod h1:RATa0GS4jhkPpsYOvQ/QvcNz8rd+TlRPDiSyXQnMMxs=
|
||||
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
|
||||
|
|
|
@ -147,3 +147,23 @@ func (c *Client) AlterCollection(ctx context.Context, option AlterCollectionOpti
|
|||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
}
|
||||
|
||||
type GetCollectionOption interface {
|
||||
Request() *milvuspb.GetCollectionStatisticsRequest
|
||||
}
|
||||
|
||||
func (c *Client) GetCollectionStats(ctx context.Context, opt GetCollectionOption) (map[string]string, error) {
|
||||
var stats map[string]string
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.GetCollectionStatistics(ctx, opt.Request())
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
stats = entity.KvPairsMap(resp.GetStats())
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
|
|
@ -310,3 +310,17 @@ func (opt *alterCollectionOption) Request() *milvuspb.AlterCollectionRequest {
|
|||
func NewAlterCollectionOption(collection string) *alterCollectionOption {
|
||||
return &alterCollectionOption{collectionName: collection, properties: make(map[string]string)}
|
||||
}
|
||||
|
||||
type getCollectionStatsOption struct {
|
||||
collectionName string
|
||||
}
|
||||
|
||||
func (opt *getCollectionStatsOption) Request() *milvuspb.GetCollectionStatisticsRequest {
|
||||
return &milvuspb.GetCollectionStatisticsRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
}
|
||||
}
|
||||
|
||||
func NewGetCollectionStatsOption(collectionName string) *getCollectionStatsOption {
|
||||
return &getCollectionStatsOption{collectionName: collectionName}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
@ -315,6 +316,38 @@ func (s *CollectionSuite) TestAlterCollection() {
|
|||
})
|
||||
}
|
||||
|
||||
func (s *CollectionSuite) TestGetCollectionStats() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
s.mock.EXPECT().GetCollectionStatistics(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, gcsr *milvuspb.GetCollectionStatisticsRequest) (*milvuspb.GetCollectionStatisticsResponse, error) {
|
||||
s.Equal(collName, gcsr.GetCollectionName())
|
||||
return &milvuspb.GetCollectionStatisticsResponse{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
Stats: []*commonpb.KeyValuePair{
|
||||
{Key: "row_count", Value: "1000"},
|
||||
},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
stats, err := s.client.GetCollectionStats(ctx, NewGetCollectionStatsOption(collName))
|
||||
s.NoError(err)
|
||||
|
||||
s.Len(stats, 1)
|
||||
s.Equal("1000", stats["row_count"])
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
s.mock.EXPECT().GetCollectionStatistics(mock.Anything, mock.Anything).Return(nil, errors.New("mocked")).Once()
|
||||
|
||||
_, err := s.client.GetCollectionStats(ctx, NewGetCollectionStatsOption(collName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCollection(t *testing.T) {
|
||||
suite.Run(t, new(CollectionSuite))
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
func (c *Client) UsingDatabase(ctx context.Context, option UsingDatabaseOption) error {
|
||||
func (c *Client) UseDatabase(ctx context.Context, option UseDatabaseOption) error {
|
||||
dbName := option.DbName()
|
||||
c.usingDatabase(dbName)
|
||||
return c.connectInternal(ctx)
|
||||
|
|
|
@ -18,20 +18,20 @@ package milvusclient
|
|||
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
|
||||
type UsingDatabaseOption interface {
|
||||
type UseDatabaseOption interface {
|
||||
DbName() string
|
||||
}
|
||||
|
||||
type usingDatabaseNameOpt struct {
|
||||
type useDatabaseNameOpt struct {
|
||||
dbName string
|
||||
}
|
||||
|
||||
func (opt *usingDatabaseNameOpt) DbName() string {
|
||||
func (opt *useDatabaseNameOpt) DbName() string {
|
||||
return opt.dbName
|
||||
}
|
||||
|
||||
func NewUsingDatabaseOption(dbName string) *usingDatabaseNameOpt {
|
||||
return &usingDatabaseNameOpt{
|
||||
func NewUseDatabaseOption(dbName string) *useDatabaseNameOpt {
|
||||
return &useDatabaseNameOpt{
|
||||
dbName: dbName,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -88,6 +88,26 @@ func (s *DatabaseSuite) TestDropDatabase() {
|
|||
})
|
||||
}
|
||||
|
||||
func (s *DatabaseSuite) TestUseDatabase() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
dbName := fmt.Sprintf("dt_%s", s.randString(6))
|
||||
s.mock.EXPECT().Connect(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cr *milvuspb.ConnectRequest) (*milvuspb.ConnectResponse, error) {
|
||||
return &milvuspb.ConnectResponse{
|
||||
Status: merr.Success(),
|
||||
ServerInfo: &commonpb.ServerInfo{},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.UseDatabase(ctx, NewUseDatabaseOption(dbName))
|
||||
s.NoError(err)
|
||||
|
||||
s.Equal(dbName, s.client.currentDB)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDatabase(t *testing.T) {
|
||||
suite.Run(t, new(DatabaseSuite))
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
|
@ -31,6 +32,7 @@ type LoadTask struct {
|
|||
collectionName string
|
||||
partitionNames []string
|
||||
interval time.Duration
|
||||
refresh bool
|
||||
}
|
||||
|
||||
func (t *LoadTask) Await(ctx context.Context) error {
|
||||
|
@ -40,6 +42,7 @@ func (t *LoadTask) Await(ctx context.Context) error {
|
|||
select {
|
||||
case <-timer.C:
|
||||
loaded := false
|
||||
refreshed := false
|
||||
err := t.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
||||
CollectionName: t.collectionName,
|
||||
|
@ -49,12 +52,13 @@ func (t *LoadTask) Await(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
loaded = resp.GetProgress() == 100
|
||||
refreshed = resp.GetRefreshProgress() == 100
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if loaded {
|
||||
if (loaded && !t.refresh) || (refreshed && t.refresh) {
|
||||
return nil
|
||||
}
|
||||
if !timer.Stop() {
|
||||
|
@ -85,6 +89,7 @@ func (c *Client) LoadCollection(ctx context.Context, option LoadCollectionOption
|
|||
client: c,
|
||||
collectionName: req.GetCollectionName(),
|
||||
interval: option.CheckInterval(),
|
||||
refresh: option.IsRefresh(),
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -108,6 +113,7 @@ func (c *Client) LoadPartitions(ctx context.Context, option LoadPartitionsOption
|
|||
collectionName: req.GetCollectionName(),
|
||||
partitionNames: req.GetPartitionNames(),
|
||||
interval: option.CheckInterval(),
|
||||
refresh: option.IsRefresh(),
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -115,6 +121,35 @@ func (c *Client) LoadPartitions(ctx context.Context, option LoadPartitionsOption
|
|||
return task, err
|
||||
}
|
||||
|
||||
func (c *Client) GetLoadState(ctx context.Context, option GetLoadStateOption, callOptions ...grpc.CallOption) (entity.LoadState, error) {
|
||||
req := option.Request()
|
||||
|
||||
var state entity.LoadState
|
||||
var err error
|
||||
|
||||
if err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.GetLoadState(ctx, req, callOptions...)
|
||||
state.State = entity.LoadStateCode(resp.GetState())
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
}); err != nil {
|
||||
return state, err
|
||||
}
|
||||
|
||||
// get progress if state is loading
|
||||
if state.State == entity.LoadStateLoading {
|
||||
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.GetLoadingProgress(ctx, option.ProgressRequest(), callOptions...)
|
||||
if err := merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state.Progress = resp.GetProgress()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
return state, err
|
||||
}
|
||||
|
||||
func (c *Client) ReleaseCollection(ctx context.Context, option ReleaseCollectionOption, callOptions ...grpc.CallOption) error {
|
||||
req := option.Request()
|
||||
|
||||
|
@ -134,6 +169,26 @@ func (c *Client) ReleasePartitions(ctx context.Context, option ReleasePartitions
|
|||
})
|
||||
}
|
||||
|
||||
func (c *Client) RefreshLoad(ctx context.Context, option RefreshLoadOption, callOptions ...grpc.CallOption) (LoadTask, error) {
|
||||
req := option.Request()
|
||||
var task LoadTask
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.LoadCollection(ctx, req, callOptions...)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
task = LoadTask{
|
||||
client: c,
|
||||
collectionName: req.GetCollectionName(),
|
||||
interval: option.CheckInterval(),
|
||||
refresh: true,
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return task, err
|
||||
}
|
||||
|
||||
type FlushTask struct {
|
||||
client *Client
|
||||
collectionName string
|
||||
|
@ -206,3 +261,29 @@ func (c *Client) Flush(ctx context.Context, option FlushOption, callOptions ...g
|
|||
})
|
||||
return task, err
|
||||
}
|
||||
|
||||
func (c *Client) Compact(ctx context.Context, option CompactOption, callOptions ...grpc.CallOption) (int64, error) {
|
||||
req := option.Request()
|
||||
|
||||
var jobID int64
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.ManualCompaction(ctx, req, callOptions...)
|
||||
jobID = resp.GetCompactionID()
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
return jobID, err
|
||||
}
|
||||
|
||||
func (c *Client) GetCompactionState(ctx context.Context, option GetCompactionStateOption, callOptions ...grpc.CallOption) (entity.CompactionState, error) {
|
||||
req := option.Request()
|
||||
|
||||
var status entity.CompactionState
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.GetCompactionState(ctx, req, callOptions...)
|
||||
status = entity.CompactionState(resp.GetState())
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
return status, err
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
type LoadCollectionOption interface {
|
||||
Request() *milvuspb.LoadCollectionRequest
|
||||
CheckInterval() time.Duration
|
||||
IsRefresh() bool
|
||||
}
|
||||
|
||||
type loadCollectionOption struct {
|
||||
|
@ -33,6 +34,8 @@ type loadCollectionOption struct {
|
|||
replicaNum int
|
||||
loadFields []string
|
||||
skipLoadDynamicField bool
|
||||
isRefresh bool
|
||||
resourceGroups []string
|
||||
}
|
||||
|
||||
func (opt *loadCollectionOption) Request() *milvuspb.LoadCollectionRequest {
|
||||
|
@ -41,6 +44,7 @@ func (opt *loadCollectionOption) Request() *milvuspb.LoadCollectionRequest {
|
|||
ReplicaNumber: int32(opt.replicaNum),
|
||||
LoadFields: opt.loadFields,
|
||||
SkipLoadDynamicField: opt.skipLoadDynamicField,
|
||||
ResourceGroups: opt.resourceGroups,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -48,11 +52,20 @@ func (opt *loadCollectionOption) CheckInterval() time.Duration {
|
|||
return opt.interval
|
||||
}
|
||||
|
||||
func (opt *loadCollectionOption) IsRefresh() bool {
|
||||
return opt.isRefresh
|
||||
}
|
||||
|
||||
func (opt *loadCollectionOption) WithReplica(num int) *loadCollectionOption {
|
||||
opt.replicaNum = num
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *loadCollectionOption) WithResourceGroup(resourceGroups ...string) *loadCollectionOption {
|
||||
opt.resourceGroups = resourceGroups
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *loadCollectionOption) WithLoadFields(loadFields ...string) *loadCollectionOption {
|
||||
opt.loadFields = loadFields
|
||||
return opt
|
||||
|
@ -63,6 +76,11 @@ func (opt *loadCollectionOption) WithSkipLoadDynamicField(skipFlag bool) *loadCo
|
|||
return opt
|
||||
}
|
||||
|
||||
func (opt *loadCollectionOption) WithRefresh(isRefresh bool) *loadCollectionOption {
|
||||
opt.isRefresh = isRefresh
|
||||
return opt
|
||||
}
|
||||
|
||||
func NewLoadCollectionOption(collectionName string) *loadCollectionOption {
|
||||
return &loadCollectionOption{
|
||||
collectionName: collectionName,
|
||||
|
@ -74,6 +92,7 @@ func NewLoadCollectionOption(collectionName string) *loadCollectionOption {
|
|||
type LoadPartitionsOption interface {
|
||||
Request() *milvuspb.LoadPartitionsRequest
|
||||
CheckInterval() time.Duration
|
||||
IsRefresh() bool
|
||||
}
|
||||
|
||||
var _ LoadPartitionsOption = (*loadPartitionsOption)(nil)
|
||||
|
@ -83,8 +102,10 @@ type loadPartitionsOption struct {
|
|||
partitionNames []string
|
||||
interval time.Duration
|
||||
replicaNum int
|
||||
resourceGroups []string
|
||||
loadFields []string
|
||||
skipLoadDynamicField bool
|
||||
isRefresh bool
|
||||
}
|
||||
|
||||
func (opt *loadPartitionsOption) Request() *milvuspb.LoadPartitionsRequest {
|
||||
|
@ -94,6 +115,7 @@ func (opt *loadPartitionsOption) Request() *milvuspb.LoadPartitionsRequest {
|
|||
ReplicaNumber: int32(opt.replicaNum),
|
||||
LoadFields: opt.loadFields,
|
||||
SkipLoadDynamicField: opt.skipLoadDynamicField,
|
||||
ResourceGroups: opt.resourceGroups,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,11 +123,20 @@ func (opt *loadPartitionsOption) CheckInterval() time.Duration {
|
|||
return opt.interval
|
||||
}
|
||||
|
||||
func (opt *loadPartitionsOption) IsRefresh() bool {
|
||||
return opt.isRefresh
|
||||
}
|
||||
|
||||
func (opt *loadPartitionsOption) WithReplica(num int) *loadPartitionsOption {
|
||||
opt.replicaNum = num
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *loadPartitionsOption) WithResourceGroup(resourceGroups ...string) *loadPartitionsOption {
|
||||
opt.resourceGroups = resourceGroups
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *loadPartitionsOption) WithLoadFields(loadFields ...string) *loadPartitionsOption {
|
||||
opt.loadFields = loadFields
|
||||
return opt
|
||||
|
@ -116,6 +147,11 @@ func (opt *loadPartitionsOption) WithSkipLoadDynamicField(skipFlag bool) *loadPa
|
|||
return opt
|
||||
}
|
||||
|
||||
func (opt *loadPartitionsOption) WithRefresh(isRefresh bool) *loadPartitionsOption {
|
||||
opt.isRefresh = isRefresh
|
||||
return opt
|
||||
}
|
||||
|
||||
func NewLoadPartitionsOption(collectionName string, partitionsNames ...string) *loadPartitionsOption {
|
||||
return &loadPartitionsOption{
|
||||
collectionName: collectionName,
|
||||
|
@ -125,6 +161,65 @@ func NewLoadPartitionsOption(collectionName string, partitionsNames ...string) *
|
|||
}
|
||||
}
|
||||
|
||||
type GetLoadStateOption interface {
|
||||
Request() *milvuspb.GetLoadStateRequest
|
||||
ProgressRequest() *milvuspb.GetLoadingProgressRequest
|
||||
}
|
||||
|
||||
type getLoadStateOption struct {
|
||||
collectionName string
|
||||
partitionNames []string
|
||||
}
|
||||
|
||||
func (opt *getLoadStateOption) Request() *milvuspb.GetLoadStateRequest {
|
||||
return &milvuspb.GetLoadStateRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
PartitionNames: opt.partitionNames,
|
||||
}
|
||||
}
|
||||
|
||||
func (opt *getLoadStateOption) ProgressRequest() *milvuspb.GetLoadingProgressRequest {
|
||||
return &milvuspb.GetLoadingProgressRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
PartitionNames: opt.partitionNames,
|
||||
}
|
||||
}
|
||||
|
||||
func NewGetLoadStateOption(collectionName string, partitionNames ...string) *getLoadStateOption {
|
||||
return &getLoadStateOption{
|
||||
collectionName: collectionName,
|
||||
partitionNames: partitionNames,
|
||||
}
|
||||
}
|
||||
|
||||
type RefreshLoadOption interface {
|
||||
Request() *milvuspb.LoadCollectionRequest
|
||||
CheckInterval() time.Duration
|
||||
}
|
||||
|
||||
type refreshLoadOption struct {
|
||||
collectionName string
|
||||
checkInterval time.Duration
|
||||
}
|
||||
|
||||
func (opt *refreshLoadOption) Request() *milvuspb.LoadCollectionRequest {
|
||||
return &milvuspb.LoadCollectionRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
Refresh: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (opt *refreshLoadOption) CheckInterval() time.Duration {
|
||||
return opt.checkInterval
|
||||
}
|
||||
|
||||
func NewRefreshLoadOption(collectionName string) *refreshLoadOption {
|
||||
return &refreshLoadOption{
|
||||
collectionName: collectionName,
|
||||
checkInterval: time.Millisecond * 200,
|
||||
}
|
||||
}
|
||||
|
||||
type ReleaseCollectionOption interface {
|
||||
Request() *milvuspb.ReleaseCollectionRequest
|
||||
}
|
||||
|
@ -203,3 +298,43 @@ func NewFlushOption(collName string) *flushOption {
|
|||
interval: time.Millisecond * 200,
|
||||
}
|
||||
}
|
||||
|
||||
type CompactOption interface {
|
||||
Request() *milvuspb.ManualCompactionRequest
|
||||
}
|
||||
|
||||
type compactOption struct {
|
||||
collectionName string
|
||||
}
|
||||
|
||||
func (opt *compactOption) Request() *milvuspb.ManualCompactionRequest {
|
||||
return &milvuspb.ManualCompactionRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
}
|
||||
}
|
||||
|
||||
func NewCompactOption(collectionName string) *compactOption {
|
||||
return &compactOption{
|
||||
collectionName: collectionName,
|
||||
}
|
||||
}
|
||||
|
||||
type GetCompactionStateOption interface {
|
||||
Request() *milvuspb.GetCompactionStateRequest
|
||||
}
|
||||
|
||||
type getCompactionStateOption struct {
|
||||
compactionID int64
|
||||
}
|
||||
|
||||
func (opt *getCompactionStateOption) Request() *milvuspb.GetCompactionStateRequest {
|
||||
return &milvuspb.GetCompactionStateRequest{
|
||||
CompactionID: opt.compactionID,
|
||||
}
|
||||
}
|
||||
|
||||
func NewGetCompactionStateOption(compactionID int64) *getCompactionStateOption {
|
||||
return &getCompactionStateOption{
|
||||
compactionID: compactionID,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
|
@ -44,6 +45,7 @@ func (s *MaintenanceSuite) TestLoadCollection() {
|
|||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
fieldNames := []string{"id", "part", "vector"}
|
||||
replicaNum := rand.Intn(3) + 1
|
||||
rgs := []string{"rg1", "rg2"}
|
||||
|
||||
done := atomic.NewBool(false)
|
||||
s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, lcr *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) {
|
||||
|
@ -51,6 +53,7 @@ func (s *MaintenanceSuite) TestLoadCollection() {
|
|||
s.ElementsMatch(fieldNames, lcr.GetLoadFields())
|
||||
s.True(lcr.SkipLoadDynamicField)
|
||||
s.EqualValues(replicaNum, lcr.GetReplicaNumber())
|
||||
s.ElementsMatch(rgs, lcr.GetResourceGroups())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
|
||||
|
@ -70,6 +73,7 @@ func (s *MaintenanceSuite) TestLoadCollection() {
|
|||
|
||||
task, err := s.client.LoadCollection(ctx, NewLoadCollectionOption(collectionName).
|
||||
WithReplica(replicaNum).
|
||||
WithResourceGroup(rgs...).
|
||||
WithLoadFields(fieldNames...).
|
||||
WithSkipLoadDynamicField(true))
|
||||
s.NoError(err)
|
||||
|
@ -114,6 +118,7 @@ func (s *MaintenanceSuite) TestLoadPartitions() {
|
|||
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
fieldNames := []string{"id", "part", "vector"}
|
||||
replicaNum := rand.Intn(3) + 1
|
||||
rgs := []string{"rg1", "rg2"}
|
||||
|
||||
done := atomic.NewBool(false)
|
||||
s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, lpr *milvuspb.LoadPartitionsRequest) (*commonpb.Status, error) {
|
||||
|
@ -122,6 +127,7 @@ func (s *MaintenanceSuite) TestLoadPartitions() {
|
|||
s.ElementsMatch(fieldNames, lpr.GetLoadFields())
|
||||
s.True(lpr.SkipLoadDynamicField)
|
||||
s.EqualValues(replicaNum, lpr.GetReplicaNumber())
|
||||
s.ElementsMatch(rgs, lpr.GetResourceGroups())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
|
||||
|
@ -142,6 +148,7 @@ func (s *MaintenanceSuite) TestLoadPartitions() {
|
|||
|
||||
task, err := s.client.LoadPartitions(ctx, NewLoadPartitionsOption(collectionName, partitionName).
|
||||
WithReplica(replicaNum).
|
||||
WithResourceGroup(rgs...).
|
||||
WithLoadFields(fieldNames...).
|
||||
WithSkipLoadDynamicField(true))
|
||||
s.NoError(err)
|
||||
|
@ -293,6 +300,167 @@ func (s *MaintenanceSuite) TestFlush() {
|
|||
})
|
||||
}
|
||||
|
||||
func (s *MaintenanceSuite) TestRefreshLoad() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
|
||||
done := atomic.NewBool(false)
|
||||
s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, lcr *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) {
|
||||
s.Equal(collectionName, lcr.GetCollectionName())
|
||||
s.True(lcr.GetRefresh())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
|
||||
s.Equal(collectionName, glpr.GetCollectionName())
|
||||
|
||||
progress := int64(50)
|
||||
if done.Load() {
|
||||
progress = 100
|
||||
}
|
||||
|
||||
return &milvuspb.GetLoadingProgressResponse{
|
||||
Status: merr.Success(),
|
||||
RefreshProgress: progress,
|
||||
}, nil
|
||||
})
|
||||
defer s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Unset()
|
||||
|
||||
task, err := s.client.RefreshLoad(ctx, NewRefreshLoadOption(collectionName))
|
||||
s.NoError(err)
|
||||
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := task.Await(ctx)
|
||||
s.NoError(err)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
s.FailNow("task done before index state set to finish")
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
|
||||
done.Store(true)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
s.FailNow("task not done after index set finished")
|
||||
}
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
|
||||
s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.RefreshLoad(ctx, NewRefreshLoadOption(collectionName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *MaintenanceSuite) TestCompact() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
compactID := rand.Int63()
|
||||
|
||||
s.mock.EXPECT().ManualCompaction(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cr *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error) {
|
||||
s.Equal(collectionName, cr.GetCollectionName())
|
||||
return &milvuspb.ManualCompactionResponse{
|
||||
CompactionID: compactID,
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
id, err := s.client.Compact(ctx, NewCompactOption(collectionName))
|
||||
s.NoError(err)
|
||||
s.Equal(compactID, id)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
|
||||
s.mock.EXPECT().ManualCompaction(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.Compact(ctx, NewCompactOption(collectionName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *MaintenanceSuite) TestGetCompactionState() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
compactID := rand.Int63()
|
||||
|
||||
s.mock.EXPECT().GetCompactionState(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, gcsr *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) {
|
||||
s.Equal(compactID, gcsr.GetCompactionID())
|
||||
return &milvuspb.GetCompactionStateResponse{
|
||||
Status: merr.Success(),
|
||||
State: commonpb.CompactionState_Completed,
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
state, err := s.client.GetCompactionState(ctx, NewGetCompactionStateOption(compactID))
|
||||
s.NoError(err)
|
||||
s.Equal(entity.CompactionStateCompleted, state)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
compactID := rand.Int63()
|
||||
|
||||
s.mock.EXPECT().GetCompactionState(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.GetCompactionState(ctx, NewGetCompactionStateOption(compactID))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *MaintenanceSuite) TestGetLoadState() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
progress := rand.Int63n(100)
|
||||
|
||||
s.mock.EXPECT().GetLoadState(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glsr *milvuspb.GetLoadStateRequest) (*milvuspb.GetLoadStateResponse, error) {
|
||||
s.Equal(collectionName, glsr.GetCollectionName())
|
||||
return &milvuspb.GetLoadStateResponse{
|
||||
Status: merr.Success(),
|
||||
State: commonpb.LoadState_LoadStateLoading,
|
||||
}, nil
|
||||
}).Once()
|
||||
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
|
||||
s.Equal(collectionName, glpr.GetCollectionName())
|
||||
return &milvuspb.GetLoadingProgressResponse{
|
||||
Status: merr.Success(),
|
||||
Progress: progress,
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
state, err := s.client.GetLoadState(ctx, NewGetLoadStateOption(collectionName))
|
||||
s.NoError(err)
|
||||
s.Equal(entity.LoadStateLoading, state.State)
|
||||
s.Equal(progress, state.Progress)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
|
||||
s.mock.EXPECT().GetLoadState(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.GetLoadState(ctx, NewGetLoadStateOption(collectionName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMaintenance(t *testing.T) {
|
||||
suite.Run(t, new(MaintenanceSuite))
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
|
@ -76,3 +77,20 @@ func (c *Client) ListPartitions(ctx context.Context, opt ListPartitionsOption, c
|
|||
})
|
||||
return partitionNames, err
|
||||
}
|
||||
|
||||
func (c *Client) GetPartitionStats(ctx context.Context, opt GetPartitionStatsOption, callOptions ...grpc.CallOption) (map[string]string, error) {
|
||||
req := opt.Request()
|
||||
|
||||
var result map[string]string
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.GetPartitionStatistics(ctx, req, callOptions...)
|
||||
err = merr.CheckRPCCall(resp, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result = entity.KvPairsMap(resp.GetStats())
|
||||
return nil
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
|
|
@ -117,3 +117,26 @@ func NewListPartitionOption(collectionName string) *listPartitionsOpt {
|
|||
collectionName: collectionName,
|
||||
}
|
||||
}
|
||||
|
||||
type GetPartitionStatsOption interface {
|
||||
Request() *milvuspb.GetPartitionStatisticsRequest
|
||||
}
|
||||
|
||||
type getPartitionStatsOpt struct {
|
||||
collectionName string
|
||||
partitionName string
|
||||
}
|
||||
|
||||
func (opt *getPartitionStatsOpt) Request() *milvuspb.GetPartitionStatisticsRequest {
|
||||
return &milvuspb.GetPartitionStatisticsRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
PartitionName: opt.partitionName,
|
||||
}
|
||||
}
|
||||
|
||||
func NewGetPartitionStatsOption(collectionName string, partitionName string) *getPartitionStatsOpt {
|
||||
return &getPartitionStatsOpt{
|
||||
collectionName: collectionName,
|
||||
partitionName: partitionName,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -162,6 +162,39 @@ func (s *PartitionSuite) TestDropPartition() {
|
|||
})
|
||||
}
|
||||
|
||||
func (s *PartitionSuite) TestGetPartitionStats() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
s.mock.EXPECT().GetPartitionStatistics(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, gpsr *milvuspb.GetPartitionStatisticsRequest) (*milvuspb.GetPartitionStatisticsResponse, error) {
|
||||
s.Equal(collectionName, gpsr.GetCollectionName())
|
||||
s.Equal(partitionName, gpsr.GetPartitionName())
|
||||
return &milvuspb.GetPartitionStatisticsResponse{
|
||||
Status: merr.Success(),
|
||||
Stats: []*commonpb.KeyValuePair{
|
||||
{Key: "rows", Value: "100"},
|
||||
},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
stats, err := s.client.GetPartitionStats(ctx, NewGetPartitionStatsOption(collectionName, partitionName))
|
||||
s.NoError(err)
|
||||
s.Equal("100", stats["rows"])
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
s.mock.EXPECT().GetPartitionStatistics(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.GetPartitionStats(ctx, NewGetPartitionStatsOption(collectionName, partitionName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPartition(t *testing.T) {
|
||||
suite.Run(t, new(PartitionSuite))
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package milvusclient
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
|
@ -27,6 +28,153 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
func (c *Client) ListUsers(ctx context.Context, opt ListUserOption, callOpts ...grpc.CallOption) ([]string, error) {
|
||||
var users []string
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.ListCredUsers(ctx, opt.Request(), callOpts...)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
users = resp.GetUsernames()
|
||||
return nil
|
||||
})
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (c *Client) DescribeUser(ctx context.Context, opt DescribeUserOption, callOpts ...grpc.CallOption) (*entity.User, error) {
|
||||
var user *entity.User
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.SelectUser(ctx, opt.Request(), callOpts...)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.GetResults()) == 0 {
|
||||
return errors.New("not user found")
|
||||
}
|
||||
result := resp.GetResults()[0]
|
||||
user = &entity.User{
|
||||
UserName: result.GetUser().GetName(),
|
||||
Roles: lo.Map(result.GetRoles(), func(r *milvuspb.RoleEntity, _ int) string { return r.GetName() }),
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return user, err
|
||||
}
|
||||
|
||||
func (c *Client) CreateUser(ctx context.Context, opt CreateUserOption, callOpts ...grpc.CallOption) error {
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.CreateCredential(ctx, opt.Request(), callOpts...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) UpdatePassword(ctx context.Context, opt UpdatePasswordOption, callOpts ...grpc.CallOption) error {
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.UpdateCredential(ctx, opt.Request(), callOpts...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) DropUser(ctx context.Context, opt DropUserOption, callOpts ...grpc.CallOption) error {
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.DeleteCredential(ctx, opt.Request(), callOpts...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) ListRoles(ctx context.Context, opt ListRoleOption, callOpts ...grpc.CallOption) ([]string, error) {
|
||||
var roles []string
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.SelectRole(ctx, opt.Request(), callOpts...)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
roles = lo.Map(resp.GetResults(), func(r *milvuspb.RoleResult, _ int) string {
|
||||
return r.GetRole().GetName()
|
||||
})
|
||||
return nil
|
||||
})
|
||||
return roles, err
|
||||
}
|
||||
|
||||
func (c *Client) CreateRole(ctx context.Context, opt CreateRoleOption, callOpts ...grpc.CallOption) error {
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.CreateRole(ctx, opt.Request(), callOpts...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) GrantRole(ctx context.Context, opt GrantRoleOption, callOpts ...grpc.CallOption) error {
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.OperateUserRole(ctx, opt.Request(), callOpts...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) RevokeRole(ctx context.Context, opt RevokeRoleOption, callOpts ...grpc.CallOption) error {
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.OperateUserRole(ctx, opt.Request(), callOpts...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) DropRole(ctx context.Context, opt DropRoleOption, callOpts ...grpc.CallOption) error {
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.DropRole(ctx, opt.Request(), callOpts...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) DescribeRole(ctx context.Context, option DescribeRoleOption, callOptions ...grpc.CallOption) (*entity.Role, error) {
|
||||
req := option.Request()
|
||||
|
||||
var role *entity.Role
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.SelectGrant(ctx, req, callOptions...)
|
||||
if err := merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.GetEntities()) == 0 {
|
||||
return errors.New("role not found")
|
||||
}
|
||||
|
||||
role = &entity.Role{
|
||||
RoleName: req.GetEntity().GetRole().GetName(),
|
||||
Privileges: lo.Map(resp.GetEntities(), func(g *milvuspb.GrantEntity, _ int) entity.GrantItem {
|
||||
return entity.GrantItem{
|
||||
Object: g.Object.GetName(),
|
||||
ObjectName: g.GetObjectName(),
|
||||
RoleName: g.GetRole().GetName(),
|
||||
Grantor: g.GetGrantor().GetUser().GetName(),
|
||||
Privilege: g.GetGrantor().GetPrivilege().GetName(),
|
||||
}
|
||||
}),
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return role, err
|
||||
}
|
||||
|
||||
func (c *Client) GrantPrivilege(ctx context.Context, option GrantPrivilegeOption, callOptions ...grpc.CallOption) error {
|
||||
req := option.Request()
|
||||
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.OperatePrivilege(ctx, req, callOptions...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) RevokePrivilege(ctx context.Context, option RevokePrivilegeOption, callOptions ...grpc.CallOption) error {
|
||||
req := option.Request()
|
||||
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.OperatePrivilege(ctx, req, callOptions...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) GrantV2(ctx context.Context, option GrantV2Option, callOptions ...grpc.CallOption) error {
|
||||
req := option.Request()
|
||||
|
||||
|
|
|
@ -20,6 +20,314 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
)
|
||||
|
||||
type ListUserOption interface {
|
||||
Request() *milvuspb.ListCredUsersRequest
|
||||
}
|
||||
|
||||
// listUserOption is the struct to build ListCredUsersRequest
|
||||
// left empty for not attribute needed right now
|
||||
type listUserOption struct{}
|
||||
|
||||
func (opt *listUserOption) Request() *milvuspb.ListCredUsersRequest {
|
||||
return &milvuspb.ListCredUsersRequest{}
|
||||
}
|
||||
|
||||
func NewListUserOption() *listUserOption {
|
||||
return &listUserOption{}
|
||||
}
|
||||
|
||||
type DescribeUserOption interface {
|
||||
Request() *milvuspb.SelectUserRequest
|
||||
}
|
||||
|
||||
type describeUserOption struct {
|
||||
userName string
|
||||
}
|
||||
|
||||
func (opt *describeUserOption) Request() *milvuspb.SelectUserRequest {
|
||||
return &milvuspb.SelectUserRequest{
|
||||
User: &milvuspb.UserEntity{
|
||||
Name: opt.userName,
|
||||
},
|
||||
IncludeRoleInfo: true,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDescribeUserOption(userName string) *describeUserOption {
|
||||
return &describeUserOption{
|
||||
userName: userName,
|
||||
}
|
||||
}
|
||||
|
||||
type CreateUserOption interface {
|
||||
Request() *milvuspb.CreateCredentialRequest
|
||||
}
|
||||
|
||||
type createUserOption struct {
|
||||
userName string
|
||||
password string
|
||||
}
|
||||
|
||||
func (opt *createUserOption) Request() *milvuspb.CreateCredentialRequest {
|
||||
return &milvuspb.CreateCredentialRequest{
|
||||
Username: opt.userName,
|
||||
Password: opt.password,
|
||||
}
|
||||
}
|
||||
|
||||
func NewCreateUserOption(userName, password string) *createUserOption {
|
||||
return &createUserOption{
|
||||
userName: userName,
|
||||
password: password,
|
||||
}
|
||||
}
|
||||
|
||||
type UpdatePasswordOption interface {
|
||||
Request() *milvuspb.UpdateCredentialRequest
|
||||
}
|
||||
|
||||
type updatePasswordOption struct {
|
||||
userName string
|
||||
oldPassword string
|
||||
newPassword string
|
||||
}
|
||||
|
||||
func (opt *updatePasswordOption) Request() *milvuspb.UpdateCredentialRequest {
|
||||
return &milvuspb.UpdateCredentialRequest{
|
||||
Username: opt.userName,
|
||||
OldPassword: opt.oldPassword,
|
||||
NewPassword: opt.newPassword,
|
||||
}
|
||||
}
|
||||
|
||||
func NewUpdatePasswordOption(userName, oldPassword, newPassword string) *updatePasswordOption {
|
||||
return &updatePasswordOption{
|
||||
userName: userName,
|
||||
oldPassword: oldPassword,
|
||||
newPassword: newPassword,
|
||||
}
|
||||
}
|
||||
|
||||
type DropUserOption interface {
|
||||
Request() *milvuspb.DeleteCredentialRequest
|
||||
}
|
||||
|
||||
type dropUserOption struct {
|
||||
userName string
|
||||
}
|
||||
|
||||
func (opt *dropUserOption) Request() *milvuspb.DeleteCredentialRequest {
|
||||
return &milvuspb.DeleteCredentialRequest{
|
||||
Username: opt.userName,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDropUserOption(userName string) *dropUserOption {
|
||||
return &dropUserOption{
|
||||
userName: userName,
|
||||
}
|
||||
}
|
||||
|
||||
type ListRoleOption interface {
|
||||
Request() *milvuspb.SelectRoleRequest
|
||||
}
|
||||
|
||||
type listRoleOption struct{}
|
||||
|
||||
func (opt *listRoleOption) Request() *milvuspb.SelectRoleRequest {
|
||||
return &milvuspb.SelectRoleRequest{
|
||||
IncludeUserInfo: false,
|
||||
}
|
||||
}
|
||||
|
||||
func NewListRoleOption() *listRoleOption {
|
||||
return &listRoleOption{}
|
||||
}
|
||||
|
||||
type CreateRoleOption interface {
|
||||
Request() *milvuspb.CreateRoleRequest
|
||||
}
|
||||
|
||||
type createRoleOption struct {
|
||||
roleName string
|
||||
}
|
||||
|
||||
func (opt *createRoleOption) Request() *milvuspb.CreateRoleRequest {
|
||||
return &milvuspb.CreateRoleRequest{
|
||||
Entity: &milvuspb.RoleEntity{Name: opt.roleName},
|
||||
}
|
||||
}
|
||||
|
||||
func NewCreateRoleOption(roleName string) *createRoleOption {
|
||||
return &createRoleOption{
|
||||
roleName: roleName,
|
||||
}
|
||||
}
|
||||
|
||||
type GrantRoleOption interface {
|
||||
Request() *milvuspb.OperateUserRoleRequest
|
||||
}
|
||||
|
||||
type grantRoleOption struct {
|
||||
roleName string
|
||||
userName string
|
||||
}
|
||||
|
||||
func (opt *grantRoleOption) Request() *milvuspb.OperateUserRoleRequest {
|
||||
return &milvuspb.OperateUserRoleRequest{
|
||||
Username: opt.userName,
|
||||
RoleName: opt.roleName,
|
||||
Type: milvuspb.OperateUserRoleType_AddUserToRole,
|
||||
}
|
||||
}
|
||||
|
||||
func NewGrantRoleOption(userName, roleName string) *grantRoleOption {
|
||||
return &grantRoleOption{
|
||||
roleName: roleName,
|
||||
userName: userName,
|
||||
}
|
||||
}
|
||||
|
||||
type RevokeRoleOption interface {
|
||||
Request() *milvuspb.OperateUserRoleRequest
|
||||
}
|
||||
|
||||
type revokeRoleOption struct {
|
||||
roleName string
|
||||
userName string
|
||||
}
|
||||
|
||||
func (opt *revokeRoleOption) Request() *milvuspb.OperateUserRoleRequest {
|
||||
return &milvuspb.OperateUserRoleRequest{
|
||||
Username: opt.userName,
|
||||
RoleName: opt.roleName,
|
||||
Type: milvuspb.OperateUserRoleType_RemoveUserFromRole,
|
||||
}
|
||||
}
|
||||
|
||||
func NewRevokeRoleOption(userName, roleName string) *revokeRoleOption {
|
||||
return &revokeRoleOption{
|
||||
roleName: roleName,
|
||||
userName: userName,
|
||||
}
|
||||
}
|
||||
|
||||
type DropRoleOption interface {
|
||||
Request() *milvuspb.DropRoleRequest
|
||||
}
|
||||
|
||||
type dropDropRoleOption struct {
|
||||
roleName string
|
||||
}
|
||||
|
||||
func (opt *dropDropRoleOption) Request() *milvuspb.DropRoleRequest {
|
||||
return &milvuspb.DropRoleRequest{
|
||||
RoleName: opt.roleName,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDropRoleOption(roleName string) *dropDropRoleOption {
|
||||
return &dropDropRoleOption{
|
||||
roleName: roleName,
|
||||
}
|
||||
}
|
||||
|
||||
type DescribeRoleOption interface {
|
||||
Request() *milvuspb.SelectGrantRequest
|
||||
}
|
||||
|
||||
type describeRoleOption struct {
|
||||
roleName string
|
||||
}
|
||||
|
||||
func (opt *describeRoleOption) Request() *milvuspb.SelectGrantRequest {
|
||||
return &milvuspb.SelectGrantRequest{
|
||||
Entity: &milvuspb.GrantEntity{
|
||||
Role: &milvuspb.RoleEntity{Name: opt.roleName},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewDescribeRoleOption(roleName string) *describeRoleOption {
|
||||
return &describeRoleOption{
|
||||
roleName: roleName,
|
||||
}
|
||||
}
|
||||
|
||||
type GrantPrivilegeOption interface {
|
||||
Request() *milvuspb.OperatePrivilegeRequest
|
||||
}
|
||||
|
||||
type grantPrivilegeOption struct {
|
||||
roleName string
|
||||
privilegeName string
|
||||
objectName string
|
||||
objectType string
|
||||
}
|
||||
|
||||
func (opt *grantPrivilegeOption) Request() *milvuspb.OperatePrivilegeRequest {
|
||||
return &milvuspb.OperatePrivilegeRequest{
|
||||
Entity: &milvuspb.GrantEntity{
|
||||
Role: &milvuspb.RoleEntity{Name: opt.roleName},
|
||||
Grantor: &milvuspb.GrantorEntity{
|
||||
Privilege: &milvuspb.PrivilegeEntity{Name: opt.privilegeName},
|
||||
},
|
||||
Object: &milvuspb.ObjectEntity{
|
||||
Name: opt.objectType,
|
||||
},
|
||||
ObjectName: opt.objectName,
|
||||
},
|
||||
|
||||
Type: milvuspb.OperatePrivilegeType_Grant,
|
||||
}
|
||||
}
|
||||
|
||||
func NewGrantPrivilegeOption(roleName, objectType, privilegeName, objectName string) *grantPrivilegeOption {
|
||||
return &grantPrivilegeOption{
|
||||
roleName: roleName,
|
||||
privilegeName: privilegeName,
|
||||
objectName: objectName,
|
||||
objectType: objectType,
|
||||
}
|
||||
}
|
||||
|
||||
type RevokePrivilegeOption interface {
|
||||
Request() *milvuspb.OperatePrivilegeRequest
|
||||
}
|
||||
|
||||
type revokePrivilegeOption struct {
|
||||
roleName string
|
||||
privilegeName string
|
||||
objectName string
|
||||
objectType string
|
||||
}
|
||||
|
||||
func (opt *revokePrivilegeOption) Request() *milvuspb.OperatePrivilegeRequest {
|
||||
return &milvuspb.OperatePrivilegeRequest{
|
||||
Entity: &milvuspb.GrantEntity{
|
||||
Role: &milvuspb.RoleEntity{Name: opt.roleName},
|
||||
Grantor: &milvuspb.GrantorEntity{
|
||||
Privilege: &milvuspb.PrivilegeEntity{Name: opt.privilegeName},
|
||||
},
|
||||
Object: &milvuspb.ObjectEntity{
|
||||
Name: opt.objectType,
|
||||
},
|
||||
ObjectName: opt.objectName,
|
||||
},
|
||||
|
||||
Type: milvuspb.OperatePrivilegeType_Revoke,
|
||||
}
|
||||
}
|
||||
|
||||
func NewRevokePrivilegeOption(roleName, objectType, privilegeName, objectName string) *revokePrivilegeOption {
|
||||
return &revokePrivilegeOption{
|
||||
roleName: roleName,
|
||||
privilegeName: privilegeName,
|
||||
objectName: objectName,
|
||||
objectType: objectType,
|
||||
}
|
||||
}
|
||||
|
||||
// GrantV2Option is the interface builds OperatePrivilegeV2Request
|
||||
type GrantV2Option interface {
|
||||
Request() *milvuspb.OperatePrivilegeV2Request
|
||||
|
|
|
@ -29,6 +29,376 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
type UserSuite struct {
|
||||
MockSuiteBase
|
||||
}
|
||||
|
||||
func (s *UserSuite) TestListUsers() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
s.mock.EXPECT().ListCredUsers(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) {
|
||||
return &milvuspb.ListCredUsersResponse{
|
||||
Usernames: []string{"user1", "user2"},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
users, err := s.client.ListUsers(ctx, NewListUserOption())
|
||||
s.NoError(err)
|
||||
s.Equal([]string{"user1", "user2"}, users)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().ListCredUsers(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.ListUsers(ctx, NewListUserOption())
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *UserSuite) TestDescribeUser() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
userName := fmt.Sprintf("user_%s", s.randString(5))
|
||||
|
||||
s.Run("success", func() {
|
||||
s.mock.EXPECT().SelectUser(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error) {
|
||||
s.Equal(userName, r.GetUser().GetName())
|
||||
return &milvuspb.SelectUserResponse{
|
||||
Results: []*milvuspb.UserResult{
|
||||
{
|
||||
User: &milvuspb.UserEntity{Name: userName},
|
||||
Roles: []*milvuspb.RoleEntity{
|
||||
{Name: "role1"},
|
||||
{Name: "role2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
user, err := s.client.DescribeUser(ctx, NewDescribeUserOption(userName))
|
||||
s.NoError(err)
|
||||
s.Equal(userName, user.UserName)
|
||||
s.Equal([]string{"role1", "role2"}, user.Roles)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().SelectUser(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.DescribeUser(ctx, NewDescribeUserOption(userName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *UserSuite) TestCreateUser() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
userName := fmt.Sprintf("user_%s", s.randString(5))
|
||||
password := s.randString(12)
|
||||
s.mock.EXPECT().CreateCredential(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ccr *milvuspb.CreateCredentialRequest) (*commonpb.Status, error) {
|
||||
s.Equal(userName, ccr.GetUsername())
|
||||
s.Equal(password, ccr.GetPassword())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.CreateUser(ctx, NewCreateUserOption(userName, password))
|
||||
s.NoError(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *UserSuite) TestUpdatePassword() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
userName := fmt.Sprintf("user_%s", s.randString(5))
|
||||
oldPassword := s.randString(12)
|
||||
newPassword := s.randString(12)
|
||||
s.mock.EXPECT().UpdateCredential(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ucr *milvuspb.UpdateCredentialRequest) (*commonpb.Status, error) {
|
||||
s.Equal(userName, ucr.GetUsername())
|
||||
s.Equal(oldPassword, ucr.GetOldPassword())
|
||||
s.Equal(newPassword, ucr.GetNewPassword())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.UpdatePassword(ctx, NewUpdatePasswordOption(userName, oldPassword, newPassword))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().UpdateCredential(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.UpdatePassword(ctx, NewUpdatePasswordOption("user", "old", "new"))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *UserSuite) TestDropUser() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
userName := fmt.Sprintf("user_%s", s.randString(5))
|
||||
s.mock.EXPECT().DeleteCredential(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dcr *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error) {
|
||||
s.Equal(userName, dcr.GetUsername())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.DropUser(ctx, NewDropUserOption(userName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().DeleteCredential(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.DropUser(ctx, NewDropUserOption("user"))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserRBAC(t *testing.T) {
|
||||
suite.Run(t, new(UserSuite))
|
||||
}
|
||||
|
||||
type RoleSuite struct {
|
||||
MockSuiteBase
|
||||
}
|
||||
|
||||
func (s *RoleSuite) TestListRoles() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
s.mock.EXPECT().SelectRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) {
|
||||
return &milvuspb.SelectRoleResponse{
|
||||
Results: []*milvuspb.RoleResult{
|
||||
{Role: &milvuspb.RoleEntity{Name: "role1"}},
|
||||
{Role: &milvuspb.RoleEntity{Name: "role2"}},
|
||||
},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
roles, err := s.client.ListRoles(ctx, NewListRoleOption())
|
||||
s.NoError(err)
|
||||
s.Equal([]string{"role1", "role2"}, roles)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().SelectRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.ListRoles(ctx, NewListRoleOption())
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RoleSuite) TestCreateRole() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
roleName := fmt.Sprintf("role_%s", s.randString(5))
|
||||
s.mock.EXPECT().CreateRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.CreateRoleRequest) (*commonpb.Status, error) {
|
||||
s.Equal(roleName, r.GetEntity().GetName())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.CreateRole(ctx, NewCreateRoleOption(roleName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().CreateRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.CreateRole(ctx, NewCreateRoleOption("role"))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RoleSuite) TestGrantRole() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
userName := fmt.Sprintf("user_%s", s.randString(5))
|
||||
roleName := fmt.Sprintf("role_%s", s.randString(5))
|
||||
s.mock.EXPECT().OperateUserRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) {
|
||||
s.Equal(userName, r.GetUsername())
|
||||
s.Equal(roleName, r.GetRoleName())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.GrantRole(ctx, NewGrantRoleOption(userName, roleName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().OperateUserRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.GrantRole(ctx, NewGrantRoleOption("user", "role"))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RoleSuite) TestRevokeRole() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
userName := fmt.Sprintf("user_%s", s.randString(5))
|
||||
roleName := fmt.Sprintf("role_%s", s.randString(5))
|
||||
s.mock.EXPECT().OperateUserRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) {
|
||||
s.Equal(userName, r.GetUsername())
|
||||
s.Equal(roleName, r.GetRoleName())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.RevokeRole(ctx, NewRevokeRoleOption(userName, roleName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().OperateUserRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.RevokeRole(ctx, NewRevokeRoleOption("user", "role"))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RoleSuite) TestDropRole() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
roleName := fmt.Sprintf("role_%s", s.randString(5))
|
||||
s.mock.EXPECT().DropRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.DropRoleRequest) (*commonpb.Status, error) {
|
||||
s.Equal(roleName, r.GetRoleName())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.DropRole(ctx, NewDropRoleOption(roleName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().DropRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.DropRole(ctx, NewDropRoleOption("role"))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RoleSuite) TestDescribeRole() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
roleName := fmt.Sprintf("role_%s", s.randString(5))
|
||||
s.mock.EXPECT().SelectGrant(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) {
|
||||
s.Equal(roleName, r.GetEntity().GetRole().GetName())
|
||||
return &milvuspb.SelectGrantResponse{
|
||||
Entities: []*milvuspb.GrantEntity{
|
||||
{
|
||||
ObjectName: "*",
|
||||
Object: &milvuspb.ObjectEntity{
|
||||
Name: "collection",
|
||||
},
|
||||
Role: &milvuspb.RoleEntity{Name: roleName},
|
||||
Grantor: &milvuspb.GrantorEntity{User: &milvuspb.UserEntity{Name: "admin"}, Privilege: &milvuspb.PrivilegeEntity{Name: "Insert"}},
|
||||
},
|
||||
{
|
||||
ObjectName: "*",
|
||||
Object: &milvuspb.ObjectEntity{
|
||||
Name: "collection",
|
||||
},
|
||||
Role: &milvuspb.RoleEntity{Name: roleName},
|
||||
Grantor: &milvuspb.GrantorEntity{User: &milvuspb.UserEntity{Name: "admin"}, Privilege: &milvuspb.PrivilegeEntity{Name: "Query"}},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
role, err := s.client.DescribeRole(ctx, NewDescribeRoleOption(roleName))
|
||||
s.NoError(err)
|
||||
s.Equal(roleName, role.RoleName)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().SelectGrant(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.DescribeRole(ctx, NewDescribeRoleOption("role"))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RoleSuite) TestGrantPrivilege() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
roleName := fmt.Sprintf("role_%s", s.randString(5))
|
||||
privilegeName := "Insert"
|
||||
collectionName := fmt.Sprintf("collection_%s", s.randString(6))
|
||||
|
||||
s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) {
|
||||
s.Equal(roleName, r.GetEntity().GetRole().GetName())
|
||||
s.Equal("collection", r.GetEntity().GetObject().GetName())
|
||||
s.Equal(privilegeName, r.GetEntity().GetGrantor().GetPrivilege().GetName())
|
||||
s.Equal(collectionName, r.GetEntity().GetObjectName())
|
||||
s.Equal(milvuspb.OperatePrivilegeType_Grant, r.GetType())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.GrantPrivilege(ctx, NewGrantPrivilegeOption(roleName, "collection", privilegeName, collectionName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.GrantPrivilege(ctx, NewGrantPrivilegeOption("role", "collection", "privilege", "coll_1"))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RoleSuite) TestRevokePrivilege() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
roleName := fmt.Sprintf("role_%s", s.randString(5))
|
||||
privilegeName := "Insert"
|
||||
collectionName := fmt.Sprintf("collection_%s", s.randString(6))
|
||||
|
||||
s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) {
|
||||
s.Equal(roleName, r.GetEntity().GetRole().GetName())
|
||||
s.Equal("collection", r.GetEntity().GetObject().GetName())
|
||||
s.Equal(privilegeName, r.GetEntity().GetGrantor().GetPrivilege().GetName())
|
||||
s.Equal(collectionName, r.GetEntity().GetObjectName())
|
||||
s.Equal(milvuspb.OperatePrivilegeType_Revoke, r.GetType())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.RevokePrivilege(ctx, NewRevokePrivilegeOption(roleName, "collection", privilegeName, collectionName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().OperatePrivilege(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.RevokePrivilege(ctx, NewRevokePrivilegeOption("role", "collection", "privilege", "coll_1"))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoleRBAC(t *testing.T) {
|
||||
suite.Run(t, new(RoleSuite))
|
||||
}
|
||||
|
||||
type PrivilgeGroupSuite struct {
|
||||
MockSuiteBase
|
||||
}
|
||||
|
|
|
@ -192,6 +192,10 @@ func (c *Client) Query(ctx context.Context, option QueryOption, callOptions ...g
|
|||
return resultSet, err
|
||||
}
|
||||
|
||||
func (c *Client) Get(ctx context.Context, option QueryOption, callOptions ...grpc.CallOption) (ResultSet, error) {
|
||||
return c.Query(ctx, option, callOptions...)
|
||||
}
|
||||
|
||||
func (c *Client) HybridSearch(ctx context.Context, option HybridSearchOption, callOptions ...grpc.CallOption) ([]ResultSet, error) {
|
||||
req, err := option.HybridRequest()
|
||||
if err != nil {
|
||||
|
|
|
@ -52,8 +52,7 @@ func (s *SearchOptionSuite) TestBasic() {
|
|||
topK := rand.Intn(100) + 1
|
||||
opt := NewSearchOption(collName, topK, []entity.Vector{entity.FloatVector([]float32{0.1, 0.2})})
|
||||
|
||||
opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000")
|
||||
|
||||
opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000").WithGroupByField("group_field").WithGroupSize(10).WithStrictGroupSize(true)
|
||||
req, err := opt.Request()
|
||||
s.Require().NoError(err)
|
||||
|
||||
|
@ -64,6 +63,15 @@ func (s *SearchOptionSuite) TestBasic() {
|
|||
annField, ok := searchParams[spAnnsField]
|
||||
s.Require().True(ok)
|
||||
s.Equal("test_field", annField)
|
||||
groupField, ok := searchParams[spGroupBy]
|
||||
s.Require().True(ok)
|
||||
s.Equal("group_field", groupField)
|
||||
groupSize, ok := searchParams[spGroupSize]
|
||||
s.Require().True(ok)
|
||||
s.Equal("10", groupSize)
|
||||
spStrictGroupSize, ok := searchParams[spStrictGroupSize]
|
||||
s.Require().True(ok)
|
||||
s.Equal("true", spStrictGroupSize)
|
||||
|
||||
opt = NewSearchOption(collName, topK, []entity.Vector{nonSupportData{}})
|
||||
_, err = opt.Request()
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
@ -28,20 +29,23 @@ import (
|
|||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/column"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/client/v2/index"
|
||||
)
|
||||
|
||||
const (
|
||||
spAnnsField = `anns_field`
|
||||
spTopK = `topk`
|
||||
spOffset = `offset`
|
||||
spLimit = `limit`
|
||||
spParams = `params`
|
||||
spMetricsType = `metric_type`
|
||||
spRoundDecimal = `round_decimal`
|
||||
spIgnoreGrowing = `ignore_growing`
|
||||
spGroupBy = `group_by_field`
|
||||
spAnnsField = `anns_field`
|
||||
spTopK = `topk`
|
||||
spOffset = `offset`
|
||||
spLimit = `limit`
|
||||
spParams = `params`
|
||||
spMetricsType = `metric_type`
|
||||
spRoundDecimal = `round_decimal`
|
||||
spIgnoreGrowing = `ignore_growing`
|
||||
spGroupBy = `group_by_field`
|
||||
spGroupSize = `group_size`
|
||||
spStrictGroupSize = `strict_group_size`
|
||||
)
|
||||
|
||||
type SearchOption interface {
|
||||
|
@ -62,16 +66,18 @@ type searchOption struct {
|
|||
type annRequest struct {
|
||||
vectors []entity.Vector
|
||||
|
||||
annField string
|
||||
metricsType entity.MetricType
|
||||
searchParam map[string]string
|
||||
groupByField string
|
||||
annParam index.AnnParam
|
||||
ignoreGrowing bool
|
||||
expr string
|
||||
topK int
|
||||
offset int
|
||||
templateParams map[string]any
|
||||
annField string
|
||||
metricsType entity.MetricType
|
||||
searchParam map[string]string
|
||||
groupByField string
|
||||
groupSize int
|
||||
strictGroupSize bool
|
||||
annParam index.AnnParam
|
||||
ignoreGrowing bool
|
||||
expr string
|
||||
topK int
|
||||
offset int
|
||||
templateParams map[string]any
|
||||
}
|
||||
|
||||
func NewAnnRequest(annField string, limit int, vectors ...entity.Vector) *annRequest {
|
||||
|
@ -108,6 +114,12 @@ func (r *annRequest) searchRequest() (*milvuspb.SearchRequest, error) {
|
|||
if r.groupByField != "" {
|
||||
params[spGroupBy] = r.groupByField
|
||||
}
|
||||
if r.groupSize != 0 {
|
||||
params[spGroupSize] = strconv.Itoa(r.groupSize)
|
||||
}
|
||||
if r.strictGroupSize {
|
||||
params[spStrictGroupSize] = "true"
|
||||
}
|
||||
// ann param
|
||||
if r.annParam != nil {
|
||||
bs, _ := json.Marshal(r.annParam.Params())
|
||||
|
@ -223,6 +235,16 @@ func (r *annRequest) WithGroupByField(groupByField string) *annRequest {
|
|||
return r
|
||||
}
|
||||
|
||||
func (r *annRequest) WithGroupSize(groupSize int) *annRequest {
|
||||
r.groupSize = groupSize
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *annRequest) WithStrictGroupSize(strictGroupSize bool) *annRequest {
|
||||
r.strictGroupSize = strictGroupSize
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *annRequest) WithSearchParam(key, value string) *annRequest {
|
||||
r.searchParam[key] = value
|
||||
return r
|
||||
|
@ -309,6 +331,16 @@ func (opt *searchOption) WithGroupByField(groupByField string) *searchOption {
|
|||
return opt
|
||||
}
|
||||
|
||||
func (opt *searchOption) WithGroupSize(groupSize int) *searchOption {
|
||||
opt.annRequest.WithGroupSize(groupSize)
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *searchOption) WithStrictGroupSize(strictGroupSize bool) *searchOption {
|
||||
opt.annRequest.WithStrictGroupSize(strictGroupSize)
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *searchOption) WithIgnoreGrowing(ignoreGrowing bool) *searchOption {
|
||||
opt.annRequest.WithIgnoreGrowing(ignoreGrowing)
|
||||
return opt
|
||||
|
@ -550,6 +582,27 @@ func (opt *queryOption) WithPartitions(partitionNames ...string) *queryOption {
|
|||
return opt
|
||||
}
|
||||
|
||||
func (opt *queryOption) WithIDs(ids column.Column) *queryOption {
|
||||
opt.expr = pks2Expr(ids)
|
||||
return opt
|
||||
}
|
||||
|
||||
func pks2Expr(ids column.Column) string {
|
||||
var expr string
|
||||
pkName := ids.Name()
|
||||
switch ids.Type() {
|
||||
case entity.FieldTypeInt64:
|
||||
expr = fmt.Sprintf("%s in %s", pkName, strings.Join(strings.Fields(fmt.Sprint(ids.FieldData().GetScalars().GetLongData().GetData())), ","))
|
||||
case entity.FieldTypeVarChar:
|
||||
data := ids.FieldData().GetScalars().GetData().(*schemapb.ScalarField_StringData).StringData.GetData()
|
||||
for i := range data {
|
||||
data[i] = fmt.Sprintf("\"%s\"", data[i])
|
||||
}
|
||||
expr = fmt.Sprintf("%s in [%s]", pkName, strings.Join(data, ","))
|
||||
}
|
||||
return expr
|
||||
}
|
||||
|
||||
func NewQueryOption(collectionName string) *queryOption {
|
||||
return &queryOption{
|
||||
collectionName: collectionName,
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package milvusclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
func (c *Client) ListResourceGroups(ctx context.Context, opt ListResourceGroupsOption, callOptions ...grpc.CallOption) ([]string, error) {
|
||||
req := opt.Request()
|
||||
|
||||
var rgs []string
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.ListResourceGroups(ctx, req, callOptions...)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rgs = resp.GetResourceGroups()
|
||||
return nil
|
||||
})
|
||||
|
||||
return rgs, err
|
||||
}
|
||||
|
||||
func (c *Client) CreateResourceGroup(ctx context.Context, opt CreateResourceGroupOption, callOptions ...grpc.CallOption) error {
|
||||
req := opt.Request()
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.CreateResourceGroup(ctx, req, callOptions...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) DropResourceGroup(ctx context.Context, opt DropResourceGroupOption, callOptions ...grpc.CallOption) error {
|
||||
req := opt.Request()
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.DropResourceGroup(ctx, req, callOptions...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,92 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package milvusclient
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/rgpb"
|
||||
)
|
||||
|
||||
type ListResourceGroupsOption interface {
|
||||
Request() *milvuspb.ListResourceGroupsRequest
|
||||
}
|
||||
|
||||
type listResourceGroupsOption struct{}
|
||||
|
||||
func (opt *listResourceGroupsOption) Request() *milvuspb.ListResourceGroupsRequest {
|
||||
return &milvuspb.ListResourceGroupsRequest{}
|
||||
}
|
||||
|
||||
func NewListResourceGroupsOption() *listResourceGroupsOption {
|
||||
return &listResourceGroupsOption{}
|
||||
}
|
||||
|
||||
type CreateResourceGroupOption interface {
|
||||
Request() *milvuspb.CreateResourceGroupRequest
|
||||
}
|
||||
|
||||
type createResourceGroupOption struct {
|
||||
name string
|
||||
nodeRequest int
|
||||
nodeLimit int
|
||||
}
|
||||
|
||||
func (opt *createResourceGroupOption) WithNodeRequest(nodeRequest int) *createResourceGroupOption {
|
||||
opt.nodeRequest = nodeRequest
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *createResourceGroupOption) WithNodeLimit(nodeLimit int) *createResourceGroupOption {
|
||||
opt.nodeLimit = nodeLimit
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *createResourceGroupOption) Request() *milvuspb.CreateResourceGroupRequest {
|
||||
return &milvuspb.CreateResourceGroupRequest{
|
||||
ResourceGroup: opt.name,
|
||||
Config: &rgpb.ResourceGroupConfig{
|
||||
Requests: &rgpb.ResourceGroupLimit{
|
||||
NodeNum: int32(opt.nodeRequest),
|
||||
},
|
||||
Limits: &rgpb.ResourceGroupLimit{
|
||||
NodeNum: int32(opt.nodeLimit),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewCreateResourceGroupOption(name string) *createResourceGroupOption {
|
||||
return &createResourceGroupOption{name: name}
|
||||
}
|
||||
|
||||
type DropResourceGroupOption interface {
|
||||
Request() *milvuspb.DropResourceGroupRequest
|
||||
}
|
||||
|
||||
type dropResourceGroupOption struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (opt *dropResourceGroupOption) Request() *milvuspb.DropResourceGroupRequest {
|
||||
return &milvuspb.DropResourceGroupRequest{
|
||||
ResourceGroup: opt.name,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDropResourceGroupOption(name string) *dropResourceGroupOption {
|
||||
return &dropResourceGroupOption{name: name}
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package milvusclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
)
|
||||
|
||||
type ResourceGroupSuite struct {
|
||||
MockSuiteBase
|
||||
}
|
||||
|
||||
func (s *ResourceGroupSuite) TestListResourceGroups() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
s.mock.EXPECT().ListResourceGroups(mock.Anything, mock.Anything).Return(&milvuspb.ListResourceGroupsResponse{
|
||||
ResourceGroups: []string{"rg1", "rg2"},
|
||||
}, nil).Once()
|
||||
rgs, err := s.client.ListResourceGroups(ctx, NewListResourceGroupsOption())
|
||||
s.NoError(err)
|
||||
s.Equal([]string{"rg1", "rg2"}, rgs)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().ListResourceGroups(mock.Anything, mock.Anything).Return(nil, errors.New("mocked")).Once()
|
||||
_, err := s.client.ListResourceGroups(ctx, NewListResourceGroupsOption())
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ResourceGroupSuite) TestCreateResourceGroup() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
rgName := fmt.Sprintf("rg_%s", s.randString(6))
|
||||
s.mock.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, crgr *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) {
|
||||
s.Equal(rgName, crgr.GetResourceGroup())
|
||||
s.Equal(int32(5), crgr.GetConfig().GetRequests().GetNodeNum())
|
||||
s.Equal(int32(10), crgr.GetConfig().GetLimits().GetNodeNum())
|
||||
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
|
||||
}).Once()
|
||||
opt := NewCreateResourceGroupOption(rgName).WithNodeLimit(10).WithNodeRequest(5)
|
||||
err := s.client.CreateResourceGroup(ctx, opt)
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
rgName := fmt.Sprintf("rg_%s", s.randString(6))
|
||||
s.mock.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything).Return(nil, errors.New("mocked")).Once()
|
||||
opt := NewCreateResourceGroupOption(rgName).WithNodeLimit(10).WithNodeRequest(5)
|
||||
err := s.client.CreateResourceGroup(ctx, opt)
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ResourceGroupSuite) TestDropResourceGroup() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
rgName := fmt.Sprintf("rg_%s", s.randString(6))
|
||||
s.mock.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, drgr *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) {
|
||||
s.Equal(rgName, drgr.GetResourceGroup())
|
||||
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
|
||||
}).Once()
|
||||
opt := NewDropResourceGroupOption(rgName)
|
||||
err := s.client.DropResourceGroup(ctx, opt)
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
rgName := fmt.Sprintf("rg_%s", s.randString(6))
|
||||
s.mock.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(nil, errors.New("mocked")).Once()
|
||||
opt := NewDropResourceGroupOption(rgName)
|
||||
err := s.client.DropResourceGroup(ctx, opt)
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResourceGroup(t *testing.T) {
|
||||
suite.Run(t, new(ResourceGroupSuite))
|
||||
}
|
|
@ -89,8 +89,8 @@ func (mc *MilvusClient) Close(ctx context.Context) error {
|
|||
// -- database --
|
||||
|
||||
// UsingDatabase list all database in milvus cluster.
|
||||
func (mc *MilvusClient) UsingDatabase(ctx context.Context, option client.UsingDatabaseOption) error {
|
||||
err := mc.mClient.UsingDatabase(ctx, option)
|
||||
func (mc *MilvusClient) UsingDatabase(ctx context.Context, option client.UseDatabaseOption) error {
|
||||
err := mc.mClient.UseDatabase(ctx, option)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ require (
|
|||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69 // indirect
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/opencontainers/runtime-spec v1.0.2 // indirect
|
||||
|
|
|
@ -318,8 +318,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr
|
|||
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
|
||||
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
|
||||
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69 h1:Qt0Bv2Fum3EX3OlkuQYHJINBzeU4oEuHy2lXSfB/gZw=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b h1:iPPhnFx+s7FF53UeWj7A4EYhPRMFPL6mHqyQw7qRjeQ=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
|
||||
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84 h1:EAFxmxUVp5yYFDCrX1MQoSxkTO+ycy8NXEqEDEB3cRM=
|
||||
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84/go.mod h1:RATa0GS4jhkPpsYOvQ/QvcNz8rd+TlRPDiSyXQnMMxs=
|
||||
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
|
||||
|
|
|
@ -26,7 +26,7 @@ func teardownTest(t *testing.T) func(t *testing.T) {
|
|||
dbs, _ := mc.ListDatabases(ctx, client.NewListDatabaseOption())
|
||||
for _, db := range dbs {
|
||||
if db != common.DefaultDb {
|
||||
_ = mc.UsingDatabase(ctx, client.NewUsingDatabaseOption(db))
|
||||
_ = mc.UsingDatabase(ctx, client.NewUseDatabaseOption(db))
|
||||
collections, _ := mc.ListCollections(ctx, client.NewListCollectionOption())
|
||||
for _, coll := range collections {
|
||||
_ = mc.DropCollection(ctx, client.NewDropCollectionOption(coll))
|
||||
|
@ -57,7 +57,7 @@ func TestDatabase(t *testing.T) {
|
|||
// new client with db1 -> using db
|
||||
clientDB1 := createMilvusClient(ctx, t, &client.ClientConfig{Address: *addr, DBName: dbName1})
|
||||
t.Log("https://github.com/milvus-io/milvus/issues/34137")
|
||||
err = clientDB1.UsingDatabase(ctx, client.NewUsingDatabaseOption(dbName1))
|
||||
err = clientDB1.UsingDatabase(ctx, client.NewUseDatabaseOption(dbName1))
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
// create collections -> verify collections contains
|
||||
|
@ -77,14 +77,14 @@ func TestDatabase(t *testing.T) {
|
|||
require.Containsf(t, dbs, dbName2, fmt.Sprintf("%s db not in dbs: %v", dbName2, dbs))
|
||||
|
||||
// using db2 -> create collection -> drop collection
|
||||
err = clientDefault.UsingDatabase(ctx, client.NewUsingDatabaseOption(dbName2))
|
||||
err = clientDefault.UsingDatabase(ctx, client.NewUseDatabaseOption(dbName2))
|
||||
common.CheckErr(t, err, true)
|
||||
_, db2Col1 := hp.CollPrepare.CreateCollection(ctx, t, clientDefault, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption())
|
||||
err = clientDefault.DropCollection(ctx, client.NewDropCollectionOption(db2Col1.CollectionName))
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
// using empty db -> drop db2
|
||||
clientDefault.UsingDatabase(ctx, client.NewUsingDatabaseOption(""))
|
||||
clientDefault.UsingDatabase(ctx, client.NewUseDatabaseOption(""))
|
||||
err = clientDefault.DropDatabase(ctx, client.NewDropDatabaseOption(dbName2))
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
|
@ -98,7 +98,7 @@ func TestDatabase(t *testing.T) {
|
|||
common.CheckErr(t, err, false, "must drop all collections before drop database")
|
||||
|
||||
// drop all db1's collections -> drop db1
|
||||
clientDB1.UsingDatabase(ctx, client.NewUsingDatabaseOption(dbName1))
|
||||
clientDB1.UsingDatabase(ctx, client.NewUseDatabaseOption(dbName1))
|
||||
err = clientDB1.DropCollection(ctx, client.NewDropCollectionOption(db1Col1.CollectionName))
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
|
@ -160,7 +160,7 @@ func TestDropDb(t *testing.T) {
|
|||
common.CheckErr(t, err, true)
|
||||
|
||||
// using db and drop the db
|
||||
err = mc.UsingDatabase(ctx, client.NewUsingDatabaseOption(dbName))
|
||||
err = mc.UsingDatabase(ctx, client.NewUseDatabaseOption(dbName))
|
||||
common.CheckErr(t, err, true)
|
||||
err = mc.DropDatabase(ctx, client.NewDropDatabaseOption(dbName))
|
||||
common.CheckErr(t, err, true)
|
||||
|
@ -170,7 +170,7 @@ func TestDropDb(t *testing.T) {
|
|||
common.CheckErr(t, err, false, fmt.Sprintf("database not found[database=%s]", dbName))
|
||||
|
||||
// using default db and verify collections
|
||||
err = mc.UsingDatabase(ctx, client.NewUsingDatabaseOption(common.DefaultDb))
|
||||
err = mc.UsingDatabase(ctx, client.NewUseDatabaseOption(common.DefaultDb))
|
||||
common.CheckErr(t, err, true)
|
||||
collections, _ = mc.ListCollections(ctx, listCollOpt)
|
||||
require.Contains(t, collections, defCol.CollectionName)
|
||||
|
@ -205,17 +205,17 @@ func TestUsingDb(t *testing.T) {
|
|||
|
||||
// using not existed db
|
||||
dbName := common.GenRandomString("db", 4)
|
||||
err := mc.UsingDatabase(ctx, client.NewUsingDatabaseOption(dbName))
|
||||
err := mc.UsingDatabase(ctx, client.NewUseDatabaseOption(dbName))
|
||||
common.CheckErr(t, err, false, fmt.Sprintf("database not found[database=%s]", dbName))
|
||||
|
||||
// using empty db
|
||||
err = mc.UsingDatabase(ctx, client.NewUsingDatabaseOption(""))
|
||||
err = mc.UsingDatabase(ctx, client.NewUseDatabaseOption(""))
|
||||
common.CheckErr(t, err, true)
|
||||
collections, _ = mc.ListCollections(ctx, listCollOpt)
|
||||
require.Contains(t, collections, col.CollectionName)
|
||||
|
||||
// using current db
|
||||
err = mc.UsingDatabase(ctx, client.NewUsingDatabaseOption(common.DefaultDb))
|
||||
err = mc.UsingDatabase(ctx, client.NewUseDatabaseOption(common.DefaultDb))
|
||||
common.CheckErr(t, err, true)
|
||||
collections, _ = mc.ListCollections(ctx, listCollOpt)
|
||||
require.Contains(t, collections, col.CollectionName)
|
||||
|
@ -262,7 +262,7 @@ func TestClientWithDb(t *testing.T) {
|
|||
require.Containsf(t, dbCollections, dbCol1.CollectionName, fmt.Sprintf("The collection %s not in: %v", dbCol1.CollectionName, dbCollections))
|
||||
|
||||
// using default db and collection not in
|
||||
_ = mcDb.UsingDatabase(ctx, client.NewUsingDatabaseOption(common.DefaultDb))
|
||||
_ = mcDb.UsingDatabase(ctx, client.NewUseDatabaseOption(common.DefaultDb))
|
||||
defCollections, _ = mcDb.ListCollections(ctx, listCollOpt)
|
||||
require.NotContains(t, defCollections, dbCol1.CollectionName)
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ func teardown() {
|
|||
dbs, _ := mc.ListDatabases(ctx, clientv2.NewListDatabaseOption())
|
||||
for _, db := range dbs {
|
||||
if db != common.DefaultDb {
|
||||
_ = mc.UsingDatabase(ctx, clientv2.NewUsingDatabaseOption(db))
|
||||
_ = mc.UsingDatabase(ctx, clientv2.NewUseDatabaseOption(db))
|
||||
collections, _ := mc.ListCollections(ctx, clientv2.NewListCollectionOption())
|
||||
for _, coll := range collections {
|
||||
_ = mc.DropCollection(ctx, clientv2.NewDropCollectionOption(coll))
|
||||
|
|
Loading…
Reference in New Issue