mirror of https://github.com/milvus-io/milvus.git
Check target node ID for query/search (#20967)
Signed-off-by: yah01 <yang.cen@zilliz.com> Signed-off-by: yah01 <yang.cen@zilliz.com>pull/20969/head
parent
8094eea2ed
commit
25a3b9ae19
|
@ -404,8 +404,10 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string) error {
|
||||
retrieveReq := typeutil.Clone(t.RetrieveRequest)
|
||||
retrieveReq.GetBase().TargetID = nodeID
|
||||
req := &querypb.QueryRequest{
|
||||
Req: t.RetrieveRequest,
|
||||
Req: retrieveReq,
|
||||
DmlChannels: channelIDs,
|
||||
Scope: querypb.DataScope_All,
|
||||
}
|
||||
|
|
|
@ -477,8 +477,10 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string) error {
|
||||
searchReq := typeutil.Clone(t.SearchRequest)
|
||||
searchReq.GetBase().TargetID = nodeID
|
||||
req := &querypb.SearchRequest{
|
||||
Req: t.SearchRequest,
|
||||
Req: searchReq,
|
||||
DmlChannels: channelIDs,
|
||||
Scope: querypb.DataScope_All,
|
||||
}
|
||||
|
|
|
@ -710,6 +710,15 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
|
|||
zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()),
|
||||
zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp()))
|
||||
|
||||
if req.GetReq().GetBase().GetTargetID() != paramtable.GetNodeID() {
|
||||
return &internalpb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
|
||||
Reason: common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), paramtable.GetNodeID()),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
failRet := &internalpb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
|
@ -1065,6 +1074,15 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
|
|||
zap.Uint64("guaranteeTimestamp", req.Req.GetGuaranteeTimestamp()),
|
||||
zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp()))
|
||||
|
||||
if req.GetReq().GetBase().GetTargetID() != paramtable.GetNodeID() {
|
||||
return &internalpb.RetrieveResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
|
||||
Reason: common.WrapNodeIDNotMatchMsg(req.GetReq().GetBase().GetTargetID(), paramtable.GetNodeID()),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
failRet := &internalpb.RetrieveResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
|
|
|
@ -686,6 +686,15 @@ func TestImpl_Search(t *testing.T) {
|
|||
DmlChannels: []string{defaultDMLChannel},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
req.GetBase().TargetID = -1
|
||||
ret, err := node.Search(ctx, &queryPb.SearchRequest{
|
||||
Req: req,
|
||||
FromShardLeader: false,
|
||||
DmlChannels: []string{defaultDMLChannel},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_NodeIDNotMatch, ret.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
func TestImpl_searchWithDmlChannel(t *testing.T) {
|
||||
|
@ -797,6 +806,15 @@ func TestImpl_Query(t *testing.T) {
|
|||
DmlChannels: []string{defaultDMLChannel},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
req.GetBase().TargetID = -1
|
||||
ret, err := node.Query(ctx, &queryPb.QueryRequest{
|
||||
Req: req,
|
||||
FromShardLeader: false,
|
||||
DmlChannels: []string{defaultDMLChannel},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_NodeIDNotMatch, ret.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
func TestImpl_queryWithDmlChannel(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue