Delete unused messages for the mq (#20295)

Signed-off-by: SimFG <bang.fu@zilliz.com>

Signed-off-by: SimFG <bang.fu@zilliz.com>
pull/20257/head
SimFG 2022-11-03 21:41:35 +08:00 committed by GitHub
parent e0506352a4
commit deb9963d0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 25 additions and 806 deletions

View File

@ -439,51 +439,6 @@ func TestStream_PulsarMsgStream_Delete(t *testing.T) {
outputStream.Close()
}
func TestStream_PulsarMsgStream_Search(t *testing.T) {
pulsarAddress := getPulsarAddress()
c := funcutil.RandomString(8)
producerChannels := []string{c}
consumerChannels := []string{c}
consumerSubName := funcutil.RandomString(8)
msgPack := MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 3))
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
inputStream.Close()
outputStream.Close()
}
func TestStream_PulsarMsgStream_SearchResult(t *testing.T) {
pulsarAddress := getPulsarAddress()
c := funcutil.RandomString(8)
producerChannels := []string{c}
consumerChannels := []string{c}
consumerSubName := funcutil.RandomString(8)
msgPack := MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_SearchResult, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_SearchResult, 3))
ctx := context.Background()
inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels)
outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName)
err := inputStream.Produce(&msgPack)
require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err))
receiveMsg(ctx, outputStream, len(msgPack.Msgs))
inputStream.Close()
outputStream.Close()
}
func TestStream_PulsarMsgStream_TimeTick(t *testing.T) {
pulsarAddress := getPulsarAddress()
c := funcutil.RandomString(8)
@ -672,8 +627,8 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) {
msgPack := MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_TimeTick, 1))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 2))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_SearchResult, 3))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, 2))
msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Delete, 3))
factory := ProtoUDFactory{}
@ -1572,8 +1527,8 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) {
// would not dedup for non-dml messages
msgPack2 := MsgPack{}
msgPack2.Msgs = append(msgPack2.Msgs, getTsMsg(commonpb.MsgType_Search, 2))
msgPack2.Msgs = append(msgPack2.Msgs, getTsMsg(commonpb.MsgType_Search, 2))
msgPack2.Msgs = append(msgPack2.Msgs, getTsMsg(commonpb.MsgType_CreateCollection, 2))
msgPack2.Msgs = append(msgPack2.Msgs, getTsMsg(commonpb.MsgType_CreateCollection, 2))
msgPack3 := MsgPack{}
msgPack3.Msgs = append(msgPack3.Msgs, getTimeTickMsg(15))
@ -1608,8 +1563,8 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) {
seekMsg := consumer(ctx, outputStream)
assert.Equal(t, len(seekMsg.Msgs), 1+2)
assert.EqualValues(t, seekMsg.Msgs[0].BeginTs(), 1)
assert.Equal(t, commonpb.MsgType_Search, seekMsg.Msgs[1].Type())
assert.Equal(t, commonpb.MsgType_Search, seekMsg.Msgs[2].Type())
assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[1].Type())
assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[2].Type())
Close(rocksdbName, inputStream, outputStream, etcdKV)
}
@ -1958,37 +1913,29 @@ func getTsMsg(msgType MsgType, reqID UniqueID) TsMsg {
DeleteRequest: deleteRequest,
}
return deleteMsg
case commonpb.MsgType_Search:
searchRequest := internalpb.SearchRequest{
case commonpb.MsgType_CreateCollection:
createCollectionRequest := internalpb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
MsgType: commonpb.MsgType_CreateCollection,
MsgID: reqID,
Timestamp: 11,
SourceID: reqID,
},
ReqID: 0,
DbName: "test_db",
CollectionName: "test_collection",
PartitionName: "test_partition",
DbID: 4,
CollectionID: 5,
PartitionID: 6,
Schema: []byte{},
VirtualChannelNames: []string{},
PhysicalChannelNames: []string{},
}
searchMsg := &SearchMsg{
BaseMsg: baseMsg,
SearchRequest: searchRequest,
createCollectionMsg := &CreateCollectionMsg{
BaseMsg: baseMsg,
CreateCollectionRequest: createCollectionRequest,
}
return searchMsg
case commonpb.MsgType_SearchResult:
searchResult := internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
MsgID: reqID,
Timestamp: 1,
SourceID: reqID,
},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
ReqID: 0,
}
searchResultMsg := &SearchResultMsg{
BaseMsg: baseMsg,
SearchResults: searchResult,
}
return searchResultMsg
return createCollectionMsg
case commonpb.MsgType_TimeTick:
timeTickResult := internalpb.TimeTickMsg{
Base: &commonpb.MsgBase{

View File

@ -20,7 +20,6 @@ import (
"context"
"errors"
"fmt"
"time"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
@ -32,8 +31,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/timerecord"
)
// MsgType is an alias of commonpb.MsgType
@ -383,292 +380,6 @@ func (dt *DeleteMsg) CheckAligned() error {
return nil
}
/////////////////////////////////////////Search//////////////////////////////////////////
// SearchMsg is a message pack that contains search request
type SearchMsg struct {
BaseMsg
internalpb.SearchRequest
tr *timerecord.TimeRecorder
}
// interface implementation validation
var _ TsMsg = &SearchMsg{}
// ID returns the ID of this message pack
func (st *SearchMsg) ID() UniqueID {
return st.Base.MsgID
}
// Type returns the type of this message pack
func (st *SearchMsg) Type() MsgType {
return st.Base.MsgType
}
// SourceID indicates which component generated this message
func (st *SearchMsg) SourceID() int64 {
return st.Base.SourceID
}
// GuaranteeTs returns the guarantee timestamp that querynode can perform this search request. This timestamp
// filled in client(e.g. pymilvus). The timestamp will be 0 if client never execute any insert, otherwise equals
// the timestamp from last insert response.
func (st *SearchMsg) GuaranteeTs() Timestamp {
return st.GetGuaranteeTimestamp()
}
// TravelTs returns the timestamp of a time travel search request
func (st *SearchMsg) TravelTs() Timestamp {
return st.GetTravelTimestamp()
}
// TimeoutTs returns the timestamp of timeout
func (st *SearchMsg) TimeoutTs() Timestamp {
return st.GetTimeoutTimestamp()
}
// SetTimeRecorder sets the timeRecorder for RetrieveMsg
func (st *SearchMsg) SetTimeRecorder() {
st.tr = timerecord.NewTimeRecorder("searchMsg")
}
// ElapseSpan returns the duration from the beginning
func (st *SearchMsg) ElapseSpan() time.Duration {
return st.tr.ElapseSpan()
}
// RecordSpan returns the duration from last record
func (st *SearchMsg) RecordSpan() time.Duration {
return st.tr.RecordSpan()
}
// Marshal is used to serializing a message pack to byte array
func (st *SearchMsg) Marshal(input TsMsg) (MarshalType, error) {
searchTask := input.(*SearchMsg)
searchRequest := &searchTask.SearchRequest
mb, err := proto.Marshal(searchRequest)
if err != nil {
return nil, err
}
return mb, nil
}
// Unmarshal is used to deserializing a message pack from byte array
func (st *SearchMsg) Unmarshal(input MarshalType) (TsMsg, error) {
searchRequest := internalpb.SearchRequest{}
in, err := convertToByteArray(input)
if err != nil {
return nil, err
}
err = proto.Unmarshal(in, &searchRequest)
if err != nil {
return nil, err
}
searchMsg := &SearchMsg{SearchRequest: searchRequest}
searchMsg.BeginTimestamp = searchMsg.Base.Timestamp
searchMsg.EndTimestamp = searchMsg.Base.Timestamp
return searchMsg, nil
}
/////////////////////////////////////////SearchResult//////////////////////////////////////////
// SearchResultMsg is a message pack that contains the result of search request
type SearchResultMsg struct {
BaseMsg
internalpb.SearchResults
}
// interface implementation validation
var _ TsMsg = &SearchResultMsg{}
// ID returns the ID of this message pack
func (srt *SearchResultMsg) ID() UniqueID {
return srt.Base.MsgID
}
// Type returns the type of this message pack
func (srt *SearchResultMsg) Type() MsgType {
return srt.Base.MsgType
}
// SourceID indicates which component generated this message
func (srt *SearchResultMsg) SourceID() int64 {
return srt.Base.SourceID
}
// Marshal is used to serializing a message pack to byte array
func (srt *SearchResultMsg) Marshal(input TsMsg) (MarshalType, error) {
searchResultTask := input.(*SearchResultMsg)
searchResultRequest := &searchResultTask.SearchResults
mb, err := proto.Marshal(searchResultRequest)
if err != nil {
return nil, err
}
return mb, nil
}
// Unmarshal is used to deserializing a message pack from byte array
func (srt *SearchResultMsg) Unmarshal(input MarshalType) (TsMsg, error) {
searchResultRequest := internalpb.SearchResults{}
in, err := convertToByteArray(input)
if err != nil {
return nil, err
}
err = proto.Unmarshal(in, &searchResultRequest)
if err != nil {
return nil, err
}
searchResultMsg := &SearchResultMsg{SearchResults: searchResultRequest}
searchResultMsg.BeginTimestamp = searchResultMsg.Base.Timestamp
searchResultMsg.EndTimestamp = searchResultMsg.Base.Timestamp
return searchResultMsg, nil
}
////////////////////////////////////////Retrieve/////////////////////////////////////////
// RetrieveMsg is a message pack that contains retrieve request
type RetrieveMsg struct {
BaseMsg
internalpb.RetrieveRequest
tr *timerecord.TimeRecorder
}
// interface implementation validation
var _ TsMsg = &RetrieveMsg{}
// ID returns the ID of this message pack
func (rm *RetrieveMsg) ID() UniqueID {
return rm.Base.MsgID
}
// Type returns the type of this message pack
func (rm *RetrieveMsg) Type() MsgType {
return rm.Base.MsgType
}
// SourceID indicates which component generated this message
func (rm *RetrieveMsg) SourceID() int64 {
return rm.Base.SourceID
}
// GuaranteeTs returns the guarantee timestamp that querynode can perform this query request. This timestamp
// filled in client(e.g. pymilvus). The timestamp will be 0 if client never execute any insert, otherwise equals
// the timestamp from last insert response.
func (rm *RetrieveMsg) GuaranteeTs() Timestamp {
return rm.GetGuaranteeTimestamp()
}
// TravelTs returns the timestamp of a time travel query request
func (rm *RetrieveMsg) TravelTs() Timestamp {
return rm.GetTravelTimestamp()
}
// TimeoutTs returns the timestamp of timeout
func (rm *RetrieveMsg) TimeoutTs() Timestamp {
return rm.GetTimeoutTimestamp()
}
// SetTimeRecorder sets the timeRecorder for RetrieveMsg
func (rm *RetrieveMsg) SetTimeRecorder() {
rm.tr = timerecord.NewTimeRecorder("retrieveMsg")
}
// ElapseSpan returns the duration from the beginning
func (rm *RetrieveMsg) ElapseSpan() time.Duration {
return rm.tr.ElapseSpan()
}
// RecordSpan returns the duration from last record
func (rm *RetrieveMsg) RecordSpan() time.Duration {
return rm.tr.RecordSpan()
}
// Marshal is used to serializing a message pack to byte array
func (rm *RetrieveMsg) Marshal(input TsMsg) (MarshalType, error) {
retrieveTask := input.(*RetrieveMsg)
retrieveRequest := &retrieveTask.RetrieveRequest
mb, err := proto.Marshal(retrieveRequest)
if err != nil {
return nil, err
}
return mb, nil
}
// Unmarshal is used to deserializing a message pack from byte array
func (rm *RetrieveMsg) Unmarshal(input MarshalType) (TsMsg, error) {
retrieveRequest := internalpb.RetrieveRequest{}
in, err := convertToByteArray(input)
if err != nil {
return nil, err
}
err = proto.Unmarshal(in, &retrieveRequest)
if err != nil {
return nil, err
}
retrieveMsg := &RetrieveMsg{RetrieveRequest: retrieveRequest}
retrieveMsg.BeginTimestamp = retrieveMsg.Base.Timestamp
retrieveMsg.EndTimestamp = retrieveMsg.Base.Timestamp
return retrieveMsg, nil
}
//////////////////////////////////////RetrieveResult///////////////////////////////////////
// RetrieveResultMsg is a message pack that contains the result of query request
type RetrieveResultMsg struct {
BaseMsg
internalpb.RetrieveResults
}
// interface implementation validation
var _ TsMsg = &RetrieveResultMsg{}
// ID returns the ID of this message pack
func (rrm *RetrieveResultMsg) ID() UniqueID {
return rrm.Base.MsgID
}
// Type returns the type of this message pack
func (rrm *RetrieveResultMsg) Type() MsgType {
return rrm.Base.MsgType
}
// SourceID indicates which component generated this message
func (rrm *RetrieveResultMsg) SourceID() int64 {
return rrm.Base.SourceID
}
// Marshal is used to serializing a message pack to byte array
func (rrm *RetrieveResultMsg) Marshal(input TsMsg) (MarshalType, error) {
retrieveResultTask := input.(*RetrieveResultMsg)
retrieveResultRequest := &retrieveResultTask.RetrieveResults
mb, err := proto.Marshal(retrieveResultRequest)
if err != nil {
return nil, err
}
return mb, nil
}
// Unmarshal is used to deserializing a message pack from byte array
func (rrm *RetrieveResultMsg) Unmarshal(input MarshalType) (TsMsg, error) {
retrieveResultRequest := internalpb.RetrieveResults{}
in, err := convertToByteArray(input)
if err != nil {
return nil, err
}
err = proto.Unmarshal(in, &retrieveResultRequest)
if err != nil {
return nil, err
}
retrieveResultMsg := &RetrieveResultMsg{RetrieveResults: retrieveResultRequest}
retrieveResultMsg.BeginTimestamp = retrieveResultMsg.Base.Timestamp
retrieveResultMsg.EndTimestamp = retrieveResultMsg.Base.Timestamp
return retrieveResultMsg, nil
}
/////////////////////////////////////////TimeTick//////////////////////////////////////////
// TimeTickMsg is a message pack that contains time tick only
@ -944,122 +655,6 @@ func (dp *DropPartitionMsg) Unmarshal(input MarshalType) (TsMsg, error) {
return dropPartitionMsg, nil
}
/////////////////////////////////////////LoadIndex//////////////////////////////////////////
// FIXME(wxyu): comment it until really needed
/*
type LoadIndexMsg struct {
BaseMsg
internalpb.LoadIndex
}
// TraceCtx returns the context of opentracing
func (lim *LoadIndexMsg) TraceCtx() context.Context {
return lim.BaseMsg.Ctx
}
// SetTraceCtx is used to set context for opentracing
func (lim *LoadIndexMsg) SetTraceCtx(ctx context.Context) {
lim.BaseMsg.Ctx = ctx
}
// ID returns the ID of this message pack
func (lim *LoadIndexMsg) ID() UniqueID {
return lim.Base.MsgID
}
// Type returns the type of this message pack
func (lim *LoadIndexMsg) Type() MsgType {
return lim.Base.MsgType
}
// SourceID indicated which component generated this message
func (lim *LoadIndexMsg) SourceID() int64 {
return lim.Base.SourceID
}
// Marshal is used to serializing a message pack to byte array
func (lim *LoadIndexMsg) Marshal(input TsMsg) (MarshalType, error) {
loadIndexMsg := input.(*LoadIndexMsg)
loadIndexRequest := &loadIndexMsg.LoadIndex
mb, err := proto.Marshal(loadIndexRequest)
if err != nil {
return nil, err
}
return mb, nil
}
// Unmarshal is used to deserializing a message pack from byte array
func (lim *LoadIndexMsg) Unmarshal(input MarshalType) (TsMsg, error) {
loadIndexRequest := internalpb.LoadIndex{}
in, err := convertToByteArray(input)
if err != nil {
return nil, err
}
err = proto.Unmarshal(in, &loadIndexRequest)
if err != nil {
return nil, err
}
loadIndexMsg := &LoadIndexMsg{LoadIndex: loadIndexRequest}
return loadIndexMsg, nil
}
*/
/////////////////////////////////////////SealedSegmentsChangeInfoMsg//////////////////////////////////////////
// SealedSegmentsChangeInfoMsg is a message pack that contains sealed segments change info
type SealedSegmentsChangeInfoMsg struct {
BaseMsg
querypb.SealedSegmentsChangeInfo
}
// interface implementation validation
var _ TsMsg = &SealedSegmentsChangeInfoMsg{}
// ID returns the ID of this message pack
func (s *SealedSegmentsChangeInfoMsg) ID() UniqueID {
return s.Base.MsgID
}
// Type returns the type of this message pack
func (s *SealedSegmentsChangeInfoMsg) Type() MsgType {
return s.Base.MsgType
}
// SourceID indicates which component generated this message
func (s *SealedSegmentsChangeInfoMsg) SourceID() int64 {
return s.Base.SourceID
}
// Marshal is used to serializing a message pack to byte array
func (s *SealedSegmentsChangeInfoMsg) Marshal(input TsMsg) (MarshalType, error) {
changeInfoMsg := input.(*SealedSegmentsChangeInfoMsg)
changeInfo := &changeInfoMsg.SealedSegmentsChangeInfo
mb, err := proto.Marshal(changeInfo)
if err != nil {
return nil, err
}
return mb, nil
}
// Unmarshal is used to deserializing a message pack from byte array
func (s *SealedSegmentsChangeInfoMsg) Unmarshal(input MarshalType) (TsMsg, error) {
changeInfo := querypb.SealedSegmentsChangeInfo{}
in, err := convertToByteArray(input)
if err != nil {
return nil, err
}
err = proto.Unmarshal(in, &changeInfo)
if err != nil {
return nil, err
}
changeInfoMsg := &SealedSegmentsChangeInfoMsg{SealedSegmentsChangeInfo: changeInfo}
changeInfoMsg.BeginTimestamp = changeInfo.Base.Timestamp
changeInfoMsg.EndTimestamp = changeInfo.Base.Timestamp
return changeInfoMsg, nil
}
/////////////////////////////////////////DataNodeTtMsg//////////////////////////////////////////
// DataNodeTtMsg is a message pack that contains datanode time tick

View File

@ -26,7 +26,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
)
func TestBaseMsg(t *testing.T) {
@ -320,240 +319,6 @@ func TestDeleteMsg_Unmarshal_IllegalParameter(t *testing.T) {
assert.Nil(t, tsMsg)
}
func TestSearchMsg(t *testing.T) {
searchMsg := &SearchMsg{
BaseMsg: generateBaseMsg(),
SearchRequest: internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
MsgID: 1,
Timestamp: 2,
SourceID: 3,
},
DbID: 4,
CollectionID: 5,
PartitionIDs: []int64{},
Dsl: "dsl",
PlaceholderGroup: []byte{},
DslType: commonpb.DslType_BoolExprV1,
SerializedExprPlan: []byte{},
OutputFieldsId: []int64{},
TravelTimestamp: 6,
GuaranteeTimestamp: 7,
TimeoutTimestamp: 8,
},
}
assert.NotNil(t, searchMsg.TraceCtx())
ctx := context.Background()
searchMsg.SetTraceCtx(ctx)
assert.Equal(t, ctx, searchMsg.TraceCtx())
assert.Equal(t, int64(1), searchMsg.ID())
assert.Equal(t, commonpb.MsgType_Search, searchMsg.Type())
assert.Equal(t, int64(3), searchMsg.SourceID())
assert.Equal(t, uint64(7), searchMsg.GuaranteeTs())
assert.Equal(t, uint64(6), searchMsg.TravelTs())
assert.Equal(t, uint64(8), searchMsg.TimeoutTs())
bytes, err := searchMsg.Marshal(searchMsg)
assert.Nil(t, err)
tsMsg, err := searchMsg.Unmarshal(bytes)
assert.Nil(t, err)
searchMsg2, ok := tsMsg.(*SearchMsg)
assert.True(t, ok)
assert.Equal(t, int64(1), searchMsg2.ID())
assert.Equal(t, commonpb.MsgType_Search, searchMsg2.Type())
assert.Equal(t, int64(3), searchMsg2.SourceID())
assert.Equal(t, uint64(7), searchMsg2.GuaranteeTs())
assert.Equal(t, uint64(6), searchMsg2.TravelTs())
}
func TestSearchMsg_Unmarshal_IllegalParameter(t *testing.T) {
searchMsg := &SearchMsg{}
tsMsg, err := searchMsg.Unmarshal(10)
assert.NotNil(t, err)
assert.Nil(t, tsMsg)
}
func TestSearchResultMsg(t *testing.T) {
searchResultMsg := &SearchResultMsg{
BaseMsg: generateBaseMsg(),
SearchResults: internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
MsgID: 1,
Timestamp: 2,
SourceID: 3,
},
MetricType: "l2",
NumQueries: 5,
TopK: 6,
SealedSegmentIDsSearched: []int64{7},
ChannelIDsSearched: []string{"test-searched"},
GlobalSealedSegmentIDs: []int64{8},
},
}
assert.NotNil(t, searchResultMsg.TraceCtx())
ctx := context.Background()
searchResultMsg.SetTraceCtx(ctx)
assert.Equal(t, ctx, searchResultMsg.TraceCtx())
assert.Equal(t, int64(1), searchResultMsg.ID())
assert.Equal(t, commonpb.MsgType_SearchResult, searchResultMsg.Type())
assert.Equal(t, int64(3), searchResultMsg.SourceID())
bytes, err := searchResultMsg.Marshal(searchResultMsg)
assert.Nil(t, err)
tsMsg, err := searchResultMsg.Unmarshal(bytes)
assert.Nil(t, err)
searchResultMsg2, ok := tsMsg.(*SearchResultMsg)
assert.True(t, ok)
assert.Equal(t, int64(1), searchResultMsg2.ID())
assert.Equal(t, commonpb.MsgType_SearchResult, searchResultMsg2.Type())
assert.Equal(t, int64(3), searchResultMsg2.SourceID())
}
func TestSearchResultMsg_Unmarshal_IllegalParameter(t *testing.T) {
searchResultMsg := &SearchResultMsg{}
tsMsg, err := searchResultMsg.Unmarshal(10)
assert.NotNil(t, err)
assert.Nil(t, tsMsg)
}
func TestRetrieveMsg(t *testing.T) {
retrieveMsg := &RetrieveMsg{
BaseMsg: generateBaseMsg(),
RetrieveRequest: internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
MsgID: 1,
Timestamp: 2,
SourceID: 3,
},
DbID: 5,
CollectionID: 6,
PartitionIDs: []int64{7, 8},
SerializedExprPlan: []byte{},
OutputFieldsId: []int64{8, 9},
TravelTimestamp: 10,
GuaranteeTimestamp: 11,
TimeoutTimestamp: 12,
},
}
assert.NotNil(t, retrieveMsg.TraceCtx())
ctx := context.Background()
retrieveMsg.SetTraceCtx(ctx)
assert.Equal(t, ctx, retrieveMsg.TraceCtx())
assert.Equal(t, int64(1), retrieveMsg.ID())
assert.Equal(t, commonpb.MsgType_Retrieve, retrieveMsg.Type())
assert.Equal(t, int64(3), retrieveMsg.SourceID())
assert.Equal(t, uint64(11), retrieveMsg.GuaranteeTs())
assert.Equal(t, uint64(10), retrieveMsg.TravelTs())
assert.Equal(t, uint64(12), retrieveMsg.TimeoutTs())
bytes, err := retrieveMsg.Marshal(retrieveMsg)
assert.Nil(t, err)
tsMsg, err := retrieveMsg.Unmarshal(bytes)
assert.Nil(t, err)
retrieveMsg2, ok := tsMsg.(*RetrieveMsg)
assert.True(t, ok)
assert.Equal(t, int64(1), retrieveMsg2.ID())
assert.Equal(t, commonpb.MsgType_Retrieve, retrieveMsg2.Type())
assert.Equal(t, int64(3), retrieveMsg2.SourceID())
assert.Equal(t, uint64(11), retrieveMsg2.GuaranteeTs())
assert.Equal(t, uint64(10), retrieveMsg2.TravelTs())
}
func TestRetrieveMsg_Unmarshal_IllegalParameter(t *testing.T) {
retrieveMsg := &RetrieveMsg{}
tsMsg, err := retrieveMsg.Unmarshal(10)
assert.NotNil(t, err)
assert.Nil(t, tsMsg)
}
func TestRetrieveResultMsg(t *testing.T) {
retrieveResultMsg := &RetrieveResultMsg{
BaseMsg: generateBaseMsg(),
RetrieveResults: internalpb.RetrieveResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RetrieveResult,
MsgID: 1,
Timestamp: 2,
SourceID: 3,
},
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{},
},
},
},
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_FloatVector,
FieldName: "vector_field",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 4,
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: []float32{1.1, 2.2, 3.3, 4.4},
},
},
},
},
FieldId: 5,
},
},
SealedSegmentIDsRetrieved: []int64{6, 7},
ChannelIDsRetrieved: []string{"test-retrieved-channel"},
GlobalSealedSegmentIDs: []int64{8, 9},
},
}
assert.NotNil(t, retrieveResultMsg.TraceCtx())
ctx := context.Background()
retrieveResultMsg.SetTraceCtx(ctx)
assert.Equal(t, ctx, retrieveResultMsg.TraceCtx())
assert.Equal(t, int64(1), retrieveResultMsg.ID())
assert.Equal(t, commonpb.MsgType_RetrieveResult, retrieveResultMsg.Type())
assert.Equal(t, int64(3), retrieveResultMsg.SourceID())
bytes, err := retrieveResultMsg.Marshal(retrieveResultMsg)
assert.Nil(t, err)
tsMsg, err := retrieveResultMsg.Unmarshal(bytes)
assert.Nil(t, err)
retrieveResultMsg2, ok := tsMsg.(*RetrieveResultMsg)
assert.True(t, ok)
assert.Equal(t, int64(1), retrieveResultMsg2.ID())
assert.Equal(t, commonpb.MsgType_RetrieveResult, retrieveResultMsg2.Type())
assert.Equal(t, int64(3), retrieveResultMsg2.SourceID())
}
func TestRetrieveResultMsg_Unmarshal_IllegalParameter(t *testing.T) {
retrieveResultMsg := &RetrieveResultMsg{}
tsMsg, err := retrieveResultMsg.Unmarshal(10)
assert.NotNil(t, err)
assert.Nil(t, tsMsg)
}
func TestTimeTickMsg(t *testing.T) {
timeTickMsg := &TimeTickMsg{
BaseMsg: generateBaseMsg(),
@ -838,67 +603,3 @@ func TestDataNodeTtMsg_Unmarshal_IllegalParameter(t *testing.T) {
assert.NotNil(t, err)
assert.Nil(t, tsMsg)
}
func TestSealedSegmentsChangeInfoMsg(t *testing.T) {
genSimpleSegmentInfo := func(segmentID UniqueID) *querypb.SegmentInfo {
return &querypb.SegmentInfo{
SegmentID: segmentID,
}
}
changeInfo := &querypb.SegmentChangeInfo{
OnlineNodeID: int64(1),
OnlineSegments: []*querypb.SegmentInfo{
genSimpleSegmentInfo(1),
genSimpleSegmentInfo(2),
genSimpleSegmentInfo(3),
},
OfflineNodeID: int64(2),
OfflineSegments: []*querypb.SegmentInfo{
genSimpleSegmentInfo(4),
genSimpleSegmentInfo(5),
genSimpleSegmentInfo(6),
},
}
changeInfoMsg := &SealedSegmentsChangeInfoMsg{
BaseMsg: generateBaseMsg(),
SealedSegmentsChangeInfo: querypb.SealedSegmentsChangeInfo{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SealedSegmentsChangeInfo,
MsgID: 1,
Timestamp: 2,
SourceID: 3,
},
Infos: []*querypb.SegmentChangeInfo{changeInfo},
},
}
assert.NotNil(t, changeInfoMsg.TraceCtx())
ctx := context.Background()
changeInfoMsg.SetTraceCtx(ctx)
assert.Equal(t, ctx, changeInfoMsg.TraceCtx())
assert.Equal(t, int64(1), changeInfoMsg.ID())
assert.Equal(t, commonpb.MsgType_SealedSegmentsChangeInfo, changeInfoMsg.Type())
assert.Equal(t, int64(3), changeInfoMsg.SourceID())
bytes, err := changeInfoMsg.Marshal(changeInfoMsg)
assert.Nil(t, err)
tsMsg, err := changeInfoMsg.Unmarshal(bytes)
assert.Nil(t, err)
changeInfoMsg2, ok := tsMsg.(*SealedSegmentsChangeInfoMsg)
assert.True(t, ok)
assert.Equal(t, int64(1), changeInfoMsg2.ID())
assert.Equal(t, commonpb.MsgType_SealedSegmentsChangeInfo, changeInfoMsg2.Type())
assert.Equal(t, int64(3), changeInfoMsg2.SourceID())
}
func TestSealedSegmentsChangeInfoMsg_Unmarshal_IllegalParameter(t *testing.T) {
changeInfoMsg := &SealedSegmentsChangeInfoMsg{}
tsMsg, err := changeInfoMsg.Unmarshal(10)
assert.NotNil(t, err)
assert.Nil(t, tsMsg)
}

View File

@ -56,33 +56,23 @@ type ProtoUDFactory struct{}
func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher {
insertMsg := InsertMsg{}
deleteMsg := DeleteMsg{}
searchMsg := SearchMsg{}
searchResultMsg := SearchResultMsg{}
retrieveMsg := RetrieveMsg{}
retrieveResultMsg := RetrieveResultMsg{}
timeTickMsg := TimeTickMsg{}
createCollectionMsg := CreateCollectionMsg{}
dropCollectionMsg := DropCollectionMsg{}
createPartitionMsg := CreatePartitionMsg{}
dropPartitionMsg := DropPartitionMsg{}
dataNodeTtMsg := DataNodeTtMsg{}
sealedSegmentsChangeInfoMsg := SealedSegmentsChangeInfoMsg{}
p := &ProtoUnmarshalDispatcher{}
p.TempMap = make(map[commonpb.MsgType]UnmarshalFunc)
p.TempMap[commonpb.MsgType_Insert] = insertMsg.Unmarshal
p.TempMap[commonpb.MsgType_Delete] = deleteMsg.Unmarshal
p.TempMap[commonpb.MsgType_Search] = searchMsg.Unmarshal
p.TempMap[commonpb.MsgType_SearchResult] = searchResultMsg.Unmarshal
p.TempMap[commonpb.MsgType_Retrieve] = retrieveMsg.Unmarshal
p.TempMap[commonpb.MsgType_RetrieveResult] = retrieveResultMsg.Unmarshal
p.TempMap[commonpb.MsgType_TimeTick] = timeTickMsg.Unmarshal
p.TempMap[commonpb.MsgType_CreateCollection] = createCollectionMsg.Unmarshal
p.TempMap[commonpb.MsgType_DropCollection] = dropCollectionMsg.Unmarshal
p.TempMap[commonpb.MsgType_CreatePartition] = createPartitionMsg.Unmarshal
p.TempMap[commonpb.MsgType_DropPartition] = dropPartitionMsg.Unmarshal
p.TempMap[commonpb.MsgType_DataNodeTt] = dataNodeTtMsg.Unmarshal
p.TempMap[commonpb.MsgType_SealedSegmentsChangeInfo] = sealedSegmentsChangeInfoMsg.Unmarshal
return p
}

View File

@ -1480,13 +1480,13 @@ func genSimpleRetrievePlanExpr(schema *schemapb.CollectionSchema) ([]byte, error
}
func genSimpleRetrievePlan(collection *Collection) (*RetrievePlan, error) {
retrieveMsg, err := genRetrieveMsg(collection.schema)
timestamp := Timestamp(1000)
planBytes, err := genSimpleRetrievePlanExpr(collection.schema)
if err != nil {
return nil, err
}
timestamp := retrieveMsg.RetrieveRequest.TravelTimestamp
plan, err2 := createRetrievePlanByExpr(collection, retrieveMsg.SerializedExprPlan, timestamp, 100)
plan, err2 := createRetrievePlanByExpr(collection, planBytes, timestamp, 100)
return plan, err2
}
@ -1546,20 +1546,6 @@ func genRetrieveRequest(schema *schemapb.CollectionSchema) (*internalpb.Retrieve
}, nil
}
func genRetrieveMsg(schema *schemapb.CollectionSchema) (*msgstream.RetrieveMsg, error) {
req, err := genRetrieveRequest(schema)
if err != nil {
return nil, err
}
msg := &msgstream.RetrieveMsg{
BaseMsg: genMsgStreamBaseMsg(),
RetrieveRequest: *req,
}
msg.SetTimeRecorder()
return msg, nil
}
func genQueryResultChannel() Channel {
const queryResultChannelPrefix = "query-node-unittest-query-result-channel-"
return queryResultChannelPrefix + strconv.Itoa(rand.Int())