enhance: timetick interceptor implementation (#34238)

issue: #33285

- optimize the message package
- add interceptor package to achieve append operation intercepting.
- add timetick interceptor to attach timetick properties for message.
- add timetick background task to send timetick message.

Signed-off-by: chyezh <chyezh@outlook.com>
pull/34340/head v2.2-testing-20240702
chyezh 2024-07-02 14:42:08 +08:00 committed by GitHub
parent a5be322ab2
commit 3563136c2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 1899 additions and 180 deletions

View File

@ -8,7 +8,7 @@ import (
message "github.com/milvus-io/milvus/pkg/streaming/util/message"
mock "github.com/stretchr/testify/mock"
streamingpb "github.com/milvus-io/milvus/internal/proto/streamingpb"
types "github.com/milvus-io/milvus/pkg/streaming/util/types"
wal "github.com/milvus-io/milvus/internal/streamingnode/server/wal"
)
@ -117,16 +117,14 @@ func (_c *MockWAL_AppendAsync_Call) RunAndReturn(run func(context.Context, messa
}
// Channel provides a mock function with given fields:
func (_m *MockWAL) Channel() *streamingpb.PChannelInfo {
func (_m *MockWAL) Channel() types.PChannelInfo {
ret := _m.Called()
var r0 *streamingpb.PChannelInfo
if rf, ok := ret.Get(0).(func() *streamingpb.PChannelInfo); ok {
var r0 types.PChannelInfo
if rf, ok := ret.Get(0).(func() types.PChannelInfo); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*streamingpb.PChannelInfo)
}
r0 = ret.Get(0).(types.PChannelInfo)
}
return r0
@ -149,12 +147,12 @@ func (_c *MockWAL_Channel_Call) Run(run func()) *MockWAL_Channel_Call {
return _c
}
func (_c *MockWAL_Channel_Call) Return(_a0 *streamingpb.PChannelInfo) *MockWAL_Channel_Call {
func (_c *MockWAL_Channel_Call) Return(_a0 types.PChannelInfo) *MockWAL_Channel_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockWAL_Channel_Call) RunAndReturn(run func() *streamingpb.PChannelInfo) *MockWAL_Channel_Call {
func (_c *MockWAL_Channel_Call) RunAndReturn(run func() types.PChannelInfo) *MockWAL_Channel_Call {
_c.Call.Return(run)
return _c
}

View File

@ -1,11 +1,12 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
package mock_walimpls
package mock_interceptors
import (
context "context"
message "github.com/milvus-io/milvus/pkg/streaming/util/message"
mock "github.com/stretchr/testify/mock"
)

View File

@ -1,9 +1,9 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
package mock_walimpls
package mock_interceptors
import (
walimpls "github.com/milvus-io/milvus/pkg/streaming/walimpls"
interceptors "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
mock "github.com/stretchr/testify/mock"
)
@ -20,16 +20,16 @@ func (_m *MockInterceptorBuilder) EXPECT() *MockInterceptorBuilder_Expecter {
return &MockInterceptorBuilder_Expecter{mock: &_m.Mock}
}
// Build provides a mock function with given fields: wal
func (_m *MockInterceptorBuilder) Build(wal <-chan walimpls.WALImpls) walimpls.BasicInterceptor {
ret := _m.Called(wal)
// Build provides a mock function with given fields: param
func (_m *MockInterceptorBuilder) Build(param interceptors.InterceptorBuildParam) interceptors.BasicInterceptor {
ret := _m.Called(param)
var r0 walimpls.BasicInterceptor
if rf, ok := ret.Get(0).(func(<-chan walimpls.WALImpls) walimpls.BasicInterceptor); ok {
r0 = rf(wal)
var r0 interceptors.BasicInterceptor
if rf, ok := ret.Get(0).(func(interceptors.InterceptorBuildParam) interceptors.BasicInterceptor); ok {
r0 = rf(param)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(walimpls.BasicInterceptor)
r0 = ret.Get(0).(interceptors.BasicInterceptor)
}
}
@ -42,24 +42,24 @@ type MockInterceptorBuilder_Build_Call struct {
}
// Build is a helper method to define mock.On call
// - wal <-chan walimpls.WALImpls
func (_e *MockInterceptorBuilder_Expecter) Build(wal interface{}) *MockInterceptorBuilder_Build_Call {
return &MockInterceptorBuilder_Build_Call{Call: _e.mock.On("Build", wal)}
// - param interceptors.InterceptorBuildParam
func (_e *MockInterceptorBuilder_Expecter) Build(param interface{}) *MockInterceptorBuilder_Build_Call {
return &MockInterceptorBuilder_Build_Call{Call: _e.mock.On("Build", param)}
}
func (_c *MockInterceptorBuilder_Build_Call) Run(run func(wal <-chan walimpls.WALImpls)) *MockInterceptorBuilder_Build_Call {
func (_c *MockInterceptorBuilder_Build_Call) Run(run func(param interceptors.InterceptorBuildParam)) *MockInterceptorBuilder_Build_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(<-chan walimpls.WALImpls))
run(args[0].(interceptors.InterceptorBuildParam))
})
return _c
}
func (_c *MockInterceptorBuilder_Build_Call) Return(_a0 walimpls.BasicInterceptor) *MockInterceptorBuilder_Build_Call {
func (_c *MockInterceptorBuilder_Build_Call) Return(_a0 interceptors.BasicInterceptor) *MockInterceptorBuilder_Build_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockInterceptorBuilder_Build_Call) RunAndReturn(run func(<-chan walimpls.WALImpls) walimpls.BasicInterceptor) *MockInterceptorBuilder_Build_Call {
func (_c *MockInterceptorBuilder_Build_Call) RunAndReturn(run func(interceptors.InterceptorBuildParam) interceptors.BasicInterceptor) *MockInterceptorBuilder_Build_Call {
_c.Call.Return(run)
return _c
}

View File

@ -1,11 +1,12 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
package mock_walimpls
package mock_interceptors
import (
context "context"
message "github.com/milvus-io/milvus/pkg/streaming/util/message"
mock "github.com/stretchr/testify/mock"
)

View File

@ -1,42 +1,6 @@
package streamingpb
import (
"google.golang.org/protobuf/types/known/emptypb"
)
const (
ServiceMethodPrefix = "/milvus.proto.log"
InitialTerm = int64(-1)
)
func NewDeliverAll() *DeliverPolicy {
return &DeliverPolicy{
Policy: &DeliverPolicy_All{
All: &emptypb.Empty{},
},
}
}
func NewDeliverLatest() *DeliverPolicy {
return &DeliverPolicy{
Policy: &DeliverPolicy_Latest{
Latest: &emptypb.Empty{},
},
}
}
func NewDeliverStartFrom(messageID *MessageID) *DeliverPolicy {
return &DeliverPolicy{
Policy: &DeliverPolicy_StartFrom{
StartFrom: messageID,
},
}
}
func NewDeliverStartAfter(messageID *MessageID) *DeliverPolicy {
return &DeliverPolicy{
Policy: &DeliverPolicy_StartAfter{
StartAfter: messageID,
},
}
}

View File

@ -0,0 +1,76 @@
package resource
import (
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource/timestamp"
"github.com/milvus-io/milvus/internal/types"
)
var r *resourceImpl // singleton resource instance
// optResourceInit is the option to initialize the resource.
type optResourceInit func(r *resourceImpl)
// OptETCD provides the etcd client to the resource.
func OptETCD(etcd *clientv3.Client) optResourceInit {
return func(r *resourceImpl) {
r.etcdClient = etcd
}
}
// OptRootCoordClient provides the root coordinator client to the resource.
func OptRootCoordClient(rootCoordClient types.RootCoordClient) optResourceInit {
return func(r *resourceImpl) {
r.rootCoordClient = rootCoordClient
}
}
// Init initializes the singleton of resources.
// Should be call when streaming node startup.
func Init(opts ...optResourceInit) {
r = &resourceImpl{}
for _, opt := range opts {
opt(r)
}
r.timestampAllocator = timestamp.NewAllocator(r.rootCoordClient)
assertNotNil(r.TimestampAllocator())
assertNotNil(r.ETCD())
assertNotNil(r.RootCoordClient())
}
// Resource access the underlying singleton of resources.
func Resource() *resourceImpl {
return r
}
// resourceImpl is a basic resource dependency for streamingnode server.
// All utility on it is concurrent-safe and singleton.
type resourceImpl struct {
timestampAllocator timestamp.Allocator
etcdClient *clientv3.Client
rootCoordClient types.RootCoordClient
}
// TimestampAllocator returns the timestamp allocator to allocate timestamp.
func (r *resourceImpl) TimestampAllocator() timestamp.Allocator {
return r.timestampAllocator
}
// ETCD returns the etcd client.
func (r *resourceImpl) ETCD() *clientv3.Client {
return r.etcdClient
}
// RootCoordClient returns the root coordinator client.
func (r *resourceImpl) RootCoordClient() types.RootCoordClient {
return r.rootCoordClient
}
// assertNotNil panics if the resource is nil.
func assertNotNil(v interface{}) {
if v == nil {
panic("nil resource")
}
}

View File

@ -0,0 +1,31 @@
package resource
import (
"testing"
"github.com/stretchr/testify/assert"
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/milvus-io/milvus/internal/mocks"
)
func TestInit(t *testing.T) {
assert.Panics(t, func() {
Init()
})
assert.Panics(t, func() {
Init(OptETCD(&clientv3.Client{}))
})
assert.Panics(t, func() {
Init(OptETCD(&clientv3.Client{}))
})
Init(OptETCD(&clientv3.Client{}), OptRootCoordClient(mocks.NewMockRootCoordClient(t)))
assert.NotNil(t, Resource().TimestampAllocator())
assert.NotNil(t, Resource().ETCD())
assert.NotNil(t, Resource().RootCoordClient())
}
func TestInitForTest(t *testing.T) {
InitForTest()
}

View File

@ -0,0 +1,17 @@
//go:build test
// +build test
package resource
import "github.com/milvus-io/milvus/internal/streamingnode/server/resource/timestamp"
// InitForTest initializes the singleton of resources for test.
func InitForTest(opts ...optResourceInit) {
r = &resourceImpl{}
for _, opt := range opts {
opt(r)
}
if r.rootCoordClient != nil {
r.timestampAllocator = timestamp.NewAllocator(r.rootCoordClient)
}
}

View File

@ -0,0 +1,95 @@
package timestamp
import (
"context"
"fmt"
"time"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
var errExhausted = errors.New("exhausted")
// newLocalAllocator creates a new local allocator.
func newLocalAllocator() *localAllocator {
return &localAllocator{
nextStartID: 0,
endStartID: 0,
}
}
// localAllocator allocates timestamp locally.
type localAllocator struct {
nextStartID uint64 // Allocate timestamp locally.
endStartID uint64
}
// AllocateOne allocates a timestamp.
func (a *localAllocator) allocateOne() (uint64, error) {
if a.nextStartID < a.endStartID {
id := a.nextStartID
a.nextStartID++
return id, nil
}
return 0, errExhausted
}
// update updates the local allocator.
func (a *localAllocator) update(start uint64, count int) {
// local allocator can be only increasing.
if start >= a.endStartID {
a.nextStartID = start
a.endStartID = start + uint64(count)
}
}
// expire expires all id in the local allocator.
func (a *localAllocator) exhausted() {
a.nextStartID = a.endStartID
}
// remoteAllocator allocate timestamp from remote root coordinator.
type remoteAllocator struct {
rc types.RootCoordClient
nodeID int64
}
// newRemoteAllocator creates a new remote allocator.
func newRemoteAllocator(rc types.RootCoordClient) *remoteAllocator {
a := &remoteAllocator{
nodeID: paramtable.GetNodeID(),
rc: rc,
}
return a
}
func (ta *remoteAllocator) allocate(ctx context.Context, count uint32) (uint64, int, error) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
req := &rootcoordpb.AllocTimestampRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_RequestTSO),
commonpbutil.WithMsgID(0),
commonpbutil.WithSourceID(ta.nodeID),
),
Count: count,
}
resp, err := ta.rc.AllocTimestamp(ctx, req)
if err != nil {
return 0, 0, fmt.Errorf("syncTimestamp Failed:%w", err)
}
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return 0, 0, fmt.Errorf("syncTimeStamp Failed:%s", resp.GetStatus().GetReason())
}
if resp == nil {
return 0, 0, fmt.Errorf("empty AllocTimestampResponse")
}
return resp.GetTimestamp(), int(resp.GetCount()), nil
}

