enable look aside balancer on replica selection (#24791)

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
pull/24963/head
wei liu 2023-06-16 18:38:39 +08:00 committed by GitHub
parent a413842e38
commit 46f7d903a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 754 additions and 48 deletions

View File

@ -362,6 +362,6 @@ generate-mockery: getdeps
$(PWD)/bin/mockery --dir=internal/datacoord --name=compactionPlanContext --filename=mock_compaction_plan_context.go --output=internal/datacoord --structname=MockCompactionPlanContext --with-expecter --inpackage $(PWD)/bin/mockery --dir=internal/datacoord --name=compactionPlanContext --filename=mock_compaction_plan_context.go --output=internal/datacoord --structname=MockCompactionPlanContext --with-expecter --inpackage
$(PWD)/bin/mockery --dir=internal/datacoord --name=Handler --filename=mock_handler.go --output=internal/datacoord --structname=NMockHandler --with-expecter --inpackage $(PWD)/bin/mockery --dir=internal/datacoord --name=Handler --filename=mock_handler.go --output=internal/datacoord --structname=NMockHandler --with-expecter --inpackage
#internal/proxy #internal/proxy
$(PWD)/bin/mockery --name=LBPolicy --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_policy.go --structname=MockLBPolicy --with-expecter --outpkg=proxy $(PWD)/bin/mockery --name=LBPolicy --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_policy.go --structname=MockLBPolicy --with-expecter --outpkg=proxy --inpackage
$(PWD)/bin/mockery --name=LBBalancer --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_balancer.go --structname=MockLBBalancer --with-expecter --outpkg=proxy --inpackage $(PWD)/bin/mockery --name=LBBalancer --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_balancer.go --structname=MockLBBalancer --with-expecter --outpkg=proxy --inpackage
$(PWD)/bin/mockery --name=shardClientMgr --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_shardclient_manager.go --structname=MockShardClientManager --with-expecter --outpkg=proxy --inpackage $(PWD)/bin/mockery --name=shardClientMgr --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_shardclient_manager.go --structname=MockShardClientManager --with-expecter --outpkg=proxy --inpackage

View File

@ -16,6 +16,11 @@
package proxy package proxy
import "github.com/milvus-io/milvus/internal/proto/internalpb"
type LBBalancer interface { type LBBalancer interface {
SelectNode(availableNodes []int64, nq int64) (int64, error) SelectNode(availableNodes []int64, nq int64) (int64, error)
CancelWorkload(node int64, nq int64)
UpdateCostMetrics(node int64, cost *internalpb.CostAggregation)
Close()
} }

View File

@ -18,6 +18,7 @@ package proxy
import ( import (
"context" "context"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
@ -48,6 +49,8 @@ type CollectionWorkLoad struct {
type LBPolicy interface { type LBPolicy interface {
Execute(ctx context.Context, workload CollectionWorkLoad) error Execute(ctx context.Context, workload CollectionWorkLoad) error
ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error
UpdateCostMetrics(node int64, cost *internalpb.CostAggregation)
Close()
} }
type LBPolicyImpl struct { type LBPolicyImpl struct {
@ -55,7 +58,8 @@ type LBPolicyImpl struct {
clientMgr shardClientMgr clientMgr shardClientMgr
} }
func NewLBPolicyImpl(balancer LBBalancer, clientMgr shardClientMgr) *LBPolicyImpl { func NewLBPolicyImpl(clientMgr shardClientMgr) *LBPolicyImpl {
balancer := NewLookAsideBalancer(clientMgr)
return &LBPolicyImpl{ return &LBPolicyImpl{
balancer: balancer, balancer: balancer,
clientMgr: clientMgr, clientMgr: clientMgr,
@ -135,6 +139,9 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
zap.Int64("nodeID", targetNode), zap.Int64("nodeID", targetNode),
zap.Error(err)) zap.Error(err))
excludeNodes.Insert(targetNode) excludeNodes.Insert(targetNode)
// cancel work load which assign to the target node
lb.balancer.CancelWorkload(targetNode, workload.nq)
return merr.WrapErrShardDelegatorAccessFailed(workload.channel, err.Error()) return merr.WrapErrShardDelegatorAccessFailed(workload.channel, err.Error())
} }
@ -144,8 +151,11 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
zap.Int64("nodeID", targetNode), zap.Int64("nodeID", targetNode),
zap.Error(err)) zap.Error(err))
excludeNodes.Insert(targetNode) excludeNodes.Insert(targetNode)
lb.balancer.CancelWorkload(targetNode, workload.nq)
return merr.WrapErrShardDelegatorAccessFailed(workload.channel, err.Error()) return merr.WrapErrShardDelegatorAccessFailed(workload.channel, err.Error())
} }
lb.balancer.CancelWorkload(targetNode, workload.nq)
return nil return nil
}, retry.Attempts(workload.retryTimes)) }, retry.Attempts(workload.retryTimes))
@ -179,3 +189,11 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
err = wg.Wait() err = wg.Wait()
return err return err
} }
func (lb *LBPolicyImpl) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {
lb.balancer.UpdateCostMetrics(node, cost)
}
func (lb *LBPolicyImpl) Close() {
lb.balancer.Close()
}

View File

