mirror of https://github.com/milvus-io/milvus.git
enhance: timetick interceptor implementation (#34238)
issue: #33285 - optimize the message package - add interceptor package to achieve append operation intercepting. - add timetick interceptor to attach timetick properties for message. - add timetick background task to send timetick message. Signed-off-by: chyezh <chyezh@outlook.com>pull/34340/head v2.2-testing-20240702
parent
a5be322ab2
commit
3563136c2a
|
@ -8,7 +8,7 @@ import (
|
|||
message "github.com/milvus-io/milvus/pkg/streaming/util/message"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
streamingpb "github.com/milvus-io/milvus/internal/proto/streamingpb"
|
||||
types "github.com/milvus-io/milvus/pkg/streaming/util/types"
|
||||
|
||||
wal "github.com/milvus-io/milvus/internal/streamingnode/server/wal"
|
||||
)
|
||||
|
@ -117,16 +117,14 @@ func (_c *MockWAL_AppendAsync_Call) RunAndReturn(run func(context.Context, messa
|
|||
}
|
||||
|
||||
// Channel provides a mock function with given fields:
|
||||
func (_m *MockWAL) Channel() *streamingpb.PChannelInfo {
|
||||
func (_m *MockWAL) Channel() types.PChannelInfo {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 *streamingpb.PChannelInfo
|
||||
if rf, ok := ret.Get(0).(func() *streamingpb.PChannelInfo); ok {
|
||||
var r0 types.PChannelInfo
|
||||
if rf, ok := ret.Get(0).(func() types.PChannelInfo); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*streamingpb.PChannelInfo)
|
||||
}
|
||||
r0 = ret.Get(0).(types.PChannelInfo)
|
||||
}
|
||||
|
||||
return r0
|
||||
|
@ -149,12 +147,12 @@ func (_c *MockWAL_Channel_Call) Run(run func()) *MockWAL_Channel_Call {
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockWAL_Channel_Call) Return(_a0 *streamingpb.PChannelInfo) *MockWAL_Channel_Call {
|
||||
func (_c *MockWAL_Channel_Call) Return(_a0 types.PChannelInfo) *MockWAL_Channel_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockWAL_Channel_Call) RunAndReturn(run func() *streamingpb.PChannelInfo) *MockWAL_Channel_Call {
|
||||
func (_c *MockWAL_Channel_Call) RunAndReturn(run func() types.PChannelInfo) *MockWAL_Channel_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
// Code generated by mockery v2.32.4. DO NOT EDIT.
|
||||
|
||||
package mock_walimpls
|
||||
package mock_interceptors
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
message "github.com/milvus-io/milvus/pkg/streaming/util/message"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
// Code generated by mockery v2.32.4. DO NOT EDIT.
|
||||
|
||||
package mock_walimpls
|
||||
package mock_interceptors
|
||||
|
||||
import (
|
||||
walimpls "github.com/milvus-io/milvus/pkg/streaming/walimpls"
|
||||
interceptors "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
|
@ -20,16 +20,16 @@ func (_m *MockInterceptorBuilder) EXPECT() *MockInterceptorBuilder_Expecter {
|
|||
return &MockInterceptorBuilder_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Build provides a mock function with given fields: wal
|
||||
func (_m *MockInterceptorBuilder) Build(wal <-chan walimpls.WALImpls) walimpls.BasicInterceptor {
|
||||
ret := _m.Called(wal)
|
||||
// Build provides a mock function with given fields: param
|
||||
func (_m *MockInterceptorBuilder) Build(param interceptors.InterceptorBuildParam) interceptors.BasicInterceptor {
|
||||
ret := _m.Called(param)
|
||||
|
||||
var r0 walimpls.BasicInterceptor
|
||||
if rf, ok := ret.Get(0).(func(<-chan walimpls.WALImpls) walimpls.BasicInterceptor); ok {
|
||||
r0 = rf(wal)
|
||||
var r0 interceptors.BasicInterceptor
|
||||
if rf, ok := ret.Get(0).(func(interceptors.InterceptorBuildParam) interceptors.BasicInterceptor); ok {
|
||||
r0 = rf(param)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(walimpls.BasicInterceptor)
|
||||
r0 = ret.Get(0).(interceptors.BasicInterceptor)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,24 +42,24 @@ type MockInterceptorBuilder_Build_Call struct {
|
|||
}
|
||||
|
||||
// Build is a helper method to define mock.On call
|
||||
// - wal <-chan walimpls.WALImpls
|
||||
func (_e *MockInterceptorBuilder_Expecter) Build(wal interface{}) *MockInterceptorBuilder_Build_Call {
|
||||
return &MockInterceptorBuilder_Build_Call{Call: _e.mock.On("Build", wal)}
|
||||
// - param interceptors.InterceptorBuildParam
|
||||
func (_e *MockInterceptorBuilder_Expecter) Build(param interface{}) *MockInterceptorBuilder_Build_Call {
|
||||
return &MockInterceptorBuilder_Build_Call{Call: _e.mock.On("Build", param)}
|
||||
}
|
||||
|
||||
func (_c *MockInterceptorBuilder_Build_Call) Run(run func(wal <-chan walimpls.WALImpls)) *MockInterceptorBuilder_Build_Call {
|
||||
func (_c *MockInterceptorBuilder_Build_Call) Run(run func(param interceptors.InterceptorBuildParam)) *MockInterceptorBuilder_Build_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(<-chan walimpls.WALImpls))
|
||||
run(args[0].(interceptors.InterceptorBuildParam))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockInterceptorBuilder_Build_Call) Return(_a0 walimpls.BasicInterceptor) *MockInterceptorBuilder_Build_Call {
|
||||
func (_c *MockInterceptorBuilder_Build_Call) Return(_a0 interceptors.BasicInterceptor) *MockInterceptorBuilder_Build_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockInterceptorBuilder_Build_Call) RunAndReturn(run func(<-chan walimpls.WALImpls) walimpls.BasicInterceptor) *MockInterceptorBuilder_Build_Call {
|
||||
func (_c *MockInterceptorBuilder_Build_Call) RunAndReturn(run func(interceptors.InterceptorBuildParam) interceptors.BasicInterceptor) *MockInterceptorBuilder_Build_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
|
@ -1,11 +1,12 @@
|
|||
// Code generated by mockery v2.32.4. DO NOT EDIT.
|
||||
|
||||
package mock_walimpls
|
||||
package mock_interceptors
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
message "github.com/milvus-io/milvus/pkg/streaming/util/message"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
|
@ -1,42 +1,6 @@
|
|||
package streamingpb
|
||||
|
||||
import (
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
const (
|
||||
ServiceMethodPrefix = "/milvus.proto.log"
|
||||
InitialTerm = int64(-1)
|
||||
)
|
||||
|
||||
func NewDeliverAll() *DeliverPolicy {
|
||||
return &DeliverPolicy{
|
||||
Policy: &DeliverPolicy_All{
|
||||
All: &emptypb.Empty{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewDeliverLatest() *DeliverPolicy {
|
||||
return &DeliverPolicy{
|
||||
Policy: &DeliverPolicy_Latest{
|
||||
Latest: &emptypb.Empty{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewDeliverStartFrom(messageID *MessageID) *DeliverPolicy {
|
||||
return &DeliverPolicy{
|
||||
Policy: &DeliverPolicy_StartFrom{
|
||||
StartFrom: messageID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewDeliverStartAfter(messageID *MessageID) *DeliverPolicy {
|
||||
return &DeliverPolicy{
|
||||
Policy: &DeliverPolicy_StartAfter{
|
||||
StartAfter: messageID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
package resource
|
||||
|
||||
import (
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/streamingnode/server/resource/timestamp"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
)
|
||||
|
||||
var r *resourceImpl // singleton resource instance
|
||||
|
||||
// optResourceInit is the option to initialize the resource.
|
||||
type optResourceInit func(r *resourceImpl)
|
||||
|
||||
// OptETCD provides the etcd client to the resource.
|
||||
func OptETCD(etcd *clientv3.Client) optResourceInit {
|
||||
return func(r *resourceImpl) {
|
||||
r.etcdClient = etcd
|
||||
}
|
||||
}
|
||||
|
||||
// OptRootCoordClient provides the root coordinator client to the resource.
|
||||
func OptRootCoordClient(rootCoordClient types.RootCoordClient) optResourceInit {
|
||||
return func(r *resourceImpl) {
|
||||
r.rootCoordClient = rootCoordClient
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the singleton of resources.
|
||||
// Should be call when streaming node startup.
|
||||
func Init(opts ...optResourceInit) {
|
||||
r = &resourceImpl{}
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
r.timestampAllocator = timestamp.NewAllocator(r.rootCoordClient)
|
||||
|
||||
assertNotNil(r.TimestampAllocator())
|
||||
assertNotNil(r.ETCD())
|
||||
assertNotNil(r.RootCoordClient())
|
||||
}
|
||||
|
||||
// Resource access the underlying singleton of resources.
|
||||
func Resource() *resourceImpl {
|
||||
return r
|
||||
}
|
||||
|
||||
// resourceImpl is a basic resource dependency for streamingnode server.
|
||||
// All utility on it is concurrent-safe and singleton.
|
||||
type resourceImpl struct {
|
||||
timestampAllocator timestamp.Allocator
|
||||
etcdClient *clientv3.Client
|
||||
rootCoordClient types.RootCoordClient
|
||||
}
|
||||
|
||||
// TimestampAllocator returns the timestamp allocator to allocate timestamp.
|
||||
func (r *resourceImpl) TimestampAllocator() timestamp.Allocator {
|
||||
return r.timestampAllocator
|
||||
}
|
||||
|
||||
// ETCD returns the etcd client.
|
||||
func (r *resourceImpl) ETCD() *clientv3.Client {
|
||||
return r.etcdClient
|
||||
}
|
||||
|
||||
// RootCoordClient returns the root coordinator client.
|
||||
func (r *resourceImpl) RootCoordClient() types.RootCoordClient {
|
||||
return r.rootCoordClient
|
||||
}
|
||||
|
||||
// assertNotNil panics if the resource is nil.
|
||||
func assertNotNil(v interface{}) {
|
||||
if v == nil {
|
||||
panic("nil resource")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package resource
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
)
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
Init()
|
||||
})
|
||||
assert.Panics(t, func() {
|
||||
Init(OptETCD(&clientv3.Client{}))
|
||||
})
|
||||
assert.Panics(t, func() {
|
||||
Init(OptETCD(&clientv3.Client{}))
|
||||
})
|
||||
Init(OptETCD(&clientv3.Client{}), OptRootCoordClient(mocks.NewMockRootCoordClient(t)))
|
||||
|
||||
assert.NotNil(t, Resource().TimestampAllocator())
|
||||
assert.NotNil(t, Resource().ETCD())
|
||||
assert.NotNil(t, Resource().RootCoordClient())
|
||||
}
|
||||
|
||||
func TestInitForTest(t *testing.T) {
|
||||
InitForTest()
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
//go:build test
|
||||
// +build test
|
||||
|
||||
package resource
|
||||
|
||||
import "github.com/milvus-io/milvus/internal/streamingnode/server/resource/timestamp"
|
||||
|
||||
// InitForTest initializes the singleton of resources for test.
|
||||
func InitForTest(opts ...optResourceInit) {
|
||||
r = &resourceImpl{}
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
if r.rootCoordClient != nil {
|
||||
r.timestampAllocator = timestamp.NewAllocator(r.rootCoordClient)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
package timestamp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
var errExhausted = errors.New("exhausted")
|
||||
|
||||
// newLocalAllocator creates a new local allocator.
|
||||
func newLocalAllocator() *localAllocator {
|
||||
return &localAllocator{
|
||||
nextStartID: 0,
|
||||
endStartID: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// localAllocator allocates timestamp locally.
|
||||
type localAllocator struct {
|
||||
nextStartID uint64 // Allocate timestamp locally.
|
||||
endStartID uint64
|
||||
}
|
||||
|
||||
// AllocateOne allocates a timestamp.
|
||||
func (a *localAllocator) allocateOne() (uint64, error) {
|
||||
if a.nextStartID < a.endStartID {
|
||||
id := a.nextStartID
|
||||
a.nextStartID++
|
||||
return id, nil
|
||||
}
|
||||
return 0, errExhausted
|
||||
}
|
||||
|
||||
// update updates the local allocator.
|
||||
func (a *localAllocator) update(start uint64, count int) {
|
||||
// local allocator can be only increasing.
|
||||
if start >= a.endStartID {
|
||||
a.nextStartID = start
|
||||
a.endStartID = start + uint64(count)
|
||||
}
|
||||
}
|
||||
|
||||
// expire expires all id in the local allocator.
|
||||
func (a *localAllocator) exhausted() {
|
||||
a.nextStartID = a.endStartID
|
||||
}
|
||||
|
||||
// remoteAllocator allocate timestamp from remote root coordinator.
|
||||
type remoteAllocator struct {
|
||||
rc types.RootCoordClient
|
||||
nodeID int64
|
||||
}
|
||||
|
||||
// newRemoteAllocator creates a new remote allocator.
|
||||
func newRemoteAllocator(rc types.RootCoordClient) *remoteAllocator {
|
||||
a := &remoteAllocator{
|
||||
nodeID: paramtable.GetNodeID(),
|
||||
rc: rc,
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func (ta *remoteAllocator) allocate(ctx context.Context, count uint32) (uint64, int, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
req := &rootcoordpb.AllocTimestampRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_RequestTSO),
|
||||
commonpbutil.WithMsgID(0),
|
||||
commonpbutil.WithSourceID(ta.nodeID),
|
||||
),
|
||||
Count: count,
|
||||
}
|
||||
|
||||
resp, err := ta.rc.AllocTimestamp(ctx, req)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("syncTimestamp Failed:%w", err)
|
||||
}
|
||||
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
return 0, 0, fmt.Errorf("syncTimeStamp Failed:%s", resp.GetStatus().GetReason())
|
||||
}
|
||||
if resp == nil {
|
||||
return 0, 0, fmt.Errorf("empty AllocTimestampResponse")
|
||||
}
|
||||
return resp.GetTimestamp(), int(resp.GetCount()), nil
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
package timestamp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"go.uber.org/atomic"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
func TestLocalAllocator(t *testing.T) {
|
||||
allocator := newLocalAllocator()
|
||||
|
||||
ts, err := allocator.allocateOne()
|
||||
assert.Error(t, err)
|
||||
assert.Zero(t, ts)
|
||||
|
||||
allocator.update(1, 100)
|
||||
|
||||
counter := atomic.NewUint64(0)
|
||||
for i := 0; i < 100; i++ {
|
||||
ts, err := allocator.allocateOne()
|
||||
assert.NoError(t, err)
|
||||
assert.NotZero(t, ts)
|
||||
counter.Add(ts)
|
||||
}
|
||||
assert.Equal(t, uint64(5050), counter.Load())
|
||||
|
||||
// allocator exhausted.
|
||||
ts, err = allocator.allocateOne()
|
||||
assert.Error(t, err)
|
||||
assert.Zero(t, ts)
|
||||
|
||||
// allocator can not be rollback.
|
||||
allocator.update(90, 100)
|
||||
ts, err = allocator.allocateOne()
|
||||
assert.Error(t, err)
|
||||
assert.Zero(t, ts)
|
||||
|
||||
// allocator can be only increasing.
|
||||
allocator.update(101, 100)
|
||||
ts, err = allocator.allocateOne()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, ts, uint64(101))
|
||||
|
||||
// allocator can be exhausted.
|
||||
allocator.exhausted()
|
||||
ts, err = allocator.allocateOne()
|
||||
assert.Error(t, err)
|
||||
assert.Zero(t, ts)
|
||||
}
|
||||
|
||||
func TestRemoteAllocator(t *testing.T) {
|
||||
paramtable.Init()
|
||||
paramtable.SetNodeID(1)
|
||||
|
||||
client := NewMockRootCoordClient(t)
|
||||
|
||||
allocator := newRemoteAllocator(client)
|
||||
ts, count, err := allocator.allocate(context.Background(), 100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotZero(t, ts)
|
||||
assert.Equal(t, count, 100)
|
||||
|
||||
// Test error.
|
||||
client = mocks.NewMockRootCoordClient(t)
|
||||
client.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn(
|
||||
func(ctx context.Context, atr *rootcoordpb.AllocTimestampRequest, co ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) {
|
||||
return nil, errors.New("test")
|
||||
},
|
||||
)
|
||||
allocator = newRemoteAllocator(client)
|
||||
_, _, err = allocator.allocate(context.Background(), 100)
|
||||
assert.Error(t, err)
|
||||
|
||||
client.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Unset()
|
||||
client.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn(
|
||||
func(ctx context.Context, atr *rootcoordpb.AllocTimestampRequest, co ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) {
|
||||
return &rootcoordpb.AllocTimestampResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_ForceDeny,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
allocator = newRemoteAllocator(client)
|
||||
_, _, err = allocator.allocate(context.Background(), 100)
|
||||
assert.Error(t, err)
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
//go:build test
|
||||
// +build test
|
||||
|
||||
package timestamp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
"go.uber.org/atomic"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
)
|
||||
|
||||
func NewMockRootCoordClient(t *testing.T) *mocks.MockRootCoordClient {
|
||||
counter := atomic.NewUint64(1)
|
||||
client := mocks.NewMockRootCoordClient(t)
|
||||
client.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn(
|
||||
func(ctx context.Context, atr *rootcoordpb.AllocTimestampRequest, co ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) {
|
||||
if atr.Count > 1000 {
|
||||
panic(fmt.Sprintf("count %d is too large", atr.Count))
|
||||
}
|
||||
c := counter.Add(uint64(atr.Count))
|
||||
return &rootcoordpb.AllocTimestampResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
Timestamp: c - uint64(atr.Count),
|
||||
Count: atr.Count,
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
return client
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package timestamp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
)
|
||||
|
||||
// batchAllocateSize is the size of batch allocate from remote allocator.
|
||||
const batchAllocateSize = 1000
|
||||
|
||||
var _ Allocator = (*allocatorImpl)(nil)
|
||||
|
||||
// NewAllocator creates a new allocator.
|
||||
func NewAllocator(rc types.RootCoordClient) Allocator {
|
||||
return &allocatorImpl{
|
||||
mu: sync.Mutex{},
|
||||
remoteAllocator: newRemoteAllocator(rc),
|
||||
localAllocator: newLocalAllocator(),
|
||||
}
|
||||
}
|
||||
|
||||
type Allocator interface {
|
||||
// Allocate allocates a timestamp.
|
||||
Allocate(ctx context.Context) (uint64, error)
|
||||
|
||||
// Sync expire the local allocator messages,
|
||||
// syncs the local allocator and remote allocator.
|
||||
Sync()
|
||||
}
|
||||
|
||||
type allocatorImpl struct {
|
||||
mu sync.Mutex
|
||||
remoteAllocator *remoteAllocator
|
||||
localAllocator *localAllocator
|
||||
}
|
||||
|
||||
// AllocateOne allocates a timestamp.
|
||||
func (ta *allocatorImpl) Allocate(ctx context.Context) (uint64, error) {
|
||||
ta.mu.Lock()
|
||||
defer ta.mu.Unlock()
|
||||
|
||||
// allocate one from local allocator first.
|
||||
if id, err := ta.localAllocator.allocateOne(); err == nil {
|
||||
return id, nil
|
||||
}
|
||||
// allocate from remote.
|
||||
return ta.allocateRemote(ctx)
|
||||
}
|
||||
|
||||
// Sync expire the local allocator messages,
|
||||
// syncs the local allocator and remote allocator.
|
||||
func (ta *allocatorImpl) Sync() {
|
||||
ta.mu.Lock()
|
||||
defer ta.mu.Unlock()
|
||||
|
||||
ta.localAllocator.exhausted()
|
||||
}
|
||||
|
||||
// allocateRemote allocates timestamp from remote root coordinator.
|
||||
func (ta *allocatorImpl) allocateRemote(ctx context.Context) (uint64, error) {
|
||||
// Update local allocator from remote.
|
||||
start, count, err := ta.remoteAllocator.allocate(ctx, batchAllocateSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
ta.localAllocator.update(start, count)
|
||||
|
||||
// Get from local again.
|
||||
return ta.localAllocator.allocateOne()
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
package timestamp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
func TestTimestampAllocator(t *testing.T) {
|
||||
paramtable.Init()
|
||||
paramtable.SetNodeID(1)
|
||||
|
||||
client := NewMockRootCoordClient(t)
|
||||
allocator := NewAllocator(client)
|
||||
|
||||
for i := 0; i < 5000; i++ {
|
||||
ts, err := allocator.Allocate(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.NotZero(t, ts)
|
||||
}
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
ts, err := allocator.Allocate(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.NotZero(t, ts)
|
||||
time.Sleep(time.Millisecond * 1)
|
||||
allocator.Sync()
|
||||
}
|
||||
|
||||
// error test
|
||||
client.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Unset()
|
||||
client.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn(
|
||||
func(ctx context.Context, atr *rootcoordpb.AllocTimestampRequest, co ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) {
|
||||
return &rootcoordpb.AllocTimestampResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_ForceDeny,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
allocator = NewAllocator(client)
|
||||
_, err := allocator.Allocate(context.Background())
|
||||
assert.Error(t, err)
|
||||
}
|
|
@ -3,8 +3,7 @@ package wal
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/streamingpb"
|
||||
"github.com/milvus-io/milvus/pkg/streaming/walimpls"
|
||||
"github.com/milvus-io/milvus/pkg/streaming/util/types"
|
||||
)
|
||||
|
||||
// OpenerBuilder is the interface for build wal opener.
|
||||
|
@ -17,8 +16,7 @@ type OpenerBuilder interface {
|
|||
|
||||
// OpenOption is the option for allocating wal instance.
|
||||
type OpenOption struct {
|
||||
Channel *streamingpb.PChannelInfo
|
||||
InterceptorBuilders []walimpls.InterceptorBuilder // Interceptor builders to build when open.
|
||||
Channel types.PChannelInfo
|
||||
}
|
||||
|
||||
// Opener is the interface for build wal instance.
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
package interceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/streaming/util/message"
|
||||
)
|
||||
|
||||
var _ InterceptorWithReady = (*chainedInterceptor)(nil)
|
||||
|
||||
type (
|
||||
// appendInterceptorCall is the common function to execute the append interceptor.
|
||||
appendInterceptorCall = func(ctx context.Context, msg message.MutableMessage, append Append) (message.MessageID, error)
|
||||
)
|
||||
|
||||
// NewChainedInterceptor creates a new chained interceptor.
|
||||
func NewChainedInterceptor(interceptors ...BasicInterceptor) InterceptorWithReady {
|
||||
appendCalls := make([]appendInterceptorCall, 0, len(interceptors))
|
||||
for _, i := range interceptors {
|
||||
if r, ok := i.(AppendInterceptor); ok {
|
||||
appendCalls = append(appendCalls, r.DoAppend)
|
||||
}
|
||||
}
|
||||
return &chainedInterceptor{
|
||||
closed: make(chan struct{}),
|
||||
interceptors: interceptors,
|
||||
appendCall: chainAppendInterceptors(appendCalls),
|
||||
}
|
||||
}
|
||||
|
||||
// chainedInterceptor chains all interceptors into one.
|
||||
type chainedInterceptor struct {
|
||||
closed chan struct{}
|
||||
interceptors []BasicInterceptor
|
||||
appendCall appendInterceptorCall
|
||||
}
|
||||
|
||||
// Ready wait all interceptors to be ready.
|
||||
func (c *chainedInterceptor) Ready() <-chan struct{} {
|
||||
ready := make(chan struct{})
|
||||
go func() {
|
||||
for _, i := range c.interceptors {
|
||||
// check if ready is implemented
|
||||
if r, ok := i.(InterceptorReady); ok {
|
||||
select {
|
||||
case <-r.Ready():
|
||||
case <-c.closed:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
close(ready)
|
||||
}()
|
||||
return ready
|
||||
}
|
||||
|
||||
// DoAppend execute the append operation with all interceptors.
|
||||
func (c *chainedInterceptor) DoAppend(ctx context.Context, msg message.MutableMessage, append Append) (message.MessageID, error) {
|
||||
return c.appendCall(ctx, msg, append)
|
||||
}
|
||||
|
||||
// Close close all interceptors.
|
||||
func (c *chainedInterceptor) Close() {
|
||||
close(c.closed)
|
||||
for _, i := range c.interceptors {
|
||||
i.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// chainAppendInterceptors chains all unary client interceptors into one.
|
||||
func chainAppendInterceptors(interceptorCalls []appendInterceptorCall) appendInterceptorCall {
|
||||
if len(interceptorCalls) == 0 {
|
||||
// Do nothing if no interceptors.
|
||||
return func(ctx context.Context, msg message.MutableMessage, append Append) (message.MessageID, error) {
|
||||
return append(ctx, msg)
|
||||
}
|
||||
} else if len(interceptorCalls) == 1 {
|
||||
return interceptorCalls[0]
|
||||
}
|
||||
return func(ctx context.Context, msg message.MutableMessage, invoker Append) (message.MessageID, error) {
|
||||
return interceptorCalls[0](ctx, msg, getChainAppendInvoker(interceptorCalls, 0, invoker))
|
||||
}
|
||||
}
|
||||
|
||||
// getChainAppendInvoker recursively generate the chained unary invoker.
|
||||
func getChainAppendInvoker(interceptors []appendInterceptorCall, idx int, finalInvoker Append) Append {
|
||||
// all interceptor is called, so return the final invoker.
|
||||
if idx == len(interceptors)-1 {
|
||||
return finalInvoker
|
||||
}
|
||||
// recursively generate the chained invoker.
|
||||
return func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) {
|
||||
return interceptors[idx+1](ctx, msg, getChainAppendInvoker(interceptors, idx+1, finalInvoker))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
package interceptors_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mocks/streamingnode/server/wal/mock_interceptors"
|
||||
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
|
||||
"github.com/milvus-io/milvus/pkg/streaming/util/message"
|
||||
)
|
||||
|
||||
func TestChainInterceptor(t *testing.T) {
|
||||
for i := 0; i < 5; i++ {
|
||||
testChainInterceptor(t, i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainReady(t *testing.T) {
|
||||
count := 5
|
||||
channels := make([]chan struct{}, 0, count)
|
||||
ips := make([]interceptors.BasicInterceptor, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
ch := make(chan struct{})
|
||||
channels = append(channels, ch)
|
||||
interceptor := mock_interceptors.NewMockInterceptorWithReady(t)
|
||||
interceptor.EXPECT().Ready().Return(ch)
|
||||
interceptor.EXPECT().Close().Return()
|
||||
ips = append(ips, interceptor)
|
||||
}
|
||||
chainInterceptor := interceptors.NewChainedInterceptor(ips...)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
// part of interceptors is not ready
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
select {
|
||||
case <-chainInterceptor.Ready():
|
||||
t.Fatal("should not ready")
|
||||
case <-ctx.Done():
|
||||
}
|
||||
close(channels[i])
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
select {
|
||||
case <-chainInterceptor.Ready():
|
||||
case <-ctx.Done():
|
||||
t.Fatal("interceptor should be ready now")
|
||||
}
|
||||
chainInterceptor.Close()
|
||||
|
||||
interceptor := mock_interceptors.NewMockInterceptorWithReady(t)
|
||||
ch := make(chan struct{})
|
||||
interceptor.EXPECT().Ready().Return(ch)
|
||||
interceptor.EXPECT().Close().Return()
|
||||
chainInterceptor = interceptors.NewChainedInterceptor(interceptor)
|
||||
chainInterceptor.Close()
|
||||
|
||||
// closed chain interceptor should block the ready (internal interceptor is not ready)
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
select {
|
||||
case <-chainInterceptor.Ready():
|
||||
t.Fatal("chan interceptor that closed but internal interceptor is not ready should block the ready")
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func testChainInterceptor(t *testing.T, count int) {
|
||||
type record struct {
|
||||
before bool
|
||||
after bool
|
||||
closed bool
|
||||
}
|
||||
|
||||
appendInterceptorRecords := make([]record, 0, count)
|
||||
ips := make([]interceptors.BasicInterceptor, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
j := i
|
||||
appendInterceptorRecords = append(appendInterceptorRecords, record{})
|
||||
|
||||
interceptor := mock_interceptors.NewMockInterceptor(t)
|
||||
interceptor.EXPECT().DoAppend(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(
|
||||
func(ctx context.Context, mm message.MutableMessage, f func(context.Context, message.MutableMessage) (message.MessageID, error)) (message.MessageID, error) {
|
||||
appendInterceptorRecords[j].before = true
|
||||
msgID, err := f(ctx, mm)
|
||||
appendInterceptorRecords[j].after = true
|
||||
return msgID, err
|
||||
})
|
||||
interceptor.EXPECT().Close().Run(func() {
|
||||
appendInterceptorRecords[j].closed = true
|
||||
})
|
||||
ips = append(ips, interceptor)
|
||||
}
|
||||
interceptor := interceptors.NewChainedInterceptor(ips...)
|
||||
|
||||
// fast return
|
||||
<-interceptor.Ready()
|
||||
|
||||
msg, err := interceptor.DoAppend(context.Background(), nil, func(context.Context, message.MutableMessage) (message.MessageID, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, msg)
|
||||
interceptor.Close()
|
||||
for i := 0; i < count; i++ {
|
||||
assert.True(t, appendInterceptorRecords[i].before)
|
||||
assert.True(t, appendInterceptorRecords[i].after)
|
||||
assert.True(t, appendInterceptorRecords[i].closed)
|
||||
}
|
||||
}
|
|
@ -1,25 +1,31 @@
|
|||
package walimpls
|
||||
package interceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/streamingnode/server/wal"
|
||||
"github.com/milvus-io/milvus/pkg/streaming/util/message"
|
||||
"github.com/milvus-io/milvus/pkg/streaming/walimpls"
|
||||
"github.com/milvus-io/milvus/pkg/util/syncutil"
|
||||
)
|
||||
|
||||
type (
|
||||
// Append is the common function to append a msg to the wal.
|
||||
Append = func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error)
|
||||
// Read is the common function to read a msg from the wal.
|
||||
Read = func(ctx context.Context, opt ReadOption) (ScannerImpls, error)
|
||||
)
|
||||
|
||||
type InterceptorBuildParam struct {
|
||||
WALImpls walimpls.WALImpls // The underlying walimpls implementation, can be used anytime.
|
||||
WAL *syncutil.Future[wal.WAL] // The wal final object, can be used after interceptor is ready.
|
||||
}
|
||||
|
||||
// InterceptorBuilder is the interface to build a interceptor.
|
||||
// 1. InterceptorBuilder is concurrent safe.
|
||||
// 2. InterceptorBuilder can used to build a interceptor with cross-wal shared resources.
|
||||
type InterceptorBuilder interface {
|
||||
// Build build a interceptor with wal that interceptor will work on.
|
||||
// the wal object will be sent to the interceptor builder when the wal is constructed with all interceptors.
|
||||
Build(wal <-chan WALImpls) BasicInterceptor
|
||||
Build(param InterceptorBuildParam) BasicInterceptor
|
||||
}
|
||||
|
||||
type BasicInterceptor interface {
|
|
@ -0,0 +1,87 @@
|
|||
package ack
|
||||
|
||||
import (
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/streaming/util/message"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
var _ typeutil.HeapInterface = (*timestampWithAckArray)(nil)
|
||||
|
||||
// newAcker creates a new acker.
|
||||
func newAcker(ts uint64, lastConfirmedMessageID message.MessageID) *Acker {
|
||||
return &Acker{
|
||||
acknowledged: atomic.NewBool(false),
|
||||
detail: newAckDetail(ts, lastConfirmedMessageID),
|
||||
}
|
||||
}
|
||||
|
||||
// Acker records the timestamp and last confirmed message id that has not been acknowledged.
|
||||
type Acker struct {
|
||||
acknowledged *atomic.Bool // is acknowledged.
|
||||
detail *AckDetail // info is available after acknowledged.
|
||||
}
|
||||
|
||||
// LastConfirmedMessageID returns the last confirmed message id.
|
||||
func (ta *Acker) LastConfirmedMessageID() message.MessageID {
|
||||
return ta.detail.LastConfirmedMessageID
|
||||
}
|
||||
|
||||
// Timestamp returns the timestamp.
|
||||
func (ta *Acker) Timestamp() uint64 {
|
||||
return ta.detail.Timestamp
|
||||
}
|
||||
|
||||
// Ack marks the timestamp as acknowledged.
|
||||
func (ta *Acker) Ack(opts ...AckOption) {
|
||||
for _, opt := range opts {
|
||||
opt(ta.detail)
|
||||
}
|
||||
ta.acknowledged.Store(true)
|
||||
}
|
||||
|
||||
// ackDetail returns the ack info, only can be called after acknowledged.
|
||||
func (ta *Acker) ackDetail() *AckDetail {
|
||||
if !ta.acknowledged.Load() {
|
||||
panic("unreachable: ackDetail can only be called after acknowledged")
|
||||
}
|
||||
return ta.detail
|
||||
}
|
||||
|
||||
// timestampWithAckArray is a heap underlying represent of timestampAck.
|
||||
type timestampWithAckArray []*Acker
|
||||
|
||||
// Len returns the length of the heap.
|
||||
func (h timestampWithAckArray) Len() int {
|
||||
return len(h)
|
||||
}
|
||||
|
||||
// Less returns true if the element at index i is less than the element at index j.
|
||||
func (h timestampWithAckArray) Less(i, j int) bool {
|
||||
return h[i].detail.Timestamp < h[j].detail.Timestamp
|
||||
}
|
||||
|
||||
// Swap swaps the elements at indexes i and j.
|
||||
func (h timestampWithAckArray) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
|
||||
// Push pushes the last one at len.
|
||||
func (h *timestampWithAckArray) Push(x interface{}) {
|
||||
// Push and Pop use pointer receivers because they modify the slice's length,
|
||||
// not just its contents.
|
||||
*h = append(*h, x.(*Acker))
|
||||
}
|
||||
|
||||
// Pop pop the last one at len.
|
||||
func (h *timestampWithAckArray) Pop() interface{} {
|
||||
old := *h
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
*h = old[0 : n-1]
|
||||
return x
|
||||
}
|
||||
|
||||
// Peek returns the element at the top of the heap.
|
||||
func (h *timestampWithAckArray) Peek() interface{} {
|
||||
return (*h)[0]
|
||||
}
|
|
@ -0,0 +1,120 @@
|
|||
package ack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
|
||||
"github.com/milvus-io/milvus/internal/streamingnode/server/resource/timestamp"
|
||||
"github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
func TestAck(t *testing.T) {
|
||||
paramtable.Init()
|
||||
paramtable.SetNodeID(1)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
rc := timestamp.NewMockRootCoordClient(t)
|
||||
resource.InitForTest(resource.OptRootCoordClient(rc))
|
||||
|
||||
ackManager := NewAckManager()
|
||||
msgID := mock_message.NewMockMessageID(t)
|
||||
msgID.EXPECT().EQ(msgID).Return(true)
|
||||
ackManager.AdvanceLastConfirmedMessageID(msgID)
|
||||
|
||||
ackers := map[uint64]*Acker{}
|
||||
for i := 0; i < 10; i++ {
|
||||
acker, err := ackManager.Allocate(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, acker.LastConfirmedMessageID().EQ(msgID))
|
||||
ackers[acker.Timestamp()] = acker
|
||||
}
|
||||
|
||||
// notAck: [1, 2, 3, ..., 10]
|
||||
// ack: []
|
||||
details, err := ackManager.SyncAndGetAcknowledged(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, details)
|
||||
|
||||
// notAck: [1, 3, ..., 10]
|
||||
// ack: [2]
|
||||
ackers[2].Ack()
|
||||
details, err = ackManager.SyncAndGetAcknowledged(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, details)
|
||||
|
||||
// notAck: [1, 3, 5, ..., 10]
|
||||
// ack: [2, 4]
|
||||
ackers[4].Ack()
|
||||
details, err = ackManager.SyncAndGetAcknowledged(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, details)
|
||||
|
||||
// notAck: [3, 5, ..., 10]
|
||||
// ack: [1, 2, 4]
|
||||
ackers[1].Ack()
|
||||
// notAck: [3, 5, ..., 10]
|
||||
// ack: [4]
|
||||
details, err = ackManager.SyncAndGetAcknowledged(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(details))
|
||||
assert.Equal(t, uint64(1), details[0].Timestamp)
|
||||
assert.Equal(t, uint64(2), details[1].Timestamp)
|
||||
|
||||
// notAck: [3, 5, ..., 10]
|
||||
// ack: [4]
|
||||
details, err = ackManager.SyncAndGetAcknowledged(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, details)
|
||||
|
||||
// notAck: [3]
|
||||
// ack: [4, ..., 10]
|
||||
for i := 5; i <= 10; i++ {
|
||||
ackers[uint64(i)].Ack()
|
||||
}
|
||||
details, err = ackManager.SyncAndGetAcknowledged(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, details)
|
||||
|
||||
// notAck: [3, ...,x, y]
|
||||
// ack: [4, ..., 10]
|
||||
tsX, err := ackManager.Allocate(ctx)
|
||||
assert.NoError(t, err)
|
||||
tsY, err := ackManager.Allocate(ctx)
|
||||
assert.NoError(t, err)
|
||||
details, err = ackManager.SyncAndGetAcknowledged(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, details)
|
||||
|
||||
// notAck: [...,x, y]
|
||||
// ack: [3, ..., 10]
|
||||
ackers[3].Ack()
|
||||
|
||||
// notAck: [...,x, y]
|
||||
// ack: []
|
||||
details, err = ackManager.SyncAndGetAcknowledged(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, len(details), 8) // with some sync operation.
|
||||
|
||||
// notAck: []
|
||||
// ack: [11, 12]
|
||||
details, err = ackManager.SyncAndGetAcknowledged(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, details)
|
||||
|
||||
tsX.Ack()
|
||||
tsY.Ack()
|
||||
|
||||
// notAck: []
|
||||
// ack: []
|
||||
details, err = ackManager.SyncAndGetAcknowledged(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, len(details), 2) // with some sync operation.
|
||||
|
||||
// no more timestamp to ack.
|
||||
assert.Zero(t, ackManager.notAckHeap.Len())
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
package ack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/streaming/util/message"
|
||||
)
|
||||
|
||||
// newAckDetail creates a new default acker detail.
|
||||
func newAckDetail(ts uint64, lastConfirmedMessageID message.MessageID) *AckDetail {
|
||||
if ts <= 0 {
|
||||
panic(fmt.Sprintf("ts should never less than 0 %d", ts))
|
||||
}
|
||||
return &AckDetail{
|
||||
Timestamp: ts,
|
||||
LastConfirmedMessageID: lastConfirmedMessageID,
|
||||
IsSync: false,
|
||||
Err: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// AckDetail records the information of acker.
|
||||
type AckDetail struct {
|
||||
Timestamp uint64
|
||||
LastConfirmedMessageID message.MessageID
|
||||
IsSync bool
|
||||
Err error
|
||||
}
|
||||
|
||||
// AckOption is the option for acker.
|
||||
type AckOption func(*AckDetail)
|
||||
|
||||
// OptSync marks the acker is sync message.
|
||||
func OptSync() AckOption {
|
||||
return func(detail *AckDetail) {
|
||||
detail.IsSync = true
|
||||
}
|
||||
}
|
||||
|
||||
// OptError marks the timestamp ack with error info.
|
||||
func OptError(err error) AckOption {
|
||||
return func(detail *AckDetail) {
|
||||
detail.Err = err
|
||||
}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package ack
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message"
|
||||
)
|
||||
|
||||
func TestDetail(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
newAckDetail(0, mock_message.NewMockMessageID(t))
|
||||
})
|
||||
msgID := mock_message.NewMockMessageID(t)
|
||||
msgID.EXPECT().EQ(msgID).Return(true)
|
||||
|
||||
ackDetail := newAckDetail(1, msgID)
|
||||
assert.Equal(t, uint64(1), ackDetail.Timestamp)
|
||||
assert.True(t, ackDetail.LastConfirmedMessageID.EQ(msgID))
|
||||
assert.False(t, ackDetail.IsSync)
|
||||
assert.NoError(t, ackDetail.Err)
|
||||
|
||||
OptSync()(ackDetail)
|
||||
assert.True(t, ackDetail.IsSync)
|
||||
OptError(errors.New("test"))(ackDetail)
|
||||
assert.Error(t, ackDetail.Err)
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
package ack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
|
||||
"github.com/milvus-io/milvus/pkg/streaming/util/message"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
// AckManager manages the timestampAck.
|
||||
type AckManager struct {
|
||||
mu sync.Mutex
|
||||
notAckHeap typeutil.Heap[*Acker] // a minimum heap of timestampAck to search minimum timestamp in list.
|
||||
lastConfirmedMessageID message.MessageID
|
||||
}
|
||||
|
||||
// NewAckManager creates a new timestampAckHelper.
|
||||
func NewAckManager() *AckManager {
|
||||
return &AckManager{
|
||||
mu: sync.Mutex{},
|
||||
notAckHeap: typeutil.NewHeap[*Acker](×tampWithAckArray{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate allocates a timestamp.
|
||||
// Concurrent safe to call with Sync and Allocate.
|
||||
func (ta *AckManager) Allocate(ctx context.Context) (*Acker, error) {
|
||||
ta.mu.Lock()
|
||||
defer ta.mu.Unlock()
|
||||
|
||||
// allocate one from underlying allocator first.
|
||||
ts, err := resource.Resource().TimestampAllocator().Allocate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// create new timestampAck for ack process.
|
||||
// add ts to heap wait for ack.
|
||||
tsWithAck := newAcker(ts, ta.lastConfirmedMessageID)
|
||||
ta.notAckHeap.Push(tsWithAck)
|
||||
return tsWithAck, nil
|
||||
}
|
||||
|
||||
// SyncAndGetAcknowledged syncs the ack records with allocator, and get the last all acknowledged info.
|
||||
// Concurrent safe to call with Allocate.
|
||||
func (ta *AckManager) SyncAndGetAcknowledged(ctx context.Context) ([]*AckDetail, error) {
|
||||
// local timestamp may out of date, sync the underlying allocator before get last all acknowledged.
|
||||
resource.Resource().TimestampAllocator().Sync()
|
||||
|
||||
// Allocate may be uncalled in long term, and the recorder may be out of date.
|
||||
// Do a Allocate and Ack, can sync up the recorder with internal timetick.TimestampAllocator latest time.
|
||||
tsWithAck, err := ta.Allocate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tsWithAck.Ack(OptSync())
|
||||
|
||||
// update a new snapshot of acknowledged timestamps after sync up.
|
||||
return ta.popUntilLastAllAcknowledged(), nil
|
||||
}
|
||||
|
||||
// popUntilLastAllAcknowledged pops the timestamps until the one that all timestamps before it have been acknowledged.
|
||||
func (ta *AckManager) popUntilLastAllAcknowledged() []*AckDetail {
|
||||
ta.mu.Lock()
|
||||
defer ta.mu.Unlock()
|
||||
|
||||
// pop all acknowledged timestamps.
|
||||
details := make([]*AckDetail, 0, 5)
|
||||
for ta.notAckHeap.Len() > 0 && ta.notAckHeap.Peek().acknowledged.Load() {
|
||||
ack := ta.notAckHeap.Pop()
|
||||
details = append(details, ack.ackDetail())
|
||||
}
|
||||
return details
|
||||
}
|
||||
|
||||
// AdvanceLastConfirmedMessageID update the last confirmed message id.
|
||||
func (ta *AckManager) AdvanceLastConfirmedMessageID(msgID message.MessageID) {
|
||||
if msgID == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ta.mu.Lock()
|
||||
if ta.lastConfirmedMessageID == nil || ta.lastConfirmedMessageID.LT(msgID) {
|
||||
ta.lastConfirmedMessageID = msgID
|
||||
}
|
||||
ta.mu.Unlock()
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package timetick
|
||||
|
||||
import "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/ack"
|
||||
|
||||
// ackDetails records the information of AckDetail.
|
||||
// Used to analyze the ack details.
|
||||
// TODO: add more analysis methods. e.g. such as counter function with filter.
|
||||
type ackDetails struct {
|
||||
detail []*ack.AckDetail
|
||||
}
|
||||
|
||||
// AddDetails adds details to AckDetails.
|
||||
func (ad *ackDetails) AddDetails(details []*ack.AckDetail) {
|
||||
if len(details) == 0 {
|
||||
return
|
||||
}
|
||||
if len(ad.detail) == 0 {
|
||||
ad.detail = details
|
||||
return
|
||||
}
|
||||
ad.detail = append(ad.detail, details...)
|
||||
}
|
||||
|
||||
// Empty returns true if the AckDetails is empty.
|
||||
func (ad *ackDetails) Empty() bool {
|
||||
return len(ad.detail) == 0
|
||||
}
|
||||
|
||||
// Len returns the count of AckDetail.
|
||||
func (ad *ackDetails) Len() int {
|
||||
return len(ad.detail)
|
||||
}
|
||||
|
||||
// LastAllAcknowledgedTimestamp returns the last timestamp which all timestamps before it have been acknowledged.
|
||||
// panic if no timestamp has been acknowledged.
|
||||
func (ad *ackDetails) LastAllAcknowledgedTimestamp() uint64 {
|
||||
return ad.detail[len(ad.detail)-1].Timestamp
|
||||
}
|
||||
|
||||
// Clear clears the AckDetails.
|
||||
func (ad *ackDetails) Clear() {
|
||||
ad.detail = nil
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package timetick
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
|
||||
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/ack"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
var _ interceptors.InterceptorBuilder = (*interceptorBuilder)(nil)
|
||||
|
||||
// NewInterceptorBuilder creates a new interceptor builder.
|
||||
// 1. Add timetick to all message before append to wal.
|
||||
// 2. Collect timetick info, and generate sync-timetick message to wal.
|
||||
func NewInterceptorBuilder() interceptors.InterceptorBuilder {
|
||||
return &interceptorBuilder{}
|
||||
}
|
||||
|
||||
// interceptorBuilder is a builder to build timeTickAppendInterceptor.
|
||||
type interceptorBuilder struct{}
|
||||
|
||||
// Build implements Builder.
|
||||
func (b *interceptorBuilder) Build(param interceptors.InterceptorBuildParam) interceptors.BasicInterceptor {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
interceptor := &timeTickAppendInterceptor{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
ready: make(chan struct{}),
|
||||
ackManager: ack.NewAckManager(),
|
||||
ackDetails: &ackDetails{},
|
||||
sourceID: paramtable.GetNodeID(),
|
||||
}
|
||||
go interceptor.executeSyncTimeTick(
|
||||
// TODO: move the configuration to streamingnode.
|
||||
paramtable.Get().ProxyCfg.TimeTickInterval.GetAsDuration(time.Millisecond),
|
||||
param,
|
||||
)
|
||||
return interceptor
|
||||
}
|
|
@ -0,0 +1,158 @@
|
|||
package timetick
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
|
||||
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/ack"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/streaming/util/message"
|
||||
)
|
||||
|
||||
var _ interceptors.AppendInterceptor = (*timeTickAppendInterceptor)(nil)
|
||||
|
||||
// timeTickAppendInterceptor is a append interceptor.
|
||||
type timeTickAppendInterceptor struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ready chan struct{}
|
||||
|
||||
ackManager *ack.AckManager
|
||||
ackDetails *ackDetails
|
||||
sourceID int64
|
||||
}
|
||||
|
||||
// Ready implements AppendInterceptor.
|
||||
func (impl *timeTickAppendInterceptor) Ready() <-chan struct{} {
|
||||
return impl.ready
|
||||
}
|
||||
|
||||
// Do implements AppendInterceptor.
|
||||
func (impl *timeTickAppendInterceptor) DoAppend(ctx context.Context, msg message.MutableMessage, append interceptors.Append) (msgID message.MessageID, err error) {
|
||||
if msg.MessageType() != message.MessageTypeTimeTick {
|
||||
// Allocate new acker for message.
|
||||
var acker *ack.Acker
|
||||
if acker, err = impl.ackManager.Allocate(ctx); err != nil {
|
||||
return nil, errors.Wrap(err, "allocate timestamp failed")
|
||||
}
|
||||
defer func() {
|
||||
acker.Ack(ack.OptError(err))
|
||||
impl.ackManager.AdvanceLastConfirmedMessageID(msgID)
|
||||
}()
|
||||
|
||||
// Assign timestamp to message and call append method.
|
||||
msg = msg.
|
||||
WithTimeTick(acker.Timestamp()). // message assigned with these timetick.
|
||||
WithLastConfirmed(acker.LastConfirmedMessageID()) // start consuming from these message id, the message which timetick greater than current timetick will never be lost.
|
||||
}
|
||||
return append(ctx, msg)
|
||||
}
|
||||
|
||||
// Close implements AppendInterceptor.
|
||||
func (impl *timeTickAppendInterceptor) Close() {
|
||||
impl.cancel()
|
||||
}
|
||||
|
||||
// execute start a background task.
|
||||
func (impl *timeTickAppendInterceptor) executeSyncTimeTick(interval time.Duration, param interceptors.InterceptorBuildParam) {
|
||||
underlyingWALImpls := param.WALImpls
|
||||
|
||||
logger := log.With(zap.Any("channel", underlyingWALImpls.Channel()))
|
||||
logger.Info("start to sync time tick...")
|
||||
defer logger.Info("sync time tick stopped")
|
||||
|
||||
// Send first timetick message to wal before interceptor is ready.
|
||||
for count := 0; ; count++ {
|
||||
// Sent first timetick message to wal before ready.
|
||||
// New TT is always greater than all tt on previous streamingnode.
|
||||
// A fencing operation of underlying WAL is needed to make exclusive produce of topic.
|
||||
// Otherwise, the TT principle may be violated.
|
||||
// And sendTsMsg must be done, to help ackManager to get first LastConfirmedMessageID
|
||||
// !!! Send a timetick message into walimpls directly is safe.
|
||||
select {
|
||||
case <-impl.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
if err := impl.sendTsMsg(impl.ctx, underlyingWALImpls.Append); err != nil {
|
||||
log.Warn("send first timestamp message failed", zap.Error(err), zap.Int("retryCount", count))
|
||||
// TODO: exponential backoff.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// interceptor is ready now.
|
||||
close(impl.ready)
|
||||
logger.Info("start to sync time ready")
|
||||
|
||||
// interceptor is ready, wait for the final wal object is ready to use.
|
||||
wal := param.WAL.Get()
|
||||
|
||||
// TODO: sync time tick message to wal periodically.
|
||||
// Add a trigger on `AckManager` to sync time tick message without periodically.
|
||||
// `AckManager` gather detail information, time tick sync can check it and make the message between tt more smaller.
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-impl.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := impl.sendTsMsg(impl.ctx, wal.Append); err != nil {
|
||||
log.Warn("send time tick sync message failed", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// syncAcknowledgedDetails syncs the timestamp acknowledged details.
|
||||
func (impl *timeTickAppendInterceptor) syncAcknowledgedDetails() {
|
||||
// Sync up and get last confirmed timestamp.
|
||||
ackDetails, err := impl.ackManager.SyncAndGetAcknowledged(impl.ctx)
|
||||
if err != nil {
|
||||
log.Warn("sync timestamp ack manager failed", zap.Error(err))
|
||||
}
|
||||
|
||||
// Add ack details to ackDetails.
|
||||
impl.ackDetails.AddDetails(ackDetails)
|
||||
}
|
||||
|
||||
// sendTsMsg sends first timestamp message to wal.
|
||||
// TODO: TT lag warning.
|
||||
func (impl *timeTickAppendInterceptor) sendTsMsg(_ context.Context, appender func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error)) error {
|
||||
// Sync the timestamp acknowledged details.
|
||||
impl.syncAcknowledgedDetails()
|
||||
|
||||
if impl.ackDetails.Empty() {
|
||||
// No acknowledged info can be sent.
|
||||
// Some message sent operation is blocked, new TT cannot be pushed forward.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Construct time tick message.
|
||||
msg, err := newTimeTickMsg(impl.ackDetails.LastAllAcknowledgedTimestamp(), impl.sourceID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "at build time tick msg")
|
||||
}
|
||||
|
||||
// Append it to wal.
|
||||
msgID, err := appender(impl.ctx, msg)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err,
|
||||
"append time tick msg to wal failed, timestamp: %d, previous message counter: %d",
|
||||
impl.ackDetails.LastAllAcknowledgedTimestamp(),
|
||||
impl.ackDetails.Len(),
|
||||
)
|
||||
}
|
||||
|
||||
// Ack details has been committed to wal, clear it.
|
||||
impl.ackDetails.Clear()
|
||||
impl.ackManager.AdvanceLastConfirmedMessageID(msgID)
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
package timetick
|
||||
|
||||
import (
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/streaming/util/message"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
)
|
||||
|
||||
func newTimeTickMsg(ts uint64, sourceID int64) (message.MutableMessage, error) {
|
||||
// TODO: time tick should be put on properties, for compatibility, we put it on message body now.
|
||||
msgstreamMsg := &msgstream.TimeTickMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: ts,
|
||||
EndTimestamp: ts,
|
||||
HashValues: []uint32{0},
|
||||
},
|
||||
TimeTickMsg: msgpb.TimeTickMsg{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_TimeTick),
|
||||
commonpbutil.WithMsgID(0),
|
||||
commonpbutil.WithTimeStamp(ts),
|
||||
commonpbutil.WithSourceID(sourceID),
|
||||
),
|
||||
},
|
||||
}
|
||||
bytes, err := msgstreamMsg.Marshal(msgstreamMsg)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshal time tick message failed")
|
||||
}
|
||||
|
||||
payload, ok := bytes.([]byte)
|
||||
if !ok {
|
||||
return nil, errors.New("marshal time tick message as []byte failed")
|
||||
}
|
||||
|
||||
// Common message's time tick is set on interceptor.
|
||||
// TimeTickMsg's time tick should be set here.
|
||||
msg := message.NewMutableMessageBuilder().
|
||||
WithMessageType(message.MessageTypeTimeTick).
|
||||
WithPayload(payload).
|
||||
BuildMutable().
|
||||
WithTimeTick(ts)
|
||||
return msg, nil
|
||||
}
|
|
@ -82,16 +82,14 @@ func (_c *MockWALImpls_Append_Call) RunAndReturn(run func(context.Context, messa
|
|||
}
|
||||
|
||||
// Channel provides a mock function with given fields:
|
||||
func (_m *MockWALImpls) Channel() *types.PChannelInfo {
|
||||
func (_m *MockWALImpls) Channel() types.PChannelInfo {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 *types.PChannelInfo
|
||||
if rf, ok := ret.Get(0).(func() *types.PChannelInfo); ok {
|
||||
var r0 types.PChannelInfo
|
||||
if rf, ok := ret.Get(0).(func() types.PChannelInfo); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*types.PChannelInfo)
|
||||
}
|
||||
r0 = ret.Get(0).(types.PChannelInfo)
|
||||
}
|
||||
|
||||
return r0
|
||||
|
@ -114,12 +112,12 @@ func (_c *MockWALImpls_Channel_Call) Run(run func()) *MockWALImpls_Channel_Call
|
|||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockWALImpls_Channel_Call) Return(_a0 *types.PChannelInfo) *MockWALImpls_Channel_Call {
|
||||
func (_c *MockWALImpls_Channel_Call) Return(_a0 types.PChannelInfo) *MockWALImpls_Channel_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockWALImpls_Channel_Call) RunAndReturn(run func() *types.PChannelInfo) *MockWALImpls_Channel_Call {
|
||||
func (_c *MockWALImpls_Channel_Call) RunAndReturn(run func() types.PChannelInfo) *MockWALImpls_Channel_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
|
|
@ -315,6 +315,47 @@ func (_c *MockImmutableMessage_TimeTick_Call) RunAndReturn(run func() uint64) *M
|
|||
return _c
|
||||
}
|
||||
|
||||
// VChannel provides a mock function with given fields:
|
||||
func (_m *MockImmutableMessage) VChannel() string {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 string
|
||||
if rf, ok := ret.Get(0).(func() string); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockImmutableMessage_VChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VChannel'
|
||||
type MockImmutableMessage_VChannel_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VChannel is a helper method to define mock.On call
|
||||
func (_e *MockImmutableMessage_Expecter) VChannel() *MockImmutableMessage_VChannel_Call {
|
||||
return &MockImmutableMessage_VChannel_Call{Call: _e.mock.On("VChannel")}
|
||||
}
|
||||
|
||||
func (_c *MockImmutableMessage_VChannel_Call) Run(run func()) *MockImmutableMessage_VChannel_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockImmutableMessage_VChannel_Call) Return(_a0 string) *MockImmutableMessage_VChannel_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockImmutableMessage_VChannel_Call) RunAndReturn(run func() string) *MockImmutableMessage_VChannel_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Version provides a mock function with given fields:
|
||||
func (_m *MockImmutableMessage) Version() message.Version {
|
||||
ret := _m.Called()
|
||||
|
|
|
@ -232,6 +232,47 @@ func (_c *MockMutableMessage_Properties_Call) RunAndReturn(run func() message.Pr
|
|||
return _c
|
||||
}
|
||||
|
||||
// Version provides a mock function with given fields:
|
||||
func (_m *MockMutableMessage) Version() message.Version {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 message.Version
|
||||
if rf, ok := ret.Get(0).(func() message.Version); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(message.Version)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockMutableMessage_Version_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Version'
|
||||
type MockMutableMessage_Version_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Version is a helper method to define mock.On call
|
||||
func (_e *MockMutableMessage_Expecter) Version() *MockMutableMessage_Version_Call {
|
||||
return &MockMutableMessage_Version_Call{Call: _e.mock.On("Version")}
|
||||
}
|
||||
|
||||
func (_c *MockMutableMessage_Version_Call) Run(run func()) *MockMutableMessage_Version_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMutableMessage_Version_Call) Return(_a0 message.Version) *MockMutableMessage_Version_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMutableMessage_Version_Call) RunAndReturn(run func() message.Version) *MockMutableMessage_Version_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// WithLastConfirmed provides a mock function with given fields: id
|
||||
func (_m *MockMutableMessage) WithLastConfirmed(id message.MessageID) message.MutableMessage {
|
||||
ret := _m.Called(id)
|
||||
|
|
|
@ -1,60 +1,72 @@
|
|||
package message
|
||||
|
||||
// NewBuilder creates a new builder.
|
||||
func NewBuilder() *Builder {
|
||||
return &Builder{
|
||||
id: nil,
|
||||
// NewImmutableMessage creates a new immutable message.
|
||||
func NewImmutableMesasge(
|
||||
id MessageID,
|
||||
payload []byte,
|
||||
properties map[string]string,
|
||||
) ImmutableMessage {
|
||||
return &immutableMessageImpl{
|
||||
id: id,
|
||||
messageImpl: messageImpl{
|
||||
payload: payload,
|
||||
properties: properties,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMutableMessageBuilder creates a new builder.
|
||||
// Should only used at client side.
|
||||
func NewMutableMessageBuilder() *MutableMesasgeBuilder {
|
||||
return &MutableMesasgeBuilder{
|
||||
payload: nil,
|
||||
properties: make(propertiesImpl),
|
||||
}
|
||||
}
|
||||
|
||||
// Builder is the builder for message.
|
||||
type Builder struct {
|
||||
id MessageID
|
||||
// MutableMesasgeBuilder is the builder for message.
|
||||
type MutableMesasgeBuilder struct {
|
||||
payload []byte
|
||||
properties propertiesImpl
|
||||
}
|
||||
|
||||
// WithMessageID creates a new builder with message id.
|
||||
func (b *Builder) WithMessageID(id MessageID) *Builder {
|
||||
b.id = id
|
||||
func (b *MutableMesasgeBuilder) WithMessageType(t MessageType) *MutableMesasgeBuilder {
|
||||
b.properties.Set(messageTypeKey, t.marshal())
|
||||
return b
|
||||
}
|
||||
|
||||
// WithMessageType creates a new builder with message type.
|
||||
func (b *Builder) WithMessageType(t MessageType) *Builder {
|
||||
b.properties.Set(messageTypeKey, t.marshal())
|
||||
// WithPayload creates a new builder with message payload.
|
||||
// The MessageType is required to indicate which message type payload is.
|
||||
func (b *MutableMesasgeBuilder) WithPayload(payload []byte) *MutableMesasgeBuilder {
|
||||
b.payload = payload
|
||||
return b
|
||||
}
|
||||
|
||||
// WithProperty creates a new builder with message property.
|
||||
// A key started with '_' is reserved for log system, should never used at user of client.
|
||||
func (b *Builder) WithProperty(key string, val string) *Builder {
|
||||
func (b *MutableMesasgeBuilder) WithProperty(key string, val string) *MutableMesasgeBuilder {
|
||||
b.properties.Set(key, val)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithProperties creates a new builder with message properties.
|
||||
// A key started with '_' is reserved for log system, should never used at user of client.
|
||||
func (b *Builder) WithProperties(kvs map[string]string) *Builder {
|
||||
func (b *MutableMesasgeBuilder) WithProperties(kvs map[string]string) *MutableMesasgeBuilder {
|
||||
for key, val := range kvs {
|
||||
b.properties.Set(key, val)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPayload creates a new builder with message payload.
|
||||
func (b *Builder) WithPayload(payload []byte) *Builder {
|
||||
b.payload = payload
|
||||
return b
|
||||
}
|
||||
|
||||
// BuildMutable builds a mutable message.
|
||||
// Panic if set the message id.
|
||||
func (b *Builder) BuildMutable() MutableMessage {
|
||||
if b.id != nil {
|
||||
panic("build a mutable message, message id should be nil")
|
||||
// Panic if not set payload and message type.
|
||||
// should only used at client side.
|
||||
func (b *MutableMesasgeBuilder) BuildMutable() MutableMessage {
|
||||
if b.payload == nil {
|
||||
panic("message builder not ready for payload field")
|
||||
}
|
||||
if !b.properties.Exist(messageTypeKey) {
|
||||
panic("message builder not ready for message type field")
|
||||
}
|
||||
// Set message version.
|
||||
b.properties.Set(messageVersion, VersionV1.String())
|
||||
|
@ -63,18 +75,3 @@ func (b *Builder) BuildMutable() MutableMessage {
|
|||
properties: b.properties,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildImmutable builds a immutable message.
|
||||
// Panic if not set the message id.
|
||||
func (b *Builder) BuildImmutable() ImmutableMessage {
|
||||
if b.id == nil {
|
||||
panic("build a immutable message, message id should not be nil")
|
||||
}
|
||||
return &immutableMessageImpl{
|
||||
id: b.id,
|
||||
messageImpl: messageImpl{
|
||||
payload: b.payload,
|
||||
properties: b.properties,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,6 +11,11 @@ type BasicMessage interface {
|
|||
// MessageType returns the type of message.
|
||||
MessageType() MessageType
|
||||
|
||||
// Version returns the message version.
|
||||
// 0: old version before streamingnode.
|
||||
// from 1: new version after streamingnode.
|
||||
Version() Version
|
||||
|
||||
// Message payload.
|
||||
Payload() []byte
|
||||
|
||||
|
@ -47,6 +52,11 @@ type ImmutableMessage interface {
|
|||
// WALName returns the name of message related wal.
|
||||
WALName() string
|
||||
|
||||
// VChannel returns the virtual channel of current message.
|
||||
// Available only when the message's version greater than 0.
|
||||
// Otherwise, it will panic.
|
||||
VChannel() string
|
||||
|
||||
// TimeTick returns the time tick of current message.
|
||||
// Available only when the message's version greater than 0.
|
||||
// Otherwise, it will panic.
|
||||
|
@ -64,9 +74,4 @@ type ImmutableMessage interface {
|
|||
|
||||
// Properties returns the message read only properties.
|
||||
Properties() RProperties
|
||||
|
||||
// Version returns the message format version.
|
||||
// 0: old version before streamingnode.
|
||||
// from 1: new version after streamingnode.
|
||||
Version() Version
|
||||
}
|
||||
|
|
|
@ -12,8 +12,9 @@ import (
|
|||
)
|
||||
|
||||
func TestMessage(t *testing.T) {
|
||||
b := message.NewBuilder()
|
||||
mutableMessage := b.WithMessageType(message.MessageTypeTimeTick).
|
||||
b := message.NewMutableMessageBuilder()
|
||||
mutableMessage := b.
|
||||
WithMessageType(message.MessageTypeTimeTick).
|
||||
WithPayload([]byte("payload")).
|
||||
WithProperties(map[string]string{"key": "value"}).
|
||||
BuildMutable()
|
||||
|
@ -49,17 +50,15 @@ func TestMessage(t *testing.T) {
|
|||
panic(fmt.Sprintf("unexpected data: %s", data))
|
||||
})
|
||||
|
||||
b = message.NewBuilder()
|
||||
immutableMessage := b.WithMessageID(msgID).
|
||||
WithPayload([]byte("payload")).
|
||||
WithProperties(map[string]string{
|
||||
immutableMessage := message.NewImmutableMesasge(msgID,
|
||||
[]byte("payload"),
|
||||
map[string]string{
|
||||
"key": "value",
|
||||
"_t": "1",
|
||||
"_tt": string(proto.EncodeVarint(456)),
|
||||
"_v": "1",
|
||||
"_lc": "lcMsgID",
|
||||
}).
|
||||
BuildImmutable()
|
||||
})
|
||||
|
||||
assert.True(t, immutableMessage.MessageID().EQ(msgID))
|
||||
assert.Equal(t, "payload", string(immutableMessage.Payload()))
|
||||
|
@ -73,12 +72,13 @@ func TestMessage(t *testing.T) {
|
|||
assert.Equal(t, uint64(456), immutableMessage.TimeTick())
|
||||
assert.NotNil(t, immutableMessage.LastConfirmedMessageID())
|
||||
|
||||
b = message.NewBuilder()
|
||||
immutableMessage = b.WithMessageID(msgID).
|
||||
WithPayload([]byte("payload")).
|
||||
WithProperty("key", "value").
|
||||
WithProperty("_t", "1").
|
||||
BuildImmutable()
|
||||
immutableMessage = message.NewImmutableMesasge(
|
||||
msgID,
|
||||
[]byte("payload"),
|
||||
map[string]string{
|
||||
"key": "value",
|
||||
"_t": "1",
|
||||
})
|
||||
|
||||
assert.True(t, immutableMessage.MessageID().EQ(msgID))
|
||||
assert.Equal(t, "payload", string(immutableMessage.Payload()))
|
||||
|
@ -97,9 +97,6 @@ func TestMessage(t *testing.T) {
|
|||
})
|
||||
|
||||
assert.Panics(t, func() {
|
||||
message.NewBuilder().WithMessageID(msgID).BuildMutable()
|
||||
})
|
||||
assert.Panics(t, func() {
|
||||
message.NewBuilder().BuildImmutable()
|
||||
message.NewMutableMessageBuilder().BuildMutable()
|
||||
})
|
||||
}
|
||||
|
|
|
@ -20,6 +20,15 @@ func (m *messageImpl) MessageType() MessageType {
|
|||
return unmarshalMessageType(val)
|
||||
}
|
||||
|
||||
// Version returns the message format version.
|
||||
func (m *messageImpl) Version() Version {
|
||||
value, ok := m.properties.Get(messageVersion)
|
||||
if !ok {
|
||||
return VersionOld
|
||||
}
|
||||
return newMessageVersionFromString(value)
|
||||
}
|
||||
|
||||
// Payload returns payload of current message.
|
||||
func (m *messageImpl) Payload() []byte {
|
||||
return m.payload
|
||||
|
@ -98,16 +107,15 @@ func (m *immutableMessageImpl) MessageID() MessageID {
|
|||
return m.id
|
||||
}
|
||||
|
||||
func (m *immutableMessageImpl) VChannel() string {
|
||||
value, ok := m.properties.Get(messageVChannel)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("there's a bug in the message codes, vchannel lost in properties of message, id: %+v", m.id))
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// Properties returns the message read only properties.
|
||||
func (m *immutableMessageImpl) Properties() RProperties {
|
||||
return m.properties
|
||||
}
|
||||
|
||||
// Version returns the message format version.
|
||||
func (m *immutableMessageImpl) Version() Version {
|
||||
value, ok := m.properties.Get(messageVersion)
|
||||
if !ok {
|
||||
return VersionOld
|
||||
}
|
||||
return newMessageVersionFromString(value)
|
||||
}
|
||||
|
|
|
@ -24,6 +24,12 @@ func (t MessageType) marshal() string {
|
|||
return strconv.FormatInt(int64(t), 10)
|
||||
}
|
||||
|
||||
// Valid checks if the MessageType is valid.
|
||||
func (t MessageType) Valid() bool {
|
||||
return t == MessageTypeTimeTick
|
||||
// TODO: fill more.
|
||||
}
|
||||
|
||||
// unmarshalMessageType unmarshal MessageType from string.
|
||||
func unmarshalMessageType(s string) MessageType {
|
||||
i, err := strconv.ParseInt(s, 10, 32)
|
||||
|
|
|
@ -6,6 +6,7 @@ const (
|
|||
messageTypeKey = "_t" // message type key.
|
||||
messageTimeTick = "_tt" // message time tick.
|
||||
messageLastConfirmed = "_lc" // message last confirmed message id.
|
||||
messageVChannel = "_vc" // message virtual channel.
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
|
@ -23,3 +23,7 @@ func newMessageVersionFromString(s string) Version {
|
|||
func (v Version) String() string {
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
}
|
||||
|
||||
func (v Version) GT(v2 Version) bool {
|
||||
return v > v2
|
||||
}
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
package options
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDeliver(t *testing.T) {
|
||||
policy := DeliverPolicyAll()
|
||||
assert.Equal(t, DeliverPolicyTypeAll, policy.Policy())
|
||||
assert.Panics(t, func() {
|
||||
policy.MessageID()
|
||||
})
|
||||
|
||||
policy = DeliverPolicyLatest()
|
||||
assert.Equal(t, DeliverPolicyTypeLatest, policy.Policy())
|
||||
assert.Panics(t, func() {
|
||||
policy.MessageID()
|
||||
})
|
||||
|
||||
messageID := mock_message.NewMockMessageID(t)
|
||||
policy = DeliverPolicyStartFrom(messageID)
|
||||
assert.Equal(t, DeliverPolicyTypeStartFrom, policy.Policy())
|
||||
assert.Equal(t, messageID, policy.MessageID())
|
||||
|
||||
policy = DeliverPolicyStartAfter(messageID)
|
||||
assert.Equal(t, DeliverPolicyTypeStartAfter, policy.Policy())
|
||||
assert.Equal(t, messageID, policy.MessageID())
|
||||
}
|
|
@ -57,11 +57,11 @@ func (s *scannerImpl) executeConsume() {
|
|||
s.Finish(err)
|
||||
return
|
||||
}
|
||||
newImmutableMessage := message.NewBuilder().
|
||||
WithMessageID(pulsarID{msg.ID()}).
|
||||
WithPayload(msg.Payload()).
|
||||
WithProperties(msg.Properties()).
|
||||
BuildImmutable()
|
||||
newImmutableMessage := message.NewImmutableMesasge(
|
||||
pulsarID{msg.ID()},
|
||||
msg.Payload(),
|
||||
msg.Properties(),
|
||||
)
|
||||
|
||||
select {
|
||||
case <-s.Context().Done():
|
||||
|
|
|
@ -66,11 +66,11 @@ func (s *scannerImpl) executeConsume() {
|
|||
// record the last message id to avoid repeated consume message.
|
||||
// and exclude message id should be filterred.
|
||||
if s.exclude == nil || !s.exclude.EQ(msgID) {
|
||||
s.msgChannel <- message.NewBuilder().
|
||||
WithMessageID(msgID).
|
||||
WithPayload(msg.Payload()).
|
||||
WithProperties(msg.Properties()).
|
||||
BuildImmutable()
|
||||
s.msgChannel <- message.NewImmutableMesasge(
|
||||
msgID,
|
||||
msg.Payload(),
|
||||
msg.Properties(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,13 +10,13 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
walName = "test"
|
||||
WALName = "test"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// register the builder to the registry.
|
||||
registry.RegisterBuilder(&openerBuilder{})
|
||||
message.RegisterMessageIDUnmsarshaler(walName, UnmarshalTestMessageID)
|
||||
message.RegisterMessageIDUnmsarshaler(WALName, UnmarshalTestMessageID)
|
||||
}
|
||||
|
||||
var _ walimpls.OpenerBuilderImpls = &openerBuilder{}
|
||||
|
@ -24,7 +24,7 @@ var _ walimpls.OpenerBuilderImpls = &openerBuilder{}
|
|||
type openerBuilder struct{}
|
||||
|
||||
func (o *openerBuilder) Name() string {
|
||||
return walName
|
||||
return WALName
|
||||
}
|
||||
|
||||
func (o *openerBuilder) Build() (walimpls.OpenerImpls, error) {
|
||||
|
|
|
@ -39,7 +39,7 @@ type testMessageID int64
|
|||
|
||||
// WALName returns the name of message id related wal.
|
||||
func (id testMessageID) WALName() string {
|
||||
return walName
|
||||
return WALName
|
||||
}
|
||||
|
||||
// LT less than.
|
||||
|
|
|
@ -247,20 +247,15 @@ func (f *testOneWALImplsFramework) testAppend(ctx context.Context, w WALImpls) (
|
|||
"const": "t",
|
||||
}
|
||||
typ := message.MessageTypeUnknown
|
||||
msg := message.NewBuilder().
|
||||
msg := message.NewMutableMessageBuilder().
|
||||
WithMessageType(typ).
|
||||
WithPayload(payload).
|
||||
WithProperties(properties).
|
||||
WithMessageType(typ).
|
||||
BuildMutable()
|
||||
id, err := w.Append(ctx, msg)
|
||||
assert.NoError(f.t, err)
|
||||
assert.NotNil(f.t, id)
|
||||
ids[i] = message.NewBuilder().
|
||||
WithPayload(payload).
|
||||
WithProperties(properties).
|
||||
WithMessageID(id).
|
||||
WithMessageType(typ).
|
||||
BuildImmutable()
|
||||
ids[i] = msg.IntoImmutableMessage(id)
|
||||
}(i)
|
||||
}
|
||||
swg.Wait()
|
||||
|
@ -280,19 +275,14 @@ func (f *testOneWALImplsFramework) testAppend(ctx context.Context, w WALImpls) (
|
|||
"const": "t",
|
||||
"term": strconv.FormatInt(int64(f.term), 10),
|
||||
}
|
||||
msg := message.NewBuilder().
|
||||
msg := message.NewMutableMessageBuilder().
|
||||
WithPayload(payload).
|
||||
WithProperties(properties).
|
||||
WithMessageType(message.MessageTypeTimeTick).
|
||||
BuildMutable()
|
||||
id, err := w.Append(ctx, msg)
|
||||
assert.NoError(f.t, err)
|
||||
ids[f.messageCount-1] = message.NewBuilder().
|
||||
WithPayload(payload).
|
||||
WithProperties(properties).
|
||||
WithMessageID(id).
|
||||
WithMessageType(message.MessageTypeTimeTick).
|
||||
BuildImmutable()
|
||||
ids[f.messageCount-1] = msg.IntoImmutableMessage(id)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
package syncutil
|
||||
|
||||
// Future is a future value that can be set and retrieved.
|
||||
type Future[T any] struct {
|
||||
ch chan struct{}
|
||||
value T
|
||||
}
|
||||
|
||||
// NewFuture creates a new future.
|
||||
func NewFuture[T any]() *Future[T] {
|
||||
return &Future[T]{
|
||||
ch: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets the value of the future.
|
||||
func (f *Future[T]) Set(value T) {
|
||||
f.value = value
|
||||
close(f.ch)
|
||||
}
|
||||
|
||||
// Get retrieves the value of the future if set, otherwise block until set.
|
||||
func (f *Future[T]) Get() T {
|
||||
<-f.ch
|
||||
return f.value
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the future is set.
|
||||
func (f *Future[T]) Done() <-chan struct{} {
|
||||
return f.ch
|
||||
}
|
||||
|
||||
// Ready returns true if the future is set.
|
||||
func (f *Future[T]) Ready() bool {
|
||||
select {
|
||||
case <-f.ch:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package syncutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFuture_SetAndGet(t *testing.T) {
|
||||
f := NewFuture[int]()
|
||||
go func() {
|
||||
time.Sleep(1 * time.Second) // Simulate some work
|
||||
f.Set(42)
|
||||
}()
|
||||
|
||||
val := f.Get()
|
||||
if val != 42 {
|
||||
t.Errorf("Expected value 42, got %d", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFuture_Done(t *testing.T) {
|
||||
f := NewFuture[string]()
|
||||
go func() {
|
||||
f.Set("done")
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-f.Done():
|
||||
// Success
|
||||
case <-time.After(20 * time.Millisecond):
|
||||
t.Error("Expected future to be done within 2 seconds")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFuture_Ready(t *testing.T) {
|
||||
f := NewFuture[float64]()
|
||||
go func() {
|
||||
time.Sleep(20 * time.Millisecond) // Simulate some work
|
||||
f.Set(3.14)
|
||||
}()
|
||||
|
||||
if f.Ready() {
|
||||
t.Error("Expected future not to be ready immediately")
|
||||
}
|
||||
|
||||
<-f.Done() // Wait for the future to be set
|
||||
|
||||
if !f.Ready() {
|
||||
t.Error("Expected future to be ready after being set")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
package typeutil
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
var _ HeapInterface = (*heapArray[int])(nil)
|
||||
|
||||
// HeapInterface is the interface that a heap must implement.
|
||||
type HeapInterface interface {
|
||||
heap.Interface
|
||||
Peek() interface{}
|
||||
}
|
||||
|
||||
// Heap is a heap of E.
|
||||
// Use `golang.org/x/exp/constraints` directly if you want to change any element.
|
||||
type Heap[E any] interface {
|
||||
// Len returns the size of the heap.
|
||||
Len() int
|
||||
|
||||
// Push pushes an element onto the heap.
|
||||
Push(x E)
|
||||
|
||||
// Pop returns the element at the top of the heap.
|
||||
// Panics if the heap is empty.
|
||||
Pop() E
|
||||
|
||||
// Peek returns the element at the top of the heap.
|
||||
// Panics if the heap is empty.
|
||||
Peek() E
|
||||
}
|
||||
|
||||
// heapArray is a heap backed by an array.
|
||||
type heapArray[E constraints.Ordered] []E
|
||||
|
||||
// Len returns the length of the heap.
|
||||
func (h heapArray[E]) Len() int {
|
||||
return len(h)
|
||||
}
|
||||
|
||||
// Less returns true if the element at index i is less than the element at index j.
|
||||
func (h heapArray[E]) Less(i, j int) bool {
|
||||
return h[i] < h[j]
|
||||
}
|
||||
|
||||
// Swap swaps the elements at indexes i and j.
|
||||
func (h heapArray[E]) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
|
||||
// Push pushes the last one at len.
|
||||
func (h *heapArray[E]) Push(x interface{}) {
|
||||
// Push and Pop use pointer receivers because they modify the slice's length,
|
||||
// not just its contents.
|
||||
*h = append(*h, x.(E))
|
||||
}
|
||||
|
||||
// Pop pop the last one at len.
|
||||
func (h *heapArray[E]) Pop() interface{} {
|
||||
old := *h
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
*h = old[0 : n-1]
|
||||
return x
|
||||
}
|
||||
|
||||
// Peek returns the element at the top of the heap.
|
||||
func (h *heapArray[E]) Peek() interface{} {
|
||||
return (*h)[0]
|
||||
}
|
||||
|
||||
// reverseOrderedInterface is a heap base interface that reverses the order of the elements.
|
||||
type reverseOrderedInterface[E constraints.Ordered] struct {
|
||||
HeapInterface
|
||||
}
|
||||
|
||||
// Less returns true if the element at index j is less than the element at index i.
|
||||
func (r reverseOrderedInterface[E]) Less(i, j int) bool {
|
||||
return r.HeapInterface.Less(j, i)
|
||||
}
|
||||
|
||||
// NewHeap returns a new heap from a underlying representation.
|
||||
func NewHeap[E any](inner HeapInterface) Heap[E] {
|
||||
return &heapImpl[E, HeapInterface]{
|
||||
inner: inner,
|
||||
}
|
||||
}
|
||||
|
||||
// NewArrayBasedMaximumHeap returns a new maximum heap.
|
||||
func NewArrayBasedMaximumHeap[E constraints.Ordered](initial []E) Heap[E] {
|
||||
ha := heapArray[E](initial)
|
||||
reverse := reverseOrderedInterface[E]{
|
||||
HeapInterface: &ha,
|
||||
}
|
||||
heap.Init(reverse)
|
||||
return &heapImpl[E, reverseOrderedInterface[E]]{
|
||||
inner: reverse,
|
||||
}
|
||||
}
|
||||
|
||||
// NewArrayBasedMinimumHeap returns a new minimum heap.
|
||||
func NewArrayBasedMinimumHeap[E constraints.Ordered](initial []E) Heap[E] {
|
||||
ha := heapArray[E](initial)
|
||||
heap.Init(&ha)
|
||||
return &heapImpl[E, *heapArray[E]]{
|
||||
inner: &ha,
|
||||
}
|
||||
}
|
||||
|
||||
// heapImpl is a min-heap of E.
|
||||
type heapImpl[E any, H HeapInterface] struct {
|
||||
inner H
|
||||
}
|
||||
|
||||
// Len returns the length of the heap.
|
||||
func (h *heapImpl[E, H]) Len() int {
|
||||
return h.inner.Len()
|
||||
}
|
||||
|
||||
// Push pushes an element onto the heap.
|
||||
func (h *heapImpl[E, H]) Push(x E) {
|
||||
heap.Push(h.inner, x)
|
||||
}
|
||||
|
||||
// Pop pops an element from the heap.
|
||||
func (h *heapImpl[E, H]) Pop() E {
|
||||
return heap.Pop(h.inner).(E)
|
||||
}
|
||||
|
||||
// Peek returns the element at the top of the heap.
|
||||
func (h *heapImpl[E, H]) Peek() E {
|
||||
return h.inner.Peek().(E)
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package typeutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMinimumHeap(t *testing.T) {
|
||||
h := []int{4, 5, 2}
|
||||
heap := NewArrayBasedMinimumHeap(h)
|
||||
assert.Equal(t, 2, heap.Peek())
|
||||
assert.Equal(t, 3, heap.Len())
|
||||
heap.Push(3)
|
||||
assert.Equal(t, 2, heap.Peek())
|
||||
assert.Equal(t, 4, heap.Len())
|
||||
heap.Push(1)
|
||||
assert.Equal(t, 1, heap.Peek())
|
||||
assert.Equal(t, 5, heap.Len())
|
||||
for i := 1; i <= 5; i++ {
|
||||
assert.Equal(t, i, heap.Peek())
|
||||
assert.Equal(t, i, heap.Pop())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaximumHeap(t *testing.T) {
|
||||
h := []int{4, 1, 2}
|
||||
heap := NewArrayBasedMaximumHeap(h)
|
||||
assert.Equal(t, 4, heap.Peek())
|
||||
assert.Equal(t, 3, heap.Len())
|
||||
heap.Push(3)
|
||||
assert.Equal(t, 4, heap.Peek())
|
||||
assert.Equal(t, 4, heap.Len())
|
||||
heap.Push(5)
|
||||
assert.Equal(t, 5, heap.Peek())
|
||||
assert.Equal(t, 5, heap.Len())
|
||||
for i := 5; i >= 1; i-- {
|
||||
assert.Equal(t, i, heap.Peek())
|
||||
assert.Equal(t, i, heap.Pop())
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue