From 629da61617bfc2246254956939ad68483455ec45 Mon Sep 17 00:00:00 2001 From: congqixia Date: Thu, 6 Mar 2025 17:48:03 +0800 Subject: [PATCH] enhance: [2.5][GoSDK] sync milvusclient patches for 2.5.1 (#40410) Cherry pick from master pr: #40268 #40284 #40328 #40373 #40381 ------------------------------------------ #### fix: [GoSDK] Pass base64 passwd content instead of raw data (#40268) Related to #40261 Also add some options for create collection options and refine some behavior ------------------------------------------ #### fix: [GoSDK] Return role without grants (#40284) Related to #40274 Previousy DescribeRole returns only roles with grants, this PR add select role action to check role existence. Also added database properties related option ----------------------------------------- #### fix: [GoSDK] Pass only valid data for nullable column (#40328) Related to #40327 ----------------------------------------- #### enhance: [GoSDK] Add DescribeReplica API & sync rbac v2 (#40373) Related to #31293 #37031 This PR: - Add DescribeReplica API - Add unified RBAC v2 API names(AddPrivilegesToGroup, RemovePrivilegesFromGroup, GrantPrivilegeV2, RevokePrivilegeV2) - Mark old ones deprecated ----------------------------------------- #### enhance: [GoSDK] support update ts caching policy(#40381) Related to #39093 This PR add update timestamp check and retry policy according to the design of the related issue ----------------------------------------- --------- Signed-off-by: Congqi Xia --- client/column/array.go | 1 + client/column/generic_base.go | 7 +- client/column/nullable.go | 13 +- client/column/nullable_test.go | 14 +- client/entity/collection.go | 3 + client/entity/field.go | 1 + client/entity/resource_group.go | 14 ++ client/index/index.go | 6 +- client/milvusclient/collection.go | 1 + .../milvusclient/collection_example_test.go | 63 ++++++++- client/milvusclient/collection_options.go | 26 +++- client/milvusclient/collection_test.go | 4 +- client/milvusclient/common.go | 31 +++++ client/milvusclient/common_test.go | 83 ++++++++++++ client/milvusclient/database_example_test.go | 80 +++++++++++ client/milvusclient/database_options.go | 14 +- client/milvusclient/rbac.go | 104 +++------------ client/milvusclient/rbac_options.go | 108 ++++++++++++--- client/milvusclient/rbac_test.go | 71 +++++++++- client/milvusclient/rbac_v2.go | 125 ++++++++++++++++++ client/milvusclient/resource_group.go | 33 +++++ client/milvusclient/resource_group_option.go | 18 +++ client/milvusclient/resource_group_test.go | 54 ++++++++ client/milvusclient/write.go | 83 ++++++------ client/milvusclient/write_options.go | 18 +-- 25 files changed, 805 insertions(+), 170 deletions(-) create mode 100644 client/milvusclient/common_test.go create mode 100644 client/milvusclient/database_example_test.go create mode 100644 client/milvusclient/rbac_v2.go diff --git a/client/column/array.go b/client/column/array.go index 90251fe689..4f25e41793 100644 --- a/client/column/array.go +++ b/client/column/array.go @@ -32,6 +32,7 @@ func (c *columnArrayBase[T]) FieldData() *schemapb.FieldData { fd := &schemapb.FieldData{ Type: schemapb.DataType_Array, FieldName: c.name, + ValidData: c.validData, } data := make([]*schemapb.ScalarField, 0, c.Len()) diff --git a/client/column/generic_base.go b/client/column/generic_base.go index e2e001828f..1d5021a865 100644 --- a/client/column/generic_base.go +++ b/client/column/generic_base.go @@ -49,6 +49,9 @@ func (c *genericColumnBase[T]) Type() entity.FieldType { } func (c *genericColumnBase[T]) Len() int { + if c.validData != nil { + return len(c.validData) + } return len(c.values) } @@ -166,9 +169,9 @@ func (c *genericColumnBase[T]) AppendNull() error { if !c.nullable { return errors.New("append null to not nullable column") } - var v T + // var v T c.validData = append(c.validData, true) - c.values = append(c.values, v) + // c.values = append(c.values, v) return nil } diff --git a/client/column/nullable.go b/client/column/nullable.go index 3ff68d083b..52f2b80e32 100644 --- a/client/column/nullable.go +++ b/client/column/nullable.go @@ -16,7 +16,10 @@ package column -import "github.com/cockroachdb/errors" +import ( + "github.com/cockroachdb/errors" + "github.com/samber/lo" +) var ( // scalars @@ -55,9 +58,13 @@ type NullableColumnCreator[col interface { func (c NullableColumnCreator[col, T]) New(name string, values []T, validData []bool) (col, error) { var result col - if len(values) != len(validData) { - return result, errors.New("values & validData slice has different length") + validCnt := lo.CountBy(validData, func(v bool) bool { + return v + }) + if validCnt != len(values) { + return result, errors.Newf("values number(%d) does not match valid count(%d)", len(values), validCnt) } + result = c.base(name, values) result.withValidData(validData) diff --git a/client/column/nullable_test.go b/client/column/nullable_test.go index eec2cbb489..c22ede658b 100644 --- a/client/column/nullable_test.go +++ b/client/column/nullable_test.go @@ -33,7 +33,7 @@ type NullableScalarSuite struct { func (s *NullableScalarSuite) TestBasic() { s.Run("nullable_bool", func() { name := fmt.Sprintf("field_%d", rand.Intn(1000)) - data := []bool{true, false} + data := []bool{false} validData := []bool{true, false} column, err := NewNullableColumnBool(name, data, validData) s.NoError(err) @@ -63,7 +63,7 @@ func (s *NullableScalarSuite) TestBasic() { s.Run("nullable_int8", func() { name := fmt.Sprintf("field_%d", rand.Intn(1000)) - data := []int8{1, 2, 3} + data := []int8{1, 3} validData := []bool{true, false, true} column, err := NewNullableColumnInt8(name, data, validData) s.NoError(err) @@ -93,7 +93,7 @@ func (s *NullableScalarSuite) TestBasic() { s.Run("nullable_int16", func() { name := fmt.Sprintf("field_%d", rand.Intn(1000)) - data := []int16{1, 2, 3} + data := []int16{1, 3} validData := []bool{true, false, true} column, err := NewNullableColumnInt16(name, data, validData) s.NoError(err) @@ -123,7 +123,7 @@ func (s *NullableScalarSuite) TestBasic() { s.Run("nullable_int32", func() { name := fmt.Sprintf("field_%d", rand.Intn(1000)) - data := []int32{1, 2, 3} + data := []int32{1, 3} validData := []bool{true, false, true} column, err := NewNullableColumnInt32(name, data, validData) s.NoError(err) @@ -153,7 +153,7 @@ func (s *NullableScalarSuite) TestBasic() { s.Run("nullable_int64", func() { name := fmt.Sprintf("field_%d", rand.Intn(1000)) - data := []int64{1, 2, 3} + data := []int64{1, 3} validData := []bool{true, false, true} column, err := NewNullableColumnInt64(name, data, validData) s.NoError(err) @@ -183,7 +183,7 @@ func (s *NullableScalarSuite) TestBasic() { s.Run("nullable_float", func() { name := fmt.Sprintf("field_%d", rand.Intn(1000)) - data := []float32{0.1, 0.2, 0.3} + data := []float32{0.1, 0.3} validData := []bool{true, false, true} column, err := NewNullableColumnFloat(name, data, validData) s.NoError(err) @@ -213,7 +213,7 @@ func (s *NullableScalarSuite) TestBasic() { s.Run("nullable_double", func() { name := fmt.Sprintf("field_%d", rand.Intn(1000)) - data := []float64{0.1, 0.2, 0.3} + data := []float64{0.1, 0.3} validData := []bool{true, false, true} column, err := NewNullableColumnDouble(name, data, validData) s.NoError(err) diff --git a/client/entity/collection.go b/client/entity/collection.go index f30cc05f59..cd4acc5030 100644 --- a/client/entity/collection.go +++ b/client/entity/collection.go @@ -33,6 +33,9 @@ type Collection struct { ConsistencyLevel ConsistencyLevel ShardNum int32 Properties map[string]string + + // collection update timestamp, usually used for internal change detection + UpdateTimestamp uint64 } // Partition represent partition meta in Milvus diff --git a/client/entity/field.go b/client/entity/field.go index b734b7c2f0..847fbcbc0d 100644 --- a/client/entity/field.go +++ b/client/entity/field.go @@ -212,6 +212,7 @@ func (f *Field) ProtoMessage() *schemapb.FieldSchema { IsPartitionKey: f.IsPartitionKey, IsClusteringKey: f.IsClusteringKey, ElementType: schemapb.DataType(f.ElementType), + Nullable: f.Nullable, } } diff --git a/client/entity/resource_group.go b/client/entity/resource_group.go index 09d5af6c36..7fabb3a687 100644 --- a/client/entity/resource_group.go +++ b/client/entity/resource_group.go @@ -52,3 +52,17 @@ type ResourceGroupConfig struct { TransferTo []*ResourceGroupTransfer NodeFilter ResourceGroupNodeFilter } + +type ReplicaInfo struct { + ReplicaID int64 + Shards []*Shard + Nodes []int64 + ResourceGroupName string + NumOutboundNode map[string]int32 +} + +type Shard struct { + ChannelName string + ShardNodes []int64 + ShardLeader int64 +} diff --git a/client/index/index.go b/client/index/index.go index e04b92b3f6..95553dd2ca 100644 --- a/client/index/index.go +++ b/client/index/index.go @@ -68,8 +68,12 @@ func (gi GenericIndex) Params() map[string]string { return m } +func (gi GenericIndex) WithMetricType(metricType MetricType) { + gi.baseIndex.metricType = metricType +} + // NewGenericIndex create generic index instance -func NewGenericIndex(name string, params map[string]string) Index { +func NewGenericIndex(name string, params map[string]string) GenericIndex { return GenericIndex{ baseIndex: baseIndex{ name: name, diff --git a/client/milvusclient/collection.go b/client/milvusclient/collection.go index 69f5da3dc2..94ad0830b9 100644 --- a/client/milvusclient/collection.go +++ b/client/milvusclient/collection.go @@ -95,6 +95,7 @@ func (c *Client) DescribeCollection(ctx context.Context, option DescribeCollecti ConsistencyLevel: entity.ConsistencyLevel(resp.ConsistencyLevel), ShardNum: resp.GetShardsNum(), Properties: entity.KvPairsMap(resp.GetProperties()), + UpdateTimestamp: resp.GetUpdateTimestamp(), } collection.Name = collection.Schema.CollectionName return nil diff --git a/client/milvusclient/collection_example_test.go b/client/milvusclient/collection_example_test.go index ec6420d200..3621f698f0 100644 --- a/client/milvusclient/collection_example_test.go +++ b/client/milvusclient/collection_example_test.go @@ -154,6 +154,67 @@ func ExampleClient_CreateCollection_ttl() { } } +func ExampleClient_CreateCollection_quickSetup() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + collectionName := `quick_setup_1` + cli, err := milvusclient.New(ctx, &milvusclient.ClientConfig{ + Address: milvusAddr, + }) + if err != nil { + // handle err + } + + err = cli.CreateCollection(ctx, milvusclient.SimpleCreateCollectionOptions(collectionName, 512)) + if err != nil { + // handle error + } +} + +func ExampleClient_CreateCollection_quickSetupWithIndexParams() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + collectionName := `quick_setup_2` + cli, err := milvusclient.New(ctx, &milvusclient.ClientConfig{ + Address: milvusAddr, + }) + if err != nil { + // handle err + } + + err = cli.CreateCollection(ctx, milvusclient.SimpleCreateCollectionOptions(collectionName, 512).WithIndexOptions( + milvusclient.NewCreateIndexOption(collectionName, "vector", index.NewHNSWIndex(entity.L2, 64, 128)), + )) + if err != nil { + log.Println(err.Error()) + // handle error + } +} + +func ExampleClient_CreateCollection_quickSetupCustomize() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + collectionName := `quick_setup_3` + cli, err := milvusclient.New(ctx, &milvusclient.ClientConfig{ + Address: milvusAddr, + }) + if err != nil { + // handle err + } + + err = cli.CreateCollection(ctx, milvusclient.SimpleCreateCollectionOptions(collectionName, 512). + WithVarcharPK(true, 64). + WithShardNum(1), + ) + if err != nil { + log.Println(err.Error()) + // handle error + } +} + func ExampleClient_CreateCollection_consistencyLevel() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -245,7 +306,7 @@ func ExampleClient_RenameCollection() { } } -func ExampleClient_AlterCollection_setTTL() { +func ExampleClient_AlterCollectionProperties_setTTL() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/client/milvusclient/collection_options.go b/client/milvusclient/collection_options.go index 1bd8029d63..fc271aa1f2 100644 --- a/client/milvusclient/collection_options.go +++ b/client/milvusclient/collection_options.go @@ -88,7 +88,7 @@ func (opt *createCollectionOption) WithVarcharPK(varcharPK bool, maxLen int) *cr } func (opt *createCollectionOption) WithIndexOptions(indexOpts ...CreateIndexOption) *createCollectionOption { - opt.indexOptions = append(opt.indexOptions, indexOpts...) + opt.indexOptions = indexOpts return opt } @@ -102,6 +102,26 @@ func (opt *createCollectionOption) WithConsistencyLevel(cl entity.ConsistencyLev return opt } +func (opt *createCollectionOption) WithMetricType(metricType entity.MetricType) *createCollectionOption { + opt.metricType = metricType + return opt +} + +func (opt *createCollectionOption) WithPKFieldName(name string) *createCollectionOption { + opt.pkFieldName = name + return opt +} + +func (opt *createCollectionOption) WithVectorFieldName(name string) *createCollectionOption { + opt.vectorFieldName = name + return opt +} + +func (opt *createCollectionOption) WithNumPartitions(numPartitions int64) *createCollectionOption { + opt.numPartitions = numPartitions + return opt +} + func (opt *createCollectionOption) Request() *milvuspb.CreateCollectionRequest { // fast create collection if opt.isFast { @@ -140,12 +160,12 @@ func (opt *createCollectionOption) Request() *milvuspb.CreateCollectionRequest { func (opt *createCollectionOption) Indexes() []CreateIndexOption { // fast create - if opt.isFast { + if opt.isFast && opt.indexOptions == nil { return []CreateIndexOption{ NewCreateIndexOption(opt.name, opt.vectorFieldName, index.NewGenericIndex("", map[string]string{})), } } - return nil + return opt.indexOptions } func (opt *createCollectionOption) IsFast() bool { diff --git a/client/milvusclient/collection_test.go b/client/milvusclient/collection_test.go index fc207ae825..327467ad97 100644 --- a/client/milvusclient/collection_test.go +++ b/client/milvusclient/collection_test.go @@ -104,7 +104,9 @@ func (s *CollectionSuite) TestCreateCollectionOptions() { s.True(collSchema.GetEnableDynamicField()) collectionName = fmt.Sprintf("test_collection_%s", s.randString(6)) - opt = SimpleCreateCollectionOptions(collectionName, 128).WithVarcharPK(true, 64).WithAutoID(false).WithDynamicSchema(false) + opt = SimpleCreateCollectionOptions(collectionName, 128).WithVarcharPK(true, 64).WithAutoID(false). + WithPKFieldName("pk").WithVectorFieldName("embedding").WithMetricType(entity.L2). + WithDynamicSchema(false) req = opt.Request() s.Equal(collectionName, req.GetCollectionName()) s.EqualValues(1, req.GetShardsNum()) diff --git a/client/milvusclient/common.go b/client/milvusclient/common.go index ea3e8e6027..7b9eb6ae2d 100644 --- a/client/milvusclient/common.go +++ b/client/milvusclient/common.go @@ -2,9 +2,14 @@ package milvusclient import ( "context" + "math" + + "github.com/cockroachdb/errors" "github.com/milvus-io/milvus/client/v2/entity" "github.com/milvus-io/milvus/pkg/v2/util/conc" + "github.com/milvus-io/milvus/pkg/v2/util/merr" + "github.com/milvus-io/milvus/pkg/v2/util/retry" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -32,6 +37,11 @@ func (c *CollectionCache) GetCollection(ctx context.Context, collName string) (* return coll, err } +// Evict removes the collection cache related to the provided collection name. +func (c *CollectionCache) Evict(collName string) { + c.collections.Remove(collName) +} + // Reset clears all cached info, used when client switching env. func (c *CollectionCache) Reset() { c.collections = typeutil.NewConcurrentMap[string, *entity.Collection]() @@ -47,3 +57,24 @@ func NewCollectionCache(fetcher func(context.Context, string) (*entity.Collectio func (c *Client) getCollection(ctx context.Context, collName string) (*entity.Collection, error) { return c.collCache.GetCollection(ctx, collName) } + +func (c *Client) retryIfSchemaError(ctx context.Context, collName string, work func(ctx context.Context) (uint64, error)) error { + var lastTs uint64 = math.MaxUint64 + return retry.Handle(ctx, func() (bool, error) { + ts, err := work(ctx) + if err != nil { + // if schema error + if errors.Is(err, merr.ErrCollectionSchemaMismatch) { + sameTs := ts == lastTs + lastTs = ts + if !sameTs { + c.collCache.Evict(collName) + } + // retry if not same ts + return !sameTs, err + } + return false, err + } + return false, nil + }) +} diff --git a/client/milvusclient/common_test.go b/client/milvusclient/common_test.go new file mode 100644 index 0000000000..4128cedea6 --- /dev/null +++ b/client/milvusclient/common_test.go @@ -0,0 +1,83 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package milvusclient + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/pkg/v2/util/merr" +) + +type CommonSuite struct { + MockSuiteBase +} + +func (s *CommonSuite) TestRetryIfSchemaError() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s.Run("normal_no_error", func() { + counter := atomic.Int32{} + err := s.client.retryIfSchemaError(ctx, "test_coll", func(ctx context.Context) (uint64, error) { + counter.Add(1) + return 10, nil + }) + s.NoError(err) + s.EqualValues(1, counter.Load()) + }) + + s.Run("other_error", func() { + counter := atomic.Int32{} + err := s.client.retryIfSchemaError(ctx, "test_coll", func(ctx context.Context) (uint64, error) { + counter.Add(1) + return 10, merr.WrapErrServiceInternal("mocked") + }) + s.Error(err) + s.EqualValues(1, counter.Load()) + }) + + s.Run("transient_schema_err", func() { + counter := atomic.Int32{} + err := s.client.retryIfSchemaError(ctx, "test_coll", func(ctx context.Context) (uint64, error) { + epoch := counter.Load() + counter.Add(1) + if epoch == 0 { + return 10, merr.WrapErrCollectionSchemaMisMatch("mocked") + } + return 11, nil + }) + s.NoError(err) + s.EqualValues(2, counter.Load()) + }) + + s.Run("consistent_schema_err", func() { + counter := atomic.Int32{} + err := s.client.retryIfSchemaError(ctx, "test_coll", func(ctx context.Context) (uint64, error) { + counter.Add(1) + return 10, merr.WrapErrCollectionSchemaMisMatch("mocked") + }) + s.Error(err) + s.EqualValues(2, counter.Load()) + }) +} + +func TestCommonFunc(t *testing.T) { + suite.Run(t, new(CommonSuite)) +} diff --git a/client/milvusclient/database_example_test.go b/client/milvusclient/database_example_test.go new file mode 100644 index 0000000000..d3fa1bd399 --- /dev/null +++ b/client/milvusclient/database_example_test.go @@ -0,0 +1,80 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// nolint +package milvusclient_test + +import ( + "context" + "log" + + "github.com/milvus-io/milvus/client/v2/milvusclient" +) + +func ExampleClient_CreateDatabase() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dbName := `test_db` + cli, err := milvusclient.New(ctx, &milvusclient.ClientConfig{ + Address: milvusAddr, + }) + if err != nil { + // handle err + } + + err = cli.CreateDatabase(ctx, milvusclient.NewCreateDatabaseOption(dbName)) + if err != nil { + // handle err + } +} + +func ExampleClient_CreateDatabase_withProperties() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dbName := `test_db_2` + cli, err := milvusclient.New(ctx, &milvusclient.ClientConfig{ + Address: milvusAddr, + }) + if err != nil { + // handle err + } + + err = cli.CreateDatabase(ctx, milvusclient.NewCreateDatabaseOption(dbName).WithProperty("database.replica.number", 3)) + if err != nil { + // handle err + } +} + +func ExampleClient_DescribeDatabase() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dbName := `test_db` + cli, err := milvusclient.New(ctx, &milvusclient.ClientConfig{ + Address: milvusAddr, + }) + if err != nil { + // handle err + } + + db, err := cli.DescribeDatabase(ctx, milvusclient.NewDescribeDatabaseOption(dbName)) + if err != nil { + // handle err + } + log.Println(db) +} diff --git a/client/milvusclient/database_options.go b/client/milvusclient/database_options.go index 4d644467ee..e0c42adef5 100644 --- a/client/milvusclient/database_options.go +++ b/client/milvusclient/database_options.go @@ -61,18 +61,26 @@ type CreateDatabaseOption interface { } type createDatabaseOption struct { - dbName string + dbName string + Properties map[string]string } func (opt *createDatabaseOption) Request() *milvuspb.CreateDatabaseRequest { return &milvuspb.CreateDatabaseRequest{ - DbName: opt.dbName, + DbName: opt.dbName, + Properties: entity.MapKvPairs(opt.Properties), } } +func (opt *createDatabaseOption) WithProperty(key string, val any) *createDatabaseOption { + opt.Properties[key] = fmt.Sprintf("%v", val) + return opt +} + func NewCreateDatabaseOption(dbName string) *createDatabaseOption { return &createDatabaseOption{ - dbName: dbName, + dbName: dbName, + Properties: make(map[string]string), } } diff --git a/client/milvusclient/rbac.go b/client/milvusclient/rbac.go index c1cff9984e..ba3892d0c9 100644 --- a/client/milvusclient/rbac.go +++ b/client/milvusclient/rbac.go @@ -128,30 +128,36 @@ func (c *Client) DropRole(ctx context.Context, opt DropRoleOption, callOpts ...g } func (c *Client) DescribeRole(ctx context.Context, option DescribeRoleOption, callOptions ...grpc.CallOption) (*entity.Role, error) { - req := option.Request() - var role *entity.Role err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { - resp, err := milvusService.SelectGrant(ctx, req, callOptions...) - if err := merr.CheckRPCCall(resp, err); err != nil { + roleResp, err := milvusService.SelectRole(ctx, option.SelectRoleRequest(), callOptions...) + if err := merr.CheckRPCCall(roleResp, err); err != nil { return err } - if len(resp.GetEntities()) == 0 { + + if len(roleResp.GetResults()) == 0 { return errors.New("role not found") } role = &entity.Role{ - RoleName: req.GetEntity().GetRole().GetName(), - Privileges: lo.Map(resp.GetEntities(), func(g *milvuspb.GrantEntity, _ int) entity.GrantItem { - return entity.GrantItem{ - Object: g.Object.GetName(), - ObjectName: g.GetObjectName(), - RoleName: g.GetRole().GetName(), - Grantor: g.GetGrantor().GetUser().GetName(), - Privilege: g.GetGrantor().GetPrivilege().GetName(), - } - }), + RoleName: roleResp.GetResults()[0].GetRole().GetName(), } + + resp, err := milvusService.SelectGrant(ctx, option.Request(), callOptions...) + if err := merr.CheckRPCCall(resp, err); err != nil { + return err + } + + role.Privileges = lo.Map(resp.GetEntities(), func(g *milvuspb.GrantEntity, _ int) entity.GrantItem { + return entity.GrantItem{ + Object: g.GetObject().GetName(), + ObjectName: g.GetObjectName(), + RoleName: g.GetRole().GetName(), + Grantor: g.GetGrantor().GetUser().GetName(), + Privilege: g.GetGrantor().GetPrivilege().GetName(), + } + }) + return nil }) return role, err @@ -174,71 +180,3 @@ func (c *Client) RevokePrivilege(ctx context.Context, option RevokePrivilegeOpti return merr.CheckRPCCall(resp, err) }) } - -func (c *Client) GrantV2(ctx context.Context, option GrantV2Option, callOptions ...grpc.CallOption) error { - req := option.Request() - - return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { - resp, err := milvusService.OperatePrivilegeV2(ctx, req, callOptions...) - return merr.CheckRPCCall(resp, err) - }) -} - -func (c *Client) RevokeV2(ctx context.Context, option RevokeV2Option, callOptions ...grpc.CallOption) error { - req := option.Request() - - return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { - resp, err := milvusService.OperatePrivilegeV2(ctx, req, callOptions...) - return merr.CheckRPCCall(resp, err) - }) -} - -func (c *Client) CreatePrivilegeGroup(ctx context.Context, option CreatePrivilegeGroupOption, callOptions ...grpc.CallOption) error { - req := option.Request() - - return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { - resp, err := milvusService.CreatePrivilegeGroup(ctx, req, callOptions...) - return merr.CheckRPCCall(resp, err) - }) -} - -func (c *Client) DropPrivilegeGroup(ctx context.Context, option DropPrivilegeGroupOption, callOptions ...grpc.CallOption) error { - req := option.Request() - - return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { - resp, err := milvusService.DropPrivilegeGroup(ctx, req, callOptions...) - return merr.CheckRPCCall(resp, err) - }) -} - -func (c *Client) ListPrivilegeGroups(ctx context.Context, option ListPrivilegeGroupsOption, callOptions ...grpc.CallOption) ([]*entity.PrivilegeGroup, error) { - req := option.Request() - - var privilegeGroups []*entity.PrivilegeGroup - err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { - r, err := milvusService.ListPrivilegeGroups(ctx, req, callOptions...) - if err != nil { - return err - } - for _, pg := range r.PrivilegeGroups { - privileges := lo.Map(pg.Privileges, func(p *milvuspb.PrivilegeEntity, _ int) string { - return p.Name - }) - privilegeGroups = append(privilegeGroups, &entity.PrivilegeGroup{ - GroupName: pg.GroupName, - Privileges: privileges, - }) - } - return nil - }) - return privilegeGroups, err -} - -func (c *Client) OperatePrivilegeGroup(ctx context.Context, option OperatePrivilegeGroupOption, callOptions ...grpc.CallOption) error { - req := option.Request() - - return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { - resp, err := milvusService.OperatePrivilegeGroup(ctx, req, callOptions...) - return merr.CheckRPCCall(resp, err) - }) -} diff --git a/client/milvusclient/rbac_options.go b/client/milvusclient/rbac_options.go index cacb72718d..e071ba2883 100644 --- a/client/milvusclient/rbac_options.go +++ b/client/milvusclient/rbac_options.go @@ -17,7 +17,10 @@ package milvusclient import ( + "github.com/samber/lo" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/v2/util/crypto" ) type ListUserOption interface { @@ -71,7 +74,7 @@ type createUserOption struct { func (opt *createUserOption) Request() *milvuspb.CreateCredentialRequest { return &milvuspb.CreateCredentialRequest{ Username: opt.userName, - Password: opt.password, + Password: crypto.Base64Encode(opt.password), } } @@ -95,8 +98,8 @@ type updatePasswordOption struct { func (opt *updatePasswordOption) Request() *milvuspb.UpdateCredentialRequest { return &milvuspb.UpdateCredentialRequest{ Username: opt.userName, - OldPassword: opt.oldPassword, - NewPassword: opt.newPassword, + OldPassword: crypto.Base64Encode(opt.oldPassword), + NewPassword: crypto.Base64Encode(opt.newPassword), } } @@ -233,6 +236,7 @@ func NewDropRoleOption(roleName string) *dropDropRoleOption { } type DescribeRoleOption interface { + SelectRoleRequest() *milvuspb.SelectRoleRequest Request() *milvuspb.SelectGrantRequest } @@ -240,6 +244,14 @@ type describeRoleOption struct { roleName string } +func (opt *describeRoleOption) SelectRoleRequest() *milvuspb.SelectRoleRequest { + return &milvuspb.SelectRoleRequest{ + Role: &milvuspb.RoleEntity{ + Name: opt.roleName, + }, + } +} + func (opt *describeRoleOption) Request() *milvuspb.SelectGrantRequest { return &milvuspb.SelectGrantRequest{ Entity: &milvuspb.GrantEntity{ @@ -328,19 +340,21 @@ func NewRevokePrivilegeOption(roleName, objectType, privilegeName, objectName st } } -// GrantV2Option is the interface builds OperatePrivilegeV2Request -type GrantV2Option interface { +type GrantV2Option GrantPrivilegeV2Option + +// GrantPrivilegeV2Option is the interface builds OperatePrivilegeV2Request +type GrantPrivilegeV2Option interface { Request() *milvuspb.OperatePrivilegeV2Request } -type grantV2Option struct { +type grantPrivilegeV2Option struct { roleName string privilegeName string dbName string collectionName string } -func (opt *grantV2Option) Request() *milvuspb.OperatePrivilegeV2Request { +func (opt *grantPrivilegeV2Option) Request() *milvuspb.OperatePrivilegeV2Request { return &milvuspb.OperatePrivilegeV2Request{ Role: &milvuspb.RoleEntity{Name: opt.roleName}, Grantor: &milvuspb.GrantorEntity{ @@ -352,8 +366,13 @@ func (opt *grantV2Option) Request() *milvuspb.OperatePrivilegeV2Request { } } -func NewGrantV2Option(roleName, privilegeName, dbName, collectionName string) *grantV2Option { - return &grantV2Option{ +// Deprecated, use `NewGrantPrivilegeV2Option` instead +func NewGrantV2Option(roleName, privilegeName, dbName, collectionName string) *grantPrivilegeV2Option { + return NewGrantPrivilegeV2Option(roleName, privilegeName, dbName, collectionName) +} + +func NewGrantPrivilegeV2Option(roleName, privilegeName, dbName, collectionName string) *grantPrivilegeV2Option { + return &grantPrivilegeV2Option{ roleName: roleName, privilegeName: privilegeName, dbName: dbName, @@ -361,19 +380,21 @@ func NewGrantV2Option(roleName, privilegeName, dbName, collectionName string) *g } } -// RevokeV2Option is the interface builds OperatePrivilegeV2Request -type RevokeV2Option interface { +type RevokeV2Option RevokePrivilegeV2Option + +// RevokePrivilegeV2Option is the interface builds OperatePrivilegeV2Request +type RevokePrivilegeV2Option interface { Request() *milvuspb.OperatePrivilegeV2Request } -type revokeV2Option struct { +type revokePrivilegeV2Option struct { roleName string privilegeName string dbName string collectionName string } -func (opt *revokeV2Option) Request() *milvuspb.OperatePrivilegeV2Request { +func (opt *revokePrivilegeV2Option) Request() *milvuspb.OperatePrivilegeV2Request { return &milvuspb.OperatePrivilegeV2Request{ Role: &milvuspb.RoleEntity{Name: opt.roleName}, Grantor: &milvuspb.GrantorEntity{ @@ -385,8 +406,8 @@ func (opt *revokeV2Option) Request() *milvuspb.OperatePrivilegeV2Request { } } -func NewRevokeV2Option(roleName, privilegeName, dbName, collectionName string) *revokeV2Option { - return &revokeV2Option{ +func NewRevokeV2Option(roleName, privilegeName, dbName, collectionName string) *revokePrivilegeV2Option { + return &revokePrivilegeV2Option{ roleName: roleName, privilegeName: privilegeName, dbName: dbName, @@ -470,6 +491,7 @@ func (opt *operatePrivilegeGroupOption) Request() *milvuspb.OperatePrivilegeGrou } } +// Deprecated, use AddPrivilegeToGroupOption/ RemovePrivilegeFromGroupOption instead func NewOperatePrivilegeGroupOption(groupName string, privileges []*milvuspb.PrivilegeEntity, operateType milvuspb.OperatePrivilegeGroupType) *operatePrivilegeGroupOption { return &operatePrivilegeGroupOption{ groupName: groupName, @@ -477,3 +499,59 @@ func NewOperatePrivilegeGroupOption(groupName string, privileges []*milvuspb.Pri operateType: operateType, } } + +type AddPrivilegeToGroupOption interface { + Request() *milvuspb.OperatePrivilegeGroupRequest +} + +type addPrivilegeToGroupOption struct { + privileges []string + groupName string +} + +func (opt *addPrivilegeToGroupOption) Request() *milvuspb.OperatePrivilegeGroupRequest { + return &milvuspb.OperatePrivilegeGroupRequest{ + GroupName: opt.groupName, + Privileges: lo.Map(opt.privileges, func(privilege string, _ int) *milvuspb.PrivilegeEntity { + return &milvuspb.PrivilegeEntity{ + Name: privilege, + } + }), + Type: milvuspb.OperatePrivilegeGroupType_AddPrivilegesToGroup, + } +} + +func NewAddPrivilegesToGroupOption(groupName string, privileges ...string) *addPrivilegeToGroupOption { + return &addPrivilegeToGroupOption{ + groupName: groupName, + privileges: privileges, + } +} + +type RemovePrivilegeFromGroupOption interface { + Request() *milvuspb.OperatePrivilegeGroupRequest +} + +type removePrivilegeFromGroupOption struct { + privileges []string + groupName string +} + +func (opt *removePrivilegeFromGroupOption) Request() *milvuspb.OperatePrivilegeGroupRequest { + return &milvuspb.OperatePrivilegeGroupRequest{ + GroupName: opt.groupName, + Privileges: lo.Map(opt.privileges, func(privilege string, _ int) *milvuspb.PrivilegeEntity { + return &milvuspb.PrivilegeEntity{ + Name: privilege, + } + }), + Type: milvuspb.OperatePrivilegeGroupType_RemovePrivilegesFromGroup, + } +} + +func NewRemovePrivilegesFromGroupOption(groupName string, privileges ...string) *removePrivilegeFromGroupOption { + return &removePrivilegeFromGroupOption{ + groupName: groupName, + privileges: privileges, + } +} diff --git a/client/milvusclient/rbac_test.go b/client/milvusclient/rbac_test.go index 5a91c28944..d82cd5017b 100644 --- a/client/milvusclient/rbac_test.go +++ b/client/milvusclient/rbac_test.go @@ -26,6 +26,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/v2/util/crypto" "github.com/milvus-io/milvus/pkg/v2/util/merr" ) @@ -102,7 +103,7 @@ func (s *UserSuite) TestCreateUser() { password := s.randString(12) s.mock.EXPECT().CreateCredential(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ccr *milvuspb.CreateCredentialRequest) (*commonpb.Status, error) { s.Equal(userName, ccr.GetUsername()) - s.Equal(password, ccr.GetPassword()) + s.Equal(crypto.Base64Encode(password), ccr.GetPassword()) return merr.Success(), nil }).Once() @@ -121,8 +122,8 @@ func (s *UserSuite) TestUpdatePassword() { newPassword := s.randString(12) s.mock.EXPECT().UpdateCredential(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ucr *milvuspb.UpdateCredentialRequest) (*commonpb.Status, error) { s.Equal(userName, ucr.GetUsername()) - s.Equal(oldPassword, ucr.GetOldPassword()) - s.Equal(newPassword, ucr.GetNewPassword()) + s.Equal(crypto.Base64Encode(oldPassword), ucr.GetOldPassword()) + s.Equal(crypto.Base64Encode(newPassword), ucr.GetNewPassword()) return merr.Success(), nil }).Once() @@ -298,6 +299,16 @@ func (s *RoleSuite) TestDescribeRole() { s.Run("success", func() { roleName := fmt.Sprintf("role_%s", s.randString(5)) + s.mock.EXPECT().SelectRole(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) { + s.Equal(roleName, r.GetRole().GetName()) + return &milvuspb.SelectRoleResponse{ + Results: []*milvuspb.RoleResult{ + { + Role: &milvuspb.RoleEntity{Name: roleName}, + }, + }, + }, nil + }).Once() s.mock.EXPECT().SelectGrant(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) { s.Equal(roleName, r.GetEntity().GetRole().GetName()) return &milvuspb.SelectGrantResponse{ @@ -328,7 +339,7 @@ func (s *RoleSuite) TestDescribeRole() { }) s.Run("failure", func() { - s.mock.EXPECT().SelectGrant(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + s.mock.EXPECT().SelectRole(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() _, err := s.client.DescribeRole(ctx, NewDescribeRoleOption("role")) s.Error(err) @@ -574,6 +585,58 @@ func (s *PrivilegeGroupSuite) TestOperatePrivilegeGroup() { }) } +func (s *PrivilegeGroupSuite) TestAddPrivilegesToGroup() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + groupName := fmt.Sprintf("test_pg_%s", s.randString(6)) + privileges := []string{"Insert", "Query"} + + s.Run("success", func() { + s.mock.EXPECT().OperatePrivilegeGroup(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.OperatePrivilegeGroupRequest) (*commonpb.Status, error) { + s.Equal(groupName, r.GetGroupName()) + s.Equal(milvuspb.OperatePrivilegeGroupType_AddPrivilegesToGroup, r.GetType()) + return merr.Success(), nil + }).Once() + + err := s.client.AddPrivilegesToGroup(ctx, NewAddPrivilegesToGroupOption(groupName, privileges...)) + s.NoError(err) + }) + + s.Run("failure", func() { + s.mock.EXPECT().OperatePrivilegeGroup(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + err := s.client.AddPrivilegesToGroup(ctx, NewAddPrivilegesToGroupOption(groupName, privileges...)) + s.Error(err) + }) +} + +func (s *PrivilegeGroupSuite) TestRemovePrivilegesFromGroup() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + groupName := fmt.Sprintf("test_pg_%s", s.randString(6)) + privileges := []string{"Insert", "Query"} + + s.Run("success", func() { + s.mock.EXPECT().OperatePrivilegeGroup(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, r *milvuspb.OperatePrivilegeGroupRequest) (*commonpb.Status, error) { + s.Equal(groupName, r.GetGroupName()) + s.Equal(milvuspb.OperatePrivilegeGroupType_RemovePrivilegesFromGroup, r.GetType()) + return merr.Success(), nil + }).Once() + + err := s.client.RemovePrivilegesFromGroup(ctx, NewRemovePrivilegesFromGroupOption(groupName, privileges...)) + s.NoError(err) + }) + + s.Run("failure", func() { + s.mock.EXPECT().OperatePrivilegeGroup(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + err := s.client.RemovePrivilegesFromGroup(ctx, NewRemovePrivilegesFromGroupOption(groupName, privileges...)) + s.Error(err) + }) +} + func TestPrivilegeGroup(t *testing.T) { suite.Run(t, new(PrivilegeGroupSuite)) } diff --git a/client/milvusclient/rbac_v2.go b/client/milvusclient/rbac_v2.go new file mode 100644 index 0000000000..e260c89afb --- /dev/null +++ b/client/milvusclient/rbac_v2.go @@ -0,0 +1,125 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package milvusclient + +import ( + "context" + + "github.com/samber/lo" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/client/v2/entity" + "github.com/milvus-io/milvus/pkg/v2/util/merr" +) + +func (c *Client) CreatePrivilegeGroup(ctx context.Context, option CreatePrivilegeGroupOption, callOptions ...grpc.CallOption) error { + req := option.Request() + + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.CreatePrivilegeGroup(ctx, req, callOptions...) + return merr.CheckRPCCall(resp, err) + }) +} + +func (c *Client) DropPrivilegeGroup(ctx context.Context, option DropPrivilegeGroupOption, callOptions ...grpc.CallOption) error { + req := option.Request() + + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.DropPrivilegeGroup(ctx, req, callOptions...) + return merr.CheckRPCCall(resp, err) + }) +} + +func (c *Client) ListPrivilegeGroups(ctx context.Context, option ListPrivilegeGroupsOption, callOptions ...grpc.CallOption) ([]*entity.PrivilegeGroup, error) { + req := option.Request() + + var privilegeGroups []*entity.PrivilegeGroup + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + r, err := milvusService.ListPrivilegeGroups(ctx, req, callOptions...) + if err != nil { + return err + } + for _, pg := range r.PrivilegeGroups { + privileges := lo.Map(pg.Privileges, func(p *milvuspb.PrivilegeEntity, _ int) string { + return p.Name + }) + privilegeGroups = append(privilegeGroups, &entity.PrivilegeGroup{ + GroupName: pg.GroupName, + Privileges: privileges, + }) + } + return nil + }) + return privilegeGroups, err +} + +func (c *Client) AddPrivilegesToGroup(ctx context.Context, option AddPrivilegeToGroupOption, callOptions ...grpc.CallOption) error { + req := option.Request() + + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.OperatePrivilegeGroup(ctx, req, callOptions...) + return merr.CheckRPCCall(resp, err) + }) +} + +func (c *Client) RemovePrivilegesFromGroup(ctx context.Context, option RemovePrivilegeFromGroupOption, callOptions ...grpc.CallOption) error { + req := option.Request() + + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.OperatePrivilegeGroup(ctx, req, callOptions...) + return merr.CheckRPCCall(resp, err) + }) +} + +func (c *Client) GrantPrivilegeV2(ctx context.Context, option GrantPrivilegeV2Option, callOptions ...grpc.CallOption) error { + req := option.Request() + + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.OperatePrivilegeV2(ctx, req, callOptions...) + return merr.CheckRPCCall(resp, err) + }) +} + +func (c *Client) RevokePrivilegeV2(ctx context.Context, option RevokePrivilegeV2Option, callOptions ...grpc.CallOption) error { + req := option.Request() + + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.OperatePrivilegeV2(ctx, req, callOptions...) + return merr.CheckRPCCall(resp, err) + }) +} + +// Deprecated, use `AddPrivilegesToGroup` or `RemovePrivilegesFromGroup` instead +func (c *Client) OperatePrivilegeGroup(ctx context.Context, option OperatePrivilegeGroupOption, callOptions ...grpc.CallOption) error { + req := option.Request() + + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.OperatePrivilegeGroup(ctx, req, callOptions...) + return merr.CheckRPCCall(resp, err) + }) +} + +// Deprecated, use `GrantPrivilegeV2` instead +func (c *Client) GrantV2(ctx context.Context, option GrantV2Option, callOptions ...grpc.CallOption) error { + return c.GrantPrivilegeV2(ctx, option, callOptions...) +} + +// Deprecated, use `RevokePrivilegeV2` instead +func (c *Client) RevokeV2(ctx context.Context, option RevokeV2Option, callOptions ...grpc.CallOption) error { + return c.RevokePrivilegeV2(ctx, option, callOptions...) +} diff --git a/client/milvusclient/resource_group.go b/client/milvusclient/resource_group.go index 8f75917b48..9efeafa0c8 100644 --- a/client/milvusclient/resource_group.go +++ b/client/milvusclient/resource_group.go @@ -142,3 +142,36 @@ func (c *Client) TransferReplica(ctx context.Context, opt TransferReplicaOption, return err } + +func (c *Client) DescribeReplica(ctx context.Context, opt DescribeReplicaOption, callOptions ...grpc.CallOption) ([]*entity.ReplicaInfo, error) { + req := opt.Request() + + var result []*entity.ReplicaInfo + + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.GetReplicas(ctx, req, callOptions...) + + if err := merr.CheckRPCCall(resp, err); err != nil { + return err + } + + result = lo.Map(resp.GetReplicas(), func(replica *milvuspb.ReplicaInfo, _ int) *entity.ReplicaInfo { + return &entity.ReplicaInfo{ + ReplicaID: replica.GetReplicaID(), + Shards: lo.Map(replica.GetShardReplicas(), func(shardReplica *milvuspb.ShardReplica, _ int) *entity.Shard { + return &entity.Shard{ + ChannelName: shardReplica.GetDmChannelName(), + ShardNodes: shardReplica.GetNodeIds(), + ShardLeader: shardReplica.GetLeaderID(), + } + }), + Nodes: replica.GetNodeIds(), + ResourceGroupName: replica.GetResourceGroupName(), + NumOutboundNode: replica.GetNumOutboundNode(), + } + }) + + return nil + }) + return result, err +} diff --git a/client/milvusclient/resource_group_option.go b/client/milvusclient/resource_group_option.go index 6c71405591..1e403521d2 100644 --- a/client/milvusclient/resource_group_option.go +++ b/client/milvusclient/resource_group_option.go @@ -191,3 +191,21 @@ func NewTransferReplicaOption(collectionName, sourceGroup, targetGroup string, r replicaNum: replicaNum, } } + +type DescribeReplicaOption interface { + Request() *milvuspb.GetReplicasRequest +} + +type describeReplicaOption struct { + collectionName string +} + +func (opt *describeReplicaOption) Request() *milvuspb.GetReplicasRequest { + return &milvuspb.GetReplicasRequest{ + CollectionName: opt.collectionName, + } +} + +func NewDescribeReplicaOption(collectionName string) *describeReplicaOption { + return &describeReplicaOption{collectionName: collectionName} +} diff --git a/client/milvusclient/resource_group_test.go b/client/milvusclient/resource_group_test.go index 2eac1478cf..86879b957e 100644 --- a/client/milvusclient/resource_group_test.go +++ b/client/milvusclient/resource_group_test.go @@ -271,6 +271,60 @@ func (s *ResourceGroupSuite) TestTransferReplica() { }) } +func (s *ResourceGroupSuite) TestDescribeReplica() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + collName := fmt.Sprintf("rg_%s", s.randString(6)) + replicas := map[int64]*entity.ReplicaInfo{ + 1: { + ReplicaID: 1, + ResourceGroupName: "rg_1", + Shards: []*entity.Shard{ + {ChannelName: "dml_1", ShardNodes: []int64{1, 2, 3}, ShardLeader: 2}, + }, + Nodes: []int64{1, 2, 3}, + NumOutboundNode: map[string]int32{"dml_1": 1}, + }, + } + s.mock.EXPECT().GetReplicas(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, grr *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) { + return &milvuspb.GetReplicasResponse{ + Replicas: lo.MapToSlice(replicas, func(_ int64, r *entity.ReplicaInfo) *milvuspb.ReplicaInfo { + return &milvuspb.ReplicaInfo{ + ReplicaID: r.ReplicaID, + ShardReplicas: lo.Map(r.Shards, func(shard *entity.Shard, _ int) *milvuspb.ShardReplica { + return &milvuspb.ShardReplica{ + DmChannelName: shard.ChannelName, + NodeIds: shard.ShardNodes, + LeaderID: shard.ShardLeader, + } + }), + ResourceGroupName: r.ResourceGroupName, + NodeIds: r.Nodes, + NumOutboundNode: r.NumOutboundNode, + } + }), + }, nil + }).Once() + result, err := s.client.DescribeReplica(ctx, NewDescribeReplicaOption(collName)) + s.NoError(err) + for _, replica := range result { + expect, ok := replicas[replica.ReplicaID] + if s.True(ok) { + s.Equal(expect, replica) + } + } + }) + + s.Run("failure", func() { + collName := fmt.Sprintf("rg_%s", s.randString(6)) + s.mock.EXPECT().GetReplicas(mock.Anything, mock.Anything).Return(nil, errors.New("mock")).Once() + _, err := s.client.DescribeReplica(ctx, NewDescribeReplicaOption(collName)) + s.Error(err) + }) +} + func TestResourceGroup(t *testing.T) { suite.Run(t, new(ResourceGroupSuite)) } diff --git a/client/milvusclient/write.go b/client/milvusclient/write.go index 863d52f790..f4dc640a76 100644 --- a/client/milvusclient/write.go +++ b/client/milvusclient/write.go @@ -18,6 +18,7 @@ package milvusclient import ( "context" + "math" "google.golang.org/grpc" @@ -33,32 +34,34 @@ type InsertResult struct { func (c *Client) Insert(ctx context.Context, option InsertOption, callOptions ...grpc.CallOption) (InsertResult, error) { result := InsertResult{} - collection, err := c.getCollection(ctx, option.CollectionName()) - if err != nil { - return result, err - } - req, err := option.InsertRequest(collection) - if err != nil { - return result, err - } - - err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error { - resp, err := milvusService.Insert(ctx, req, callOptions...) - - err = merr.CheckRPCCall(resp, err) + err := c.retryIfSchemaError(ctx, option.CollectionName(), func(ctx context.Context) (uint64, error) { + collection, err := c.getCollection(ctx, option.CollectionName()) if err != nil { - return err + return math.MaxUint64, err + } + req, err := option.InsertRequest(collection) + if err != nil { + return collection.UpdateTimestamp, err } - result.InsertCount = resp.GetInsertCnt() - result.IDs, err = column.IDColumns(collection.Schema, resp.GetIDs(), 0, -1) - if err != nil { - return err - } + return collection.UpdateTimestamp, c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.Insert(ctx, req, callOptions...) - // write back pks if needed - // pks values shall be written back to struct if receiver field exists - return option.WriteBackPKs(collection.Schema, result.IDs) + err = merr.CheckRPCCall(resp, err) + if err != nil { + return err + } + + result.InsertCount = resp.GetInsertCnt() + result.IDs, err = column.IDColumns(collection.Schema, resp.GetIDs(), 0, -1) + if err != nil { + return err + } + + // write back pks if needed + // pks values shall be written back to struct if receiver field exists + return option.WriteBackPKs(collection.Schema, result.IDs) + }) }) return result, err } @@ -89,25 +92,27 @@ type UpsertResult struct { func (c *Client) Upsert(ctx context.Context, option UpsertOption, callOptions ...grpc.CallOption) (UpsertResult, error) { result := UpsertResult{} - collection, err := c.getCollection(ctx, option.CollectionName()) - if err != nil { - return result, err - } - req, err := option.UpsertRequest(collection) - if err != nil { - return result, err - } - err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error { - resp, err := milvusService.Upsert(ctx, req, callOptions...) - if err = merr.CheckRPCCall(resp, err); err != nil { - return err - } - result.UpsertCount = resp.GetUpsertCnt() - result.IDs, err = column.IDColumns(collection.Schema, resp.GetIDs(), 0, -1) + err := c.retryIfSchemaError(ctx, option.CollectionName(), func(ctx context.Context) (uint64, error) { + collection, err := c.getCollection(ctx, option.CollectionName()) if err != nil { - return err + return math.MaxUint64, err } - return nil + req, err := option.UpsertRequest(collection) + if err != nil { + return collection.UpdateTimestamp, err + } + return collection.UpdateTimestamp, c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.Upsert(ctx, req, callOptions...) + if err = merr.CheckRPCCall(resp, err); err != nil { + return err + } + result.UpsertCount = resp.GetUpsertCnt() + result.IDs, err = column.IDColumns(collection.Schema, resp.GetIDs(), 0, -1) + if err != nil { + return err + } + return nil + }) }) return result, err } diff --git a/client/milvusclient/write_options.go b/client/milvusclient/write_options.go index b05c6c30c3..3480c4b6c6 100644 --- a/client/milvusclient/write_options.go +++ b/client/milvusclient/write_options.go @@ -249,10 +249,11 @@ func (opt *columnBasedDataOption) InsertRequest(coll *entity.Collection) (*milvu return nil, err } return &milvuspb.InsertRequest{ - CollectionName: opt.collName, - PartitionName: opt.partitionName, - FieldsData: fieldsData, - NumRows: uint32(rowNum), + CollectionName: opt.collName, + PartitionName: opt.partitionName, + FieldsData: fieldsData, + NumRows: uint32(rowNum), + SchemaTimestamp: coll.UpdateTimestamp, }, nil } @@ -262,10 +263,11 @@ func (opt *columnBasedDataOption) UpsertRequest(coll *entity.Collection) (*milvu return nil, err } return &milvuspb.UpsertRequest{ - CollectionName: opt.collName, - PartitionName: opt.partitionName, - FieldsData: fieldsData, - NumRows: uint32(rowNum), + CollectionName: opt.collName, + PartitionName: opt.partitionName, + FieldsData: fieldsData, + NumRows: uint32(rowNum), + SchemaTimestamp: coll.UpdateTimestamp, }, nil }