@ -24,6 +24,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
@ -91,11 +92,13 @@ func (s *LBPolicySuite) SetupTest() {
s.qn = types.NewMockQueryNode(s.T()) s.qn = types.NewMockQueryNode(s.T())
s.qn.EXPECT().GetAddress().Return("localhost").Maybe() s.qn.EXPECT().GetAddress().Return("localhost").Maybe()
s.qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
s.mgr = NewMockShardClientManager(s.T()) s.mgr = NewMockShardClientManager(s.T())
s.mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() s.mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
s.lbBalancer = NewMockLBBalancer(s.T()) s.lbBalancer = NewMockLBBalancer(s.T())
s.lbPolicy = NewLBPolicyImpl(s.lbBalancer, s.mgr) s.lbPolicy = NewLBPolicyImpl(s.mgr)
s.lbPolicy.balancer = s.lbBalancer
err := InitMetaCache(context.Background(), s.rc, s.qc, s.mgr) err := InitMetaCache(context.Background(), s.rc, s.qc, s.mgr)
s.NoError(err) s.NoError(err)
@ -223,6 +226,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.lbBalancer.ExpectedCalls = nil s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
collection: s.collection, collection: s.collection,
channel: s.channels[0], channel: s.channels[0],
@ -255,6 +259,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1)
s.lbBalancer.ExpectedCalls = nil s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
collection: s.collection, collection: s.collection,
channel: s.channels[0], channel: s.channels[0],
@ -270,6 +275,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.mgr.ExpectedCalls = nil s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1)
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
collection: s.collection, collection: s.collection,
channel: s.channels[0], channel: s.channels[0],
@ -287,6 +293,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.ExpectedCalls = nil s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
counter := 0 counter := 0
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
collection: s.collection, collection: s.collection,
@ -310,6 +317,7 @@ func (s *LBPolicySuite) TestExecute() {
// test all channel success // test all channel success
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{ err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{
collection: s.collection, collection: s.collection,
nq: 1, nq: 1,
@ -348,6 +356,11 @@ func (s *LBPolicySuite) TestExecute() {
s.Error(err) s.Error(err)
} }
func (s *LBPolicySuite) TestUpdateCostMetrics() {
s.lbBalancer.EXPECT().UpdateCostMetrics(mock.Anything, mock.Anything)
s.lbPolicy.UpdateCostMetrics(1, &internalpb.CostAggregation{})
}
func TestLBPolicySuite(t *testing.T) { func TestLBPolicySuite(t *testing.T) {
suite.Run(t, new(LBPolicySuite)) suite.Run(t, new(LBPolicySuite))
} }

View File

@ -0,0 +1,185 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"math"
"sync"
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"go.uber.org/atomic"
"go.uber.org/zap"
)
var (
checkQueryNodeHealthInterval = 500 * time.Millisecond
)
type LookAsideBalancer struct {
clientMgr shardClientMgr
// query node -> workload latest metrics
metricsMap *typeutil.ConcurrentMap[int64, *internalpb.CostAggregation]
// query node -> last update metrics ts
metricsUpdateTs *typeutil.ConcurrentMap[int64, int64]
// query node -> total nq of requests which already send but response hasn't received
executingTaskTotalNQ *typeutil.ConcurrentMap[int64, *atomic.Int64]
unreachableQueryNodes *typeutil.ConcurrentSet[int64]
closeCh chan struct{}
closeOnce sync.Once
wg sync.WaitGroup
}
func NewLookAsideBalancer(clientMgr shardClientMgr) *LookAsideBalancer {
balancer := &LookAsideBalancer{
clientMgr: clientMgr,
metricsMap: typeutil.NewConcurrentMap[int64, *internalpb.CostAggregation](),
metricsUpdateTs: typeutil.NewConcurrentMap[int64, int64](),
executingTaskTotalNQ: typeutil.NewConcurrentMap[int64, *atomic.Int64](),
unreachableQueryNodes: typeutil.NewConcurrentSet[int64](),
closeCh: make(chan struct{}),
}
balancer.wg.Add(1)
go balancer.checkQueryNodeHealthLoop()
return balancer
}
func (b *LookAsideBalancer) Close() {
b.closeOnce.Do(func() {
close(b.closeCh)
b.wg.Wait()
})
}
func (b *LookAsideBalancer) SelectNode(availableNodes []int64, cost int64) (int64, error) {
targetNode := int64(-1)
targetScore := float64(math.MaxFloat64)
for _, node := range availableNodes {
if b.unreachableQueryNodes.Contain(node) {
continue
}
cost, _ := b.metricsMap.Get(node)
executingNQ, ok := b.executingTaskTotalNQ.Get(node)
if !ok {
executingNQ = atomic.NewInt64(0)
b.executingTaskTotalNQ.Insert(node, executingNQ)
}
score := b.calculateScore(cost, executingNQ.Load())
if targetNode == -1 || score < targetScore {
targetScore = score
targetNode = node
}
}
// update executing task cost
totalNQ, ok := b.executingTaskTotalNQ.Get(targetNode)
if !ok {
totalNQ = atomic.NewInt64(0)
}
totalNQ.Add(cost)
return targetNode, nil
}
// when task canceled, should reduce executing total nq cost
func (b *LookAsideBalancer) CancelWorkload(node int64, nq int64) {
totalNQ, ok := b.executingTaskTotalNQ.Get(node)
if ok {
totalNQ.Sub(nq)
}
}
// UpdateCostMetrics used for cache some metrics of recent search/query cost
func (b *LookAsideBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {
// cache the latest query node cost metrics for updating the score
b.metricsMap.Insert(node, cost)
b.metricsUpdateTs.Insert(node, time.Now().UnixMilli())
}
// calculateScore compute the query node's workload score
// https://www.usenix.org/conference/nsdi15/technical-sessions/presentation/suresh
func (b *LookAsideBalancer) calculateScore(cost *internalpb.CostAggregation, executingNQ int64) float64 {
if cost == nil || cost.ResponseTime == 0 {
return float64(executingNQ)
}
return float64(cost.ResponseTime) - float64(1)/float64(cost.ServiceTime) + math.Pow(float64(1+cost.TotalNQ+executingNQ), 3.0)/float64(cost.ServiceTime)
}
func (b *LookAsideBalancer) checkQueryNodeHealthLoop() {
defer b.wg.Done()
ticker := time.NewTicker(checkQueryNodeHealthInterval)
defer ticker.Stop()
log.Info("Start check query node health loop")
for {
select {
case <-b.closeCh:
log.Info("check query node health loop exit")
return
case <-ticker.C:
now := time.Now().UnixMilli()
b.metricsUpdateTs.Range(func(node int64, lastUpdateTs int64) bool {
if now-lastUpdateTs > checkQueryNodeHealthInterval.Milliseconds() {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
checkHealthFailed := func(err error) bool {
log.Warn("query node check health failed, add it to unreachable nodes list",
zap.Int64("nodeID", node),
zap.Error(err))
b.unreachableQueryNodes.Insert(node)
return true
}
qn, err := b.clientMgr.GetClient(ctx, node)
if err != nil {
return checkHealthFailed(err)
}
resp, err := qn.GetComponentStates(ctx)
if err != nil {
return checkHealthFailed(err)
}
if resp.GetState().GetStateCode() != commonpb.StateCode_Healthy {
return checkHealthFailed(merr.WrapErrNodeOffline(node))
}
// check health successfully, update check health ts
b.metricsUpdateTs.Insert(node, time.Now().Local().UnixMilli())
b.unreachableQueryNodes.Remove(node)
}
return true
})
}
}
}

View File

@ -0,0 +1,292 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
)
type LookAsideBalancerSuite struct {
suite.Suite
clientMgr *MockShardClientManager
balancer *LookAsideBalancer
}
func (suite *LookAsideBalancerSuite) SetupTest() {
suite.clientMgr = NewMockShardClientManager(suite.T())
suite.balancer = NewLookAsideBalancer(suite.clientMgr)
qn := types.NewMockQueryNode(suite.T())
suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(1)).Return(qn, nil).Maybe()
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, errors.New("fake error")).Maybe()
}
func (suite *LookAsideBalancerSuite) TearDownTest() {
suite.balancer.Close()
}
func (suite *LookAsideBalancerSuite) TestUpdateMetrics() {
costMetrics := &internalpb.CostAggregation{
ResponseTime: 5,
ServiceTime: 1,
TotalNQ: 1,
}
suite.balancer.UpdateCostMetrics(1, costMetrics)
lastUpdateTs, ok := suite.balancer.metricsUpdateTs.Get(1)
suite.True(ok)
suite.True(time.Now().UnixMilli()-lastUpdateTs <= 5)
}
func (suite *LookAsideBalancerSuite) TestCalculateScore() {
costMetrics1 := &internalpb.CostAggregation{
ResponseTime: 5,
ServiceTime: 1,
TotalNQ: 1,
}
costMetrics2 := &internalpb.CostAggregation{
ResponseTime: 5,
ServiceTime: 2,
TotalNQ: 1,
}
costMetrics3 := &internalpb.CostAggregation{
ResponseTime: 10,
ServiceTime: 1,
TotalNQ: 1,
}
costMetrics4 := &internalpb.CostAggregation{
ResponseTime: 5,
ServiceTime: 1,
TotalNQ: 0,
}
score1 := suite.balancer.calculateScore(costMetrics1, 0)
score2 := suite.balancer.calculateScore(costMetrics2, 0)
score3 := suite.balancer.calculateScore(costMetrics3, 0)
score4 := suite.balancer.calculateScore(costMetrics4, 0)
suite.Equal(float64(12), score1)
suite.Equal(float64(8.5), score2)
suite.Equal(float64(17), score3)
suite.Equal(float64(5), score4)
score5 := suite.balancer.calculateScore(costMetrics1, 5)
score6 := suite.balancer.calculateScore(costMetrics2, 5)
score7 := suite.balancer.calculateScore(costMetrics3, 5)
score8 := suite.balancer.calculateScore(costMetrics4, 5)
suite.Equal(float64(347), score5)
suite.Equal(float64(176), score6)
suite.Equal(float64(352), score7)
suite.Equal(float64(220), score8)
}
func (suite *LookAsideBalancerSuite) TestSelectNode() {
type testcase struct {
name string
costMetrics map[int64]*internalpb.CostAggregation
executingNQ map[int64]int64
requestCount int
result map[int64]int64
}
cases := []testcase{
{
name: "each qn has same cost metrics",
costMetrics: map[int64]*internalpb.CostAggregation{
1: {
ResponseTime: 5,
ServiceTime: 1,
TotalNQ: 0,
},
2: {
ResponseTime: 5,
ServiceTime: 1,
TotalNQ: 0,
},
3: {
ResponseTime: 5,
ServiceTime: 1,
TotalNQ: 0,
},
},
executingNQ: map[int64]int64{1: 0, 2: 0, 3: 0},
requestCount: 100,
result: map[int64]int64{1: 34, 2: 33, 3: 33},
},
{
name: "each qn has different service time",
costMetrics: map[int64]*internalpb.CostAggregation{
1: {
ResponseTime: 30,
ServiceTime: 20,
TotalNQ: 0,
},
2: {
ResponseTime: 50,
ServiceTime: 40,
TotalNQ: 0,
},
3: {
ResponseTime: 70,
ServiceTime: 60,
TotalNQ: 0,
},
},
executingNQ: map[int64]int64{1: 0, 2: 0, 3: 0},
requestCount: 100,
result: map[int64]int64{1: 27, 2: 34, 3: 39},
},
{
name: "one qn has task in queue",
costMetrics: map[int64]*internalpb.CostAggregation{
1: {
ResponseTime: 5,
ServiceTime: 1,
TotalNQ: 0,
},
2: {
ResponseTime: 5,
ServiceTime: 1,
TotalNQ: 0,
},
3: {
ResponseTime: 100,
ServiceTime: 1,
TotalNQ: 20,
},
},
executingNQ: map[int64]int64{1: 0, 2: 0, 3: 0},
requestCount: 100,
result: map[int64]int64{1: 40, 2: 40, 3: 20},
},
{
name: "qn with executing task",
costMetrics: map[int64]*internalpb.CostAggregation{
1: {
ResponseTime: 5,
ServiceTime: 1,
TotalNQ: 0,
},
2: {
ResponseTime: 5,
ServiceTime: 1,
TotalNQ: 0,
},
3: {
ResponseTime: 5,
ServiceTime: 1,
TotalNQ: 0,
},
},
executingNQ: map[int64]int64{1: 0, 2: 0, 3: 20},
requestCount: 100,
result: map[int64]int64{1: 40, 2: 40, 3: 20},
},
{
name: "qn with empty metrics",
costMetrics: map[int64]*internalpb.CostAggregation{
1: {},
2: {},
3: {},
},
executingNQ: map[int64]int64{1: 0, 2: 0, 3: 0},
requestCount: 100,
result: map[int64]int64{1: 34, 2: 33, 3: 33},
},
}
for _, c := range cases {
suite.Run(c.name, func() {
for node, cost := range c.costMetrics {
suite.balancer.UpdateCostMetrics(node, cost)
}
for node, executingNQ := range c.executingNQ {
suite.balancer.executingTaskTotalNQ.Insert(node, atomic.NewInt64(executingNQ))
}
counter := make(map[int64]int64)
for i := 0; i < c.requestCount; i++ {
node, err := suite.balancer.SelectNode([]int64{1, 2, 3}, 1)
suite.NoError(err)
counter[node]++
}
for node, result := range c.result {
suite.Equal(result, counter[node])
}
})
}
}
func (suite *LookAsideBalancerSuite) TestCancelWorkload() {
node, err := suite.balancer.SelectNode([]int64{1, 2, 3}, 10)
suite.NoError(err)
suite.balancer.CancelWorkload(node, 10)
executingNQ, ok := suite.balancer.executingTaskTotalNQ.Get(node)
suite.True(ok)
suite.Equal(int64(0), executingNQ.Load())
}
func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() {
qn2 := types.NewMockQueryNode(suite.T())
suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(2)).Return(qn2, nil)
qn2.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
StateCode: commonpb.StateCode_Healthy,
},
}, nil)
suite.balancer.metricsUpdateTs.Insert(1, time.Now().UnixMilli())
suite.balancer.metricsUpdateTs.Insert(2, time.Now().UnixMilli())
suite.Eventually(func() bool {
return suite.balancer.unreachableQueryNodes.Contain(1)
}, 2*time.Second, 100*time.Millisecond)
suite.Eventually(func() bool {
return !suite.balancer.unreachableQueryNodes.Contain(2)
}, 3*time.Second, 100*time.Millisecond)
}
func TestLookAsideBalancerSuite(t *testing.T) {
suite.Run(t, new(LookAsideBalancerSuite))
}

