Refactor QueryNode ()

Signed-off-by: yah01 <yang.cen@zilliz.com>
Co-authored-by: Congqi Xia <congqi.xia@zilliz.com>
Co-authored-by: aoiasd <zhicheng.yue@zilliz.com>
pull/22298/head
yah01 2023-03-27 00:42:00 +08:00 committed by GitHub
parent 977943463e
commit 081572d31c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
167 changed files with 24994 additions and 485 deletions
internal
distributed/querynode

View File

@ -365,5 +365,10 @@ generate-mockery: getdeps
#internal/types
$(PWD)/bin/mockery --name=QueryCoordComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/types --filename=mock_querycoord.go --with-expecter --structname=MockQueryCoord --outpkg=types --inpackage
$(PWD)/bin/mockery --name=QueryNodeComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/types --filename=mock_querynode.go --with-expecter --structname=MockQueryNode --outpkg=types --inpackage
# internal/querynodev2
$(PWD)/bin/mockery --name=Manager --dir=$(PWD)/internal/querynodev2/cluster --output=$(PWD)/internal/querynodev2/cluster --filename=mock_manager.go --with-expecter --outpkg=cluster --structname=MockManager --inpackage
$(PWD)/bin/mockery --name=Loader --dir=$(PWD)/internal/querynodev2/segments --output=$(PWD)/internal/querynodev2/segments --filename=mock_loader.go --with-expecter --outpkg=segments --structname=MockLoader --inpackage
$(PWD)/bin/mockery --name=Worker --dir=$(PWD)/internal/querynodev2/cluster --output=$(PWD)/internal/querynodev2/cluster --filename=mock_worker.go --with-expecter --outpkg=worker --structname=MockWorker --inpackage
ci-ut: build-cpp-with-coverage generated-proto-go-without-cpp codecov-cpp codecov-go

View File

@ -182,13 +182,15 @@ queryCoord:
heartbeatAvailableInterval: 10000 # 10s, Only QueryNodes which fetched heartbeats within the duration are available
loadTimeoutSeconds: 600
checkHandoffInterval: 5000
enableActiveStandby: false
port: 19531
grpc:
serverMaxSendSize: 536870912
serverMaxRecvSize: 536870912
clientMaxSendSize: 268435456
clientMaxRecvSize: 268435456
taskMergeCap: 1
taskExecutionCap: 256
enableActiveStandby: false # Enable active-standby
# Related configuration of queryNode, used to run hybrid search between vector and scalar data.
queryNode:

View File

@ -46,7 +46,7 @@ datatype_sizeof(DataType data_type, int dim = 1) {
case DataType::VECTOR_FLOAT:
return sizeof(float) * dim;
case DataType::VECTOR_BINARY: {
Assert(dim % 8 == 0);
AssertInfo(dim % 8 == 0, "dim=" + std::to_string(dim));
return dim / 8;
}
default: {

View File

@ -15,11 +15,12 @@
// limitations under the License.
#include "index/VectorDiskIndex.h"
#include "common/Utils.h"
#include "config/ConfigKnowhere.h"
#include "index/Meta.h"
#include "index/Utils.h"
#include "storage/LocalChunkManager.h"
#include "config/ConfigKnowhere.h"
#include "storage/Util.h"
#include "common/Consts.h"
#include "common/Utils.h"

View File

@ -54,6 +54,6 @@ target_link_libraries(milvus_segcore
install(TARGETS milvus_segcore DESTINATION "${CMAKE_INSTALL_LIBDIR}")
add_executable(velox_demo VeloxDemo.cpp)
target_link_libraries(velox_demo ${CONAN_LIBS} velox_bundled)
install(TARGETS velox_demo DESTINATION "${CMAKE_INSTALL_BINDIR}")
# add_executable(velox_demo VeloxDemo.cpp)
# target_link_libraries(velox_demo ${CONAN_LIBS} velox_bundled)
# install(TARGETS velox_demo DESTINATION "${CMAKE_INSTALL_BINDIR}")

View File

@ -274,9 +274,27 @@ SegmentSealedImpl::LoadDeletedRecord(const LoadDeletedRecordInfo& info) {
auto timestamps = reinterpret_cast<const Timestamp*>(info.timestamps);
// step 2: fill pks and timestamps
ssize_t n = deleted_record_.ack_responder_.GetAck();
ssize_t divide_point = 0;
// Truncate the overlapping prefix
if (n > 0) {
auto last = deleted_record_.timestamps_[n - 1];
divide_point =
std::lower_bound(timestamps, timestamps + size, last + 1) -
timestamps;
}
// All these delete records have been loaded
if (divide_point == size) {
return;
}
size -= divide_point;
auto reserved_begin = deleted_record_.reserved.fetch_add(size);
deleted_record_.pks_.set_data_raw(reserved_begin, pks.data(), size);
deleted_record_.timestamps_.set_data_raw(reserved_begin, timestamps, size);
deleted_record_.pks_.set_data_raw(
reserved_begin, pks.data() + divide_point, size);
deleted_record_.timestamps_.set_data_raw(
reserved_begin, timestamps + divide_point, size);
deleted_record_.ack_responder_.AddSegment(reserved_begin,
reserved_begin + size);
}

View File

@ -9,15 +9,16 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "segcore/load_index_c.h"
#include "common/CDataType.h"
#include "common/FieldMeta.h"
#include "common/Utils.h"
#include "index/IndexFactory.h"
#include "index/Meta.h"
#include "index/Utils.h"
#include "index/IndexFactory.h"
#include "storage/Util.h"
#include "segcore/load_index_c.h"
#include "segcore/Types.h"
#include "storage/Util.h"
CStatus
NewLoadIndexInfo(CLoadIndexInfo* c_load_index_info,

View File

@ -739,7 +739,6 @@ TEST(Sealed, Delete) {
LoadDeletedRecordInfo info = {timestamps.data(), ids.get(), row_count};
segment->LoadDeletedRecord(info);
std::vector<uint8_t> tmp_block{0, 0};
BitsetType bitset(N, false);
segment->mask_with_delete(bitset, 10, 11);
ASSERT_EQ(bitset.count(), pks.size());
@ -758,6 +757,90 @@ TEST(Sealed, Delete) {
reinterpret_cast<const Timestamp*>(new_timestamps.data()));
}
TEST(Sealed, OverlapDelete) {
auto dim = 16;
auto topK = 5;
auto N = 10;
auto metric_type = knowhere::metric::L2;
auto schema = std::make_shared<Schema>();
auto fakevec_id = schema->AddDebugField("fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
auto counter_id = schema->AddDebugField("counter", DataType::INT64);
auto double_id = schema->AddDebugField("double", DataType::DOUBLE);
auto nothing_id = schema->AddDebugField("nothing", DataType::INT32);
schema->set_primary_field_id(counter_id);
auto dataset = DataGen(schema, N);
auto fakevec = dataset.get_col<float>(fakevec_id);
auto segment = CreateSealedSegment(schema);
std::string dsl = R"({
"bool": {
"must": [
{
"range": {
"double": {
"GE": -1,
"LT": 1
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5,
"round_decimal": 3
}
}
}
]
}
})";
Timestamp time = 1000000;
auto plan = CreatePlan(*schema, dsl);
auto num_queries = 5;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), time));
SealedLoadFieldData(dataset, *segment);
int64_t row_count = 5;
std::vector<idx_t> pks{1, 2, 3, 4, 5};
auto ids = std::make_unique<IdArray>();
ids->mutable_int_id()->mutable_data()->Add(pks.begin(), pks.end());
std::vector<Timestamp> timestamps{10, 10, 10, 10, 10};
LoadDeletedRecordInfo info = {timestamps.data(), ids.get(), row_count};
segment->LoadDeletedRecord(info);
ASSERT_EQ(segment->get_deleted_count(), pks.size())
<< "deleted_count=" << segment->get_deleted_count() << " pks_count=" << pks.size() << std::endl;
// Load overlapping delete records
row_count += 3;
pks.insert(pks.end(), {6, 7, 8});
auto new_ids = std::make_unique<IdArray>();
new_ids->mutable_int_id()->mutable_data()->Add(pks.begin(), pks.end());
timestamps.insert(timestamps.end(), {11, 11, 11});
LoadDeletedRecordInfo overlap_info = {timestamps.data(), new_ids.get(), row_count};
segment->LoadDeletedRecord(overlap_info);
BitsetType bitset(N, false);
// NOTE: need to change delete timestamp, so not to hit the cache
ASSERT_EQ(segment->get_deleted_count(), pks.size())
<< "deleted_count=" << segment->get_deleted_count() << " pks_count=" << pks.size() << std::endl;
segment->mask_with_delete(bitset, 10, 12);
ASSERT_EQ(bitset.count(), pks.size())
<< "bitset_count=" << bitset.count() << " pks_count=" << pks.size() << std::endl;
}
auto
GenMaxFloatVecs(int N, int dim) {
std::vector<float> vecs;

View File

@ -255,7 +255,7 @@ func (b *binlogIO) genDeltaBlobs(data *DeleteData, collID, partID, segID UniqueI
// genInsertBlobs returns kvs, insert-paths, stats-paths
func (b *binlogIO) genInsertBlobs(data *InsertData, partID, segID UniqueID, meta *etcdpb.CollectionMeta) (map[string][]byte, map[UniqueID]*datapb.FieldBinlog, map[UniqueID]*datapb.FieldBinlog, error) {
inCodec := storage.NewInsertCodec(meta)
inCodec := storage.NewInsertCodecWithSchema(meta)
inlogs, statslogs, err := inCodec.Serialize(partID, segID, data)
if err != nil {
return nil, nil, nil, err

View File

@ -554,7 +554,7 @@ func getInt64DeltaBlobs(segID UniqueID, pks []UniqueID, tss []Timestamp) ([]*Blo
}
func getInsertBlobs(segID UniqueID, iData *InsertData, meta *etcdpb.CollectionMeta) ([]*Blob, error) {
iCodec := storage.NewInsertCodec(meta)
iCodec := storage.NewInsertCodecWithSchema(meta)
iblobs, _, err := iCodec.Serialize(10, segID, iData)
return iblobs, err

View File

@ -364,7 +364,7 @@ func (m *rendezvousFlushManager) flushBufferData(data *BufferData, segmentID Uni
}
// encode data and convert output data
inCodec := storage.NewInsertCodec(meta)
inCodec := storage.NewInsertCodecWithSchema(meta)
binLogs, statsBinlogs, err := inCodec.Serialize(partID, segmentID, data.buffer)
if err != nil {

View File

@ -917,7 +917,7 @@ func createBinLogs(rowNum int, schema *schemapb.CollectionSchema, ts Timestamp,
ID: colID,
Schema: schema,
}
binLogs, statsBinLogs, err := storage.NewInsertCodec(meta).Serialize(partID, segmentID, data.buffer)
binLogs, statsBinLogs, err := storage.NewInsertCodecWithSchema(meta).Serialize(partID, segmentID, data.buffer)
if err != nil {
return nil, nil, err
}

View File

@ -414,3 +414,23 @@ func (c *Client) SyncDistribution(ctx context.Context, req *querypb.SyncDistribu
}
return ret.(*commonpb.Status), err
}
// Delete is used to forward delete message between delegator and workers.
func (c *Client) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commonpb.Status, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()),
)
ret, err := c.grpcClient.Call(ctx, func(client querypb.QueryNodeClient) (any, error) {
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.Delete(ctx, req)
})
if err != nil || ret == nil {
return nil, err
}
return ret.(*commonpb.Status), err
}

View File

@ -36,7 +36,7 @@ import (
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
qn "github.com/milvus-io/milvus/internal/querynode"
qn "github.com/milvus-io/milvus/internal/querynodev2"
"github.com/milvus-io/milvus/internal/tracer"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
@ -331,3 +331,8 @@ func (s *Server) GetDataDistribution(ctx context.Context, req *querypb.GetDataDi
func (s *Server) SyncDistribution(ctx context.Context, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) {
return s.querynode.SyncDistribution(ctx, req)
}
// Delete is used to forward delete message between delegator and workers.
func (s *Server) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commonpb.Status, error) {
return s.querynode.Delete(ctx, req)
}

View File

@ -161,6 +161,10 @@ func (m *MockQueryNode) SyncDistribution(context.Context, *querypb.SyncDistribut
return m.status, m.err
}
func (m *MockQueryNode) Delete(context.Context, *querypb.DeleteRequest) (*commonpb.Status, error) {
return m.status, m.err
}
type MockRootCoord struct {
types.RootCoord
initErr error

View File

@ -0,0 +1,116 @@
// Code generated by mockery v2.15.0. DO NOT EDIT.
package msgdispatcher
import (
"github.com/milvus-io/milvus-proto/go-api/msgpb"
mock "github.com/stretchr/testify/mock"
mqwrapper "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
msgstream "github.com/milvus-io/milvus/internal/mq/msgstream"
)
// MockClient is an autogenerated mock type for the Client type
type MockClient struct {
mock.Mock
}
type MockClient_Expecter struct {
mock *mock.Mock
}
func (_m *MockClient) EXPECT() *MockClient_Expecter {
return &MockClient_Expecter{mock: &_m.Mock}
}
// Deregister provides a mock function with given fields: vchannel
func (_m *MockClient) Deregister(vchannel string) {
_m.Called(vchannel)
}
// MockClient_Deregister_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Deregister'
type MockClient_Deregister_Call struct {
*mock.Call
}
// Deregister is a helper method to define mock.On call
// - vchannel string
func (_e *MockClient_Expecter) Deregister(vchannel interface{}) *MockClient_Deregister_Call {
return &MockClient_Deregister_Call{Call: _e.mock.On("Deregister", vchannel)}
}
func (_c *MockClient_Deregister_Call) Run(run func(vchannel string)) *MockClient_Deregister_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockClient_Deregister_Call) Return() *MockClient_Deregister_Call {
_c.Call.Return()
return _c
}
// Register provides a mock function with given fields: vchannel, pos, subPos
func (_m *MockClient) Register(vchannel string, pos *msgpb.MsgPosition, subPos mqwrapper.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error) {
ret := _m.Called(vchannel, pos, subPos)
var r0 <-chan *msgstream.MsgPack
if rf, ok := ret.Get(0).(func(string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) <-chan *msgstream.MsgPack); ok {
r0 = rf(vchannel, pos, subPos)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(<-chan *msgstream.MsgPack)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) error); ok {
r1 = rf(vchannel, pos, subPos)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockClient_Register_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Register'
type MockClient_Register_Call struct {
*mock.Call
}
// Register is a helper method to define mock.On call
// - vchannel string
// - pos *msgpb.MsgPosition
// - subPos mqwrapper.SubscriptionInitialPosition
func (_e *MockClient_Expecter) Register(vchannel interface{}, pos interface{}, subPos interface{}) *MockClient_Register_Call {
return &MockClient_Register_Call{Call: _e.mock.On("Register", vchannel, pos, subPos)}
}
func (_c *MockClient_Register_Call) Run(run func(vchannel string, pos *msgpb.MsgPosition, subPos mqwrapper.SubscriptionInitialPosition)) *MockClient_Register_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string), args[1].(*msgpb.MsgPosition), args[2].(mqwrapper.SubscriptionInitialPosition))
})
return _c
}
func (_c *MockClient_Register_Call) Return(_a0 <-chan *msgstream.MsgPack, _a1 error) *MockClient_Register_Call {
_c.Call.Return(_a0, _a1)
return _c
}
type mockConstructorTestingTNewMockClient interface {
mock.TestingT
Cleanup(func())
}
// NewMockClient creates a new instance of MockClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockClient(t mockConstructorTestingTNewMockClient) *MockClient {
mock := &MockClient{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -1,20 +1,393 @@
// Code generated by mockery v2.15.0. DO NOT EDIT.
package msgstream
import (
"github.com/milvus-io/milvus-proto/go-api/msgpb"
mock "github.com/stretchr/testify/mock"
mqwrapper "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
)
// MockMsgStream is an autogenerated mock type for the MsgStream type
type MockMsgStream struct {
MsgStream
AsProducerFunc func(channels []string)
BroadcastMarkFunc func(*MsgPack) (map[string][]MessageID, error)
BroadcastFunc func(*MsgPack) error
mock.Mock
}
func NewMockMsgStream() *MockMsgStream {
return &MockMsgStream{}
type MockMsgStream_Expecter struct {
mock *mock.Mock
}
func (m MockMsgStream) AsProducer(channels []string) {
m.AsProducerFunc(channels)
func (_m *MockMsgStream) EXPECT() *MockMsgStream_Expecter {
return &MockMsgStream_Expecter{mock: &_m.Mock}
}
func (m MockMsgStream) Broadcast(pack *MsgPack) (map[string][]MessageID, error) {
return m.BroadcastMarkFunc(pack)
// AsConsumer provides a mock function with given fields: channels, subName, position
func (_m *MockMsgStream) AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) {
_m.Called(channels, subName, position)
}
// MockMsgStream_AsConsumer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AsConsumer'
type MockMsgStream_AsConsumer_Call struct {
*mock.Call
}
// AsConsumer is a helper method to define mock.On call
// - channels []string
// - subName string
// - position mqwrapper.SubscriptionInitialPosition
func (_e *MockMsgStream_Expecter) AsConsumer(channels interface{}, subName interface{}, position interface{}) *MockMsgStream_AsConsumer_Call {
return &MockMsgStream_AsConsumer_Call{Call: _e.mock.On("AsConsumer", channels, subName, position)}
}
func (_c *MockMsgStream_AsConsumer_Call) Run(run func(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition)) *MockMsgStream_AsConsumer_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]string), args[1].(string), args[2].(mqwrapper.SubscriptionInitialPosition))
})
return _c
}
func (_c *MockMsgStream_AsConsumer_Call) Return() *MockMsgStream_AsConsumer_Call {
_c.Call.Return()
return _c
}
// AsProducer provides a mock function with given fields: channels
func (_m *MockMsgStream) AsProducer(channels []string) {
_m.Called(channels)
}
// MockMsgStream_AsProducer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AsProducer'
type MockMsgStream_AsProducer_Call struct {
*mock.Call
}
// AsProducer is a helper method to define mock.On call
// - channels []string
func (_e *MockMsgStream_Expecter) AsProducer(channels interface{}) *MockMsgStream_AsProducer_Call {
return &MockMsgStream_AsProducer_Call{Call: _e.mock.On("AsProducer", channels)}
}
func (_c *MockMsgStream_AsProducer_Call) Run(run func(channels []string)) *MockMsgStream_AsProducer_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]string))
})
return _c
}
func (_c *MockMsgStream_AsProducer_Call) Return() *MockMsgStream_AsProducer_Call {
_c.Call.Return()
return _c
}
// Broadcast provides a mock function with given fields: _a0
func (_m *MockMsgStream) Broadcast(_a0 *MsgPack) (map[string][]mqwrapper.MessageID, error) {
ret := _m.Called(_a0)
var r0 map[string][]mqwrapper.MessageID
if rf, ok := ret.Get(0).(func(*MsgPack) map[string][]mqwrapper.MessageID); ok {
r0 = rf(_a0)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string][]mqwrapper.MessageID)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(*MsgPack) error); ok {
r1 = rf(_a0)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockMsgStream_Broadcast_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Broadcast'
type MockMsgStream_Broadcast_Call struct {
*mock.Call
}
// Broadcast is a helper method to define mock.On call
// - _a0 *MsgPack
func (_e *MockMsgStream_Expecter) Broadcast(_a0 interface{}) *MockMsgStream_Broadcast_Call {
return &MockMsgStream_Broadcast_Call{Call: _e.mock.On("Broadcast", _a0)}
}
func (_c *MockMsgStream_Broadcast_Call) Run(run func(_a0 *MsgPack)) *MockMsgStream_Broadcast_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*MsgPack))
})
return _c
}
func (_c *MockMsgStream_Broadcast_Call) Return(_a0 map[string][]mqwrapper.MessageID, _a1 error) *MockMsgStream_Broadcast_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// Chan provides a mock function with given fields:
func (_m *MockMsgStream) Chan() <-chan *MsgPack {
ret := _m.Called()
var r0 <-chan *MsgPack
if rf, ok := ret.Get(0).(func() <-chan *MsgPack); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(<-chan *MsgPack)
}
}
return r0
}
// MockMsgStream_Chan_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Chan'
type MockMsgStream_Chan_Call struct {
*mock.Call
}
// Chan is a helper method to define mock.On call
func (_e *MockMsgStream_Expecter) Chan() *MockMsgStream_Chan_Call {
return &MockMsgStream_Chan_Call{Call: _e.mock.On("Chan")}
}
func (_c *MockMsgStream_Chan_Call) Run(run func()) *MockMsgStream_Chan_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockMsgStream_Chan_Call) Return(_a0 <-chan *MsgPack) *MockMsgStream_Chan_Call {
_c.Call.Return(_a0)
return _c
}
// Close provides a mock function with given fields:
func (_m *MockMsgStream) Close() {
_m.Called()
}
// MockMsgStream_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
type MockMsgStream_Close_Call struct {
*mock.Call
}
// Close is a helper method to define mock.On call
func (_e *MockMsgStream_Expecter) Close() *MockMsgStream_Close_Call {
return &MockMsgStream_Close_Call{Call: _e.mock.On("Close")}
}
func (_c *MockMsgStream_Close_Call) Run(run func()) *MockMsgStream_Close_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockMsgStream_Close_Call) Return() *MockMsgStream_Close_Call {
_c.Call.Return()
return _c
}
// GetLatestMsgID provides a mock function with given fields: channel
func (_m *MockMsgStream) GetLatestMsgID(channel string) (mqwrapper.MessageID, error) {
ret := _m.Called(channel)
var r0 mqwrapper.MessageID
if rf, ok := ret.Get(0).(func(string) mqwrapper.MessageID); ok {
r0 = rf(channel)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(mqwrapper.MessageID)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(channel)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockMsgStream_GetLatestMsgID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetLatestMsgID'
type MockMsgStream_GetLatestMsgID_Call struct {
*mock.Call
}
// GetLatestMsgID is a helper method to define mock.On call
// - channel string
func (_e *MockMsgStream_Expecter) GetLatestMsgID(channel interface{}) *MockMsgStream_GetLatestMsgID_Call {
return &MockMsgStream_GetLatestMsgID_Call{Call: _e.mock.On("GetLatestMsgID", channel)}
}
func (_c *MockMsgStream_GetLatestMsgID_Call) Run(run func(channel string)) *MockMsgStream_GetLatestMsgID_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockMsgStream_GetLatestMsgID_Call) Return(_a0 mqwrapper.MessageID, _a1 error) *MockMsgStream_GetLatestMsgID_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// GetProduceChannels provides a mock function with given fields:
func (_m *MockMsgStream) GetProduceChannels() []string {
ret := _m.Called()
var r0 []string
if rf, ok := ret.Get(0).(func() []string); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
return r0
}
// MockMsgStream_GetProduceChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetProduceChannels'
type MockMsgStream_GetProduceChannels_Call struct {
*mock.Call
}
// GetProduceChannels is a helper method to define mock.On call
func (_e *MockMsgStream_Expecter) GetProduceChannels() *MockMsgStream_GetProduceChannels_Call {
return &MockMsgStream_GetProduceChannels_Call{Call: _e.mock.On("GetProduceChannels")}
}
func (_c *MockMsgStream_GetProduceChannels_Call) Run(run func()) *MockMsgStream_GetProduceChannels_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockMsgStream_GetProduceChannels_Call) Return(_a0 []string) *MockMsgStream_GetProduceChannels_Call {
_c.Call.Return(_a0)
return _c
}
// Produce provides a mock function with given fields: _a0
func (_m *MockMsgStream) Produce(_a0 *MsgPack) error {
ret := _m.Called(_a0)
var r0 error
if rf, ok := ret.Get(0).(func(*MsgPack) error); ok {
r0 = rf(_a0)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockMsgStream_Produce_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Produce'
type MockMsgStream_Produce_Call struct {
*mock.Call
}
// Produce is a helper method to define mock.On call
// - _a0 *MsgPack
func (_e *MockMsgStream_Expecter) Produce(_a0 interface{}) *MockMsgStream_Produce_Call {
return &MockMsgStream_Produce_Call{Call: _e.mock.On("Produce", _a0)}
}
func (_c *MockMsgStream_Produce_Call) Run(run func(_a0 *MsgPack)) *MockMsgStream_Produce_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*MsgPack))
})
return _c
}
func (_c *MockMsgStream_Produce_Call) Return(_a0 error) *MockMsgStream_Produce_Call {
_c.Call.Return(_a0)
return _c
}
// Seek provides a mock function with given fields: offset
func (_m *MockMsgStream) Seek(offset []*msgpb.MsgPosition) error {
ret := _m.Called(offset)
var r0 error
if rf, ok := ret.Get(0).(func([]*msgpb.MsgPosition) error); ok {
r0 = rf(offset)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockMsgStream_Seek_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Seek'
type MockMsgStream_Seek_Call struct {
*mock.Call
}
// Seek is a helper method to define mock.On call
// - offset []*msgpb.MsgPosition
func (_e *MockMsgStream_Expecter) Seek(offset interface{}) *MockMsgStream_Seek_Call {
return &MockMsgStream_Seek_Call{Call: _e.mock.On("Seek", offset)}
}
func (_c *MockMsgStream_Seek_Call) Run(run func(offset []*msgpb.MsgPosition)) *MockMsgStream_Seek_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]*msgpb.MsgPosition))
})
return _c
}
func (_c *MockMsgStream_Seek_Call) Return(_a0 error) *MockMsgStream_Seek_Call {
_c.Call.Return(_a0)
return _c
}
// SetRepackFunc provides a mock function with given fields: repackFunc
func (_m *MockMsgStream) SetRepackFunc(repackFunc RepackFunc) {
_m.Called(repackFunc)
}
// MockMsgStream_SetRepackFunc_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRepackFunc'
type MockMsgStream_SetRepackFunc_Call struct {
*mock.Call
}
// SetRepackFunc is a helper method to define mock.On call
// - repackFunc RepackFunc
func (_e *MockMsgStream_Expecter) SetRepackFunc(repackFunc interface{}) *MockMsgStream_SetRepackFunc_Call {
return &MockMsgStream_SetRepackFunc_Call{Call: _e.mock.On("SetRepackFunc", repackFunc)}
}
func (_c *MockMsgStream_SetRepackFunc_Call) Run(run func(repackFunc RepackFunc)) *MockMsgStream_SetRepackFunc_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(RepackFunc))
})
return _c
}
func (_c *MockMsgStream_SetRepackFunc_Call) Return() *MockMsgStream_SetRepackFunc_Call {
_c.Call.Return()
return _c
}
type mockConstructorTestingTNewMockMsgStream interface {
mock.TestingT
Cleanup(func())
}
// NewMockMsgStream creates a new instance of MockMsgStream. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockMsgStream(t mockConstructorTestingTNewMockMsgStream) *MockMsgStream {
mock := &MockMsgStream{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -0,0 +1,25 @@
package msgstream
type WastedMockMsgStream struct {
MsgStream
AsProducerFunc func(channels []string)
BroadcastMarkFunc func(*MsgPack) (map[string][]MessageID, error)
BroadcastFunc func(*MsgPack) error
ChanFunc func() <-chan *MsgPack
}
func NewWastedMockMsgStream() *WastedMockMsgStream {
return &WastedMockMsgStream{}
}
func (m WastedMockMsgStream) AsProducer(channels []string) {
m.AsProducerFunc(channels)
}
func (m WastedMockMsgStream) Broadcast(pack *MsgPack) (map[string][]MessageID, error) {
return m.BroadcastMarkFunc(pack)
}
func (m WastedMockMsgStream) Chan() <-chan *MsgPack {
return m.ChanFunc()
}

View File

@ -72,6 +72,7 @@ service QueryNode {
rpc GetDataDistribution(GetDataDistributionRequest) returns (GetDataDistributionResponse) {}
rpc SyncDistribution(SyncDistributionRequest) returns (common.Status) {}
rpc Delete(DeleteRequest) returns (common.Status) {}
}
//--------------------QueryCoord grpc request and response proto------------------
@ -243,6 +244,7 @@ message SegmentLoadInfo {
int64 segment_size = 12;
string insert_channel = 13;
msg.MsgPosition start_position = 14;
msg.MsgPosition end_position = 15;
}
message FieldIndexInfo {
@ -259,6 +261,11 @@ message FieldIndexInfo {
int64 num_rows = 10;
}
enum LoadScope {
Full = 0;
Delta = 1;
}
message LoadSegmentsRequest {
common.MsgBase base = 1;
int64 dst_nodeID = 2;
@ -271,6 +278,7 @@ message LoadSegmentsRequest {
repeated msg.MsgPosition delta_positions = 9;
int64 version = 10;
bool need_transfer = 11;
LoadScope load_scope = 12;
}
message ReleaseSegmentsRequest {
@ -316,6 +324,18 @@ message ReplicaSegmentsInfo {
repeated int64 versions = 4;
}
message GetLoadInfoRequest {
common.MsgBase base = 1;
int64 collection_id = 2;
}
message GetLoadInfoResponse {
common.Status status = 1;
schema.CollectionSchema schema = 2;
LoadType load_type = 3;
repeated int64 partitions = 4;
}
//----------------request auto triggered by QueryCoord-----------------
message HandoffSegmentsRequest {
common.MsgBase base = 1;
@ -445,6 +465,7 @@ message SealedSegmentsChangeInfo {
message GetDataDistributionRequest {
common.MsgBase base = 1;
map<string, msg.MsgPosition> checkpoints = 2;
}
message GetDataDistributionResponse {
@ -475,6 +496,7 @@ message SegmentVersionInfo {
int64 partition = 3;
string channel = 4;
int64 version = 5;
uint64 last_delta_timestamp = 6;
}
message ChannelVersionInfo {
@ -516,6 +538,7 @@ message Replica {
enum SyncType {
Remove = 0;
Set = 1;
Amend = 2;
}
message SyncAction {
@ -524,6 +547,7 @@ message SyncAction {
int64 segmentID = 3;
int64 nodeID = 4;
int64 version = 5;
SegmentLoadInfo info = 6;
}
message SyncDistributionRequest {
@ -568,4 +592,13 @@ message ResourceGroupInfo {
map<int64, int32> num_outgoing_node = 5;
// collection id -> be accessed node num by other rg
map<int64, int32> num_incoming_node = 6;
}
}
message DeleteRequest {
common.MsgBase base = 1;
int64 collection_id = 2;
int64 partition_id = 3;
string vchannel_name = 4;
int64 segment_id = 5;
schema.IDs primary_keys = 6;
repeated uint64 timestamps = 7;
}

File diff suppressed because it is too large Load Diff

View File

@ -191,7 +191,7 @@ func (rl *rateLimiter) registerLimiters() {
}
limit := ratelimitutil.Limit(r.GetAsFloat())
burst := r.GetAsFloat() // use rate as burst, because Limiter is with punishment mechanism, burst is insignificant.
rl.limiters.InsertIfNotPresent(internalpb.RateType(rt), ratelimitutil.NewLimiter(limit, burst))
rl.limiters.GetOrInsert(internalpb.RateType(rt), ratelimitutil.NewLimiter(limit, burst))
onEvent := func(rateType internalpb.RateType) func(*config.Event) {
return func(event *config.Event) {
f, err := strconv.ParseFloat(event.Value, 64)

View File

@ -135,3 +135,7 @@ func (m *QueryNodeMock) GetDataDistribution(context.Context, *querypb.GetDataDis
func (m *QueryNodeMock) SyncDistribution(context.Context, *querypb.SyncDistributionRequest) (*commonpb.Status, error) {
return nil, nil
}
func (m *QueryNodeMock) Delete(context.Context, *querypb.DeleteRequest) (*commonpb.Status, error) {
return nil, nil
}

View File

@ -7,6 +7,7 @@ import (
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"go.opentelemetry.io/otel"
@ -299,8 +300,10 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro
}
func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string, channelNum int) error {
nodeReq := proto.Clone(g.GetStatisticsRequest).(*internalpb.GetStatisticsRequest)
nodeReq.Base.TargetID = nodeID
req := &querypb.GetStatisticsRequest{
Req: g.GetStatisticsRequest,
Req: nodeReq,
DmlChannels: channelIDs,
Scope: querypb.DataScope_All,
}

View File

@ -160,7 +160,8 @@ func (c *SegmentChecker) getStreamingSegmentsDist(distMgr *meta.DistributionMana
}
// GetHistoricalSegmentDiff get historical segment diff between target and dist
func (c *SegmentChecker) getHistoricalSegmentDiff(targetMgr *meta.TargetManager,
func (c *SegmentChecker) getHistoricalSegmentDiff(
targetMgr *meta.TargetManager,
distMgr *meta.DistributionManager,
metaInfo *meta.Meta,
collectionID int64,
@ -179,7 +180,7 @@ func (c *SegmentChecker) getHistoricalSegmentDiff(targetMgr *meta.TargetManager,
nextTargetMap := targetMgr.GetHistoricalSegmentsByCollection(collectionID, meta.NextTarget)
currentTargetMap := targetMgr.GetHistoricalSegmentsByCollection(collectionID, meta.CurrentTarget)
//get segment which exist on next target, but not on dist
// Segment which exist on next target, but not on dist
for segmentID, segment := range nextTargetMap {
if !distMap.Contain(segmentID) {
toLoad = append(toLoad, segment)

View File

@ -23,6 +23,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/msgpb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
@ -119,14 +120,16 @@ func (dh *distHandler) updateSegmentsDistribution(resp *querypb.GetDataDistribut
PartitionID: s.GetPartition(),
InsertChannel: s.GetChannel(),
},
Node: resp.GetNodeID(),
Version: s.GetVersion(),
Node: resp.GetNodeID(),
Version: s.GetVersion(),
LastDeltaTimestamp: s.GetLastDeltaTimestamp(),
}
} else {
segment = &meta.Segment{
SegmentInfo: proto.Clone(segmentInfo).(*datapb.SegmentInfo),
Node: resp.GetNodeID(),
Version: s.GetVersion(),
SegmentInfo: proto.Clone(segmentInfo).(*datapb.SegmentInfo),
Node: resp.GetNodeID(),
Version: s.GetVersion(),
LastDeltaTimestamp: s.GetLastDeltaTimestamp(),
}
}
updates = append(updates, segment)
@ -200,12 +203,24 @@ func (dh *distHandler) updateLeaderView(resp *querypb.GetDataDistributionRespons
func (dh *distHandler) getDistribution(ctx context.Context) error {
dh.mu.Lock()
defer dh.mu.Unlock()
channels := make(map[string]*msgpb.MsgPosition)
for _, channel := range dh.dist.ChannelDistManager.GetByNode(dh.nodeID) {
targetChannel := dh.target.GetDmChannel(channel.GetCollectionID(), channel.GetChannelName(), meta.CurrentTarget)
if targetChannel == nil {
continue
}
channels[channel.GetChannelName()] = targetChannel.GetSeekPosition()
}
ctx, cancel := context.WithTimeout(ctx, distReqTimeout)
defer cancel()
resp, err := dh.client.GetDataDistribution(ctx, dh.nodeID, &querypb.GetDataDistributionRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_GetDistribution),
),
Checkpoints: channels,
})
if err != nil {

View File

@ -26,8 +26,9 @@ import (
type Segment struct {
*datapb.SegmentInfo
Node int64 // Node the segment is in
Version int64 // Version is the timestamp of loading segment
Node int64 // Node the segment is in
Version int64 // Version is the timestamp of loading segment
LastDeltaTimestamp uint64 // The timestamp of the last delta record
}
func SegmentFromInfo(info *datapb.SegmentInfo) *Segment {

View File

@ -29,6 +29,53 @@ func (_m *MockQueryNodeServer) EXPECT() *MockQueryNodeServer_Expecter {
return &MockQueryNodeServer_Expecter{mock: &_m.Mock}
}
// Delete provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) Delete(_a0 context.Context, _a1 *querypb.DeleteRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status
if rf, ok := ret.Get(0).(func(context.Context, *querypb.DeleteRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.DeleteRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryNodeServer_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete'
type MockQueryNodeServer_Delete_Call struct {
*mock.Call
}
// Delete is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.DeleteRequest
func (_e *MockQueryNodeServer_Expecter) Delete(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Delete_Call {
return &MockQueryNodeServer_Delete_Call{Call: _e.mock.On("Delete", _a0, _a1)}
}
func (_c *MockQueryNodeServer_Delete_Call) Run(run func(_a0 context.Context, _a1 *querypb.DeleteRequest)) *MockQueryNodeServer_Delete_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.DeleteRequest))
})
return _c
}
func (_c *MockQueryNodeServer_Delete_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryNodeServer_Delete_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// GetComponentStates provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) GetComponentStates(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) {
ret := _m.Called(_a0, _a1)

View File

@ -31,7 +31,10 @@ import (
"go.uber.org/zap"
)
const interval = 1 * time.Second
const (
interval = 1 * time.Second
RPCTimeout = 3 * time.Second
)
// LeaderObserver is to sync the distribution with leader
type LeaderObserver struct {
@ -40,6 +43,7 @@ type LeaderObserver struct {
dist *meta.DistributionManager
meta *meta.Meta
target *meta.TargetManager
broker meta.Broker
cluster session.Cluster
stopOnce sync.Once
@ -106,18 +110,54 @@ func (o *LeaderObserver) findNeedLoadedSegments(leaderView *meta.LeaderView, dis
dists = utils.FindMaxVersionSegments(dists)
for _, s := range dists {
version, ok := leaderView.Segments[s.GetID()]
existInCurrentTarget := o.target.GetHistoricalSegment(s.CollectionID, s.GetID(), meta.CurrentTarget) != nil
currentTarget := o.target.GetHistoricalSegment(s.CollectionID, s.GetID(), meta.CurrentTarget)
existInCurrentTarget := currentTarget != nil
existInNextTarget := o.target.GetHistoricalSegment(s.CollectionID, s.GetID(), meta.NextTarget) != nil
if ok && version.GetVersion() >= s.Version || (!existInCurrentTarget && !existInNextTarget) {
if !existInCurrentTarget && !existInNextTarget {
continue
}
ret = append(ret, &querypb.SyncAction{
Type: querypb.SyncType_Set,
PartitionID: s.GetPartitionID(),
SegmentID: s.GetID(),
NodeID: s.Node,
Version: s.Version,
})
if !ok || version.GetVersion() < s.Version { // Leader misses this segment
ctx := context.Background()
resp, err := o.broker.GetSegmentInfo(ctx, s.GetID())
if err != nil || len(resp.GetInfos()) == 0 {
log.Warn("failed to get segment info from DataCoord", zap.Error(err))
continue
}
segment := resp.GetInfos()[0]
loadInfo := utils.PackSegmentLoadInfo(segment, nil)
// Fix the leader view with lacks of delta logs
if existInCurrentTarget && s.LastDeltaTimestamp < currentTarget.GetDmlPosition().GetTimestamp() {
log.Info("Fix QueryNode delta logs lag",
zap.Int64("nodeID", s.Node),
zap.Int64("collectionID", s.GetCollectionID()),
zap.Int64("partitionID", s.GetPartitionID()),
zap.Int64("segmentID", s.GetID()),
zap.Uint64("segmentDeltaTimestamp", s.LastDeltaTimestamp),
zap.Uint64("channelTimestamp", currentTarget.GetDmlPosition().GetTimestamp()),
)
ret = append(ret, &querypb.SyncAction{
Type: querypb.SyncType_Amend,
PartitionID: s.GetPartitionID(),
SegmentID: s.GetID(),
NodeID: s.Node,
Version: s.Version,
Info: loadInfo,
})
}
ret = append(ret, &querypb.SyncAction{
Type: querypb.SyncType_Set,
PartitionID: s.GetPartitionID(),
SegmentID: s.GetID(),
NodeID: s.Node,
Version: s.Version,
Info: loadInfo,
})
}
}
return ret
}
@ -181,6 +221,7 @@ func NewLeaderObserver(
dist *meta.DistributionManager,
meta *meta.Meta,
targetMgr *meta.TargetManager,
broker meta.Broker,
cluster session.Cluster,
) *LeaderObserver {
return &LeaderObserver{
@ -188,6 +229,7 @@ func NewLeaderObserver(
dist: dist,
meta: meta,
target: targetMgr,
broker: broker,
cluster: cluster,
}
}

View File

@ -73,7 +73,7 @@ func (suite *LeaderObserverTestSuite) SetupTest() {
suite.mockCluster = session.NewMockCluster(suite.T())
distManager := meta.NewDistributionManager()
targetManager := meta.NewTargetManager(suite.broker, suite.meta)
suite.observer = NewLeaderObserver(distManager, suite.meta, targetManager, suite.mockCluster)
suite.observer = NewLeaderObserver(distManager, suite.meta, targetManager, suite.broker, suite.mockCluster)
}
func (suite *LeaderObserverTestSuite) TearDownTest() {
@ -98,6 +98,15 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegments() {
},
}
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
info := &datapb.SegmentInfo{
ID: 1,
CollectionID: 1,
PartitionID: 1,
InsertChannel: "test-insert-channel",
}
loadInfo := utils.PackSegmentLoadInfo(info, nil)
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, int64(1)).Return(
&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{info}}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
@ -118,6 +127,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegments() {
SegmentID: 1,
NodeID: 1,
Version: 1,
Info: loadInfo,
},
},
}
@ -154,6 +164,15 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncLoadedSegments() {
},
}
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
info := &datapb.SegmentInfo{
ID: 1,
CollectionID: 1,
PartitionID: 1,
InsertChannel: "test-insert-channel",
}
loadInfo := utils.PackSegmentLoadInfo(info, nil)
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, int64(1)).Return(
&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{info}}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
@ -176,6 +195,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncLoadedSegments() {
SegmentID: 1,
NodeID: 1,
Version: 1,
Info: loadInfo,
},
},
}
@ -252,6 +272,15 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegmentsWithReplicas() {
},
}
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
info := &datapb.SegmentInfo{
ID: 1,
CollectionID: 1,
PartitionID: 1,
InsertChannel: "test-insert-channel",
}
loadInfo := utils.PackSegmentLoadInfo(info, nil)
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, int64(1)).Return(
&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{info}}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
@ -274,6 +303,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegmentsWithReplicas() {
SegmentID: 1,
NodeID: 1,
Version: 1,
Info: loadInfo,
},
},
}

View File

@ -331,6 +331,7 @@ func (s *Server) initObserver() {
s.dist,
s.meta,
s.targetMgr,
s.broker,
s.cluster,
)
s.targetObserver = observers.NewTargetObserver(

View File

@ -77,6 +77,7 @@ type SegmentAction struct {
func NewSegmentAction(nodeID UniqueID, typ ActionType, shard string, segmentID UniqueID) *SegmentAction {
return NewSegmentActionWithScope(nodeID, typ, shard, segmentID, querypb.DataScope_All)
}
func NewSegmentActionWithScope(nodeID UniqueID, typ ActionType, shard string, segmentID UniqueID, scope querypb.DataScope) *SegmentAction {
base := NewBaseAction(nodeID, typ, shard)
return &SegmentAction{

View File

@ -99,10 +99,7 @@ func packLoadSegmentRequest(
segment := resp.GetInfos()[0]
var posSrcStr string
if resp.GetChannelCheckpoint() != nil && resp.ChannelCheckpoint[segment.InsertChannel] != nil {
deltaPosition = resp.ChannelCheckpoint[segment.InsertChannel]
posSrcStr = "channelCheckpoint"
} else if segment.GetDmlPosition() != nil {
if segment.GetDmlPosition() != nil {
deltaPosition = segment.GetDmlPosition()
posSrcStr = "segmentDMLPos"
} else {

View File

@ -26,6 +26,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/msgpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/tsoutil"
)
@ -36,7 +37,6 @@ func Test_packLoadSegmentRequest(t *testing.T) {
t0 := tsoutil.ComposeTSByTime(time.Now().Add(-20*time.Minute), 0)
t1 := tsoutil.ComposeTSByTime(time.Now().Add(-8*time.Minute), 0)
t2 := tsoutil.ComposeTSByTime(time.Now().Add(-5*time.Minute), 0)
t3 := tsoutil.ComposeTSByTime(time.Now().Add(-1*time.Minute), 0)
segmentInfo := &datapb.SegmentInfo{
ID: 0,
@ -51,28 +51,6 @@ func Test_packLoadSegmentRequest(t *testing.T) {
},
}
t.Run("test set deltaPosition from channel checkpoint", func(t *testing.T) {
segmentAction := NewSegmentAction(0, 0, "", 0)
segmentTask, err := NewSegmentTask(context.TODO(), 5*time.Second, 0, 0, 0, segmentAction)
assert.NoError(t, err)
resp := &datapb.GetSegmentInfoResponse{
Infos: []*datapb.SegmentInfo{
proto.Clone(segmentInfo).(*datapb.SegmentInfo),
},
ChannelCheckpoint: map[string]*msgpb.MsgPosition{
mockVChannel: {
ChannelName: mockPChannel,
Timestamp: t3,
},
},
}
req := packLoadSegmentRequest(segmentTask, segmentAction, nil, nil, nil, resp)
assert.Equal(t, 1, len(req.GetDeltaPositions()))
assert.Equal(t, mockPChannel, req.DeltaPositions[0].ChannelName)
assert.Equal(t, t3, req.DeltaPositions[0].Timestamp)
})
t.Run("test set deltaPosition from segment dmlPosition", func(t *testing.T) {
segmentAction := NewSegmentAction(0, 0, "", 0)
segmentTask, err := NewSegmentTask(context.TODO(), 5*time.Second, 0, 0, 0, segmentAction)
@ -83,7 +61,7 @@ func Test_packLoadSegmentRequest(t *testing.T) {
proto.Clone(segmentInfo).(*datapb.SegmentInfo),
},
}
req := packLoadSegmentRequest(segmentTask, segmentAction, nil, nil, nil, resp)
req := packLoadSegmentRequest(segmentTask, segmentAction, nil, nil, &querypb.SegmentLoadInfo{}, resp)
assert.Equal(t, 1, len(req.GetDeltaPositions()))
assert.Equal(t, mockPChannel, req.DeltaPositions[0].ChannelName)
assert.Equal(t, t2, req.DeltaPositions[0].Timestamp)
@ -99,7 +77,7 @@ func Test_packLoadSegmentRequest(t *testing.T) {
resp := &datapb.GetSegmentInfoResponse{
Infos: []*datapb.SegmentInfo{segInfo},
}
req := packLoadSegmentRequest(segmentTask, segmentAction, nil, nil, nil, resp)
req := packLoadSegmentRequest(segmentTask, segmentAction, nil, nil, &querypb.SegmentLoadInfo{}, resp)
assert.Equal(t, 1, len(req.GetDeltaPositions()))
assert.Equal(t, mockPChannel, req.DeltaPositions[0].ChannelName)
assert.Equal(t, t1, req.DeltaPositions[0].Timestamp)
@ -115,7 +93,7 @@ func Test_packLoadSegmentRequest(t *testing.T) {
resp := &datapb.GetSegmentInfoResponse{
Infos: []*datapb.SegmentInfo{segInfo},
}
req := packLoadSegmentRequest(segmentTask, segmentAction, nil, nil, nil, resp)
req := packLoadSegmentRequest(segmentTask, segmentAction, nil, nil, &querypb.SegmentLoadInfo{}, resp)
assert.Equal(t, 1, len(req.GetDeltaPositions()))
assert.Equal(t, mockPChannel, req.DeltaPositions[0].ChannelName)
assert.Equal(t, t0, req.DeltaPositions[0].Timestamp)

View File

@ -90,6 +90,8 @@ func PackSegmentLoadInfo(segment *datapb.SegmentInfo, indexes []*querypb.FieldIn
Deltalogs: segment.Deltalogs,
InsertChannel: segment.InsertChannel,
IndexInfos: indexes,
StartPosition: segment.GetStartPosition(),
EndPosition: segment.GetDmlPosition(),
}
loadInfo.SegmentSize = calculateSegmentSize(loadInfo)
return loadInfo

View File

@ -1499,3 +1499,10 @@ func (node *QueryNode) GetSession() *sessionutil.Session {
defer node.sessionMu.Unlock()
return node.session
}
func (node *QueryNode) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "not implemented in qnv1",
}, nil
}

View File

@ -622,24 +622,24 @@ func (replica *metaReplica) addSegmentPrivate(segment *Segment) error {
zap.Int64("segment ID", segID),
zap.String("segment type", segType.String()),
zap.Int64("row count", rowCount),
zap.Uint64("segment indexed fields", segment.indexedFieldInfos.Len()),
)
metrics.QueryNodeNumSegments.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),
fmt.Sprint(segment.collectionID),
fmt.Sprint(segment.partitionID),
segType.String(),
fmt.Sprint(segment.indexedFieldInfos.Len()),
).Inc()
if rowCount > 0 {
metrics.QueryNodeNumEntities.WithLabelValues(
/*
metrics.QueryNodeNumSegments.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),
fmt.Sprint(segment.collectionID),
fmt.Sprint(segment.partitionID),
segType.String(),
fmt.Sprint(segment.indexedFieldInfos.Len()),
).Add(float64(rowCount))
}
).Inc()
if rowCount > 0 {
metrics.QueryNodeNumEntities.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),
fmt.Sprint(segment.collectionID),
fmt.Sprint(segment.partitionID),
segType.String(),
fmt.Sprint(segment.indexedFieldInfos.Len()),
).Add(float64(rowCount))
}*/
return nil
}
@ -730,25 +730,25 @@ func (replica *metaReplica) removeSegmentPrivate(segmentID UniqueID, segType seg
zap.Int64("segment ID", segmentID),
zap.String("segment type", segType.String()),
zap.Int64("row count", rowCount),
zap.Uint64("segment indexed fields", segment.indexedFieldInfos.Len()),
)
metrics.QueryNodeNumSegments.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),
fmt.Sprint(segment.collectionID),
fmt.Sprint(segment.partitionID),
segType.String(),
// Note: this field is mutable after segment is loaded.
fmt.Sprint(segment.indexedFieldInfos.Len()),
).Dec()
if rowCount > 0 {
metrics.QueryNodeNumEntities.WithLabelValues(
/*
metrics.QueryNodeNumSegments.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),
fmt.Sprint(segment.collectionID),
fmt.Sprint(segment.partitionID),
segType.String(),
// Note: this field is mutable after segment is loaded.
fmt.Sprint(segment.indexedFieldInfos.Len()),
).Sub(float64(rowCount))
}
).Dec()
if rowCount > 0 {
metrics.QueryNodeNumEntities.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),
fmt.Sprint(segment.collectionID),
fmt.Sprint(segment.partitionID),
segType.String(),
fmt.Sprint(segment.indexedFieldInfos.Len()),
).Sub(float64(rowCount))
}*/
}
replica.sendNoSegmentSignal()
}

View File

@ -595,10 +595,7 @@ func genVectorChunkManager(ctx context.Context, col *Collection) (*storage.Vecto
return nil, err
}
vcm, err := storage.NewVectorChunkManager(ctx, lcm, rcm, &etcdpb.CollectionMeta{
ID: col.id,
Schema: col.schema,
}, Params.QueryNodeCfg.CacheMemoryLimit.GetAsInt64(), false)
vcm, err := storage.NewVectorChunkManager(ctx, lcm, rcm, Params.QueryNodeCfg.CacheMemoryLimit.GetAsInt64(), false)
if err != nil {
return nil, err
}
@ -908,7 +905,7 @@ func genStorageBlob(collectionID UniqueID,
}
tmpSchema.Fields = append(tmpSchema.Fields, schema.Fields...)
collMeta := genCollectionMeta(collectionID, tmpSchema)
inCodec := storage.NewInsertCodec(collMeta)
inCodec := storage.NewInsertCodecWithSchema(collMeta)
insertData, err := genInsertData(msgLength, schema)
if err != nil {
return nil, nil, err

View File

@ -23,7 +23,6 @@ import (
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/funcutil"
)
@ -68,11 +67,11 @@ func newQueryShard(
if remoteChunkManager == nil {
return nil, fmt.Errorf("can not create vector chunk manager for remote chunk manager is nil")
}
vectorChunkManager, err := storage.NewVectorChunkManager(ctx, localChunkManager, remoteChunkManager,
&etcdpb.CollectionMeta{
ID: collectionID,
Schema: collection.schema,
}, Params.QueryNodeCfg.CacheMemoryLimit.GetAsInt64(), localCacheEnabled)
vectorChunkManager, err := storage.NewVectorChunkManager(ctx,
localChunkManager,
remoteChunkManager,
Params.QueryNodeCfg.CacheMemoryLimit.GetAsInt64(),
localCacheEnabled)
if err != nil {
return nil, err
}

View File

@ -122,7 +122,7 @@ func (s *Segment) getType() segmentType {
}
func (s *Segment) setIndexedFieldInfo(fieldID UniqueID, info *IndexedFieldInfo) {
s.indexedFieldInfos.InsertIfNotPresent(fieldID, info)
s.indexedFieldInfos.GetOrInsert(fieldID, info)
}
func (s *Segment) getIndexedFieldInfo(fieldID UniqueID) (*IndexedFieldInfo, error) {

View File

@ -0,0 +1,72 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cluster
import (
"fmt"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
)
// Manager is the interface for worker manager.
type Manager interface {
GetWorker(nodeID int64) (Worker, error)
}
// WorkerBuilder is function alias to build a worker from NodeID
type WorkerBuilder func(nodeID int64) (Worker, error)
type grpcWorkerManager struct {
workers *typeutil.ConcurrentMap[int64, Worker]
builder WorkerBuilder
}
// GetWorker returns worker with specified nodeID.
func (m *grpcWorkerManager) GetWorker(nodeID int64) (Worker, error) {
worker, ok := m.workers.Get(nodeID)
var err error
if !ok {
//TODO merge request?
worker, err = m.builder(nodeID)
if err != nil {
log.Warn("failed to build worker",
zap.Int64("nodeID", nodeID),
zap.Error(err),
)
return nil, err
}
old, exist := m.workers.GetOrInsert(nodeID, worker)
if exist {
worker.Stop()
worker = old
}
}
if !worker.IsHealthy() {
// TODO wrap error
return nil, fmt.Errorf("node is not healthy: %d", nodeID)
}
return worker, nil
}
func NewWorkerManager(builder WorkerBuilder) Manager {
return &grpcWorkerManager{
workers: typeutil.NewConcurrentMap[int64, Worker](),
builder: builder,
}
}

View File

@ -0,0 +1,82 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cluster
import (
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
)
func TestManager(t *testing.T) {
t.Run("normal_get", func(t *testing.T) {
worker := &MockWorker{}
worker.EXPECT().IsHealthy().Return(true)
var buildErr error
var called int
builder := func(nodeID int64) (Worker, error) {
called++
return worker, buildErr
}
manager := NewWorkerManager(builder)
w, err := manager.GetWorker(0)
assert.Equal(t, worker, w)
assert.NoError(t, err)
assert.Equal(t, 1, called)
w, err = manager.GetWorker(0)
assert.Equal(t, worker, w)
assert.NoError(t, err)
assert.Equal(t, 1, called)
})
t.Run("builder_return_error", func(t *testing.T) {
worker := &MockWorker{}
worker.EXPECT().IsHealthy().Return(true)
var buildErr error
var called int
buildErr = errors.New("mocked error")
builder := func(nodeID int64) (Worker, error) {
called++
return worker, buildErr
}
manager := NewWorkerManager(builder)
_, err := manager.GetWorker(0)
assert.Error(t, err)
assert.Equal(t, 1, called)
})
t.Run("worker_not_healthy", func(t *testing.T) {
worker := &MockWorker{}
worker.EXPECT().IsHealthy().Return(false)
var buildErr error
var called int
builder := func(nodeID int64) (Worker, error) {
called++
return worker, buildErr
}
manager := NewWorkerManager(builder)
_, err := manager.GetWorker(0)
assert.Error(t, err)
assert.Equal(t, 1, called)
})
}

View File

@ -0,0 +1,79 @@
// Code generated by mockery v2.16.0. DO NOT EDIT.
package cluster
import mock "github.com/stretchr/testify/mock"
// MockManager is an autogenerated mock type for the Manager type
type MockManager struct {
mock.Mock
}
type MockManager_Expecter struct {
mock *mock.Mock
}
func (_m *MockManager) EXPECT() *MockManager_Expecter {
return &MockManager_Expecter{mock: &_m.Mock}
}
// GetWorker provides a mock function with given fields: nodeID
func (_m *MockManager) GetWorker(nodeID int64) (Worker, error) {
ret := _m.Called(nodeID)
var r0 Worker
if rf, ok := ret.Get(0).(func(int64) Worker); ok {
r0 = rf(nodeID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(Worker)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int64) error); ok {
r1 = rf(nodeID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockManager_GetWorker_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorker'
type MockManager_GetWorker_Call struct {
*mock.Call
}
// GetWorker is a helper method to define mock.On call
// - nodeID int64
func (_e *MockManager_Expecter) GetWorker(nodeID interface{}) *MockManager_GetWorker_Call {
return &MockManager_GetWorker_Call{Call: _e.mock.On("GetWorker", nodeID)}
}
func (_c *MockManager_GetWorker_Call) Run(run func(nodeID int64)) *MockManager_GetWorker_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
})
return _c
}
func (_c *MockManager_GetWorker_Call) Return(_a0 Worker, _a1 error) *MockManager_GetWorker_Call {
_c.Call.Return(_a0, _a1)
return _c
}
type mockConstructorTestingTNewMockManager interface {
mock.TestingT
Cleanup(func())
}
// NewMockManager creates a new instance of MockManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockManager(t mockConstructorTestingTNewMockManager) *MockManager {
mock := &MockManager{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -0,0 +1,358 @@
// Code generated by mockery v2.16.0. DO NOT EDIT.
package cluster
import (
context "context"
internalpb "github.com/milvus-io/milvus/internal/proto/internalpb"
mock "github.com/stretchr/testify/mock"
querypb "github.com/milvus-io/milvus/internal/proto/querypb"
)
// MockWorker is an autogenerated mock type for the Worker type
type MockWorker struct {
mock.Mock
}
type MockWorker_Expecter struct {
mock *mock.Mock
}
func (_m *MockWorker) EXPECT() *MockWorker_Expecter {
return &MockWorker_Expecter{mock: &_m.Mock}
}
// Delete provides a mock function with given fields: ctx, req
func (_m *MockWorker) Delete(ctx context.Context, req *querypb.DeleteRequest) error {
ret := _m.Called(ctx, req)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.DeleteRequest) error); ok {
r0 = rf(ctx, req)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockWorker_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete'
type MockWorker_Delete_Call struct {
*mock.Call
}
// Delete is a helper method to define mock.On call
// - ctx context.Context
// - req *querypb.DeleteRequest
func (_e *MockWorker_Expecter) Delete(ctx interface{}, req interface{}) *MockWorker_Delete_Call {
return &MockWorker_Delete_Call{Call: _e.mock.On("Delete", ctx, req)}
}
func (_c *MockWorker_Delete_Call) Run(run func(ctx context.Context, req *querypb.DeleteRequest)) *MockWorker_Delete_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.DeleteRequest))
})
return _c
}
func (_c *MockWorker_Delete_Call) Return(_a0 error) *MockWorker_Delete_Call {
_c.Call.Return(_a0)
return _c
}
// GetStatistics provides a mock function with given fields: ctx, req
func (_m *MockWorker) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) {
ret := _m.Called(ctx, req)
var r0 *internalpb.GetStatisticsResponse
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest) *internalpb.GetStatisticsResponse); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*internalpb.GetStatisticsResponse)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetStatisticsRequest) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockWorker_GetStatistics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStatistics'
type MockWorker_GetStatistics_Call struct {
*mock.Call
}
// GetStatistics is a helper method to define mock.On call
// - ctx context.Context
// - req *querypb.GetStatisticsRequest
func (_e *MockWorker_Expecter) GetStatistics(ctx interface{}, req interface{}) *MockWorker_GetStatistics_Call {
return &MockWorker_GetStatistics_Call{Call: _e.mock.On("GetStatistics", ctx, req)}
}
func (_c *MockWorker_GetStatistics_Call) Run(run func(ctx context.Context, req *querypb.GetStatisticsRequest)) *MockWorker_GetStatistics_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.GetStatisticsRequest))
})
return _c
}
func (_c *MockWorker_GetStatistics_Call) Return(_a0 *internalpb.GetStatisticsResponse, _a1 error) *MockWorker_GetStatistics_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// IsHealthy provides a mock function with given fields:
func (_m *MockWorker) IsHealthy() bool {
ret := _m.Called()
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// MockWorker_IsHealthy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsHealthy'
type MockWorker_IsHealthy_Call struct {
*mock.Call
}
// IsHealthy is a helper method to define mock.On call
func (_e *MockWorker_Expecter) IsHealthy() *MockWorker_IsHealthy_Call {
return &MockWorker_IsHealthy_Call{Call: _e.mock.On("IsHealthy")}
}
func (_c *MockWorker_IsHealthy_Call) Run(run func()) *MockWorker_IsHealthy_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockWorker_IsHealthy_Call) Return(_a0 bool) *MockWorker_IsHealthy_Call {
_c.Call.Return(_a0)
return _c
}
// LoadSegments provides a mock function with given fields: _a0, _a1
func (_m *MockWorker) LoadSegments(_a0 context.Context, _a1 *querypb.LoadSegmentsRequest) error {
ret := _m.Called(_a0, _a1)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadSegmentsRequest) error); ok {
r0 = rf(_a0, _a1)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockWorker_LoadSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadSegments'
type MockWorker_LoadSegments_Call struct {
*mock.Call
}
// LoadSegments is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.LoadSegmentsRequest
func (_e *MockWorker_Expecter) LoadSegments(_a0 interface{}, _a1 interface{}) *MockWorker_LoadSegments_Call {
return &MockWorker_LoadSegments_Call{Call: _e.mock.On("LoadSegments", _a0, _a1)}
}
func (_c *MockWorker_LoadSegments_Call) Run(run func(_a0 context.Context, _a1 *querypb.LoadSegmentsRequest)) *MockWorker_LoadSegments_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.LoadSegmentsRequest))
})
return _c
}
func (_c *MockWorker_LoadSegments_Call) Return(_a0 error) *MockWorker_LoadSegments_Call {
_c.Call.Return(_a0)
return _c
}
// Query provides a mock function with given fields: ctx, req
func (_m *MockWorker) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
ret := _m.Called(ctx, req)
var r0 *internalpb.RetrieveResults
if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) *internalpb.RetrieveResults); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*internalpb.RetrieveResults)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockWorker_Query_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Query'
type MockWorker_Query_Call struct {
*mock.Call
}
// Query is a helper method to define mock.On call
// - ctx context.Context
// - req *querypb.QueryRequest
func (_e *MockWorker_Expecter) Query(ctx interface{}, req interface{}) *MockWorker_Query_Call {
return &MockWorker_Query_Call{Call: _e.mock.On("Query", ctx, req)}
}
func (_c *MockWorker_Query_Call) Run(run func(ctx context.Context, req *querypb.QueryRequest)) *MockWorker_Query_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.QueryRequest))
})
return _c
}
func (_c *MockWorker_Query_Call) Return(_a0 *internalpb.RetrieveResults, _a1 error) *MockWorker_Query_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// ReleaseSegments provides a mock function with given fields: _a0, _a1
func (_m *MockWorker) ReleaseSegments(_a0 context.Context, _a1 *querypb.ReleaseSegmentsRequest) error {
ret := _m.Called(_a0, _a1)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseSegmentsRequest) error); ok {
r0 = rf(_a0, _a1)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockWorker_ReleaseSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleaseSegments'
type MockWorker_ReleaseSegments_Call struct {
*mock.Call
}
// ReleaseSegments is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.ReleaseSegmentsRequest
func (_e *MockWorker_Expecter) ReleaseSegments(_a0 interface{}, _a1 interface{}) *MockWorker_ReleaseSegments_Call {
return &MockWorker_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", _a0, _a1)}
}
func (_c *MockWorker_ReleaseSegments_Call) Run(run func(_a0 context.Context, _a1 *querypb.ReleaseSegmentsRequest)) *MockWorker_ReleaseSegments_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.ReleaseSegmentsRequest))
})
return _c
}
func (_c *MockWorker_ReleaseSegments_Call) Return(_a0 error) *MockWorker_ReleaseSegments_Call {
_c.Call.Return(_a0)
return _c
}
// Search provides a mock function with given fields: ctx, req
func (_m *MockWorker) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
ret := _m.Called(ctx, req)
var r0 *internalpb.SearchResults
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) *internalpb.SearchResults); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*internalpb.SearchResults)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SearchRequest) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockWorker_Search_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Search'
type MockWorker_Search_Call struct {
*mock.Call
}
// Search is a helper method to define mock.On call
// - ctx context.Context
// - req *querypb.SearchRequest
func (_e *MockWorker_Expecter) Search(ctx interface{}, req interface{}) *MockWorker_Search_Call {
return &MockWorker_Search_Call{Call: _e.mock.On("Search", ctx, req)}
}
func (_c *MockWorker_Search_Call) Run(run func(ctx context.Context, req *querypb.SearchRequest)) *MockWorker_Search_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.SearchRequest))
})
return _c
}
func (_c *MockWorker_Search_Call) Return(_a0 *internalpb.SearchResults, _a1 error) *MockWorker_Search_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// Stop provides a mock function with given fields:
func (_m *MockWorker) Stop() {
_m.Called()
}
// MockWorker_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
type MockWorker_Stop_Call struct {
*mock.Call
}
// Stop is a helper method to define mock.On call
func (_e *MockWorker_Expecter) Stop() *MockWorker_Stop_Call {
return &MockWorker_Stop_Call{Call: _e.mock.On("Stop")}
}
func (_c *MockWorker_Stop_Call) Run(run func()) *MockWorker_Stop_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockWorker_Stop_Call) Return() *MockWorker_Stop_Call {
_c.Call.Return()
return _c
}
type mockConstructorTestingTNewMockWorker interface {
mock.TestingT
Cleanup(func())
}
// NewMockWorker creates a new instance of MockWorker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockWorker(t mockConstructorTestingTNewMockWorker) *MockWorker {
mock := &MockWorker{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -0,0 +1,137 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// delegator package contains the logic of shard delegator.
package cluster
import (
"context"
"fmt"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/merr"
"go.uber.org/zap"
)
// Worker is the interface definition for querynode worker role.
type Worker interface {
LoadSegments(context.Context, *querypb.LoadSegmentsRequest) error
ReleaseSegments(context.Context, *querypb.ReleaseSegmentsRequest) error
Delete(ctx context.Context, req *querypb.DeleteRequest) error
Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error)
Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error)
GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error)
IsHealthy() bool
Stop()
}
// remoteWorker wraps grpc QueryNode client as Worker.
type remoteWorker struct {
client types.QueryNode
}
// NewRemoteWorker creates a grpcWorker.
func NewRemoteWorker(client types.QueryNode) Worker {
return &remoteWorker{
client: client,
}
}
// LoadSegments implements Worker.
func (w *remoteWorker) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error {
log := log.Ctx(ctx).With(
zap.Int64("workerID", req.GetDstNodeID()),
)
status, err := w.client.LoadSegments(ctx, req)
if err != nil {
log.Warn("failed to call LoadSegments via grpc worker",
zap.Error(err),
)
return err
} else if status.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("failed to call LoadSegments, worker return error",
zap.String("errorCode", status.GetErrorCode().String()),
zap.String("reason", status.GetReason()),
)
return fmt.Errorf(status.Reason)
}
return nil
}
func (w *remoteWorker) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) error {
log := log.Ctx(ctx).With(
zap.Int64("workerID", req.GetNodeID()),
)
status, err := w.client.ReleaseSegments(ctx, req)
if err != nil {
log.Warn("failed to call ReleaseSegments via grpc worker",
zap.Error(err),
)
return err
} else if status.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("failed to call ReleaseSegments, worker return error",
zap.String("errorCode", status.GetErrorCode().String()),
zap.String("reason", status.GetReason()),
)
return fmt.Errorf(status.Reason)
}
return nil
}
func (w *remoteWorker) Delete(ctx context.Context, req *querypb.DeleteRequest) error {
log := log.Ctx(ctx).With(
zap.Int64("workerID", req.GetBase().GetTargetID()),
)
status, err := w.client.Delete(ctx, req)
if err != nil {
log.Warn("failed to call Delete via grpc worker",
zap.Error(err),
)
return err
} else if status.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("failed to call Delete, worker return error",
zap.String("errorCode", status.GetErrorCode().String()),
zap.String("reason", status.GetReason()),
)
return merr.Error(status)
}
return nil
}
func (w *remoteWorker) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
return w.client.Search(ctx, req)
}
func (w *remoteWorker) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
return w.client.Query(ctx, req)
}
func (w *remoteWorker) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) {
return w.client.GetStatistics(ctx, req)
}
func (w *remoteWorker) IsHealthy() bool {
return true
}
func (w *remoteWorker) Stop() {
w.client.Stop()
}

View File

@ -0,0 +1,372 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// delegator package contains the logic of shard delegator.
package cluster
import (
context "context"
"testing"
"github.com/cockroachdb/errors"
commonpb "github.com/milvus-io/milvus-proto/go-api/commonpb"
internalpb "github.com/milvus-io/milvus/internal/proto/internalpb"
querypb "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)
type RemoteWorkerSuite struct {
suite.Suite
mockClient *types.MockQueryNode
worker *remoteWorker
}
func (s *RemoteWorkerSuite) SetupTest() {
s.mockClient = &types.MockQueryNode{}
s.worker = &remoteWorker{client: s.mockClient}
}
func (s *RemoteWorkerSuite) TearDownTest() {
s.mockClient = nil
s.worker = nil
}
func (s *RemoteWorkerSuite) TestLoadSegments() {
s.Run("normal_run", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
s.mockClient.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.worker.LoadSegments(ctx, &querypb.LoadSegmentsRequest{})
s.NoError(err)
})
s.Run("client_return_error", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
s.mockClient.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(nil, errors.New("mocked error"))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.worker.LoadSegments(ctx, &querypb.LoadSegmentsRequest{})
s.Error(err)
})
s.Run("client_return_fail_status", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
s.mockClient.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "mocked failure"}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.worker.LoadSegments(ctx, &querypb.LoadSegmentsRequest{})
s.Error(err)
})
}
func (s *RemoteWorkerSuite) TestReleaseSegments() {
s.Run("normal_run", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
s.mockClient.EXPECT().ReleaseSegments(mock.Anything, mock.AnythingOfType("*querypb.ReleaseSegmentsRequest")).
Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.worker.ReleaseSegments(ctx, &querypb.ReleaseSegmentsRequest{})
s.NoError(err)
})
s.Run("client_return_error", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
s.mockClient.EXPECT().ReleaseSegments(mock.Anything, mock.AnythingOfType("*querypb.ReleaseSegmentsRequest")).
Return(nil, errors.New("mocked error"))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.worker.ReleaseSegments(ctx, &querypb.ReleaseSegmentsRequest{})
s.Error(err)
})
s.Run("client_return_fail_status", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
s.mockClient.EXPECT().ReleaseSegments(mock.Anything, mock.AnythingOfType("*querypb.ReleaseSegmentsRequest")).
Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "mocked failure"}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.worker.ReleaseSegments(ctx, &querypb.ReleaseSegmentsRequest{})
s.Error(err)
})
}
func (s *RemoteWorkerSuite) TestDelete() {
s.Run("normal_run", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
s.mockClient.EXPECT().Delete(mock.Anything, mock.AnythingOfType("*querypb.DeleteRequest")).
Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.worker.Delete(ctx, &querypb.DeleteRequest{})
s.NoError(err)
})
s.Run("client_return_error", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
s.mockClient.EXPECT().Delete(mock.Anything, mock.AnythingOfType("*querypb.DeleteRequest")).
Return(nil, errors.New("mocked error"))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.worker.Delete(ctx, &querypb.DeleteRequest{})
s.Error(err)
})
s.Run("client_return_fail_status", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
s.mockClient.EXPECT().Delete(mock.Anything, mock.AnythingOfType("*querypb.DeleteRequest")).
Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "mocked failure"}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.worker.Delete(ctx, &querypb.DeleteRequest{})
s.Error(err)
})
}
func (s *RemoteWorkerSuite) TestSearch() {
s.Run("normal_run", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
var result *internalpb.SearchResults
var err error
result = &internalpb.SearchResults{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}
s.mockClient.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Return(result, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sr, serr := s.worker.Search(ctx, &querypb.SearchRequest{})
s.Equal(err, serr)
s.Equal(result, sr)
})
s.Run("client_return_error", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
var result *internalpb.SearchResults
err := errors.New("mocked error")
s.mockClient.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Return(result, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sr, serr := s.worker.Search(ctx, &querypb.SearchRequest{})
s.Equal(err, serr)
s.Equal(result, sr)
})
s.Run("client_return_fail_status", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
var result *internalpb.SearchResults
var err error
result = &internalpb.SearchResults{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError},
}
s.mockClient.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Return(result, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sr, serr := s.worker.Search(ctx, &querypb.SearchRequest{})
s.Equal(err, serr)
s.Equal(result, sr)
})
}
func (s *RemoteWorkerSuite) TestQuery() {
s.Run("normal_run", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
var result *internalpb.RetrieveResults
var err error
result = &internalpb.RetrieveResults{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}
s.mockClient.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).
Return(result, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sr, serr := s.worker.Query(ctx, &querypb.QueryRequest{})
s.Equal(err, serr)
s.Equal(result, sr)
})
s.Run("client_return_error", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
var result *internalpb.RetrieveResults
err := errors.New("mocked error")
s.mockClient.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).
Return(result, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sr, serr := s.worker.Query(ctx, &querypb.QueryRequest{})
s.Equal(err, serr)
s.Equal(result, sr)
})
s.Run("client_return_fail_status", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
var result *internalpb.RetrieveResults
var err error
result = &internalpb.RetrieveResults{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError},
}
s.mockClient.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).
Return(result, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sr, serr := s.worker.Query(ctx, &querypb.QueryRequest{})
s.Equal(err, serr)
s.Equal(result, sr)
})
}
func (s *RemoteWorkerSuite) TestGetStatistics() {
s.Run("normal_run", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
var result *internalpb.GetStatisticsResponse
var err error
result = &internalpb.GetStatisticsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}
s.mockClient.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).
Return(result, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sr, serr := s.worker.GetStatistics(ctx, &querypb.GetStatisticsRequest{})
s.Equal(err, serr)
s.Equal(result, sr)
})
s.Run("client_return_error", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
var result *internalpb.GetStatisticsResponse
err := errors.New("mocked error")
s.mockClient.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).
Return(result, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sr, serr := s.worker.GetStatistics(ctx, &querypb.GetStatisticsRequest{})
s.Equal(err, serr)
s.Equal(result, sr)
})
s.Run("client_return_fail_status", func() {
defer func() { s.mockClient.ExpectedCalls = nil }()
var result *internalpb.GetStatisticsResponse
var err error
result = &internalpb.GetStatisticsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError},
}
s.mockClient.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).
Return(result, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sr, serr := s.worker.GetStatistics(ctx, &querypb.GetStatisticsRequest{})
s.Equal(err, serr)
s.Equal(result, sr)
})
}
func (s *RemoteWorkerSuite) TestBasic() {
s.True(s.worker.IsHealthy())
s.mockClient.EXPECT().Stop().Return(nil)
s.worker.Stop()
s.mockClient.AssertCalled(s.T(), "Stop")
}
func TestRemoteWorker(t *testing.T) {
suite.Run(t, new(RemoteWorkerSuite))
}
func TestNewRemoteWorker(t *testing.T) {
client := &types.MockQueryNode{}
w := NewRemoteWorker(client)
rw, ok := w.(*remoteWorker)
assert.True(t, ok)
assert.Equal(t, client, rw.client)
}

View File

@ -0,0 +1,78 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package collector
import (
"sync"
)
type averageData struct {
total float64
count int32
}
func (v *averageData) Add(value float64) {
v.total += value
v.count++
}
func (v *averageData) Value() float64 {
if v.count == 0 {
return 0
}
return v.total / float64(v.count)
}
type averageCollector struct {
sync.Mutex
averages map[string]*averageData
}
func (c *averageCollector) Register(label string) {
c.Lock()
defer c.Unlock()
if _, ok := c.averages[label]; !ok {
c.averages[label] = &averageData{}
}
}
func (c *averageCollector) Add(label string, value float64) {
c.Lock()
defer c.Unlock()
if average, ok := c.averages[label]; ok {
average.Add(value)
}
}
func (c *averageCollector) Average(label string) (float64, error) {
c.Lock()
defer c.Unlock()
average, ok := c.averages[label]
if !ok {
return 0, WrapErrAvarageLabelNotRegister(label)
}
return average.Value(), nil
}
func newAverageCollector() *averageCollector {
return &averageCollector{
averages: make(map[string]*averageData),
}
}

View File

@ -0,0 +1,59 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package collector
import (
"testing"
"github.com/stretchr/testify/suite"
)
type AverageCollectorTestSuite struct {
suite.Suite
label string
average *averageCollector
}
func (suite *AverageCollectorTestSuite) SetupSuite() {
suite.average = newAverageCollector()
suite.label = "test_label"
}
func (suite *AverageCollectorTestSuite) TestBasic() {
//Get average not register
_, err := suite.average.Average(suite.label)
suite.Error(err)
//register and get
suite.average.Register(suite.label)
value, err := suite.average.Average(suite.label)
suite.Equal(float64(0), value)
suite.NoError(err)
//add and get
sum := 4
for i := 0; i <= sum; i++ {
suite.average.Add(suite.label, float64(i))
}
value, err = suite.average.Average(suite.label)
suite.NoError(err)
suite.Equal(float64(sum)/2, value)
}
func TestAverageCollector(t *testing.T) {
suite.Run(t, new(AverageCollectorTestSuite))
}

View File

@ -0,0 +1,78 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package collector
import (
"fmt"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/ratelimitutil"
)
var Average *averageCollector
var Rate *ratelimitutil.RateCollector
var Counter *counter
func RateMetrics() []string {
return []string{
metricsinfo.NQPerSecond,
metricsinfo.SearchThroughput,
metricsinfo.InsertConsumeThroughput,
metricsinfo.DeleteConsumeThroughput,
}
}
func AverageMetrics() []string {
return []string{
metricsinfo.QueryQueueMetric,
metricsinfo.SearchQueueMetric,
}
}
func ConstructLabel(subs ...string) string {
label := ""
for id, sub := range subs {
label += sub
if id != len(subs)-1 {
label += "-"
}
}
return label
}
func init() {
var err error
Rate, err = ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity)
if err != nil {
err = fmt.Errorf("querynode collector init failed, err = %s", err)
log.Error(err.Error())
panic(err)
}
Average = newAverageCollector()
Counter = newCounter()
//init rate Metric
for _, label := range RateMetrics() {
Rate.Register(label)
}
//init average metric
for _, label := range AverageMetrics() {
Average.Register(label)
}
}

View File

@ -0,0 +1,67 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package collector
import (
"sync"
)
type counter struct {
sync.Mutex
values map[string]int64
}
func (c *counter) Inc(label string, value int64) {
c.Lock()
defer c.Unlock()
v, ok := c.values[label]
if !ok {
c.values[label] = value
} else {
v += value
c.values[label] = v
}
}
func (c *counter) Get(label string) int64 {
c.Lock()
defer c.Unlock()
v, ok := c.values[label]
if !ok {
return 0
}
return v
}
func (c *counter) Remove(label string) bool {
c.Lock()
defer c.Unlock()
_, ok := c.values[label]
if ok {
delete(c.values, label)
}
return ok
}
func newCounter() *counter {
return &counter{
values: make(map[string]int64),
}
}

View File

@ -0,0 +1,54 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package collector
import (
"testing"
"github.com/stretchr/testify/suite"
)
type CounterTestSuite struct {
suite.Suite
label string
counter *counter
}
func (suite *CounterTestSuite) SetupSuite() {
suite.counter = newCounter()
suite.label = "test_label"
}
func (suite *CounterTestSuite) TestBasic() {
//get default value(zero)
value := suite.counter.Get(suite.label)
suite.Equal(int64(0), value)
//get after inc
suite.counter.Inc(suite.label, 3)
value = suite.counter.Get(suite.label)
suite.Equal(int64(3), value)
//remote
suite.counter.Remove(suite.label)
value = suite.counter.Get(suite.label)
suite.Equal(int64(0), value)
}
func TestCounter(t *testing.T) {
suite.Run(t, new(CounterTestSuite))
}

View File

@ -0,0 +1,31 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package collector
import (
"fmt"
"github.com/cockroachdb/errors"
)
var (
ErrAvarageLabelNotRegister = errors.New("AvarageLabelNotRegister")
)
func WrapErrAvarageLabelNotRegister(label string) error {
return fmt.Errorf("%w :%s", ErrAvarageLabelNotRegister, label)
}

View File

@ -0,0 +1,565 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// delegator package contains the logic of shard delegator.
package delegator
import (
"context"
"fmt"
"sync"
"time"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
"github.com/milvus-io/milvus/internal/querynodev2/delegator/deletebuffer"
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/samber/lo"
"go.uber.org/atomic"
"go.uber.org/zap"
)
type lifetime struct {
state atomic.Int32
closeCh chan struct{}
closeOnce sync.Once
}
func (lt *lifetime) SetState(state int32) {
lt.state.Store(state)
}
func (lt *lifetime) GetState() int32 {
return lt.state.Load()
}
func (lt *lifetime) Close() {
lt.closeOnce.Do(func() {
close(lt.closeCh)
})
}
func newLifetime() *lifetime {
return &lifetime{
closeCh: make(chan struct{}),
}
}
// ShardDelegator is the interface definition.
type ShardDelegator interface {
Collection() int64
Version() int64
GetDistribution() *distribution
SyncDistribution(ctx context.Context, entries ...SegmentEntry)
Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error)
Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error)
GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error)
//data
ProcessInsert(insertRecords map[int64]*InsertData)
ProcessDelete(deleteData []*DeleteData, ts uint64)
LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error
LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error
ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error
// control
Serviceable() bool
Start()
Close()
}
var _ ShardDelegator = (*shardDelegator)(nil)
const (
initializing int32 = iota
working
stopped
)
// shardDelegator maintains the shard distribution and streaming part of the data.
type shardDelegator struct {
// shard information attributes
collectionID int64
replicaID int64
vchannelName string
version int64
// collection schema
collection *segments.Collection
workerManager cluster.Manager
lifetime *lifetime
distribution *distribution
segmentManager segments.SegmentManager
tsafeManager tsafe.Manager
pkOracle pkoracle.PkOracle
// L0 delete buffer
deleteMut sync.Mutex
deleteBuffer deletebuffer.DeleteBuffer[*deletebuffer.Item]
//dispatcherClient msgdispatcher.Client
factory msgstream.Factory
loader segments.Loader
wg sync.WaitGroup
tsCond *sync.Cond
latestTsafe *atomic.Uint64
}
// getLogger returns the zap logger with pre-defined shard attributes.
func (sd *shardDelegator) getLogger(ctx context.Context) *log.MLogger {
return log.Ctx(ctx).With(
zap.Int64("collectionID", sd.collectionID),
zap.String("channel", sd.vchannelName),
zap.Int64("replicaID", sd.replicaID),
)
}
// Serviceable returns whether delegator is serviceable now.
func (sd *shardDelegator) Serviceable() bool {
return sd.lifetime.GetState() == working
}
// Start sets delegator to working state.
func (sd *shardDelegator) Start() {
sd.lifetime.SetState(working)
}
// Collection returns delegator collection id.
func (sd *shardDelegator) Collection() int64 {
return sd.collectionID
}
// Version returns delegator version.
func (sd *shardDelegator) Version() int64 {
return sd.version
}
func (sd *shardDelegator) GetDistribution() *distribution {
return sd.distribution
}
// SyncDistribution revises distribution.
func (sd *shardDelegator) SyncDistribution(ctx context.Context, entries ...SegmentEntry) {
log := sd.getLogger(ctx)
log.Info("sync distribution", zap.Any("entries", entries))
sd.distribution.AddDistributions(entries...)
}
func modifySearchRequest(req *querypb.SearchRequest, scope querypb.DataScope, segmentIDs []int64, targetID int64) *querypb.SearchRequest {
nodeReq := proto.Clone(req).(*querypb.SearchRequest)
nodeReq.Scope = scope
nodeReq.Req.Base.TargetID = targetID
nodeReq.SegmentIDs = segmentIDs
nodeReq.FromShardLeader = true
return nodeReq
}
func modifyQueryRequest(req *querypb.QueryRequest, scope querypb.DataScope, segmentIDs []int64, targetID int64) *querypb.QueryRequest {
nodeReq := proto.Clone(req).(*querypb.QueryRequest)
nodeReq.Scope = scope
nodeReq.Req.Base.TargetID = targetID
nodeReq.SegmentIDs = segmentIDs
nodeReq.FromShardLeader = true
return nodeReq
}
// Search preforms search operation on shard.
func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
log := sd.getLogger(ctx)
if !sd.Serviceable() {
return nil, errors.New("delegator is not serviceable")
}
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
log.Warn("deletgator received search request not belongs to it",
zap.Strings("reqChannels", req.GetDmlChannels()),
)
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
}
// wait tsafe
err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
if err != nil {
log.Warn("delegator search failed to wait tsafe", zap.Error(err))
return nil, err
}
sealed, growing, version := sd.distribution.GetCurrent(req.GetReq().GetPartitionIDs()...)
defer sd.distribution.FinishUsage(version)
if req.Req.IgnoreGrowing {
growing = []SegmentEntry{}
}
tasks, err := organizeSubTask(req, sealed, growing, sd.workerManager, modifySearchRequest)
if err != nil {
log.Warn("Search organizeSubTask failed", zap.Error(err))
return nil, err
}
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) {
return worker.Search(ctx, req)
}, "Search", log)
if err != nil {
log.Warn("Delegator search failed", zap.Error(err))
return nil, err
}
log.Info("Delegator search done")
return results, nil
}
// Query performs query operation on shard.
func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) {
log := sd.getLogger(ctx)
if !sd.Serviceable() {
return nil, errors.New("delegator is not serviceable")
}
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
log.Warn("deletgator received query request not belongs to it",
zap.Strings("reqChannels", req.GetDmlChannels()),
)
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
}
// wait tsafe
err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
if err != nil {
log.Warn("delegator query failed to wait tsafe", zap.Error(err))
return nil, err
}
sealed, growing, version := sd.distribution.GetCurrent(req.GetReq().GetPartitionIDs()...)
defer sd.distribution.FinishUsage(version)
if req.Req.IgnoreGrowing {
growing = []SegmentEntry{}
}
log.Info("query segments...",
zap.Int("sealedNum", len(sealed)),
zap.Int("growingNum", len(growing)),
)
tasks, err := organizeSubTask(req, sealed, growing, sd.workerManager, modifyQueryRequest)
if err != nil {
log.Warn("query organizeSubTask failed", zap.Error(err))
return nil, err
}
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.QueryRequest, worker cluster.Worker) (*internalpb.RetrieveResults, error) {
return worker.Query(ctx, req)
}, "Query", log)
if err != nil {
log.Warn("Delegator query failed", zap.Error(err))
return nil, err
}
log.Info("Delegator Query done")
return results, nil
}
// GetStatistics returns statistics aggregated by delegator.
func (sd *shardDelegator) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error) {
log := sd.getLogger(ctx)
if !sd.Serviceable() {
return nil, errors.New("delegator is not serviceable")
}
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
log.Warn("deletgator received query request not belongs to it",
zap.Strings("reqChannels", req.GetDmlChannels()),
)
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
}
// wait tsafe
err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
if err != nil {
log.Warn("delegator query failed to wait tsafe", zap.Error(err))
return nil, err
}
sealed, growing, version := sd.distribution.GetCurrent(req.Req.GetPartitionIDs()...)
defer sd.distribution.FinishUsage(version)
tasks, err := organizeSubTask(req, sealed, growing, sd.workerManager, func(req *querypb.GetStatisticsRequest, scope querypb.DataScope, segmentIDs []int64, targetID int64) *querypb.GetStatisticsRequest {
nodeReq := proto.Clone(req).(*querypb.GetStatisticsRequest)
nodeReq.GetReq().GetBase().TargetID = targetID
nodeReq.Scope = scope
nodeReq.SegmentIDs = segmentIDs
nodeReq.FromShardLeader = true
return nodeReq
})
if err != nil {
log.Warn("Get statistics organizeSubTask failed", zap.Error(err))
return nil, err
}
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.GetStatisticsRequest, worker cluster.Worker) (*internalpb.GetStatisticsResponse, error) {
return worker.GetStatistics(ctx, req)
}, "GetStatistics", log)
if err != nil {
log.Warn("Delegator get statistics failed", zap.Error(err))
return nil, err
}
return results, nil
}
type subTask[T any] struct {
req T
targetID int64
worker cluster.Worker
}
func organizeSubTask[T any](req T, sealed []SnapshotItem, growing []SegmentEntry, workerManager cluster.Manager, modify func(T, querypb.DataScope, []int64, int64) T) ([]subTask[T], error) {
result := make([]subTask[T], 0, len(sealed)+1)
packSubTask := func(segments []SegmentEntry, workerID int64, scope querypb.DataScope) error {
segmentIDs := lo.Map(segments, func(item SegmentEntry, _ int) int64 {
return item.SegmentID
})
if len(segmentIDs) == 0 {
return nil
}
// update request
req := modify(req, scope, segmentIDs, workerID)
worker, err := workerManager.GetWorker(workerID)
if err != nil {
log.Warn("failed to get worker",
zap.Int64("nodeID", workerID),
zap.Error(err),
)
return fmt.Errorf("failed to get worker %d, %w", workerID, err)
}
result = append(result, subTask[T]{
req: req,
targetID: workerID,
worker: worker,
})
return nil
}
for _, entry := range sealed {
err := packSubTask(entry.Segments, entry.NodeID, querypb.DataScope_Historical)
if err != nil {
return nil, err
}
}
packSubTask(growing, paramtable.GetNodeID(), querypb.DataScope_Streaming)
return result, nil
}
func executeSubTasks[T any, R interface {
GetStatus() *commonpb.Status
}](ctx context.Context, tasks []subTask[T], execute func(context.Context, T, cluster.Worker) (R, error), taskType string, log *log.MLogger) ([]R, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var wg sync.WaitGroup
wg.Add(len(tasks))
resultCh := make(chan R, len(tasks))
errCh := make(chan error, 1)
for _, task := range tasks {
go func(task subTask[T]) {
defer wg.Done()
result, err := execute(ctx, task.req, task.worker)
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
err = fmt.Errorf("worker(%d) query failed: %s", task.targetID, result.GetStatus().GetReason())
}
if err != nil {
log.Warn("failed to execute sub task",
zap.String("taskType", taskType),
zap.Int64("nodeID", task.targetID),
zap.Error(err),
)
select {
case errCh <- err: // must be the first
default: // skip other errors
}
cancel()
return
}
resultCh <- result
}(task)
}
wg.Wait()
close(resultCh)
select {
case err := <-errCh:
log.Warn("Delegator execute subTask failed",
zap.String("taskType", taskType),
zap.Error(err),
)
return nil, err
default:
}
results := make([]R, 0, len(tasks))
for result := range resultCh {
results = append(results, result)
}
return results, nil
}
// waitTSafe returns when tsafe listener notifies a timestamp which meet the guarantee ts.
func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error {
log := sd.getLogger(ctx)
// already safe to search
if sd.latestTsafe.Load() >= ts {
return nil
}
// check lag duration too large
st, _ := tsoutil.ParseTS(sd.latestTsafe.Load())
gt, _ := tsoutil.ParseTS(ts)
lag := gt.Sub(st)
maxLag := paramtable.Get().QueryNodeCfg.MaxTimestampLag.GetAsDuration(time.Second)
if lag > maxLag {
log.Warn("guarantee and servicable ts larger than MaxLag",
zap.Time("guaranteeTime", gt),
zap.Time("serviceableTime", st),
zap.Duration("lag", lag),
zap.Duration("maxTsLag", maxLag),
)
return WrapErrTsLagTooLarge(lag, maxLag)
}
ch := make(chan struct{})
go func() {
sd.tsCond.L.Lock()
defer sd.tsCond.L.Unlock()
for sd.latestTsafe.Load() < ts && ctx.Err() == nil {
sd.tsCond.Wait()
}
close(ch)
}()
for {
select {
// timeout
case <-ctx.Done():
// notify wait goroutine to quit
sd.tsCond.Broadcast()
return ctx.Err()
case <-ch:
return nil
}
}
}
// watchTSafe is the worker function to update serviceable timestamp.
func (sd *shardDelegator) watchTSafe() {
defer sd.wg.Done()
listener := sd.tsafeManager.WatchChannel(sd.vchannelName)
sd.updateTSafe()
log := sd.getLogger(context.Background())
for {
select {
case _, ok := <-listener.On():
if !ok {
// listener close
log.Warn("tsafe listener closed")
return
}
sd.updateTSafe()
case <-sd.lifetime.closeCh:
log.Info("updateTSafe quit")
// shard delegator closed
return
}
}
}
// updateTSafe read current tsafe value from tsafeManager.
func (sd *shardDelegator) updateTSafe() {
sd.tsCond.L.Lock()
tsafe, err := sd.tsafeManager.Get(sd.vchannelName)
if err != nil {
log.Warn("tsafeManager failed to get lastest", zap.Error(err))
}
if tsafe > sd.latestTsafe.Load() {
sd.latestTsafe.Store(tsafe)
sd.tsCond.Broadcast()
}
sd.tsCond.L.Unlock()
}
// Close closes the delegator.
func (sd *shardDelegator) Close() {
sd.lifetime.SetState(stopped)
sd.lifetime.Close()
sd.wg.Wait()
}
// NewShardDelegator creates a new ShardDelegator instance with all fields initialized.
func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string, version int64,
workerManager cluster.Manager, manager *segments.Manager, tsafeManager tsafe.Manager, loader segments.Loader,
factory msgstream.Factory, startTs uint64) (ShardDelegator, error) {
collection := manager.Collection.Get(collectionID)
if collection == nil {
return nil, fmt.Errorf("collection(%d) not found in manager", collectionID)
}
maxSegmentDeleteBuffer := paramtable.Get().QueryNodeCfg.MaxSegmentDeleteBuffer.GetAsInt64()
log.Info("Init delte cache", zap.Int64("maxSegmentCacheBuffer", maxSegmentDeleteBuffer), zap.Time("startTime", tsoutil.PhysicalTime(startTs)))
sd := &shardDelegator{
collectionID: collectionID,
replicaID: replicaID,
vchannelName: channel,
version: version,
collection: collection,
segmentManager: manager.Segment,
workerManager: workerManager,
lifetime: newLifetime(),
distribution: NewDistribution(),
deleteBuffer: deletebuffer.NewDoubleCacheDeleteBuffer[*deletebuffer.Item](startTs, maxSegmentDeleteBuffer),
pkOracle: pkoracle.NewPkOracle(),
tsafeManager: tsafeManager,
latestTsafe: atomic.NewUint64(0),
loader: loader,
factory: factory,
}
m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
sd.wg.Add(1)
go sd.watchTSafe()
return sd, nil
}

View File

@ -0,0 +1,615 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package delegator
import (
"context"
"fmt"
"math/rand"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/msgpb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
"github.com/milvus-io/milvus/internal/querynodev2/delegator/deletebuffer"
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/merr"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/samber/lo"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
)
// delegator data related part
// InsertData
type InsertData struct {
RowIDs []int64
PrimaryKeys []storage.PrimaryKey
Timestamps []uint64
InsertRecord *segcorepb.InsertRecord
StartPosition *msgpb.MsgPosition
PartitionID int64
}
type DeleteData struct {
PartitionID int64
PrimaryKeys []storage.PrimaryKey
Timestamps []uint64
RowCount int64
}
// Append appends another delete data into this one.
func (d *DeleteData) Append(ad DeleteData) {
d.PrimaryKeys = append(d.PrimaryKeys, ad.PrimaryKeys...)
d.Timestamps = append(d.Timestamps, ad.Timestamps...)
d.RowCount += ad.RowCount
}
func (sd *shardDelegator) newGrowing(segmentID int64, insertData *InsertData) segments.Segment {
log := sd.getLogger(context.Background()).With(zap.Int64("segmentID", segmentID))
// try add partition
if sd.collection.GetLoadType() == loadTypeCollection {
sd.collection.AddPartition(insertData.PartitionID)
}
segment, err := segments.NewSegment(sd.collection, segmentID, insertData.PartitionID, sd.collectionID, sd.vchannelName, segments.SegmentTypeGrowing, 0, insertData.StartPosition, insertData.StartPosition)
if err != nil {
log.Error("failed to create new segment", zap.Error(err))
panic(err)
}
sd.pkOracle.Register(segment, paramtable.GetNodeID())
sd.segmentManager.Put(segments.SegmentTypeGrowing, segment)
sd.addGrowing(SegmentEntry{
NodeID: paramtable.GetNodeID(),
SegmentID: segmentID,
PartitionID: insertData.PartitionID,
Version: 0,
})
return segment
}
// ProcessInsert handles insert data in delegator.
func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) {
log := sd.getLogger(context.Background())
for segmentID, insertData := range insertRecords {
growing := sd.segmentManager.GetGrowing(segmentID)
if growing == nil {
growing = sd.newGrowing(segmentID, insertData)
}
err := growing.Insert(insertData.RowIDs, insertData.Timestamps, insertData.InsertRecord)
if err != nil {
log.Error("failed to insert data into growing segment",
zap.Int64("segmentID", segmentID),
zap.Error(err),
)
// panic here, insert failure
panic(err)
}
growing.UpdateBloomFilter(insertData.PrimaryKeys)
log.Debug("insert into growing segment",
zap.Int64("collectionID", growing.Collection()),
zap.Int64("segmentID", segmentID),
zap.Int("rowCount", len(insertData.RowIDs)),
zap.Uint64("maxTimestamp", insertData.Timestamps[len(insertData.Timestamps)-1]),
)
}
}
// ProcessDelete handles delete data in delegator.
// delegator puts deleteData into buffer first,
// then dispatch data to segments acoording to the result of pkOracle.
func (sd *shardDelegator) ProcessDelete(deleteData []*DeleteData, ts uint64) {
// block load segment handle delete buffer
sd.deleteMut.Lock()
defer sd.deleteMut.Unlock()
log := sd.getLogger(context.Background())
log.Debug("start to process delete", zap.Uint64("ts", ts))
// add deleteData into buffer.
cacheItems := make([]deletebuffer.BufferItem, 0, len(deleteData))
for _, entry := range deleteData {
cacheItems = append(cacheItems, deletebuffer.BufferItem{
PartitionID: entry.PartitionID,
DeleteData: storage.DeleteData{
Pks: entry.PrimaryKeys,
Tss: entry.Timestamps,
RowCount: entry.RowCount,
},
})
}
sd.deleteBuffer.Put(&deletebuffer.Item{
Ts: ts,
Data: cacheItems,
})
// segment => delete data
delRecords := make(map[int64]DeleteData)
for _, data := range deleteData {
for i, pk := range data.PrimaryKeys {
segmentIDs, err := sd.pkOracle.Get(pk, pkoracle.WithPartitionID(data.PartitionID))
if err != nil {
log.Warn("failed to get delete candidates for pk", zap.Any("pk", pk.GetValue()))
continue
}
for _, segmentID := range segmentIDs {
delRecord := delRecords[segmentID]
delRecord.PrimaryKeys = append(delRecord.PrimaryKeys, pk)
delRecord.Timestamps = append(delRecord.Timestamps, data.Timestamps[i])
delRecord.RowCount++
delRecords[segmentID] = delRecord
}
}
}
offlineSegments := typeutil.NewConcurrentSet[int64]()
sealed, growing, version := sd.distribution.GetCurrent()
eg, ctx := errgroup.WithContext(context.Background())
for _, entry := range sealed {
entry := entry
eg.Go(func() error {
worker, err := sd.workerManager.GetWorker(entry.NodeID)
if err != nil {
log.Warn("failed to get worker",
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Error(err),
)
// skip if node down
// delete will be processed after loaded again
return nil
}
offlineSegments.Upsert(sd.applyDelete(ctx, entry.NodeID, worker, delRecords, entry.Segments)...)
return nil
})
}
if len(growing) > 0 {
eg.Go(func() error {
worker, err := sd.workerManager.GetWorker(paramtable.GetNodeID())
if err != nil {
log.Error("failed to get worker(local)",
zap.Int64("nodeID", paramtable.GetNodeID()),
zap.Error(err),
)
// panic here, local worker shall not have error
panic(err)
}
offlineSegments.Upsert(sd.applyDelete(ctx, paramtable.GetNodeID(), worker, delRecords, growing)...)
return nil
})
}
// not error return in apply delete
_ = eg.Wait()
sd.distribution.FinishUsage(version)
offlineSegIDs := offlineSegments.Collect()
if len(offlineSegIDs) > 0 {
log.Warn("failed to apply delete, mark segment offline", zap.Int64s("offlineSegments", offlineSegIDs))
sd.markSegmentOffline(offlineSegIDs...)
}
}
// applyDelete handles delete record and apply them to corresponding workers.
func (sd *shardDelegator) applyDelete(ctx context.Context, nodeID int64, worker cluster.Worker, delRecords map[int64]DeleteData, entries []SegmentEntry) []int64 {
var offlineSegments []int64
log := sd.getLogger(ctx)
for _, segmentEntry := range entries {
log := log.With(
zap.Int64("segmentID", segmentEntry.SegmentID),
zap.Int64("workerID", nodeID),
)
delRecord, ok := delRecords[segmentEntry.SegmentID]
if ok {
log.Debug("delegator plan to applyDelete via worker")
err := retry.Do(ctx, func() error {
err := worker.Delete(ctx, &querypb.DeleteRequest{
Base: commonpbutil.NewMsgBase(commonpbutil.WithTargetID(nodeID)),
CollectionId: sd.collectionID,
PartitionId: segmentEntry.PartitionID,
VchannelName: sd.vchannelName,
SegmentId: segmentEntry.SegmentID,
PrimaryKeys: storage.ParsePrimaryKeys2IDs(delRecord.PrimaryKeys),
Timestamps: delRecord.Timestamps,
})
if errors.Is(err, merr.ErrSegmentNotFound) {
log.Warn("try to delete data of released segment")
return nil
} else if err != nil {
log.Warn("worker failed to delete on segment",
zap.Error(err),
)
return err
}
return nil
}, retry.Attempts(10))
if err != nil {
log.Warn("apply delete for segment failed, marking it offline")
offlineSegments = append(offlineSegments, segmentEntry.SegmentID)
}
}
}
return offlineSegments
}
// markSegmentOffline makes segment go offline and waits for QueryCoord to fix.
func (sd *shardDelegator) markSegmentOffline(segmentIDs ...int64) {
sd.distribution.AddOfflines(segmentIDs...)
}
// addGrowing add growing segment record for delegator.
func (sd *shardDelegator) addGrowing(entries ...SegmentEntry) {
log := sd.getLogger(context.Background())
log.Info("add growing segments to delegator", zap.Int64s("segmentIDs", lo.Map(entries, func(entry SegmentEntry, _ int) int64 {
return entry.SegmentID
})))
sd.distribution.AddGrowing(entries...)
}
// LoadGrowing load growing segments locally.
func (sd *shardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error {
log := sd.getLogger(ctx)
loaded, err := sd.loader.Load(ctx, sd.collectionID, segments.SegmentTypeGrowing, version, infos...)
if err != nil {
log.Warn("failed to load growing segment", zap.Error(err))
for _, segment := range loaded {
segments.DeleteSegment(segment.(*segments.LocalSegment))
}
return err
}
for _, candidate := range loaded {
sd.pkOracle.Register(candidate, paramtable.GetNodeID())
}
sd.segmentManager.Put(segments.SegmentTypeGrowing, loaded...)
sd.addGrowing(lo.Map(loaded, func(segment segments.Segment, _ int) SegmentEntry {
return SegmentEntry{
NodeID: paramtable.GetNodeID(),
SegmentID: segment.ID(),
PartitionID: segment.Partition(),
Version: version,
}
})...)
return nil
}
// LoadSegments load segments local or remotely depends on the target node.
func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error {
log := sd.getLogger(ctx)
targetNodeID := req.GetDstNodeID()
// add common log fields
log = log.With(
zap.Int64("workID", req.GetDstNodeID()),
zap.Int64s("segments", lo.Map(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) int64 { return info.GetSegmentID() })),
)
worker, err := sd.workerManager.GetWorker(targetNodeID)
if err != nil {
log.Warn("delegator failed to find worker", zap.Error(err))
return err
}
// load bloom filter only when candidate not exists
infos := lo.Filter(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) bool {
return !sd.pkOracle.Exists(pkoracle.NewCandidateKey(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed), targetNodeID)
})
candidates, err := sd.loader.LoadBloomFilterSet(ctx, req.GetCollectionID(), req.GetVersion(), infos...)
if err != nil {
log.Warn("failed to load bloom filter set for segment", zap.Error(err))
return err
}
req.Base.TargetID = req.GetDstNodeID()
log.Info("worker loads segments...")
err = worker.LoadSegments(ctx, req)
if err != nil {
log.Warn("worker failed to load segments", zap.Error(err))
return err
}
log.Info("work loads segments done")
log.Info("load delete...")
err = sd.loadStreamDelete(ctx, candidates, infos, targetNodeID, worker)
if err != nil {
log.Warn("load stream delete failed", zap.Error(err))
return err
}
log.Info("load delete done")
// alter distribution
entries := lo.Map(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) SegmentEntry {
return SegmentEntry{
SegmentID: info.GetSegmentID(),
PartitionID: info.GetPartitionID(),
NodeID: req.GetDstNodeID(),
Version: req.GetVersion(),
}
})
removed := sd.distribution.AddDistributions(entries...)
// call worker release async
if len(removed) > 0 {
go func() {
worker, err := sd.workerManager.GetWorker(paramtable.GetNodeID())
if err != nil {
log.Warn("failed to get local worker when try to release related growing", zap.Error(err))
return
}
err = worker.ReleaseSegments(context.Background(), &querypb.ReleaseSegmentsRequest{
Base: commonpbutil.NewMsgBase(commonpbutil.WithTargetID(paramtable.GetNodeID())),
CollectionID: sd.collectionID,
NodeID: paramtable.GetNodeID(),
Scope: querypb.DataScope_Streaming,
SegmentIDs: removed,
Shard: sd.vchannelName,
NeedTransfer: false,
})
if err != nil {
log.Warn("failed to call release segments(local)", zap.Error(err))
}
}()
}
return nil
}
func (sd *shardDelegator) loadStreamDelete(ctx context.Context, candidates []*pkoracle.BloomFilterSet, infos []*querypb.SegmentLoadInfo,
targetNodeID int64, worker cluster.Worker) error {
log := sd.getLogger(ctx)
sd.deleteMut.Lock()
defer sd.deleteMut.Unlock()
idCandidates := lo.SliceToMap(candidates, func(candidate *pkoracle.BloomFilterSet) (int64, *pkoracle.BloomFilterSet) {
return candidate.ID(), candidate
})
// apply buffered delete for new segments
// no goroutines here since qnv2 has no load merging logic
for _, info := range infos {
candidate := idCandidates[info.GetSegmentID()]
deleteData := &storage.DeleteData{}
// start position is dml position for segment
// if this position is before deleteBuffer's safe ts, it means some delete shall be read from msgstream
if info.GetEndPosition().GetTimestamp() < sd.deleteBuffer.SafeTs() {
log.Info("load delete from stream...")
var err error
deleteData, err = sd.readDeleteFromMsgstream(ctx, info.GetEndPosition(), sd.deleteBuffer.SafeTs(), candidate)
if err != nil {
log.Warn("failed to read delete data from msgstream", zap.Error(err))
return err
}
log.Info("load delete from stream done")
}
// list buffered delete
deleteRecords := sd.deleteBuffer.ListAfter(info.GetEndPosition().GetTimestamp())
for _, entry := range deleteRecords {
for _, record := range entry.Data {
if record.PartitionID != common.InvalidPartitionID && candidate.Partition() != record.PartitionID {
continue
}
for i, pk := range record.DeleteData.Pks {
if candidate.MayPkExist(pk) {
deleteData.Pks = append(deleteData.Pks, pk)
deleteData.Tss = append(deleteData.Tss, record.DeleteData.Tss[i])
deleteData.RowCount++
}
}
}
}
// if delete count not empty, apply
if deleteData.RowCount > 0 {
log.Info("forward delete to worker...", zap.Int64("deleteRowNum", deleteData.RowCount))
err := worker.Delete(ctx, &querypb.DeleteRequest{
Base: commonpbutil.NewMsgBase(commonpbutil.WithTargetID(targetNodeID)),
CollectionId: info.GetCollectionID(),
PartitionId: info.GetPartitionID(),
SegmentId: info.GetSegmentID(),
PrimaryKeys: storage.ParsePrimaryKeys2IDs(deleteData.Pks),
Timestamps: deleteData.Tss,
})
if err != nil {
log.Warn("failed to apply delete when LoadSegment", zap.Error(err))
return err
}
}
}
// add candidate after load success
for _, candidate := range candidates {
log.Info("register sealed segment bfs into pko candidates",
zap.Int64("segmentID", candidate.ID()),
)
sd.pkOracle.Register(candidate, targetNodeID)
}
return nil
}
func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position *msgpb.MsgPosition, safeTs uint64, candidate *pkoracle.BloomFilterSet) (*storage.DeleteData, error) {
log := sd.getLogger(ctx).With(
zap.String("channel", position.ChannelName),
zap.Int64("segmentID", candidate.ID()),
)
stream, err := sd.factory.NewTtMsgStream(ctx)
if err != nil {
return nil, err
}
vchannelName := position.ChannelName
pChannelName := funcutil.ToPhysicalChannel(vchannelName)
position.ChannelName = pChannelName
ts, _ := tsoutil.ParseTS(position.Timestamp)
// Random the subname in case we trying to load same delta at the same time
subName := fmt.Sprintf("querynode-delta-loader-%d-%d-%d", paramtable.GetNodeID(), sd.collectionID, rand.Int())
log.Info("from dml check point load delete", zap.Any("position", position), zap.String("subName", subName), zap.Time("positionTs", ts))
stream.AsConsumer([]string{pChannelName}, subName, mqwrapper.SubscriptionPositionUnknown)
err = stream.Seek([]*msgpb.MsgPosition{position})
if err != nil {
return nil, err
}
result := &storage.DeleteData{}
hasMore := true
for hasMore {
select {
case <-ctx.Done():
log.Debug("read delta msg from seek position done", zap.Error(ctx.Err()))
return nil, ctx.Err()
case msgPack, ok := <-stream.Chan():
if !ok {
err = fmt.Errorf("stream channel closed, pChannelName=%v, msgID=%v", pChannelName, position.GetMsgID())
log.Warn("fail to read delta msg",
zap.String("pChannelName", pChannelName),
zap.Binary("msgID", position.GetMsgID()),
zap.Error(err),
)
return nil, err
}
if msgPack == nil {
continue
}
for _, tsMsg := range msgPack.Msgs {
if tsMsg.Type() == commonpb.MsgType_Delete {
dmsg := tsMsg.(*msgstream.DeleteMsg)
if dmsg.CollectionID != sd.collectionID || dmsg.GetPartitionID() != candidate.Partition() {
continue
}
for idx, pk := range storage.ParseIDs2PrimaryKeys(dmsg.GetPrimaryKeys()) {
if candidate.MayPkExist(pk) {
result.Pks = append(result.Pks, pk)
result.Tss = append(result.Tss, dmsg.Timestamps[idx])
}
}
}
}
// reach safe ts
if safeTs <= msgPack.EndPositions[0].GetTimestamp() {
hasMore = false
break
}
}
}
return result, nil
}
// ReleaseSegments releases segments local or remotely depends ont the target node.
func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error {
log := sd.getLogger(ctx)
targetNodeID := req.GetNodeID()
// add common log fields
log = log.With(
zap.Int64s("segmentIDs", req.GetSegmentIDs()),
zap.Int64("nodeID", req.GetNodeID()),
zap.String("scope", req.GetScope().String()),
zap.Bool("force", force))
log.Info("delegator start to release segments")
// alter distribution first
if force {
targetNodeID = wildcardNodeID
}
var sealed, growing []SegmentEntry
convertSealed := func(segmentID int64, _ int) SegmentEntry {
return SegmentEntry{
SegmentID: segmentID,
NodeID: targetNodeID,
}
}
convertGrowing := func(segmentID int64, _ int) SegmentEntry {
return SegmentEntry{
SegmentID: segmentID,
}
}
switch req.GetScope() {
case querypb.DataScope_All:
sealed = lo.Map(req.GetSegmentIDs(), convertSealed)
growing = lo.Map(req.GetSegmentIDs(), convertGrowing)
case querypb.DataScope_Streaming:
growing = lo.Map(req.GetSegmentIDs(), convertGrowing)
case querypb.DataScope_Historical:
sealed = lo.Map(req.GetSegmentIDs(), convertSealed)
}
signal := sd.distribution.RemoveDistributions(sealed, growing)
// wait cleared signal
<-signal
if len(sealed) > 0 {
sd.pkOracle.Remove(
pkoracle.WithSegmentIDs(lo.Map(sealed, func(entry SegmentEntry, _ int) int64 { return entry.SegmentID })...),
pkoracle.WithSegmentType(commonpb.SegmentState_Sealed),
pkoracle.WithWorkerID(targetNodeID),
)
}
if len(growing) > 0 {
sd.pkOracle.Remove(
pkoracle.WithSegmentIDs(lo.Map(growing, func(entry SegmentEntry, _ int) int64 { return entry.SegmentID })...),
pkoracle.WithSegmentType(commonpb.SegmentState_Growing),
)
}
if !force {
worker, err := sd.workerManager.GetWorker(targetNodeID)
if err != nil {
log.Warn("delegator failed to find worker",
zap.Error(err),
)
return err
}
err = worker.ReleaseSegments(ctx, req)
if err != nil {
log.Warn("worker failed to release segments",
zap.Error(err),
)
}
return err
}
return nil
}