View File

@ -0,0 +1,97 @@
package timestamp
import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"go.uber.org/atomic"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
func TestLocalAllocator(t *testing.T) {
allocator := newLocalAllocator()
ts, err := allocator.allocateOne()
assert.Error(t, err)
assert.Zero(t, ts)
allocator.update(1, 100)
counter := atomic.NewUint64(0)
for i := 0; i < 100; i++ {
ts, err := allocator.allocateOne()
assert.NoError(t, err)
assert.NotZero(t, ts)
counter.Add(ts)
}
assert.Equal(t, uint64(5050), counter.Load())
// allocator exhausted.
ts, err = allocator.allocateOne()
assert.Error(t, err)
assert.Zero(t, ts)
// allocator can not be rollback.
allocator.update(90, 100)
ts, err = allocator.allocateOne()
assert.Error(t, err)
assert.Zero(t, ts)
// allocator can be only increasing.
allocator.update(101, 100)
ts, err = allocator.allocateOne()
assert.NoError(t, err)
assert.Equal(t, ts, uint64(101))
// allocator can be exhausted.
allocator.exhausted()
ts, err = allocator.allocateOne()
assert.Error(t, err)
assert.Zero(t, ts)
}
func TestRemoteAllocator(t *testing.T) {
paramtable.Init()
paramtable.SetNodeID(1)
client := NewMockRootCoordClient(t)
allocator := newRemoteAllocator(client)
ts, count, err := allocator.allocate(context.Background(), 100)
assert.NoError(t, err)
assert.NotZero(t, ts)
assert.Equal(t, count, 100)
// Test error.
client = mocks.NewMockRootCoordClient(t)
client.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, atr *rootcoordpb.AllocTimestampRequest, co ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) {
return nil, errors.New("test")
},
)
allocator = newRemoteAllocator(client)
_, _, err = allocator.allocate(context.Background(), 100)
assert.Error(t, err)
client.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Unset()
client.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, atr *rootcoordpb.AllocTimestampRequest, co ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) {
return &rootcoordpb.AllocTimestampResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_ForceDeny,
},
}, nil
},
)
allocator = newRemoteAllocator(client)
_, _, err = allocator.allocate(context.Background(), 100)
assert.Error(t, err)
}

View File

@ -0,0 +1,39 @@
//go:build test
// +build test
package timestamp
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/mock"
"go.uber.org/atomic"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
)
func NewMockRootCoordClient(t *testing.T) *mocks.MockRootCoordClient {
counter := atomic.NewUint64(1)
client := mocks.NewMockRootCoordClient(t)
client.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, atr *rootcoordpb.AllocTimestampRequest, co ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) {
if atr.Count > 1000 {
panic(fmt.Sprintf("count %d is too large", atr.Count))
}
c := counter.Add(uint64(atr.Count))
return &rootcoordpb.AllocTimestampResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Timestamp: c - uint64(atr.Count),
Count: atr.Count,
}, nil
},
)
return client
}

View File

@ -0,0 +1,88 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package timestamp
import (
"context"
"sync"
"github.com/milvus-io/milvus/internal/types"
)
// batchAllocateSize is the size of batch allocate from remote allocator.
const batchAllocateSize = 1000
var _ Allocator = (*allocatorImpl)(nil)
// NewAllocator creates a new allocator.
func NewAllocator(rc types.RootCoordClient) Allocator {
return &allocatorImpl{
mu: sync.Mutex{},
remoteAllocator: newRemoteAllocator(rc),
localAllocator: newLocalAllocator(),
}
}
type Allocator interface {
// Allocate allocates a timestamp.
Allocate(ctx context.Context) (uint64, error)
// Sync expire the local allocator messages,
// syncs the local allocator and remote allocator.
Sync()
}
type allocatorImpl struct {
mu sync.Mutex
remoteAllocator *remoteAllocator
localAllocator *localAllocator
}
// AllocateOne allocates a timestamp.
func (ta *allocatorImpl) Allocate(ctx context.Context) (uint64, error) {
ta.mu.Lock()
defer ta.mu.Unlock()
// allocate one from local allocator first.
if id, err := ta.localAllocator.allocateOne(); err == nil {
return id, nil
}
// allocate from remote.
return ta.allocateRemote(ctx)
}
// Sync expire the local allocator messages,
// syncs the local allocator and remote allocator.
func (ta *allocatorImpl) Sync() {
ta.mu.Lock()
defer ta.mu.Unlock()
ta.localAllocator.exhausted()
}
// allocateRemote allocates timestamp from remote root coordinator.
func (ta *allocatorImpl) allocateRemote(ctx context.Context) (uint64, error) {
// Update local allocator from remote.
start, count, err := ta.remoteAllocator.allocate(ctx, batchAllocateSize)
if err != nil {
return 0, err
}
ta.localAllocator.update(start, count)
// Get from local again.
return ta.localAllocator.allocateOne()
}

View File

