enhance: streaming service grpc utilities (#34436)

issue: #33285

- add two grpc resolver (by session and by streaming coord assignment
service)
- add one grpc balancer (by serverID and roundrobin)
- add lazy conn to avoid block by first service discovery
- add some utility function for streaming service

Signed-off-by: chyezh <chyezh@outlook.com>
pull/34694/head
chyezh 2024-07-15 20:49:38 +08:00 committed by GitHub
parent eb472b7f08
commit fda720b880
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
65 changed files with 3343 additions and 212 deletions

View File

@ -322,6 +322,10 @@ test-metastore:
@echo "Running go unittests..."
@(env bash $(PWD)/scripts/run_go_unittest.sh -t metastore)
test-streaming:
@echo "Running go unittests..."
@(env bash $(PWD)/scripts/run_go_unittest.sh -t streaming)
test-go: build-cpp-with-unittest
@echo "Running go unittests..."
@(env bash $(PWD)/scripts/run_go_unittest.sh)
@ -517,10 +521,10 @@ generate-mockery-chunk-manager: getdeps
generate-mockery-pkg:
$(MAKE) -C pkg generate-mockery
generate-mockery-streaming:
$(INSTALL_PATH)/mockery --config $(PWD)/internal/streamingservice/.mockery.yaml
generate-mockery-internal:
$(INSTALL_PATH)/mockery --config $(PWD)/internal/.mockery.yaml
generate-mockery: generate-mockery-types generate-mockery-kv generate-mockery-rootcoord generate-mockery-proxy generate-mockery-querycoord generate-mockery-querynode generate-mockery-datacoord generate-mockery-pkg generate-mockery-log
generate-mockery: generate-mockery-types generate-mockery-kv generate-mockery-rootcoord generate-mockery-proxy generate-mockery-querycoord generate-mockery-querynode generate-mockery-datacoord generate-mockery-pkg generate-mockery-internal
generate-yaml: milvus-tools
@echo "Updating milvus config yaml"

View File

@ -36,3 +36,16 @@ packages:
github.com/milvus-io/milvus/internal/metastore:
interfaces:
StreamingCoordCataLog:
github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer:
interfaces:
Discoverer:
AssignmentDiscoverWatcher:
github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver:
interfaces:
Resolver:
google.golang.org/grpc/resolver:
interfaces:
ClientConn:
google.golang.org/grpc/balancer:
interfaces:
SubConn:

View File

@ -9,6 +9,8 @@ import (
"github.com/milvus-io/milvus/internal/metastore"
"github.com/milvus-io/milvus/internal/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/kv"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/etcd"
)
// NewCataLog creates a new catalog instance
@ -53,7 +55,9 @@ func (c *catalog) SavePChannels(ctx context.Context, infos []*streamingpb.PChann
}
kvs[key] = string(v)
}
return c.metaKV.MultiSave(kvs)
return etcd.SaveByBatchWithLimit(kvs, util.MaxEtcdTxnNum, func(partialKvs map[string]string) error {
return c.metaKV.MultiSave(partialKvs)
})
}
// buildPChannelInfoPath builds the path for pchannel info.

View File

@ -0,0 +1,158 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
package mock_balancer
import (
mock "github.com/stretchr/testify/mock"
balancer "google.golang.org/grpc/balancer"
resolver "google.golang.org/grpc/resolver"
)
// MockSubConn is an autogenerated mock type for the SubConn type
type MockSubConn struct {
mock.Mock
}
type MockSubConn_Expecter struct {
mock *mock.Mock
}
func (_m *MockSubConn) EXPECT() *MockSubConn_Expecter {
return &MockSubConn_Expecter{mock: &_m.Mock}
}
// Connect provides a mock function with given fields:
func (_m *MockSubConn) Connect() {
_m.Called()
}
// MockSubConn_Connect_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Connect'
type MockSubConn_Connect_Call struct {
*mock.Call
}
// Connect is a helper method to define mock.On call
func (_e *MockSubConn_Expecter) Connect() *MockSubConn_Connect_Call {
return &MockSubConn_Connect_Call{Call: _e.mock.On("Connect")}
}
func (_c *MockSubConn_Connect_Call) Run(run func()) *MockSubConn_Connect_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockSubConn_Connect_Call) Return() *MockSubConn_Connect_Call {
_c.Call.Return()
return _c
}
func (_c *MockSubConn_Connect_Call) RunAndReturn(run func()) *MockSubConn_Connect_Call {
_c.Call.Return(run)
return _c
}
// GetOrBuildProducer provides a mock function with given fields: _a0
func (_m *MockSubConn) GetOrBuildProducer(_a0 balancer.ProducerBuilder) (balancer.Producer, func()) {
ret := _m.Called(_a0)
var r0 balancer.Producer
var r1 func()
if rf, ok := ret.Get(0).(func(balancer.ProducerBuilder) (balancer.Producer, func())); ok {
return rf(_a0)
}
if rf, ok := ret.Get(0).(func(balancer.ProducerBuilder) balancer.Producer); ok {
r0 = rf(_a0)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(balancer.Producer)
}
}
if rf, ok := ret.Get(1).(func(balancer.ProducerBuilder) func()); ok {
r1 = rf(_a0)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(func())
}
}
return r0, r1
}
// MockSubConn_GetOrBuildProducer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetOrBuildProducer'
type MockSubConn_GetOrBuildProducer_Call struct {
*mock.Call
}
// GetOrBuildProducer is a helper method to define mock.On call
// - _a0 balancer.ProducerBuilder
func (_e *MockSubConn_Expecter) GetOrBuildProducer(_a0 interface{}) *MockSubConn_GetOrBuildProducer_Call {
return &MockSubConn_GetOrBuildProducer_Call{Call: _e.mock.On("GetOrBuildProducer", _a0)}
}
func (_c *MockSubConn_GetOrBuildProducer_Call) Run(run func(_a0 balancer.ProducerBuilder)) *MockSubConn_GetOrBuildProducer_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(balancer.ProducerBuilder))
})
return _c
}
func (_c *MockSubConn_GetOrBuildProducer_Call) Return(p balancer.Producer, close func()) *MockSubConn_GetOrBuildProducer_Call {
_c.Call.Return(p, close)
return _c
}
func (_c *MockSubConn_GetOrBuildProducer_Call) RunAndReturn(run func(balancer.ProducerBuilder) (balancer.Producer, func())) *MockSubConn_GetOrBuildProducer_Call {
_c.Call.Return(run)
return _c
}
// UpdateAddresses provides a mock function with given fields: _a0
func (_m *MockSubConn) UpdateAddresses(_a0 []resolver.Address) {
_m.Called(_a0)
}
// MockSubConn_UpdateAddresses_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAddresses'
type MockSubConn_UpdateAddresses_Call struct {
*mock.Call
}
// UpdateAddresses is a helper method to define mock.On call
// - _a0 []resolver.Address
func (_e *MockSubConn_Expecter) UpdateAddresses(_a0 interface{}) *MockSubConn_UpdateAddresses_Call {
return &MockSubConn_UpdateAddresses_Call{Call: _e.mock.On("UpdateAddresses", _a0)}
}
func (_c *MockSubConn_UpdateAddresses_Call) Run(run func(_a0 []resolver.Address)) *MockSubConn_UpdateAddresses_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]resolver.Address))
})
return _c
}
func (_c *MockSubConn_UpdateAddresses_Call) Return() *MockSubConn_UpdateAddresses_Call {
_c.Call.Return()
return _c
}
func (_c *MockSubConn_UpdateAddresses_Call) RunAndReturn(run func([]resolver.Address)) *MockSubConn_UpdateAddresses_Call {
_c.Call.Return(run)
return _c
}
// NewMockSubConn creates a new instance of MockSubConn. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockSubConn(t interface {
mock.TestingT
Cleanup(func())
}) *MockSubConn {
mock := &MockSubConn{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -0,0 +1,222 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
package mock_resolver
import (
mock "github.com/stretchr/testify/mock"
resolver "google.golang.org/grpc/resolver"
serviceconfig "google.golang.org/grpc/serviceconfig"
)
// MockClientConn is an autogenerated mock type for the ClientConn type
type MockClientConn struct {
mock.Mock
}
type MockClientConn_Expecter struct {
mock *mock.Mock
}
func (_m *MockClientConn) EXPECT() *MockClientConn_Expecter {
return &MockClientConn_Expecter{mock: &_m.Mock}
}
// NewAddress provides a mock function with given fields: addresses
func (_m *MockClientConn) NewAddress(addresses []resolver.Address) {
_m.Called(addresses)
}
// MockClientConn_NewAddress_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NewAddress'
type MockClientConn_NewAddress_Call struct {
*mock.Call
}
// NewAddress is a helper method to define mock.On call
// - addresses []resolver.Address
func (_e *MockClientConn_Expecter) NewAddress(addresses interface{}) *MockClientConn_NewAddress_Call {
return &MockClientConn_NewAddress_Call{Call: _e.mock.On("NewAddress", addresses)}
}
func (_c *MockClientConn_NewAddress_Call) Run(run func(addresses []resolver.Address)) *MockClientConn_NewAddress_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]resolver.Address))
})
return _c
}
func (_c *MockClientConn_NewAddress_Call) Return() *MockClientConn_NewAddress_Call {
_c.Call.Return()
return _c
}
func (_c *MockClientConn_NewAddress_Call) RunAndReturn(run func([]resolver.Address)) *MockClientConn_NewAddress_Call {
_c.Call.Return(run)
return _c
}
// NewServiceConfig provides a mock function with given fields: serviceConfig
func (_m *MockClientConn) NewServiceConfig(serviceConfig string) {
_m.Called(serviceConfig)
}
// MockClientConn_NewServiceConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NewServiceConfig'
type MockClientConn_NewServiceConfig_Call struct {
*mock.Call
}
// NewServiceConfig is a helper method to define mock.On call
// - serviceConfig string
func (_e *MockClientConn_Expecter) NewServiceConfig(serviceConfig interface{}) *MockClientConn_NewServiceConfig_Call {
return &MockClientConn_NewServiceConfig_Call{Call: _e.mock.On("NewServiceConfig", serviceConfig)}
}
func (_c *MockClientConn_NewServiceConfig_Call) Run(run func(serviceConfig string)) *MockClientConn_NewServiceConfig_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockClientConn_NewServiceConfig_Call) Return() *MockClientConn_NewServiceConfig_Call {
_c.Call.Return()
return _c
}
func (_c *MockClientConn_NewServiceConfig_Call) RunAndReturn(run func(string)) *MockClientConn_NewServiceConfig_Call {
_c.Call.Return(run)
return _c
}
// ParseServiceConfig provides a mock function with given fields: serviceConfigJSON
func (_m *MockClientConn) ParseServiceConfig(serviceConfigJSON string) *serviceconfig.ParseResult {
ret := _m.Called(serviceConfigJSON)
var r0 *serviceconfig.ParseResult
if rf, ok := ret.Get(0).(func(string) *serviceconfig.ParseResult); ok {
r0 = rf(serviceConfigJSON)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*serviceconfig.ParseResult)
}
}
return r0
}
// MockClientConn_ParseServiceConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ParseServiceConfig'
type MockClientConn_ParseServiceConfig_Call struct {
*mock.Call
}
// ParseServiceConfig is a helper method to define mock.On call
// - serviceConfigJSON string
func (_e *MockClientConn_Expecter) ParseServiceConfig(serviceConfigJSON interface{}) *MockClientConn_ParseServiceConfig_Call {
return &MockClientConn_ParseServiceConfig_Call{Call: _e.mock.On("ParseServiceConfig", serviceConfigJSON)}
}
func (_c *MockClientConn_ParseServiceConfig_Call) Run(run func(serviceConfigJSON string)) *MockClientConn_ParseServiceConfig_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockClientConn_ParseServiceConfig_Call) Return(_a0 *serviceconfig.ParseResult) *MockClientConn_ParseServiceConfig_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockClientConn_ParseServiceConfig_Call) RunAndReturn(run func(string) *serviceconfig.ParseResult) *MockClientConn_ParseServiceConfig_Call {
_c.Call.Return(run)
return _c
}
// ReportError provides a mock function with given fields: _a0
func (_m *MockClientConn) ReportError(_a0 error) {
_m.Called(_a0)
}
// MockClientConn_ReportError_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportError'
type MockClientConn_ReportError_Call struct {
*mock.Call
}
// ReportError is a helper method to define mock.On call
// - _a0 error
func (_e *MockClientConn_Expecter) ReportError(_a0 interface{}) *MockClientConn_ReportError_Call {
return &MockClientConn_ReportError_Call{Call: _e.mock.On("ReportError", _a0)}
}
func (_c *MockClientConn_ReportError_Call) Run(run func(_a0 error)) *MockClientConn_ReportError_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(error))
})
return _c
}
func (_c *MockClientConn_ReportError_Call) Return() *MockClientConn_ReportError_Call {
_c.Call.Return()
return _c
}
func (_c *MockClientConn_ReportError_Call) RunAndReturn(run func(error)) *MockClientConn_ReportError_Call {
_c.Call.Return(run)
return _c
}
// UpdateState provides a mock function with given fields: _a0
func (_m *MockClientConn) UpdateState(_a0 resolver.State) error {
ret := _m.Called(_a0)
var r0 error
if rf, ok := ret.Get(0).(func(resolver.State) error); ok {
r0 = rf(_a0)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockClientConn_UpdateState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateState'
type MockClientConn_UpdateState_Call struct {
*mock.Call
}
// UpdateState is a helper method to define mock.On call
// - _a0 resolver.State
func (_e *MockClientConn_Expecter) UpdateState(_a0 interface{}) *MockClientConn_UpdateState_Call {
return &MockClientConn_UpdateState_Call{Call: _e.mock.On("UpdateState", _a0)}
}
func (_c *MockClientConn_UpdateState_Call) Run(run func(_a0 resolver.State)) *MockClientConn_UpdateState_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(resolver.State))
})
return _c
}
func (_c *MockClientConn_UpdateState_Call) Return(_a0 error) *MockClientConn_UpdateState_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockClientConn_UpdateState_Call) RunAndReturn(run func(resolver.State) error) *MockClientConn_UpdateState_Call {
_c.Call.Return(run)
return _c
}
// NewMockClientConn creates a new instance of MockClientConn. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockClientConn(t interface {
mock.TestingT
Cleanup(func())
}) *MockClientConn {
mock := &MockClientConn{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -141,8 +141,8 @@ func (_c *MockBalancer_Trigger_Call) RunAndReturn(run func(context.Context) erro
return _c
}
// WatchBalanceResult provides a mock function with given fields: ctx, cb
func (_m *MockBalancer) WatchBalanceResult(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error {
// WatchChannelAssignments provides a mock function with given fields: ctx, cb
func (_m *MockBalancer) WatchChannelAssignments(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error {
ret := _m.Called(ctx, cb)
var r0 error
@ -155,31 +155,31 @@ func (_m *MockBalancer) WatchBalanceResult(ctx context.Context, cb func(typeutil
return r0
}
// MockBalancer_WatchBalanceResult_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchBalanceResult'
type MockBalancer_WatchBalanceResult_Call struct {
// MockBalancer_WatchChannelAssignments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchChannelAssignments'
type MockBalancer_WatchChannelAssignments_Call struct {
*mock.Call
}
// WatchBalanceResult is a helper method to define mock.On call
// WatchChannelAssignments is a helper method to define mock.On call
// - ctx context.Context
// - cb func(typeutil.VersionInt64Pair , []types.PChannelInfoAssigned) error
func (_e *MockBalancer_Expecter) WatchBalanceResult(ctx interface{}, cb interface{}) *MockBalancer_WatchBalanceResult_Call {
return &MockBalancer_WatchBalanceResult_Call{Call: _e.mock.On("WatchBalanceResult", ctx, cb)}
func (_e *MockBalancer_Expecter) WatchChannelAssignments(ctx interface{}, cb interface{}) *MockBalancer_WatchChannelAssignments_Call {
return &MockBalancer_WatchChannelAssignments_Call{Call: _e.mock.On("WatchChannelAssignments", ctx, cb)}
}
func (_c *MockBalancer_WatchBalanceResult_Call) Run(run func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error)) *MockBalancer_WatchBalanceResult_Call {
func (_c *MockBalancer_WatchChannelAssignments_Call) Run(run func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error)) *MockBalancer_WatchChannelAssignments_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error))
})
return _c
}
func (_c *MockBalancer_WatchBalanceResult_Call) Return(_a0 error) *MockBalancer_WatchBalanceResult_Call {
func (_c *MockBalancer_WatchChannelAssignments_Call) Return(_a0 error) *MockBalancer_WatchChannelAssignments_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBalancer_WatchBalanceResult_Call) RunAndReturn(run func(context.Context, func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error) *MockBalancer_WatchBalanceResult_Call {
func (_c *MockBalancer_WatchChannelAssignments_Call) RunAndReturn(run func(context.Context, func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error) *MockBalancer_WatchChannelAssignments_Call {
_c.Call.Return(run)
return _c
}

View File

@ -7,8 +7,6 @@ import (
mock "github.com/stretchr/testify/mock"
sessionutil "github.com/milvus-io/milvus/internal/util/sessionutil"
types "github.com/milvus-io/milvus/pkg/streaming/util/types"
)
@ -101,19 +99,19 @@ func (_c *MockManagerClient_Close_Call) RunAndReturn(run func()) *MockManagerCli
}
// CollectAllStatus provides a mock function with given fields: ctx
func (_m *MockManagerClient) CollectAllStatus(ctx context.Context) (map[int64]types.StreamingNodeStatus, error) {
func (_m *MockManagerClient) CollectAllStatus(ctx context.Context) (map[int64]*types.StreamingNodeStatus, error) {
ret := _m.Called(ctx)
var r0 map[int64]types.StreamingNodeStatus
var r0 map[int64]*types.StreamingNodeStatus
var r1 error
if rf, ok := ret.Get(0).(func(context.Context) (map[int64]types.StreamingNodeStatus, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context) (map[int64]*types.StreamingNodeStatus, error)); ok {
return rf(ctx)
}
if rf, ok := ret.Get(0).(func(context.Context) map[int64]types.StreamingNodeStatus); ok {
if rf, ok := ret.Get(0).(func(context.Context) map[int64]*types.StreamingNodeStatus); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[int64]types.StreamingNodeStatus)
r0 = ret.Get(0).(map[int64]*types.StreamingNodeStatus)
}
}
@ -144,12 +142,12 @@ func (_c *MockManagerClient_CollectAllStatus_Call) Run(run func(ctx context.Cont
return _c
}
func (_c *MockManagerClient_CollectAllStatus_Call) Return(_a0 map[int64]types.StreamingNodeStatus, _a1 error) *MockManagerClient_CollectAllStatus_Call {
func (_c *MockManagerClient_CollectAllStatus_Call) Return(_a0 map[int64]*types.StreamingNodeStatus, _a1 error) *MockManagerClient_CollectAllStatus_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockManagerClient_CollectAllStatus_Call) RunAndReturn(run func(context.Context) (map[int64]types.StreamingNodeStatus, error)) *MockManagerClient_CollectAllStatus_Call {
func (_c *MockManagerClient_CollectAllStatus_Call) RunAndReturn(run func(context.Context) (map[int64]*types.StreamingNodeStatus, error)) *MockManagerClient_CollectAllStatus_Call {
_c.Call.Return(run)
return _c
}
@ -198,19 +196,29 @@ func (_c *MockManagerClient_Remove_Call) RunAndReturn(run func(context.Context,
}
// WatchNodeChanged provides a mock function with given fields: ctx
func (_m *MockManagerClient) WatchNodeChanged(ctx context.Context) <-chan map[int64]*sessionutil.SessionRaw {
func (_m *MockManagerClient) WatchNodeChanged(ctx context.Context) (<-chan struct{}, error) {
ret := _m.Called(ctx)
var r0 <-chan map[int64]*sessionutil.SessionRaw
if rf, ok := ret.Get(0).(func(context.Context) <-chan map[int64]*sessionutil.SessionRaw); ok {
var r0 <-chan struct{}
var r1 error
if rf, ok := ret.Get(0).(func(context.Context) (<-chan struct{}, error)); ok {
return rf(ctx)
}
if rf, ok := ret.Get(0).(func(context.Context) <-chan struct{}); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(<-chan map[int64]*sessionutil.SessionRaw)
r0 = ret.Get(0).(<-chan struct{})
}
}
return r0
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockManagerClient_WatchNodeChanged_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchNodeChanged'
@ -231,12 +239,12 @@ func (_c *MockManagerClient_WatchNodeChanged_Call) Run(run func(ctx context.Cont
return _c
}
func (_c *MockManagerClient_WatchNodeChanged_Call) Return(_a0 <-chan map[int64]*sessionutil.SessionRaw) *MockManagerClient_WatchNodeChanged_Call {
_c.Call.Return(_a0)
func (_c *MockManagerClient_WatchNodeChanged_Call) Return(_a0 <-chan struct{}, _a1 error) *MockManagerClient_WatchNodeChanged_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockManagerClient_WatchNodeChanged_Call) RunAndReturn(run func(context.Context) <-chan map[int64]*sessionutil.SessionRaw) *MockManagerClient_WatchNodeChanged_Call {
func (_c *MockManagerClient_WatchNodeChanged_Call) RunAndReturn(run func(context.Context) (<-chan struct{}, error)) *MockManagerClient_WatchNodeChanged_Call {
_c.Call.Return(run)
return _c
}

View File

@ -0,0 +1,121 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
package mock_discoverer
import (
context "context"
discoverer "github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer"
mock "github.com/stretchr/testify/mock"
)
// MockDiscoverer is an autogenerated mock type for the Discoverer type
type MockDiscoverer struct {
mock.Mock
}
type MockDiscoverer_Expecter struct {
mock *mock.Mock
}
func (_m *MockDiscoverer) EXPECT() *MockDiscoverer_Expecter {
return &MockDiscoverer_Expecter{mock: &_m.Mock}
}
// Discover provides a mock function with given fields: ctx, cb
func (_m *MockDiscoverer) Discover(ctx context.Context, cb func(discoverer.VersionedState) error) error {
ret := _m.Called(ctx, cb)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, func(discoverer.VersionedState) error) error); ok {
r0 = rf(ctx, cb)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockDiscoverer_Discover_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Discover'
type MockDiscoverer_Discover_Call struct {
*mock.Call
}
// Discover is a helper method to define mock.On call
// - ctx context.Context
// - cb func(discoverer.VersionedState) error
func (_e *MockDiscoverer_Expecter) Discover(ctx interface{}, cb interface{}) *MockDiscoverer_Discover_Call {
return &MockDiscoverer_Discover_Call{Call: _e.mock.On("Discover", ctx, cb)}
}
func (_c *MockDiscoverer_Discover_Call) Run(run func(ctx context.Context, cb func(discoverer.VersionedState) error)) *MockDiscoverer_Discover_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(func(discoverer.VersionedState) error))
})
return _c
}
func (_c *MockDiscoverer_Discover_Call) Return(_a0 error) *MockDiscoverer_Discover_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockDiscoverer_Discover_Call) RunAndReturn(run func(context.Context, func(discoverer.VersionedState) error) error) *MockDiscoverer_Discover_Call {
_c.Call.Return(run)
return _c
}
// NewVersionedState provides a mock function with given fields:
func (_m *MockDiscoverer) NewVersionedState() discoverer.VersionedState {
ret := _m.Called()
var r0 discoverer.VersionedState
if rf, ok := ret.Get(0).(func() discoverer.VersionedState); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(discoverer.VersionedState)
}
return r0
}
// MockDiscoverer_NewVersionedState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NewVersionedState'
type MockDiscoverer_NewVersionedState_Call struct {
*mock.Call
}
// NewVersionedState is a helper method to define mock.On call
func (_e *MockDiscoverer_Expecter) NewVersionedState() *MockDiscoverer_NewVersionedState_Call {
return &MockDiscoverer_NewVersionedState_Call{Call: _e.mock.On("NewVersionedState")}
}
func (_c *MockDiscoverer_NewVersionedState_Call) Run(run func()) *MockDiscoverer_NewVersionedState_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockDiscoverer_NewVersionedState_Call) Return(_a0 discoverer.VersionedState) *MockDiscoverer_NewVersionedState_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockDiscoverer_NewVersionedState_Call) RunAndReturn(run func() discoverer.VersionedState) *MockDiscoverer_NewVersionedState_Call {
_c.Call.Return(run)
return _c
}
// NewMockDiscoverer creates a new instance of MockDiscoverer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockDiscoverer(t interface {
mock.TestingT
Cleanup(func())
}) *MockDiscoverer {
mock := &MockDiscoverer{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -0,0 +1,121 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
package mock_resolver
import (
context "context"
discoverer "github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer"
mock "github.com/stretchr/testify/mock"
)
// MockResolver is an autogenerated mock type for the Resolver type
type MockResolver struct {
mock.Mock
}
type MockResolver_Expecter struct {
mock *mock.Mock
}
func (_m *MockResolver) EXPECT() *MockResolver_Expecter {
return &MockResolver_Expecter{mock: &_m.Mock}
}
// GetLatestState provides a mock function with given fields:
func (_m *MockResolver) GetLatestState() discoverer.VersionedState {
ret := _m.Called()
var r0 discoverer.VersionedState
if rf, ok := ret.Get(0).(func() discoverer.VersionedState); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(discoverer.VersionedState)
}
return r0
}
// MockResolver_GetLatestState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetLatestState'
type MockResolver_GetLatestState_Call struct {
*mock.Call
}
// GetLatestState is a helper method to define mock.On call
func (_e *MockResolver_Expecter) GetLatestState() *MockResolver_GetLatestState_Call {
return &MockResolver_GetLatestState_Call{Call: _e.mock.On("GetLatestState")}
}
func (_c *MockResolver_GetLatestState_Call) Run(run func()) *MockResolver_GetLatestState_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockResolver_GetLatestState_Call) Return(_a0 discoverer.VersionedState) *MockResolver_GetLatestState_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockResolver_GetLatestState_Call) RunAndReturn(run func() discoverer.VersionedState) *MockResolver_GetLatestState_Call {
_c.Call.Return(run)
return _c
}
// Watch provides a mock function with given fields: ctx, cb
func (_m *MockResolver) Watch(ctx context.Context, cb func(discoverer.VersionedState) error) error {
ret := _m.Called(ctx, cb)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, func(discoverer.VersionedState) error) error); ok {
r0 = rf(ctx, cb)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockResolver_Watch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Watch'
type MockResolver_Watch_Call struct {
*mock.Call
}
// Watch is a helper method to define mock.On call
// - ctx context.Context
// - cb func(discoverer.VersionedState) error
func (_e *MockResolver_Expecter) Watch(ctx interface{}, cb interface{}) *MockResolver_Watch_Call {
return &MockResolver_Watch_Call{Call: _e.mock.On("Watch", ctx, cb)}
}
func (_c *MockResolver_Watch_Call) Run(run func(ctx context.Context, cb func(discoverer.VersionedState) error)) *MockResolver_Watch_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(func(discoverer.VersionedState) error))
})
return _c
}
func (_c *MockResolver_Watch_Call) Return(_a0 error) *MockResolver_Watch_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockResolver_Watch_Call) RunAndReturn(run func(context.Context, func(discoverer.VersionedState) error) error) *MockResolver_Watch_Call {
_c.Call.Return(run)
return _c
}
// NewMockResolver creates a new instance of MockResolver. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockResolver(t interface {
mock.TestingT
Cleanup(func())
}) *MockResolver {
mock := &MockResolver{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -18,20 +18,22 @@ message MessageID {
// Message is the basic unit of communication between publisher and consumer.
message Message {
bytes payload = 1; // message body
bytes payload = 1; // message body
map<string, string> properties = 2; // message properties
}
// PChannelInfo is the information of a pchannel info, should only keep the basic info of a pchannel.
// It's used in many rpc and meta, so keep it simple.
// PChannelInfo is the information of a pchannel info, should only keep the
// basic info of a pchannel. It's used in many rpc and meta, so keep it simple.
message PChannelInfo {
string name = 1; // channel name
int64 term =
2; // A monotonic increasing term, every time the channel is recovered or moved to another streamingnode, the term will increase by meta server.
int64 term = 2; // A monotonic increasing term, every time the channel is
// recovered or moved to another streamingnode, the term
// will increase by meta server.
}
// PChannelMetaHistory is the history meta information of a pchannel, should only keep the data that is necessary to persistent.
message PChannelMetaHistory {
// PChannelAssignmentLog is the log of meta information of a pchannel, should
// only keep the data that is necessary to persistent.
message PChannelAssignmentLog {
int64 term = 1; // term when server assigned.
StreamingNodeInfo node =
2; // streaming node that the channel is assigned to.
@ -50,20 +52,21 @@ enum PChannelMetaState {
4; // channel is unavailable at this term.
}
// PChannelMeta is the meta information of a pchannel, should only keep the data that is necessary to persistent.
// It's only used in meta, so do not use it in rpc.
// PChannelMeta is the meta information of a pchannel, should only keep the data
// that is necessary to persistent. It's only used in meta, so do not use it in
// rpc.
message PChannelMeta {
PChannelInfo channel = 1; // keep the meta info that current assigned to.
PChannelInfo channel = 1; // keep the meta info that current assigned to.
StreamingNodeInfo node = 2; // nil if channel is not uninitialized.
PChannelMetaState state = 3; // state of the channel.
repeated PChannelMetaHistory histories =
4; // keep the meta info history that used to be assigned to.
repeated PChannelAssignmentLog histories =
4; // keep the meta info assignment log that used to be assigned to.
}
// VersionPair is the version pair of global and local.
message VersionPair {
int64 global = 1;
int64 local = 2;
int64 local = 2;
}
//
@ -72,14 +75,12 @@ message VersionPair {
service StreamingCoordStateService {
rpc GetComponentStates(milvus.GetComponentStatesRequest)
returns (milvus.ComponentStates) {
}
returns (milvus.ComponentStates) {}
}
service StreamingNodeStateService {
rpc GetComponentStates(milvus.GetComponentStatesRequest)
returns (milvus.ComponentStates) {
}
returns (milvus.ComponentStates) {}
}
//
@ -90,11 +91,11 @@ service StreamingNodeStateService {
// Server: log coord. Running on every log node.
// Client: all log publish/consuming node.
service StreamingCoordAssignmentService {
// AssignmentDiscover is used to discover all log nodes managed by the streamingcoord.
// Channel assignment information will be pushed to client by stream.
// AssignmentDiscover is used to discover all log nodes managed by the
// streamingcoord. Channel assignment information will be pushed to client
// by stream.
rpc AssignmentDiscover(stream AssignmentDiscoverRequest)
returns (stream AssignmentDiscoverResponse) {
}
returns (stream AssignmentDiscoverResponse) {}
}
// AssignmentDiscoverRequest is the request of Discovery
@ -106,15 +107,15 @@ message AssignmentDiscoverRequest {
}
}
// ReportAssignmentErrorRequest is the request to report assignment error happens.
// ReportAssignmentErrorRequest is the request to report assignment error
// happens.
message ReportAssignmentErrorRequest {
PChannelInfo pchannel = 1; // channel
StreamingError err = 2; // error happend on log node
StreamingError err = 2; // error happend on log node
}
// CloseAssignmentDiscoverRequest is the request to close the stream.
message CloseAssignmentDiscoverRequest {
}
message CloseAssignmentDiscoverRequest {}
// AssignmentDiscoverResponse is the response of Discovery
message AssignmentDiscoverResponse {
@ -126,31 +127,31 @@ message AssignmentDiscoverResponse {
}
}
// FullStreamingNodeAssignmentWithVersion is the full assignment info of a log node with version.
// FullStreamingNodeAssignmentWithVersion is the full assignment info of a log
// node with version.
message FullStreamingNodeAssignmentWithVersion {
VersionPair version = 1;
VersionPair version = 1;
repeated StreamingNodeAssignment assignments = 2;
}
message CloseAssignmentDiscoverResponse {
}
message CloseAssignmentDiscoverResponse {}
// StreamingNodeInfo is the information of a streaming node.
message StreamingNodeInfo {
int64 server_id = 1;
string address = 2;
string address = 2;
}
// StreamingNodeAssignment is the assignment info of a streaming node.
message StreamingNodeAssignment {
StreamingNodeInfo node = 1;
StreamingNodeInfo node = 1;
repeated PChannelInfo channels = 2;
}
// DeliverPolicy is the policy to deliver message.
message DeliverPolicy {
oneof policy {
google.protobuf.Empty all = 1; // deliver all messages.
google.protobuf.Empty all = 1; // deliver all messages.
google.protobuf.Empty latest = 2; // deliver the latest message.
MessageID start_from =
3; // deliver message from this message id. [startFrom, ...]
@ -162,22 +163,24 @@ message DeliverPolicy {
// DeliverFilter is the filter to deliver message.
message DeliverFilter {
oneof filter {
DeliverFilterTimeTickGT time_tick_gt = 1;
DeliverFilterTimeTickGT time_tick_gt = 1;
DeliverFilterTimeTickGTE time_tick_gte = 2;
DeliverFilterVChannel vchannel = 3;
DeliverFilterVChannel vchannel = 3;
}
}
// DeliverFilterTimeTickGT is the filter to deliver message with time tick greater than this value.
// DeliverFilterTimeTickGT is the filter to deliver message with time tick
// greater than this value.
message DeliverFilterTimeTickGT {
uint64 time_tick =
1; // deliver message with time tick greater than this value.
}
// DeliverFilterTimeTickGTE is the filter to deliver message with time tick greater than or equal to this value.
// DeliverFilterTimeTickGTE is the filter to deliver message with time tick
// greater than or equal to this value.
message DeliverFilterTimeTickGTE {
uint64 time_tick =
1; // deliver message with time tick greater than or equal to this value.
uint64 time_tick = 1; // deliver message with time tick greater than or
// equal to this value.
}
// DeliverFilterVChannel is the filter to deliver message with vchannel name.
@ -187,24 +190,22 @@ message DeliverFilterVChannel {
// StreamingCode is the error code for log internal component.
enum StreamingCode {
STREAMING_CODE_OK = 0;
STREAMING_CODE_CHANNEL_EXIST = 1; // channel already exist
STREAMING_CODE_CHANNEL_NOT_EXIST = 2; // channel not exist
STREAMING_CODE_CHANNEL_FENCED = 3; // channel is fenced
STREAMING_CODE_ON_SHUTDOWN = 4; // component is on shutdown
STREAMING_CODE_INVALID_REQUEST_SEQ = 5; // invalid request sequence
STREAMING_CODE_UNMATCHED_CHANNEL_TERM = 6; // unmatched channel term
STREAMING_CODE_IGNORED_OPERATION = 7; // ignored operation
STREAMING_CODE_INNER = 8; // underlying service failure.
STREAMING_CODE_EOF = 9; // end of stream, generated by grpc status.
STREAMING_CODE_INVAILD_ARGUMENT = 10; // invalid argument
STREAMING_CODE_UNKNOWN = 999; // unknown error
STREAMING_CODE_OK = 0;
STREAMING_CODE_CHANNEL_NOT_EXIST = 1; // channel not exist
STREAMING_CODE_CHANNEL_FENCED = 2; // channel is fenced
STREAMING_CODE_ON_SHUTDOWN = 3; // component is on shutdown
STREAMING_CODE_INVALID_REQUEST_SEQ = 4; // invalid request sequence
STREAMING_CODE_UNMATCHED_CHANNEL_TERM = 5; // unmatched channel term
STREAMING_CODE_IGNORED_OPERATION = 6; // ignored operation
STREAMING_CODE_INNER = 7; // underlying service failure.
STREAMING_CODE_INVAILD_ARGUMENT = 8; // invalid argument
STREAMING_CODE_UNKNOWN = 999; // unknown error
}
// StreamingError is the error type for log internal component.
message StreamingError {
StreamingCode code = 1;
string cause = 2;
string cause = 2;
}
//
@ -212,36 +213,35 @@ message StreamingError {
//
// StreamingNodeHandlerService is the service to handle log messages.
// All handler operation will be blocked until the channel is ready read or write on that log node.
// Server: all log node. Running on every log node.
// All handler operation will be blocked until the channel is ready read or
// write on that log node. Server: all log node. Running on every log node.
// Client: all log produce or consuming node.
service StreamingNodeHandlerService {
// Produce is a bi-directional streaming RPC to send messages to a channel.
// All messages sent to a channel will be assigned a unique messageID.
// The messageID is used to identify the message in the channel.
// The messageID isn't promised to be monotonous increasing with the sequence of responsing.
// Error:
// If channel isn't assign to this log node, the RPC will return error CHANNEL_NOT_EXIST.
// If channel is moving away to other log node, the RPC will return error CHANNEL_FENCED.
rpc Produce(stream ProduceRequest) returns (stream ProduceResponse) {
};
// The messageID isn't promised to be monotonous increasing with the
// sequence of responsing. Error: If channel isn't assign to this log node,
// the RPC will return error CHANNEL_NOT_EXIST. If channel is moving away to
// other log node, the RPC will return error CHANNEL_FENCED.
rpc Produce(stream ProduceRequest) returns (stream ProduceResponse) {};
// Consume is a server streaming RPC to receive messages from a channel.
// All message after given startMessageID and excluding will be sent to the client by stream.
// If no more message in the channel, the stream will be blocked until new message coming.
// Error:
// If channel isn't assign to this log node, the RPC will return error CHANNEL_NOT_EXIST.
// If channel is moving away to other log node, the RPC will return error CHANNEL_FENCED.
rpc Consume(stream ConsumeRequest) returns (stream ConsumeResponse) {
};
// All message after given startMessageID and excluding will be sent to the
// client by stream. If no more message in the channel, the stream will be
// blocked until new message coming. Error: If channel isn't assign to this
// log node, the RPC will return error CHANNEL_NOT_EXIST. If channel is
// moving away to other log node, the RPC will return error CHANNEL_FENCED.
rpc Consume(stream ConsumeRequest) returns (stream ConsumeResponse) {};
}
// ProduceRequest is the request of the Produce RPC.
// Channel name will be passthrough in the header of stream bu not in the request body.
// Channel name will be passthrough in the header of stream bu not in the
// request body.
message ProduceRequest {
oneof request {
ProduceMessageRequest produce = 2;
CloseProducerRequest close = 3;
ProduceMessageRequest produce = 1;
CloseProducerRequest close = 2;
}
}
@ -254,46 +254,47 @@ message CreateProducerRequest {
// ProduceMessageRequest is the request of the Produce RPC.
message ProduceMessageRequest {
int64 request_id = 1; // request id for reply.
Message message = 2; // message to be sent.
Message message = 2; // message to be sent.
}
// CloseProducerRequest is the request of the CloseProducer RPC.
// After CloseProducerRequest is requested, no more ProduceRequest can be sent.
message CloseProducerRequest {
}
message CloseProducerRequest {}
// ProduceResponse is the response of the Produce RPC.
message ProduceResponse {
oneof response {
CreateProducerResponse create = 1;
CreateProducerResponse create = 1;
ProduceMessageResponse produce = 2;
CloseProducerResponse close = 3;
CloseProducerResponse close = 3;
}
}
// CreateProducerResponse is the result of the CreateProducer RPC.
message CreateProducerResponse {
int64 producer_id =
1; // A unique producer id on streamingnode for this producer in streamingnode lifetime.
// Is used to identify the producer in streamingnode for other unary grpc call at producer level.
string wal_name = 1; // wal name at server side.
int64 producer_id = 2; // A unique producer id on streamingnode for this
// producer in streamingnode lifetime.
// Is used to identify the producer in streamingnode for other unary grpc
// call at producer level.
}
message ProduceMessageResponse {
int64 request_id = 1;
oneof response {
ProduceMessageResponseResult result = 2;
StreamingError error = 3;
StreamingError error = 3;
}
}
// ProduceMessageResponseResult is the result of the produce message streaming RPC.
// ProduceMessageResponseResult is the result of the produce message streaming
// RPC.
message ProduceMessageResponseResult {
MessageID id = 1; // the offset of the message in the channel
}
// CloseProducerResponse is the result of the CloseProducer RPC.
message CloseProducerResponse {
}
message CloseProducerResponse {}
// ConsumeRequest is the request of the Consume RPC.
// Add more control block in future.
@ -305,14 +306,13 @@ message ConsumeRequest {
// CloseConsumerRequest is the request of the CloseConsumer RPC.
// After CloseConsumerRequest is requested, no more ConsumeRequest can be sent.
message CloseConsumerRequest {
}
message CloseConsumerRequest {}
// CreateConsumerRequest is the request of the CreateConsumer RPC.
// CreateConsumerRequest is passed in the header of stream.
message CreateConsumerRequest {
PChannelInfo pchannel = 1;
DeliverPolicy deliver_policy = 2; // deliver policy.
PChannelInfo pchannel = 1;
DeliverPolicy deliver_policy = 2; // deliver policy.
repeated DeliverFilter deliver_filters = 3; // deliver filter.
}
@ -321,20 +321,20 @@ message ConsumeResponse {
oneof response {
CreateConsumerResponse create = 1;
ConsumeMessageReponse consume = 2;
CloseConsumerResponse close = 3;
CloseConsumerResponse close = 3;
}
}
message CreateConsumerResponse {
string wal_name = 1; // wal name at server side.
}
message ConsumeMessageReponse {
MessageID id = 1; // message id of message.
MessageID id = 1; // message id of message.
Message message = 2; // message to be consumed.
}
message CloseConsumerResponse {
}
message CloseConsumerResponse {}
//
// StreamingNodeManagerService
@ -342,32 +342,31 @@ message CloseConsumerResponse {
// StreamingNodeManagerService is the log manage operation on log node.
// Server: all log node. Running on every log node.
// Client: log coord. There should be only one client globally to call this service on all streamingnode.
// Client: log coord. There should be only one client globally to call this
// service on all streamingnode.
service StreamingNodeManagerService {
// Assign is a unary RPC to assign a channel on a log node.
// Block until the channel assignd is ready to read or write on the log node.
// Error:
// If the channel already exists, return error with code CHANNEL_EXIST.
// Block until the channel assignd is ready to read or write on the log
// node. Error: If the channel already exists, return error with code
// CHANNEL_EXIST.
rpc Assign(StreamingNodeManagerAssignRequest)
returns (StreamingNodeManagerAssignResponse) {
};
returns (StreamingNodeManagerAssignResponse) {};
// Remove is unary RPC to remove a channel on a log node.
// Data of the channel on flying would be sent or flused as much as possible.
// Block until the resource of channel is released on the log node.
// New incoming request of handler of this channel will be rejected with special error.
// Error:
// If the channel does not exist, return error with code CHANNEL_NOT_EXIST.
// Data of the channel on flying would be sent or flused as much as
// possible. Block until the resource of channel is released on the log
// node. New incoming request of handler of this channel will be rejected
// with special error. Error: If the channel does not exist, return error
// with code CHANNEL_NOT_EXIST.
rpc Remove(StreamingNodeManagerRemoveRequest)
returns (StreamingNodeManagerRemoveResponse) {
};
returns (StreamingNodeManagerRemoveResponse) {};
// rpc CollectStatus() ...
// CollectStatus is unary RPC to collect all avaliable channel info and load balance info on a log node.
// Used to recover channel info on log coord, collect balance info and health check.
// CollectStatus is unary RPC to collect all avaliable channel info and load
// balance info on a log node. Used to recover channel info on log coord,
// collect balance info and health check.
rpc CollectStatus(StreamingNodeManagerCollectStatusRequest)
returns (StreamingNodeManagerCollectStatusResponse) {
};
returns (StreamingNodeManagerCollectStatusResponse) {};
}
// StreamingManagerAssignRequest is the request message of Assign RPC.
@ -375,18 +374,15 @@ message StreamingNodeManagerAssignRequest {
PChannelInfo pchannel = 1;
}
message StreamingNodeManagerAssignResponse {
}
message StreamingNodeManagerAssignResponse {}
message StreamingNodeManagerRemoveRequest {
PChannelInfo pchannel = 1;
}
message StreamingNodeManagerRemoveResponse {
}
message StreamingNodeManagerRemoveResponse {}
message StreamingNodeManagerCollectStatusRequest {
}
message StreamingNodeManagerCollectStatusRequest {}
message StreamingNodeBalanceAttributes {
// TODO: traffic of pchannel or other things.

View File

@ -25,8 +25,10 @@ type balanceTimer struct {
// EnableBackoffOrNot enables or disables backoff
func (t *balanceTimer) EnableBackoff() {
t.enableBackoff = true
t.newIncomingBackOff = true
if !t.enableBackoff {
t.enableBackoff = true
t.newIncomingBackOff = true
}
}
// DisableBackoff disables backoff

View File

@ -14,8 +14,8 @@ var _ Balancer = (*balancerImpl)(nil)
// Balancer is a local component, it should promise all channel can be assigned, and reach the final consistency.
// Balancer should be thread safe.
type Balancer interface {
// WatchBalanceResult watches the balance result.
WatchBalanceResult(ctx context.Context, cb func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error) error
// WatchChannelAssignments watches the balance result.
WatchChannelAssignments(ctx context.Context, cb func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error) error
// MarkAsAvailable marks the pchannels as available, and trigger a rebalance.
MarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) error

View File

@ -8,7 +8,7 @@ import (
"golang.org/x/sync/errgroup"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/channel"
"github.com/milvus-io/milvus/internal/streamingnode/client/manager"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
@ -21,7 +21,6 @@ import (
func RecoverBalancer(
ctx context.Context,
policy string,
streamingNodeManager manager.ManagerClient,
incomingNewChannel ...string, // Concurrent incoming new channel directly from the configuration.
// we should add a rpc interface for creating new incoming new channel.
) (Balancer, error) {
@ -33,7 +32,6 @@ func RecoverBalancer(
b := &balancerImpl{
lifetime: lifetime.NewLifetime(lifetime.Working),
logger: log.With(zap.String("policy", policy)),
streamingNodeManager: streamingNodeManager, // TODO: fill it up.
channelMetaManager: manager,
policy: mustGetPolicy(policy),
reqCh: make(chan *request, 5),
@ -47,15 +45,14 @@ func RecoverBalancer(
type balancerImpl struct {
lifetime lifetime.Lifetime[lifetime.State]
logger *log.MLogger
streamingNodeManager manager.ManagerClient
channelMetaManager *channel.ChannelManager
policy Policy // policy is the balance policy, TODO: should be dynamic in future.
reqCh chan *request // reqCh is the request channel, send the operation to background task.
backgroundTaskNotifier *syncutil.AsyncTaskNotifier[struct{}] // backgroundTaskNotifier is used to conmunicate with the background task.
}
// WatchBalanceResult watches the balance result.
func (b *balancerImpl) WatchBalanceResult(ctx context.Context, cb func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error) error {
// WatchChannelAssignments watches the balance result.
func (b *balancerImpl) WatchChannelAssignments(ctx context.Context, cb func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error) error {
if b.lifetime.Add(lifetime.IsWorking) != nil {
return status.NewOnShutdownError("balancer is closing")
}
@ -110,6 +107,11 @@ func (b *balancerImpl) execute() {
}()
balanceTimer := newBalanceTimer()
nodeChanged, err := resource.Resource().StreamingNodeManagerClient().WatchNodeChanged(b.backgroundTaskNotifier.Context())
if err != nil {
b.logger.Error("fail to watch node changed", zap.Error(err))
return
}
for {
// Wait for next balance trigger.
// Maybe trigger by timer or by request.
@ -122,6 +124,13 @@ func (b *balancerImpl) execute() {
newReq.apply(b)
b.applyAllRequest()
case <-nextTimer:
// balance triggered by timer.
case _, ok := <-nodeChanged:
if !ok {
return // nodeChanged is only closed if context cancel.
// in other word, balancer is closed.
}
// balance triggered by new streaming node changed.
}
if err := b.balance(b.backgroundTaskNotifier.Context()); err != nil {
@ -159,7 +168,7 @@ func (b *balancerImpl) balance(ctx context.Context) error {
pchannelView := b.channelMetaManager.CurrentPChannelsView()
b.logger.Info("collect all status...")
nodeStatus, err := b.streamingNodeManager.CollectAllStatus(ctx)
nodeStatus, err := resource.Resource().StreamingNodeManagerClient().CollectAllStatus(ctx)
if err != nil {
return errors.Wrap(err, "fail to collect all status")
}
@ -197,15 +206,15 @@ func (b *balancerImpl) applyBalanceResultToStreamingNode(ctx context.Context, mo
g.Go(func() error {
// all history channels should be remove from related nodes.
for _, assignment := range channel.AssignHistories() {
if err := b.streamingNodeManager.Remove(ctx, assignment); err != nil {
b.logger.Warn("fail to remove channel", zap.Any("assignment", assignment))
if err := resource.Resource().StreamingNodeManagerClient().Remove(ctx, assignment); err != nil {
b.logger.Warn("fail to remove channel", zap.Any("assignment", assignment), zap.Error(err))
return err
}
b.logger.Info("remove channel success", zap.Any("assignment", assignment))
}
// assign the channel to the target node.
if err := b.streamingNodeManager.Assign(ctx, channel.CurrentAssignment()); err != nil {
if err := resource.Resource().StreamingNodeManagerClient().Assign(ctx, channel.CurrentAssignment()); err != nil {
b.logger.Warn("fail to assign channel", zap.Any("assignment", channel.CurrentAssignment()))
return err
}
@ -223,7 +232,7 @@ func (b *balancerImpl) applyBalanceResultToStreamingNode(ctx context.Context, mo
}
// generateCurrentLayout generate layout from all nodes info and meta.
func generateCurrentLayout(channelsInMeta map[string]*channel.PChannelMeta, allNodesStatus map[int64]types.StreamingNodeStatus) (layout CurrentLayout) {
func generateCurrentLayout(channelsInMeta map[string]*channel.PChannelMeta, allNodesStatus map[int64]*types.StreamingNodeStatus) (layout CurrentLayout) {
activeRelations := make(map[int64][]types.PChannelInfo, len(allNodesStatus))
incomingChannels := make([]string, 0)
channelsToNodes := make(map[string]int64, len(channelsInMeta))
@ -255,7 +264,7 @@ func generateCurrentLayout(channelsInMeta map[string]*channel.PChannelMeta, allN
zap.String("channel", meta.Name()),
zap.Int64("term", meta.CurrentTerm()),
zap.Int64("serverID", meta.CurrentServerID()),
zap.Error(nodeStatus.Err),
zap.Error(nodeStatus.ErrorOfNode()),
)
}
}

View File

@ -23,9 +23,10 @@ func TestBalancer(t *testing.T) {
paramtable.Init()
streamingNodeManager := mock_manager.NewMockManagerClient(t)
streamingNodeManager.EXPECT().WatchNodeChanged(mock.Anything).Return(make(chan struct{}), nil)
streamingNodeManager.EXPECT().Assign(mock.Anything, mock.Anything).Return(nil)
streamingNodeManager.EXPECT().Remove(mock.Anything, mock.Anything).Return(nil)
streamingNodeManager.EXPECT().CollectAllStatus(mock.Anything).Return(map[int64]types.StreamingNodeStatus{
streamingNodeManager.EXPECT().CollectAllStatus(mock.Anything).Return(map[int64]*types.StreamingNodeStatus{
1: {
StreamingNodeInfo: types.StreamingNodeInfo{
ServerID: 1,
@ -54,7 +55,7 @@ func TestBalancer(t *testing.T) {
}, nil)
catalog := mock_metastore.NewMockStreamingCoordCataLog(t)
resource.InitForTest(resource.OptStreamingCatalog(catalog))
resource.InitForTest(resource.OptStreamingCatalog(catalog), resource.OptStreamingManagerClient(streamingNodeManager))
catalog.EXPECT().ListPChannel(mock.Anything).Unset()
catalog.EXPECT().ListPChannel(mock.Anything).RunAndReturn(func(ctx context.Context) ([]*streamingpb.PChannelMeta, error) {
return []*streamingpb.PChannelMeta{
@ -87,7 +88,7 @@ func TestBalancer(t *testing.T) {
catalog.EXPECT().SavePChannels(mock.Anything, mock.Anything).Return(nil).Maybe()
ctx := context.Background()
b, err := balancer.RecoverBalancer(ctx, "pchannel_count_fair", streamingNodeManager)
b, err := balancer.RecoverBalancer(ctx, "pchannel_count_fair")
assert.NoError(t, err)
assert.NotNil(t, b)
defer b.Close()
@ -99,7 +100,7 @@ func TestBalancer(t *testing.T) {
b.Trigger(ctx)
doneErr := errors.New("done")
err = b.WatchBalanceResult(ctx, func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error {
err = b.WatchChannelAssignments(ctx, func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error {
// should one pchannel be assigned to per nodes
nodeIDs := typeutil.NewSet[int64]()
if len(relations) == 3 {

View File

@ -18,7 +18,7 @@ func newPChannelMeta(name string) *PChannelMeta {
},
Node: nil,
State: streamingpb.PChannelMetaState_PCHANNEL_META_STATE_UNINITIALIZED,
Histories: make([]*streamingpb.PChannelMetaHistory, 0),
Histories: make([]*streamingpb.PChannelAssignmentLog, 0),
},
}
}
@ -114,7 +114,7 @@ func (m *mutablePChannel) TryAssignToServerID(streamingNode types.StreamingNodeI
}
if m.inner.State != streamingpb.PChannelMetaState_PCHANNEL_META_STATE_UNINITIALIZED {
// if the channel is already initialized, add the history.
m.inner.Histories = append(m.inner.Histories, &streamingpb.PChannelMetaHistory{
m.inner.Histories = append(m.inner.Histories, &streamingpb.PChannelAssignmentLog{
Term: m.inner.Channel.Term,
Node: m.inner.Node,
})
@ -130,7 +130,7 @@ func (m *mutablePChannel) TryAssignToServerID(streamingNode types.StreamingNodeI
// AssignToServerDone assigns the channel to the server done.
func (m *mutablePChannel) AssignToServerDone() {
if m.inner.State == streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNING {
m.inner.Histories = make([]*streamingpb.PChannelMetaHistory, 0)
m.inner.Histories = make([]*streamingpb.PChannelAssignmentLog, 0)
m.inner.State = streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNED
}
}

View File

@ -1,9 +1,12 @@
package resource
import (
"reflect"
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/milvus-io/milvus/internal/metastore"
"github.com/milvus-io/milvus/internal/streamingnode/client/manager"
)
var r *resourceImpl // singleton resource instance
@ -28,12 +31,15 @@ func OptStreamingCatalog(catalog metastore.StreamingCoordCataLog) optResourceIni
// Init initializes the singleton of resources.
// Should be call when streaming node startup.
func Init(opts ...optResourceInit) {
r = &resourceImpl{}
newR := &resourceImpl{}
for _, opt := range opts {
opt(r)
opt(newR)
}
assertNotNil(r.ETCD())
assertNotNil(r.StreamingCatalog())
assertNotNil(newR.ETCD())
assertNotNil(newR.StreamingCatalog())
// TODO: after add streaming node manager client, remove this line.
// assertNotNil(r.StreamingNodeManagerClient())
r = newR
}
// Resource access the underlying singleton of resources.
@ -44,8 +50,9 @@ func Resource() *resourceImpl {
// resourceImpl is a basic resource dependency for streamingnode server.
// All utility on it is concurrent-safe and singleton.
type resourceImpl struct {
etcdClient *clientv3.Client
streamingCatalog metastore.StreamingCoordCataLog
etcdClient *clientv3.Client
streamingCatalog metastore.StreamingCoordCataLog
streamingNodeManagerClient manager.ManagerClient
}
// StreamingCatalog returns the StreamingCatalog client.
@ -58,9 +65,21 @@ func (r *resourceImpl) ETCD() *clientv3.Client {
return r.etcdClient
}
// StreamingNodeClient returns the streaming node client.
func (r *resourceImpl) StreamingNodeManagerClient() manager.ManagerClient {
return r.streamingNodeManagerClient
}
// assertNotNil panics if the resource is nil.
func assertNotNil(v interface{}) {
if v == nil {
iv := reflect.ValueOf(v)
if !iv.IsValid() {
panic("nil resource")
}
switch iv.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Func, reflect.Interface:
if iv.IsNil() {
panic("nil resource")
}
}
}

View File

@ -17,7 +17,9 @@ func TestInit(t *testing.T) {
Init(OptETCD(&clientv3.Client{}))
})
assert.Panics(t, func() {
Init(OptETCD(&clientv3.Client{}))
Init(OptStreamingCatalog(
mock_metastore.NewMockStreamingCoordCataLog(t),
))
})
Init(OptETCD(&clientv3.Client{}), OptStreamingCatalog(
mock_metastore.NewMockStreamingCoordCataLog(t),

View File

@ -3,6 +3,17 @@
package resource
import (
"github.com/milvus-io/milvus/internal/streamingnode/client/manager"
)
// OptStreamingManagerClient provides streaming manager client to the resource.
func OptStreamingManagerClient(c manager.ManagerClient) optResourceInit {
return func(r *resourceImpl) {
r.streamingNodeManagerClient = c
}
}
// InitForTest initializes the singleton of resources for test.
func InitForTest(opts ...optResourceInit) {
r = &resourceImpl{}

View File

@ -90,7 +90,7 @@ func (s *AssignmentDiscoverServer) recvLoop() (err error) {
// sendLoop sends the message to client.
func (s *AssignmentDiscoverServer) sendLoop() error {
err := s.balancer.WatchBalanceResult(s.ctx, s.streamServer.SendFullAssignment)
err := s.balancer.WatchChannelAssignments(s.ctx, s.streamServer.SendFullAssignment)
if errors.Is(err, errClosedByUser) {
return s.streamServer.SendCloseResponse()
}

View File

@ -16,7 +16,7 @@ import (
func TestAssignmentDiscover(t *testing.T) {
b := mock_balancer.NewMockBalancer(t)
b.EXPECT().WatchBalanceResult(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error {
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error {
versions := []typeutil.VersionInt64Pair{
{Global: 1, Local: 2},
{Global: 1, Local: 3},
@ -59,7 +59,7 @@ func TestAssignmentDiscover(t *testing.T) {
Term: 1,
},
Err: &streamingpb.StreamingError{
Code: streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_EXIST,
Code: streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_NOT_EXIST,
},
},
},

View File

@ -3,16 +3,15 @@ package manager
import (
"context"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
)
type ManagerClient interface {
// WatchNodeChanged returns a channel that receive a node change.
WatchNodeChanged(ctx context.Context) <-chan map[int64]*sessionutil.SessionRaw
WatchNodeChanged(ctx context.Context) (<-chan struct{}, error)
// CollectStatus collects status of all wal instances in all streamingnode.
CollectAllStatus(ctx context.Context) (map[int64]types.StreamingNodeStatus, error)
CollectAllStatus(ctx context.Context) (map[int64]*types.StreamingNodeStatus, error)
// Assign a wal instance for the channel on log node of given server id.
Assign(ctx context.Context, pchannel types.PChannelInfoAssigned) error

View File

@ -1,6 +1,8 @@
package resource
import (
"reflect"
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource/timestamp"
@ -70,7 +72,14 @@ func (r *resourceImpl) RootCoordClient() types.RootCoordClient {
// assertNotNil panics if the resource is nil.
func assertNotNil(v interface{}) {
if v == nil {
iv := reflect.ValueOf(v)
if !iv.IsValid() {
panic("nil resource")
}
switch iv.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Func, reflect.Interface:
if iv.IsNil() {
panic("nil resource")
}
}
}

View File

@ -17,7 +17,7 @@ func TestInit(t *testing.T) {
Init(OptETCD(&clientv3.Client{}))
})
assert.Panics(t, func() {
Init(OptETCD(&clientv3.Client{}))
Init(OptRootCoordClient(mocks.NewMockRootCoordClient(t)))
})
Init(OptETCD(&clientv3.Client{}), OptRootCoordClient(mocks.NewMockRootCoordClient(t)))

View File

@ -56,7 +56,9 @@ func CreateConsumeServer(walManager walmanager.Manager, streamServer streamingpb
consumeServer := &consumeGrpcServerHelper{
StreamingNodeHandlerService_ConsumeServer: streamServer,
}
if err := consumeServer.SendCreated(&streamingpb.CreateConsumerResponse{}); err != nil {
if err := consumeServer.SendCreated(&streamingpb.CreateConsumerResponse{
WalName: l.WALName(),
}); err != nil {
// release the scanner to avoid resource leak.
if err := scanner.Close(); err != nil {
log.Warn("close scanner failed at create consume server", zap.Error(err))

View File

@ -19,10 +19,12 @@ func (p *produceGrpcServerHelper) SendProduceMessage(resp *streamingpb.ProduceMe
}
// SendCreated sends the create response to client.
func (p *produceGrpcServerHelper) SendCreated() error {
func (p *produceGrpcServerHelper) SendCreated(walName string) error {
return p.Send(&streamingpb.ProduceResponse{
Response: &streamingpb.ProduceResponse_Create{
Create: &streamingpb.CreateProducerResponse{},
Create: &streamingpb.CreateProducerResponse{
WalName: walName,
},
},
})
}

View File

@ -41,7 +41,7 @@ func CreateProduceServer(walManager walmanager.Manager, streamServer streamingpb
produceServer := &produceGrpcServerHelper{
StreamingNodeHandlerService_ProduceServer: streamServer,
}
if err := produceServer.SendCreated(); err != nil {
if err := produceServer.SendCreated(l.WALName()); err != nil {
return nil, errors.Wrap(err, "at send created")
}
return &ProduceServer{

View File

@ -54,6 +54,7 @@ func TestCreateProduceServer(t *testing.T) {
// Return error if create scanner failed.
l := mock_wal.NewMockWAL(t)
l.EXPECT().WALName().Return("test")
manager.ExpectedCalls = nil
manager.EXPECT().GetAvailableWAL(types.PChannelInfo{Name: "test", Term: 1}).Return(l, nil)
grpcProduceServer.EXPECT().Send(mock.Anything).Return(errors.New("send created failed"))

View File

@ -8,7 +8,6 @@ import (
"github.com/milvus-io/milvus/internal/streamingnode/server/wal"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/internal/util/streamingutil/util"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/streaming/walimpls"
@ -24,7 +23,7 @@ func adaptImplsToOpener(opener walimpls.OpenerImpls, builders []interceptors.Int
return &openerAdaptorImpl{
lifetime: lifetime.NewLifetime(lifetime.Working),
opener: opener,
idAllocator: util.NewIDAllocator(),
idAllocator: typeutil.NewIDAllocator(),
walInstances: typeutil.NewConcurrentMap[int64, wal.WAL](),
interceptorBuilders: builders,
}
@ -34,7 +33,7 @@ func adaptImplsToOpener(opener walimpls.OpenerImpls, builders []interceptors.Int
type openerAdaptorImpl struct {
lifetime lifetime.Lifetime[lifetime.State]
opener walimpls.OpenerImpls
idAllocator *util.IDAllocator
idAllocator *typeutil.IDAllocator
walInstances *typeutil.ConcurrentMap[int64, wal.WAL] // store all wal instances allocated by these allocator.
interceptorBuilders []interceptors.InterceptorBuilder
}

View File

@ -4,13 +4,13 @@ import (
"fmt"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal"
"github.com/milvus-io/milvus/internal/util/streamingutil/util"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type scannerRegistry struct {
channel types.PChannelInfo
idAllocator *util.IDAllocator
idAllocator *typeutil.IDAllocator
}
// AllocateScannerName a scanner name for a scanner.

View File

@ -8,7 +8,6 @@ import (
"github.com/milvus-io/milvus/internal/streamingnode/server/wal"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/internal/util/streamingutil/util"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/streaming/util/message"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
@ -35,14 +34,14 @@ func adaptImplsToWAL(
wal := &walAdaptorImpl{
lifetime: lifetime.NewLifetime(lifetime.Working),
idAllocator: util.NewIDAllocator(),
idAllocator: typeutil.NewIDAllocator(),
inner: basicWAL,
// TODO: make the pool size configurable.
appendExecutionPool: conc.NewPool[struct{}](10),
interceptor: interceptor,
scannerRegistry: scannerRegistry{
channel: basicWAL.Channel(),
idAllocator: util.NewIDAllocator(),
idAllocator: typeutil.NewIDAllocator(),
},
scanners: typeutil.NewConcurrentMap[int64, wal.Scanner](),
cleanup: cleanup,
@ -54,7 +53,7 @@ func adaptImplsToWAL(
// walAdaptorImpl is a wrapper of WALImpls to extend it into a WAL interface.
type walAdaptorImpl struct {
lifetime lifetime.Lifetime[lifetime.State]
idAllocator *util.IDAllocator
idAllocator *typeutil.IDAllocator
inner walimpls.WALImpls
appendExecutionPool *conc.Pool[struct{}]
interceptor interceptors.InterceptorWithReady

View File

@ -0,0 +1,62 @@
package attributes
import (
"google.golang.org/grpc/attributes"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
)
type attributesKeyType int
const (
serverIDKey attributesKeyType = iota
channelAssignmentInfoKey
sessionKey
)
type Attributes = attributes.Attributes
// GetServerID returns the serverID in the given Attributes.
func GetServerID(attr *Attributes) *int64 {
val := attr.Value(serverIDKey)
if val == nil {
return nil
}
serverID := val.(int64)
return &serverID
}
// WithServerID returns a new Attributes containing the given serverID.
func WithServerID(attr *Attributes, serverID int64) *Attributes {
return attr.WithValue(serverIDKey, serverID)
}
// WithChannelAssignmentInfo returns a new Attributes containing the given channelInfo.
func WithChannelAssignmentInfo(attr *Attributes, assignment *types.StreamingNodeAssignment) *attributes.Attributes {
return attr.WithValue(channelAssignmentInfoKey, assignment).WithValue(serverIDKey, assignment.NodeInfo.ServerID)
}
// GetChannelAssignmentInfoFromAttributes get the channel info fetched from streamingcoord.
// Generated by the channel assignment discoverer and sent to channel assignment balancer.
func GetChannelAssignmentInfoFromAttributes(attrs *Attributes) *types.StreamingNodeAssignment {
val := attrs.Value(channelAssignmentInfoKey)
if val == nil {
return nil
}
return val.(*types.StreamingNodeAssignment)
}
// WithSession returns a new Attributes containing the given session.
func WithSession(attr *Attributes, val *sessionutil.SessionRaw) *attributes.Attributes {
return attr.WithValue(sessionKey, val).WithValue(serverIDKey, val.ServerID)
}
// GetSessionFromAttributes get session from attributes.
func GetSessionFromAttributes(attrs *Attributes) *sessionutil.SessionRaw {
val := attrs.Value(sessionKey)
if val == nil {
return nil
}
return val.(*sessionutil.SessionRaw)
}

View File

@ -0,0 +1,43 @@
package attributes
import (
"testing"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
"github.com/stretchr/testify/assert"
)
func TestAttributes(t *testing.T) {
attr := new(Attributes)
serverID := GetServerID(attr)
assert.Nil(t, serverID)
assert.Nil(t, GetChannelAssignmentInfoFromAttributes(attr))
assert.Nil(t, GetSessionFromAttributes(attr))
attr = new(Attributes)
attr = WithChannelAssignmentInfo(attr, &types.StreamingNodeAssignment{
NodeInfo: types.StreamingNodeInfo{
ServerID: 1,
Address: "localhost:8080",
},
})
assert.NotNil(t, GetServerID(attr))
assert.Equal(t, int64(1), *GetServerID(attr))
assert.NotNil(t, GetChannelAssignmentInfoFromAttributes(attr))
assert.Equal(t, "localhost:8080", GetChannelAssignmentInfoFromAttributes(attr).NodeInfo.Address)
attr = new(Attributes)
attr = WithSession(attr, &sessionutil.SessionRaw{
ServerID: 1,
})
assert.NotNil(t, GetServerID(attr))
assert.Equal(t, int64(1), *GetServerID(attr))
assert.NotNil(t, GetSessionFromAttributes(attr))
assert.Equal(t, int64(1), GetSessionFromAttributes(attr).ServerID)
attr = new(Attributes)
attr = WithServerID(attr, 1)
serverID = GetServerID(attr)
assert.Equal(t, int64(1), *GetServerID(attr))
}

View File

@ -0,0 +1,242 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Modified by github.com/milvus-io/milvus, @chyezh
* - Add `UnReadySCs` into `PickerBuildInfo` for picker to do better chosen.
* - Remove extra log.
*
*/
package balancer
import (
"errors"
"fmt"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/resolver"
)
var (
_ balancer.Balancer = (*baseBalancer)(nil)
_ balancer.ExitIdler = (*baseBalancer)(nil)
_ balancer.Builder = (*baseBuilder)(nil)
)
type baseBuilder struct {
name string
pickerBuilder PickerBuilder
config base.Config
}
func (bb *baseBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
bal := &baseBalancer{
cc: cc,
pickerBuilder: bb.pickerBuilder,
subConns: resolver.NewAddressMap(),
scStates: make(map[balancer.SubConn]connectivity.State),
csEvltr: &balancer.ConnectivityStateEvaluator{},
config: bb.config,
state: connectivity.Connecting,
}
// Initialize picker to a picker that always returns
// ErrNoSubConnAvailable, because when state of a SubConn changes, we
// may call UpdateState with this picker.
bal.picker = base.NewErrPicker(balancer.ErrNoSubConnAvailable)
return bal
}
func (bb *baseBuilder) Name() string {
return bb.name
}
// baseBalancer is the base balancer for all balancers.
type baseBalancer struct {
cc balancer.ClientConn
pickerBuilder PickerBuilder
csEvltr *balancer.ConnectivityStateEvaluator
state connectivity.State
subConns *resolver.AddressMap
scStates map[balancer.SubConn]connectivity.State
picker balancer.Picker
config base.Config
resolverErr error // the last error reported by the resolver; cleared on successful resolution
connErr error // the last connection error; cleared upon leaving TransientFailure
}
func (b *baseBalancer) ResolverError(err error) {
b.resolverErr = err
if b.subConns.Len() == 0 {
b.state = connectivity.TransientFailure
}
if b.state != connectivity.TransientFailure {
// The picker will not change since the balancer does not currently
// report an error.
return
}
b.regeneratePicker()
b.cc.UpdateState(balancer.State{
ConnectivityState: b.state,
Picker: b.picker,
})
}
func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) error {
// Successful resolution; clear resolver error and ensure we return nil.
b.resolverErr = nil
// addrsSet is the set converted from addrs, it's used for quick lookup of an address.
addrsSet := resolver.NewAddressMap()
for _, a := range s.ResolverState.Addresses {
addrsSet.Set(a, nil)
if _, ok := b.subConns.Get(a); !ok {
// a is a new address (not existing in b.subConns).
sc, err := b.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{HealthCheckEnabled: b.config.HealthCheck})
if err != nil {
continue
}
b.subConns.Set(a, sc)
b.scStates[sc] = connectivity.Idle
b.csEvltr.RecordTransition(connectivity.Shutdown, connectivity.Idle)
sc.Connect()
}
}
for _, a := range b.subConns.Keys() {
sci, _ := b.subConns.Get(a)
sc := sci.(balancer.SubConn)
// a was removed by resolver.
if _, ok := addrsSet.Get(a); !ok {
b.cc.RemoveSubConn(sc)
b.subConns.Delete(a)
// Keep the state of this sc in b.scStates until sc's state becomes Shutdown.
// The entry will be deleted in UpdateSubConnState.
}
}
// If resolver state contains no addresses, return an error so ClientConn
// will trigger re-resolve. Also records this as an resolver error, so when
// the overall state turns transient failure, the error message will have
// the zero address information.
if len(s.ResolverState.Addresses) == 0 {
b.ResolverError(errors.New("produced zero addresses"))
return balancer.ErrBadResolverState
}
b.regeneratePicker()
b.cc.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker})
return nil
}
// mergeErrors builds an error from the last connection error and the last
// resolver error. Must only be called if b.state is TransientFailure.
func (b *baseBalancer) mergeErrors() error {
// connErr must always be non-nil unless there are no SubConns, in which
// case resolverErr must be non-nil.
if b.connErr == nil {
return fmt.Errorf("last resolver error: %v", b.resolverErr)
}
if b.resolverErr == nil {
return fmt.Errorf("last connection error: %v", b.connErr)
}
return fmt.Errorf("last connection error: %v; last resolver error: %v", b.connErr, b.resolverErr)
}
// regeneratePicker takes a snapshot of the balancer, and generates a picker
// from it. The picker is
// - errPicker if the balancer is in TransientFailure,
// - built by the pickerBuilder with all READY SubConns otherwise.
func (b *baseBalancer) regeneratePicker() {
if b.state == connectivity.TransientFailure {
b.picker = base.NewErrPicker(b.mergeErrors())
return
}
readySCs := make(map[balancer.SubConn]base.SubConnInfo)
unReadySCs := make(map[balancer.SubConn]base.SubConnInfo)
// Filter out all ready SCs from full subConn map.
for _, addr := range b.subConns.Keys() {
sci, _ := b.subConns.Get(addr)
sc := sci.(balancer.SubConn)
if st, ok := b.scStates[sc]; ok {
if st == connectivity.Ready {
readySCs[sc] = base.SubConnInfo{Address: addr}
continue
}
unReadySCs[sc] = base.SubConnInfo{Address: addr}
}
}
b.picker = b.pickerBuilder.Build(PickerBuildInfo{
ReadySCs: readySCs,
UnReadySCs: unReadySCs,
})
}
func (b *baseBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
s := state.ConnectivityState
oldS, ok := b.scStates[sc]
if !ok {
return
}
if oldS == connectivity.TransientFailure &&
(s == connectivity.Connecting || s == connectivity.Idle) {
// Once a subconn enters TRANSIENT_FAILURE, ignore subsequent IDLE or
// CONNECTING transitions to prevent the aggregated state from being
// always CONNECTING when many backends exist but are all down.
if s == connectivity.Idle {
sc.Connect()
}
return
}
b.scStates[sc] = s
switch s {
case connectivity.Idle:
sc.Connect()
case connectivity.Shutdown:
// When an address was removed by resolver, b called RemoveSubConn but
// kept the sc's state in scStates. Remove state for this sc here.
delete(b.scStates, sc)
case connectivity.TransientFailure:
// Save error to be reported via picker.
b.connErr = state.ConnectionError
}
b.state = b.csEvltr.RecordTransition(oldS, s)
// Regenerate picker when one of the following happens:
// - this sc entered or left ready
// - the aggregated state of balancer is TransientFailure
// (may need to update error message)
if (s == connectivity.Ready) != (oldS == connectivity.Ready) ||
b.state == connectivity.TransientFailure {
b.regeneratePicker()
}
b.cc.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker})
}
// Close is a nop because base balancer doesn't have internal state to clean up,
// and it doesn't need to call RemoveSubConn for the SubConns.
func (b *baseBalancer) Close() {
}
// ExitIdle is a nop because the base balancer attempts to stay connected to
// all SubConns at all times.
func (b *baseBalancer) ExitIdle() {
}

View File

@ -0,0 +1,98 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package balancer
import (
"testing"
"google.golang.org/grpc/attributes"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/resolver"
)
type testClientConn struct {
balancer.ClientConn
newSubConn func([]resolver.Address, balancer.NewSubConnOptions) (balancer.SubConn, error)
}
func (c *testClientConn) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
return c.newSubConn(addrs, opts)
}
func (c *testClientConn) UpdateState(balancer.State) {}
type testSubConn struct{}
func (sc *testSubConn) UpdateAddresses(addresses []resolver.Address) {}
func (sc *testSubConn) Connect() {}
func (sc *testSubConn) GetOrBuildProducer(balancer.ProducerBuilder) (balancer.Producer, func()) {
return nil, nil
}
// testPickBuilder creates balancer.Picker for test.
type testPickBuilder struct {
validate func(info PickerBuildInfo)
}
func (p *testPickBuilder) Build(info PickerBuildInfo) balancer.Picker {
p.validate(info)
return nil
}
func TestBaseBalancerReserveAttributes(t *testing.T) {
v := func(info PickerBuildInfo) {
for _, sc := range info.ReadySCs {
if sc.Address.Addr == "1.1.1.1" {
if sc.Address.Attributes == nil {
t.Errorf("in picker.validate, got address %+v with nil attributes, want not nil", sc.Address)
}
foo, ok := sc.Address.Attributes.Value("foo").(string)
if !ok || foo != "2233niang" {
t.Errorf("in picker.validate, got address[1.1.1.1] with invalid attributes value %v, want 2233niang", sc.Address.Attributes.Value("foo"))
}
} else if sc.Address.Addr == "2.2.2.2" {
if sc.Address.Attributes != nil {
t.Error("in b.subConns, got address[2.2.2.2] with not nil attributes, want nil")
}
}
}
}
pickBuilder := &testPickBuilder{validate: v}
b := (&baseBuilder{pickerBuilder: pickBuilder}).Build(&testClientConn{
newSubConn: func(addrs []resolver.Address, _ balancer.NewSubConnOptions) (balancer.SubConn, error) {
return &testSubConn{}, nil
},
}, balancer.BuildOptions{}).(*baseBalancer)
b.UpdateClientConnState(balancer.ClientConnState{
ResolverState: resolver.State{
Addresses: []resolver.Address{
{Addr: "1.1.1.1", Attributes: attributes.New("foo", "2233niang")},
{Addr: "2.2.2.2", Attributes: nil},
},
},
})
for sc := range b.scStates {
b.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Ready, ConnectionError: nil})
}
}

View File

@ -0,0 +1,50 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Modified by github.com/milvus-io/milvus, @chyezh
* - Only keep modified struct `PickerBuildInfo`, `PickerBuilder`, remove unmodified struct.
*
*/
package balancer
import (
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
)
type PickerBuildInfo struct {
// ReadySCs is a map from all ready SubConns to the Addresses can be used.
ReadySCs map[balancer.SubConn]base.SubConnInfo
// UnReadySCs is a map from all unready SubConns to the Addresses can be used.
UnReadySCs map[balancer.SubConn]base.SubConnInfo
}
// PickerBuilder creates balancer.Picker.
type PickerBuilder interface {
// Build returns a picker that will be used by gRPC to pick a SubConn.
Build(info PickerBuildInfo) balancer.Picker
}
// NewBalancerBuilder returns a base balancer builder configured by the provided config.
func NewBalancerBuilder(name string, pb PickerBuilder, config base.Config) balancer.Builder {
return &baseBuilder{
name: name,
pickerBuilder: pb,
config: config,
}
}

View File

@ -0,0 +1,77 @@
package picker
import (
"go.uber.org/atomic"
"go.uber.org/zap"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/attributes"
bbalancer "github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer"
"github.com/milvus-io/milvus/pkg/log"
)
const (
ServerIDPickerBalancerName = "server_id_picker"
)
func init() {
balancer.Register(bbalancer.NewBalancerBuilder(
ServerIDPickerBalancerName,
&serverIDPickerBuilder{},
base.Config{HealthCheck: true}),
)
}
// serverIDPickerBuilder is a bkproxy picker builder.
type serverIDPickerBuilder struct{}
// Build returns a picker that will be used by gRPC to pick a SubConn.
func (b *serverIDPickerBuilder) Build(info bbalancer.PickerBuildInfo) balancer.Picker {
if len(info.ReadySCs) == 0 {
return base.NewErrPicker(balancer.ErrNoSubConnAvailable)
}
readyMap := make(map[int64]subConnInfo, len(info.ReadySCs))
readyList := make([]subConnInfo, 0, len(info.ReadySCs))
for sc, scInfo := range info.ReadySCs {
serverID := attributes.GetServerID(scInfo.Address.BalancerAttributes)
if serverID == nil {
log.Warn("no server id found in subConn", zap.String("address", scInfo.Address.Addr))
continue
}
info := subConnInfo{
serverID: *serverID,
subConn: sc,
subConnInfo: scInfo,
}
readyMap[*serverID] = info
readyList = append(readyList, info)
}
unReadyMap := make(map[int64]subConnInfo, len(info.UnReadySCs))
for sc, scInfo := range info.UnReadySCs {
serverID := attributes.GetServerID(scInfo.Address.BalancerAttributes)
if serverID == nil {
log.Warn("no server id found in subConn", zap.String("address", scInfo.Address.Addr))
continue
}
info := subConnInfo{
serverID: *serverID,
subConn: sc,
subConnInfo: scInfo,
}
unReadyMap[*serverID] = info
}
if len(readyList) == 0 {
log.Warn("no subConn available after serverID filtering")
return base.NewErrPicker(balancer.ErrNoSubConnAvailable)
}
p := &serverIDPicker{
next: atomic.NewInt64(0),
readySubConnsMap: readyMap,
readySubConsList: readyList,
unreadySubConnsMap: unReadyMap,
}
return p
}

View File

@ -0,0 +1,124 @@
package picker
import (
"strconv"
"github.com/cockroachdb/errors"
"go.uber.org/atomic"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil"
"github.com/milvus-io/milvus/pkg/util/interceptor"
)
var _ balancer.Picker = &serverIDPicker{}
var ErrNoSubConnNotExist = status.New(codes.Unavailable, "sub connection not exist").Err()
type subConnInfo struct {
serverID int64
subConn balancer.SubConn
subConnInfo base.SubConnInfo
}
// serverIDPicker is a force address picker.
type serverIDPicker struct {
next *atomic.Int64 // index of the next subConn to pick.
readySubConsList []subConnInfo // ready resolver ordered list.
readySubConnsMap map[int64]subConnInfo // map the server id to ready subConnInfo.
unreadySubConnsMap map[int64]subConnInfo // map the server id to unready subConnInfo.
}
// Pick returns the connection to use for this RPC and related information.
//
// Pick should not block. If the balancer needs to do I/O or any blocking
// or time-consuming work to service this call, it should return
// ErrNoSubConnAvailable, and the Pick call will be repeated by gRPC when
// the Picker is updated (using ClientConn.UpdateState).
//
// If an error is returned:
//
// - If the error is ErrNoSubConnAvailable, gRPC will block until a new
// Picker is provided by the balancer (using ClientConn.UpdateState).
//
// - If the error is a status error (implemented by the grpc/status
// package), gRPC will terminate the RPC with the code and message
// provided.
//
// - For all other errors, wait for ready RPCs will wait, but non-wait for
// ready RPCs will be terminated with this error's Error() string and
// status code Unavailable.
func (p *serverIDPicker) Pick(pickInfo balancer.PickInfo) (balancer.PickResult, error) {
var conn *subConnInfo
var err error
serverID, ok := contextutil.GetPickServerID(pickInfo.Ctx)
if !ok {
// round robin should be blocked.
if conn, err = p.roundRobin(); err != nil {
return balancer.PickResult{}, err
}
} else {
// force address should not be blocked.
if conn, err = p.useGivenAddr(pickInfo, serverID); err != nil {
return balancer.PickResult{}, err
}
}
return balancer.PickResult{
SubConn: conn.subConn,
Done: nil, // TODO: add a done function to handle the rpc finished.
// Add the server id to the metadata.
// See interceptor.ServerIDValidationUnaryServerInterceptor
Metadata: metadata.Pairs(
interceptor.ServerIDKey,
strconv.FormatInt(conn.serverID, 10),
),
}, nil
}
// roundRobin returns the next subConn in round robin.
func (p *serverIDPicker) roundRobin() (*subConnInfo, error) {
if len(p.readySubConsList) == 0 {
return nil, balancer.ErrNoSubConnAvailable
}
subConnsLen := len(p.readySubConsList)
nextIndex := int(p.next.Inc()) % subConnsLen
return &p.readySubConsList[nextIndex], nil
}
// useGivenAddr returns whether given subConn.
func (p *serverIDPicker) useGivenAddr(_ balancer.PickInfo, serverID int64) (*subConnInfo, error) {
sc, ok := p.readySubConnsMap[serverID]
if ok {
return &sc, nil
}
// subConn is not ready, return ErrNoSubConnAvailable to wait the connection ready.
if _, ok := p.unreadySubConnsMap[serverID]; ok {
return nil, balancer.ErrNoSubConnAvailable
}
// If the given address is not in the readySubConnsMap or unreadySubConnsMap, return a unavailable error to user to avoid block rpc.
// FailPrecondition will be converted to Internal by grpc framework in function `IsRestrictedControlPlaneCode`.
// Use Unavailable here.
// Unavailable code is retried in many cases, so it's better to be used here to avoid when Subconn is not ready scene.
return nil, ErrNoSubConnNotExist
}
// IsErrNoSubConnForPick checks whether the error is ErrNoSubConnForPick.
func IsErrNoSubConnForPick(err error) bool {
if errors.Is(err, ErrNoSubConnNotExist) {
return true
}
if se, ok := err.(interface {
GRPCStatus() *status.Status
}); ok {
return errors.Is(se.GRPCStatus().Err(), ErrNoSubConnNotExist)
}
return false
}

View File

@ -0,0 +1,103 @@
package picker
import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/resolver"
"github.com/milvus-io/milvus/internal/mocks/google.golang.org/grpc/mock_balancer"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/attributes"
bbalancer "github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/pkg/util/interceptor"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func TestServerIDPickerBuilder(t *testing.T) {
builder := &serverIDPickerBuilder{}
picker := builder.Build(bbalancer.PickerBuildInfo{})
assert.NotNil(t, picker)
_, err := picker.Pick(balancer.PickInfo{})
assert.Error(t, err)
assert.ErrorIs(t, err, balancer.ErrNoSubConnAvailable)
picker = builder.Build(bbalancer.PickerBuildInfo{
ReadySCs: map[balancer.SubConn]base.SubConnInfo{
mock_balancer.NewMockSubConn(t): {
Address: resolver.Address{
Addr: "localhost:1",
BalancerAttributes: attributes.WithServerID(
new(attributes.Attributes),
1,
),
},
},
mock_balancer.NewMockSubConn(t): {
Address: resolver.Address{
Addr: "localhost:2",
BalancerAttributes: attributes.WithServerID(
new(attributes.Attributes),
2,
),
},
},
},
UnReadySCs: map[balancer.SubConn]base.SubConnInfo{
mock_balancer.NewMockSubConn(t): {
Address: resolver.Address{
Addr: "localhost:3",
BalancerAttributes: attributes.WithServerID(
new(attributes.Attributes),
3,
),
},
},
},
})
// Test round-robin
serverIDSet := typeutil.NewSet[string]()
info, err := picker.Pick(balancer.PickInfo{Ctx: context.Background()})
assert.NoError(t, err)
serverIDSet.Insert(info.Metadata.Get(interceptor.ServerIDKey)[0])
info, err = picker.Pick(balancer.PickInfo{Ctx: context.Background()})
assert.NoError(t, err)
serverIDSet.Insert(info.Metadata.Get(interceptor.ServerIDKey)[0])
serverIDSet.Insert(info.Metadata.Get(interceptor.ServerIDKey)[0])
assert.Equal(t, 2, serverIDSet.Len())
// Test force address
info, err = picker.Pick(balancer.PickInfo{
Ctx: contextutil.WithPickServerID(context.Background(), 1),
})
assert.NoError(t, err)
assert.Equal(t, "1", info.Metadata.Get(interceptor.ServerIDKey)[0])
// Test pick not ready
info, err = picker.Pick(balancer.PickInfo{
Ctx: contextutil.WithPickServerID(context.Background(), 3),
})
assert.Error(t, err)
assert.ErrorIs(t, err, balancer.ErrNoSubConnAvailable)
assert.NotNil(t, info)
// Test pick not exists
info, err = picker.Pick(balancer.PickInfo{
Ctx: contextutil.WithPickServerID(context.Background(), 4),
})
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoSubConnNotExist)
assert.NotNil(t, info)
}
func TestIsErrNoSubConnForPick(t *testing.T) {
assert.True(t, IsErrNoSubConnForPick(ErrNoSubConnNotExist))
assert.False(t, IsErrNoSubConnForPick(errors.New("test")))
err := status.ConvertStreamingError("test", ErrNoSubConnNotExist)
assert.True(t, IsErrNoSubConnForPick(err))
}

View File

@ -0,0 +1,33 @@
package contextutil
import (
"context"
)
type (
pickResultKeyType int
)
var pickResultServerIDKey pickResultKeyType = 0
// WithPickServerID returns a new context with the pick result.
func WithPickServerID(ctx context.Context, serverID int64) context.Context {
return context.WithValue(ctx, pickResultServerIDKey, &serverIDPickResult{
serverID: serverID,
})
}
// GetPickServerID must get the pick result from context.
// panic otherwise.
func GetPickServerID(ctx context.Context) (int64, bool) {
pr := ctx.Value(pickResultServerIDKey)
if pr == nil {
return -1, false
}
return pr.(*serverIDPickResult).serverID, true
}
// serverIDPickResult is used to store the result of picker.
type serverIDPickResult struct {
serverID int64
}

View File

@ -0,0 +1,25 @@
package contextutil
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestWithPickServerID(t *testing.T) {
ctx := context.Background()
ctx = WithPickServerID(ctx, 1)
serverID, ok := GetPickServerID(ctx)
assert.True(t, ok)
assert.EqualValues(t, 1, serverID)
}
func TestGetPickServerID(t *testing.T) {
ctx := context.Background()
serverID, ok := GetPickServerID(ctx)
assert.False(t, ok)
assert.EqualValues(t, -1, serverID)
// normal case is tested in TestWithPickServerID
}

View File

@ -0,0 +1,84 @@
package discoverer
import (
"context"
"go.uber.org/zap"
"google.golang.org/grpc/resolver"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/attributes"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
// NewChannelAssignmentDiscoverer returns a new Discoverer for the channel assignment registration.
func NewChannelAssignmentDiscoverer(logCoordManager types.AssignmentDiscoverWatcher) Discoverer {
return &channelAssignmentDiscoverer{
assignmentWatcher: logCoordManager,
lastDiscovery: nil,
}
}
// channelAssignmentDiscoverer is the discoverer for channel assignment.
type channelAssignmentDiscoverer struct {
assignmentWatcher types.AssignmentDiscoverWatcher // last discovered state and last version discovery.
lastDiscovery *types.VersionedStreamingNodeAssignments
}
// NewVersionedState returns a lowest versioned state.
func (d *channelAssignmentDiscoverer) NewVersionedState() VersionedState {
return VersionedState{
Version: typeutil.VersionInt64Pair{Global: -1, Local: -1},
State: resolver.State{},
}
}
// channelAssignmentDiscoverer implements the resolver.Discoverer interface.
func (d *channelAssignmentDiscoverer) Discover(ctx context.Context, cb func(VersionedState) error) error {
// Always send the current state first.
// Outside logic may lost the last state before retry Discover function.
if err := cb(d.parseState()); err != nil {
return err
}
return d.assignmentWatcher.AssignmentDiscover(ctx, func(assignments *types.VersionedStreamingNodeAssignments) error {
d.lastDiscovery = assignments
return cb(d.parseState())
})
}
// parseState parses the addresses from the discovery response.
// Always perform a copy here.
func (d *channelAssignmentDiscoverer) parseState() VersionedState {
if d.lastDiscovery == nil {
return d.NewVersionedState()
}
addrs := make([]resolver.Address, 0, len(d.lastDiscovery.Assignments))
for _, assignment := range d.lastDiscovery.Assignments {
assignment := assignment
addrs = append(addrs, resolver.Address{
Addr: assignment.NodeInfo.Address,
BalancerAttributes: attributes.WithChannelAssignmentInfo(new(attributes.Attributes), &assignment),
})
}
// TODO: service config should be sent by resolver in future to achieve dynamic configuration for grpc.
return VersionedState{
Version: d.lastDiscovery.Version,
State: resolver.State{Addresses: addrs},
}
}
// ChannelAssignmentInfo returns the channel assignment info from the resolver state.
func (s *VersionedState) ChannelAssignmentInfo() map[int64]types.StreamingNodeAssignment {
assignments := make(map[int64]types.StreamingNodeAssignment)
for _, v := range s.State.Addresses {
assignment := attributes.GetChannelAssignmentInfoFromAttributes(v.BalancerAttributes)
if assignment == nil {
log.Error("no assignment found in resolver state, skip it", zap.String("address", v.Addr))
continue
}
assignments[assignment.NodeInfo.ServerID] = *assignment
}
return assignments
}

View File

@ -0,0 +1,98 @@
package discoverer
import (
"context"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_types"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func TestChannelAssignmentDiscoverer(t *testing.T) {
w := mock_types.NewMockAssignmentDiscoverWatcher(t)
ch := make(chan *types.VersionedStreamingNodeAssignments, 10)
w.EXPECT().AssignmentDiscover(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case result, ok := <-ch:
if ok {
if err := cb(result); err != nil {
return err
}
} else {
return io.EOF
}
}
}
})
d := NewChannelAssignmentDiscoverer(w)
s := d.NewVersionedState()
assert.True(t, s.Version.EQ(typeutil.VersionInt64Pair{Global: -1, Local: -1}))
expected := []*types.VersionedStreamingNodeAssignments{
{
Version: typeutil.VersionInt64Pair{Global: -1, Local: -1},
Assignments: map[int64]types.StreamingNodeAssignment{},
},
{
Version: typeutil.VersionInt64Pair{
Global: 1,
Local: 2,
},
Assignments: map[int64]types.StreamingNodeAssignment{
1: {
NodeInfo: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"},
Channels: map[string]types.PChannelInfo{
"ch1": {Name: "ch1", Term: 1},
},
},
},
},
{
Version: typeutil.VersionInt64Pair{
Global: 3,
Local: 4,
},
Assignments: map[int64]types.StreamingNodeAssignment{},
},
{
Version: typeutil.VersionInt64Pair{
Global: 5,
Local: 6,
},
Assignments: map[int64]types.StreamingNodeAssignment{
1: {
NodeInfo: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"},
Channels: map[string]types.PChannelInfo{
"ch2": {Name: "ch2", Term: 1},
},
},
},
},
}
idx := 0
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := d.Discover(ctx, func(state VersionedState) error {
assert.True(t, expected[idx].Version.EQ(state.Version))
assignment := state.ChannelAssignmentInfo()
assert.Equal(t, expected[idx].Assignments, assignment)
if idx < len(expected)-1 {
ch <- expected[idx+1]
idx++
return nil
}
return io.EOF
})
assert.ErrorIs(t, err, io.EOF)
}

View File

@ -0,0 +1,29 @@
package discoverer
import (
"context"
"google.golang.org/grpc/resolver"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
// Discoverer is the interface for the discoverer.
// Do not promise
// 1. concurrent safe.
// 2. the version of discovery may be repeated or decreasing. So user should check the version in callback.
type Discoverer interface {
// NewVersionedState returns a lowest versioned state.
NewVersionedState() VersionedState
// Discover watches the service discovery on these goroutine.
// 1. Call the callback when the discovery is changed, and block until the discovery is canceled or break down.
// 2. Discover should always send the current state first and then block.
Discover(ctx context.Context, cb func(VersionedState) error) error
}
// VersionedState is the state with version.
type VersionedState struct {
Version typeutil.Version
State resolver.State
}

View File

@ -0,0 +1,202 @@
package discoverer
import (
"context"
"encoding/json"
"github.com/blang/semver/v4"
"github.com/cockroachdb/errors"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
"google.golang.org/grpc/resolver"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/attributes"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
// NewSessionDiscoverer returns a new Discoverer for the milvus session registration.
func NewSessionDiscoverer(etcdCli *clientv3.Client, prefix string, minimumVersion string) Discoverer {
return &sessionDiscoverer{
etcdCli: etcdCli,
prefix: prefix,
versionRange: semver.MustParseRange(">=" + minimumVersion),
logger: log.With(zap.String("prefix", prefix), zap.String("expectedVersion", minimumVersion)),
revision: 0,
peerSessions: make(map[string]*sessionutil.SessionRaw),
}
}
// sessionDiscoverer is used to apply a session watch on etcd.
type sessionDiscoverer struct {
etcdCli *clientv3.Client
prefix string
logger *log.MLogger
versionRange semver.Range
revision int64
peerSessions map[string]*sessionutil.SessionRaw // map[Key]SessionRaw, map the key path of session to session.
}
// NewVersionedState return the empty version state.
func (sw *sessionDiscoverer) NewVersionedState() VersionedState {
return VersionedState{
Version: typeutil.VersionInt64(-1),
State: resolver.State{},
}
}
// Discover watches the service discovery on these goroutine.
// It may be broken down if compaction happens on etcd server.
func (sw *sessionDiscoverer) Discover(ctx context.Context, cb func(VersionedState) error) error {
// init the discoverer.
if err := sw.initDiscover(ctx); err != nil {
return err
}
// Always send the current state first.
// Outside logic may lost the last state before retry Discover function.
if err := cb(sw.parseState()); err != nil {
return err
}
return sw.watch(ctx, cb)
}
// watch performs the watch on etcd.
func (sw *sessionDiscoverer) watch(ctx context.Context, cb func(VersionedState) error) error {
// start a watcher at background.
eventCh := sw.etcdCli.Watch(
ctx,
sw.prefix,
clientv3.WithPrefix(),
clientv3.WithRev(sw.revision+1),
)
for {
// Watch the etcd events.
select {
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "cancel the discovery")
case event, ok := <-eventCh:
// Break the loop if the watch is failed.
if !ok {
return errors.New("etcd watch channel closed unexpectedly")
}
if err := sw.handleETCDEvent(event); err != nil {
return err
}
}
if err := cb(sw.parseState()); err != nil {
return err
}
}
}
// handleETCDEvent handles the etcd event.
func (sw *sessionDiscoverer) handleETCDEvent(resp clientv3.WatchResponse) error {
if resp.Err() != nil {
return resp.Err()
}
for _, ev := range resp.Events {
logger := sw.logger.With(zap.String("event", ev.Type.String()),
zap.String("sessionKey", string(ev.Kv.Key)))
switch ev.Type {
case clientv3.EventTypePut:
logger = logger.With(zap.String("sessionValue", string(ev.Kv.Value)))
session, err := sw.parseSession(ev.Kv.Value)
if err != nil {
logger.Warn("failed to parse session", zap.Error(err))
continue
}
logger.Info("new server modification")
sw.peerSessions[string(ev.Kv.Key)] = session
case clientv3.EventTypeDelete:
logger.Info("old server removed")
delete(sw.peerSessions, string(ev.Kv.Key))
}
}
// Update last revision.
sw.revision = resp.Header.Revision
return nil
}
// initDiscover initializes the discoverer if needed.
func (sw *sessionDiscoverer) initDiscover(ctx context.Context) error {
if sw.revision > 0 {
return nil
}
resp, err := sw.etcdCli.Get(ctx, sw.prefix, clientv3.WithPrefix(), clientv3.WithSerializable())
if err != nil {
return err
}
for _, kv := range resp.Kvs {
logger := sw.logger.With(zap.String("sessionKey", string(kv.Key)), zap.String("sessionValue", string(kv.Value)))
session, err := sw.parseSession(kv.Value)
if err != nil {
logger.Warn("fail to parse session when initializing discoverer", zap.Error(err))
continue
}
logger.Info("new server initialization", zap.Any("session", session))
sw.peerSessions[string(kv.Key)] = session
}
sw.revision = resp.Header.Revision
return nil
}
// parseSession parse the session from etcd value.
func (sw *sessionDiscoverer) parseSession(value []byte) (*sessionutil.SessionRaw, error) {
session := new(sessionutil.SessionRaw)
if err := json.Unmarshal(value, session); err != nil {
return nil, err
}
return session, nil
}
// parseState parse the state from peerSessions.
// Always perform a copy here.
func (sw *sessionDiscoverer) parseState() VersionedState {
addrs := make([]resolver.Address, 0, len(sw.peerSessions))
for _, session := range sw.peerSessions {
session := session
v, err := semver.Parse(session.Version)
if err != nil {
sw.logger.Error("failed to parse version for session", zap.Int64("serverID", session.ServerID), zap.String("version", session.Version), zap.Error(err))
continue
}
// filter low version.
if !sw.versionRange(v) {
sw.logger.Info("skip low version node", zap.Int64("serverID", session.ServerID), zap.String("version", session.Version))
continue
}
// !!! important, stopping nodes should not be removed here.
attr := new(attributes.Attributes)
attr = attributes.WithSession(attr, session)
addrs = append(addrs, resolver.Address{
Addr: session.Address,
BalancerAttributes: attr,
})
}
// TODO: service config should be sent by resolver in future to achieve dynamic configuration for grpc.
return VersionedState{
Version: typeutil.VersionInt64(sw.revision),
State: resolver.State{Addresses: addrs},
}
}
// Sessions returns the sessions in the state.
// Should only be called when using session discoverer.
func (s *VersionedState) Sessions() map[int64]*sessionutil.SessionRaw {
sessions := make(map[int64]*sessionutil.SessionRaw)
for _, v := range s.State.Addresses {
session := attributes.GetSessionFromAttributes(v.BalancerAttributes)
if session == nil {
log.Error("no session found in resolver state, skip it", zap.String("address", v.Addr))
continue
}
sessions[session.ServerID] = session
}
return sessions
}

View File

@ -0,0 +1,111 @@
package discoverer
import (
"context"
"encoding/json"
"fmt"
"io"
"testing"
"github.com/blang/semver/v4"
"github.com/stretchr/testify/assert"
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func TestSessionDiscoverer(t *testing.T) {
err := etcd.InitEtcdServer(true, "", t.TempDir(), "stdout", "info")
assert.NoError(t, err)
defer etcd.StopEtcdServer()
etcdClient, err := etcd.GetEmbedEtcdClient()
assert.NoError(t, err)
targetVersion := "0.1.0"
d := NewSessionDiscoverer(etcdClient, "session/", targetVersion)
s := d.NewVersionedState()
assert.True(t, s.Version.EQ(typeutil.VersionInt64(-1)))
expected := []map[int64]*sessionutil.SessionRaw{
{},
{
1: {ServerID: 1, Version: "0.2.0"},
},
{
1: {ServerID: 1, Version: "0.2.0"},
2: {ServerID: 2, Version: "0.4.0"},
},
{
1: {ServerID: 1, Version: "0.2.0"},
2: {ServerID: 2, Version: "0.4.0"},
3: {ServerID: 3, Version: "0.3.0"},
},
{
1: {ServerID: 1, Version: "0.2.0"},
2: {ServerID: 2, Version: "0.4.0"},
3: {ServerID: 3, Version: "0.3.0", Stopping: true},
},
{
1: {ServerID: 1, Version: "0.2.0"},
2: {ServerID: 2, Version: "0.4.0"},
3: {ServerID: 3, Version: "0.3.0"},
4: {ServerID: 4, Version: "0.0.1"}, // version filtering
},
}
idx := 0
var lastVersion typeutil.Version = typeutil.VersionInt64(-1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err = d.Discover(ctx, func(state VersionedState) error {
sessions := state.Sessions()
expectedSessions := make(map[int64]*sessionutil.SessionRaw, len(expected[idx]))
for k, v := range expected[idx] {
if semver.MustParse(v.Version).GT(semver.MustParse(targetVersion)) {
expectedSessions[k] = v
}
}
assert.Equal(t, expectedSessions, sessions)
assert.True(t, state.Version.GT(lastVersion))
lastVersion = state.Version
if idx < len(expected)-1 {
ops := make([]clientv3.Op, 0, len(expected[idx+1]))
for k, v := range expected[idx+1] {
sessionStr, err := json.Marshal(v)
assert.NoError(t, err)
ops = append(ops, clientv3.OpPut(fmt.Sprintf("session/%d", k), string(sessionStr)))
}
resp, err := etcdClient.Txn(ctx).Then(
ops...,
).Commit()
assert.NoError(t, err)
assert.NotNil(t, resp)
idx++
return nil
}
return io.EOF
})
assert.ErrorIs(t, err, io.EOF)
// Do a init discover here.
d = NewSessionDiscoverer(etcdClient, "session/", targetVersion)
err = d.Discover(ctx, func(state VersionedState) error {
sessions := state.Sessions()
expectedSessions := make(map[int64]*sessionutil.SessionRaw, len(expected[idx]))
for k, v := range expected[idx] {
if semver.MustParse(v.Version).GT(semver.MustParse(targetVersion)) {
expectedSessions[k] = v
}
}
assert.Equal(t, expectedSessions, sessions)
return io.EOF
})
assert.ErrorIs(t, err, io.EOF)
}

View File

@ -0,0 +1,35 @@
package interceptor
import (
"context"
"strings"
"google.golang.org/grpc"
"github.com/milvus-io/milvus/internal/proto/streamingpb"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
)
// NewStreamingServiceUnaryClientInterceptor returns a new unary client interceptor for error handling.
func NewStreamingServiceUnaryClientInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
err := invoker(ctx, method, req, reply, cc, opts...)
if strings.HasPrefix(method, streamingpb.ServiceMethodPrefix) {
st := status.ConvertStreamingError(method, err)
return st
}
return err
}
}
// NewStreamingServiceStreamClientInterceptor returns a new stream client interceptor for error handling.
func NewStreamingServiceStreamClientInterceptor() grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
clientStream, err := streamer(ctx, desc, cc, method, opts...)
if strings.HasPrefix(method, streamingpb.ServiceMethodPrefix) {
e := status.ConvertStreamingError(method, err)
return status.NewClientStreamWrapper(method, clientStream), e
}
return clientStream, err
}
}

View File

@ -0,0 +1,52 @@
package interceptor
import (
"context"
"strings"
"google.golang.org/grpc"
"github.com/milvus-io/milvus/internal/proto/streamingpb"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
)
// NewStreamingServiceUnaryServerInterceptor returns a new unary server interceptor for error handling, metric...
func NewStreamingServiceUnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
resp, err := handler(ctx, req)
if err == nil {
return resp, err
}
// Streaming Service Method should be overwrite the response error code.
if strings.HasPrefix(info.FullMethod, streamingpb.ServiceMethodPrefix) {
err := status.AsStreamingError(err)
if err == nil {
// return no error if StreamingError is ok.
return resp, nil
}
return resp, status.NewGRPCStatusFromStreamingError(err).Err()
}
return resp, err
}
}
// NewStreamingServiceStreamServerInterceptor returns a new stream server interceptor for error handling, metric...
func NewStreamingServiceStreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
err := handler(srv, ss)
if err == nil {
return err
}
// Streaming Service Method should be overwrite the response error code.
if strings.HasPrefix(info.FullMethod, streamingpb.ServiceMethodPrefix) {
err := status.AsStreamingError(err)
if err == nil {
// return no error if StreamingError is ok.
return nil
}
return status.NewGRPCStatusFromStreamingError(err).Err()
}
return err
}
}

View File

@ -0,0 +1,93 @@
package lazygrpc
import (
"context"
"github.com/cenkalti/backoff/v4"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"google.golang.org/grpc"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/syncutil"
)
var ErrClosed = errors.New("lazy grpc conn closed")
// NewConn creates a new lazy grpc conn.
func NewConn(dialer func(ctx context.Context) (*grpc.ClientConn, error)) Conn {
conn := &connImpl{
initializationNotifier: syncutil.NewAsyncTaskNotifier[struct{}](),
conn: syncutil.NewFuture[*grpc.ClientConn](),
dialer: dialer,
}
go conn.initialize()
return conn
}
// Conn is a lazy grpc conn implementation.
// grpc.Dial operation will block until new grpc conn is created at least once.
// Conn will dial the underlying grpc conn asynchronously to avoid dependency cycle of milvus component when create grpc client.
// TODO: Remove in future if we can refactor the dependency cycle.
type Conn interface {
// GetConn will block until the grpc.ClientConn is ready to use.
// If the context is done, return immediately with the context.Canceled or Context.DeadlineExceeded error.
// Return ErrClosed if the lazy grpc conn is closed.
GetConn(ctx context.Context) (*grpc.ClientConn, error)
// Close closes the lazy grpc conn.
// Close the underlying grpc conn if it is already created.
Close()
}
type connImpl struct {
initializationNotifier *syncutil.AsyncTaskNotifier[struct{}]
conn *syncutil.Future[*grpc.ClientConn]
dialer func(ctx context.Context) (*grpc.ClientConn, error)
}
func (c *connImpl) initialize() {
defer c.initializationNotifier.Finish(struct{}{})
backoff.Retry(func() error {
conn, err := c.dialer(c.initializationNotifier.Context())
if err != nil {
if c.initializationNotifier.Context().Err() != nil {
log.Info("lazy grpc conn canceled", zap.Error(c.initializationNotifier.Context().Err()))
return nil
}
log.Warn("async dial failed, wait for retry...", zap.Error(err))
return err
}
c.conn.Set(conn)
return nil
}, backoff.NewExponentialBackOff())
}
func (c *connImpl) GetConn(ctx context.Context) (*grpc.ClientConn, error) {
// If the context is done, return immediately to perform a stable shutdown error after closing.
if c.initializationNotifier.Context().Err() != nil {
return nil, ErrClosed
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-c.initializationNotifier.Context().Done():
return nil, ErrClosed
case <-c.conn.Done():
return c.conn.Get(), nil
}
}
func (c *connImpl) Close() {
c.initializationNotifier.Cancel()
c.initializationNotifier.BlockUntilFinish()
if c.conn.Ready() {
if err := c.conn.Get().Close(); err != nil {
log.Warn("close underlying grpc conn fail", zap.Error(err))
}
}
}

View File

@ -0,0 +1,80 @@
package lazygrpc
import (
"context"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"
)
func TestLazyConn(t *testing.T) {
listener := bufconn.Listen(1024)
s := grpc.NewServer()
go s.Serve(listener)
defer s.Stop()
ticker := time.NewTicker(3 * time.Second)
defer ticker.Stop()
lconn := NewConn(func(ctx context.Context) (*grpc.ClientConn, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
return grpc.DialContext(ctx, "", grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
return listener.Dial()
}), grpc.WithTransportCredentials(insecure.NewCredentials()))
}
})
// Get with timeout
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
conn, err := lconn.GetConn(ctx)
assert.Nil(t, conn)
assert.ErrorIs(t, err, context.DeadlineExceeded)
// Get conn after timeout
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err = lconn.GetConn(ctx)
assert.NotNil(t, conn)
assert.Nil(t, err)
// Get with closed.
lconn.Close()
conn, err = lconn.GetConn(context.Background())
assert.ErrorIs(t, err, ErrClosed)
assert.Nil(t, conn)
// Get before initialize.
ticker = time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
lconn = NewConn(func(ctx context.Context) (*grpc.ClientConn, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
return grpc.DialContext(ctx, "", grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
return listener.Dial()
}), grpc.WithTransportCredentials(insecure.NewCredentials()))
}
})
// Test WithLazyGRPCServiceCreator
grpcService := WithServiceCreator(lconn, func(*grpc.ClientConn) int {
return 1
})
realService, err := grpcService.GetService(ctx)
assert.Equal(t, 1, realService)
assert.NoError(t, err)
lconn.Close()
conn, err = lconn.GetConn(context.Background())
assert.ErrorIs(t, err, ErrClosed)
assert.Nil(t, conn)
}

View File

@ -0,0 +1,37 @@
package lazygrpc
import (
"context"
"google.golang.org/grpc"
)
// WithServiceCreator creates a lazy grpc service with a service creator.
func WithServiceCreator[T any](conn Conn, serviceCreator func(*grpc.ClientConn) T) Service[T] {
return &serviceImpl[T]{
Conn: conn,
serviceCreator: serviceCreator,
}
}
// Service is a lazy grpc service.
type Service[T any] interface {
Conn
GetService(ctx context.Context) (T, error)
}
// serviceImpl is a lazy grpc service implementation.
type serviceImpl[T any] struct {
Conn
serviceCreator func(*grpc.ClientConn) T
}
func (s *serviceImpl[T]) GetService(ctx context.Context) (T, error) {
conn, err := s.Conn.GetConn(ctx)
if err != nil {
var result T
return result, err
}
return s.serviceCreator(conn), nil
}

View File

@ -0,0 +1,91 @@
package resolver
import (
"errors"
"time"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
"google.golang.org/grpc/resolver"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
"github.com/milvus-io/milvus/pkg/util/lifetime"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
const (
// targets: milvus-session:///streamingcoord.
SessionResolverScheme = "milvus-session"
// targets: channel-assignment://external-grpc-client
ChannelAssignmentResolverScheme = "channel-assignment"
)
var idAllocator = typeutil.NewIDAllocator()
// NewChannelAssignmentBuilder creates a new resolver builder.
func NewChannelAssignmentBuilder(w types.AssignmentDiscoverWatcher) Builder {
return newBuilder(ChannelAssignmentResolverScheme, discoverer.NewChannelAssignmentDiscoverer(w))
}
// NewSessionBuilder creates a new resolver builder.
func NewSessionBuilder(c *clientv3.Client, role string) Builder {
// TODO: use 2.5.0 after 2.5.0 released.
return newBuilder(SessionResolverScheme, discoverer.NewSessionDiscoverer(c, role, "2.4.0"))
}
// newBuilder creates a new resolver builder.
func newBuilder(scheme string, d discoverer.Discoverer) Builder {
resolver := newResolverWithDiscoverer(scheme, d, 1*time.Second) // configurable.
return &builderImpl{
lifetime: lifetime.NewLifetime(lifetime.Working),
scheme: scheme,
resolver: resolver,
}
}
// builderImpl implements resolver.Builder.
type builderImpl struct {
lifetime lifetime.Lifetime[lifetime.State]
scheme string
resolver *resolverWithDiscoverer
}
// Build creates a new resolver for the given target.
//
// gRPC dial calls Build synchronously, and fails if the returned error is
// not nil.
//
// In our implementation, resolver.Target is ignored, because the resolver results is determined by the discoverer.
// Resolver is built when a Builder constructed.
// So build operation just register a new watcher into the existed resolver to share the resolver result.
func (b *builderImpl) Build(_ resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) {
if err := b.lifetime.Add(lifetime.IsWorking); err != nil {
return nil, errors.New("builder is closed")
}
defer b.lifetime.Done()
r := newWatchBasedGRPCResolver(cc, b.resolver.logger.With(zap.Int64("id", idAllocator.Allocate())))
b.resolver.RegisterNewWatcher(r)
return r, nil
}
func (b *builderImpl) Resolver() Resolver {
return b.resolver
}
// Scheme returns the scheme supported by this resolver. Scheme is defined
// at https://github.com/grpc/grpc/blob/master/doc/naming.md. The returned
// string should not contain uppercase characters, as they will not match
// the parsed target's scheme as defined in RFC 3986.
func (b *builderImpl) Scheme() string {
return b.scheme
}
// Close closes the builder also close the underlying resolver.
func (b *builderImpl) Close() {
b.lifetime.SetState(lifetime.Stopped)
b.lifetime.Wait()
b.lifetime.Close()
b.resolver.Close()
}

View File

@ -0,0 +1,47 @@
package resolver
import (
"context"
"testing"
"github.com/milvus-io/milvus/internal/mocks/google.golang.org/grpc/mock_resolver"
"github.com/milvus-io/milvus/internal/mocks/util/streamingutil/service/mock_discoverer"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/resolver"
)
func TestNewBuilder(t *testing.T) {
d := mock_discoverer.NewMockDiscoverer(t)
ch := make(chan discoverer.VersionedState)
d.EXPECT().Discover(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(discoverer.VersionedState) error) error {
for {
select {
case state := <-ch:
if err := cb(state); err != nil {
return err
}
case <-ctx.Done():
return ctx.Err()
}
}
})
d.EXPECT().NewVersionedState().Return(discoverer.VersionedState{
Version: typeutil.VersionInt64(-1),
})
b := newBuilder("test", d)
r := b.Resolver()
assert.NotNil(t, r)
assert.Equal(t, "test", b.Scheme())
mockClientConn := mock_resolver.NewMockClientConn(t)
mockClientConn.EXPECT().UpdateState(mock.Anything).RunAndReturn(func(args resolver.State) error {
return nil
})
grpcResolver, err := b.Build(resolver.Target{}, mockClientConn, resolver.BuildOptions{})
assert.NoError(t, err)
assert.NotNil(t, grpcResolver)
b.Close()
}

View File

@ -0,0 +1,46 @@
package resolver
import (
"context"
"errors"
"google.golang.org/grpc/resolver"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer"
)
type VersionedState = discoverer.VersionedState
var (
ErrCanceled = errors.New("canceled")
ErrInterrupted = errors.New("interrupted")
)
// Builder is the interface for the grpc resolver builder.
// It owns a Resolver instance and build grpc.Resolver from it.
type Builder interface {
resolver.Builder
// Resolver returns the underlying resolver instance.
Resolver() Resolver
// Close the builder, release the underlying resolver instance.
Close()
}
// Resolver is the interface for the service discovery in grpc.
// Allow the user to get the grpc service discovery results and watch the changes.
// Not all changes can be arrived by these api, only the newest state is guaranteed.
type Resolver interface {
// GetLatestState returns the latest state of the resolver.
// The returned state should be read only, applied any change to it will cause data race.
GetLatestState() VersionedState
// Watch watch the state change of the resolver.
// cb will be called with latest state after call, and will be called with new state when state changed.
// version may be skipped if the state is changed too fast, and latest version can be seen by cb.
// Watch is keep running until ctx is canceled or cb first return error.
// - Return error with ErrCanceled mark when ctx is canceled.
// - Return error with ErrInterrupted when cb returns.
Watch(ctx context.Context, cb func(VersionedState) error) error
}

View File

@ -0,0 +1,192 @@
package resolver
import (
"context"
"sync"
"time"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/lifetime"
"github.com/milvus-io/milvus/pkg/util/syncutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
var _ Resolver = (*resolverWithDiscoverer)(nil)
// newResolverWithDiscoverer creates a new resolver with discoverer.
func newResolverWithDiscoverer(scheme string, d discoverer.Discoverer, retryInterval time.Duration) *resolverWithDiscoverer {
r := &resolverWithDiscoverer{
taskNotifier: syncutil.NewAsyncTaskNotifier[struct{}](),
logger: log.With(zap.String("scheme", scheme)),
registerCh: make(chan *watchBasedGRPCResolver),
discoverer: d,
retryInterval: retryInterval,
latestStateCond: syncutil.NewContextCond(&sync.Mutex{}),
latestState: d.NewVersionedState(),
}
go r.doDiscover()
return r
}
// versionStateWithError is the versionedState with error.
type versionStateWithError struct {
state VersionedState
err error
}
// resolverWithDiscoverer is the resolver for bkproxy service.
type resolverWithDiscoverer struct {
taskNotifier *syncutil.AsyncTaskNotifier[struct{}]
logger *log.MLogger
registerCh chan *watchBasedGRPCResolver
discoverer discoverer.Discoverer // the discoverer method for the bkproxy service
retryInterval time.Duration
latestStateCond *syncutil.ContextCond
latestState discoverer.VersionedState
}
// GetLatestState returns the latest state of the resolver.
func (r *resolverWithDiscoverer) GetLatestState() VersionedState {
r.latestStateCond.L.Lock()
state := r.latestState
r.latestStateCond.L.Unlock()
return state
}
// Watch watch the state change of the resolver.
func (r *resolverWithDiscoverer) Watch(ctx context.Context, cb func(VersionedState) error) error {
state := r.GetLatestState()
if err := cb(state); err != nil {
return errors.Mark(err, ErrInterrupted)
}
version := state.Version
for {
if err := r.watchStateChange(ctx, version); err != nil {
return errors.Mark(err, ErrCanceled)
}
state := r.GetLatestState()
if err := cb(state); err != nil {
return errors.Mark(err, ErrInterrupted)
}
version = state.Version
}
}
// Close closes the resolver.
func (r *resolverWithDiscoverer) Close() {
// Cancel underlying task and close the discovery service.
r.taskNotifier.Cancel()
r.taskNotifier.BlockUntilFinish()
}
// watchStateChange block util the state is changed.
func (r *resolverWithDiscoverer) watchStateChange(ctx context.Context, version typeutil.Version) error {
r.latestStateCond.L.Lock()
for version.EQ(r.latestState.Version) {
if err := r.latestStateCond.Wait(ctx); err != nil {
return err
}
}
r.latestStateCond.L.Unlock()
return nil
}
// RegisterNewWatcher registers a new grpc resolver.
// RegisterNewWatcher should always be call before Close.
func (r *resolverWithDiscoverer) RegisterNewWatcher(grpcResolver *watchBasedGRPCResolver) error {
select {
case <-r.taskNotifier.Context().Done():
return errors.Mark(r.taskNotifier.Context().Err(), ErrCanceled)
case r.registerCh <- grpcResolver:
return nil
}
}
// doDiscover do the discovery on background.
func (r *resolverWithDiscoverer) doDiscover() {
grpcResolvers := make(map[*watchBasedGRPCResolver]struct{}, 0)
defer func() {
// Check if all grpc resolver is stopped.
for r := range grpcResolvers {
if err := lifetime.IsWorking(r.State()); err == nil {
r.logger.Warn("resolver is stopped before grpc watcher exist, maybe bug here")
break
}
}
r.logger.Info("resolver stopped")
r.taskNotifier.Finish(struct{}{})
}()
for {
ch := r.asyncDiscover(r.taskNotifier.Context())
r.logger.Info("service discover task started, listening...")
L:
for {
select {
case watcher := <-r.registerCh:
// New grpc resolver registered.
// Trigger the latest state to the new grpc resolver.
if err := watcher.Update(r.GetLatestState()); err != nil {
r.logger.Info("resolver is closed, ignore the new grpc resolver", zap.Error(err))
} else {
grpcResolvers[watcher] = struct{}{}
}
case stateWithError := <-ch:
if stateWithError.err != nil {
if r.taskNotifier.Context().Err() != nil {
// resolver stopped.
return
}
r.logger.Warn("service discover break down", zap.Error(stateWithError.err), zap.Duration("retryInterval", r.retryInterval))
time.Sleep(r.retryInterval)
break L
}
// Check if the state is the newer.
state := stateWithError.state
latestState := r.GetLatestState()
if !state.Version.GT(latestState.Version) {
// Ignore the old version.
r.logger.Info("service discover update, ignore old version", zap.Any("state", state))
continue
}
// Update all grpc resolver.
r.logger.Info("service discover update, update resolver", zap.Any("state", state), zap.Int("resolver_count", len(grpcResolvers)))
for watcher := range grpcResolvers {
// update operation do not block.
if err := watcher.Update(state); err != nil {
r.logger.Info("resolver is closed, unregister the resolver", zap.Error(err))
delete(grpcResolvers, watcher)
}
}
r.logger.Info("update resolver done")
// Update the latest state and notify all resolver watcher should be executed after the all grpc watcher updated.
r.latestStateCond.LockAndBroadcast()
r.latestState = state
r.latestStateCond.L.Unlock()
}
}
}
}
// asyncDiscover is a non-blocking version of Discover.
func (r *resolverWithDiscoverer) asyncDiscover(ctx context.Context) <-chan versionStateWithError {
ch := make(chan versionStateWithError, 1)
go func() {
err := r.discoverer.Discover(ctx, func(vs discoverer.VersionedState) error {
ch <- versionStateWithError{
state: vs,
}
return nil
})
ch <- versionStateWithError{err: err}
}()
return ch
}

View File

@ -0,0 +1,166 @@
package resolver
import (
"context"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/attributes"
"google.golang.org/grpc/resolver"
"github.com/milvus-io/milvus/internal/mocks/google.golang.org/grpc/mock_resolver"
"github.com/milvus-io/milvus/internal/mocks/util/streamingutil/service/mock_discoverer"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func TestResolverWithDiscoverer(t *testing.T) {
d := mock_discoverer.NewMockDiscoverer(t)
ch := make(chan discoverer.VersionedState)
d.EXPECT().Discover(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(discoverer.VersionedState) error) error {
for {
select {
case state := <-ch:
if err := cb(state); err != nil {
return err
}
case <-ctx.Done():
return ctx.Err()
}
}
})
d.EXPECT().NewVersionedState().Return(discoverer.VersionedState{
Version: typeutil.VersionInt64(-1),
})
r := newResolverWithDiscoverer("test", d, time.Second)
var resultOfGRPCResolver resolver.State
mockClientConn := mock_resolver.NewMockClientConn(t)
mockClientConn.EXPECT().UpdateState(mock.Anything).RunAndReturn(func(args resolver.State) error {
resultOfGRPCResolver = args
return nil
})
w := newWatchBasedGRPCResolver(mockClientConn, log.With())
w2 := newWatchBasedGRPCResolver(nil, log.With())
w2.Close()
// Test Register a grpc resolver watcher.
err := r.RegisterNewWatcher(w)
assert.NoError(t, err)
err = r.RegisterNewWatcher(w2) // A closed resolver should be removed automatically by resolver.
assert.NoError(t, err)
state := r.GetLatestState()
assert.Equal(t, typeutil.VersionInt64(-1), state.Version)
time.Sleep(500 * time.Millisecond)
state = r.GetLatestState()
assert.Equal(t, typeutil.VersionInt64(-1), state.Version)
// should be non block after context canceled
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
err = r.Watch(ctx, func(s VersionedState) error {
state = s
return nil
})
assert.Equal(t, typeutil.VersionInt64(-1), state.Version)
assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.True(t, errors.Is(err, ErrCanceled))
// should be non block after state operation failure.
testErr := errors.New("test error")
err = r.Watch(context.Background(), func(s VersionedState) error {
return testErr
})
assert.ErrorIs(t, err, testErr)
assert.True(t, errors.Is(err, ErrInterrupted))
outCh := make(chan VersionedState, 1)
go func() {
var state VersionedState
err := r.Watch(context.Background(), func(s VersionedState) error {
state = s
if state.Version.GT(typeutil.VersionInt64(2)) {
return testErr
}
return nil
})
assert.ErrorIs(t, err, testErr)
outCh <- state
}()
// should be block.
shouldbeBlock(t, outCh)
ch <- discoverer.VersionedState{
Version: typeutil.VersionInt64(1),
State: resolver.State{
Addresses: []resolver.Address{},
},
}
// version do not reach, should be block.
shouldbeBlock(t, outCh)
ch <- discoverer.VersionedState{
Version: typeutil.VersionInt64(3),
State: resolver.State{
Addresses: []resolver.Address{{Addr: "1"}},
Attributes: attributes.New("1", "1"),
},
}
// version do reach, should not be block.
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
select {
case state = <-outCh:
assert.Equal(t, typeutil.VersionInt64(3), state.Version)
assert.NotNil(t, state.State.Attributes)
assert.NotNil(t, state.State.Addresses)
case <-ctx.Done():
t.Errorf("should not be block")
}
// after block, should be see the last state by grpc watcher.
assert.Len(t, resultOfGRPCResolver.Addresses, 1)
// old version should be filtered.
ch <- discoverer.VersionedState{
Version: typeutil.VersionInt64(2),
State: resolver.State{
Addresses: []resolver.Address{{Addr: "1"}},
Attributes: attributes.New("1", "1"),
},
}
shouldbeBlock(t, outCh)
w.Close() // closed watcher should be removed in next update.
ch <- discoverer.VersionedState{
Version: typeutil.VersionInt64(5),
State: resolver.State{
Addresses: []resolver.Address{{Addr: "1"}},
Attributes: attributes.New("1", "1"),
},
}
r.Close()
// after close, new register is not allowed.
err = r.RegisterNewWatcher(nil)
assert.True(t, errors.Is(err, ErrCanceled))
}
func shouldbeBlock(t *testing.T, ch <-chan VersionedState) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
select {
case <-ch:
t.Errorf("should be block")
case <-ctx.Done():
}
}

View File

@ -0,0 +1,63 @@
package resolver
import (
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"google.golang.org/grpc/resolver"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/lifetime"
)
var _ resolver.Resolver = (*watchBasedGRPCResolver)(nil)
// newWatchBasedGRPCResolver creates a new watch based grpc resolver.
func newWatchBasedGRPCResolver(cc resolver.ClientConn, logger *log.MLogger) *watchBasedGRPCResolver {
return &watchBasedGRPCResolver{
lifetime: lifetime.NewLifetime(lifetime.Working),
cc: cc,
logger: logger,
}
}
// watchBasedGRPCResolver is a watch based grpc resolver.
type watchBasedGRPCResolver struct {
lifetime lifetime.Lifetime[lifetime.State]
cc resolver.ClientConn
logger *log.MLogger
}
// ResolveNow will be called by gRPC to try to resolve the target name
// again. It's just a hint, resolver can ignore this if it's not necessary.
//
// It could be called multiple times concurrently.
func (r *watchBasedGRPCResolver) ResolveNow(_ resolver.ResolveNowOptions) {
}
// Close closes the resolver.
// Do nothing.
func (r *watchBasedGRPCResolver) Close() {
r.lifetime.SetState(lifetime.Stopped)
r.lifetime.Wait()
r.lifetime.Close()
}
func (r *watchBasedGRPCResolver) Update(state VersionedState) error {
if r.lifetime.Add(lifetime.IsWorking) != nil {
return errors.New("resolver is closed")
}
defer r.lifetime.Done()
if err := r.cc.UpdateState(state.State); err != nil {
// watch based resolver could ignore the error.
r.logger.Warn("fail to update resolver state", zap.Error(err))
}
r.logger.Info("update resolver state success", zap.Any("state", state.State))
return nil
}
// State returns the state of the resolver.
func (r *watchBasedGRPCResolver) State() lifetime.State {
return r.lifetime.GetState()
}

View File

@ -0,0 +1,35 @@
package resolver
import (
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/resolver"
"github.com/milvus-io/milvus/internal/mocks/google.golang.org/grpc/mock_resolver"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
func TestMain(m *testing.M) {
paramtable.Init()
m.Run()
}
func TestWatchBasedGRPCResolver(t *testing.T) {
cc := mock_resolver.NewMockClientConn(t)
cc.EXPECT().UpdateState(mock.Anything).Return(nil)
r := newWatchBasedGRPCResolver(cc, log.With())
assert.NoError(t, r.Update(VersionedState{State: resolver.State{Addresses: []resolver.Address{{Addr: "addr"}}}}))
cc.EXPECT().UpdateState(mock.Anything).Unset()
cc.EXPECT().UpdateState(mock.Anything).Return(errors.New("err"))
// watch based resolver could ignore the error.
assert.NoError(t, r.Update(VersionedState{State: resolver.State{Addresses: []resolver.Address{{Addr: "addr"}}}}))
r.Close()
assert.Error(t, r.Update(VersionedState{State: resolver.State{Addresses: []resolver.Address{{Addr: "addr"}}}}))
}

View File

@ -14,14 +14,13 @@ import (
var streamingErrorToGRPCStatus = map[streamingpb.StreamingCode]codes.Code{
streamingpb.StreamingCode_STREAMING_CODE_OK: codes.OK,
streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_EXIST: codes.AlreadyExists,
streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_NOT_EXIST: codes.FailedPrecondition,
streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_FENCED: codes.FailedPrecondition,
streamingpb.StreamingCode_STREAMING_CODE_ON_SHUTDOWN: codes.FailedPrecondition,
streamingpb.StreamingCode_STREAMING_CODE_INVALID_REQUEST_SEQ: codes.FailedPrecondition,
streamingpb.StreamingCode_STREAMING_CODE_UNMATCHED_CHANNEL_TERM: codes.FailedPrecondition,
streamingpb.StreamingCode_STREAMING_CODE_IGNORED_OPERATION: codes.FailedPrecondition,
streamingpb.StreamingCode_STREAMING_CODE_INNER: codes.Unavailable,
streamingpb.StreamingCode_STREAMING_CODE_INNER: codes.Internal,
streamingpb.StreamingCode_STREAMING_CODE_INVAILD_ARGUMENT: codes.InvalidArgument,
streamingpb.StreamingCode_STREAMING_CODE_UNKNOWN: codes.Unknown,
}

View File

@ -29,13 +29,19 @@ func (e *StreamingError) AsPBError() *streamingpb.StreamingError {
}
// IsWrongStreamingNode returns true if the error is caused by wrong streamingnode.
// Client should report these error to coord and block until new assignment term coming.
// Client for producing and consuming should report these error to coord and block until new assignment term coming.
func (e *StreamingError) IsWrongStreamingNode() bool {
return e.Code == streamingpb.StreamingCode_STREAMING_CODE_UNMATCHED_CHANNEL_TERM || // channel term not match
e.Code == streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_NOT_EXIST || // channel do not exist on streamingnode
e.Code == streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_FENCED // channel fenced on these node.
}
// IsSkippedOperation returns true if the operation is ignored or skipped.
func (e *StreamingError) IsSkippedOperation() bool {
return e.Code == streamingpb.StreamingCode_STREAMING_CODE_IGNORED_OPERATION ||
e.Code == streamingpb.StreamingCode_STREAMING_CODE_UNMATCHED_CHANNEL_TERM
}
// NewOnShutdownError creates a new StreamingError with code STREAMING_CODE_ON_SHUTDOWN.
func NewOnShutdownError(format string, args ...interface{}) *StreamingError {
return New(streamingpb.StreamingCode_STREAMING_CODE_ON_SHUTDOWN, format, args...)
@ -51,11 +57,6 @@ func NewInvalidRequestSeq(format string, args ...interface{}) *StreamingError {
return New(streamingpb.StreamingCode_STREAMING_CODE_INVALID_REQUEST_SEQ, format, args...)
}
// NewChannelExist creates a new StreamingError with code StreamingCode_STREAMING_CODE_CHANNEL_EXIST.
func NewChannelExist(format string, args ...interface{}) *StreamingError {
return New(streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_EXIST, format, args...)
}
// NewChannelNotExist creates a new StreamingError with code STREAMING_CODE_CHANNEL_NOT_EXIST.
func NewChannelNotExist(channel string) *StreamingError {
return New(streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_NOT_EXIST, "%s not exist", channel)

View File

@ -27,12 +27,6 @@ func TestStreamingError(t *testing.T) {
pbErr = streamingErr.AsPBError()
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_INVALID_REQUEST_SEQ, pbErr.Code)
streamingErr = NewChannelExist("test")
assert.Contains(t, streamingErr.Error(), "code: STREAMING_CODE_CHANNEL_EXIST, cause: test")
assert.False(t, streamingErr.IsWrongStreamingNode())
pbErr = streamingErr.AsPBError()
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_EXIST, pbErr.Code)
streamingErr = NewChannelNotExist("test")
assert.Contains(t, streamingErr.Error(), "code: STREAMING_CODE_CHANNEL_NOT_EXIST, cause: test")
assert.True(t, streamingErr.IsWrongStreamingNode())

View File

@ -22,4 +22,8 @@ packages:
WALImpls:
Interceptor:
InterceptorWithReady:
InterceptorBuilder:
InterceptorBuilder:
github.com/milvus-io/milvus/pkg/streaming/util/types:
interfaces:
AssignmentDiscoverWatcher:

View File

@ -0,0 +1,80 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
package mock_types
import (
context "context"
types "github.com/milvus-io/milvus/pkg/streaming/util/types"
mock "github.com/stretchr/testify/mock"
)
// MockAssignmentDiscoverWatcher is an autogenerated mock type for the AssignmentDiscoverWatcher type
type MockAssignmentDiscoverWatcher struct {
mock.Mock
}
type MockAssignmentDiscoverWatcher_Expecter struct {
mock *mock.Mock
}
func (_m *MockAssignmentDiscoverWatcher) EXPECT() *MockAssignmentDiscoverWatcher_Expecter {
return &MockAssignmentDiscoverWatcher_Expecter{mock: &_m.Mock}
}
// AssignmentDiscover provides a mock function with given fields: ctx, cb
func (_m *MockAssignmentDiscoverWatcher) AssignmentDiscover(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error) error {
ret := _m.Called(ctx, cb)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, func(*types.VersionedStreamingNodeAssignments) error) error); ok {
r0 = rf(ctx, cb)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockAssignmentDiscoverWatcher_AssignmentDiscover_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AssignmentDiscover'
type MockAssignmentDiscoverWatcher_AssignmentDiscover_Call struct {
*mock.Call
}
// AssignmentDiscover is a helper method to define mock.On call
// - ctx context.Context
// - cb func(*types.VersionedStreamingNodeAssignments) error
func (_e *MockAssignmentDiscoverWatcher_Expecter) AssignmentDiscover(ctx interface{}, cb interface{}) *MockAssignmentDiscoverWatcher_AssignmentDiscover_Call {
return &MockAssignmentDiscoverWatcher_AssignmentDiscover_Call{Call: _e.mock.On("AssignmentDiscover", ctx, cb)}
}
func (_c *MockAssignmentDiscoverWatcher_AssignmentDiscover_Call) Run(run func(ctx context.Context, cb func(*types.VersionedStreamingNodeAssignments) error)) *MockAssignmentDiscoverWatcher_AssignmentDiscover_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(func(*types.VersionedStreamingNodeAssignments) error))
})
return _c
}
func (_c *MockAssignmentDiscoverWatcher_AssignmentDiscover_Call) Return(_a0 error) *MockAssignmentDiscoverWatcher_AssignmentDiscover_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockAssignmentDiscoverWatcher_AssignmentDiscover_Call) RunAndReturn(run func(context.Context, func(*types.VersionedStreamingNodeAssignments) error) error) *MockAssignmentDiscoverWatcher_AssignmentDiscover_Call {
_c.Call.Return(run)
return _c
}
// NewMockAssignmentDiscoverWatcher creates a new instance of MockAssignmentDiscoverWatcher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockAssignmentDiscoverWatcher(t interface {
mock.TestingT
Cleanup(func())
}) *MockAssignmentDiscoverWatcher {
mock := &MockAssignmentDiscoverWatcher{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -1,6 +1,8 @@
package types
import (
"context"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/pkg/util/typeutil"
@ -11,6 +13,14 @@ var (
ErrNotAlive = errors.New("streaming node is not alive")
)
// AssignmentDiscoverWatcher is the interface for watching the assignment discovery.
type AssignmentDiscoverWatcher interface {
// AssignmentDiscover watches the assignment discovery.
// The callback will be called when the discovery is changed.
// The final error will be returned when the watcher is closed or broken.
AssignmentDiscover(ctx context.Context, cb func(*VersionedStreamingNodeAssignments) error) error
}
// VersionedStreamingNodeAssignments is the relation between server and channels with version.
type VersionedStreamingNodeAssignments struct {
Version typeutil.VersionInt64Pair
@ -20,7 +30,7 @@ type VersionedStreamingNodeAssignments struct {
// StreamingNodeAssignment is the relation between server and channels.
type StreamingNodeAssignment struct {
NodeInfo StreamingNodeInfo
Channels []PChannelInfo
Channels map[string]PChannelInfo
}
// StreamingNodeInfo is the relation between server and channels.
@ -40,3 +50,11 @@ type StreamingNodeStatus struct {
func (n *StreamingNodeStatus) IsHealthy() bool {
return n.Err == nil
}
// ErrorOfNode returns the error of the streaming node.
func (n *StreamingNodeStatus) ErrorOfNode() error {
if n == nil {
return ErrNotAlive
}
return n.Err
}

View File

@ -1,17 +1,20 @@
package util
package typeutil
import (
"go.uber.org/atomic"
)
// NewIDAllocator creates a new IDAllocator.
func NewIDAllocator() *IDAllocator {
return &IDAllocator{}
}
// IDAllocator is a thread-safe ID allocator.
type IDAllocator struct {
underlying atomic.Int64
}
// Allocate allocates a new ID.
func (ida *IDAllocator) Allocate() int64 {
return ida.underlying.Inc()
}

View File

@ -108,6 +108,7 @@ go test -race -cover -tags dynamic,test "${MILVUS_DIR}/util/importutilv2/..." -f
go test -race -cover -tags dynamic,test "${MILVUS_DIR}/util/proxyutil/..." -failfast -count=1 -ldflags="-r ${RPATH}"
go test -race -cover -tags dynamic,test "${MILVUS_DIR}/util/initcore/..." -failfast -count=1 -ldflags="-r ${RPATH}"
go test -race -cover -tags dynamic,test "${MILVUS_DIR}/util/cgo/..." -failfast -count=1 -ldflags="-r ${RPATH}"
go test -race -cover -tags dynamic,test "${MILVUS_DIR}/util/streamingutil/..." -failfast -count=1 -ldflags="-r ${RPATH}"
}
function test_pkg()
@ -163,6 +164,13 @@ function test_cmd()
go test -race -cover -tags dynamic,test "${ROOT_DIR}/cmd/tools/..." -failfast -count=1 -ldflags="-r ${RPATH}"
}
function test_streaming()
{
go test -race -cover -tags dynamic,test "${MILVUS_DIR}/streamingcoord/..." -failfast -count=1 -ldflags="-r ${RPATH}"
go test -race -cover -tags dynamic,test "${MILVUS_DIR}/streamingnode/..." -failfast -count=1 -ldflags="-r ${RPATH}"
go test -race -cover -tags dynamic,test "${MILVUS_DIR}/util/streamingutil/..." -failfast -count=1 -ldflags="-r ${RPATH}"
}
function test_all()
{
test_proxy
@ -181,6 +189,7 @@ test_util
test_pkg
test_metastore
test_cmd
test_streaming
}
@ -237,6 +246,9 @@ case "${TEST_TAG}" in
cmd)
test_cmd
;;
streaming)
test_streaming
;;
*) echo "Test All";
test_all
;;