View File

@ -2,7 +2,10 @@
package proxy package proxy
import mock "github.com/stretchr/testify/mock" import (
internalpb "github.com/milvus-io/milvus/internal/proto/internalpb"
mock "github.com/stretchr/testify/mock"
)
// MockLBBalancer is an autogenerated mock type for the LBBalancer type // MockLBBalancer is an autogenerated mock type for the LBBalancer type
type MockLBBalancer struct { type MockLBBalancer struct {
@ -17,6 +20,72 @@ func (_m *MockLBBalancer) EXPECT() *MockLBBalancer_Expecter {
return &MockLBBalancer_Expecter{mock: &_m.Mock} return &MockLBBalancer_Expecter{mock: &_m.Mock}
} }
// CancelWorkload provides a mock function with given fields: node, nq
func (_m *MockLBBalancer) CancelWorkload(node int64, nq int64) {
_m.Called(node, nq)
}
// MockLBBalancer_CancelWorkload_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CancelWorkload'
type MockLBBalancer_CancelWorkload_Call struct {
*mock.Call
}
// CancelWorkload is a helper method to define mock.On call
// - node int64
// - nq int64
func (_e *MockLBBalancer_Expecter) CancelWorkload(node interface{}, nq interface{}) *MockLBBalancer_CancelWorkload_Call {
return &MockLBBalancer_CancelWorkload_Call{Call: _e.mock.On("CancelWorkload", node, nq)}
}
func (_c *MockLBBalancer_CancelWorkload_Call) Run(run func(node int64, nq int64)) *MockLBBalancer_CancelWorkload_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(int64))
})
return _c
}
func (_c *MockLBBalancer_CancelWorkload_Call) Return() *MockLBBalancer_CancelWorkload_Call {
_c.Call.Return()
return _c
}
func (_c *MockLBBalancer_CancelWorkload_Call) RunAndReturn(run func(int64, int64)) *MockLBBalancer_CancelWorkload_Call {
_c.Call.Return(run)
return _c
}
// Close provides a mock function with given fields:
func (_m *MockLBBalancer) Close() {
_m.Called()
}
// MockLBBalancer_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
type MockLBBalancer_Close_Call struct {
*mock.Call
}
// Close is a helper method to define mock.On call
func (_e *MockLBBalancer_Expecter) Close() *MockLBBalancer_Close_Call {
return &MockLBBalancer_Close_Call{Call: _e.mock.On("Close")}
}
func (_c *MockLBBalancer_Close_Call) Run(run func()) *MockLBBalancer_Close_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockLBBalancer_Close_Call) Return() *MockLBBalancer_Close_Call {
_c.Call.Return()
return _c
}
func (_c *MockLBBalancer_Close_Call) RunAndReturn(run func()) *MockLBBalancer_Close_Call {
_c.Call.Return(run)
return _c
}
// SelectNode provides a mock function with given fields: availableNodes, nq // SelectNode provides a mock function with given fields: availableNodes, nq
func (_m *MockLBBalancer) SelectNode(availableNodes []int64, nq int64) (int64, error) { func (_m *MockLBBalancer) SelectNode(availableNodes []int64, nq int64) (int64, error) {
ret := _m.Called(availableNodes, nq) ret := _m.Called(availableNodes, nq)
@ -70,6 +139,40 @@ func (_c *MockLBBalancer_SelectNode_Call) RunAndReturn(run func([]int64, int64)
return _c return _c
} }
// UpdateCostMetrics provides a mock function with given fields: node, cost
func (_m *MockLBBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {
_m.Called(node, cost)
}
// MockLBBalancer_UpdateCostMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCostMetrics'
type MockLBBalancer_UpdateCostMetrics_Call struct {
*mock.Call
}
// UpdateCostMetrics is a helper method to define mock.On call
// - node int64
// - cost *internalpb.CostAggregation
func (_e *MockLBBalancer_Expecter) UpdateCostMetrics(node interface{}, cost interface{}) *MockLBBalancer_UpdateCostMetrics_Call {
return &MockLBBalancer_UpdateCostMetrics_Call{Call: _e.mock.On("UpdateCostMetrics", node, cost)}
}
func (_c *MockLBBalancer_UpdateCostMetrics_Call) Run(run func(node int64, cost *internalpb.CostAggregation)) *MockLBBalancer_UpdateCostMetrics_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(*internalpb.CostAggregation))
})
return _c
}
func (_c *MockLBBalancer_UpdateCostMetrics_Call) Return() *MockLBBalancer_UpdateCostMetrics_Call {
_c.Call.Return()
return _c
}
func (_c *MockLBBalancer_UpdateCostMetrics_Call) RunAndReturn(run func(int64, *internalpb.CostAggregation)) *MockLBBalancer_UpdateCostMetrics_Call {
_c.Call.Return(run)
return _c
}
type mockConstructorTestingTNewMockLBBalancer interface { type mockConstructorTestingTNewMockLBBalancer interface {
mock.TestingT mock.TestingT
Cleanup(func()) Cleanup(func())

View File

@ -5,6 +5,7 @@ package proxy
import ( import (
context "context" context "context"
internalpb "github.com/milvus-io/milvus/internal/proto/internalpb"
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
) )
@ -64,13 +65,13 @@ func (_c *MockLBPolicy_Execute_Call) RunAndReturn(run func(context.Context, Coll
return _c return _c
} }
// ExecuteWithRetry provides a mock function with given fields: ctx, workload, retryTimes // ExecuteWithRetry provides a mock function with given fields: ctx, workload
func (_m *MockLBPolicy) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload, retryTimes uint) error { func (_m *MockLBPolicy) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error {
ret := _m.Called(ctx, workload, retryTimes) ret := _m.Called(ctx, workload)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, ChannelWorkload, uint) error); ok { if rf, ok := ret.Get(0).(func(context.Context, ChannelWorkload) error); ok {
r0 = rf(ctx, workload, retryTimes) r0 = rf(ctx, workload)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@ -86,14 +87,13 @@ type MockLBPolicy_ExecuteWithRetry_Call struct {
// ExecuteWithRetry is a helper method to define mock.On call // ExecuteWithRetry is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - workload ChannelWorkload // - workload ChannelWorkload
// - retryTimes uint func (_e *MockLBPolicy_Expecter) ExecuteWithRetry(ctx interface{}, workload interface{}) *MockLBPolicy_ExecuteWithRetry_Call {
func (_e *MockLBPolicy_Expecter) ExecuteWithRetry(ctx interface{}, workload interface{}, retryTimes interface{}) *MockLBPolicy_ExecuteWithRetry_Call { return &MockLBPolicy_ExecuteWithRetry_Call{Call: _e.mock.On("ExecuteWithRetry", ctx, workload)}
return &MockLBPolicy_ExecuteWithRetry_Call{Call: _e.mock.On("ExecuteWithRetry", ctx, workload, retryTimes)}
} }
func (_c *MockLBPolicy_ExecuteWithRetry_Call) Run(run func(ctx context.Context, workload ChannelWorkload, retryTimes uint)) *MockLBPolicy_ExecuteWithRetry_Call { func (_c *MockLBPolicy_ExecuteWithRetry_Call) Run(run func(ctx context.Context, workload ChannelWorkload)) *MockLBPolicy_ExecuteWithRetry_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(ChannelWorkload), args[2].(uint)) run(args[0].(context.Context), args[1].(ChannelWorkload))
}) })
return _c return _c
} }
@ -103,7 +103,41 @@ func (_c *MockLBPolicy_ExecuteWithRetry_Call) Return(_a0 error) *MockLBPolicy_Ex
return _c return _c
} }
func (_c *MockLBPolicy_ExecuteWithRetry_Call) RunAndReturn(run func(context.Context, ChannelWorkload, uint) error) *MockLBPolicy_ExecuteWithRetry_Call { func (_c *MockLBPolicy_ExecuteWithRetry_Call) RunAndReturn(run func(context.Context, ChannelWorkload) error) *MockLBPolicy_ExecuteWithRetry_Call {
_c.Call.Return(run)
return _c
}
// UpdateCostMetrics provides a mock function with given fields: node, cost
func (_m *MockLBPolicy) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {
_m.Called(node, cost)
}
// MockLBPolicy_UpdateCostMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCostMetrics'
type MockLBPolicy_UpdateCostMetrics_Call struct {
*mock.Call
}
// UpdateCostMetrics is a helper method to define mock.On call
// - node int64
// - cost *internalpb.CostAggregation
func (_e *MockLBPolicy_Expecter) UpdateCostMetrics(node interface{}, cost interface{}) *MockLBPolicy_UpdateCostMetrics_Call {
return &MockLBPolicy_UpdateCostMetrics_Call{Call: _e.mock.On("UpdateCostMetrics", node, cost)}
}
func (_c *MockLBPolicy_UpdateCostMetrics_Call) Run(run func(node int64, cost *internalpb.CostAggregation)) *MockLBPolicy_UpdateCostMetrics_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(*internalpb.CostAggregation))
})
return _c
}
func (_c *MockLBPolicy_UpdateCostMetrics_Call) Return() *MockLBPolicy_UpdateCostMetrics_Call {
_c.Call.Return()
return _c
}
func (_c *MockLBPolicy_UpdateCostMetrics_Call) RunAndReturn(run func(int64, *internalpb.CostAggregation)) *MockLBPolicy_UpdateCostMetrics_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }

View File

@ -127,7 +127,7 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
searchResultCh: make(chan *internalpb.SearchResults, n), searchResultCh: make(chan *internalpb.SearchResults, n),
shardMgr: mgr, shardMgr: mgr,
multiRateLimiter: NewMultiRateLimiter(), multiRateLimiter: NewMultiRateLimiter(),
lbPolicy: NewLBPolicyImpl(NewRoundRobinBalancer(), mgr), lbPolicy: NewLBPolicyImpl(mgr),
} }
node.UpdateStateCode(commonpb.StateCode_Abnormal) node.UpdateStateCode(commonpb.StateCode_Abnormal)
logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load())) logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load()))
@ -437,6 +437,10 @@ func (node *Proxy) Stop() error {
node.chMgr.removeAllDMLStream() node.chMgr.removeAllDMLStream()
} }
if node.lbPolicy != nil {
node.lbPolicy.Close()
}
// https://github.com/milvus-io/milvus/issues/12282 // https://github.com/milvus-io/milvus/issues/12282
node.UpdateStateCode(commonpb.StateCode_Abnormal) node.UpdateStateCode(commonpb.StateCode_Abnormal)

View File

@ -16,38 +16,56 @@
package proxy package proxy
import ( import (
"sync" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"go.uber.org/atomic"
) )
type RoundRobinBalancer struct { type RoundRobinBalancer struct {
// request num send to each node // request num send to each node
mutex sync.RWMutex nodeWorkload *typeutil.ConcurrentMap[int64, *atomic.Int64]
nodeWorkload map[int64]int64
} }
func NewRoundRobinBalancer() *RoundRobinBalancer { func NewRoundRobinBalancer() *RoundRobinBalancer {
return &RoundRobinBalancer{ return &RoundRobinBalancer{
nodeWorkload: make(map[int64]int64), nodeWorkload: typeutil.NewConcurrentMap[int64, *atomic.Int64](),
} }
} }
func (b *RoundRobinBalancer) SelectNode(availableNodes []int64, workload int64) (int64, error) { func (b *RoundRobinBalancer) SelectNode(availableNodes []int64, cost int64) (int64, error) {
if len(availableNodes) == 0 { if len(availableNodes) == 0 {
return -1, merr.ErrNoAvailableNode return -1, merr.ErrNoAvailableNode
} }
b.mutex.Lock()
defer b.mutex.Unlock()
targetNode := int64(-1) targetNode := int64(-1)
targetNodeWorkload := int64(-1) var targetNodeWorkload *atomic.Int64
for _, node := range availableNodes { for _, node := range availableNodes {
if targetNodeWorkload == -1 || b.nodeWorkload[node] < targetNodeWorkload { workload, ok := b.nodeWorkload.Get(node)
if !ok {
workload = atomic.NewInt64(0)
b.nodeWorkload.Insert(node, workload)
}
if targetNodeWorkload == nil || workload.Load() < targetNodeWorkload.Load() {
targetNode = node targetNode = node
targetNodeWorkload = b.nodeWorkload[node] targetNodeWorkload = workload
} }
} }
b.nodeWorkload[targetNode] += workload targetNodeWorkload.Add(cost)
return targetNode, nil return targetNode, nil
} }
func (b *RoundRobinBalancer) CancelWorkload(node int64, nq int64) {
load, ok := b.nodeWorkload.Get(node)
if ok {
load.Sub(nq)
}
}
func (b *RoundRobinBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {}
func (b *RoundRobinBalancer) Close() {}

View File

@ -38,16 +38,24 @@ func (s *RoundRobinBalancerSuite) TestRoundRobin() {
s.balancer.SelectNode(availableNodes, 1) s.balancer.SelectNode(availableNodes, 1)
s.balancer.SelectNode(availableNodes, 1) s.balancer.SelectNode(availableNodes, 1)
s.Equal(int64(2), s.balancer.nodeWorkload[1]) workload, ok := s.balancer.nodeWorkload.Get(1)
s.Equal(int64(2), s.balancer.nodeWorkload[2]) s.True(ok)
s.Equal(int64(2), workload.Load())
workload, ok = s.balancer.nodeWorkload.Get(1)
s.True(ok)
s.Equal(int64(2), workload.Load())
s.balancer.SelectNode(availableNodes, 3) s.balancer.SelectNode(availableNodes, 3)
s.balancer.SelectNode(availableNodes, 1) s.balancer.SelectNode(availableNodes, 1)
s.balancer.SelectNode(availableNodes, 1) s.balancer.SelectNode(availableNodes, 1)
s.balancer.SelectNode(availableNodes, 1) s.balancer.SelectNode(availableNodes, 1)
s.Equal(int64(5), s.balancer.nodeWorkload[1]) workload, ok = s.balancer.nodeWorkload.Get(1)
s.Equal(int64(5), s.balancer.nodeWorkload[2]) s.True(ok)
s.Equal(int64(5), workload.Load())
workload, ok = s.balancer.nodeWorkload.Get(1)
s.True(ok)
s.Equal(int64(5), workload.Load())
} }
func (s *RoundRobinBalancerSuite) TestNoAvailableNode() { func (s *RoundRobinBalancerSuite) TestNoAvailableNode() {
@ -56,6 +64,17 @@ func (s *RoundRobinBalancerSuite) TestNoAvailableNode() {
s.Error(err) s.Error(err)
} }
func (s *RoundRobinBalancerSuite) TestCancelWorkload() {
availableNodes := []int64{101}
_, err := s.balancer.SelectNode(availableNodes, 5)
s.NoError(err)
workload, ok := s.balancer.nodeWorkload.Get(101)
s.True(ok)
s.Equal(int64(5), workload.Load())
s.balancer.CancelWorkload(101, 5)
s.Equal(int64(0), workload.Load())
}
func TestRoundRobinBalancerSuite(t *testing.T) { func TestRoundRobinBalancerSuite(t *testing.T) {
suite.Run(t, new(RoundRobinBalancerSuite)) suite.Run(t, new(RoundRobinBalancerSuite))
} }

View File

@ -491,6 +491,7 @@ func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.Query
log.Debug("get query result") log.Debug("get query result")
t.resultBuf.Insert(result) t.resultBuf.Insert(result)
t.lb.UpdateCostMetrics(nodeID, result.CostAggregation)
return nil return nil
} }