@ -0,0 +1,52 @@
package timestamp
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
func TestTimestampAllocator(t *testing.T) {
paramtable.Init()
paramtable.SetNodeID(1)
client := NewMockRootCoordClient(t)
allocator := NewAllocator(client)
for i := 0; i < 5000; i++ {
ts, err := allocator.Allocate(context.Background())
assert.NoError(t, err)
assert.NotZero(t, ts)
}
for i := 0; i < 100; i++ {
ts, err := allocator.Allocate(context.Background())
assert.NoError(t, err)
assert.NotZero(t, ts)
time.Sleep(time.Millisecond * 1)
allocator.Sync()
}
// error test
client.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Unset()
client.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, atr *rootcoordpb.AllocTimestampRequest, co ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) {
return &rootcoordpb.AllocTimestampResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_ForceDeny,
},
}, nil
},
)
allocator = NewAllocator(client)
_, err := allocator.Allocate(context.Background())
assert.Error(t, err)
}

View File

@ -3,8 +3,7 @@ package wal
import (
"context"
"github.com/milvus-io/milvus/internal/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/streaming/walimpls"
"github.com/milvus-io/milvus/pkg/streaming/util/types"
)
// OpenerBuilder is the interface for build wal opener.
@ -17,8 +16,7 @@ type OpenerBuilder interface {
// OpenOption is the option for allocating wal instance.
type OpenOption struct {
Channel *streamingpb.PChannelInfo
InterceptorBuilders []walimpls.InterceptorBuilder // Interceptor builders to build when open.
Channel types.PChannelInfo
}
// Opener is the interface for build wal instance.

View File

@ -0,0 +1,95 @@
package interceptors
import (
"context"
"github.com/milvus-io/milvus/pkg/streaming/util/message"
)
var _ InterceptorWithReady = (*chainedInterceptor)(nil)
type (
// appendInterceptorCall is the common function to execute the append interceptor.
appendInterceptorCall = func(ctx context.Context, msg message.MutableMessage, append Append) (message.MessageID, error)
)
// NewChainedInterceptor creates a new chained interceptor.
func NewChainedInterceptor(interceptors ...BasicInterceptor) InterceptorWithReady {
appendCalls := make([]appendInterceptorCall, 0, len(interceptors))
for _, i := range interceptors {
if r, ok := i.(AppendInterceptor); ok {
appendCalls = append(appendCalls, r.DoAppend)
}
}
return &chainedInterceptor{
closed: make(chan struct{}),
interceptors: interceptors,
appendCall: chainAppendInterceptors(appendCalls),
}
}
// chainedInterceptor chains all interceptors into one.
type chainedInterceptor struct {
closed chan struct{}
interceptors []BasicInterceptor
appendCall appendInterceptorCall
}
// Ready wait all interceptors to be ready.
func (c *chainedInterceptor) Ready() <-chan struct{} {
ready := make(chan struct{})
go func() {
for _, i := range c.interceptors {
// check if ready is implemented
if r, ok := i.(InterceptorReady); ok {
select {
case <-r.Ready():
case <-c.closed:
return
}
}
}
close(ready)
}()
return ready
}
// DoAppend execute the append operation with all interceptors.
func (c *chainedInterceptor) DoAppend(ctx context.Context, msg message.MutableMessage, append Append) (message.MessageID, error) {
return c.appendCall(ctx, msg, append)
}
// Close close all interceptors.
func (c *chainedInterceptor) Close() {
close(c.closed)
for _, i := range c.interceptors {
i.Close()
}
}
// chainAppendInterceptors chains all unary client interceptors into one.
func chainAppendInterceptors(interceptorCalls []appendInterceptorCall) appendInterceptorCall {
if len(interceptorCalls) == 0 {
// Do nothing if no interceptors.
return func(ctx context.Context, msg message.MutableMessage, append Append) (message.MessageID, error) {
return append(ctx, msg)
}
} else if len(interceptorCalls) == 1 {
return interceptorCalls[0]
}
return func(ctx context.Context, msg message.MutableMessage, invoker Append) (message.MessageID, error) {
return interceptorCalls[0](ctx, msg, getChainAppendInvoker(interceptorCalls, 0, invoker))
}
}
// getChainAppendInvoker recursively generate the chained unary invoker.
func getChainAppendInvoker(interceptors []appendInterceptorCall, idx int, finalInvoker Append) Append {
// all interceptor is called, so return the final invoker.
if idx == len(interceptors)-1 {
return finalInvoker
}
// recursively generate the chained invoker.
return func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) {
return interceptors[idx+1](ctx, msg, getChainAppendInvoker(interceptors, idx+1, finalInvoker))
}
}

View File

@ -0,0 +1,116 @@
package interceptors_test
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/mocks/streamingnode/server/wal/mock_interceptors"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
"github.com/milvus-io/milvus/pkg/streaming/util/message"
)
func TestChainInterceptor(t *testing.T) {
for i := 0; i < 5; i++ {
testChainInterceptor(t, i)
}
}
func TestChainReady(t *testing.T) {
count := 5
channels := make([]chan struct{}, 0, count)
ips := make([]interceptors.BasicInterceptor, 0, count)
for i := 0; i < count; i++ {
ch := make(chan struct{})
channels = append(channels, ch)
interceptor := mock_interceptors.NewMockInterceptorWithReady(t)
interceptor.EXPECT().Ready().Return(ch)
interceptor.EXPECT().Close().Return()
ips = append(ips, interceptor)
}
chainInterceptor := interceptors.NewChainedInterceptor(ips...)
for i := 0; i < count; i++ {
// part of interceptors is not ready
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
select {
case <-chainInterceptor.Ready():
t.Fatal("should not ready")
case <-ctx.Done():
}
close(channels[i])
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
select {
case <-chainInterceptor.Ready():
case <-ctx.Done():
t.Fatal("interceptor should be ready now")
}
chainInterceptor.Close()
interceptor := mock_interceptors.NewMockInterceptorWithReady(t)
ch := make(chan struct{})
interceptor.EXPECT().Ready().Return(ch)
interceptor.EXPECT().Close().Return()
chainInterceptor = interceptors.NewChainedInterceptor(interceptor)
chainInterceptor.Close()
// closed chain interceptor should block the ready (internal interceptor is not ready)
ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
select {
case <-chainInterceptor.Ready():
t.Fatal("chan interceptor that closed but internal interceptor is not ready should block the ready")
case <-ctx.Done():
}
}
func testChainInterceptor(t *testing.T, count int) {
type record struct {
before bool
after bool
closed bool
}
appendInterceptorRecords := make([]record, 0, count)
ips := make([]interceptors.BasicInterceptor, 0, count)
for i := 0; i < count; i++ {
j := i
appendInterceptorRecords = append(appendInterceptorRecords, record{})
interceptor := mock_interceptors.NewMockInterceptor(t)
interceptor.EXPECT().DoAppend(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, mm message.MutableMessage, f func(context.Context, message.MutableMessage) (message.MessageID, error)) (message.MessageID, error) {
appendInterceptorRecords[j].before = true
msgID, err := f(ctx, mm)
appendInterceptorRecords[j].after = true
return msgID, err
})
interceptor.EXPECT().Close().Run(func() {
appendInterceptorRecords[j].closed = true
})
ips = append(ips, interceptor)
}
interceptor := interceptors.NewChainedInterceptor(ips...)
// fast return
<-interceptor.Ready()
msg, err := interceptor.DoAppend(context.Background(), nil, func(context.Context, message.MutableMessage) (message.MessageID, error) {
return nil, nil
})
assert.NoError(t, err)
assert.Nil(t, msg)
interceptor.Close()
for i := 0; i < count; i++ {
assert.True(t, appendInterceptorRecords[i].before)
assert.True(t, appendInterceptorRecords[i].after)
assert.True(t, appendInterceptorRecords[i].closed)
}
}

View File