View File

@ -0,0 +1,657 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package delegator
import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/bits-and-blooms/bloom/v3"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/msgpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)
type DelegatorDataSuite struct {
suite.Suite
collectionID int64
replicaID int64
vchannelName string
version int64
workerManager *cluster.MockManager
manager *segments.Manager
tsafeManager tsafe.Manager
loader *segments.MockLoader
mq *msgstream.MockMsgStream
delegator ShardDelegator
}
func (s *DelegatorDataSuite) SetupSuite() {
paramtable.Init()
paramtable.SetNodeID(1)
}
func (s *DelegatorDataSuite) SetupTest() {
s.collectionID = 1000
s.replicaID = 65535
s.vchannelName = "rootcoord-dml_1000_v0"
s.version = 2000
s.workerManager = &cluster.MockManager{}
s.manager = segments.NewManager()
s.tsafeManager = tsafe.NewTSafeReplica()
s.loader = &segments.MockLoader{}
// init schema
s.manager.Collection.Put(s.collectionID, &schemapb.CollectionSchema{
Name: "TestCollection",
Fields: []*schemapb.FieldSchema{
{
Name: "id",
FieldID: 100,
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
AutoID: true,
},
{
Name: "vector",
FieldID: 101,
IsPrimaryKey: false,
DataType: schemapb.DataType_BinaryVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
},
},
}, &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection,
})
s.mq = &msgstream.MockMsgStream{}
var err error
s.delegator, err = NewShardDelegator(s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000)
s.Require().NoError(err)
}
func (s *DelegatorDataSuite) TestProcessInsert() {
s.Run("normal_insert", func() {
s.delegator.ProcessInsert(map[int64]*InsertData{
100: {
RowIDs: []int64{0, 1},
PrimaryKeys: []storage.PrimaryKey{storage.NewInt64PrimaryKey(1), storage.NewInt64PrimaryKey(2)},
Timestamps: []uint64{10, 10},
PartitionID: 500,
StartPosition: &msgpb.MsgPosition{},
InsertRecord: &segcorepb.InsertRecord{
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "id",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{1, 2},
},
},
},
},
FieldId: 100,
},
{
Type: schemapb.DataType_FloatVector,
FieldName: "vector",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 128,
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{Data: make([]float32, 128*2)},
},
},
},
FieldId: 101,
},
},
NumRows: 2,
},
},
})
s.NotNil(s.manager.Segment.GetGrowing(100))
})
s.Run("insert_bad_data", func() {
s.Panics(func() {
s.delegator.ProcessInsert(map[int64]*InsertData{
100: {
RowIDs: []int64{0, 1},
PrimaryKeys: []storage.PrimaryKey{storage.NewInt64PrimaryKey(1), storage.NewInt64PrimaryKey(2)},
Timestamps: []uint64{10, 10},
PartitionID: 500,
StartPosition: &msgpb.MsgPosition{},
InsertRecord: &segcorepb.InsertRecord{
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "id",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{1, 2},
},
},
},
},
FieldId: 100,
},
},
NumRows: 2,
},
},
})
})
})
}
func (s *DelegatorDataSuite) TestProcessDelete() {
s.loader.EXPECT().
Load(mock.Anything, s.collectionID, segments.SegmentTypeGrowing, int64(0), mock.Anything).
Call.Return(func(ctx context.Context, collectionID int64, segmentType segments.SegmentType, version int64, infos ...*querypb.SegmentLoadInfo) []segments.Segment {
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) segments.Segment {
ms := &segments.MockSegment{}
ms.EXPECT().ID().Return(info.GetSegmentID())
ms.EXPECT().Type().Return(segments.SegmentTypeGrowing)
ms.EXPECT().Collection().Return(info.GetCollectionID())
ms.EXPECT().Partition().Return(info.GetPartitionID())
ms.EXPECT().Indexes().Return(nil)
ms.EXPECT().RowNum().Return(info.GetNumOfRows())
ms.EXPECT().MayPkExist(mock.Anything).Call.Return(func(pk storage.PrimaryKey) bool {
return pk.EQ(storage.NewInt64PrimaryKey(10))
})
return ms
})
}, nil)
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet {
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet {
bfs := pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed)
bf := bloom.NewWithEstimates(storage.BloomFilterSize, storage.MaxBloomFalsePositive)
pks := &storage.PkStatistics{
PkFilter: bf,
}
pks.UpdatePKRange(&storage.Int64FieldData{
Data: []int64{10, 20, 30},
})
bfs.AddHistoricalStats(pks)
return bfs
})
}, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error {
return nil
})
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
workers[1] = worker1
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(nil)
worker1.EXPECT().Delete(mock.Anything, mock.AnythingOfType("*querypb.DeleteRequest")).Return(nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
// load growing
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.delegator.LoadGrowing(ctx, []*querypb.SegmentLoadInfo{
{
SegmentID: 1001,
CollectionID: s.collectionID,
PartitionID: 500,
},
}, 0)
s.Require().NoError(err)
// load sealed
s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
Base: commonpbutil.NewMsgBase(),
DstNodeID: 1,
CollectionID: s.collectionID,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: 1000,
CollectionID: s.collectionID,
PartitionID: 500,
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
},
},
})
s.Require().NoError(err)
s.delegator.ProcessDelete([]*DeleteData{
{
PartitionID: 500,
PrimaryKeys: []storage.PrimaryKey{storage.NewInt64PrimaryKey(10)},
Timestamps: []uint64{10},
RowCount: 1,
},
}, 10)
}
func (s *DelegatorDataSuite) TestLoadSegments() {
s.Run("normal_run", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
s.loader.ExpectedCalls = nil
}()
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet {
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet {
return pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed)
})
}, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error {
return nil
})
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
workers[1] = worker1
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
Base: commonpbutil.NewMsgBase(),
DstNodeID: 1,
CollectionID: s.collectionID,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: 100,
PartitionID: 500,
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
},
},
})
s.NoError(err)
sealed, _, _ := s.delegator.GetDistribution().GetCurrent()
s.Require().Equal(1, len(sealed))
s.Equal(int64(1), sealed[0].NodeID)
s.ElementsMatch([]SegmentEntry{
{
SegmentID: 100,
NodeID: 1,
PartitionID: 500,
},
}, sealed[0].Segments)
})
s.Run("load_segments_with_delete", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
s.loader.ExpectedCalls = nil
}()
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet {
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet {
bfs := pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed)
bf := bloom.NewWithEstimates(storage.BloomFilterSize, storage.MaxBloomFalsePositive)
pks := &storage.PkStatistics{
PkFilter: bf,
}
pks.UpdatePKRange(&storage.Int64FieldData{
Data: []int64{10, 20, 30},
})
bfs.AddHistoricalStats(pks)
return bfs
})
}, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error {
return nil
})
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
workers[1] = worker1
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(nil)
worker1.EXPECT().Delete(mock.Anything, mock.AnythingOfType("*querypb.DeleteRequest")).Return(nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
s.delegator.ProcessDelete([]*DeleteData{
{
PartitionID: 500,
PrimaryKeys: []storage.PrimaryKey{
storage.NewInt64PrimaryKey(1),
storage.NewInt64PrimaryKey(10),
},
Timestamps: []uint64{10, 10},
RowCount: 2,
},
}, 10)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
Base: commonpbutil.NewMsgBase(),
DstNodeID: 1,
CollectionID: s.collectionID,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: 200,
PartitionID: 500,
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
},
},
})
s.NoError(err)
sealed, _, _ := s.delegator.GetDistribution().GetCurrent()
s.Require().Equal(1, len(sealed))
s.Equal(int64(1), sealed[0].NodeID)
s.ElementsMatch([]SegmentEntry{
{
SegmentID: 100,
NodeID: 1,
PartitionID: 500,
},
{
SegmentID: 200,
NodeID: 1,
PartitionID: 500,
},
}, sealed[0].Segments)
})
s.Run("get_worker_fail", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
s.loader.ExpectedCalls = nil
}()
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Return(nil, errors.New("mock error"))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
Base: commonpbutil.NewMsgBase(),
DstNodeID: 1,
CollectionID: s.collectionID,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: 100,
PartitionID: 500,
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
},
},
})
s.Error(err)
})
s.Run("loader_bfs_fail", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
s.loader.ExpectedCalls = nil
}()
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
Return(nil, errors.New("mocked error"))
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
workers[1] = worker1
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
Base: commonpbutil.NewMsgBase(),
DstNodeID: 1,
CollectionID: s.collectionID,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: 100,
PartitionID: 500,
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
},
},
})
s.Error(err)
})
s.Run("worker_load_fail", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
s.loader.ExpectedCalls = nil
}()
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet {
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet {
return pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed)
})
}, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error {
return nil
})
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
workers[1] = worker1
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(errors.New("mocked error"))
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
Base: commonpbutil.NewMsgBase(),
DstNodeID: 1,
CollectionID: s.collectionID,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: 100,
PartitionID: 500,
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
},
},
})
s.Error(err)
})
}
func (s *DelegatorDataSuite) TestReleaseSegment() {
s.loader.EXPECT().
Load(mock.Anything, s.collectionID, segments.SegmentTypeGrowing, int64(0), mock.Anything).
Call.Return(func(ctx context.Context, collectionID int64, segmentType segments.SegmentType, version int64, infos ...*querypb.SegmentLoadInfo) []segments.Segment {
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) segments.Segment {
ms := &segments.MockSegment{}
ms.EXPECT().ID().Return(info.GetSegmentID())
ms.EXPECT().Type().Return(segments.SegmentTypeGrowing)
ms.EXPECT().Partition().Return(info.GetPartitionID())
ms.EXPECT().Collection().Return(info.GetCollectionID())
ms.EXPECT().Indexes().Return(nil)
ms.EXPECT().RowNum().Return(info.GetNumOfRows())
ms.EXPECT().MayPkExist(mock.Anything).Call.Return(func(pk storage.PrimaryKey) bool {
return pk.EQ(storage.NewInt64PrimaryKey(10))
})
return ms
})
}, nil)
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet {
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet {
bfs := pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed)
bf := bloom.NewWithEstimates(storage.BloomFilterSize, storage.MaxBloomFalsePositive)
pks := &storage.PkStatistics{
PkFilter: bf,
}
pks.UpdatePKRange(&storage.Int64FieldData{
Data: []int64{10, 20, 30},
})
bfs.AddHistoricalStats(pks)
return bfs
})
}, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error {
return nil
})
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
workers[1] = worker1
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(nil)
worker1.EXPECT().ReleaseSegments(mock.Anything, mock.AnythingOfType("*querypb.ReleaseSegmentsRequest")).Return(nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
// load growing
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.delegator.LoadGrowing(ctx, []*querypb.SegmentLoadInfo{
{
SegmentID: 1001,
CollectionID: s.collectionID,
PartitionID: 500,
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
},
}, 0)
s.Require().NoError(err)
// load sealed
s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
Base: commonpbutil.NewMsgBase(),
DstNodeID: 1,
CollectionID: s.collectionID,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: 1000,
CollectionID: s.collectionID,
PartitionID: 500,
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
},
},
})
s.Require().NoError(err)
sealed, growing, version := s.delegator.GetDistribution().GetCurrent()
s.delegator.GetDistribution().FinishUsage(version)
s.Require().Equal(1, len(sealed))
s.Equal(int64(1), sealed[0].NodeID)
s.ElementsMatch([]SegmentEntry{
{
SegmentID: 1000,
NodeID: 1,
PartitionID: 500,
},
}, sealed[0].Segments)
s.ElementsMatch([]SegmentEntry{
{
SegmentID: 1001,
NodeID: 1,
PartitionID: 500,
},
}, growing)
err = s.delegator.ReleaseSegments(ctx, &querypb.ReleaseSegmentsRequest{
Base: commonpbutil.NewMsgBase(),
NodeID: 1,
SegmentIDs: []int64{1000},
Scope: querypb.DataScope_Historical,
}, false)
s.NoError(err)
sealed, _, version = s.delegator.GetDistribution().GetCurrent()
s.delegator.GetDistribution().FinishUsage(version)
s.Equal(0, len(sealed))
err = s.delegator.ReleaseSegments(ctx, &querypb.ReleaseSegmentsRequest{
Base: commonpbutil.NewMsgBase(),
NodeID: 1,
SegmentIDs: []int64{1001},
Scope: querypb.DataScope_Streaming,
}, false)
s.NoError(err)
_, growing, version = s.delegator.GetDistribution().GetCurrent()
s.delegator.GetDistribution().FinishUsage(version)
s.Equal(0, len(growing))
err = s.delegator.ReleaseSegments(ctx, &querypb.ReleaseSegmentsRequest{
Base: commonpbutil.NewMsgBase(),
NodeID: 1,
SegmentIDs: []int64{1000},
Scope: querypb.DataScope_All,
}, true)
s.NoError(err)
}
func TestDelegatorDataSuite(t *testing.T) {
suite.Run(t, new(DelegatorDataSuite))
}

View File

@ -0,0 +1,831 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package delegator
import (
"context"
"sync"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
)
type DelegatorSuite struct {
suite.Suite
collectionID int64
replicaID int64
vchannelName string
version int64
workerManager *cluster.MockManager
manager *segments.Manager
tsafeManager tsafe.Manager
loader *segments.MockLoader
mq *msgstream.MockMsgStream
delegator ShardDelegator
}
func (s *DelegatorSuite) SetupSuite() {
paramtable.Init()
}
func (s *DelegatorSuite) TearDownSuite() {
}
func (s *DelegatorSuite) SetupTest() {
s.collectionID = 1000
s.replicaID = 65535
s.vchannelName = "rootcoord-dml_1000_v0"
s.version = 2000
s.workerManager = &cluster.MockManager{}
s.manager = segments.NewManager()
s.tsafeManager = tsafe.NewTSafeReplica()
s.loader = &segments.MockLoader{}
s.loader.EXPECT().
Load(mock.Anything, s.collectionID, segments.SegmentTypeGrowing, int64(0), mock.Anything).
Call.Return(func(ctx context.Context, collectionID int64, segmentType segments.SegmentType, version int64, infos ...*querypb.SegmentLoadInfo) []segments.Segment {
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) segments.Segment {
ms := &segments.MockSegment{}
ms.EXPECT().ID().Return(info.GetSegmentID())
ms.EXPECT().Type().Return(segments.SegmentTypeGrowing)
ms.EXPECT().Partition().Return(info.GetPartitionID())
ms.EXPECT().Collection().Return(info.GetCollectionID())
ms.EXPECT().Indexes().Return(nil)
ms.EXPECT().RowNum().Return(info.GetNumOfRows())
return ms
})
}, nil)
// init schema
s.manager.Collection.Put(s.collectionID, &schemapb.CollectionSchema{
Name: "TestCollection",
Fields: []*schemapb.FieldSchema{
{
Name: "id",
FieldID: 100,
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
AutoID: true,
},
{
Name: "vector",
FieldID: 101,
IsPrimaryKey: false,
DataType: schemapb.DataType_BinaryVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
},
},
}, &querypb.LoadMetaInfo{})
s.mq = &msgstream.MockMsgStream{}
var err error
// s.delegator, err = NewShardDelegator(s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader)
s.delegator, err = NewShardDelegator(s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000)
s.Require().NoError(err)
}
func (s *DelegatorSuite) TearDownTest() {
s.delegator.Close()
s.delegator = nil
}
func (s *DelegatorSuite) TestBasicInfo() {
s.Equal(s.collectionID, s.delegator.Collection())
s.Equal(s.version, s.delegator.Version())
s.False(s.delegator.Serviceable())
s.delegator.Start()
s.True(s.delegator.Serviceable())
}
func (s *DelegatorSuite) TestDistribution() {
sealed, growing, version := s.delegator.GetDistribution().GetCurrent()
s.delegator.GetDistribution().FinishUsage(version)
s.Equal(0, len(sealed))
s.Equal(0, len(growing))
s.delegator.SyncDistribution(context.Background(), SegmentEntry{
NodeID: 1,
SegmentID: 1001,
PartitionID: 500,
Version: 2001,
})
sealed, growing, version = s.delegator.GetDistribution().GetCurrent()
s.EqualValues([]SnapshotItem{
{
NodeID: 1,
Segments: []SegmentEntry{
{
NodeID: 1,
SegmentID: 1001,
PartitionID: 500,
Version: 2001,
},
},
},
}, sealed)
s.Equal(0, len(growing))
s.delegator.GetDistribution().FinishUsage(version)
}
func (s *DelegatorSuite) TestSearch() {
s.delegator.Start()
// 1 => sealed segment 1000, 1001
// 1 => growing segment 1004
// 2 => sealed segment 1002, 1003
paramtable.SetNodeID(1)
s.delegator.LoadGrowing(context.Background(), []*querypb.SegmentLoadInfo{
{
SegmentID: 1004,
CollectionID: s.collectionID,
PartitionID: 500,
},
}, 0)
s.delegator.SyncDistribution(context.Background(),
SegmentEntry{
NodeID: 1,
SegmentID: 1000,
PartitionID: 500,
Version: 2001,
},
SegmentEntry{
NodeID: 1,
SegmentID: 1001,
PartitionID: 501,
Version: 2001,
},
SegmentEntry{
NodeID: 2,
SegmentID: 1002,
PartitionID: 500,
Version: 2001,
},
SegmentEntry{
NodeID: 2,
SegmentID: 1003,
PartitionID: 501,
Version: 2001,
},
)
s.Run("normal", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
}()
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
worker2 := &cluster.MockWorker{}
workers[1] = worker1
workers[2] = worker2
worker1.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Run(func(_ context.Context, req *querypb.SearchRequest) {
s.EqualValues(1, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
if req.GetScope() == querypb.DataScope_Streaming {
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1004}, req.GetSegmentIDs())
}
if req.GetScope() == querypb.DataScope_Historical {
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1000, 1001}, req.GetSegmentIDs())
}
}).Return(&internalpb.SearchResults{}, nil)
worker2.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Run(func(_ context.Context, req *querypb.SearchRequest) {
s.EqualValues(2, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
s.Equal(querypb.DataScope_Historical, req.GetScope())
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.SearchResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
results, err := s.delegator.Search(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{s.vchannelName},
})
s.NoError(err)
s.Equal(3, len(results))
})
s.Run("worker_return_error", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
}()
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
worker2 := &cluster.MockWorker{}
workers[1] = worker1
workers[2] = worker2
worker1.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(nil, errors.New("mock error"))
worker2.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Run(func(_ context.Context, req *querypb.SearchRequest) {
s.EqualValues(2, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
s.Equal(querypb.DataScope_Historical, req.GetScope())
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.SearchResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.Search(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("worker_return_failure_code", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
}()
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
worker2 := &cluster.MockWorker{}
workers[1] = worker1
workers[2] = worker2
worker1.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(&internalpb.SearchResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mocked error",
},
}, nil)
worker2.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
Run(func(_ context.Context, req *querypb.SearchRequest) {
s.EqualValues(2, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
s.Equal(querypb.DataScope_Historical, req.GetScope())
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.SearchResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.Search(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("wrong_channel", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.Search(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{"non_exist_channel"},
})
s.Error(err)
})
s.Run("wait_tsafe_timeout", func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()
_, err := s.delegator.Search(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
Base: commonpbutil.NewMsgBase(),
GuaranteeTimestamp: 100,
},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("tsafe_behind_max_lag", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.Search(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
Base: commonpbutil.NewMsgBase(),
GuaranteeTimestamp: uint64(paramtable.Get().QueryNodeCfg.MaxTimestampLag.GetAsDuration(time.Second)) + 1,
},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("cluster_not_serviceable", func() {
s.delegator.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.Search(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
}
func (s *DelegatorSuite) TestQuery() {
s.delegator.Start()
// 1 => sealed segment 1000, 1001
// 1 => growing segment 1004
// 2 => sealed segment 1002, 1003
paramtable.SetNodeID(1)
s.delegator.LoadGrowing(context.Background(), []*querypb.SegmentLoadInfo{
{
SegmentID: 1004,
CollectionID: s.collectionID,
PartitionID: 500,
},
}, 0)
s.delegator.SyncDistribution(context.Background(),
SegmentEntry{
NodeID: 1,
SegmentID: 1000,
PartitionID: 500,
Version: 2001,
},
SegmentEntry{
NodeID: 1,
SegmentID: 1001,
PartitionID: 501,
Version: 2001,
},
SegmentEntry{
NodeID: 2,
SegmentID: 1002,
PartitionID: 500,
Version: 2001,
},
SegmentEntry{
NodeID: 2,
SegmentID: 1003,
PartitionID: 501,
Version: 2001,
},
)
s.Run("normal", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
}()
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
worker2 := &cluster.MockWorker{}
workers[1] = worker1
workers[2] = worker2
worker1.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).
Run(func(_ context.Context, req *querypb.QueryRequest) {
s.EqualValues(1, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
if req.GetScope() == querypb.DataScope_Streaming {
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1004}, req.GetSegmentIDs())
}
if req.GetScope() == querypb.DataScope_Historical {
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1000, 1001}, req.GetSegmentIDs())
}
}).Return(&internalpb.RetrieveResults{}, nil)
worker2.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).
Run(func(_ context.Context, req *querypb.QueryRequest) {
s.EqualValues(2, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
s.Equal(querypb.DataScope_Historical, req.GetScope())
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.RetrieveResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
results, err := s.delegator.Query(ctx, &querypb.QueryRequest{
Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{s.vchannelName},
})
s.NoError(err)
s.Equal(3, len(results))
})
s.Run("worker_return_error", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
}()
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
worker2 := &cluster.MockWorker{}
workers[1] = worker1
workers[2] = worker2
worker1.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).Return(nil, errors.New("mock error"))
worker2.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).
Run(func(_ context.Context, req *querypb.QueryRequest) {
s.EqualValues(2, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
s.Equal(querypb.DataScope_Historical, req.GetScope())
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.RetrieveResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.Query(ctx, &querypb.QueryRequest{
Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("worker_return_failure_code", func() {
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
worker2 := &cluster.MockWorker{}
workers[1] = worker1
workers[2] = worker2
worker1.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).Return(&internalpb.RetrieveResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mocked error",
},
}, nil)
worker2.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).
Run(func(_ context.Context, req *querypb.QueryRequest) {
s.EqualValues(2, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
s.Equal(querypb.DataScope_Historical, req.GetScope())
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.RetrieveResults{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.Query(ctx, &querypb.QueryRequest{
Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("wrong_channel", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.Query(ctx, &querypb.QueryRequest{
Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{"non_exist_channel"},
})
s.Error(err)
})
s.Run("cluster_not_serviceable", func() {
s.delegator.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.Query(ctx, &querypb.QueryRequest{
Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
}
func (s *DelegatorSuite) TestGetStats() {
s.delegator.Start()
// 1 => sealed segment 1000, 1001
// 1 => growing segment 1004
// 2 => sealed segment 1002, 1003
paramtable.SetNodeID(1)
s.delegator.LoadGrowing(context.Background(), []*querypb.SegmentLoadInfo{
{
SegmentID: 1004,
CollectionID: s.collectionID,
PartitionID: 500,
},
}, 0)
s.delegator.SyncDistribution(context.Background(),
SegmentEntry{
NodeID: 1,
SegmentID: 1000,
PartitionID: 500,
Version: 2001,
},
SegmentEntry{
NodeID: 1,
SegmentID: 1001,
PartitionID: 501,
Version: 2001,
},
SegmentEntry{
NodeID: 2,
SegmentID: 1002,
PartitionID: 500,
Version: 2001,
},
SegmentEntry{
NodeID: 2,
SegmentID: 1003,
PartitionID: 501,
Version: 2001,
},
)
s.Run("normal", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
}()
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
worker2 := &cluster.MockWorker{}
workers[1] = worker1
workers[2] = worker2
worker1.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).
Run(func(_ context.Context, req *querypb.GetStatisticsRequest) {
s.EqualValues(1, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
if req.GetScope() == querypb.DataScope_Streaming {
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1004}, req.GetSegmentIDs())
}
if req.GetScope() == querypb.DataScope_Historical {
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1000, 1001}, req.GetSegmentIDs())
}
}).Return(&internalpb.GetStatisticsResponse{}, nil)
worker2.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).
Run(func(_ context.Context, req *querypb.GetStatisticsRequest) {
s.EqualValues(2, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
s.Equal(querypb.DataScope_Historical, req.GetScope())
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.GetStatisticsResponse{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
results, err := s.delegator.GetStatistics(ctx, &querypb.GetStatisticsRequest{
Req: &internalpb.GetStatisticsRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{s.vchannelName},
})
s.NoError(err)
s.Equal(3, len(results))
})
s.Run("worker_return_error", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
}()
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
worker2 := &cluster.MockWorker{}
workers[1] = worker1
workers[2] = worker2
worker1.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).Return(nil, errors.New("mock error"))
worker2.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).
Run(func(_ context.Context, req *querypb.GetStatisticsRequest) {
s.EqualValues(2, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
s.Equal(querypb.DataScope_Historical, req.GetScope())
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.GetStatisticsResponse{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.GetStatistics(ctx, &querypb.GetStatisticsRequest{
Req: &internalpb.GetStatisticsRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("worker_return_failure_code", func() {
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
worker2 := &cluster.MockWorker{}
workers[1] = worker1
workers[2] = worker2
worker1.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).Return(&internalpb.GetStatisticsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mocked error",
},
}, nil)
worker2.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).
Run(func(_ context.Context, req *querypb.GetStatisticsRequest) {
s.EqualValues(2, req.Req.GetBase().GetTargetID())
s.True(req.GetFromShardLeader())
s.Equal(querypb.DataScope_Historical, req.GetScope())
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
}).Return(&internalpb.GetStatisticsResponse{}, nil)
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.GetStatistics(ctx, &querypb.GetStatisticsRequest{
Req: &internalpb.GetStatisticsRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
s.Run("wrong_channel", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.GetStatistics(ctx, &querypb.GetStatisticsRequest{
Req: &internalpb.GetStatisticsRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{"non_exist_channel"},
})
s.Error(err)
})
s.Run("cluster_not_serviceable", func() {
s.delegator.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := s.delegator.GetStatistics(ctx, &querypb.GetStatisticsRequest{
Req: &internalpb.GetStatisticsRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{s.vchannelName},
})
s.Error(err)
})
}
func TestDelegatorSuite(t *testing.T) {
suite.Run(t, new(DelegatorSuite))
}
func TestDelegatorWatchTsafe(t *testing.T) {
channelName := "default_dml_channel"
tsafeManager := tsafe.NewTSafeReplica()
tsafeManager.Add(channelName, 100)
sd := &shardDelegator{
tsafeManager: tsafeManager,
vchannelName: channelName,
lifetime: newLifetime(),
latestTsafe: atomic.NewUint64(0),
}
defer sd.Close()
m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
sd.wg.Add(1)
go sd.watchTSafe()
err := tsafeManager.Set(channelName, 200)
require.NoError(t, err)
assert.Eventually(t, func() bool {
return sd.latestTsafe.Load() == 200
}, time.Second*10, time.Millisecond*10)
}
func TestDelegatorTSafeListenerClosed(t *testing.T) {
channelName := "default_dml_channel"
tsafeManager := tsafe.NewTSafeReplica()
tsafeManager.Add(channelName, 100)
sd := &shardDelegator{
tsafeManager: tsafeManager,
vchannelName: channelName,
lifetime: newLifetime(),
latestTsafe: atomic.NewUint64(0),
}
defer sd.Close()
m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
sd.wg.Add(1)
signal := make(chan struct{})
go func() {
sd.watchTSafe()
close(signal)
}()
select {
case <-signal:
assert.FailNow(t, "watchTsafe quit unexpectedly")
case <-time.After(time.Millisecond * 10):
}
tsafeManager.Remove(channelName)
select {
case <-signal:
case <-time.After(time.Second):
assert.FailNow(t, "watchTsafe still working after listener closed")
}
}

View File

@ -0,0 +1,138 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deletebuffer
import (
"sort"
"sync"
"github.com/cockroachdb/errors"
)
var (
errBufferFull = errors.New("buffer full")
)
type timed interface {
Timestamp() uint64
Size() int64
}
// DeleteBuffer is the interface for delete buffer.
type DeleteBuffer[T timed] interface {
Put(T)
ListAfter(uint64) []T
SafeTs() uint64
}
func NewDoubleCacheDeleteBuffer[T timed](startTs uint64, maxSize int64) DeleteBuffer[T] {
return &doubleCacheBuffer[T]{
head: newDoubleCacheItem[T](startTs, maxSize),
maxSize: maxSize,
ts: startTs,
}
}
// doubleCacheBuffer implements DeleteBuffer with fixed sized double cache.
type doubleCacheBuffer[T timed] struct {
mut sync.RWMutex
head, tail *doubleCacheItem[T]
maxSize int64
ts uint64
}
func (c *doubleCacheBuffer[T]) SafeTs() uint64 {
return c.ts
}
// Put implements DeleteBuffer.
func (c *doubleCacheBuffer[T]) Put(entry T) {
c.mut.Lock()
defer c.mut.Unlock()
err := c.head.Put(entry)
if errors.Is(err, errBufferFull) {
c.evict(entry.Timestamp())
c.head.Put(entry)
}
}
// ListAfter implements DeleteBuffer.
func (c *doubleCacheBuffer[T]) ListAfter(ts uint64) []T {
c.mut.RLock()
defer c.mut.RUnlock()
var result []T
if c.tail != nil {
result = append(result, c.tail.ListAfter(ts)...)
}
if c.head != nil {
result = append(result, c.head.ListAfter(ts)...)
}
return result
}
// evict sets head as tail and evicts tail.
func (c *doubleCacheBuffer[T]) evict(newTs uint64) {
c.tail = c.head
c.head = newDoubleCacheItem[T](newTs, c.maxSize/2)
c.ts = c.tail.headTs
}
func newDoubleCacheItem[T timed](ts uint64, maxSize int64) *doubleCacheItem[T] {
return &doubleCacheItem[T]{
headTs: ts,
maxSize: maxSize,
}
}
type doubleCacheItem[T timed] struct {
mut sync.RWMutex
headTs uint64
size int64
maxSize int64
data []T
}
// Cache adds entry into cache item.
// returns error if item is full
func (c *doubleCacheItem[T]) Put(entry T) error {
c.mut.Lock()
defer c.mut.Unlock()
if c.size+entry.Size() > c.maxSize {
return errBufferFull
}
c.data = append(c.data, entry)
c.size += entry.Size()
return nil
}
// ListAfter returns entries of which ts after provided value.
func (c *doubleCacheItem[T]) ListAfter(ts uint64) []T {
c.mut.RLock()
defer c.mut.RUnlock()
idx := sort.Search(len(c.data), func(idx int) bool {
return c.data[idx].Timestamp() >= ts
})
// not found
if idx == len(c.data) {
return nil
}
return c.data[idx:]
}

View File

@ -0,0 +1,82 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deletebuffer
import (
"testing"
"github.com/milvus-io/milvus/internal/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
func TestSkipListDeleteBuffer(t *testing.T) {
db := NewDeleteBuffer()
db.Cache(10, []BufferItem{
{PartitionID: 1},
})
result := db.List(0)
assert.Equal(t, 1, len(result))
assert.Equal(t, int64(1), result[0][0].PartitionID)
db.TruncateBefore(11)
result = db.List(0)
assert.Equal(t, 0, len(result))
}
type DoubleCacheBufferSuite struct {
suite.Suite
}
func (s *DoubleCacheBufferSuite) TestNewBuffer() {
buffer := NewDoubleCacheDeleteBuffer[*Item](10, 1000)
s.EqualValues(10, buffer.SafeTs())
}
func (s *DoubleCacheBufferSuite) TestCache() {
buffer := NewDoubleCacheDeleteBuffer[*Item](10, 1000)
buffer.Put(&Item{
Ts: 11,
Data: []BufferItem{
{
PartitionID: 200,
DeleteData: storage.DeleteData{},
},
},
})
buffer.Put(&Item{
Ts: 12,
Data: []BufferItem{
{
PartitionID: 200,
DeleteData: storage.DeleteData{},
},
},
})
s.Equal(2, len(buffer.ListAfter(11)))
s.Equal(1, len(buffer.ListAfter(12)))
}
func TestDoubleCacheDeleteBuffer(t *testing.T) {
suite.Run(t, new(DoubleCacheBufferSuite))
}

View File

@ -0,0 +1,38 @@
package deletebuffer
import (
"github.com/milvus-io/milvus/internal/storage"
"github.com/samber/lo"
)
// Item wraps cache item as `timed`.
type Item struct {
Ts uint64
Data []BufferItem
}
// Timestamp implements `timed`.
func (item *Item) Timestamp() uint64 {
return item.Ts
}
// Size implements `timed`.
func (item *Item) Size() int64 {
return lo.Reduce(item.Data, func(size int64, item BufferItem, _ int) int64 {
return size + item.Size()
}, int64(0))
}
type BufferItem struct {
PartitionID int64
DeleteData storage.DeleteData
}
func (item *BufferItem) Size() int64 {
var pkSize int64
if len(item.DeleteData.Pks) > 0 {
pkSize = int64(len(item.DeleteData.Pks)) * item.DeleteData.Pks[0].Size()
}
return int64(96) + pkSize + int64(8*len(item.DeleteData.Tss))
}

View File

@ -0,0 +1,22 @@
package deletebuffer
import (
"testing"
"github.com/milvus-io/milvus/internal/storage"
"github.com/stretchr/testify/assert"
)
func TestDeleteBufferItem(t *testing.T) {
item := &BufferItem{
PartitionID: 100,
DeleteData: storage.DeleteData{},
}
assert.Equal(t, int64(96), item.Size())
item.DeleteData.Pks = []storage.PrimaryKey{
storage.NewInt64PrimaryKey(10),
}
item.DeleteData.Tss = []uint64{2000}
}

View File

@ -0,0 +1,29 @@
package deletebuffer
import "github.com/milvus-io/milvus/internal/util/typeutil"
// deleteBuffer caches L0 delete buffer for remote segments.
type deleteBuffer struct {
// timestamp => DeleteData
cache *typeutil.SkipList[uint64, []BufferItem]
}
// Cache delete data.
func (b *deleteBuffer) Cache(timestamp uint64, data []BufferItem) {
b.cache.Upsert(timestamp, data)
}
func (b *deleteBuffer) List(since uint64) [][]BufferItem {
return b.cache.ListAfter(since, false)
}
func (b *deleteBuffer) TruncateBefore(ts uint64) {
b.cache.TruncateBefore(ts)
}
func NewDeleteBuffer() *deleteBuffer {
cache, _ := typeutil.NewSkipList[uint64, []BufferItem]()
return &deleteBuffer{
cache: cache,
}
}

View File

@ -0,0 +1,242 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package delegator
import (
"sync"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/atomic"
)
const (
// wildcardNodeID matches any nodeID, used for force distribution correction.
wildcardNodeID = int64(-1)
)
// distribution is the struct to store segment distribution.
// it contains both growing and sealed segments.
type distribution struct {
// segments information
// map[SegmentID]=>segmentEntry
growingSegments map[UniqueID]SegmentEntry
sealedSegments map[UniqueID]SegmentEntry
// version indicator
version int64
// quick flag for current snapshot is serviceable
serviceable *atomic.Bool
offlines typeutil.Set[int64]
snapshots *typeutil.ConcurrentMap[int64, *snapshot]
// current is the snapshot for quick usage for search/query
// generated for each change of distribution
current *snapshot
// protects current & segments
mut sync.RWMutex
}
// SegmentEntry stores the segment meta information.
type SegmentEntry struct {
NodeID int64
SegmentID UniqueID
PartitionID UniqueID
Version int64
}
// NewDistribution creates a new distribution instance with all field initialized.
func NewDistribution() *distribution {
dist := &distribution{
serviceable: atomic.NewBool(false),
growingSegments: make(map[UniqueID]SegmentEntry),
sealedSegments: make(map[UniqueID]SegmentEntry),
snapshots: typeutil.NewConcurrentMap[int64, *snapshot](),
}
dist.genSnapshot()
return dist
}
// GetCurrent returns current snapshot.
func (d *distribution) GetCurrent(partitions ...int64) (sealed []SnapshotItem, growing []SegmentEntry, version int64) {
d.mut.RLock()
defer d.mut.RUnlock()
sealed, growing = d.current.Get(partitions...)
version = d.current.version
return
}
// FinishUsage notifies snapshot one reference is released.
func (d *distribution) FinishUsage(version int64) {
snapshot, ok := d.snapshots.Get(version)
if ok {
snapshot.Done(d.getCleanup(snapshot.version))
}
}
// Serviceable returns wether current snapshot is serviceable.
func (d *distribution) Serviceable() bool {
return d.serviceable.Load()
}
// AddDistributions add multiple segment entries.
func (d *distribution) AddDistributions(entries ...SegmentEntry) []int64 {
d.mut.Lock()
defer d.mut.Unlock()
// remove growing if sealed is loaded
var removed []int64
for _, entry := range entries {
d.sealedSegments[entry.SegmentID] = entry
d.offlines.Remove(entry.SegmentID)
_, ok := d.growingSegments[entry.SegmentID]
if ok {
removed = append(removed, entry.SegmentID)
delete(d.growingSegments, entry.SegmentID)
}
}
d.genSnapshot()
return removed
}
// AddGrowing adds growing segment distribution.
func (d *distribution) AddGrowing(entries ...SegmentEntry) {
d.mut.Lock()
defer d.mut.Unlock()
for _, entry := range entries {
d.growingSegments[entry.SegmentID] = entry
}
d.genSnapshot()
}
// AddOffline set segmentIDs to offlines.
func (d *distribution) AddOfflines(segmentIDs ...int64) {
d.mut.Lock()
defer d.mut.Unlock()
updated := false
for _, segmentID := range segmentIDs {
_, ok := d.sealedSegments[segmentID]
if !ok {
continue
}
updated = true
d.offlines.Insert(segmentID)
}
if updated {
d.genSnapshot()
}
}
// RemoveDistributions remove segments distributions and returns the clear signal channel.
func (d *distribution) RemoveDistributions(sealedSegments []SegmentEntry, growingSegments []SegmentEntry) chan struct{} {
d.mut.Lock()
defer d.mut.Unlock()
changed := false
for _, sealed := range sealedSegments {
if d.offlines.Contain(sealed.SegmentID) {
d.offlines.Remove(sealed.SegmentID)
changed = true
}
entry, ok := d.sealedSegments[sealed.SegmentID]
if !ok {
continue
}
if entry.NodeID == sealed.NodeID || sealed.NodeID == wildcardNodeID {
delete(d.sealedSegments, sealed.SegmentID)
changed = true
}
}
for _, growing := range growingSegments {
_, ok := d.growingSegments[growing.SegmentID]
if !ok {
continue
}
delete(d.growingSegments, growing.SegmentID)
changed = true
}
if !changed {
// no change made, return closed signal channel
ch := make(chan struct{})
close(ch)
return ch
}
return d.genSnapshot()
}
// getSnapshot converts current distribution to snapshot format.
// in which, user could juse found nodeID=>segmentID list.
// mutex RLock is required before calling this method.
func (d *distribution) genSnapshot() chan struct{} {
nodeSegments := make(map[int64][]SegmentEntry)
for _, entry := range d.sealedSegments {
nodeSegments[entry.NodeID] = append(nodeSegments[entry.NodeID], entry)
}
dist := make([]SnapshotItem, 0, len(nodeSegments))
for nodeID, items := range nodeSegments {
dist = append(dist, SnapshotItem{
NodeID: nodeID,
Segments: items,
})
}
growing := make([]SegmentEntry, 0, len(d.growingSegments))
for _, entry := range d.growingSegments {
growing = append(growing, entry)
}
d.serviceable.Store(d.offlines.Len() == 0)
// stores last snapshot
// ok to be nil
last := d.current
// increase version
d.version++
d.current = NewSnapshot(dist, growing, last, d.version)
// shall be a new one
d.snapshots.GetOrInsert(d.version, d.current)
// first snapshot, return closed chan
if last == nil {
ch := make(chan struct{})
close(ch)
return ch
}
last.Expire(d.getCleanup(last.version))
return last.cleared
}
// getCleanup returns cleanup snapshots function.
func (d *distribution) getCleanup(version int64) snapshotCleanup {
return func() {
d.snapshots.GetAndRemove(version)
}
}

View File

@ -0,0 +1,395 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package delegator
import (
"testing"
"time"
"github.com/stretchr/testify/suite"
)
type DistributionSuite struct {
suite.Suite
dist *distribution
}
func (s *DistributionSuite) SetupTest() {
s.dist = NewDistribution()
}
func (s *DistributionSuite) TearDownTest() {
s.dist = nil
}
func (s *DistributionSuite) TestAddDistribution() {
type testCase struct {
tag string
input []SegmentEntry
expected []SnapshotItem
}
cases := []testCase{
{
tag: "one node",
input: []SegmentEntry{
{
NodeID: 1,
SegmentID: 1,
},
{
NodeID: 1,
SegmentID: 2,
},
},
expected: []SnapshotItem{
{
NodeID: 1,
Segments: []SegmentEntry{
{
NodeID: 1,
SegmentID: 1,
},
{
NodeID: 1,
SegmentID: 2,
},
},
},
},
},
{
tag: "multiple nodes",
input: []SegmentEntry{
{
NodeID: 1,
SegmentID: 1,
},
{
NodeID: 2,
SegmentID: 2,
},
{
NodeID: 1,
SegmentID: 3,
},
},
expected: []SnapshotItem{
{
NodeID: 1,
Segments: []SegmentEntry{
{
NodeID: 1,
SegmentID: 1,
},
{
NodeID: 1,
SegmentID: 3,
},
},
},
{
NodeID: 2,
Segments: []SegmentEntry{
{
NodeID: 2,
SegmentID: 2,
},
},
},
},
},
}
for _, tc := range cases {
s.Run(tc.tag, func() {
s.SetupTest()
defer s.TearDownTest()
s.dist.AddDistributions(tc.input...)
sealed, _, version := s.dist.GetCurrent()
defer s.dist.FinishUsage(version)
s.compareSnapshotItems(tc.expected, sealed)
})
}
}
func (s *DistributionSuite) compareSnapshotItems(target, value []SnapshotItem) {
if !s.Equal(len(target), len(value)) {
return
}
mapNodeItem := make(map[int64]SnapshotItem)
for _, valueItem := range value {
mapNodeItem[valueItem.NodeID] = valueItem
}
for _, targetItem := range target {
valueItem, ok := mapNodeItem[targetItem.NodeID]
if !s.True(ok) {
return
}
s.ElementsMatch(targetItem.Segments, valueItem.Segments)
}
}
func (s *DistributionSuite) TestAddGrowing() {
type testCase struct {
tag string
input []SegmentEntry
expected []SegmentEntry
}
cases := []testCase{
{
tag: "nil input",
input: nil,
expected: []SegmentEntry{},
},
{
tag: "normal case",
input: []SegmentEntry{
{SegmentID: 1, PartitionID: 1},
{SegmentID: 2, PartitionID: 2},
},
expected: []SegmentEntry{
{SegmentID: 1, PartitionID: 1},
{SegmentID: 2, PartitionID: 2},
},
},
}
for _, tc := range cases {
s.Run(tc.tag, func() {
s.SetupTest()
defer s.TearDownTest()
s.dist.AddGrowing(tc.input...)
_, growing, version := s.dist.GetCurrent()
defer s.dist.FinishUsage(version)
s.ElementsMatch(tc.expected, growing)
})
}
}
func (s *DistributionSuite) TestRemoveDistribution() {
type testCase struct {
tag string
presetSealed []SegmentEntry
presetGrowing []SegmentEntry
removalSealed []SegmentEntry
removalGrowing []SegmentEntry
withMockRead bool
expectSealed []SnapshotItem
expectGrowing []SegmentEntry
}
cases := []testCase{
{
tag: "remove with no read",
presetSealed: []SegmentEntry{
{NodeID: 1, SegmentID: 1},
{NodeID: 2, SegmentID: 2},
{NodeID: 1, SegmentID: 3},
},
presetGrowing: []SegmentEntry{
{SegmentID: 4},
{SegmentID: 5},
},
removalSealed: []SegmentEntry{
{NodeID: 1, SegmentID: 1},
},
removalGrowing: []SegmentEntry{
{SegmentID: 5},
},
withMockRead: false,
expectSealed: []SnapshotItem{
{
NodeID: 1,
Segments: []SegmentEntry{
{NodeID: 1, SegmentID: 3},
},
},
{
NodeID: 2,
Segments: []SegmentEntry{
{NodeID: 2, SegmentID: 2},
},
},
},
expectGrowing: []SegmentEntry{{SegmentID: 4}},
},
{
tag: "remove with wrong nodeID",
presetSealed: []SegmentEntry{
{NodeID: 1, SegmentID: 1},
{NodeID: 2, SegmentID: 2},
{NodeID: 1, SegmentID: 3},
},
presetGrowing: []SegmentEntry{
{SegmentID: 4},
{SegmentID: 5},
},
removalSealed: []SegmentEntry{
{NodeID: 2, SegmentID: 1},
},
removalGrowing: nil,
withMockRead: false,
expectSealed: []SnapshotItem{
{
NodeID: 1,
Segments: []SegmentEntry{
{NodeID: 1, SegmentID: 1},
{NodeID: 1, SegmentID: 3},
},
},
{
NodeID: 2,
Segments: []SegmentEntry{
{NodeID: 2, SegmentID: 2},
},
},
},
expectGrowing: []SegmentEntry{{SegmentID: 4}, {SegmentID: 5}},
},
{
tag: "remove with wildcardNodeID",
presetSealed: []SegmentEntry{
{NodeID: 1, SegmentID: 1},
{NodeID: 2, SegmentID: 2},
{NodeID: 1, SegmentID: 3},
},
presetGrowing: []SegmentEntry{
{SegmentID: 4},
{SegmentID: 5},
},
removalSealed: []SegmentEntry{
{NodeID: wildcardNodeID, SegmentID: 1},
},
removalGrowing: nil,
withMockRead: false,
expectSealed: []SnapshotItem{
{
NodeID: 1,
Segments: []SegmentEntry{
{NodeID: 1, SegmentID: 3},
},
},
{
NodeID: 2,
Segments: []SegmentEntry{
{NodeID: 2, SegmentID: 2},
},
},
},
expectGrowing: []SegmentEntry{{SegmentID: 4}, {SegmentID: 5}},
},
{
tag: "remove with read",
presetSealed: []SegmentEntry{
{NodeID: 1, SegmentID: 1},
{NodeID: 2, SegmentID: 2},
{NodeID: 1, SegmentID: 3},
},
presetGrowing: []SegmentEntry{
{SegmentID: 4},
{SegmentID: 5},
},
removalSealed: []SegmentEntry{
{NodeID: 1, SegmentID: 1},
},
removalGrowing: []SegmentEntry{
{SegmentID: 5},
},
withMockRead: true,
expectSealed: []SnapshotItem{
{
NodeID: 1,
Segments: []SegmentEntry{
{NodeID: 1, SegmentID: 3},
},
},
{
NodeID: 2,
Segments: []SegmentEntry{
{NodeID: 2, SegmentID: 2},
},
},
},
expectGrowing: []SegmentEntry{{SegmentID: 4}},
},
}
for _, tc := range cases {
s.Run(tc.tag, func() {
s.SetupTest()
defer s.TearDownTest()
s.dist.AddGrowing(tc.presetGrowing...)
s.dist.AddDistributions(tc.presetSealed...)
var version int64
if tc.withMockRead {
_, _, version = s.dist.GetCurrent()
}
ch := s.dist.RemoveDistributions(tc.removalSealed, tc.removalGrowing)
if tc.withMockRead {
// check ch not closed
select {
case <-ch:
s.Fail("ch closed with running read")
default:
}
s.dist.FinishUsage(version)
}
// check ch close very soon
timeout := time.NewTimer(time.Second)
defer timeout.Stop()
select {
case <-timeout.C:
s.Fail("ch not closed after 1 second")
case <-ch:
}
sealed, growing, version := s.dist.GetCurrent()
defer s.dist.FinishUsage(version)
s.compareSnapshotItems(tc.expectSealed, sealed)
s.ElementsMatch(tc.expectGrowing, growing)
})
}
}
func TestDistributionSuite(t *testing.T) {
suite.Run(t, new(DistributionSuite))
}

View File

@ -0,0 +1,597 @@
// Code generated by mockery v2.16.0. DO NOT EDIT.
package delegator
import (
context "context"
internalpb "github.com/milvus-io/milvus/internal/proto/internalpb"
mock "github.com/stretchr/testify/mock"
querypb "github.com/milvus-io/milvus/internal/proto/querypb"
)
// MockShardDelegator is an autogenerated mock type for the ShardDelegator type
type MockShardDelegator struct {
mock.Mock
}
type MockShardDelegator_Expecter struct {
mock *mock.Mock
}
func (_m *MockShardDelegator) EXPECT() *MockShardDelegator_Expecter {
return &MockShardDelegator_Expecter{mock: &_m.Mock}
}
// Close provides a mock function with given fields:
func (_m *MockShardDelegator) Close() {
_m.Called()
}
// MockShardDelegator_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
type MockShardDelegator_Close_Call struct {
*mock.Call
}
// Close is a helper method to define mock.On call
func (_e *MockShardDelegator_Expecter) Close() *MockShardDelegator_Close_Call {
return &MockShardDelegator_Close_Call{Call: _e.mock.On("Close")}
}
func (_c *MockShardDelegator_Close_Call) Run(run func()) *MockShardDelegator_Close_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockShardDelegator_Close_Call) Return() *MockShardDelegator_Close_Call {
_c.Call.Return()
return _c
}
// Collection provides a mock function with given fields:
func (_m *MockShardDelegator) Collection() int64 {
ret := _m.Called()
var r0 int64
if rf, ok := ret.Get(0).(func() int64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int64)
}
return r0
}
// MockShardDelegator_Collection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Collection'
type MockShardDelegator_Collection_Call struct {
*mock.Call
}
// Collection is a helper method to define mock.On call
func (_e *MockShardDelegator_Expecter) Collection() *MockShardDelegator_Collection_Call {
return &MockShardDelegator_Collection_Call{Call: _e.mock.On("Collection")}
}
func (_c *MockShardDelegator_Collection_Call) Run(run func()) *MockShardDelegator_Collection_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockShardDelegator_Collection_Call) Return(_a0 int64) *MockShardDelegator_Collection_Call {
_c.Call.Return(_a0)
return _c
}
// GetDistribution provides a mock function with given fields:
func (_m *MockShardDelegator) GetDistribution() *distribution {
ret := _m.Called()
var r0 *distribution
if rf, ok := ret.Get(0).(func() *distribution); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*distribution)
}
}
return r0
}
// MockShardDelegator_GetDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDistribution'
type MockShardDelegator_GetDistribution_Call struct {
*mock.Call
}
// GetDistribution is a helper method to define mock.On call
func (_e *MockShardDelegator_Expecter) GetDistribution() *MockShardDelegator_GetDistribution_Call {
return &MockShardDelegator_GetDistribution_Call{Call: _e.mock.On("GetDistribution")}
}
func (_c *MockShardDelegator_GetDistribution_Call) Run(run func()) *MockShardDelegator_GetDistribution_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockShardDelegator_GetDistribution_Call) Return(_a0 *distribution) *MockShardDelegator_GetDistribution_Call {
_c.Call.Return(_a0)
return _c
}
// GetStatistics provides a mock function with given fields: ctx, req
func (_m *MockShardDelegator) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error) {
ret := _m.Called(ctx, req)
var r0 []*internalpb.GetStatisticsResponse
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest) []*internalpb.GetStatisticsResponse); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*internalpb.GetStatisticsResponse)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetStatisticsRequest) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockShardDelegator_GetStatistics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStatistics'
type MockShardDelegator_GetStatistics_Call struct {
*mock.Call
}
// GetStatistics is a helper method to define mock.On call
// - ctx context.Context
// - req *querypb.GetStatisticsRequest
func (_e *MockShardDelegator_Expecter) GetStatistics(ctx interface{}, req interface{}) *MockShardDelegator_GetStatistics_Call {
return &MockShardDelegator_GetStatistics_Call{Call: _e.mock.On("GetStatistics", ctx, req)}
}
func (_c *MockShardDelegator_GetStatistics_Call) Run(run func(ctx context.Context, req *querypb.GetStatisticsRequest)) *MockShardDelegator_GetStatistics_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.GetStatisticsRequest))
})
return _c
}
func (_c *MockShardDelegator_GetStatistics_Call) Return(_a0 []*internalpb.GetStatisticsResponse, _a1 error) *MockShardDelegator_GetStatistics_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// LoadGrowing provides a mock function with given fields: ctx, infos, version
func (_m *MockShardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error {
ret := _m.Called(ctx, infos, version)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, []*querypb.SegmentLoadInfo, int64) error); ok {
r0 = rf(ctx, infos, version)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockShardDelegator_LoadGrowing_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadGrowing'
type MockShardDelegator_LoadGrowing_Call struct {
*mock.Call
}
// LoadGrowing is a helper method to define mock.On call
// - ctx context.Context
// - infos []*querypb.SegmentLoadInfo
// - version int64
func (_e *MockShardDelegator_Expecter) LoadGrowing(ctx interface{}, infos interface{}, version interface{}) *MockShardDelegator_LoadGrowing_Call {
return &MockShardDelegator_LoadGrowing_Call{Call: _e.mock.On("LoadGrowing", ctx, infos, version)}
}
func (_c *MockShardDelegator_LoadGrowing_Call) Run(run func(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64)) *MockShardDelegator_LoadGrowing_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].([]*querypb.SegmentLoadInfo), args[2].(int64))
})
return _c
}
func (_c *MockShardDelegator_LoadGrowing_Call) Return(_a0 error) *MockShardDelegator_LoadGrowing_Call {
_c.Call.Return(_a0)
return _c
}
// LoadSegments provides a mock function with given fields: ctx, req
func (_m *MockShardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error {
ret := _m.Called(ctx, req)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadSegmentsRequest) error); ok {
r0 = rf(ctx, req)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockShardDelegator_LoadSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadSegments'
type MockShardDelegator_LoadSegments_Call struct {
*mock.Call
}
// LoadSegments is a helper method to define mock.On call
// - ctx context.Context
// - req *querypb.LoadSegmentsRequest
func (_e *MockShardDelegator_Expecter) LoadSegments(ctx interface{}, req interface{}) *MockShardDelegator_LoadSegments_Call {
return &MockShardDelegator_LoadSegments_Call{Call: _e.mock.On("LoadSegments", ctx, req)}
}
func (_c *MockShardDelegator_LoadSegments_Call) Run(run func(ctx context.Context, req *querypb.LoadSegmentsRequest)) *MockShardDelegator_LoadSegments_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.LoadSegmentsRequest))
})
return _c
}
func (_c *MockShardDelegator_LoadSegments_Call) Return(_a0 error) *MockShardDelegator_LoadSegments_Call {
_c.Call.Return(_a0)
return _c
}
// ProcessDelete provides a mock function with given fields: deleteData, ts
func (_m *MockShardDelegator) ProcessDelete(deleteData []*DeleteData, ts uint64) {
_m.Called(deleteData, ts)
}
// MockShardDelegator_ProcessDelete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ProcessDelete'
type MockShardDelegator_ProcessDelete_Call struct {
*mock.Call
}
// ProcessDelete is a helper method to define mock.On call
// - deleteData []*DeleteData
// - ts uint64
func (_e *MockShardDelegator_Expecter) ProcessDelete(deleteData interface{}, ts interface{}) *MockShardDelegator_ProcessDelete_Call {
return &MockShardDelegator_ProcessDelete_Call{Call: _e.mock.On("ProcessDelete", deleteData, ts)}
}
func (_c *MockShardDelegator_ProcessDelete_Call) Run(run func(deleteData []*DeleteData, ts uint64)) *MockShardDelegator_ProcessDelete_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]*DeleteData), args[1].(uint64))
})
return _c
}
func (_c *MockShardDelegator_ProcessDelete_Call) Return() *MockShardDelegator_ProcessDelete_Call {
_c.Call.Return()
return _c
}
// ProcessInsert provides a mock function with given fields: insertRecords
func (_m *MockShardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) {
_m.Called(insertRecords)
}
// MockShardDelegator_ProcessInsert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ProcessInsert'
type MockShardDelegator_ProcessInsert_Call struct {
*mock.Call
}
// ProcessInsert is a helper method to define mock.On call
// - insertRecords map[int64]*InsertData
func (_e *MockShardDelegator_Expecter) ProcessInsert(insertRecords interface{}) *MockShardDelegator_ProcessInsert_Call {
return &MockShardDelegator_ProcessInsert_Call{Call: _e.mock.On("ProcessInsert", insertRecords)}
}
func (_c *MockShardDelegator_ProcessInsert_Call) Run(run func(insertRecords map[int64]*InsertData)) *MockShardDelegator_ProcessInsert_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(map[int64]*InsertData))
})
return _c
}
func (_c *MockShardDelegator_ProcessInsert_Call) Return() *MockShardDelegator_ProcessInsert_Call {
_c.Call.Return()
return _c
}
// Query provides a mock function with given fields: ctx, req
func (_m *MockShardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) {
ret := _m.Called(ctx, req)
var r0 []*internalpb.RetrieveResults
if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) []*internalpb.RetrieveResults); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*internalpb.RetrieveResults)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockShardDelegator_Query_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Query'
type MockShardDelegator_Query_Call struct {
*mock.Call
}
// Query is a helper method to define mock.On call
// - ctx context.Context
// - req *querypb.QueryRequest
func (_e *MockShardDelegator_Expecter) Query(ctx interface{}, req interface{}) *MockShardDelegator_Query_Call {
return &MockShardDelegator_Query_Call{Call: _e.mock.On("Query", ctx, req)}
}
func (_c *MockShardDelegator_Query_Call) Run(run func(ctx context.Context, req *querypb.QueryRequest)) *MockShardDelegator_Query_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.QueryRequest))
})
return _c
}
func (_c *MockShardDelegator_Query_Call) Return(_a0 []*internalpb.RetrieveResults, _a1 error) *MockShardDelegator_Query_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// ReleaseSegments provides a mock function with given fields: ctx, req, force
func (_m *MockShardDelegator) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error {
ret := _m.Called(ctx, req, force)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseSegmentsRequest, bool) error); ok {
r0 = rf(ctx, req, force)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockShardDelegator_ReleaseSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleaseSegments'
type MockShardDelegator_ReleaseSegments_Call struct {
*mock.Call
}
// ReleaseSegments is a helper method to define mock.On call
// - ctx context.Context
// - req *querypb.ReleaseSegmentsRequest
// - force bool
func (_e *MockShardDelegator_Expecter) ReleaseSegments(ctx interface{}, req interface{}, force interface{}) *MockShardDelegator_ReleaseSegments_Call {
return &MockShardDelegator_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", ctx, req, force)}
}
func (_c *MockShardDelegator_ReleaseSegments_Call) Run(run func(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool)) *MockShardDelegator_ReleaseSegments_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.ReleaseSegmentsRequest), args[2].(bool))
})
return _c
}
func (_c *MockShardDelegator_ReleaseSegments_Call) Return(_a0 error) *MockShardDelegator_ReleaseSegments_Call {
_c.Call.Return(_a0)
return _c
}
// Search provides a mock function with given fields: ctx, req
func (_m *MockShardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
ret := _m.Called(ctx, req)
var r0 []*internalpb.SearchResults
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) []*internalpb.SearchResults); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*internalpb.SearchResults)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SearchRequest) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockShardDelegator_Search_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Search'
type MockShardDelegator_Search_Call struct {
*mock.Call
}
// Search is a helper method to define mock.On call
// - ctx context.Context
// - req *querypb.SearchRequest
func (_e *MockShardDelegator_Expecter) Search(ctx interface{}, req interface{}) *MockShardDelegator_Search_Call {
return &MockShardDelegator_Search_Call{Call: _e.mock.On("Search", ctx, req)}
}
func (_c *MockShardDelegator_Search_Call) Run(run func(ctx context.Context, req *querypb.SearchRequest)) *MockShardDelegator_Search_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.SearchRequest))
})
return _c
}
func (_c *MockShardDelegator_Search_Call) Return(_a0 []*internalpb.SearchResults, _a1 error) *MockShardDelegator_Search_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// Serviceable provides a mock function with given fields:
func (_m *MockShardDelegator) Serviceable() bool {
ret := _m.Called()
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// MockShardDelegator_Serviceable_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Serviceable'
type MockShardDelegator_Serviceable_Call struct {
*mock.Call
}
// Serviceable is a helper method to define mock.On call
func (_e *MockShardDelegator_Expecter) Serviceable() *MockShardDelegator_Serviceable_Call {
return &MockShardDelegator_Serviceable_Call{Call: _e.mock.On("Serviceable")}
}
func (_c *MockShardDelegator_Serviceable_Call) Run(run func()) *MockShardDelegator_Serviceable_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockShardDelegator_Serviceable_Call) Return(_a0 bool) *MockShardDelegator_Serviceable_Call {
_c.Call.Return(_a0)
return _c
}
// Start provides a mock function with given fields:
func (_m *MockShardDelegator) Start() {
_m.Called()
}
// MockShardDelegator_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start'
type MockShardDelegator_Start_Call struct {
*mock.Call
}
// Start is a helper method to define mock.On call
func (_e *MockShardDelegator_Expecter) Start() *MockShardDelegator_Start_Call {
return &MockShardDelegator_Start_Call{Call: _e.mock.On("Start")}
}
func (_c *MockShardDelegator_Start_Call) Run(run func()) *MockShardDelegator_Start_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockShardDelegator_Start_Call) Return() *MockShardDelegator_Start_Call {
_c.Call.Return()
return _c
}
// SyncDistribution provides a mock function with given fields: ctx, entries
func (_m *MockShardDelegator) SyncDistribution(ctx context.Context, entries ...SegmentEntry) {
_va := make([]interface{}, len(entries))
for _i := range entries {
_va[_i] = entries[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx)
_ca = append(_ca, _va...)
_m.Called(_ca...)
}
// MockShardDelegator_SyncDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SyncDistribution'
type MockShardDelegator_SyncDistribution_Call struct {
*mock.Call
}
// SyncDistribution is a helper method to define mock.On call
// - ctx context.Context
// - entries ...SegmentEntry
func (_e *MockShardDelegator_Expecter) SyncDistribution(ctx interface{}, entries ...interface{}) *MockShardDelegator_SyncDistribution_Call {
return &MockShardDelegator_SyncDistribution_Call{Call: _e.mock.On("SyncDistribution",
append([]interface{}{ctx}, entries...)...)}
}
func (_c *MockShardDelegator_SyncDistribution_Call) Run(run func(ctx context.Context, entries ...SegmentEntry)) *MockShardDelegator_SyncDistribution_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]SegmentEntry, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(SegmentEntry)
}
}
run(args[0].(context.Context), variadicArgs...)
})
return _c
}
func (_c *MockShardDelegator_SyncDistribution_Call) Return() *MockShardDelegator_SyncDistribution_Call {
_c.Call.Return()
return _c
}
// Version provides a mock function with given fields:
func (_m *MockShardDelegator) Version() int64 {
ret := _m.Called()
var r0 int64
if rf, ok := ret.Get(0).(func() int64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int64)
}
return r0
}
// MockShardDelegator_Version_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Version'
type MockShardDelegator_Version_Call struct {
*mock.Call
}
// Version is a helper method to define mock.On call
func (_e *MockShardDelegator_Expecter) Version() *MockShardDelegator_Version_Call {
return &MockShardDelegator_Version_Call{Call: _e.mock.On("Version")}
}
func (_c *MockShardDelegator_Version_Call) Run(run func()) *MockShardDelegator_Version_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockShardDelegator_Version_Call) Return(_a0 int64) *MockShardDelegator_Version_Call {
_c.Call.Return(_a0)
return _c
}
type mockConstructorTestingTNewMockShardDelegator interface {
mock.TestingT
Cleanup(func())
}
// NewMockShardDelegator creates a new instance of MockShardDelegator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockShardDelegator(t mockConstructorTestingTNewMockShardDelegator) *MockShardDelegator {
mock := &MockShardDelegator{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -0,0 +1,132 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package delegator
import (
"sync"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/samber/lo"
"go.uber.org/atomic"
)
// SnapshotItem group segmentEntry slice
type SnapshotItem struct {
NodeID int64
Segments []SegmentEntry
}
// snapshotCleanup cleanup function signature.
type snapshotCleanup func()
// snapshot records segment distribution with ref count.
type snapshot struct {
dist []SnapshotItem
growing []SegmentEntry
// version ID for tracking
version int64
// signal channel for this snapshot cleared
cleared chan struct{}
once sync.Once
// reference to last version
last *snapshot
// reference count for this snapshot
inUse atomic.Int64
// expired flag
expired bool
}
// NewSnapshot returns a prepared snapshot with channel initialized.
func NewSnapshot(sealed []SnapshotItem, growing []SegmentEntry, last *snapshot, version int64) *snapshot {
return &snapshot{
version: version,
growing: growing,
dist: sealed,
last: last,
cleared: make(chan struct{}),
}
}
// Expire sets expired flag to true.
func (s *snapshot) Expire(cleanup snapshotCleanup) {
s.expired = true
s.checkCleared(cleanup)
}
// Get returns segment distributions with provided partition ids.
func (s *snapshot) Get(partitions ...int64) (sealed []SnapshotItem, growing []SegmentEntry) {
s.inUse.Inc()
filter := func(entry SegmentEntry, idx int) bool {
return len(partitions) == 0 || funcutil.SliceContain(partitions, entry.PartitionID)
}
sealed = make([]SnapshotItem, 0, len(s.dist))
for _, item := range s.dist {
segments := lo.Filter(item.Segments, filter)
sealed = append(sealed, SnapshotItem{
NodeID: item.NodeID,
Segments: segments,
})
}
growing = lo.Filter(s.growing, filter)
return
}
// Done decreases inUse count for snapshot.
// also performs cleared check.
func (s *snapshot) Done(cleanup snapshotCleanup) {
s.inUse.Dec()
s.checkCleared(cleanup)
}
// checkCleared performs safety check for snapshot closing the cleared signal.
func (s *snapshot) checkCleared(cleanup snapshotCleanup) {
if s.expired && s.inUse.Load() == 0 {
s.once.Do(func() {
// first snapshot
if s.last == nil {
close(s.cleared)
cleanup()
return
}
// wait last version cleared
go func() {
<-s.last.cleared
s.last = nil
cleanup()
close(s.cleared)
}()
})
}
}
func inList(list []int64, target int64) bool {
for _, i := range list {
if i == target {
return true
}
}
return false
}

View File

@ -0,0 +1,254 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package delegator
import (
"testing"
"time"
"github.com/stretchr/testify/suite"
)
type SnapshotSuite struct {
suite.Suite
snapshot *snapshot
}
func (s *SnapshotSuite) SetupTest() {
last := NewSnapshot(nil, nil, nil, 0)
last.Expire(func() {})
dist := []SnapshotItem{
{
NodeID: 1,
Segments: []SegmentEntry{
{
SegmentID: 1,
PartitionID: 1,
},
{
SegmentID: 2,
PartitionID: 2,
},
},
},
{
NodeID: 2,
Segments: []SegmentEntry{
{
SegmentID: 3,
PartitionID: 1,
},
{
SegmentID: 4,
PartitionID: 2,
},
},
},
}
growing := []SegmentEntry{
{
SegmentID: 5,
PartitionID: 1,
},
{
SegmentID: 6,
PartitionID: 2,
},
}
s.snapshot = NewSnapshot(dist, growing, last, 1)
}
func (s *SnapshotSuite) TearDownTest() {
s.snapshot = nil
}
func (s *SnapshotSuite) TestGet() {
type testCase struct {
tag string
partitions []int64
expectedSealed []SnapshotItem
expectedGrowing []SegmentEntry
}
cases := []testCase{
{
tag: "nil partition",
partitions: nil,
expectedSealed: []SnapshotItem{
{
NodeID: 1,
Segments: []SegmentEntry{
{SegmentID: 1, PartitionID: 1},
{SegmentID: 2, PartitionID: 2},
},
},
{
NodeID: 2,
Segments: []SegmentEntry{
{SegmentID: 3, PartitionID: 1},
{SegmentID: 4, PartitionID: 2},
},
},
},
expectedGrowing: []SegmentEntry{
{
SegmentID: 5,
PartitionID: 1,
},
{
SegmentID: 6,
PartitionID: 2,
},
},
},
{
tag: "partition_1",
partitions: []int64{1},
expectedSealed: []SnapshotItem{
{
NodeID: 1,
Segments: []SegmentEntry{
{SegmentID: 1, PartitionID: 1},
},
},
{
NodeID: 2,
Segments: []SegmentEntry{
{SegmentID: 3, PartitionID: 1},
},
},
},
expectedGrowing: []SegmentEntry{
{
SegmentID: 5,
PartitionID: 1,
},
},
},
{
tag: "partition_2",
partitions: []int64{2},
expectedSealed: []SnapshotItem{
{
NodeID: 1,
Segments: []SegmentEntry{
{SegmentID: 2, PartitionID: 2},
},
},
{
NodeID: 2,
Segments: []SegmentEntry{
{SegmentID: 4, PartitionID: 2},
},
},
},
expectedGrowing: []SegmentEntry{
{
SegmentID: 6,
PartitionID: 2,
},
},
},
{
tag: "partition not exists",
partitions: []int64{3},
expectedSealed: []SnapshotItem{
{
NodeID: 1,
Segments: []SegmentEntry{},
},
{
NodeID: 2,
Segments: []SegmentEntry{},
},
},
expectedGrowing: []SegmentEntry{},
},
}
for _, tc := range cases {
s.Run(tc.tag, func() {
before := s.snapshot.inUse.Load()
sealed, growing := s.snapshot.Get(tc.partitions...)
after := s.snapshot.inUse.Load()
s.ElementsMatch(tc.expectedSealed, sealed)
s.ElementsMatch(tc.expectedGrowing, growing)
s.EqualValues(1, after-before)
})
}
}
func (s *SnapshotSuite) TestDone() {
s.Run("done not expired snapshot", func() {
inUse := s.snapshot.inUse.Load()
s.Require().EqualValues(0, inUse)
s.snapshot.Get()
inUse = s.snapshot.inUse.Load()
s.Require().EqualValues(1, inUse)
s.snapshot.Done(func() {})
inUse = s.snapshot.inUse.Load()
s.EqualValues(0, inUse)
// check cleared channel closed
select {
case <-s.snapshot.cleared:
s.Fail("snapshot channel closed in non-expired state")
default:
}
})
s.Run("done expired snapshot", func() {
inUse := s.snapshot.inUse.Load()
s.Require().EqualValues(0, inUse)
s.snapshot.Get()
inUse = s.snapshot.inUse.Load()
s.Require().EqualValues(1, inUse)
s.snapshot.Expire(func() {})
// check cleared channel closed
select {
case <-s.snapshot.cleared:
s.FailNow("snapshot channel closed in non-expired state")
default:
}
signal := make(chan struct{})
s.snapshot.Done(func() { close(signal) })
inUse = s.snapshot.inUse.Load()
s.EqualValues(0, inUse)
timeout := time.NewTimer(time.Second)
defer timeout.Stop()
select {
case <-timeout.C:
s.FailNow("cleanup never called")
case <-signal:
}
timeout = time.NewTimer(10 * time.Millisecond)
defer timeout.Stop()
select {
case <-timeout.C:
s.FailNow("snapshot channel not closed after expired and no use")
case <-s.snapshot.cleared:
}
})
}
func TestSnapshot(t *testing.T) {
suite.Run(t, new(SnapshotSuite))
}

View File

@ -0,0 +1,61 @@
package delegator
import (
"fmt"
"time"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
const (
rowIDFieldID FieldID = 0
timestampFieldID FieldID = 1
)
type (
// UniqueID is an identifier that is guaranteed to be unique among all the collections, partitions and segments
UniqueID = typeutil.UniqueID
// Timestamp is timestamp
Timestamp = typeutil.Timestamp
// FieldID is to uniquely identify the field
FieldID = int64
// IntPrimaryKey is the primary key of int type
IntPrimaryKey = typeutil.IntPrimaryKey
// DSL is the Domain Specific Language
DSL = string
// ConsumeSubName is consumer's subscription name of the message stream
ConsumeSubName = string
)
// TimeRange is a range of time periods
type TimeRange struct {
timestampMin Timestamp
timestampMax Timestamp
}
// loadType is load collection or load partition
type loadType = querypb.LoadType
const (
loadTypeCollection = querypb.LoadType_LoadCollection
loadTypePartition = querypb.LoadType_LoadPartition
)
// TSafeUpdater is the interface for type provides tsafe update event
type TSafeUpdater interface {
RegisterChannel(string) chan Timestamp
UnregisterChannel(string) error
}
var (
// ErrTsLagTooLarge serviceable and guarantee lag too large.
ErrTsLagTooLarge = errors.New("Timestamp lag too large")
)
// WrapErrTsLagTooLarge wraps ErrTsLagTooLarge with lag and max value.
func WrapErrTsLagTooLarge(duration time.Duration, maxLag time.Duration) error {
return fmt.Errorf("%w lag(%s) max(%s)", ErrTsLagTooLarge, duration, maxLag)
}

View File

@ -0,0 +1,42 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querynodev2
import (
"fmt"
"github.com/cockroachdb/errors"
)
var (
ErrNodeUnhealthy = errors.New("NodeIsUnhealthy")
ErrGetDelegatorFailed = errors.New("GetShardDelefatorFailed")
ErrInitPipelineFailed = errors.New("InitPipelineFailed")
)
// WrapErrNodeUnhealthy wraps ErrNodeUnhealthy with nodeID.
func WrapErrNodeUnhealthy(nodeID int64) error {
return fmt.Errorf("%w id: %d", ErrNodeUnhealthy, nodeID)
}
func WrapErrInitPipelineFailed(err error) error {
return fmt.Errorf("%w err: %s", ErrInitPipelineFailed, err.Error())
}
func msgQueryNodeIsUnhealthy(nodeID int64) string {
return fmt.Sprintf("query node %d is not ready", nodeID)
}

View File

@ -0,0 +1,455 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querynodev2
import (
"context"
"fmt"
"strconv"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tasks"
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/timerecord"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
)
func loadGrowingSegments(ctx context.Context, delegator delegator.ShardDelegator, req *querypb.WatchDmChannelsRequest) error {
// load growing segments
growingSegments := make([]*querypb.SegmentLoadInfo, 0, len(req.Infos))
for _, info := range req.Infos {
for _, segmentID := range info.GetUnflushedSegmentIds() {
// unFlushed segment may not have binLogs, skip loading
segmentInfo := req.GetSegmentInfos()[segmentID]
if segmentInfo == nil {
log.Warn("an unflushed segment is not found in segment infos", zap.Int64("segment ID", segmentID))
continue
}
if len(segmentInfo.GetBinlogs()) > 0 {
growingSegments = append(growingSegments, &querypb.SegmentLoadInfo{
SegmentID: segmentInfo.ID,
PartitionID: segmentInfo.PartitionID,
CollectionID: segmentInfo.CollectionID,
BinlogPaths: segmentInfo.Binlogs,
NumOfRows: segmentInfo.NumOfRows,
Statslogs: segmentInfo.Statslogs,
Deltalogs: segmentInfo.Deltalogs,
InsertChannel: segmentInfo.InsertChannel,
})
} else {
log.Info("skip segment which binlog is empty", zap.Int64("segmentID", segmentInfo.ID))
}
}
}
return delegator.LoadGrowing(ctx, growingSegments, req.GetVersion())
}
func (node *QueryNode) loadDeltaLogs(ctx context.Context, req *querypb.LoadSegmentsRequest) *commonpb.Status {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
)
loadedSegments := make([]int64, 0, len(req.GetInfos()))
var finalErr error
for _, info := range req.GetInfos() {
segment := node.manager.Segment.GetSealed(info.GetSegmentID())
if segment == nil {
continue
}
local := segment.(*segments.LocalSegment)
err := node.loader.LoadDeltaLogs(ctx, local, info.GetDeltalogs())
if err != nil {
if finalErr == nil {
finalErr = err
}
continue
}
loadedSegments = append(loadedSegments, info.GetSegmentID())
}
if finalErr != nil {
log.Warn("failed to load delta logs", zap.Error(finalErr))
return util.WrapStatus(commonpb.ErrorCode_UnexpectedError, "failed to load delta logs", finalErr)
}
return util.SuccessStatus()
}
func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryRequest, channel string) (*internalpb.RetrieveResults, error) {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel).Inc()
failRet := WrapRetrieveResult(commonpb.ErrorCode_UnexpectedError, "")
msgID := req.Req.Base.GetMsgID()
traceID := trace.SpanFromContext(ctx).SpanContext().TraceID()
log := log.Ctx(ctx).With(
zap.Int64("msgID", msgID),
zap.Int64("collectionID", req.GetReq().GetCollectionID()),
zap.String("channel", channel),
zap.String("scope", req.GetScope().String()),
)
defer func() {
if failRet.Status.ErrorCode != commonpb.ErrorCode_Success {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel).Inc()
}
}()
if !node.lifetime.Add(commonpbutil.IsHealthy) {
failRet.Status.Reason = msgQueryNodeIsUnhealthy(paramtable.GetNodeID())
return failRet, nil
}
defer node.lifetime.Done()
log.Debug("start do query with channel",
zap.Bool("fromShardLeader", req.GetFromShardLeader()),
zap.Int64s("segmentIDs", req.GetSegmentIDs()),
)
// add cancel when error occurs
queryCtx, cancel := context.WithCancel(ctx)
defer cancel()
// TODO From Shard Delegator
if req.FromShardLeader {
tr := timerecord.NewTimeRecorder("queryChannel")
results, err := node.querySegments(queryCtx, req)
if err != nil {
log.Warn("failed to query channel", zap.Error(err))
failRet.Status.Reason = err.Error()
return failRet, nil
}
tr.CtxElapse(ctx, fmt.Sprintf("do query done, traceID = %s, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
traceID,
req.GetFromShardLeader(),
channel,
req.GetSegmentIDs(),
))
failRet.Status.ErrorCode = commonpb.ErrorCode_Success
// TODO QueryNodeSQLatencyInQueue QueryNodeReduceLatency
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel).Inc()
return results, nil
}
// From Proxy
tr := timerecord.NewTimeRecorder("queryDelegator")
// get delegator
sd, ok := node.delegators.Get(channel)
if !ok {
log.Warn("Query failed, failed to get query shard delegator", zap.Error(ErrGetDelegatorFailed))
failRet.Status.Reason = ErrGetDelegatorFailed.Error()
return failRet, nil
}
// do query
results, err := sd.Query(queryCtx, req)
if err != nil {
log.Warn("failed to query on delegator", zap.Error(err))
failRet.Status.Reason = err.Error()
return failRet, nil
}
// reduce result
tr.CtxElapse(ctx, fmt.Sprintf("start reduce query result, traceID = %s, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
traceID,
req.GetFromShardLeader(),
channel,
req.GetSegmentIDs(),
))
collection := node.manager.Collection.Get(req.Req.GetCollectionID())
if collection == nil {
log.Warn("Query failed, failed to get collection")
failRet.Status.Reason = segments.WrapCollectionNotFound(req.Req.CollectionID).Error()
return failRet, nil
}
ret, err := segments.MergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, req.Req.GetLimit(), req.GetReq().GetOutputFieldsId(), collection.Schema())
if err != nil {
failRet.Status.Reason = err.Error()
return failRet, nil
}
tr.CtxElapse(ctx, fmt.Sprintf("do query with channel done , vChannel = %s, segmentIDs = %v",
channel,
req.GetSegmentIDs(),
))
//
failRet.Status.ErrorCode = commonpb.ErrorCode_Success
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel).Inc()
return ret, nil
}
func (node *QueryNode) querySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
collection := node.manager.Collection.Get(req.Req.GetCollectionID())
if collection == nil {
return nil, segments.ErrCollectionNotFound
}
// build plan
retrievePlan, err := segments.NewRetrievePlan(
collection,
req.Req.GetSerializedExprPlan(),
req.Req.GetTravelTimestamp(),
req.Req.Base.GetMsgID(),
)
if err != nil {
return nil, err
}
defer retrievePlan.Delete()
var results []*segcorepb.RetrieveResults
if req.GetScope() == querypb.DataScope_Historical {
results, _, _, err = segments.RetrieveHistorical(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs(), node.cacheChunkManager)
} else {
results, _, _, err = segments.RetrieveStreaming(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs(), node.cacheChunkManager)
}
if err != nil {
return nil, err
}
reducedResult, err := segments.MergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, req.Req.GetLimit(), req.Req.GetOutputFieldsId(), collection.Schema())
if err != nil {
return nil, err
}
return &internalpb.RetrieveResults{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
Ids: reducedResult.Ids,
FieldsData: reducedResult.FieldsData,
}, nil
}
func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchRequest, channel string) (*internalpb.SearchResults, error) {
log := log.Ctx(ctx).With(
zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
zap.Int64("collectionID", req.Req.GetCollectionID()),
zap.String("channel", channel),
zap.String("scope", req.GetScope().String()),
)
traceID := trace.SpanFromContext(ctx).SpanContext().TraceID()
if !node.lifetime.Add(commonpbutil.IsHealthy) {
return nil, WrapErrNodeUnhealthy(paramtable.GetNodeID())
}
defer node.lifetime.Done()
log.Debug("start to search channel",
zap.Bool("fromShardLeader", req.GetFromShardLeader()),
zap.Int64s("segmentIDs", req.GetSegmentIDs()),
)
searchCtx, cancel := context.WithCancel(ctx)
defer cancel()
// TODO From Shard Delegator
if req.GetFromShardLeader() {
tr := timerecord.NewTimeRecorder("searchChannel")
log.Debug("search channel...")
collection := node.manager.Collection.Get(req.Req.GetCollectionID())
if collection == nil {
log.Warn("failed to search channel", zap.Error(segments.ErrCollectionNotFound))
return nil, segments.WrapCollectionNotFound(req.GetReq().GetCollectionID())
}
task := tasks.NewSearchTask(searchCtx, collection, node.manager, req)
if !node.scheduler.Add(task) {
log.Warn("failed to search channel", zap.Error(tasks.ErrTaskQueueFull))
return nil, tasks.ErrTaskQueueFull
}
err := task.Wait()
if err != nil {
log.Warn("failed to search channel", zap.Error(err))
return nil, err
}
tr.CtxElapse(ctx, fmt.Sprintf("search channel done, channel = %s, segmentIDs = %v",
channel,
req.GetSegmentIDs(),
))
// TODO QueryNodeSQLatencyInQueue QueryNodeReduceLatency
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel).Inc()
return task.Result(), nil
}
// From Proxy
tr := timerecord.NewTimeRecorder("searchDelegator")
// get delegator
sd, ok := node.delegators.Get(channel)
if !ok {
log.Warn("Query failed, failed to get query shard delegator", zap.Error(ErrGetDelegatorFailed))
return nil, ErrGetDelegatorFailed
}
// do search
results, err := sd.Search(searchCtx, req)
if err != nil {
log.Warn("failed to search on delegator", zap.Error(err))
return nil, err
}
// reduce result
tr.CtxElapse(ctx, fmt.Sprintf("start reduce query result, traceID = %s, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
traceID,
req.GetFromShardLeader(),
channel,
req.GetSegmentIDs(),
))
ret, err := segments.ReduceSearchResults(ctx, results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
if err != nil {
return nil, err
}
tr.CtxElapse(ctx, fmt.Sprintf("do search with channel done , vChannel = %s, segmentIDs = %v",
channel,
req.GetSegmentIDs(),
))
// update metric to prometheus
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel).Inc()
metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetNq()))
metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetTopk()))
return ret, nil
}
func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.GetStatisticsRequest, channel string) (*internalpb.GetStatisticsResponse, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.Req.GetCollectionID()),
zap.String("channel", channel),
zap.String("scope", req.GetScope().String()),
)
failRet := &internalpb.GetStatisticsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
},
}
if req.GetFromShardLeader() {
var results []segments.SegmentStats
var err error
switch req.GetScope() {
case querypb.DataScope_Historical:
results, _, _, err = segments.StatisticsHistorical(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs())
case querypb.DataScope_Streaming:
results, _, _, err = segments.StatisticStreaming(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs())
}
if err != nil {
log.Warn("get segments statistics failed", zap.Error(err))
return nil, err
}
return segmentStatsResponse(results), nil
}
sd, ok := node.delegators.Get(channel)
if !ok {
log.Warn("GetStatistics failed, failed to get query shard delegator")
return failRet, nil
}
results, err := sd.GetStatistics(ctx, req)
if err != nil {
log.Warn("failed to get statistics from delegator", zap.Error(err))
failRet.Status.Reason = err.Error()
return failRet, nil
}
ret, err := reduceStatisticResponse(results)
if err != nil {
failRet.Status.Reason = err.Error()
return failRet, nil
}
return ret, nil
}
func segmentStatsResponse(segStats []segments.SegmentStats) *internalpb.GetStatisticsResponse {
var totalRowNum int64
for _, stats := range segStats {
totalRowNum += stats.RowCount
}
resultMap := make(map[string]string)
resultMap["row_count"] = strconv.FormatInt(totalRowNum, 10)
ret := &internalpb.GetStatisticsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
Stats: funcutil.Map2KeyValuePair(resultMap),
}
return ret
}
func reduceStatisticResponse(results []*internalpb.GetStatisticsResponse) (*internalpb.GetStatisticsResponse, error) {
mergedResults := map[string]interface{}{
"row_count": int64(0),
}
fieldMethod := map[string]func(string) error{
"row_count": func(str string) error {
count, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return err
}
mergedResults["row_count"] = mergedResults["row_count"].(int64) + count
return nil
},
}
for _, partialResult := range results {
for _, pair := range partialResult.Stats {
fn, ok := fieldMethod[pair.Key]
if !ok {
return nil, fmt.Errorf("unknown statistic field: %s", pair.Key)
}
if err := fn(pair.Value); err != nil {
return nil, err
}
}
}
stringMap := make(map[string]string)
for k, v := range mergedResults {
stringMap[k] = fmt.Sprint(v)
}
ret := &internalpb.GetStatisticsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
Stats: funcutil.Map2KeyValuePair(stringMap),
}
return ret, nil
}

View File

@ -0,0 +1,138 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querynodev2
import (
"context"
"os"
"testing"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
clientv3 "go.etcd.io/etcd/client/v3"
)
type HandlersSuite struct {
suite.Suite
// Data
collectionID int64
collectionName string
segmentID int64
channel string
// Dependency
params *paramtable.ComponentParam
node *QueryNode
etcd *clientv3.Client
chunkManagerFactory *storage.ChunkManagerFactory
// Mock
factory *dependency.MockFactory
}
func (suite *HandlersSuite) SetupSuite() {
suite.collectionID = 111
suite.collectionName = "test-collection"
suite.segmentID = 1
suite.channel = "test-channel"
}
func (suite *HandlersSuite) SetupTest() {
var err error
paramtable.Init()
suite.params = paramtable.Get()
suite.params.Save(suite.params.QueryNodeCfg.GCEnabled.Key, "false")
// mock factory
suite.factory = dependency.NewMockFactory(suite.T())
suite.chunkManagerFactory = storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus_test"))
// new node
suite.node = NewQueryNode(context.Background(), suite.factory)
// init etcd
suite.etcd, err = etcd.GetEtcdClient(
suite.params.EtcdCfg.UseEmbedEtcd.GetAsBool(),
suite.params.EtcdCfg.EtcdUseSSL.GetAsBool(),
suite.params.EtcdCfg.Endpoints.GetAsStrings(),
suite.params.EtcdCfg.EtcdTLSCert.GetValue(),
suite.params.EtcdCfg.EtcdTLSKey.GetValue(),
suite.params.EtcdCfg.EtcdTLSCACert.GetValue(),
suite.params.EtcdCfg.EtcdTLSMinVersion.GetValue())
suite.NoError(err)
}
func (suite *HandlersSuite) TearDownTest() {
suite.etcd.Close()
os.RemoveAll("/tmp/milvus-test")
}
func (suite *HandlersSuite) TestLoadGrowingSegments() {
ctx := context.Background()
var err error
// mock
loadSegmetns := []int64{}
delegator := delegator.NewMockShardDelegator(suite.T())
delegator.EXPECT().LoadGrowing(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) {
for _, info := range infos {
loadSegmetns = append(loadSegmetns, info.SegmentID)
}
}).Return(nil)
req := &querypb.WatchDmChannelsRequest{
Infos: []*datapb.VchannelInfo{
{
CollectionID: suite.collectionID,
ChannelName: suite.channel,
UnflushedSegmentIds: []int64{suite.segmentID},
},
},
SegmentInfos: make(map[int64]*datapb.SegmentInfo),
}
// unflushed segment not in segmentInfos, will skip
err = loadGrowingSegments(ctx, delegator, req)
suite.NoError(err)
suite.Equal(0, len(loadSegmetns))
// binlog was empty, will skip
req.SegmentInfos[suite.segmentID] = &datapb.SegmentInfo{
ID: suite.segmentID,
CollectionID: suite.collectionID,
Binlogs: make([]*datapb.FieldBinlog, 0),
}
err = loadGrowingSegments(ctx, delegator, req)
suite.NoError(err)
suite.Equal(0, len(loadSegmetns))
// normal load
binlog := &datapb.FieldBinlog{}
req.SegmentInfos[suite.segmentID].Binlogs = append(req.SegmentInfos[suite.segmentID].Binlogs, binlog)
err = loadGrowingSegments(ctx, delegator, req)
suite.NoError(err)
suite.Equal(1, len(loadSegmetns))
}
func TestHandlersSuite(t *testing.T) {
suite.Run(t, new(HandlersSuite))
}

View File

@ -0,0 +1,103 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querynodev2
import (
"context"
"fmt"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/samber/lo"
"go.uber.org/zap"
)
var _ cluster.Worker = &LocalWorker{}
type LocalWorker struct {
node *QueryNode
}
func NewLocalWorker(node *QueryNode) *LocalWorker {
return &LocalWorker{
node: node,
}
}
func (w *LocalWorker) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error {
log := log.Ctx(ctx)
log.Info("start to load segments...")
loaded, err := w.node.loader.Load(ctx,
req.GetCollectionID(),
segments.SegmentTypeSealed,
req.GetVersion(),
req.GetInfos()...,
)
if err != nil {
return err
}
log.Info("save loaded segments...",
zap.Int64s("segments", lo.Map(loaded, func(s segments.Segment, _ int) int64 { return s.ID() })))
w.node.manager.Segment.Put(segments.SegmentTypeSealed, loaded...)
return nil
}
func (w *LocalWorker) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) error {
log := log.Ctx(ctx)
log.Info("start to release segments")
for _, id := range req.GetSegmentIDs() {
w.node.manager.Segment.Remove(id, req.GetScope())
}
return nil
}
func (w *LocalWorker) Delete(ctx context.Context, req *querypb.DeleteRequest) error {
log := log.Ctx(ctx)
log.Info("start to process segment delete")
status, err := w.node.Delete(ctx, req)
if err != nil {
return err
}
if status.GetErrorCode() != commonpb.ErrorCode_Success {
return fmt.Errorf(status.GetReason())
}
return nil
}
func (w *LocalWorker) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
return w.node.Search(ctx, req)
}
func (w *LocalWorker) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
return w.node.Query(ctx, req)
}
func (w *LocalWorker) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) {
return w.node.GetStatistics(ctx, req)
}
func (w *LocalWorker) IsHealthy() bool {
return true
}
func (w *LocalWorker) Stop() {
}

View File

@ -0,0 +1,133 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querynodev2
import (
"context"
"testing"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/samber/lo"
"github.com/stretchr/testify/suite"
clientv3 "go.etcd.io/etcd/client/v3"
)
type LocalWorkerTestSuite struct {
suite.Suite
params *paramtable.ComponentParam
// data
collectionID int64
collectionName string
channel string
partitionIDs []int64
segmentIDs []int64
schema *schemapb.CollectionSchema
// dependency
node *QueryNode
worker *LocalWorker
etcdClient *clientv3.Client
// context
ctx context.Context
cancel context.CancelFunc
}
func (suite *LocalWorkerTestSuite) SetupSuite() {
suite.collectionID = 111
suite.collectionName = "test-collection"
suite.channel = "test-channel"
suite.partitionIDs = []int64{11, 22}
suite.segmentIDs = []int64{0, 1}
}
func (suite *LocalWorkerTestSuite) BeforeTest(suiteName, testName string) {
var err error
// init param
paramtable.Init()
suite.params = paramtable.Get()
// close GC at test to avoid data race
suite.params.Save(suite.params.QueryNodeCfg.GCEnabled.Key, "false")
suite.ctx, suite.cancel = context.WithCancel(context.Background())
// init node
factory := dependency.NewDefaultFactory(true)
suite.node = NewQueryNode(suite.ctx, factory)
// init etcd
suite.etcdClient, err = etcd.GetEtcdClient(
suite.params.EtcdCfg.UseEmbedEtcd.GetAsBool(),
suite.params.EtcdCfg.EtcdUseSSL.GetAsBool(),
suite.params.EtcdCfg.Endpoints.GetAsStrings(),
suite.params.EtcdCfg.EtcdTLSCert.GetValue(),
suite.params.EtcdCfg.EtcdTLSKey.GetValue(),
suite.params.EtcdCfg.EtcdTLSCACert.GetValue(),
suite.params.EtcdCfg.EtcdTLSMinVersion.GetValue())
suite.NoError(err)
suite.node.SetEtcdClient(suite.etcdClient)
err = suite.node.Init()
suite.NoError(err)
err = suite.node.Start()
suite.NoError(err)
suite.schema = segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
collection := segments.NewCollection(suite.collectionID, suite.schema, querypb.LoadType_LoadCollection)
loadMata := &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection,
CollectionID: suite.collectionID,
}
suite.node.manager.Collection.Put(suite.collectionID, collection.Schema(), loadMata)
suite.worker = NewLocalWorker(suite.node)
}
func (suite *LocalWorkerTestSuite) AfterTest(suiteName, testName string) {
suite.node.Stop()
suite.etcdClient.Close()
suite.cancel()
}
func (suite *LocalWorkerTestSuite) TestLoadSegment() {
// load empty
req := &querypb.LoadSegmentsRequest{
CollectionID: suite.collectionID,
Infos: lo.Map(suite.segmentIDs, func(segID int64, _ int) *querypb.SegmentLoadInfo {
return &querypb.SegmentLoadInfo{
CollectionID: suite.collectionID,
PartitionID: suite.partitionIDs[segID%2],
SegmentID: segID,
}
}),
}
err := suite.worker.LoadSegments(suite.ctx, req)
suite.NoError(err)
}
func (suite *LocalWorkerTestSuite) TestReleaseSegment() {
req := &querypb.ReleaseSegmentsRequest{
CollectionID: suite.collectionID,
SegmentIDs: suite.segmentIDs,
}
err := suite.worker.ReleaseSegments(suite.ctx, req)
suite.NoError(err)
}
func TestLocalWorker(t *testing.T) {
suite.Run(t, new(LocalWorkerTestSuite))
}

View File

@ -0,0 +1,183 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querynodev2
import (
"context"
"time"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/querynodev2/collector"
"github.com/milvus-io/milvus/internal/util/hardware"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/ratelimitutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
func getRateMetric() ([]metricsinfo.RateMetric, error) {
rms := make([]metricsinfo.RateMetric, 0)
for _, label := range collector.RateMetrics() {
rate, err := collector.Rate.Rate(label, ratelimitutil.DefaultAvgDuration)
if err != nil {
return nil, err
}
rms = append(rms, metricsinfo.RateMetric{
Label: label,
Rate: rate,
})
}
return rms, nil
}
func getSearchNQInQueue() (metricsinfo.ReadInfoInQueue, error) {
average, err := collector.Average.Average(metricsinfo.SearchQueueMetric)
if err != nil {
return metricsinfo.ReadInfoInQueue{}, err
}
unsolvedQueueLabel := collector.ConstructLabel(metricsinfo.UnsolvedQueueType, metricsinfo.SearchQueueMetric)
readyQueueLabel := collector.ConstructLabel(metricsinfo.ReadyQueueType, metricsinfo.SearchQueueMetric)
receiveQueueLabel := collector.ConstructLabel(metricsinfo.ReceiveQueueType, metricsinfo.SearchQueueMetric)
executeQueueLabel := collector.ConstructLabel(metricsinfo.ExecuteQueueType, metricsinfo.SearchQueueMetric)
return metricsinfo.ReadInfoInQueue{
UnsolvedQueue: collector.Counter.Get(unsolvedQueueLabel),
ReadyQueue: collector.Counter.Get(readyQueueLabel),
ReceiveChan: collector.Counter.Get(receiveQueueLabel),
ExecuteChan: collector.Counter.Get(executeQueueLabel),
AvgQueueDuration: time.Duration(int64(average)),
}, nil
}
func getQueryTasksInQueue() (metricsinfo.ReadInfoInQueue, error) {
average, err := collector.Average.Average(metricsinfo.QueryQueueMetric)
if err != nil {
return metricsinfo.ReadInfoInQueue{}, err
}
unsolvedQueueLabel := collector.ConstructLabel(metricsinfo.UnsolvedQueueType, metricsinfo.QueryQueueMetric)
readyQueueLabel := collector.ConstructLabel(metricsinfo.ReadyQueueType, metricsinfo.QueryQueueMetric)
receiveQueueLabel := collector.ConstructLabel(metricsinfo.ReceiveQueueType, metricsinfo.QueryQueueMetric)
executeQueueLabel := collector.ConstructLabel(metricsinfo.ExecuteQueueType, metricsinfo.QueryQueueMetric)
return metricsinfo.ReadInfoInQueue{
UnsolvedQueue: collector.Counter.Get(unsolvedQueueLabel),
ReadyQueue: collector.Counter.Get(readyQueueLabel),
ReceiveChan: collector.Counter.Get(receiveQueueLabel),
ExecuteChan: collector.Counter.Get(executeQueueLabel),
AvgQueueDuration: time.Duration(int64(average)),
}, nil
}
// getQuotaMetrics returns QueryNodeQuotaMetrics.
func getQuotaMetrics(node *QueryNode) (*metricsinfo.QueryNodeQuotaMetrics, error) {
rms, err := getRateMetric()
if err != nil {
return nil, err
}
sqms, err := getSearchNQInQueue()
if err != nil {
return nil, err
}
qqms, err := getQueryTasksInQueue()
if err != nil {
return nil, err
}
minTsafeChannel, minTsafe := node.tSafeManager.Min()
return &metricsinfo.QueryNodeQuotaMetrics{
Hms: metricsinfo.HardwareMetrics{},
Rms: rms,
Fgm: metricsinfo.FlowGraphMetric{
MinFlowGraphChannel: minTsafeChannel,
MinFlowGraphTt: minTsafe,
NumFlowGraph: node.pipelineManager.Num(),
},
SearchQueue: sqms,
QueryQueue: qqms,
}, nil
}
// getSystemInfoMetrics returns metrics info of QueryNode
func getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, node *QueryNode) (*milvuspb.GetMetricsResponse, error) {
usedMem := hardware.GetUsedMemoryCount()
totalMem := hardware.GetMemoryCount()
quotaMetrics, err := getQuotaMetrics(node)
if err != nil {
return &milvuspb.GetMetricsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
ComponentName: metricsinfo.ConstructComponentName(typeutil.DataNodeRole, paramtable.GetNodeID()),
}, nil
}
hardwareInfos := metricsinfo.HardwareMetrics{
IP: node.session.Address,
CPUCoreCount: hardware.GetCPUNum(),
CPUCoreUsage: hardware.GetCPUUsage(),
Memory: totalMem,
MemoryUsage: usedMem,
Disk: hardware.GetDiskCount(),
DiskUsage: hardware.GetDiskUsage(),
}
quotaMetrics.Hms = hardwareInfos
nodeInfos := metricsinfo.QueryNodeInfos{
BaseComponentInfos: metricsinfo.BaseComponentInfos{
Name: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, paramtable.GetNodeID()),
HardwareInfos: hardwareInfos,
SystemInfo: metricsinfo.DeployMetrics{},
CreatedTime: paramtable.GetCreateTime().String(),
UpdatedTime: paramtable.GetUpdateTime().String(),
Type: typeutil.QueryNodeRole,
ID: node.session.ServerID,
},
SystemConfigurations: metricsinfo.QueryNodeConfiguration{
SimdType: paramtable.Get().CommonCfg.SimdType.GetValue(),
},
QuotaMetrics: quotaMetrics,
}
metricsinfo.FillDeployMetricsWithEnv(&nodeInfos.SystemInfo)
resp, err := metricsinfo.MarshalComponentInfos(nodeInfos)
if err != nil {
return &milvuspb.GetMetricsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
Response: "",
ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, paramtable.GetNodeID()),
}, nil
}
return &milvuspb.GetMetricsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
Response: resp,
ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, paramtable.GetNodeID()),
}, nil
}

View File

@ -0,0 +1,230 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querynodev2
import (
"fmt"
"math"
"math/rand"
"strconv"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
// ---------- unittest util functions ----------
// common definitions
const (
metricTypeKey = common.MetricTypeKey
defaultTopK = int64(10)
defaultRoundDecimal = int64(6)
defaultDim = 128
defaultNProb = 10
defaultEf = 10
defaultMetricType = "L2"
defaultNQ = 10
)
const (
// index type
IndexFaissIDMap = "FLAT"
IndexFaissIVFFlat = "IVF_FLAT"
IndexFaissIVFPQ = "IVF_PQ"
IndexFaissIVFSQ8 = "IVF_SQ8"
IndexFaissBinIDMap = "BIN_FLAT"
IndexFaissBinIVFFlat = "BIN_IVF_FLAT"
IndexHNSW = "HNSW"
IndexANNOY = "ANNOY"
)
// ---------- unittest util functions ----------
// functions of messages and requests
func genIVFFlatDSL(schema *schemapb.CollectionSchema, nProb int, topK int64, roundDecimal int64) (string, error) {
var vecFieldName string
var metricType string
nProbStr := strconv.Itoa(nProb)
topKStr := strconv.FormatInt(topK, 10)
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
for _, f := range schema.Fields {
if f.DataType == schemapb.DataType_FloatVector {
vecFieldName = f.Name
for _, p := range f.IndexParams {
if p.Key == metricTypeKey {
metricType = p.Value
}
}
}
}
if vecFieldName == "" || metricType == "" {
err := errors.New("invalid vector field name or metric type")
return "", err
}
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
"\": {\n \"metric_type\": \"" + metricType +
"\", \n \"params\": {\n \"nprobe\": " + nProbStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
" \n,\"round_decimal\": " + roundDecimalStr +
"\n } \n } \n } \n }", nil
}
func genHNSWDSL(schema *schemapb.CollectionSchema, ef int, topK int64, roundDecimal int64) (string, error) {
var vecFieldName string
var metricType string
efStr := strconv.Itoa(ef)
topKStr := strconv.FormatInt(topK, 10)
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
for _, f := range schema.Fields {
if f.DataType == schemapb.DataType_FloatVector {
vecFieldName = f.Name
for _, p := range f.IndexParams {
if p.Key == metricTypeKey {
metricType = p.Value
}
}
}
}
if vecFieldName == "" || metricType == "" {
err := errors.New("invalid vector field name or metric type")
return "", err
}
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
"\": {\n \"metric_type\": \"" + metricType +
"\", \n \"params\": {\n \"ef\": " + efStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
" \n,\"round_decimal\": " + roundDecimalStr +
"\n } \n } \n } \n }", nil
}
func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecimal int64) (string, error) {
var vecFieldName string
var metricType string
topKStr := strconv.FormatInt(topK, 10)
nProbStr := strconv.Itoa(defaultNProb)
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
for _, f := range schema.Fields {
if f.DataType == schemapb.DataType_FloatVector {
vecFieldName = f.Name
for _, p := range f.IndexParams {
if p.Key == metricTypeKey {
metricType = p.Value
}
}
}
}
if vecFieldName == "" || metricType == "" {
err := errors.New("invalid vector field name or metric type")
return "", err
}
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
"\": {\n \"metric_type\": \"" + metricType +
"\", \n \"params\": {\n \"nprobe\": " + nProbStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
" \n,\"round_decimal\": " + roundDecimalStr +
"\n } \n } \n } \n }", nil
}
func genDSLByIndexType(schema *schemapb.CollectionSchema, indexType string) (string, error) {
if indexType == IndexFaissIDMap { // float vector
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexFaissBinIDMap {
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexFaissIVFFlat {
return genIVFFlatDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexFaissBinIVFFlat { // binary vector
return genIVFFlatDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
} else if indexType == IndexHNSW {
return genHNSWDSL(schema, defaultEf, defaultTopK, defaultRoundDecimal)
}
return "", fmt.Errorf("Invalid indexType")
}
func genPlaceHolderGroup(nq int64) ([]byte, error) {
placeholderValue := &commonpb.PlaceholderValue{
Tag: "$0",
Type: commonpb.PlaceholderType_FloatVector,
Values: make([][]byte, 0),
}
for i := int64(0); i < nq; i++ {
var vec = make([]float32, defaultDim)
for j := 0; j < defaultDim; j++ {
vec[j] = rand.Float32()
}
var rawData []byte
for k, ele := range vec {
buf := make([]byte, 4)
common.Endian.PutUint32(buf, math.Float32bits(ele+float32(k*2)))
rawData = append(rawData, buf...)
}
placeholderValue.Values = append(placeholderValue.Values, rawData)
}
// generate placeholder
placeholderGroup := commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{placeholderValue},
}
placeGroupByte, err := proto.Marshal(&placeholderGroup)
if err != nil {
return nil, err
}
return placeGroupByte, nil
}
func genSimpleRetrievePlanExpr(schema *schemapb.CollectionSchema) ([]byte, error) {
pkField, err := typeutil.GetPrimaryFieldSchema(schema)
if err != nil {
return nil, err
}
planNode := &planpb.PlanNode{
Node: &planpb.PlanNode_Predicates{
Predicates: &planpb.Expr{
Expr: &planpb.Expr_TermExpr{
TermExpr: &planpb.TermExpr{
ColumnInfo: &planpb.ColumnInfo{
FieldId: pkField.FieldID,
DataType: pkField.DataType,
},
Values: []*planpb.GenericValue{
{
Val: &planpb.GenericValue_Int64Val{
Int64Val: 1,
},
},
{
Val: &planpb.GenericValue_Int64Val{
Int64Val: 2,
},
},
{
Val: &planpb.GenericValue_Int64Val{
Int64Val: 3,
},
},
},
},
},
},
},
OutputFieldIds: []int64{pkField.FieldID, 100},
}
planExpr, err := proto.Marshal(planNode)
return planExpr, err
}

View File

@ -0,0 +1,99 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"fmt"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/storage"
base "github.com/milvus-io/milvus/internal/util/pipeline"
"github.com/samber/lo"
"go.uber.org/zap"
)
type deleteNode struct {
*BaseNode
collectionID UniqueID
channel string
manager *DataManager
tSafeManager TSafeManager
delegator delegator.ShardDelegator
}
//addDeleteData find the segment of delete column in DeleteMsg and save in deleteData
func (dNode *deleteNode) addDeleteData(deleteDatas map[UniqueID]*delegator.DeleteData, msg *DeleteMsg) {
deleteData, ok := deleteDatas[msg.PartitionID]
if !ok {
deleteData = &delegator.DeleteData{
PartitionID: msg.PartitionID,
}
deleteDatas[msg.PartitionID] = deleteData
}
pks := storage.ParseIDs2PrimaryKeys(msg.PrimaryKeys)
deleteData.PrimaryKeys = append(deleteData.PrimaryKeys, pks...)
deleteData.Timestamps = append(deleteData.Timestamps, msg.Timestamps...)
deleteData.RowCount += int64(len(pks))
log.Info("pipeline fetch delete msg",
zap.Int64("collectionID", dNode.collectionID),
zap.Int64("partitionID", msg.PartitionID),
zap.Int("insertRowNum", len(pks)),
zap.Uint64("timestampMin", msg.BeginTimestamp),
zap.Uint64("timestampMax", msg.EndTimestamp))
}
func (dNode *deleteNode) Operate(in Msg) Msg {
nodeMsg := in.(*deleteNodeMsg)
// partition id = > DeleteData
deleteDatas := make(map[UniqueID]*delegator.DeleteData)
for _, msg := range nodeMsg.deleteMsgs {
dNode.addDeleteData(deleteDatas, msg)
}
if len(deleteDatas) > 0 {
//do Delete, use ts range max as ts
dNode.delegator.ProcessDelete(lo.Values(deleteDatas), nodeMsg.timeRange.timestampMax)
}
//update tSafe
err := dNode.tSafeManager.Set(dNode.channel, nodeMsg.timeRange.timestampMax)
if err != nil {
// should not happen, QueryNode should addTSafe before start pipeline
panic(fmt.Errorf("serviceTimeNode setTSafe timeout, collectionID = %d, err = %s", dNode.collectionID, err))
}
return nil
}
func newDeleteNode(
collectionID UniqueID, channel string,
manager *DataManager, tSafeManager TSafeManager, delegator delegator.ShardDelegator,
maxQueueLength int32,
) *deleteNode {
return &deleteNode{
BaseNode: base.NewBaseNode(fmt.Sprintf("DeleteNode-%s", channel), maxQueueLength),
collectionID: collectionID,
channel: channel,
manager: manager,
tSafeManager: tSafeManager,
delegator: delegator,
}
}

View File

@ -0,0 +1,108 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"testing"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)
type DeleteNodeSuite struct {
suite.Suite
//datas
collectionID int64
collectionName string
partitionIDs []int64
deletePKs []int64
channel string
timeRange TimeRange
//dependency
tSafeManager TSafeManager
//mocks
manager *segments.Manager
delegator *delegator.MockShardDelegator
}
func (suite *DeleteNodeSuite) SetupSuite() {
suite.collectionID = 111
suite.collectionName = "test-collection"
suite.partitionIDs = []int64{11, 22}
suite.channel = "test-channel"
//segment own data row whichs pk same with segments ID
suite.deletePKs = []int64{1, 2, 3, 4}
suite.timeRange = TimeRange{
timestampMin: 0,
timestampMax: 1,
}
}
func (suite *DeleteNodeSuite) buildDeleteNodeMsg() *deleteNodeMsg {
nodeMsg := &deleteNodeMsg{
deleteMsgs: []*DeleteMsg{},
timeRange: suite.timeRange,
}
for i, pk := range suite.deletePKs {
deleteMsg := buildDeleteMsg(suite.collectionID, suite.partitionIDs[i%len(suite.partitionIDs)], suite.channel, 1)
deleteMsg.PrimaryKeys = genDeletePK(pk)
nodeMsg.deleteMsgs = append(nodeMsg.deleteMsgs, deleteMsg)
}
return nodeMsg
}
func (suite *DeleteNodeSuite) TestBasic() {
//mock
mockCollectionManager := segments.NewMockCollectionManager(suite.T())
mockSegmentManager := segments.NewMockSegmentManager(suite.T())
suite.manager = &segments.Manager{
Collection: mockCollectionManager,
Segment: mockSegmentManager,
}
suite.delegator = delegator.NewMockShardDelegator(suite.T())
suite.delegator.EXPECT().ProcessDelete(mock.Anything, mock.Anything).Run(
func(deleteData []*delegator.DeleteData, ts uint64) {
for _, data := range deleteData {
for _, pk := range data.PrimaryKeys {
suite.True(lo.Contains(suite.deletePKs, pk.GetValue().(int64)))
}
}
})
//init dependency
suite.tSafeManager = tsafe.NewTSafeReplica()
suite.tSafeManager.Add(suite.channel, 0)
//build delete node and data
node := newDeleteNode(suite.collectionID, suite.channel, suite.manager, suite.tSafeManager, suite.delegator, 8)
in := suite.buildDeleteNodeMsg()
//run
out := node.Operate(in)
suite.Nil(out)
//check tsafe
tt, err := suite.tSafeManager.Get(suite.channel)
suite.NoError(err)
suite.Equal(suite.timeRange.timestampMax, tt)
}
func TestDeleteNode(t *testing.T) {
suite.Run(t, new(DeleteNodeSuite))
}

View File

@ -0,0 +1,61 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"fmt"
"github.com/cockroachdb/errors"
)
var (
ErrMsgInvalidType = errors.New("InvalidMessageType")
ErrMsgNotAligned = errors.New("CheckAlignedFailed")
ErrMsgEmpty = errors.New("EmptyMessage")
ErrMsgNotTarget = errors.New("NotTarget")
ErrMsgExcluded = errors.New("SegmentExcluded")
ErrCollectionNotFound = errors.New("CollectionNotFound")
ErrShardDelegatorNotFound = errors.New("ShardDelegatorNotFound")
ErrNewPipelineFailed = errors.New("FailedCreateNewPipeline")
ErrStartPipeline = errors.New("PipineStartFailed")
)
func WrapErrMsgNotAligned(err error) error {
return fmt.Errorf("%w :%s", ErrMsgNotAligned, err)
}
func WrapErrMsgNotTarget(reason string) error {
return fmt.Errorf("%w%s", ErrMsgNotTarget, reason)
}
func WrapErrMsgExcluded(segmentID int64) error {
return fmt.Errorf("%w ID:%d", ErrMsgExcluded, segmentID)
}
func WrapErrNewPipelineFailed(err error) error {
return fmt.Errorf("%w :%s", ErrNewPipelineFailed, err)
}
func WrapErrStartPipeline(reason string) error {
return fmt.Errorf("%w :%s", ErrStartPipeline, reason)
}
func WrapErrShardDelegatorNotFound(channel string) error {
return fmt.Errorf("%w channel:%s", ErrShardDelegatorNotFound, channel)
}

View File

@ -0,0 +1,158 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"fmt"
"reflect"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/util/paramtable"
base "github.com/milvus-io/milvus/internal/util/pipeline"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
)
//filterNode filter the invalid message of pipeline
type filterNode struct {
*BaseNode
collectionID UniqueID
manager *DataManager
excludedSegments *typeutil.ConcurrentMap[int64, *datapb.SegmentInfo]
channel string
InsertMsgPolicys []InsertMsgFilter
DeleteMsgPolicys []DeleteMsgFilter
}
func (fNode *filterNode) Operate(in Msg) Msg {
if in == nil {
log.Debug("type assertion failed for Msg in filterNode because it's nil",
zap.String("name", fNode.Name()))
return nil
}
streamMsgPack, ok := in.(*msgstream.MsgPack)
if !ok {
log.Warn("type assertion failed for MsgPack",
zap.String("msgType", reflect.TypeOf(in).Name()),
zap.String("name", fNode.Name()))
return nil
}
metrics.QueryNodeConsumerMsgCount.
WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.InsertLabel, fmt.Sprint(fNode.collectionID)).
Inc()
metrics.QueryNodeConsumeTimeTickLag.
WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.InsertLabel, fmt.Sprint(fNode.collectionID)).
Set(float64(streamMsgPack.EndTs))
//Get collection from collection manager
collection := fNode.manager.Collection.Get(fNode.collectionID)
if collection == nil {
err := segments.WrapCollectionNotFound(fNode.collectionID)
log.Error(err.Error())
panic(err)
}
out := &insertNodeMsg{
insertMsgs: []*InsertMsg{},
deleteMsgs: []*DeleteMsg{},
timeRange: TimeRange{
timestampMin: streamMsgPack.BeginTs,
timestampMax: streamMsgPack.EndTs,
},
}
//add msg to out if msg pass check of filter
for _, msg := range streamMsgPack.Msgs {
err := fNode.filtrate(collection, msg)
if err != nil {
log.Debug("filter invalid message",
zap.String("message type", msg.Type().String()),
zap.String("channel", fNode.channel),
zap.Int64("collectionID", fNode.collectionID),
zap.Error(err),
)
} else {
out.append(msg)
}
}
return out
}
//filtrate message with filter policy
func (fNode *filterNode) filtrate(c *Collection, msg msgstream.TsMsg) error {
switch msg.Type() {
case commonpb.MsgType_Insert:
insertMsg := msg.(*msgstream.InsertMsg)
metrics.QueryNodeConsumeCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.InsertLabel).Add(float64(proto.Size(insertMsg)))
for _, policy := range fNode.InsertMsgPolicys {
err := policy(fNode, c, insertMsg)
if err != nil {
return err
}
}
case commonpb.MsgType_Delete:
deleteMsg := msg.(*msgstream.DeleteMsg)
metrics.QueryNodeConsumeCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.InsertLabel).Add(float64(proto.Size(deleteMsg)))
for _, policy := range fNode.DeleteMsgPolicys {
err := policy(fNode, c, deleteMsg)
if err != nil {
return err
}
}
default:
return ErrMsgInvalidType
}
return nil
}
func newFilterNode(
collectionID int64,
channel string,
manager *DataManager,
excludedSegments *typeutil.ConcurrentMap[int64, *datapb.SegmentInfo],
maxQueueLength int32,
) *filterNode {
return &filterNode{
BaseNode: base.NewBaseNode(fmt.Sprintf("FilterNode-%s", channel), maxQueueLength),
collectionID: collectionID,
manager: manager,
channel: channel,
excludedSegments: excludedSegments,
InsertMsgPolicys: []InsertMsgFilter{
InsertNotAligned,
InsertEmpty,
InsertOutOfTarget,
InsertExcluded,
},
DeleteMsgPolicys: []DeleteMsgFilter{
DeleteNotAligned,
DeleteEmpty,
DeleteOutOfTarget,
},
}
}

View File

@ -0,0 +1,200 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/msgpb"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/samber/lo"
"github.com/stretchr/testify/suite"
)
//test of filter node
type FilterNodeSuite struct {
suite.Suite
//datas
collectionID int64
partitionIDs []int64
channel string
validSegmentIDs []int64
excludedSegments *typeutil.ConcurrentMap[int64, *datapb.SegmentInfo]
excludedSegmentIDs []int64
insertSegmentIDs []int64
deleteSegmentSum int
//segmentID of msg invalid because empty of not aligned
errSegmentID int64
//mocks
manager *segments.Manager
}
func (suite *FilterNodeSuite) SetupSuite() {
paramtable.Init()
suite.collectionID = 111
suite.partitionIDs = []int64{11, 22}
suite.channel = "test-channel"
// first one invalid because insert max timestamp before dmlPosition timestamp
suite.excludedSegmentIDs = []int64{1, 2}
suite.insertSegmentIDs = []int64{3, 4, 5, 6}
suite.deleteSegmentSum = 4
suite.errSegmentID = 7
//init excludedSegment
suite.excludedSegments = typeutil.NewConcurrentMap[int64, *datapb.SegmentInfo]()
for _, id := range suite.excludedSegmentIDs {
suite.excludedSegments.Insert(id, &datapb.SegmentInfo{
DmlPosition: &msgpb.MsgPosition{
Timestamp: 1,
},
})
}
}
//test filter node with collection load collection
func (suite *FilterNodeSuite) TestWithLoadCollection() {
//data
suite.validSegmentIDs = []int64{2, 3, 4, 5, 6}
//mock
collection := segments.NewCollectionWithoutSchema(suite.collectionID, querypb.LoadType_LoadCollection)
for _, partitionID := range suite.partitionIDs {
collection.AddPartition(partitionID)
}
mockCollectionManager := segments.NewMockCollectionManager(suite.T())
mockCollectionManager.EXPECT().Get(suite.collectionID).Return(collection)
mockSegmentManager := segments.NewMockSegmentManager(suite.T())
suite.manager = &segments.Manager{
Collection: mockCollectionManager,
Segment: mockSegmentManager,
}
node := newFilterNode(suite.collectionID, suite.channel, suite.manager, suite.excludedSegments, 8)
in := suite.buildMsgPack()
out := node.Operate(in)
nodeMsg, ok := out.(*insertNodeMsg)
suite.True(ok)
suite.Equal(len(suite.validSegmentIDs), len(nodeMsg.insertMsgs))
for _, msg := range nodeMsg.insertMsgs {
suite.True(lo.Contains(suite.validSegmentIDs, msg.SegmentID))
}
suite.Equal(suite.deleteSegmentSum, len(nodeMsg.deleteMsgs))
}
//test filter node with collection load partition
func (suite *FilterNodeSuite) TestWithLoadPartation() {
//data
suite.validSegmentIDs = []int64{2, 4, 6}
//mock
collection := segments.NewCollectionWithoutSchema(suite.collectionID, querypb.LoadType_LoadPartition)
collection.AddPartition(suite.partitionIDs[0])
mockCollectionManager := segments.NewMockCollectionManager(suite.T())
mockCollectionManager.EXPECT().Get(suite.collectionID).Return(collection)
mockSegmentManager := segments.NewMockSegmentManager(suite.T())
suite.manager = &segments.Manager{
Collection: mockCollectionManager,
Segment: mockSegmentManager,
}
node := newFilterNode(suite.collectionID, suite.channel, suite.manager, suite.excludedSegments, 8)
in := suite.buildMsgPack()
out := node.Operate(in)
nodeMsg, ok := out.(*insertNodeMsg)
suite.True(ok)
suite.Equal(len(suite.validSegmentIDs), len(nodeMsg.insertMsgs))
for _, msg := range nodeMsg.insertMsgs {
suite.True(lo.Contains(suite.validSegmentIDs, msg.SegmentID))
}
suite.Equal(suite.deleteSegmentSum/2, len(nodeMsg.deleteMsgs))
}
func (suite *FilterNodeSuite) buildMsgPack() *msgstream.MsgPack {
msgPack := &msgstream.MsgPack{
BeginTs: 0,
EndTs: 0,
Msgs: []msgstream.TsMsg{},
}
//add valid insert
for _, id := range suite.insertSegmentIDs {
insertMsg := buildInsertMsg(suite.collectionID, suite.partitionIDs[id%2], id, suite.channel, 1)
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
//add valid delete
for i := 0; i < suite.deleteSegmentSum; i++ {
deleteMsg := buildDeleteMsg(suite.collectionID, suite.partitionIDs[i%2], suite.channel, 1)
msgPack.Msgs = append(msgPack.Msgs, deleteMsg)
}
//add invalid msg
//segment in excludedSegments
//some one end timestamp befroe dmlPosition timestamp will be invalid
for _, id := range suite.excludedSegmentIDs {
insertMsg := buildInsertMsg(suite.collectionID, suite.partitionIDs[id%2], id, suite.channel, 1)
insertMsg.EndTimestamp = uint64(id)
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
//empty msg
insertMsg := buildInsertMsg(suite.collectionID, suite.partitionIDs[0], suite.errSegmentID, suite.channel, 0)
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
deleteMsg := buildDeleteMsg(suite.collectionID, suite.partitionIDs[0], suite.channel, 0)
msgPack.Msgs = append(msgPack.Msgs, deleteMsg)
//msg not target
insertMsg = buildInsertMsg(suite.collectionID+1, 1, 0, "Unknown", 1)
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
deleteMsg = buildDeleteMsg(suite.collectionID+1, 1, "Unknown", 1)
msgPack.Msgs = append(msgPack.Msgs, deleteMsg)
//msg not aligned
insertMsg = buildInsertMsg(suite.collectionID, suite.partitionIDs[0], suite.errSegmentID, suite.channel, 1)
insertMsg.Timestamps = []uint64{}
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
deleteMsg = buildDeleteMsg(suite.collectionID, suite.partitionIDs[0], suite.channel, 1)
deleteMsg.Timestamps = append(deleteMsg.Timestamps, 1)
msgPack.Msgs = append(msgPack.Msgs, deleteMsg)
return msgPack
}
func TestFilterNode(t *testing.T) {
suite.Run(t, new(FilterNodeSuite))
}

View File

@ -0,0 +1,91 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import "github.com/milvus-io/milvus/internal/common"
//MsgFilter will return error if Msg was invalid
type InsertMsgFilter = func(n *filterNode, c *Collection, msg *InsertMsg) error
type DeleteMsgFilter = func(n *filterNode, c *Collection, msg *DeleteMsg) error
//Chack msg is aligned --
//len of each kind of infos in InsertMsg should match each other
func InsertNotAligned(n *filterNode, c *Collection, msg *InsertMsg) error {
err := msg.CheckAligned()
if err != nil {
return WrapErrMsgNotAligned(err)
}
return nil
}
func InsertEmpty(n *filterNode, c *Collection, msg *InsertMsg) error {
if len(msg.GetTimestamps()) <= 0 {
return ErrMsgEmpty
}
return nil
}
func InsertOutOfTarget(n *filterNode, c *Collection, msg *InsertMsg) error {
if msg.GetCollectionID() != c.ID() {
return WrapErrMsgNotTarget("Collection")
}
if c.GetLoadType() == loadTypePartition {
if msg.PartitionID != common.InvalidPartitionID && !c.ExistPartition(msg.PartitionID) {
return WrapErrMsgNotTarget("Partition")
}
}
return nil
}
func InsertExcluded(n *filterNode, c *Collection, msg *InsertMsg) error {
segInfo, ok := n.excludedSegments.Get(msg.SegmentID)
if !ok {
return nil
}
if msg.EndTimestamp <= segInfo.GetDmlPosition().Timestamp {
return WrapErrMsgExcluded(msg.SegmentID)
}
return nil
}
func DeleteNotAligned(n *filterNode, c *Collection, msg *DeleteMsg) error {
err := msg.CheckAligned()
if err != nil {
return WrapErrMsgNotAligned(err)
}
return nil
}
func DeleteEmpty(n *filterNode, c *Collection, msg *DeleteMsg) error {
if len(msg.GetTimestamps()) <= 0 {
return ErrMsgEmpty
}
return nil
}
func DeleteOutOfTarget(n *filterNode, c *Collection, msg *DeleteMsg) error {
if msg.GetCollectionID() != c.ID() {
return WrapErrMsgNotTarget("Collection")
}
if c.GetLoadType() == loadTypePartition {
if msg.PartitionID != common.InvalidPartitionID && !c.ExistPartition(msg.PartitionID) {
return WrapErrMsgNotTarget("Partition")
}
}
return nil
}

View File

@ -0,0 +1,123 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"fmt"
"sort"
"github.com/milvus-io/milvus-proto/go-api/msgpb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/storage"
base "github.com/milvus-io/milvus/internal/util/pipeline"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
)
type insertNode struct {
*BaseNode
collectionID int64
channel string
manager *DataManager
delegator delegator.ShardDelegator
}
func (iNode *insertNode) addInsertData(insertDatas map[UniqueID]*delegator.InsertData, msg *InsertMsg, collection *Collection) {
insertRecord, err := storage.TransferInsertMsgToInsertRecord(collection.Schema(), msg)
if err != nil {
err = fmt.Errorf("failed to get primary keys, err = %d", err)
log.Error(err.Error(), zap.Int64("collectionID", iNode.collectionID), zap.String("channel", iNode.channel))
panic(err)
}
iData, ok := insertDatas[msg.SegmentID]
if !ok {
iData = &delegator.InsertData{
PartitionID: msg.PartitionID,
InsertRecord: insertRecord,
StartPosition: &msgpb.MsgPosition{
Timestamp: msg.BeginTs(),
ChannelName: msg.GetShardName(),
},
}
insertDatas[msg.SegmentID] = iData
} else {
typeutil.MergeFieldData(iData.InsertRecord.FieldsData, insertRecord.FieldsData)
iData.InsertRecord.NumRows += insertRecord.NumRows
}
pks, err := segments.GetPrimaryKeys(msg, collection.Schema())
if err != nil {
log.Error("failed to get primary keys from insert message", zap.Error(err))
panic(err)
}
iData.PrimaryKeys = append(iData.PrimaryKeys, pks...)
iData.RowIDs = append(iData.RowIDs, msg.RowIDs...)
iData.Timestamps = append(iData.Timestamps, msg.Timestamps...)
log.Info("pipeline fetch insert msg",
zap.Int64("collectionID", iNode.collectionID),
zap.Int64("segmentID", msg.SegmentID),
zap.Int("insertRowNum", len(pks)),
zap.Uint64("timestampMin", msg.BeginTimestamp),
zap.Uint64("timestampMax", msg.EndTimestamp))
}
//Insert task
func (iNode *insertNode) Operate(in Msg) Msg {
nodeMsg := in.(*insertNodeMsg)
sort.Slice(nodeMsg.insertMsgs, func(i, j int) bool {
return nodeMsg.insertMsgs[i].BeginTs() < nodeMsg.insertMsgs[j].BeginTs()
})
insertDatas := make(map[UniqueID]*delegator.InsertData)
collection := iNode.manager.Collection.Get(iNode.collectionID)
if collection == nil {
log.Error("insertNode with collection not exist", zap.Int64("collection", iNode.collectionID))
panic("insertNode with collection not exist")
}
//get InsertData and merge datas of same segment
for _, msg := range nodeMsg.insertMsgs {
iNode.addInsertData(insertDatas, msg, collection)
}
iNode.delegator.ProcessInsert(insertDatas)
return &deleteNodeMsg{
deleteMsgs: nodeMsg.deleteMsgs,
timeRange: nodeMsg.timeRange,
}
}
func newInsertNode(
collectionID UniqueID,
channel string,
manager *DataManager,
delegator delegator.ShardDelegator,
maxQueueLength int32,
) *insertNode {
return &insertNode{
BaseNode: base.NewBaseNode(fmt.Sprintf("InsertNode-%s", channel), maxQueueLength),
collectionID: collectionID,
channel: channel,
manager: manager,
delegator: delegator,
}
}

View File

@ -0,0 +1,116 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)
type InsertNodeSuite struct {
suite.Suite
//datas
collectionName string
collectionID int64
partitionID int64
channel string
insertSegmentIDs []int64
deleteSegmentSum int
//mocks
manager *segments.Manager
delegator *delegator.MockShardDelegator
}
func (suite *InsertNodeSuite) SetupSuite() {
suite.collectionName = "test-collection"
suite.collectionID = 111
suite.partitionID = 11
suite.channel = "test_channel"
suite.insertSegmentIDs = []int64{4, 3}
suite.deleteSegmentSum = 2
}
func (suite *InsertNodeSuite) TestBasic() {
//data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
in := suite.buildInsertNodeMsg(schema)
collection := segments.NewCollection(suite.collectionID, schema, querypb.LoadType_LoadCollection)
collection.AddPartition(suite.partitionID)
//init mock
mockCollectionManager := segments.NewMockCollectionManager(suite.T())
mockCollectionManager.EXPECT().Get(suite.collectionID).Return(collection)
mockSegmentManager := segments.NewMockSegmentManager(suite.T())
suite.manager = &segments.Manager{
Collection: mockCollectionManager,
Segment: mockSegmentManager,
}
suite.delegator = delegator.NewMockShardDelegator(suite.T())
suite.delegator.EXPECT().ProcessInsert(mock.Anything).Run(func(insertRecords map[int64]*delegator.InsertData) {
for segID := range insertRecords {
suite.True(lo.Contains(suite.insertSegmentIDs, segID))
}
})
//TODO mock a delgator for test
node := newInsertNode(suite.collectionID, suite.channel, suite.manager, suite.delegator, 8)
out := node.Operate(in)
nodeMsg, ok := out.(*deleteNodeMsg)
suite.True(ok)
suite.Equal(suite.deleteSegmentSum, len(nodeMsg.deleteMsgs))
}
func (suite *InsertNodeSuite) buildInsertNodeMsg(schema *schemapb.CollectionSchema) *insertNodeMsg {
nodeMsg := insertNodeMsg{
insertMsgs: []*InsertMsg{},
deleteMsgs: []*DeleteMsg{},
timeRange: TimeRange{
timestampMin: 0,
timestampMax: 0,
},
}
for _, segmentID := range suite.insertSegmentIDs {
insertMsg := buildInsertMsg(suite.collectionID, suite.partitionID, segmentID, suite.channel, 1)
insertMsg.FieldsData = genFiledDataWithSchema(schema, 1)
nodeMsg.insertMsgs = append(nodeMsg.insertMsgs, insertMsg)
}
for i := 0; i < suite.deleteSegmentSum; i++ {
deleteMsg := buildDeleteMsg(suite.collectionID, suite.partitionID, suite.channel, 1)
nodeMsg.deleteMsgs = append(nodeMsg.deleteMsgs, deleteMsg)
}
return &nodeMsg
}
func TestInsertNode(t *testing.T) {
suite.Run(t, new(InsertNodeSuite))
}

View File

@ -0,0 +1,163 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"fmt"
"sync"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
)
//Manager manage pipeline in querynode
type Manager interface {
Num() int
Add(collectionID UniqueID, channel string) (Pipeline, error)
Get(channel string) Pipeline
Remove(channels ...string)
Start(channels ...string) error
Close()
}
type manager struct {
channel2Pipeline map[string]Pipeline
dataManager *DataManager
delegators *typeutil.ConcurrentMap[string, delegator.ShardDelegator]
tSafeManager TSafeManager
dispatcher msgdispatcher.Client
mu sync.Mutex
}
func (m *manager) Num() int {
return len(m.channel2Pipeline)
}
//Add pipeline for each channel of collection
func (m *manager) Add(collectionID UniqueID, channel string) (Pipeline, error) {
m.mu.Lock()
defer m.mu.Unlock()
log.Debug("start create pipeine",
zap.Int64("collectionID", collectionID),
zap.String("channel", channel),
)
collection := m.dataManager.Collection.Get(collectionID)
if collection == nil {
return nil, segments.WrapCollectionNotFound(collectionID)
}
if pipeline, ok := m.channel2Pipeline[channel]; ok {
return pipeline, nil
}
//get shard delegator for add growing in pipeline
delegator, ok := m.delegators.Get(channel)
if !ok {
return nil, WrapErrShardDelegatorNotFound(channel)
}
newPipeLine, err := NewPipeLine(collectionID, channel, m.dataManager, m.tSafeManager, m.dispatcher, delegator)
if err != nil {
return nil, WrapErrNewPipelineFailed(err)
}
m.channel2Pipeline[channel] = newPipeLine
metrics.QueryNodeNumFlowGraphs.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
metrics.QueryNodeNumDmlChannels.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
return newPipeLine, nil
}
func (m *manager) Get(channel string) Pipeline {
m.mu.Lock()
defer m.mu.Unlock()
pipeline, ok := m.channel2Pipeline[channel]
if !ok {
log.Warn("pipeline not existed",
zap.String("channel", channel),
)
return nil
}
return pipeline
}
//Remove pipeline from Manager by channel
func (m *manager) Remove(channels ...string) {
m.mu.Lock()
defer m.mu.Unlock()
for _, channel := range channels {
if pipeline, ok := m.channel2Pipeline[channel]; ok {
pipeline.Close()
delete(m.channel2Pipeline, channel)
} else {
log.Warn("pipeline to be removed doesn't existed", zap.Any("channel", channel))
}
}
metrics.QueryNodeNumFlowGraphs.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
metrics.QueryNodeNumDmlChannels.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
}
//Start pipeline by channel
func (m *manager) Start(channels ...string) error {
m.mu.Lock()
defer m.mu.Unlock()
//check pipelie all exist before start
for _, channel := range channels {
if _, ok := m.channel2Pipeline[channel]; !ok {
return WrapErrStartPipeline(fmt.Sprintf("pipeline with channel %s not exist", channel))
}
}
for _, channel := range channels {
m.channel2Pipeline[channel].Start()
}
return nil
}
//Close all pipeline of Manager
func (m *manager) Close() {
m.mu.Lock()
defer m.mu.Unlock()
for _, pipeline := range m.channel2Pipeline {
pipeline.Close()
}
}
func NewManager(dataManager *DataManager,
tSafeManager TSafeManager,
dispatcher msgdispatcher.Client,
delegators *typeutil.ConcurrentMap[string, delegator.ShardDelegator],
) Manager {
return &manager{
channel2Pipeline: make(map[string]Pipeline),
dataManager: dataManager,
delegators: delegators,
tSafeManager: tSafeManager,
dispatcher: dispatcher,
mu: sync.Mutex{},
}
}

View File

@ -0,0 +1,116 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/msgpb"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)
type PipelineManagerTestSuite struct {
suite.Suite
//data
collectionID int64
channel string
//dependencies
tSafeManager TSafeManager
delegators *typeutil.ConcurrentMap[string, delegator.ShardDelegator]
//mocks
segmentManager *segments.MockSegmentManager
collectionManager *segments.MockCollectionManager
delegator *delegator.MockShardDelegator
msgDispatcher *msgdispatcher.MockClient
msgChan chan *msgstream.MsgPack
}
func (suite *PipelineManagerTestSuite) SetupSuite() {
suite.collectionID = 111
suite.msgChan = make(chan *msgstream.MsgPack, 1)
}
func (suite *PipelineManagerTestSuite) SetupTest() {
paramtable.Init()
//init dependency
// init tsafeManager
suite.tSafeManager = tsafe.NewTSafeReplica()
suite.tSafeManager.Add(suite.channel, 0)
suite.delegators = typeutil.NewConcurrentMap[string, delegator.ShardDelegator]()
//init mock
// init manager
suite.collectionManager = segments.NewMockCollectionManager(suite.T())
suite.segmentManager = segments.NewMockSegmentManager(suite.T())
// init delegator
suite.delegator = delegator.NewMockShardDelegator(suite.T())
suite.delegators.Insert(suite.channel, suite.delegator)
// init mq dispatcher
suite.msgDispatcher = msgdispatcher.NewMockClient(suite.T())
}
func (suite *PipelineManagerTestSuite) TestBasic() {
//init mock
// mock collection manager
suite.collectionManager.EXPECT().Get(suite.collectionID).Return(&segments.Collection{})
// mock mq factory
suite.msgDispatcher.EXPECT().Register(suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.msgChan, nil)
suite.msgDispatcher.EXPECT().Deregister(suite.channel)
//build manager
manager := &segments.Manager{
Collection: suite.collectionManager,
Segment: suite.segmentManager,
}
pipelineManager := NewManager(manager, suite.tSafeManager, suite.msgDispatcher, suite.delegators)
defer pipelineManager.Close()
//Add pipeline
_, err := pipelineManager.Add(suite.collectionID, suite.channel)
suite.NoError(err)
suite.Equal(1, pipelineManager.Num())
//Get pipeline
pipeline := pipelineManager.Get(suite.channel)
suite.NotNil(pipeline)
//Init Consumer
err = pipeline.ConsumeMsgStream(&msgpb.MsgPosition{})
suite.NoError(err)
//Start pipeline
err = pipelineManager.Start(suite.channel)
suite.NoError(err)
//Remove pipeline
pipelineManager.Remove(suite.channel)
suite.Equal(0, pipelineManager.Num())
}
func TestQueryNodePipelineManager(t *testing.T) {
suite.Run(t, new(PipelineManagerTestSuite))
}

View File

@ -0,0 +1,52 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/querynodev2/collector"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
)
type insertNodeMsg struct {
insertMsgs []*InsertMsg
deleteMsgs []*DeleteMsg
timeRange TimeRange
}
type deleteNodeMsg struct {
deleteMsgs []*DeleteMsg
timeRange TimeRange
}
func (msg *insertNodeMsg) append(taskMsg msgstream.TsMsg) error {
switch taskMsg.Type() {
case commonpb.MsgType_Insert:
insertMsg := taskMsg.(*InsertMsg)
msg.insertMsgs = append(msg.insertMsgs, insertMsg)
collector.Rate.Add(metricsinfo.InsertConsumeThroughput, float64(proto.Size(&insertMsg.InsertRequest)))
case commonpb.MsgType_Delete:
deleteMsg := taskMsg.(*DeleteMsg)
msg.deleteMsgs = append(msg.deleteMsgs, deleteMsg)
collector.Rate.Add(metricsinfo.DeleteConsumeThroughput, float64(proto.Size(&deleteMsg.DeleteRequest)))
default:
return ErrMsgInvalidType
}
return nil
}

View File

@ -0,0 +1,173 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"math/rand"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/msgpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
)
const defaultDim = 128
func buildDeleteMsg(collectionID int64, partitionID int64, channel string, rowSum int) *msgstream.DeleteMsg {
deleteMsg := emptyDeleteMsg(collectionID, partitionID, channel)
for i := 1; i <= rowSum; i++ {
deleteMsg.Timestamps = append(deleteMsg.Timestamps, 0)
deleteMsg.HashValues = append(deleteMsg.HashValues, 0)
deleteMsg.NumRows++
}
deleteMsg.PrimaryKeys = genDefaultDeletePK(rowSum)
return deleteMsg
}
func buildInsertMsg(collectionID int64, partitionID int64, segmentID int64, channel string, rowSum int) *msgstream.InsertMsg {
insertMsg := emptyInsertMsg(collectionID, partitionID, segmentID, channel)
for i := 1; i <= rowSum; i++ {
insertMsg.HashValues = append(insertMsg.HashValues, 0)
insertMsg.Timestamps = append(insertMsg.Timestamps, 0)
insertMsg.RowIDs = append(insertMsg.RowIDs, rand.Int63n(100))
insertMsg.NumRows++
}
insertMsg.FieldsData = genDefaultFiledData(rowSum)
return insertMsg
}
func emptyDeleteMsg(collectionID int64, partitionID int64, channel string) *msgstream.DeleteMsg {
deleteReq := msgpb.DeleteRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Delete),
commonpbutil.WithTimeStamp(0),
),
CollectionID: collectionID,
PartitionID: partitionID,
ShardName: channel,
}
return &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{},
DeleteRequest: deleteReq,
}
}
func emptyInsertMsg(collectionID int64, partitionID int64, segmentID int64, channel string) *msgstream.InsertMsg {
insertReq := msgpb.InsertRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
commonpbutil.WithTimeStamp(0),
),
CollectionID: collectionID,
PartitionID: partitionID,
SegmentID: segmentID,
ShardName: channel,
Version: msgpb.InsertDataVersion_ColumnBased,
}
insertMsg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{},
InsertRequest: insertReq,
}
return insertMsg
}
//gen IDs with random pks for DeleteMsg
func genDefaultDeletePK(rowSum int) *schemapb.IDs {
pkDatas := []int64{}
for i := 1; i <= rowSum; i++ {
pkDatas = append(pkDatas, int64(i))
}
return &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: pkDatas,
},
},
}
}
//gen IDs with specified pk
func genDeletePK(pks ...int64) *schemapb.IDs {
pkDatas := make([]int64, 0, len(pks))
pkDatas = append(pkDatas, pks...)
return &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: pkDatas,
},
},
}
}
func genDefaultFiledData(numRows int) []*schemapb.FieldData {
pkDatas := []int64{}
vectorDatas := []byte{}
for i := 1; i <= numRows; i++ {
pkDatas = append(pkDatas, int64(i))
vectorDatas = append(vectorDatas, uint8(i))
}
return []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "pk",
FieldId: 100,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: pkDatas,
},
},
},
},
},
{
Type: schemapb.DataType_BinaryVector,
FieldName: "vector",
FieldId: 101,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 8,
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: vectorDatas,
},
},
},
},
}
}
func genFiledDataWithSchema(schema *schemapb.CollectionSchema, numRows int) []*schemapb.FieldData {
fieldsData := make([]*schemapb.FieldData, 0)
for _, field := range schema.Fields {
if field.DataType < 100 {
fieldsData = append(fieldsData, segments.GenTestScalarFieldData(field.DataType, field.DataType.String(), numRows))
} else {
fieldsData = append(fieldsData, segments.GenTestVectorFiledData(field.DataType, field.DataType.String(), numRows, defaultDim))
}
}
return fieldsData
}

View File

@ -0,0 +1,80 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/util/paramtable"
base "github.com/milvus-io/milvus/internal/util/pipeline"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
)
//pipeline used for querynode
type Pipeline interface {
base.StreamPipeline
ExcludedSegments(segInfos ...*datapb.SegmentInfo)
}
type pipeline struct {
base.StreamPipeline
excludedSegments *typeutil.ConcurrentMap[int64, *datapb.SegmentInfo]
collectionID UniqueID
}
func (p *pipeline) ExcludedSegments(segInfos ...*datapb.SegmentInfo) {
for _, segInfo := range segInfos {
log.Debug("pipeline add exclude info",
zap.Int64("segmentID", segInfo.GetID()),
zap.Uint64("tss", segInfo.GetDmlPosition().Timestamp),
)
p.excludedSegments.Insert(segInfo.GetID(), segInfo)
}
}
func (p *pipeline) Close() {
p.StreamPipeline.Close()
metrics.CleanupQueryNodeCollectionMetrics(paramtable.GetNodeID(), p.collectionID)
}
func NewPipeLine(
collectionID UniqueID,
channel string,
manager *DataManager,
tSafeManager TSafeManager,
dispatcher msgdispatcher.Client,
delegator delegator.ShardDelegator,
) (Pipeline, error) {
pipelineQueueLength := paramtable.Get().QueryNodeCfg.FlowGraphMaxQueueLength.GetAsInt32()
excludedSegments := typeutil.NewConcurrentMap[int64, *datapb.SegmentInfo]()
p := &pipeline{
collectionID: collectionID,
excludedSegments: excludedSegments,
StreamPipeline: base.NewPipelineWithStream(dispatcher, nodeCtxTtInterval, enableTtChecker, channel),
}
filterNode := newFilterNode(collectionID, channel, manager, excludedSegments, pipelineQueueLength)
insertNode := newInsertNode(collectionID, channel, manager, delegator, pipelineQueueLength)
deleteNode := newDeleteNode(collectionID, channel, manager, tSafeManager, delegator, pipelineQueueLength)
p.Add(filterNode, insertNode, deleteNode)
return p, nil
}

View File

@ -0,0 +1,167 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/msgpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)
type PipelineTestSuite struct {
suite.Suite
//datas
collectionName string
collectionID int64
partitionIDs []int64
channel string
insertSegmentIDs []int64
deletePKs []int64
//dependencies
tSafeManager TSafeManager
//mocks
segmentManager *segments.MockSegmentManager
collectionManager *segments.MockCollectionManager
delegator *delegator.MockShardDelegator
msgDispatcher *msgdispatcher.MockClient
msgChan chan *msgstream.MsgPack
}
func (suite *PipelineTestSuite) SetupSuite() {
suite.collectionID = 111
suite.collectionName = "test-collection"
suite.channel = "test-channel"
suite.partitionIDs = []int64{11, 22}
suite.insertSegmentIDs = []int64{1, 2, 3}
suite.deletePKs = []int64{1, 2, 3}
suite.msgChan = make(chan *msgstream.MsgPack, 1)
}
func (suite *PipelineTestSuite) buildMsgPack(schema *schemapb.CollectionSchema) *msgstream.MsgPack {
msgPack := &msgstream.MsgPack{
BeginTs: 0,
EndTs: 1,
Msgs: []msgstream.TsMsg{},
}
for id, segmentID := range suite.insertSegmentIDs {
insertMsg := buildInsertMsg(suite.collectionID, suite.partitionIDs[id%len(suite.partitionIDs)], segmentID, suite.channel, 1)
insertMsg.FieldsData = genFiledDataWithSchema(schema, 1)
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
}
for id, pk := range suite.deletePKs {
deleteMsg := buildDeleteMsg(suite.collectionID, suite.partitionIDs[id%len(suite.partitionIDs)], suite.channel, 1)
deleteMsg.PrimaryKeys = genDeletePK(pk)
msgPack.Msgs = append(msgPack.Msgs, deleteMsg)
}
return msgPack
}
func (suite *PipelineTestSuite) SetupTest() {
paramtable.Init()
//init mock
// init manager
suite.collectionManager = segments.NewMockCollectionManager(suite.T())
suite.segmentManager = segments.NewMockSegmentManager(suite.T())
// init delegator
suite.delegator = delegator.NewMockShardDelegator(suite.T())
// init mq dispatcher
suite.msgDispatcher = msgdispatcher.NewMockClient(suite.T())
//init dependency
// init tsafeManager
suite.tSafeManager = tsafe.NewTSafeReplica()
suite.tSafeManager.Add(suite.channel, 0)
}
func (suite *PipelineTestSuite) TestBasic() {
//init mock
// mock collection manager
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
collection := segments.NewCollection(suite.collectionID, schema, querypb.LoadType_LoadCollection)
suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection)
// mock mq factory
suite.msgDispatcher.EXPECT().Register(suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.msgChan, nil)
suite.msgDispatcher.EXPECT().Deregister(suite.channel)
// mock delegator
suite.delegator.EXPECT().ProcessInsert(mock.Anything).Run(
func(insertRecords map[int64]*delegator.InsertData) {
for segmentID := range insertRecords {
suite.True(lo.Contains(suite.insertSegmentIDs, segmentID))
}
})
suite.delegator.EXPECT().ProcessDelete(mock.Anything, mock.Anything).Run(
func(deleteData []*delegator.DeleteData, ts uint64) {
for _, data := range deleteData {
for _, pk := range data.PrimaryKeys {
suite.True(lo.Contains(suite.deletePKs, pk.GetValue().(int64)))
}
}
})
//build pipleine
manager := &segments.Manager{
Collection: suite.collectionManager,
Segment: suite.segmentManager,
}
pipeline, err := NewPipeLine(suite.collectionID, suite.channel, manager, suite.tSafeManager, suite.msgDispatcher, suite.delegator)
suite.NoError(err)
//Init Consumer
err = pipeline.ConsumeMsgStream(&msgpb.MsgPosition{})
suite.NoError(err)
err = pipeline.Start()
suite.NoError(err)
defer pipeline.Close()
// watch tsafe manager
listener := suite.tSafeManager.WatchChannel(suite.channel)
// build input msg
in := suite.buildMsgPack(schema)
suite.msgChan <- in
// wait pipeline work
<-listener.On()
//check tsafe
tsafe, err := suite.tSafeManager.Get(suite.channel)
suite.NoError(err)
suite.Equal(in.EndTs, tsafe)
}
func TestQueryNodePipeline(t *testing.T) {
suite.Run(t, new(PipelineTestSuite))
}

View File

@ -0,0 +1,63 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"time"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
"github.com/milvus-io/milvus/internal/storage"
base "github.com/milvus-io/milvus/internal/util/pipeline"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
const (
// TODO: better to be configured
nodeCtxTtInterval = 2 * time.Minute
enableTtChecker = true
loadTypeCollection = querypb.LoadType_LoadCollection
loadTypePartition = querypb.LoadType_LoadPartition
segmentTypeGrowing = commonpb.SegmentState_Growing
segmentTypeSealed = commonpb.SegmentState_Sealed
)
type (
UniqueID = typeutil.UniqueID
Timestamp = typeutil.Timestamp
PrimaryKey = storage.PrimaryKey
InsertMsg = msgstream.InsertMsg
DeleteMsg = msgstream.DeleteMsg
Collection = segments.Collection
DataManager = segments.Manager
Segment = segments.Segment
TSafeManager = tsafe.Manager
BaseNode = base.BaseNode
Msg = base.Msg
)
type TimeRange struct {
timestampMin Timestamp
timestampMax Timestamp
}

View File

@ -0,0 +1,131 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pkoracle
import (
"sync"
"github.com/bits-and-blooms/bloom/v3"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/storage"
"go.uber.org/zap"
)
var _ Candidate = (*BloomFilterSet)(nil)
// BloomFilterSet is one implementation of Candidate with bloom filter in statslog.
type BloomFilterSet struct {
statsMutex sync.RWMutex
segmentID int64
paritionID int64
segType commonpb.SegmentState
currentStat *storage.PkStatistics
historyStats []*storage.PkStatistics
}
// MayPkExist returns whether any bloom filters returns positive.
func (s *BloomFilterSet) MayPkExist(pk storage.PrimaryKey) bool {
s.statsMutex.RLock()
defer s.statsMutex.RUnlock()
if s.currentStat != nil && s.currentStat.PkExist(pk) {
return true
}
// for sealed, if one of the stats shows it exist, then we have to check it
for _, historyStat := range s.historyStats {
if historyStat.PkExist(pk) {
return true
}
}
return false
}
// ID implement candidate.
func (s *BloomFilterSet) ID() int64 {
return s.segmentID
}
// Partition implements candidate.
func (s *BloomFilterSet) Partition() int64 {
return s.paritionID
}
// Type implements candidate.
func (s *BloomFilterSet) Type() commonpb.SegmentState {
return s.segType
}
// UpdateBloomFilter updates currentStats with provided pks.
func (s *BloomFilterSet) UpdateBloomFilter(pks []storage.PrimaryKey) {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
if s.currentStat == nil {
s.currentStat = &storage.PkStatistics{
PkFilter: bloom.NewWithEstimates(storage.BloomFilterSize, storage.MaxBloomFalsePositive),
}
}
buf := make([]byte, 8)
for _, pk := range pks {
s.currentStat.UpdateMinMax(pk)
switch pk.Type() {
case schemapb.DataType_Int64:
int64Value := pk.(*storage.Int64PrimaryKey).Value
common.Endian.PutUint64(buf, uint64(int64Value))
s.currentStat.PkFilter.Add(buf)
case schemapb.DataType_VarChar:
stringValue := pk.(*storage.VarCharPrimaryKey).Value
s.currentStat.PkFilter.AddString(stringValue)
default:
log.Error("failed to update bloomfilter", zap.Any("PK type", pk.Type()))
panic("failed to update bloomfilter")
}
}
}
// AddHistoricalStats add loaded historical stats.
func (s *BloomFilterSet) AddHistoricalStats(stats *storage.PkStatistics) {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
s.historyStats = append(s.historyStats, stats)
}
// initCurrentStat initialize currentStats if nil.
// Note: invoker shall acquire statsMutex lock first.
func (s *BloomFilterSet) initCurrentStat() {
if s.currentStat == nil {
s.currentStat = &storage.PkStatistics{
PkFilter: bloom.NewWithEstimates(storage.BloomFilterSize, storage.MaxBloomFalsePositive),
}
}
}
// NewBloomFilterSet returns a new BloomFilterSet.
func NewBloomFilterSet(segmentID int64, paritionID int64, segType commonpb.SegmentState) *BloomFilterSet {
bfs := &BloomFilterSet{
segmentID: segmentID,
paritionID: paritionID,
segType: segType,
}
// does not need to init current
return bfs
}

View File

@ -0,0 +1,72 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pkoracle
import (
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
// Candidate is the interface for pk oracle candidate.
type Candidate interface {
// MayPkExist checks whether primary key could exists in this candidate.
MayPkExist(pk storage.PrimaryKey) bool
ID() int64
Partition() int64
Type() commonpb.SegmentState
}
type candidateWithWorker struct {
Candidate
workerID int64
}
// CandidateFilter filter type for candidate.
type CandidateFilter func(candidate candidateWithWorker) bool
// WithSegmentType returns CandiateFilter with provided segment type.
func WithSegmentType(typ commonpb.SegmentState) CandidateFilter {
return func(candidate candidateWithWorker) bool {
return candidate.Type() == typ
}
}
// WithWorkerID returns CandidateFilter with provided worker id.
func WithWorkerID(workerID int64) CandidateFilter {
return func(candidate candidateWithWorker) bool {
return candidate.workerID == workerID
}
}
// WithSegmentIDs returns CandidateFilter with provided segment ids.
func WithSegmentIDs(segmentIDs ...int64) CandidateFilter {
set := typeutil.NewSet[int64]()
set.Insert(segmentIDs...)
return func(candidate candidateWithWorker) bool {
return set.Contain(candidate.ID())
}
}
// WithPartitionID returns CandidateFilter with provided partitionID.
func WithPartitionID(partitionID int64) CandidateFilter {
return func(candidate candidateWithWorker) bool {
return candidate.Partition() == partitionID || partitionID == common.InvalidPartitionID
}
}

View File

@ -0,0 +1,58 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pkoracle
import (
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/storage"
)
type candidateKey struct {
segmentID int64
partitionID int64
typ commonpb.SegmentState
}
// MayPkExist checks whether primary key could exists in this candidate.
func (k candidateKey) MayPkExist(pk storage.PrimaryKey) bool {
// always return true to prevent miuse
return true
}
// ID implements Candidate.
func (k candidateKey) ID() int64 {
return k.segmentID
}
// Partition implements Candidate.
func (k candidateKey) Partition() int64 {
return k.partitionID
}
// Type implements Candidate.
func (k candidateKey) Type() commonpb.SegmentState {
return k.typ
}
// NewCandidateKey creates a candidateKey and returns as Candidate.
func NewCandidateKey(id int64, partitionID int64, typ commonpb.SegmentState) Candidate {
return candidateKey{
segmentID: id,
partitionID: partitionID,
typ: typ,
}
}

View File

@ -0,0 +1,103 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// pkoracle package contains pk - segment mapping logic.
package pkoracle
import (
"fmt"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
// PkOracle interface for pk oracle.
type PkOracle interface {
// GetCandidates returns segment candidates of which pk might belongs to.
Get(pk storage.PrimaryKey, filters ...CandidateFilter) ([]int64, error)
// RegisterCandidate adds candidate into pkOracle.
Register(candidate Candidate, workerID int64) error
// RemoveCandidate removes candidate
Remove(filters ...CandidateFilter) error
// CheckCandidate checks whether candidate with provided key exists.
Exists(candidate Candidate, workerID int64) bool
}
var _ PkOracle = (*pkOracle)(nil)
// pkOracle implementation.
type pkOracle struct {
candidates *typeutil.ConcurrentMap[string, candidateWithWorker]
}
// Get implements PkOracle.
func (pko *pkOracle) Get(pk storage.PrimaryKey, filters ...CandidateFilter) ([]int64, error) {
var result []int64
pko.candidates.Range(func(key string, candidate candidateWithWorker) bool {
for _, filter := range filters {
if !filter(candidate) {
return true
}
}
if candidate.MayPkExist(pk) {
result = append(result, candidate.ID())
}
return true
})
return result, nil
}
func (pko *pkOracle) candidateKey(candidate Candidate, workerID int64) string {
return fmt.Sprintf("%s-%d-%d", candidate.Type().String(), workerID, candidate.ID())
}
// Register register candidate
func (pko *pkOracle) Register(candidate Candidate, workerID int64) error {
pko.candidates.Insert(pko.candidateKey(candidate, workerID), candidateWithWorker{
Candidate: candidate,
workerID: workerID,
})
return nil
}
// Remove removes candidate from pko.
func (pko *pkOracle) Remove(filters ...CandidateFilter) error {
pko.candidates.Range(func(key string, candidate candidateWithWorker) bool {
for _, filter := range filters {
if !filter(candidate) {
return true
}
}
pko.candidates.GetAndRemove(pko.candidateKey(candidate, candidate.workerID))
return true
})
return nil
}
func (pko *pkOracle) Exists(candidate Candidate, workerID int64) bool {
_, ok := pko.candidates.Get(pko.candidateKey(candidate, workerID))
return ok
}
// NewPkOracle returns pkOracle as PkOracle interface.
func NewPkOracle() PkOracle {
return &pkOracle{
candidates: typeutil.NewConcurrentMap[string, candidateWithWorker](),
}
}

View File

@ -0,0 +1,98 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package segments
import (
"sync"
"github.com/bits-and-blooms/bloom/v3"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
storage "github.com/milvus-io/milvus/internal/storage"
"go.uber.org/zap"
)
type bloomFilterSet struct {
statsMutex sync.RWMutex
currentStat *storage.PkStatistics
historyStats []*storage.PkStatistics
}
func newBloomFilterSet() *bloomFilterSet {
return &bloomFilterSet{}
}
// MayPkExist returns whether any bloom filters returns positive.
func (s *bloomFilterSet) MayPkExist(pk storage.PrimaryKey) bool {
s.statsMutex.RLock()
defer s.statsMutex.RUnlock()
if s.currentStat != nil && s.currentStat.PkExist(pk) {
return true
}
// for sealed, if one of the stats shows it exist, then we have to check it
for _, historyStat := range s.historyStats {
if historyStat.PkExist(pk) {
return true
}
}
return false
}
// UpdateBloomFilter updates currentStats with provided pks.
func (s *bloomFilterSet) UpdateBloomFilter(pks []storage.PrimaryKey) {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
if s.currentStat == nil {
s.initCurrentStat()
}
buf := make([]byte, 8)
for _, pk := range pks {
s.currentStat.UpdateMinMax(pk)
switch pk.Type() {
case schemapb.DataType_Int64:
int64Value := pk.(*storage.Int64PrimaryKey).Value
common.Endian.PutUint64(buf, uint64(int64Value))
s.currentStat.PkFilter.Add(buf)
case schemapb.DataType_VarChar:
stringValue := pk.(*storage.VarCharPrimaryKey).Value
s.currentStat.PkFilter.AddString(stringValue)
default:
log.Error("failed to update bloomfilter", zap.Any("PK type", pk.Type()))
panic("failed to update bloomfilter")
}
}
}
// AddHistoricalStats add loaded historical stats.
func (s *bloomFilterSet) AddHistoricalStats(stats *storage.PkStatistics) {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
s.historyStats = append(s.historyStats, stats)
}
// initCurrentStat initialize currentStats if nil.
// Note: invoker shall acquire statsMutex lock first.
func (s *bloomFilterSet) initCurrentStat() {
s.currentStat = &storage.PkStatistics{
PkFilter: bloom.NewWithEstimates(storage.BloomFilterSize, storage.MaxBloomFalsePositive),
}
}

View File

@ -0,0 +1,88 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package segments
import (
"testing"
"github.com/milvus-io/milvus/internal/storage"
"github.com/stretchr/testify/suite"
)
type BloomFilterSetSuite struct {
suite.Suite
intPks []int64
stringPks []string
set *bloomFilterSet
}
func (suite *BloomFilterSetSuite) SetupTest() {
suite.intPks = []int64{1, 2, 3}
suite.stringPks = []string{"1", "2", "3"}
suite.set = newBloomFilterSet()
}
func (suite *BloomFilterSetSuite) TestInt64PkBloomFilter() {
pks, err := storage.GenInt64PrimaryKeys(suite.intPks...)
suite.NoError(err)
suite.set.UpdateBloomFilter(pks)
for _, pk := range pks {
exist := suite.set.MayPkExist(pk)
suite.True(exist)
}
}
func (suite *BloomFilterSetSuite) TestStringPkBloomFilter() {
pks, err := storage.GenVarcharPrimaryKeys(suite.stringPks...)
suite.NoError(err)
suite.set.UpdateBloomFilter(pks)
for _, pk := range pks {
exist := suite.set.MayPkExist(pk)
suite.True(exist)
}
}
func (suite *BloomFilterSetSuite) TestHistoricalBloomFilter() {
pks, err := storage.GenVarcharPrimaryKeys(suite.stringPks...)
suite.NoError(err)
suite.set.UpdateBloomFilter(pks)
for _, pk := range pks {
exist := suite.set.MayPkExist(pk)
suite.True(exist)
}
old := suite.set.currentStat
suite.set.currentStat = nil
for _, pk := range pks {
exist := suite.set.MayPkExist(pk)
suite.False(exist)
}
suite.set.AddHistoricalStats(old)
for _, pk := range pks {
exist := suite.set.MayPkExist(pk)
suite.True(exist)
}
}
func TestBloomFilterSet(t *testing.T) {
suite.Run(t, &BloomFilterSetSuite{})
}

View File

@ -0,0 +1,97 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package segments
/*
#cgo pkg-config: milvus_segcore milvus_storage
#include "segcore/collection_c.h"
#include "common/type_c.h"
#include "segcore/segment_c.h"
#include "storage/storage_c.h"
*/
import "C"
import (
"fmt"
"unsafe"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/cgoconverter"
)
// HandleCStatus deals with the error returned from CGO
func HandleCStatus(status *C.CStatus, extraInfo string) error {
if status.error_code == 0 {
return nil
}
errorCode := status.error_code
errorName, ok := commonpb.ErrorCode_name[int32(errorCode)]
if !ok {
errorName = "UnknownError"
}
errorMsg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
finalMsg := fmt.Sprintf("[%s] %s", errorName, errorMsg)
logMsg := fmt.Sprintf("%s, C Runtime Exception: %s\n", extraInfo, finalMsg)
log.Warn(logMsg)
return errors.New(finalMsg)
}
// HandleCProto deal with the result proto returned from CGO
func HandleCProto(cRes *C.CProto, msg proto.Message) error {
// Standalone CProto is protobuf created by C side,
// Passed from c side
// memory is managed manually
lease, blob := cgoconverter.UnsafeGoBytes(&cRes.proto_blob, int(cRes.proto_size))
defer cgoconverter.Release(lease)
return proto.Unmarshal(blob, msg)
}
// CopyCProtoBlob returns the copy of C memory
func CopyCProtoBlob(cProto *C.CProto) []byte {
blob := C.GoBytes(cProto.proto_blob, C.int32_t(cProto.proto_size))
C.free(cProto.proto_blob)
return blob
}
// GetCProtoBlob returns the raw C memory, invoker should release it itself
func GetCProtoBlob(cProto *C.CProto) []byte {
lease, blob := cgoconverter.UnsafeGoBytes(&cProto.proto_blob, int(cProto.proto_size))
cgoconverter.Extract(lease)
return blob
}
func GetLocalUsedSize() (int64, error) {
var availableSize int64
cSize := C.int64_t(availableSize)
status := C.GetLocalUsedSize(&cSize)
err := HandleCStatus(&status, "get local used size failed")
if err != nil {
return 0, err
}
return availableSize, nil
}

View File

@ -0,0 +1,156 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package segments
/*
#cgo pkg-config: milvus_segcore
#include "segcore/collection_c.h"
#include "segcore/segment_c.h"
*/
import "C"
import (
"sync"
"unsafe"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type CollectionManager interface {
Get(collectionID int64) *Collection
Put(collectionID int64, schema *schemapb.CollectionSchema, loadMeta *querypb.LoadMetaInfo)
}
type collectionManager struct {
mut sync.RWMutex
collections map[int64]*Collection
}
func NewCollectionManager() *collectionManager {
return &collectionManager{
collections: make(map[int64]*Collection),
}
}
func (m *collectionManager) Get(collectionID int64) *Collection {
m.mut.RLock()
defer m.mut.RUnlock()
return m.collections[collectionID]
}
func (m *collectionManager) Put(collectionID int64, schema *schemapb.CollectionSchema, loadMeta *querypb.LoadMetaInfo) {
m.mut.Lock()
defer m.mut.Unlock()
if _, ok := m.collections[collectionID]; ok {
return
}
collection := NewCollection(collectionID, schema, loadMeta.GetLoadType())
collection.AddPartition(loadMeta.GetPartitionIDs()...)
m.collections[collectionID] = collection
}
// Collection is a wrapper of the underlying C-structure C.CCollection
type Collection struct {
mu sync.RWMutex // protects colllectionPtr
collectionPtr C.CCollection
id int64
partitions *typeutil.ConcurrentSet[int64]
loadType querypb.LoadType
schema *schemapb.CollectionSchema
}
// ID returns collection id
func (c *Collection) ID() int64 {
return c.id
}
// Schema returns the schema of collection
func (c *Collection) Schema() *schemapb.CollectionSchema {
return c.schema
}
// getPartitionIDs return partitionIDs of collection
func (c *Collection) GetPartitions() []int64 {
return c.partitions.Collect()
}
func (c *Collection) ExistPartition(partitionIDs ...int64) bool {
return c.partitions.Contain(partitionIDs...)
}
// addPartitionID would add a partition id to partition id list of collection
func (c *Collection) AddPartition(partitions ...int64) {
for i := range partitions {
c.partitions.Insert(partitions[i])
}
}
// removePartitionID removes the partition id from partition id list of collection
func (c *Collection) RemovePartition(partitionID int64) {
c.partitions.Remove(partitionID)
}
// getLoadType get the loadType of collection, which is loadTypeCollection or loadTypePartition
func (c *Collection) GetLoadType() querypb.LoadType {
return c.loadType
}
// newCollection returns a new Collection
func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, loadType querypb.LoadType) *Collection {
/*
CCollection
NewCollection(const char* schema_proto_blob);
*/
schemaBlob := proto.MarshalTextString(schema)
cSchemaBlob := C.CString(schemaBlob)
defer C.free(unsafe.Pointer(cSchemaBlob))
collection := C.NewCollection(cSchemaBlob)
return &Collection{
collectionPtr: collection,
id: collectionID,
schema: schema,
partitions: typeutil.NewConcurrentSet[int64](),
loadType: loadType,
}
}
func NewCollectionWithoutSchema(collectionID int64, loadType querypb.LoadType) *Collection {
return &Collection{
id: collectionID,
partitions: typeutil.NewConcurrentSet[int64](),
loadType: loadType,
}
}
// deleteCollection delete collection and free the collection memory
func DeleteCollection(collection *Collection) {
/*
void
deleteCollection(CCollection collection);
*/
cPtr := collection.collectionPtr
C.DeleteCollection(cPtr)
collection.collectionPtr = nil
}

View File

@ -0,0 +1,49 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package segments
import (
"fmt"
"github.com/cockroachdb/errors"
)
var (
// Manager related errors
ErrCollectionNotFound = errors.New("CollectionNotFound")
ErrPartitionNotFound = errors.New("PartitionNotFound")
ErrSegmentNotFound = errors.New("SegmentNotFound")
ErrFieldNotFound = errors.New("FieldNotFound")
ErrSegmentReleased = errors.New("SegmentReleased")
)
func WrapSegmentNotFound(segmentID int64) error {
return fmt.Errorf("%w(%v)", ErrSegmentNotFound, segmentID)
}
func WrapCollectionNotFound(collectionID int64) error {
return fmt.Errorf("%w(%v)", ErrCollectionNotFound, collectionID)
}
func WrapFieldNotFound(fieldID int64) error {
return fmt.Errorf("%w(%v)", ErrFieldNotFound, fieldID)
}
// WrapSegmentReleased wrap ErrSegmentReleased with segmentID.
func WrapSegmentReleased(segmentID int64) error {
return fmt.Errorf("%w(%d)", ErrSegmentReleased, segmentID)
}

View File

@ -0,0 +1,202 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package segments
/*
#cgo pkg-config: milvus_common milvus_segcore
#include "segcore/load_index_c.h"
#include "common/binary_set_c.h"
*/
import "C"
import (
"encoding/json"
"fmt"
"path/filepath"
"unsafe"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/indexparams"
"github.com/milvus-io/milvus/internal/util/paramtable"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
)
// LoadIndexInfo is a wrapper of the underlying C-structure C.CLoadIndexInfo
type LoadIndexInfo struct {
cLoadIndexInfo C.CLoadIndexInfo
}
// newLoadIndexInfo returns a new LoadIndexInfo and error
func newLoadIndexInfo() (*LoadIndexInfo, error) {
var cLoadIndexInfo C.CLoadIndexInfo
// TODO::xige-16 support embedded milvus
storageType := "minio"
cAddress := C.CString(paramtable.Get().MinioCfg.Address.GetValue())
cBucketName := C.CString(paramtable.Get().MinioCfg.BucketName.GetValue())
cAccessKey := C.CString(paramtable.Get().MinioCfg.AccessKeyID.GetValue())
cAccessValue := C.CString(paramtable.Get().MinioCfg.SecretAccessKey.GetValue())
cRootPath := C.CString(paramtable.Get().MinioCfg.RootPath.GetValue())
cStorageType := C.CString(storageType)
cIamEndPoint := C.CString(paramtable.Get().MinioCfg.IAMEndpoint.GetValue())
defer C.free(unsafe.Pointer(cAddress))
defer C.free(unsafe.Pointer(cBucketName))
defer C.free(unsafe.Pointer(cAccessKey))
defer C.free(unsafe.Pointer(cAccessValue))
defer C.free(unsafe.Pointer(cRootPath))
defer C.free(unsafe.Pointer(cStorageType))
defer C.free(unsafe.Pointer(cIamEndPoint))
storageConfig := C.CStorageConfig{
address: cAddress,
bucket_name: cBucketName,
access_key_id: cAccessKey,
access_key_value: cAccessValue,
remote_root_path: cRootPath,
storage_type: cStorageType,
iam_endpoint: cIamEndPoint,
useSSL: C.bool(paramtable.Get().MinioCfg.UseSSL.GetAsBool()),
useIAM: C.bool(paramtable.Get().MinioCfg.UseIAM.GetAsBool()),
}
status := C.NewLoadIndexInfo(&cLoadIndexInfo, storageConfig)
if err := HandleCStatus(&status, "NewLoadIndexInfo failed"); err != nil {
return nil, err
}
return &LoadIndexInfo{cLoadIndexInfo: cLoadIndexInfo}, nil
}
// deleteLoadIndexInfo would delete C.CLoadIndexInfo
func deleteLoadIndexInfo(info *LoadIndexInfo) {
C.DeleteLoadIndexInfo(info.cLoadIndexInfo)
}
func (li *LoadIndexInfo) appendLoadIndexInfo(bytesIndex [][]byte, indexInfo *querypb.FieldIndexInfo, collectionID int64, partitionID int64, segmentID int64, fieldType schemapb.DataType) error {
fieldID := indexInfo.FieldID
indexPaths := indexInfo.IndexFilePaths
err := li.appendFieldInfo(collectionID, partitionID, segmentID, fieldID, fieldType)
if err != nil {
return err
}
err = li.appendIndexInfo(indexInfo.IndexID, indexInfo.BuildID, indexInfo.IndexVersion)
if err != nil {
return err
}
// some build params also exist in indexParams, which are useless during loading process
indexParams := funcutil.KeyValuePair2Map(indexInfo.IndexParams)
indexparams.SetDiskIndexLoadParams(indexParams, indexInfo.GetNumRows())
jsonIndexParams, err := json.Marshal(indexParams)
if err != nil {
err = fmt.Errorf("failed to json marshal index params %w", err)
return err
}
log.Info("start append index params", zap.String("index params", string(jsonIndexParams)))
for key, value := range indexParams {
err = li.appendIndexParam(key, value)
if err != nil {
return err
}
}
err = li.appendIndexData(bytesIndex, indexPaths)
return err
}
// appendIndexParam append indexParam to index
func (li *LoadIndexInfo) appendIndexParam(indexKey string, indexValue string) error {
cIndexKey := C.CString(indexKey)
defer C.free(unsafe.Pointer(cIndexKey))
cIndexValue := C.CString(indexValue)
defer C.free(unsafe.Pointer(cIndexValue))
status := C.AppendIndexParam(li.cLoadIndexInfo, cIndexKey, cIndexValue)
return HandleCStatus(&status, "AppendIndexParam failed")
}
func (li *LoadIndexInfo) appendIndexInfo(indexID int64, buildID int64, indexVersion int64) error {
cIndexID := C.int64_t(indexID)
cBuildID := C.int64_t(buildID)
cIndexVersion := C.int64_t(indexVersion)
status := C.AppendIndexInfo(li.cLoadIndexInfo, cIndexID, cBuildID, cIndexVersion)
return HandleCStatus(&status, "AppendIndexInfo failed")
}
func (li *LoadIndexInfo) cleanLocalData() error {
status := C.CleanLoadedIndex(li.cLoadIndexInfo)
return HandleCStatus(&status, "failed to clean cached data on disk")
}
func (li *LoadIndexInfo) appendIndexFile(filePath string) error {
cIndexFilePath := C.CString(filePath)
defer C.free(unsafe.Pointer(cIndexFilePath))
status := C.AppendIndexFilePath(li.cLoadIndexInfo, cIndexFilePath)
return HandleCStatus(&status, "AppendIndexIFile failed")
}
// appendFieldInfo appends fieldID & fieldType to index
func (li *LoadIndexInfo) appendFieldInfo(collectionID int64, partitionID int64, segmentID int64, fieldID int64, fieldType schemapb.DataType) error {
cColID := C.int64_t(collectionID)
cParID := C.int64_t(partitionID)
cSegID := C.int64_t(segmentID)
cFieldID := C.int64_t(fieldID)
cintDType := uint32(fieldType)
status := C.AppendFieldInfo(li.cLoadIndexInfo, cColID, cParID, cSegID, cFieldID, cintDType)
return HandleCStatus(&status, "AppendFieldInfo failed")
}
// appendIndexData appends binarySet index to cLoadIndexInfo
func (li *LoadIndexInfo) appendIndexData(bytesIndex [][]byte, indexKeys []string) error {
for _, indexPath := range indexKeys {
err := li.appendIndexFile(indexPath)
if err != nil {
return err
}
}
var cBinarySet C.CBinarySet
status := C.NewBinarySet(&cBinarySet)
defer C.DeleteBinarySet(cBinarySet)
if err := HandleCStatus(&status, "NewBinarySet failed"); err != nil {
return err
}
for i, byteIndex := range bytesIndex {
indexPtr := unsafe.Pointer(&byteIndex[0])
indexLen := C.int64_t(len(byteIndex))
binarySetKey := filepath.Base(indexKeys[i])
indexKey := C.CString(binarySetKey)
status = C.AppendIndexBinary(cBinarySet, indexPtr, indexLen, indexKey)
C.free(unsafe.Pointer(indexKey))
if err := HandleCStatus(&status, "LoadIndexInfo AppendIndexBinary failed"); err != nil {
return err
}
}
status = C.AppendIndex(li.cLoadIndexInfo, cBinarySet)
return HandleCStatus(&status, "AppendIndex failed")
}

View File

@ -0,0 +1,308 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package segments
/*
#cgo pkg-config: milvus_segcore
#include "segcore/collection_c.h"
#include "segcore/segment_c.h"
*/
import "C"
import (
"fmt"
"sync"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/paramtable"
. "github.com/milvus-io/milvus/internal/util/typeutil"
)
type SegmentFilter func(segment Segment) bool
func WithPartition(partitionID UniqueID) SegmentFilter {
return func(segment Segment) bool {
return segment.Partition() == partitionID
}
}
func WithChannel(channel string) SegmentFilter {
return func(segment Segment) bool {
return segment.Shard() == channel
}
}
func WithType(typ SegmentType) SegmentFilter {
return func(segment Segment) bool {
return segment.Type() == typ
}
}
func WithID(id int64) SegmentFilter {
return func(segment Segment) bool {
return segment.ID() == id
}
}
type actionType int32
const (
removeAction actionType = iota
addAction
)
type Manager struct {
Collection CollectionManager
Segment SegmentManager
}
func NewManager() *Manager {
return &Manager{
Collection: NewCollectionManager(),
Segment: NewSegmentManager(),
}
}
type SegmentManager interface {
// Put puts the given segments in,
// and increases the ref count of the corresponding collection,
// dup segments will not increase the ref count
Put(segmentType SegmentType, segments ...Segment)
Get(segmentID UniqueID) Segment
GetWithType(segmentID UniqueID, typ SegmentType) Segment
GetBy(filters ...SegmentFilter) []Segment
GetSealed(segmentID UniqueID) Segment
GetGrowing(segmentID UniqueID) Segment
// Remove removes the given segment,
// and decreases the ref count of the corresponding collection,
// will not decrease the ref count if the given segment not exists
Remove(segmentID UniqueID, scope querypb.DataScope)
RemoveBy(filters ...SegmentFilter)
}
var _ SegmentManager = (*segmentManager)(nil)
// Manager manages all collections and segments
type segmentManager struct {
mu sync.RWMutex // guards all
growingSegments map[UniqueID]Segment
sealedSegments map[UniqueID]Segment
}
func NewSegmentManager() *segmentManager {
return &segmentManager{
growingSegments: make(map[int64]Segment),
sealedSegments: make(map[int64]Segment),
}
}
func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
targetMap := mgr.growingSegments
switch segmentType {
case SegmentTypeGrowing:
targetMap = mgr.growingSegments
case SegmentTypeSealed:
targetMap = mgr.sealedSegments
default:
panic("unexpected segment type")
}
for _, segment := range segments {
if _, ok := targetMap[segment.ID()]; ok {
continue
}
targetMap[segment.ID()] = segment
metrics.QueryNodeNumSegments.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),
fmt.Sprint(segment.Collection()),
fmt.Sprint(segment.Partition()),
segment.Type().String(),
fmt.Sprint(len(segment.Indexes())),
).Inc()
if segment.RowNum() > 0 {
metrics.QueryNodeNumEntities.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),
fmt.Sprint(segment.Collection()),
fmt.Sprint(segment.Partition()),
segment.Type().String(),
fmt.Sprint(len(segment.Indexes())),
).Add(float64(segment.RowNum()))
}
}
mgr.updateMetric()
}
func (mgr *segmentManager) Get(segmentID UniqueID) Segment {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
if segment, ok := mgr.growingSegments[segmentID]; ok {
return segment
} else if segment, ok = mgr.sealedSegments[segmentID]; ok {
return segment
}
return nil
}
func (mgr *segmentManager) GetWithType(segmentID UniqueID, typ SegmentType) Segment {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
switch typ {
case SegmentTypeSealed:
return mgr.sealedSegments[segmentID]
case SegmentTypeGrowing:
return mgr.growingSegments[segmentID]
default:
return nil
}
}
func (mgr *segmentManager) GetBy(filters ...SegmentFilter) []Segment {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
ret := make([]Segment, 0)
for _, segment := range mgr.growingSegments {
if filter(segment, filters...) {
ret = append(ret, segment)
}
}
for _, segment := range mgr.sealedSegments {
if filter(segment, filters...) {
ret = append(ret, segment)
}
}
return ret
}
func filter(segment Segment, filters ...SegmentFilter) bool {
for _, filter := range filters {
if !filter(segment) {
return false
}
}
return true
}
func (mgr *segmentManager) GetSealed(segmentID UniqueID) Segment {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
if segment, ok := mgr.sealedSegments[segmentID]; ok {
return segment
}
return nil
}
func (mgr *segmentManager) GetGrowing(segmentID UniqueID) Segment {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
if segment, ok := mgr.growingSegments[segmentID]; ok {
return segment
}
return nil
}
func (mgr *segmentManager) Remove(segmentID UniqueID, scope querypb.DataScope) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
switch scope {
case querypb.DataScope_Streaming:
remove(segmentID, mgr.growingSegments)
case querypb.DataScope_Historical:
remove(segmentID, mgr.sealedSegments)
case querypb.DataScope_All:
remove(segmentID, mgr.growingSegments)
remove(segmentID, mgr.sealedSegments)
}
mgr.updateMetric()
}
func (mgr *segmentManager) RemoveBy(filters ...SegmentFilter) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
for id, segment := range mgr.growingSegments {
if filter(segment, filters...) {
remove(id, mgr.growingSegments)
}
}
for id, segment := range mgr.sealedSegments {
if filter(segment, filters...) {
remove(id, mgr.sealedSegments)
}
}
mgr.updateMetric()
}
func (mgr *segmentManager) updateMetric() {
// update collection and partiation metric
var collections, partiations = make(Set[int64]), make(Set[int64])
for _, seg := range mgr.growingSegments {
collections.Insert(seg.Collection())
partiations.Insert(seg.Partition())
}
for _, seg := range mgr.sealedSegments {
collections.Insert(seg.Collection())
partiations.Insert(seg.Partition())
}
metrics.QueryNodeNumCollections.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(collections.Len()))
metrics.QueryNodeNumPartitions.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(partiations.Len()))
}
func remove(segmentID int64, container map[int64]Segment) {
segment, ok := container[segmentID]
if !ok {
return
}
delete(container, segmentID)
rowNum := segment.RowNum()
DeleteSegment(segment.(*LocalSegment))
metrics.QueryNodeNumSegments.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),
fmt.Sprint(segment.Collection()),
fmt.Sprint(segment.Partition()),
segment.Type().String(),
fmt.Sprint(len(segment.Indexes())),
).Dec()
if rowNum > 0 {
metrics.QueryNodeNumEntities.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()),
fmt.Sprint(segment.Collection()),
fmt.Sprint(segment.Partition()),
segment.Type().String(),
fmt.Sprint(len(segment.Indexes())),
).Sub(float64(rowNum))
}
}

View File

@ -0,0 +1,108 @@
package segments
import (
"testing"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/stretchr/testify/suite"
)
type ManagerSuite struct {
suite.Suite
// Data
segmentIDs []int64
collectionIDs []int64
partitionIDs []int64
channels []string
types []SegmentType
segments []*LocalSegment
mgr *segmentManager
}
func (s *ManagerSuite) SetupSuite() {
paramtable.Init()
s.segmentIDs = []int64{1, 2, 3}
s.collectionIDs = []int64{100, 200, 300}
s.partitionIDs = []int64{10, 11, 12}
s.channels = []string{"dml1", "dml2", "dml3"}
s.types = []SegmentType{SegmentTypeSealed, SegmentTypeGrowing, SegmentTypeSealed}
}
func (s *ManagerSuite) SetupTest() {
s.mgr = NewSegmentManager()
for i, id := range s.segmentIDs {
segment, err := NewSegment(
NewCollection(s.collectionIDs[i], GenTestCollectionSchema("manager-suite", schemapb.DataType_Int64), querypb.LoadType_LoadCollection),
id,
s.partitionIDs[i],
s.collectionIDs[i],
s.channels[i],
s.types[i],
0,
nil,
nil,
)
s.Require().NoError(err)
s.segments = append(s.segments, segment)
s.mgr.Put(s.types[i], segment)
}
}
func (s *ManagerSuite) TestGetBy() {
for i, partitionID := range s.partitionIDs {
segments := s.mgr.GetBy(WithPartition(partitionID))
s.Contains(segments, s.segments[i])
}
for i, channel := range s.channels {
segments := s.mgr.GetBy(WithChannel(channel))
s.Contains(segments, s.segments[i])
}
for i, typ := range s.types {
segments := s.mgr.GetBy(WithType(typ))
s.Contains(segments, s.segments[i])
}
}
func (s *ManagerSuite) TestRemoveGrowing() {
for i, id := range s.segmentIDs {
isGrowing := s.types[i] == SegmentTypeGrowing
s.mgr.Remove(id, querypb.DataScope_Streaming)
s.Equal(s.mgr.Get(id) == nil, isGrowing)
}
}
func (s *ManagerSuite) TestRemoveSealed() {
for i, id := range s.segmentIDs {
isSealed := s.types[i] == SegmentTypeSealed
s.mgr.Remove(id, querypb.DataScope_Historical)
s.Equal(s.mgr.Get(id) == nil, isSealed)
}
}
func (s *ManagerSuite) TestRemoveAll() {
for _, id := range s.segmentIDs {
s.mgr.Remove(id, querypb.DataScope_All)
s.Nil(s.mgr.Get(id))
}
}
func (s *ManagerSuite) TestRemoveBy() {
for _, id := range s.segmentIDs {
s.mgr.RemoveBy(WithID(id))
s.Nil(s.mgr.Get(id))
}
}
func TestManager(t *testing.T) {
suite.Run(t, new(ManagerSuite))
}

Some files were not shown because too many files have changed in this diff Show More