View File

@ -58,6 +58,8 @@ func TestQueryTask_all(t *testing.T) {
hitNum = 10 hitNum = 10
) )
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
qc.EXPECT().Start().Return(nil) qc.EXPECT().Start().Return(nil)
qc.EXPECT().Stop().Return(nil) qc.EXPECT().Stop().Return(nil)
@ -73,12 +75,10 @@ func TestQueryTask_all(t *testing.T) {
}, },
}, nil).Maybe() }, nil).Maybe()
mockCreator := func(ctx context.Context, address string) (types.QueryNode, error) { mgr := NewMockShardClientManager(t)
return qn, nil mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
} mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
lb := NewLBPolicyImpl(mgr)
mgr := newShardClientMgr(withShardClientCreator(mockCreator))
lb := NewLBPolicyImpl(NewRoundRobinBalancer(), mgr)
rc.Start() rc.Start()
defer rc.Stop() defer rc.Stop()
@ -217,10 +217,12 @@ func TestQueryTask_all(t *testing.T) {
task.RetrieveRequest.OutputFieldsId = append(task.RetrieveRequest.OutputFieldsId, common.TimeStampField) task.RetrieveRequest.OutputFieldsId = append(task.RetrieveRequest.OutputFieldsId, common.TimeStampField)
task.ctx = ctx task.ctx = ctx
qn.ExpectedCalls = nil qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) qn.EXPECT().Query(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
assert.Error(t, task.Execute(ctx)) assert.Error(t, task.Execute(ctx))
qn.ExpectedCalls = nil qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{ qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NotShardLeader, ErrorCode: commonpb.ErrorCode_NotShardLeader,
@ -230,6 +232,7 @@ func TestQueryTask_all(t *testing.T) {
assert.True(t, strings.Contains(err.Error(), errInvalidShardLeaders.Error())) assert.True(t, strings.Contains(err.Error(), errInvalidShardLeaders.Error()))
qn.ExpectedCalls = nil qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{ qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -238,6 +241,7 @@ func TestQueryTask_all(t *testing.T) {
assert.Error(t, task.Execute(ctx)) assert.Error(t, task.Execute(ctx))
qn.ExpectedCalls = nil qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(result1, nil) qn.EXPECT().Query(mock.Anything, mock.Anything).Return(result1, nil)
assert.NoError(t, task.Execute(ctx)) assert.NoError(t, task.Execute(ctx))

View File

@ -521,6 +521,7 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que
return fmt.Errorf("fail to Search, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason()) return fmt.Errorf("fail to Search, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason())
} }
t.resultBuf.Insert(result) t.resultBuf.Insert(result)
t.lb.UpdateCostMetrics(nodeID, result.CostAggregation)
return nil return nil
} }

View File

@ -1543,12 +1543,12 @@ func TestSearchTask_ErrExecute(t *testing.T) {
collectionName = t.Name() + funcutil.GenRandomStr() collectionName = t.Name() + funcutil.GenRandomStr()
) )
mockCreator := func(ctx context.Context, address string) (types.QueryNode, error) { qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
return qn, nil
}
mgr := newShardClientMgr(withShardClientCreator(mockCreator)) mgr := NewMockShardClientManager(t)
lb := NewLBPolicyImpl(NewRoundRobinBalancer(), mgr) mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
lb := NewLBPolicyImpl(mgr)
rc.Start() rc.Start()
defer rc.Stop() defer rc.Stop()
@ -1661,6 +1661,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
assert.Error(t, task.Execute(ctx)) assert.Error(t, task.Execute(ctx))
qn.ExpectedCalls = nil qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{ qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_NotShardLeader, ErrorCode: commonpb.ErrorCode_NotShardLeader,
@ -1670,6 +1671,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
assert.True(t, strings.Contains(err.Error(), errInvalidShardLeaders.Error())) assert.True(t, strings.Contains(err.Error(), errInvalidShardLeaders.Error()))
qn.ExpectedCalls = nil qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{ qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
@ -1678,6 +1680,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
assert.Error(t, task.Execute(ctx)) assert.Error(t, task.Execute(ctx))
qn.ExpectedCalls = nil qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{ qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success, ErrorCode: commonpb.ErrorCode_Success,

View File

@ -77,11 +77,11 @@ func (s *StatisticTaskSuite) SetupTest() {
s.rc.Start() s.rc.Start()
s.qn = types.NewMockQueryNode(s.T()) s.qn = types.NewMockQueryNode(s.T())
mockCreator := func(ctx context.Context, addr string) (types.QueryNode, error) { s.qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe()
return s.qn, nil mgr := NewMockShardClientManager(s.T())
} mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil).Maybe()
mgr := newShardClientMgr(withShardClientCreator(mockCreator)) mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe()
s.lb = NewLBPolicyImpl(NewRoundRobinBalancer(), mgr) s.lb = NewLBPolicyImpl(mgr)
err := InitMetaCache(context.Background(), s.rc, s.qc, mgr) err := InitMetaCache(context.Background(), s.rc, s.qc, mgr)
s.NoError(err) s.NoError(err)

View File

@ -1017,6 +1017,12 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
collector.Rate.Add(metricsinfo.NQPerSecond, 1) collector.Rate.Add(metricsinfo.NQPerSecond, 1)
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req))) metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req)))
} }
if ret.CostAggregation != nil {
// update channel's response time
currentTotalNQ := node.scheduler.GetWaitingTaskTotalNQ()
ret.CostAggregation.TotalNQ = currentTotalNQ
}
return ret, nil return ret, nil
} }