@ -1,25 +1,31 @@
package walimpls
package interceptors
import (
"context"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal"
"github.com/milvus-io/milvus/pkg/streaming/util/message"
"github.com/milvus-io/milvus/pkg/streaming/walimpls"
"github.com/milvus-io/milvus/pkg/util/syncutil"
)
type (
// Append is the common function to append a msg to the wal.
Append = func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error)
// Read is the common function to read a msg from the wal.
Read = func(ctx context.Context, opt ReadOption) (ScannerImpls, error)
)
type InterceptorBuildParam struct {
WALImpls walimpls.WALImpls // The underlying walimpls implementation, can be used anytime.
WAL *syncutil.Future[wal.WAL] // The wal final object, can be used after interceptor is ready.
}
// InterceptorBuilder is the interface to build a interceptor.
// 1. InterceptorBuilder is concurrent safe.
// 2. InterceptorBuilder can used to build a interceptor with cross-wal shared resources.
type InterceptorBuilder interface {
// Build build a interceptor with wal that interceptor will work on.
// the wal object will be sent to the interceptor builder when the wal is constructed with all interceptors.
Build(wal <-chan WALImpls) BasicInterceptor
Build(param InterceptorBuildParam) BasicInterceptor
}
type BasicInterceptor interface {

View File

@ -0,0 +1,87 @@
package ack
import (
"go.uber.org/atomic"
"github.com/milvus-io/milvus/pkg/streaming/util/message"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
var _ typeutil.HeapInterface = (*timestampWithAckArray)(nil)
// newAcker creates a new acker.
func newAcker(ts uint64, lastConfirmedMessageID message.MessageID) *Acker {
return &Acker{
acknowledged: atomic.NewBool(false),
detail: newAckDetail(ts, lastConfirmedMessageID),
}
}
// Acker records the timestamp and last confirmed message id that has not been acknowledged.
type Acker struct {
acknowledged *atomic.Bool // is acknowledged.
detail *AckDetail // info is available after acknowledged.
}
// LastConfirmedMessageID returns the last confirmed message id.
func (ta *Acker) LastConfirmedMessageID() message.MessageID {
return ta.detail.LastConfirmedMessageID
}
// Timestamp returns the timestamp.
func (ta *Acker) Timestamp() uint64 {
return ta.detail.Timestamp
}
// Ack marks the timestamp as acknowledged.
func (ta *Acker) Ack(opts ...AckOption) {
for _, opt := range opts {
opt(ta.detail)
}
ta.acknowledged.Store(true)
}
// ackDetail returns the ack info, only can be called after acknowledged.
func (ta *Acker) ackDetail() *AckDetail {
if !ta.acknowledged.Load() {
panic("unreachable: ackDetail can only be called after acknowledged")
}
return ta.detail
}
// timestampWithAckArray is a heap underlying represent of timestampAck.
type timestampWithAckArray []*Acker
// Len returns the length of the heap.
func (h timestampWithAckArray) Len() int {
return len(h)
}
// Less returns true if the element at index i is less than the element at index j.
func (h timestampWithAckArray) Less(i, j int) bool {
return h[i].detail.Timestamp < h[j].detail.Timestamp
}
// Swap swaps the elements at indexes i and j.
func (h timestampWithAckArray) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
// Push pushes the last one at len.
func (h *timestampWithAckArray) Push(x interface{}) {
// Push and Pop use pointer receivers because they modify the slice's length,
// not just its contents.
*h = append(*h, x.(*Acker))
}
// Pop pop the last one at len.
func (h *timestampWithAckArray) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
// Peek returns the element at the top of the heap.
func (h *timestampWithAckArray) Peek() interface{} {
return (*h)[0]
}

View File

@ -0,0 +1,120 @@
package ack
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource/timestamp"
"github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
func TestAck(t *testing.T) {
paramtable.Init()
paramtable.SetNodeID(1)
ctx := context.Background()
rc := timestamp.NewMockRootCoordClient(t)
resource.InitForTest(resource.OptRootCoordClient(rc))
ackManager := NewAckManager()
msgID := mock_message.NewMockMessageID(t)
msgID.EXPECT().EQ(msgID).Return(true)
ackManager.AdvanceLastConfirmedMessageID(msgID)
ackers := map[uint64]*Acker{}
for i := 0; i < 10; i++ {
acker, err := ackManager.Allocate(ctx)
assert.NoError(t, err)
assert.True(t, acker.LastConfirmedMessageID().EQ(msgID))
ackers[acker.Timestamp()] = acker
}
// notAck: [1, 2, 3, ..., 10]
// ack: []
details, err := ackManager.SyncAndGetAcknowledged(ctx)
assert.NoError(t, err)
assert.Empty(t, details)
// notAck: [1, 3, ..., 10]
// ack: [2]
ackers[2].Ack()
details, err = ackManager.SyncAndGetAcknowledged(ctx)
assert.NoError(t, err)
assert.Empty(t, details)
// notAck: [1, 3, 5, ..., 10]
// ack: [2, 4]
ackers[4].Ack()
details, err = ackManager.SyncAndGetAcknowledged(ctx)
assert.NoError(t, err)
assert.Empty(t, details)
// notAck: [3, 5, ..., 10]
// ack: [1, 2, 4]
ackers[1].Ack()
// notAck: [3, 5, ..., 10]
// ack: [4]
details, err = ackManager.SyncAndGetAcknowledged(ctx)
assert.NoError(t, err)
assert.Equal(t, 2, len(details))
assert.Equal(t, uint64(1), details[0].Timestamp)
assert.Equal(t, uint64(2), details[1].Timestamp)
// notAck: [3, 5, ..., 10]
// ack: [4]
details, err = ackManager.SyncAndGetAcknowledged(ctx)
assert.NoError(t, err)
assert.Empty(t, details)
// notAck: [3]
// ack: [4, ..., 10]
for i := 5; i <= 10; i++ {
ackers[uint64(i)].Ack()
}
details, err = ackManager.SyncAndGetAcknowledged(ctx)
assert.NoError(t, err)
assert.Empty(t, details)
// notAck: [3, ...,x, y]
// ack: [4, ..., 10]
tsX, err := ackManager.Allocate(ctx)
assert.NoError(t, err)
tsY, err := ackManager.Allocate(ctx)
assert.NoError(t, err)
details, err = ackManager.SyncAndGetAcknowledged(ctx)
assert.NoError(t, err)
assert.Empty(t, details)
// notAck: [...,x, y]
// ack: [3, ..., 10]
ackers[3].Ack()
// notAck: [...,x, y]
// ack: []
details, err = ackManager.SyncAndGetAcknowledged(ctx)
assert.NoError(t, err)
assert.Greater(t, len(details), 8) // with some sync operation.
// notAck: []
// ack: [11, 12]
details, err = ackManager.SyncAndGetAcknowledged(ctx)
assert.NoError(t, err)
assert.Empty(t, details)
tsX.Ack()
tsY.Ack()
// notAck: []
// ack: []
details, err = ackManager.SyncAndGetAcknowledged(ctx)
assert.NoError(t, err)
assert.Greater(t, len(details), 2) // with some sync operation.
// no more timestamp to ack.
assert.Zero(t, ackManager.notAckHeap.Len())
}

View File

@ -0,0 +1,45 @@
package ack
import (
"fmt"
"github.com/milvus-io/milvus/pkg/streaming/util/message"
)
// newAckDetail creates a new default acker detail.
func newAckDetail(ts uint64, lastConfirmedMessageID message.MessageID) *AckDetail {
if ts <= 0 {
panic(fmt.Sprintf("ts should never less than 0 %d", ts))
}
return &AckDetail{
Timestamp: ts,
LastConfirmedMessageID: lastConfirmedMessageID,
IsSync: false,
Err: nil,
}
}
// AckDetail records the information of acker.
type AckDetail struct {
Timestamp uint64
LastConfirmedMessageID message.MessageID
IsSync bool
Err error
}
// AckOption is the option for acker.
type AckOption func(*AckDetail)
// OptSync marks the acker is sync message.
func OptSync() AckOption {
return func(detail *AckDetail) {
detail.IsSync = true
}
}
// OptError marks the timestamp ack with error info.
func OptError(err error) AckOption {
return func(detail *AckDetail) {
detail.Err = err
}
}

View File

@ -0,0 +1,29 @@
package ack
import (
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message"
)
func TestDetail(t *testing.T) {
assert.Panics(t, func() {
newAckDetail(0, mock_message.NewMockMessageID(t))
})
msgID := mock_message.NewMockMessageID(t)
msgID.EXPECT().EQ(msgID).Return(true)
ackDetail := newAckDetail(1, msgID)
assert.Equal(t, uint64(1), ackDetail.Timestamp)
assert.True(t, ackDetail.LastConfirmedMessageID.EQ(msgID))
assert.False(t, ackDetail.IsSync)
assert.NoError(t, ackDetail.Err)
OptSync()(ackDetail)
assert.True(t, ackDetail.IsSync)
OptError(errors.New("test"))(ackDetail)
assert.Error(t, ackDetail.Err)
}

View File

@ -0,0 +1,89 @@
package ack
import (
"context"
"sync"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/pkg/streaming/util/message"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
// AckManager manages the timestampAck.
type AckManager struct {
mu sync.Mutex
notAckHeap typeutil.Heap[*Acker] // a minimum heap of timestampAck to search minimum timestamp in list.
lastConfirmedMessageID message.MessageID
}
// NewAckManager creates a new timestampAckHelper.
func NewAckManager() *AckManager {
return &AckManager{
mu: sync.Mutex{},
notAckHeap: typeutil.NewHeap[*Acker](&timestampWithAckArray{}),
}
}
// Allocate allocates a timestamp.
// Concurrent safe to call with Sync and Allocate.
func (ta *AckManager) Allocate(ctx context.Context) (*Acker, error) {
ta.mu.Lock()
defer ta.mu.Unlock()
// allocate one from underlying allocator first.
ts, err := resource.Resource().TimestampAllocator().Allocate(ctx)
if err != nil {
return nil, err
}
// create new timestampAck for ack process.
// add ts to heap wait for ack.
tsWithAck := newAcker(ts, ta.lastConfirmedMessageID)
ta.notAckHeap.Push(tsWithAck)
return tsWithAck, nil
}
// SyncAndGetAcknowledged syncs the ack records with allocator, and get the last all acknowledged info.
// Concurrent safe to call with Allocate.
func (ta *AckManager) SyncAndGetAcknowledged(ctx context.Context) ([]*AckDetail, error) {
// local timestamp may out of date, sync the underlying allocator before get last all acknowledged.
resource.Resource().TimestampAllocator().Sync()
// Allocate may be uncalled in long term, and the recorder may be out of date.
// Do a Allocate and Ack, can sync up the recorder with internal timetick.TimestampAllocator latest time.
tsWithAck, err := ta.Allocate(ctx)
if err != nil {
return nil, err
}
tsWithAck.Ack(OptSync())
// update a new snapshot of acknowledged timestamps after sync up.
return ta.popUntilLastAllAcknowledged(), nil
}
// popUntilLastAllAcknowledged pops the timestamps until the one that all timestamps before it have been acknowledged.
func (ta *AckManager) popUntilLastAllAcknowledged() []*AckDetail {
ta.mu.Lock()
defer ta.mu.Unlock()
// pop all acknowledged timestamps.
details := make([]*AckDetail, 0, 5)
for ta.notAckHeap.Len() > 0 && ta.notAckHeap.Peek().acknowledged.Load() {
ack := ta.notAckHeap.Pop()
details = append(details, ack.ackDetail())
}
return details
}
// AdvanceLastConfirmedMessageID update the last confirmed message id.
func (ta *AckManager) AdvanceLastConfirmedMessageID(msgID message.MessageID) {
if msgID == nil {
return
}
ta.mu.Lock()
if ta.lastConfirmedMessageID == nil || ta.lastConfirmedMessageID.LT(msgID) {
ta.lastConfirmedMessageID = msgID
}
ta.mu.Unlock()
}

View File

@ -0,0 +1,43 @@
package timetick
import "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/ack"
// ackDetails records the information of AckDetail.
// Used to analyze the ack details.
// TODO: add more analysis methods. e.g. such as counter function with filter.
type ackDetails struct {
detail []*ack.AckDetail
}
// AddDetails adds details to AckDetails.
func (ad *ackDetails) AddDetails(details []*ack.AckDetail) {
if len(details) == 0 {
return
}
if len(ad.detail) == 0 {
ad.detail = details
return
}
ad.detail = append(ad.detail, details...)
}
// Empty returns true if the AckDetails is empty.
func (ad *ackDetails) Empty() bool {
return len(ad.detail) == 0
}
// Len returns the count of AckDetail.
func (ad *ackDetails) Len() int {
return len(ad.detail)
}
// LastAllAcknowledgedTimestamp returns the last timestamp which all timestamps before it have been acknowledged.
// panic if no timestamp has been acknowledged.
func (ad *ackDetails) LastAllAcknowledgedTimestamp() uint64 {
return ad.detail[len(ad.detail)-1].Timestamp
}
// Clear clears the AckDetails.
func (ad *ackDetails) Clear() {
ad.detail = nil
}

View File

@ -0,0 +1,41 @@
package timetick
import (
"context"
"time"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/ack"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
var _ interceptors.InterceptorBuilder = (*interceptorBuilder)(nil)
// NewInterceptorBuilder creates a new interceptor builder.
// 1. Add timetick to all message before append to wal.
// 2. Collect timetick info, and generate sync-timetick message to wal.
func NewInterceptorBuilder() interceptors.InterceptorBuilder {
return &interceptorBuilder{}
}
// interceptorBuilder is a builder to build timeTickAppendInterceptor.
type interceptorBuilder struct{}
// Build implements Builder.
func (b *interceptorBuilder) Build(param interceptors.InterceptorBuildParam) interceptors.BasicInterceptor {
ctx, cancel := context.WithCancel(context.Background())
interceptor := &timeTickAppendInterceptor{
ctx: ctx,
cancel: cancel,
ready: make(chan struct{}),
ackManager: ack.NewAckManager(),
ackDetails: &ackDetails{},
sourceID: paramtable.GetNodeID(),
}
go interceptor.executeSyncTimeTick(
// TODO: move the configuration to streamingnode.
paramtable.Get().ProxyCfg.TimeTickInterval.GetAsDuration(time.Millisecond),
param,
)
return interceptor
}

View File

@ -0,0 +1,158 @@
package timetick
import (
"context"
"time"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/ack"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/streaming/util/message"
)
var _ interceptors.AppendInterceptor = (*timeTickAppendInterceptor)(nil)
// timeTickAppendInterceptor is a append interceptor.
type timeTickAppendInterceptor struct {
ctx context.Context
cancel context.CancelFunc
ready chan struct{}
ackManager *ack.AckManager
ackDetails *ackDetails
sourceID int64
}
// Ready implements AppendInterceptor.
func (impl *timeTickAppendInterceptor) Ready() <-chan struct{} {
return impl.ready
}
// Do implements AppendInterceptor.
func (impl *timeTickAppendInterceptor) DoAppend(ctx context.Context, msg message.MutableMessage, append interceptors.Append) (msgID message.MessageID, err error) {
if msg.MessageType() != message.MessageTypeTimeTick {
// Allocate new acker for message.
var acker *ack.Acker
if acker, err = impl.ackManager.Allocate(ctx); err != nil {
return nil, errors.Wrap(err, "allocate timestamp failed")
}
defer func() {
acker.Ack(ack.OptError(err))
impl.ackManager.AdvanceLastConfirmedMessageID(msgID)
}()
// Assign timestamp to message and call append method.
msg = msg.
WithTimeTick(acker.Timestamp()). // message assigned with these timetick.
WithLastConfirmed(acker.LastConfirmedMessageID()) // start consuming from these message id, the message which timetick greater than current timetick will never be lost.
}
return append(ctx, msg)
}
// Close implements AppendInterceptor.
func (impl *timeTickAppendInterceptor) Close() {
impl.cancel()
}
// execute start a background task.
func (impl *timeTickAppendInterceptor) executeSyncTimeTick(interval time.Duration, param interceptors.InterceptorBuildParam) {
underlyingWALImpls := param.WALImpls
logger := log.With(zap.Any("channel", underlyingWALImpls.Channel()))
logger.Info("start to sync time tick...")
defer logger.Info("sync time tick stopped")
// Send first timetick message to wal before interceptor is ready.
for count := 0; ; count++ {
// Sent first timetick message to wal before ready.
// New TT is always greater than all tt on previous streamingnode.
// A fencing operation of underlying WAL is needed to make exclusive produce of topic.
// Otherwise, the TT principle may be violated.
// And sendTsMsg must be done, to help ackManager to get first LastConfirmedMessageID
// !!! Send a timetick message into walimpls directly is safe.
select {
case <-impl.ctx.Done():
return
default:
}
if err := impl.sendTsMsg(impl.ctx, underlyingWALImpls.Append); err != nil {
log.Warn("send first timestamp message failed", zap.Error(err), zap.Int("retryCount", count))
// TODO: exponential backoff.
time.Sleep(50 * time.Millisecond)
continue
}
break
}
// interceptor is ready now.
close(impl.ready)
logger.Info("start to sync time ready")
// interceptor is ready, wait for the final wal object is ready to use.
wal := param.WAL.Get()
// TODO: sync time tick message to wal periodically.
// Add a trigger on `AckManager` to sync time tick message without periodically.
// `AckManager` gather detail information, time tick sync can check it and make the message between tt more smaller.
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-impl.ctx.Done():
return
case <-ticker.C:
if err := impl.sendTsMsg(impl.ctx, wal.Append); err != nil {
log.Warn("send time tick sync message failed", zap.Error(err))
}
}
}
}
// syncAcknowledgedDetails syncs the timestamp acknowledged details.
func (impl *timeTickAppendInterceptor) syncAcknowledgedDetails() {
// Sync up and get last confirmed timestamp.
ackDetails, err := impl.ackManager.SyncAndGetAcknowledged(impl.ctx)
if err != nil {
log.Warn("sync timestamp ack manager failed", zap.Error(err))
}
// Add ack details to ackDetails.
impl.ackDetails.AddDetails(ackDetails)
}
// sendTsMsg sends first timestamp message to wal.
// TODO: TT lag warning.
func (impl *timeTickAppendInterceptor) sendTsMsg(_ context.Context, appender func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error)) error {
// Sync the timestamp acknowledged details.
impl.syncAcknowledgedDetails()
if impl.ackDetails.Empty() {
// No acknowledged info can be sent.
// Some message sent operation is blocked, new TT cannot be pushed forward.
return nil
}
// Construct time tick message.
msg, err := newTimeTickMsg(impl.ackDetails.LastAllAcknowledgedTimestamp(), impl.sourceID)
if err != nil {
return errors.Wrap(err, "at build time tick msg")
}
// Append it to wal.
msgID, err := appender(impl.ctx, msg)
if err != nil {
return errors.Wrapf(err,
"append time tick msg to wal failed, timestamp: %d, previous message counter: %d",
impl.ackDetails.LastAllAcknowledgedTimestamp(),
impl.ackDetails.Len(),
)
}
// Ack details has been committed to wal, clear it.
impl.ackDetails.Clear()
impl.ackManager.AdvanceLastConfirmedMessageID(msgID)
return nil
}

View File

@ -0,0 +1,48 @@
package timetick
import (
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/streaming/util/message"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
)
func newTimeTickMsg(ts uint64, sourceID int64) (message.MutableMessage, error) {
// TODO: time tick should be put on properties, for compatibility, we put it on message body now.
msgstreamMsg := &msgstream.TimeTickMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: ts,
EndTimestamp: ts,
HashValues: []uint32{0},
},
TimeTickMsg: msgpb.TimeTickMsg{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_TimeTick),
commonpbutil.WithMsgID(0),
commonpbutil.WithTimeStamp(ts),
commonpbutil.WithSourceID(sourceID),
),
},
}
bytes, err := msgstreamMsg.Marshal(msgstreamMsg)
if err != nil {
return nil, errors.Wrap(err, "marshal time tick message failed")
}
payload, ok := bytes.([]byte)
if !ok {
return nil, errors.New("marshal time tick message as []byte failed")
}
// Common message's time tick is set on interceptor.
// TimeTickMsg's time tick should be set here.
msg := message.NewMutableMessageBuilder().
WithMessageType(message.MessageTypeTimeTick).
WithPayload(payload).
BuildMutable().
WithTimeTick(ts)
return msg, nil
}

View File

@ -82,16 +82,14 @@ func (_c *MockWALImpls_Append_Call) RunAndReturn(run func(context.Context, messa
}
// Channel provides a mock function with given fields:
func (_m *MockWALImpls) Channel() *types.PChannelInfo {
func (_m *MockWALImpls) Channel() types.PChannelInfo {
ret := _m.Called()
var r0 *types.PChannelInfo
if rf, ok := ret.Get(0).(func() *types.PChannelInfo); ok {
var r0 types.PChannelInfo
if rf, ok := ret.Get(0).(func() types.PChannelInfo); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*types.PChannelInfo)
}
r0 = ret.Get(0).(types.PChannelInfo)
}
return r0
@ -114,12 +112,12 @@ func (_c *MockWALImpls_Channel_Call) Run(run func()) *MockWALImpls_Channel_Call
return _c
}
func (_c *MockWALImpls_Channel_Call) Return(_a0 *types.PChannelInfo) *MockWALImpls_Channel_Call {
func (_c *MockWALImpls_Channel_Call) Return(_a0 types.PChannelInfo) *MockWALImpls_Channel_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockWALImpls_Channel_Call) RunAndReturn(run func() *types.PChannelInfo) *MockWALImpls_Channel_Call {
func (_c *MockWALImpls_Channel_Call) RunAndReturn(run func() types.PChannelInfo) *MockWALImpls_Channel_Call {
_c.Call.Return(run)
return _c
}

View File

@ -315,6 +315,47 @@ func (_c *MockImmutableMessage_TimeTick_Call) RunAndReturn(run func() uint64) *M
return _c
}
// VChannel provides a mock function with given fields:
func (_m *MockImmutableMessage) VChannel() string {
ret := _m.Called()
var r0 string
if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(string)
}
return r0
}
// MockImmutableMessage_VChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VChannel'
type MockImmutableMessage_VChannel_Call struct {
*mock.Call
}
// VChannel is a helper method to define mock.On call
func (_e *MockImmutableMessage_Expecter) VChannel() *MockImmutableMessage_VChannel_Call {
return &MockImmutableMessage_VChannel_Call{Call: _e.mock.On("VChannel")}
}
func (_c *MockImmutableMessage_VChannel_Call) Run(run func()) *MockImmutableMessage_VChannel_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockImmutableMessage_VChannel_Call) Return(_a0 string) *MockImmutableMessage_VChannel_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockImmutableMessage_VChannel_Call) RunAndReturn(run func() string) *MockImmutableMessage_VChannel_Call {
_c.Call.Return(run)
return _c
}
// Version provides a mock function with given fields:
func (_m *MockImmutableMessage) Version() message.Version {
ret := _m.Called()

View File

@ -232,6 +232,47 @@ func (_c *MockMutableMessage_Properties_Call) RunAndReturn(run func() message.Pr
return _c
}
// Version provides a mock function with given fields:
func (_m *MockMutableMessage) Version() message.Version {
ret := _m.Called()
var r0 message.Version
if rf, ok := ret.Get(0).(func() message.Version); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(message.Version)
}
return r0
}
// MockMutableMessage_Version_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Version'
type MockMutableMessage_Version_Call struct {
*mock.Call
}
// Version is a helper method to define mock.On call
func (_e *MockMutableMessage_Expecter) Version() *MockMutableMessage_Version_Call {
return &MockMutableMessage_Version_Call{Call: _e.mock.On("Version")}
}
func (_c *MockMutableMessage_Version_Call) Run(run func()) *MockMutableMessage_Version_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockMutableMessage_Version_Call) Return(_a0 message.Version) *MockMutableMessage_Version_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockMutableMessage_Version_Call) RunAndReturn(run func() message.Version) *MockMutableMessage_Version_Call {
_c.Call.Return(run)
return _c
}
// WithLastConfirmed provides a mock function with given fields: id
func (_m *MockMutableMessage) WithLastConfirmed(id message.MessageID) message.MutableMessage {
ret := _m.Called(id)

View File

@ -1,60 +1,72 @@
package message
// NewBuilder creates a new builder.
func NewBuilder() *Builder {
return &Builder{
id: nil,
// NewImmutableMessage creates a new immutable message.
func NewImmutableMesasge(
id MessageID,
payload []byte,
properties map[string]string,
) ImmutableMessage {
return &immutableMessageImpl{
id: id,
messageImpl: messageImpl{
payload: payload,
properties: properties,
},
}
}
// NewMutableMessageBuilder creates a new builder.
// Should only used at client side.
func NewMutableMessageBuilder() *MutableMesasgeBuilder {
return &MutableMesasgeBuilder{
payload: nil,
properties: make(propertiesImpl),
}
}
// Builder is the builder for message.
type Builder struct {
id MessageID
// MutableMesasgeBuilder is the builder for message.
type MutableMesasgeBuilder struct {
payload []byte
properties propertiesImpl
}
// WithMessageID creates a new builder with message id.
func (b *Builder) WithMessageID(id MessageID) *Builder {
b.id = id
func (b *MutableMesasgeBuilder) WithMessageType(t MessageType) *MutableMesasgeBuilder {
b.properties.Set(messageTypeKey, t.marshal())
return b
}
// WithMessageType creates a new builder with message type.
func (b *Builder) WithMessageType(t MessageType) *Builder {
b.properties.Set(messageTypeKey, t.marshal())
// WithPayload creates a new builder with message payload.
// The MessageType is required to indicate which message type payload is.
func (b *MutableMesasgeBuilder) WithPayload(payload []byte) *MutableMesasgeBuilder {
b.payload = payload
return b
}
// WithProperty creates a new builder with message property.
// A key started with '_' is reserved for log system, should never used at user of client.
func (b *Builder) WithProperty(key string, val string) *Builder {
func (b *MutableMesasgeBuilder) WithProperty(key string, val string) *MutableMesasgeBuilder {
b.properties.Set(key, val)
return b
}
// WithProperties creates a new builder with message properties.
// A key started with '_' is reserved for log system, should never used at user of client.
func (b *Builder) WithProperties(kvs map[string]string) *Builder {
func (b *MutableMesasgeBuilder) WithProperties(kvs map[string]string) *MutableMesasgeBuilder {
for key, val := range kvs {
b.properties.Set(key, val)
}
return b
}
// WithPayload creates a new builder with message payload.
func (b *Builder) WithPayload(payload []byte) *Builder {
b.payload = payload
return b
}
// BuildMutable builds a mutable message.
// Panic if set the message id.
func (b *Builder) BuildMutable() MutableMessage {
if b.id != nil {
panic("build a mutable message, message id should be nil")
// Panic if not set payload and message type.
// should only used at client side.
func (b *MutableMesasgeBuilder) BuildMutable() MutableMessage {
if b.payload == nil {
panic("message builder not ready for payload field")
}
if !b.properties.Exist(messageTypeKey) {
panic("message builder not ready for message type field")
}
// Set message version.
b.properties.Set(messageVersion, VersionV1.String())
@ -63,18 +75,3 @@ func (b *Builder) BuildMutable() MutableMessage {
properties: b.properties,
}
}
// BuildImmutable builds a immutable message.
// Panic if not set the message id.
func (b *Builder) BuildImmutable() ImmutableMessage {
if b.id == nil {
panic("build a immutable message, message id should not be nil")
}
return &immutableMessageImpl{
id: b.id,
messageImpl: messageImpl{
payload: b.payload,
properties: b.properties,
},
}
}

View File

@ -11,6 +11,11 @@ type BasicMessage interface {
// MessageType returns the type of message.
MessageType() MessageType
// Version returns the message version.
// 0: old version before streamingnode.
// from 1: new version after streamingnode.
Version() Version
// Message payload.
Payload() []byte
@ -47,6 +52,11 @@ type ImmutableMessage interface {
// WALName returns the name of message related wal.
WALName() string
// VChannel returns the virtual channel of current message.
// Available only when the message's version greater than 0.
// Otherwise, it will panic.
VChannel() string
// TimeTick returns the time tick of current message.
// Available only when the message's version greater than 0.
// Otherwise, it will panic.
@ -64,9 +74,4 @@ type ImmutableMessage interface {
// Properties returns the message read only properties.
Properties() RProperties
// Version returns the message format version.
// 0: old version before streamingnode.
// from 1: new version after streamingnode.
Version() Version
}

View File

@ -12,8 +12,9 @@ import (
)
func TestMessage(t *testing.T) {
b := message.NewBuilder()
mutableMessage := b.WithMessageType(message.MessageTypeTimeTick).
b := message.NewMutableMessageBuilder()
mutableMessage := b.
WithMessageType(message.MessageTypeTimeTick).
WithPayload([]byte("payload")).
WithProperties(map[string]string{"key": "value"}).
BuildMutable()
@ -49,17 +50,15 @@ func TestMessage(t *testing.T) {
panic(fmt.Sprintf("unexpected data: %s", data))
})
b = message.NewBuilder()
immutableMessage := b.WithMessageID(msgID).
WithPayload([]byte("payload")).
WithProperties(map[string]string{
immutableMessage := message.NewImmutableMesasge(msgID,
[]byte("payload"),
map[string]string{
"key": "value",
"_t": "1",
"_tt": string(proto.EncodeVarint(456)),
"_v": "1",
"_lc": "lcMsgID",
}).
BuildImmutable()
})
assert.True(t, immutableMessage.MessageID().EQ(msgID))
assert.Equal(t, "payload", string(immutableMessage.Payload()))
@ -73,12 +72,13 @@ func TestMessage(t *testing.T) {
assert.Equal(t, uint64(456), immutableMessage.TimeTick())
assert.NotNil(t, immutableMessage.LastConfirmedMessageID())
b = message.NewBuilder()
immutableMessage = b.WithMessageID(msgID).
WithPayload([]byte("payload")).
WithProperty("key", "value").
WithProperty("_t", "1").
BuildImmutable()
immutableMessage = message.NewImmutableMesasge(
msgID,
[]byte("payload"),
map[string]string{
"key": "value",
"_t": "1",
})
assert.True(t, immutableMessage.MessageID().EQ(msgID))
assert.Equal(t, "payload", string(immutableMessage.Payload()))
@ -97,9 +97,6 @@ func TestMessage(t *testing.T) {
})
assert.Panics(t, func() {
message.NewBuilder().WithMessageID(msgID).BuildMutable()
})
assert.Panics(t, func() {
message.NewBuilder().BuildImmutable()
message.NewMutableMessageBuilder().BuildMutable()
})
}

View File

@ -20,6 +20,15 @@ func (m *messageImpl) MessageType() MessageType {
return unmarshalMessageType(val)
}
// Version returns the message format version.
func (m *messageImpl) Version() Version {
value, ok := m.properties.Get(messageVersion)
if !ok {
return VersionOld
}
return newMessageVersionFromString(value)
}
// Payload returns payload of current message.
func (m *messageImpl) Payload() []byte {
return m.payload
@ -98,16 +107,15 @@ func (m *immutableMessageImpl) MessageID() MessageID {
return m.id
}
func (m *immutableMessageImpl) VChannel() string {
value, ok := m.properties.Get(messageVChannel)
if !ok {
panic(fmt.Sprintf("there's a bug in the message codes, vchannel lost in properties of message, id: %+v", m.id))
}
return value
}
// Properties returns the message read only properties.
func (m *immutableMessageImpl) Properties() RProperties {
return m.properties
}
// Version returns the message format version.
func (m *immutableMessageImpl) Version() Version {
value, ok := m.properties.Get(messageVersion)
if !ok {
return VersionOld
}
return newMessageVersionFromString(value)
}

View File

@ -24,6 +24,12 @@ func (t MessageType) marshal() string {
return strconv.FormatInt(int64(t), 10)
}
// Valid checks if the MessageType is valid.
func (t MessageType) Valid() bool {
return t == MessageTypeTimeTick
// TODO: fill more.
}
// unmarshalMessageType unmarshal MessageType from string.
func unmarshalMessageType(s string) MessageType {
i, err := strconv.ParseInt(s, 10, 32)

View File

@ -6,6 +6,7 @@ const (
messageTypeKey = "_t" // message type key.
messageTimeTick = "_tt" // message time tick.
messageLastConfirmed = "_lc" // message last confirmed message id.
messageVChannel = "_vc" // message virtual channel.
)
var (

View File

@ -23,3 +23,7 @@ func newMessageVersionFromString(s string) Version {
func (v Version) String() string {
return strconv.FormatInt(int64(v), 10)
}
func (v Version) GT(v2 Version) bool {
return v > v2
}

View File

@ -0,0 +1,31 @@
package options
import (
"testing"
"github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message"
"github.com/stretchr/testify/assert"
)
func TestDeliver(t *testing.T) {
policy := DeliverPolicyAll()
assert.Equal(t, DeliverPolicyTypeAll, policy.Policy())
assert.Panics(t, func() {
policy.MessageID()
})
policy = DeliverPolicyLatest()
assert.Equal(t, DeliverPolicyTypeLatest, policy.Policy())
assert.Panics(t, func() {
policy.MessageID()
})
messageID := mock_message.NewMockMessageID(t)
policy = DeliverPolicyStartFrom(messageID)
assert.Equal(t, DeliverPolicyTypeStartFrom, policy.Policy())
assert.Equal(t, messageID, policy.MessageID())
policy = DeliverPolicyStartAfter(messageID)
assert.Equal(t, DeliverPolicyTypeStartAfter, policy.Policy())
assert.Equal(t, messageID, policy.MessageID())
}

View File

@ -57,11 +57,11 @@ func (s *scannerImpl) executeConsume() {
s.Finish(err)
return
}
newImmutableMessage := message.NewBuilder().
WithMessageID(pulsarID{msg.ID()}).
WithPayload(msg.Payload()).
WithProperties(msg.Properties()).
BuildImmutable()
newImmutableMessage := message.NewImmutableMesasge(
pulsarID{msg.ID()},
msg.Payload(),
msg.Properties(),
)
select {
case <-s.Context().Done():

View File

@ -66,11 +66,11 @@ func (s *scannerImpl) executeConsume() {
// record the last message id to avoid repeated consume message.
// and exclude message id should be filterred.
if s.exclude == nil || !s.exclude.EQ(msgID) {
s.msgChannel <- message.NewBuilder().
WithMessageID(msgID).
WithPayload(msg.Payload()).
WithProperties(msg.Properties()).
BuildImmutable()
s.msgChannel <- message.NewImmutableMesasge(
msgID,
msg.Payload(),
msg.Properties(),
)
}
}
}

View File

@ -10,13 +10,13 @@ import (
)
const (
walName = "test"
WALName = "test"
)
func init() {
// register the builder to the registry.
registry.RegisterBuilder(&openerBuilder{})
message.RegisterMessageIDUnmsarshaler(walName, UnmarshalTestMessageID)
message.RegisterMessageIDUnmsarshaler(WALName, UnmarshalTestMessageID)
}
var _ walimpls.OpenerBuilderImpls = &openerBuilder{}
@ -24,7 +24,7 @@ var _ walimpls.OpenerBuilderImpls = &openerBuilder{}
type openerBuilder struct{}
func (o *openerBuilder) Name() string {
return walName
return WALName
}
func (o *openerBuilder) Build() (walimpls.OpenerImpls, error) {

View File

@ -39,7 +39,7 @@ type testMessageID int64
// WALName returns the name of message id related wal.
func (id testMessageID) WALName() string {
return walName
return WALName
}
// LT less than.

View File

@ -247,20 +247,15 @@ func (f *testOneWALImplsFramework) testAppend(ctx context.Context, w WALImpls) (
"const": "t",
}
typ := message.MessageTypeUnknown
msg := message.NewBuilder().
msg := message.NewMutableMessageBuilder().
WithMessageType(typ).
WithPayload(payload).
WithProperties(properties).
WithMessageType(typ).
BuildMutable()
id, err := w.Append(ctx, msg)
assert.NoError(f.t, err)
assert.NotNil(f.t, id)
ids[i] = message.NewBuilder().
WithPayload(payload).
WithProperties(properties).
WithMessageID(id).
WithMessageType(typ).
BuildImmutable()
ids[i] = msg.IntoImmutableMessage(id)
}(i)
}
swg.Wait()
@ -280,19 +275,14 @@ func (f *testOneWALImplsFramework) testAppend(ctx context.Context, w WALImpls) (
"const": "t",
"term": strconv.FormatInt(int64(f.term), 10),
}
msg := message.NewBuilder().
msg := message.NewMutableMessageBuilder().
WithPayload(payload).
WithProperties(properties).
WithMessageType(message.MessageTypeTimeTick).
BuildMutable()
id, err := w.Append(ctx, msg)
assert.NoError(f.t, err)
ids[f.messageCount-1] = message.NewBuilder().
WithPayload(payload).
WithProperties(properties).
WithMessageID(id).
WithMessageType(message.MessageTypeTimeTick).
BuildImmutable()
ids[f.messageCount-1] = msg.IntoImmutableMessage(id)
return ids, nil
}

View File

@ -0,0 +1,41 @@
package syncutil
// Future is a future value that can be set and retrieved.
type Future[T any] struct {
ch chan struct{}
value T
}
// NewFuture creates a new future.
func NewFuture[T any]() *Future[T] {
return &Future[T]{
ch: make(chan struct{}),
}
}
// Set sets the value of the future.
func (f *Future[T]) Set(value T) {
f.value = value
close(f.ch)
}
// Get retrieves the value of the future if set, otherwise block until set.
func (f *Future[T]) Get() T {
<-f.ch
return f.value
}
// Done returns a channel that is closed when the future is set.
func (f *Future[T]) Done() <-chan struct{} {
return f.ch
}
// Ready returns true if the future is set.
func (f *Future[T]) Ready() bool {
select {
case <-f.ch:
return true
default:
return false
}
}

View File

@ -0,0 +1,51 @@
package syncutil
import (
"testing"
"time"
)
func TestFuture_SetAndGet(t *testing.T) {
f := NewFuture[int]()
go func() {
time.Sleep(1 * time.Second) // Simulate some work
f.Set(42)
}()
val := f.Get()
if val != 42 {
t.Errorf("Expected value 42, got %d", val)
}
}
func TestFuture_Done(t *testing.T) {
f := NewFuture[string]()
go func() {
f.Set("done")
}()
select {
case <-f.Done():
// Success
case <-time.After(20 * time.Millisecond):
t.Error("Expected future to be done within 2 seconds")
}
}
func TestFuture_Ready(t *testing.T) {
f := NewFuture[float64]()
go func() {
time.Sleep(20 * time.Millisecond) // Simulate some work
f.Set(3.14)
}()
if f.Ready() {
t.Error("Expected future not to be ready immediately")
}
<-f.Done() // Wait for the future to be set
if !f.Ready() {
t.Error("Expected future to be ready after being set")
}
}

133
pkg/util/typeutil/heap.go Normal file
View File

@ -0,0 +1,133 @@
package typeutil
import (
"container/heap"
"golang.org/x/exp/constraints"
)
var _ HeapInterface = (*heapArray[int])(nil)
// HeapInterface is the interface that a heap must implement.
type HeapInterface interface {
heap.Interface
Peek() interface{}
}
// Heap is a heap of E.
// Use `golang.org/x/exp/constraints` directly if you want to change any element.
type Heap[E any] interface {
// Len returns the size of the heap.
Len() int
// Push pushes an element onto the heap.
Push(x E)
// Pop returns the element at the top of the heap.
// Panics if the heap is empty.
Pop() E
// Peek returns the element at the top of the heap.
// Panics if the heap is empty.
Peek() E
}
// heapArray is a heap backed by an array.
type heapArray[E constraints.Ordered] []E
// Len returns the length of the heap.
func (h heapArray[E]) Len() int {
return len(h)
}
// Less returns true if the element at index i is less than the element at index j.
func (h heapArray[E]) Less(i, j int) bool {
return h[i] < h[j]
}
// Swap swaps the elements at indexes i and j.
func (h heapArray[E]) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
// Push pushes the last one at len.
func (h *heapArray[E]) Push(x interface{}) {
// Push and Pop use pointer receivers because they modify the slice's length,
// not just its contents.
*h = append(*h, x.(E))
}
// Pop pop the last one at len.
func (h *heapArray[E]) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
// Peek returns the element at the top of the heap.
func (h *heapArray[E]) Peek() interface{} {
return (*h)[0]
}
// reverseOrderedInterface is a heap base interface that reverses the order of the elements.
type reverseOrderedInterface[E constraints.Ordered] struct {
HeapInterface
}
// Less returns true if the element at index j is less than the element at index i.
func (r reverseOrderedInterface[E]) Less(i, j int) bool {
return r.HeapInterface.Less(j, i)
}
// NewHeap returns a new heap from a underlying representation.
func NewHeap[E any](inner HeapInterface) Heap[E] {
return &heapImpl[E, HeapInterface]{
inner: inner,
}
}
// NewArrayBasedMaximumHeap returns a new maximum heap.
func NewArrayBasedMaximumHeap[E constraints.Ordered](initial []E) Heap[E] {
ha := heapArray[E](initial)
reverse := reverseOrderedInterface[E]{
HeapInterface: &ha,
}
heap.Init(reverse)
return &heapImpl[E, reverseOrderedInterface[E]]{
inner: reverse,
}
}
// NewArrayBasedMinimumHeap returns a new minimum heap.
func NewArrayBasedMinimumHeap[E constraints.Ordered](initial []E) Heap[E] {
ha := heapArray[E](initial)
heap.Init(&ha)
return &heapImpl[E, *heapArray[E]]{
inner: &ha,
}
}
// heapImpl is a min-heap of E.
type heapImpl[E any, H HeapInterface] struct {
inner H
}
// Len returns the length of the heap.
func (h *heapImpl[E, H]) Len() int {
return h.inner.Len()
}
// Push pushes an element onto the heap.
func (h *heapImpl[E, H]) Push(x E) {
heap.Push(h.inner, x)
}
// Pop pops an element from the heap.
func (h *heapImpl[E, H]) Pop() E {
return heap.Pop(h.inner).(E)
}
// Peek returns the element at the top of the heap.
func (h *heapImpl[E, H]) Peek() E {
return h.inner.Peek().(E)
}

View File

@ -0,0 +1,41 @@
package typeutil
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestMinimumHeap(t *testing.T) {
h := []int{4, 5, 2}
heap := NewArrayBasedMinimumHeap(h)
assert.Equal(t, 2, heap.Peek())
assert.Equal(t, 3, heap.Len())
heap.Push(3)
assert.Equal(t, 2, heap.Peek())
assert.Equal(t, 4, heap.Len())
heap.Push(1)
assert.Equal(t, 1, heap.Peek())
assert.Equal(t, 5, heap.Len())
for i := 1; i <= 5; i++ {
assert.Equal(t, i, heap.Peek())
assert.Equal(t, i, heap.Pop())
}
}
func TestMaximumHeap(t *testing.T) {
h := []int{4, 1, 2}
heap := NewArrayBasedMaximumHeap(h)
assert.Equal(t, 4, heap.Peek())
assert.Equal(t, 3, heap.Len())
heap.Push(3)
assert.Equal(t, 4, heap.Peek())
assert.Equal(t, 4, heap.Len())
heap.Push(5)
assert.Equal(t, 5, heap.Peek())
assert.Equal(t, 5, heap.Len())
for i := 5; i >= 1; i-- {
assert.Equal(t, i, heap.Peek())
assert.Equal(t, i, heap.Pop())
}
}