mirror of https://github.com/milvus-io/milvus.git
Refactor QueryNode (#21625)
Signed-off-by: yah01 <yang.cen@zilliz.com> Co-authored-by: Congqi Xia <congqi.xia@zilliz.com> Co-authored-by: aoiasd <zhicheng.yue@zilliz.com>pull/22298/head
parent
977943463e
commit
081572d31c
configs
internal
core
src
unittest
distributed/querynode
mq
msgdispatcher
proto
querycoordv2
querynodev2
5
Makefile
5
Makefile
|
@ -365,5 +365,10 @@ generate-mockery: getdeps
|
|||
#internal/types
|
||||
$(PWD)/bin/mockery --name=QueryCoordComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/types --filename=mock_querycoord.go --with-expecter --structname=MockQueryCoord --outpkg=types --inpackage
|
||||
$(PWD)/bin/mockery --name=QueryNodeComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/types --filename=mock_querynode.go --with-expecter --structname=MockQueryNode --outpkg=types --inpackage
|
||||
# internal/querynodev2
|
||||
$(PWD)/bin/mockery --name=Manager --dir=$(PWD)/internal/querynodev2/cluster --output=$(PWD)/internal/querynodev2/cluster --filename=mock_manager.go --with-expecter --outpkg=cluster --structname=MockManager --inpackage
|
||||
$(PWD)/bin/mockery --name=Loader --dir=$(PWD)/internal/querynodev2/segments --output=$(PWD)/internal/querynodev2/segments --filename=mock_loader.go --with-expecter --outpkg=segments --structname=MockLoader --inpackage
|
||||
$(PWD)/bin/mockery --name=Worker --dir=$(PWD)/internal/querynodev2/cluster --output=$(PWD)/internal/querynodev2/cluster --filename=mock_worker.go --with-expecter --outpkg=worker --structname=MockWorker --inpackage
|
||||
|
||||
ci-ut: build-cpp-with-coverage generated-proto-go-without-cpp codecov-cpp codecov-go
|
||||
|
||||
|
|
|
@ -182,13 +182,15 @@ queryCoord:
|
|||
heartbeatAvailableInterval: 10000 # 10s, Only QueryNodes which fetched heartbeats within the duration are available
|
||||
loadTimeoutSeconds: 600
|
||||
checkHandoffInterval: 5000
|
||||
enableActiveStandby: false
|
||||
port: 19531
|
||||
grpc:
|
||||
serverMaxSendSize: 536870912
|
||||
serverMaxRecvSize: 536870912
|
||||
clientMaxSendSize: 268435456
|
||||
clientMaxRecvSize: 268435456
|
||||
taskMergeCap: 1
|
||||
taskExecutionCap: 256
|
||||
enableActiveStandby: false # Enable active-standby
|
||||
|
||||
# Related configuration of queryNode, used to run hybrid search between vector and scalar data.
|
||||
queryNode:
|
||||
|
|
|
@ -46,7 +46,7 @@ datatype_sizeof(DataType data_type, int dim = 1) {
|
|||
case DataType::VECTOR_FLOAT:
|
||||
return sizeof(float) * dim;
|
||||
case DataType::VECTOR_BINARY: {
|
||||
Assert(dim % 8 == 0);
|
||||
AssertInfo(dim % 8 == 0, "dim=" + std::to_string(dim));
|
||||
return dim / 8;
|
||||
}
|
||||
default: {
|
||||
|
|
|
@ -15,11 +15,12 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include "index/VectorDiskIndex.h"
|
||||
|
||||
#include "common/Utils.h"
|
||||
#include "config/ConfigKnowhere.h"
|
||||
#include "index/Meta.h"
|
||||
#include "index/Utils.h"
|
||||
|
||||
#include "storage/LocalChunkManager.h"
|
||||
#include "config/ConfigKnowhere.h"
|
||||
#include "storage/Util.h"
|
||||
#include "common/Consts.h"
|
||||
#include "common/Utils.h"
|
||||
|
|
|
@ -54,6 +54,6 @@ target_link_libraries(milvus_segcore
|
|||
|
||||
install(TARGETS milvus_segcore DESTINATION "${CMAKE_INSTALL_LIBDIR}")
|
||||
|
||||
add_executable(velox_demo VeloxDemo.cpp)
|
||||
target_link_libraries(velox_demo ${CONAN_LIBS} velox_bundled)
|
||||
install(TARGETS velox_demo DESTINATION "${CMAKE_INSTALL_BINDIR}")
|
||||
# add_executable(velox_demo VeloxDemo.cpp)
|
||||
# target_link_libraries(velox_demo ${CONAN_LIBS} velox_bundled)
|
||||
# install(TARGETS velox_demo DESTINATION "${CMAKE_INSTALL_BINDIR}")
|
||||
|
|
|
@ -274,9 +274,27 @@ SegmentSealedImpl::LoadDeletedRecord(const LoadDeletedRecordInfo& info) {
|
|||
auto timestamps = reinterpret_cast<const Timestamp*>(info.timestamps);
|
||||
|
||||
// step 2: fill pks and timestamps
|
||||
ssize_t n = deleted_record_.ack_responder_.GetAck();
|
||||
ssize_t divide_point = 0;
|
||||
// Truncate the overlapping prefix
|
||||
if (n > 0) {
|
||||
auto last = deleted_record_.timestamps_[n - 1];
|
||||
divide_point =
|
||||
std::lower_bound(timestamps, timestamps + size, last + 1) -
|
||||
timestamps;
|
||||
}
|
||||
|
||||
// All these delete records have been loaded
|
||||
if (divide_point == size) {
|
||||
return;
|
||||
}
|
||||
|
||||
size -= divide_point;
|
||||
auto reserved_begin = deleted_record_.reserved.fetch_add(size);
|
||||
deleted_record_.pks_.set_data_raw(reserved_begin, pks.data(), size);
|
||||
deleted_record_.timestamps_.set_data_raw(reserved_begin, timestamps, size);
|
||||
deleted_record_.pks_.set_data_raw(
|
||||
reserved_begin, pks.data() + divide_point, size);
|
||||
deleted_record_.timestamps_.set_data_raw(
|
||||
reserved_begin, timestamps + divide_point, size);
|
||||
deleted_record_.ack_responder_.AddSegment(reserved_begin,
|
||||
reserved_begin + size);
|
||||
}
|
||||
|
|
|
@ -9,15 +9,16 @@
|
|||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "segcore/load_index_c.h"
|
||||
|
||||
#include "common/CDataType.h"
|
||||
#include "common/FieldMeta.h"
|
||||
#include "common/Utils.h"
|
||||
#include "index/IndexFactory.h"
|
||||
#include "index/Meta.h"
|
||||
#include "index/Utils.h"
|
||||
#include "index/IndexFactory.h"
|
||||
#include "storage/Util.h"
|
||||
#include "segcore/load_index_c.h"
|
||||
#include "segcore/Types.h"
|
||||
#include "storage/Util.h"
|
||||
|
||||
CStatus
|
||||
NewLoadIndexInfo(CLoadIndexInfo* c_load_index_info,
|
||||
|
|
|
@ -739,7 +739,6 @@ TEST(Sealed, Delete) {
|
|||
LoadDeletedRecordInfo info = {timestamps.data(), ids.get(), row_count};
|
||||
segment->LoadDeletedRecord(info);
|
||||
|
||||
std::vector<uint8_t> tmp_block{0, 0};
|
||||
BitsetType bitset(N, false);
|
||||
segment->mask_with_delete(bitset, 10, 11);
|
||||
ASSERT_EQ(bitset.count(), pks.size());
|
||||
|
@ -758,6 +757,90 @@ TEST(Sealed, Delete) {
|
|||
reinterpret_cast<const Timestamp*>(new_timestamps.data()));
|
||||
}
|
||||
|
||||
TEST(Sealed, OverlapDelete) {
|
||||
auto dim = 16;
|
||||
auto topK = 5;
|
||||
auto N = 10;
|
||||
auto metric_type = knowhere::metric::L2;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto fakevec_id = schema->AddDebugField("fakevec", DataType::VECTOR_FLOAT, dim, metric_type);
|
||||
auto counter_id = schema->AddDebugField("counter", DataType::INT64);
|
||||
auto double_id = schema->AddDebugField("double", DataType::DOUBLE);
|
||||
auto nothing_id = schema->AddDebugField("nothing", DataType::INT32);
|
||||
schema->set_primary_field_id(counter_id);
|
||||
|
||||
auto dataset = DataGen(schema, N);
|
||||
|
||||
auto fakevec = dataset.get_col<float>(fakevec_id);
|
||||
|
||||
auto segment = CreateSealedSegment(schema);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"double": {
|
||||
"GE": -1,
|
||||
"LT": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5,
|
||||
"round_decimal": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
|
||||
Timestamp time = 1000000;
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024);
|
||||
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
|
||||
ASSERT_ANY_THROW(segment->Search(plan.get(), ph_group.get(), time));
|
||||
|
||||
SealedLoadFieldData(dataset, *segment);
|
||||
|
||||
int64_t row_count = 5;
|
||||
std::vector<idx_t> pks{1, 2, 3, 4, 5};
|
||||
auto ids = std::make_unique<IdArray>();
|
||||
ids->mutable_int_id()->mutable_data()->Add(pks.begin(), pks.end());
|
||||
std::vector<Timestamp> timestamps{10, 10, 10, 10, 10};
|
||||
|
||||
LoadDeletedRecordInfo info = {timestamps.data(), ids.get(), row_count};
|
||||
segment->LoadDeletedRecord(info);
|
||||
ASSERT_EQ(segment->get_deleted_count(), pks.size())
|
||||
<< "deleted_count=" << segment->get_deleted_count() << " pks_count=" << pks.size() << std::endl;
|
||||
|
||||
// Load overlapping delete records
|
||||
row_count += 3;
|
||||
pks.insert(pks.end(), {6, 7, 8});
|
||||
auto new_ids = std::make_unique<IdArray>();
|
||||
new_ids->mutable_int_id()->mutable_data()->Add(pks.begin(), pks.end());
|
||||
timestamps.insert(timestamps.end(), {11, 11, 11});
|
||||
LoadDeletedRecordInfo overlap_info = {timestamps.data(), new_ids.get(), row_count};
|
||||
segment->LoadDeletedRecord(overlap_info);
|
||||
|
||||
BitsetType bitset(N, false);
|
||||
// NOTE: need to change delete timestamp, so not to hit the cache
|
||||
ASSERT_EQ(segment->get_deleted_count(), pks.size())
|
||||
<< "deleted_count=" << segment->get_deleted_count() << " pks_count=" << pks.size() << std::endl;
|
||||
segment->mask_with_delete(bitset, 10, 12);
|
||||
ASSERT_EQ(bitset.count(), pks.size())
|
||||
<< "bitset_count=" << bitset.count() << " pks_count=" << pks.size() << std::endl;
|
||||
}
|
||||
|
||||
auto
|
||||
GenMaxFloatVecs(int N, int dim) {
|
||||
std::vector<float> vecs;
|
||||
|
|
|
@ -255,7 +255,7 @@ func (b *binlogIO) genDeltaBlobs(data *DeleteData, collID, partID, segID UniqueI
|
|||
|
||||
// genInsertBlobs returns kvs, insert-paths, stats-paths
|
||||
func (b *binlogIO) genInsertBlobs(data *InsertData, partID, segID UniqueID, meta *etcdpb.CollectionMeta) (map[string][]byte, map[UniqueID]*datapb.FieldBinlog, map[UniqueID]*datapb.FieldBinlog, error) {
|
||||
inCodec := storage.NewInsertCodec(meta)
|
||||
inCodec := storage.NewInsertCodecWithSchema(meta)
|
||||
inlogs, statslogs, err := inCodec.Serialize(partID, segID, data)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
|
|
|
@ -554,7 +554,7 @@ func getInt64DeltaBlobs(segID UniqueID, pks []UniqueID, tss []Timestamp) ([]*Blo
|
|||
}
|
||||
|
||||
func getInsertBlobs(segID UniqueID, iData *InsertData, meta *etcdpb.CollectionMeta) ([]*Blob, error) {
|
||||
iCodec := storage.NewInsertCodec(meta)
|
||||
iCodec := storage.NewInsertCodecWithSchema(meta)
|
||||
|
||||
iblobs, _, err := iCodec.Serialize(10, segID, iData)
|
||||
return iblobs, err
|
||||
|
|
|
@ -364,7 +364,7 @@ func (m *rendezvousFlushManager) flushBufferData(data *BufferData, segmentID Uni
|
|||
}
|
||||
|
||||
// encode data and convert output data
|
||||
inCodec := storage.NewInsertCodec(meta)
|
||||
inCodec := storage.NewInsertCodecWithSchema(meta)
|
||||
|
||||
binLogs, statsBinlogs, err := inCodec.Serialize(partID, segmentID, data.buffer)
|
||||
if err != nil {
|
||||
|
|
|
@ -917,7 +917,7 @@ func createBinLogs(rowNum int, schema *schemapb.CollectionSchema, ts Timestamp,
|
|||
ID: colID,
|
||||
Schema: schema,
|
||||
}
|
||||
binLogs, statsBinLogs, err := storage.NewInsertCodec(meta).Serialize(partID, segmentID, data.buffer)
|
||||
binLogs, statsBinLogs, err := storage.NewInsertCodecWithSchema(meta).Serialize(partID, segmentID, data.buffer)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
|
@ -414,3 +414,23 @@ func (c *Client) SyncDistribution(ctx context.Context, req *querypb.SyncDistribu
|
|||
}
|
||||
return ret.(*commonpb.Status), err
|
||||
}
|
||||
|
||||
// Delete is used to forward delete message between delegator and workers.
|
||||
func (c *Client) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commonpb.Status, error) {
|
||||
req = typeutil.Clone(req)
|
||||
commonpbutil.UpdateMsgBase(
|
||||
req.GetBase(),
|
||||
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()),
|
||||
)
|
||||
ret, err := c.grpcClient.Call(ctx, func(client querypb.QueryNodeClient) (any, error) {
|
||||
if !funcutil.CheckCtxValid(ctx) {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return client.Delete(ctx, req)
|
||||
})
|
||||
|
||||
if err != nil || ret == nil {
|
||||
return nil, err
|
||||
}
|
||||
return ret.(*commonpb.Status), err
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
qn "github.com/milvus-io/milvus/internal/querynode"
|
||||
qn "github.com/milvus-io/milvus/internal/querynodev2"
|
||||
"github.com/milvus-io/milvus/internal/tracer"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
|
@ -331,3 +331,8 @@ func (s *Server) GetDataDistribution(ctx context.Context, req *querypb.GetDataDi
|
|||
func (s *Server) SyncDistribution(ctx context.Context, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) {
|
||||
return s.querynode.SyncDistribution(ctx, req)
|
||||
}
|
||||
|
||||
// Delete is used to forward delete message between delegator and workers.
|
||||
func (s *Server) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commonpb.Status, error) {
|
||||
return s.querynode.Delete(ctx, req)
|
||||
}
|
||||
|
|
|
@ -161,6 +161,10 @@ func (m *MockQueryNode) SyncDistribution(context.Context, *querypb.SyncDistribut
|
|||
return m.status, m.err
|
||||
}
|
||||
|
||||
func (m *MockQueryNode) Delete(context.Context, *querypb.DeleteRequest) (*commonpb.Status, error) {
|
||||
return m.status, m.err
|
||||
}
|
||||
|
||||
type MockRootCoord struct {
|
||||
types.RootCoord
|
||||
initErr error
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
// Code generated by mockery v2.15.0. DO NOT EDIT.
|
||||
|
||||
package msgdispatcher
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
mqwrapper "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
|
||||
msgstream "github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
)
|
||||
|
||||
// MockClient is an autogenerated mock type for the Client type
|
||||
type MockClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockClient_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockClient) EXPECT() *MockClient_Expecter {
|
||||
return &MockClient_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Deregister provides a mock function with given fields: vchannel
|
||||
func (_m *MockClient) Deregister(vchannel string) {
|
||||
_m.Called(vchannel)
|
||||
}
|
||||
|
||||
// MockClient_Deregister_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Deregister'
|
||||
type MockClient_Deregister_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Deregister is a helper method to define mock.On call
|
||||
// - vchannel string
|
||||
func (_e *MockClient_Expecter) Deregister(vchannel interface{}) *MockClient_Deregister_Call {
|
||||
return &MockClient_Deregister_Call{Call: _e.mock.On("Deregister", vchannel)}
|
||||
}
|
||||
|
||||
func (_c *MockClient_Deregister_Call) Run(run func(vchannel string)) *MockClient_Deregister_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockClient_Deregister_Call) Return() *MockClient_Deregister_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
// Register provides a mock function with given fields: vchannel, pos, subPos
|
||||
func (_m *MockClient) Register(vchannel string, pos *msgpb.MsgPosition, subPos mqwrapper.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error) {
|
||||
ret := _m.Called(vchannel, pos, subPos)
|
||||
|
||||
var r0 <-chan *msgstream.MsgPack
|
||||
if rf, ok := ret.Get(0).(func(string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) <-chan *msgstream.MsgPack); ok {
|
||||
r0 = rf(vchannel, pos, subPos)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(<-chan *msgstream.MsgPack)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) error); ok {
|
||||
r1 = rf(vchannel, pos, subPos)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockClient_Register_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Register'
|
||||
type MockClient_Register_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Register is a helper method to define mock.On call
|
||||
// - vchannel string
|
||||
// - pos *msgpb.MsgPosition
|
||||
// - subPos mqwrapper.SubscriptionInitialPosition
|
||||
func (_e *MockClient_Expecter) Register(vchannel interface{}, pos interface{}, subPos interface{}) *MockClient_Register_Call {
|
||||
return &MockClient_Register_Call{Call: _e.mock.On("Register", vchannel, pos, subPos)}
|
||||
}
|
||||
|
||||
func (_c *MockClient_Register_Call) Run(run func(vchannel string, pos *msgpb.MsgPosition, subPos mqwrapper.SubscriptionInitialPosition)) *MockClient_Register_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string), args[1].(*msgpb.MsgPosition), args[2].(mqwrapper.SubscriptionInitialPosition))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockClient_Register_Call) Return(_a0 <-chan *msgstream.MsgPack, _a1 error) *MockClient_Register_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
type mockConstructorTestingTNewMockClient interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}
|
||||
|
||||
// NewMockClient creates a new instance of MockClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
func NewMockClient(t mockConstructorTestingTNewMockClient) *MockClient {
|
||||
mock := &MockClient{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
|
@ -1,20 +1,393 @@
|
|||
// Code generated by mockery v2.15.0. DO NOT EDIT.
|
||||
|
||||
package msgstream
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
mqwrapper "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
)
|
||||
|
||||
// MockMsgStream is an autogenerated mock type for the MsgStream type
|
||||
type MockMsgStream struct {
|
||||
MsgStream
|
||||
AsProducerFunc func(channels []string)
|
||||
BroadcastMarkFunc func(*MsgPack) (map[string][]MessageID, error)
|
||||
BroadcastFunc func(*MsgPack) error
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func NewMockMsgStream() *MockMsgStream {
|
||||
return &MockMsgStream{}
|
||||
type MockMsgStream_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (m MockMsgStream) AsProducer(channels []string) {
|
||||
m.AsProducerFunc(channels)
|
||||
func (_m *MockMsgStream) EXPECT() *MockMsgStream_Expecter {
|
||||
return &MockMsgStream_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
func (m MockMsgStream) Broadcast(pack *MsgPack) (map[string][]MessageID, error) {
|
||||
return m.BroadcastMarkFunc(pack)
|
||||
// AsConsumer provides a mock function with given fields: channels, subName, position
|
||||
func (_m *MockMsgStream) AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) {
|
||||
_m.Called(channels, subName, position)
|
||||
}
|
||||
|
||||
// MockMsgStream_AsConsumer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AsConsumer'
|
||||
type MockMsgStream_AsConsumer_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// AsConsumer is a helper method to define mock.On call
|
||||
// - channels []string
|
||||
// - subName string
|
||||
// - position mqwrapper.SubscriptionInitialPosition
|
||||
func (_e *MockMsgStream_Expecter) AsConsumer(channels interface{}, subName interface{}, position interface{}) *MockMsgStream_AsConsumer_Call {
|
||||
return &MockMsgStream_AsConsumer_Call{Call: _e.mock.On("AsConsumer", channels, subName, position)}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_AsConsumer_Call) Run(run func(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition)) *MockMsgStream_AsConsumer_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]string), args[1].(string), args[2].(mqwrapper.SubscriptionInitialPosition))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_AsConsumer_Call) Return() *MockMsgStream_AsConsumer_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
// AsProducer provides a mock function with given fields: channels
|
||||
func (_m *MockMsgStream) AsProducer(channels []string) {
|
||||
_m.Called(channels)
|
||||
}
|
||||
|
||||
// MockMsgStream_AsProducer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AsProducer'
|
||||
type MockMsgStream_AsProducer_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// AsProducer is a helper method to define mock.On call
|
||||
// - channels []string
|
||||
func (_e *MockMsgStream_Expecter) AsProducer(channels interface{}) *MockMsgStream_AsProducer_Call {
|
||||
return &MockMsgStream_AsProducer_Call{Call: _e.mock.On("AsProducer", channels)}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_AsProducer_Call) Run(run func(channels []string)) *MockMsgStream_AsProducer_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_AsProducer_Call) Return() *MockMsgStream_AsProducer_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
// Broadcast provides a mock function with given fields: _a0
|
||||
func (_m *MockMsgStream) Broadcast(_a0 *MsgPack) (map[string][]mqwrapper.MessageID, error) {
|
||||
ret := _m.Called(_a0)
|
||||
|
||||
var r0 map[string][]mqwrapper.MessageID
|
||||
if rf, ok := ret.Get(0).(func(*MsgPack) map[string][]mqwrapper.MessageID); ok {
|
||||
r0 = rf(_a0)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(map[string][]mqwrapper.MessageID)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(*MsgPack) error); ok {
|
||||
r1 = rf(_a0)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockMsgStream_Broadcast_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Broadcast'
|
||||
type MockMsgStream_Broadcast_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Broadcast is a helper method to define mock.On call
|
||||
// - _a0 *MsgPack
|
||||
func (_e *MockMsgStream_Expecter) Broadcast(_a0 interface{}) *MockMsgStream_Broadcast_Call {
|
||||
return &MockMsgStream_Broadcast_Call{Call: _e.mock.On("Broadcast", _a0)}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Broadcast_Call) Run(run func(_a0 *MsgPack)) *MockMsgStream_Broadcast_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(*MsgPack))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Broadcast_Call) Return(_a0 map[string][]mqwrapper.MessageID, _a1 error) *MockMsgStream_Broadcast_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Chan provides a mock function with given fields:
|
||||
func (_m *MockMsgStream) Chan() <-chan *MsgPack {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 <-chan *MsgPack
|
||||
if rf, ok := ret.Get(0).(func() <-chan *MsgPack); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(<-chan *MsgPack)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockMsgStream_Chan_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Chan'
|
||||
type MockMsgStream_Chan_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Chan is a helper method to define mock.On call
|
||||
func (_e *MockMsgStream_Expecter) Chan() *MockMsgStream_Chan_Call {
|
||||
return &MockMsgStream_Chan_Call{Call: _e.mock.On("Chan")}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Chan_Call) Run(run func()) *MockMsgStream_Chan_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Chan_Call) Return(_a0 <-chan *MsgPack) *MockMsgStream_Chan_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Close provides a mock function with given fields:
|
||||
func (_m *MockMsgStream) Close() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// MockMsgStream_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
|
||||
type MockMsgStream_Close_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Close is a helper method to define mock.On call
|
||||
func (_e *MockMsgStream_Expecter) Close() *MockMsgStream_Close_Call {
|
||||
return &MockMsgStream_Close_Call{Call: _e.mock.On("Close")}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Close_Call) Run(run func()) *MockMsgStream_Close_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Close_Call) Return() *MockMsgStream_Close_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetLatestMsgID provides a mock function with given fields: channel
|
||||
func (_m *MockMsgStream) GetLatestMsgID(channel string) (mqwrapper.MessageID, error) {
|
||||
ret := _m.Called(channel)
|
||||
|
||||
var r0 mqwrapper.MessageID
|
||||
if rf, ok := ret.Get(0).(func(string) mqwrapper.MessageID); ok {
|
||||
r0 = rf(channel)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(mqwrapper.MessageID)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(string) error); ok {
|
||||
r1 = rf(channel)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockMsgStream_GetLatestMsgID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetLatestMsgID'
|
||||
type MockMsgStream_GetLatestMsgID_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetLatestMsgID is a helper method to define mock.On call
|
||||
// - channel string
|
||||
func (_e *MockMsgStream_Expecter) GetLatestMsgID(channel interface{}) *MockMsgStream_GetLatestMsgID_Call {
|
||||
return &MockMsgStream_GetLatestMsgID_Call{Call: _e.mock.On("GetLatestMsgID", channel)}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_GetLatestMsgID_Call) Run(run func(channel string)) *MockMsgStream_GetLatestMsgID_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_GetLatestMsgID_Call) Return(_a0 mqwrapper.MessageID, _a1 error) *MockMsgStream_GetLatestMsgID_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetProduceChannels provides a mock function with given fields:
|
||||
func (_m *MockMsgStream) GetProduceChannels() []string {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 []string
|
||||
if rf, ok := ret.Get(0).(func() []string); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]string)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockMsgStream_GetProduceChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetProduceChannels'
|
||||
type MockMsgStream_GetProduceChannels_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetProduceChannels is a helper method to define mock.On call
|
||||
func (_e *MockMsgStream_Expecter) GetProduceChannels() *MockMsgStream_GetProduceChannels_Call {
|
||||
return &MockMsgStream_GetProduceChannels_Call{Call: _e.mock.On("GetProduceChannels")}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_GetProduceChannels_Call) Run(run func()) *MockMsgStream_GetProduceChannels_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_GetProduceChannels_Call) Return(_a0 []string) *MockMsgStream_GetProduceChannels_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Produce provides a mock function with given fields: _a0
|
||||
func (_m *MockMsgStream) Produce(_a0 *MsgPack) error {
|
||||
ret := _m.Called(_a0)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(*MsgPack) error); ok {
|
||||
r0 = rf(_a0)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockMsgStream_Produce_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Produce'
|
||||
type MockMsgStream_Produce_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Produce is a helper method to define mock.On call
|
||||
// - _a0 *MsgPack
|
||||
func (_e *MockMsgStream_Expecter) Produce(_a0 interface{}) *MockMsgStream_Produce_Call {
|
||||
return &MockMsgStream_Produce_Call{Call: _e.mock.On("Produce", _a0)}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Produce_Call) Run(run func(_a0 *MsgPack)) *MockMsgStream_Produce_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(*MsgPack))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Produce_Call) Return(_a0 error) *MockMsgStream_Produce_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Seek provides a mock function with given fields: offset
|
||||
func (_m *MockMsgStream) Seek(offset []*msgpb.MsgPosition) error {
|
||||
ret := _m.Called(offset)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]*msgpb.MsgPosition) error); ok {
|
||||
r0 = rf(offset)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockMsgStream_Seek_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Seek'
|
||||
type MockMsgStream_Seek_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Seek is a helper method to define mock.On call
|
||||
// - offset []*msgpb.MsgPosition
|
||||
func (_e *MockMsgStream_Expecter) Seek(offset interface{}) *MockMsgStream_Seek_Call {
|
||||
return &MockMsgStream_Seek_Call{Call: _e.mock.On("Seek", offset)}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Seek_Call) Run(run func(offset []*msgpb.MsgPosition)) *MockMsgStream_Seek_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]*msgpb.MsgPosition))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_Seek_Call) Return(_a0 error) *MockMsgStream_Seek_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetRepackFunc provides a mock function with given fields: repackFunc
|
||||
func (_m *MockMsgStream) SetRepackFunc(repackFunc RepackFunc) {
|
||||
_m.Called(repackFunc)
|
||||
}
|
||||
|
||||
// MockMsgStream_SetRepackFunc_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRepackFunc'
|
||||
type MockMsgStream_SetRepackFunc_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SetRepackFunc is a helper method to define mock.On call
|
||||
// - repackFunc RepackFunc
|
||||
func (_e *MockMsgStream_Expecter) SetRepackFunc(repackFunc interface{}) *MockMsgStream_SetRepackFunc_Call {
|
||||
return &MockMsgStream_SetRepackFunc_Call{Call: _e.mock.On("SetRepackFunc", repackFunc)}
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_SetRepackFunc_Call) Run(run func(repackFunc RepackFunc)) *MockMsgStream_SetRepackFunc_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(RepackFunc))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockMsgStream_SetRepackFunc_Call) Return() *MockMsgStream_SetRepackFunc_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
type mockConstructorTestingTNewMockMsgStream interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}
|
||||
|
||||
// NewMockMsgStream creates a new instance of MockMsgStream. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
func NewMockMsgStream(t mockConstructorTestingTNewMockMsgStream) *MockMsgStream {
|
||||
mock := &MockMsgStream{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
package msgstream
|
||||
|
||||
type WastedMockMsgStream struct {
|
||||
MsgStream
|
||||
AsProducerFunc func(channels []string)
|
||||
BroadcastMarkFunc func(*MsgPack) (map[string][]MessageID, error)
|
||||
BroadcastFunc func(*MsgPack) error
|
||||
ChanFunc func() <-chan *MsgPack
|
||||
}
|
||||
|
||||
func NewWastedMockMsgStream() *WastedMockMsgStream {
|
||||
return &WastedMockMsgStream{}
|
||||
}
|
||||
|
||||
func (m WastedMockMsgStream) AsProducer(channels []string) {
|
||||
m.AsProducerFunc(channels)
|
||||
}
|
||||
|
||||
func (m WastedMockMsgStream) Broadcast(pack *MsgPack) (map[string][]MessageID, error) {
|
||||
return m.BroadcastMarkFunc(pack)
|
||||
}
|
||||
|
||||
func (m WastedMockMsgStream) Chan() <-chan *MsgPack {
|
||||
return m.ChanFunc()
|
||||
}
|
|
@ -72,6 +72,7 @@ service QueryNode {
|
|||
|
||||
rpc GetDataDistribution(GetDataDistributionRequest) returns (GetDataDistributionResponse) {}
|
||||
rpc SyncDistribution(SyncDistributionRequest) returns (common.Status) {}
|
||||
rpc Delete(DeleteRequest) returns (common.Status) {}
|
||||
}
|
||||
|
||||
//--------------------QueryCoord grpc request and response proto------------------
|
||||
|
@ -243,6 +244,7 @@ message SegmentLoadInfo {
|
|||
int64 segment_size = 12;
|
||||
string insert_channel = 13;
|
||||
msg.MsgPosition start_position = 14;
|
||||
msg.MsgPosition end_position = 15;
|
||||
}
|
||||
|
||||
message FieldIndexInfo {
|
||||
|
@ -259,6 +261,11 @@ message FieldIndexInfo {
|
|||
int64 num_rows = 10;
|
||||
}
|
||||
|
||||
enum LoadScope {
|
||||
Full = 0;
|
||||
Delta = 1;
|
||||
}
|
||||
|
||||
message LoadSegmentsRequest {
|
||||
common.MsgBase base = 1;
|
||||
int64 dst_nodeID = 2;
|
||||
|
@ -271,6 +278,7 @@ message LoadSegmentsRequest {
|
|||
repeated msg.MsgPosition delta_positions = 9;
|
||||
int64 version = 10;
|
||||
bool need_transfer = 11;
|
||||
LoadScope load_scope = 12;
|
||||
}
|
||||
|
||||
message ReleaseSegmentsRequest {
|
||||
|
@ -316,6 +324,18 @@ message ReplicaSegmentsInfo {
|
|||
repeated int64 versions = 4;
|
||||
}
|
||||
|
||||
message GetLoadInfoRequest {
|
||||
common.MsgBase base = 1;
|
||||
int64 collection_id = 2;
|
||||
}
|
||||
|
||||
message GetLoadInfoResponse {
|
||||
common.Status status = 1;
|
||||
schema.CollectionSchema schema = 2;
|
||||
LoadType load_type = 3;
|
||||
repeated int64 partitions = 4;
|
||||
}
|
||||
|
||||
//----------------request auto triggered by QueryCoord-----------------
|
||||
message HandoffSegmentsRequest {
|
||||
common.MsgBase base = 1;
|
||||
|
@ -445,6 +465,7 @@ message SealedSegmentsChangeInfo {
|
|||
|
||||
message GetDataDistributionRequest {
|
||||
common.MsgBase base = 1;
|
||||
map<string, msg.MsgPosition> checkpoints = 2;
|
||||
}
|
||||
|
||||
message GetDataDistributionResponse {
|
||||
|
@ -475,6 +496,7 @@ message SegmentVersionInfo {
|
|||
int64 partition = 3;
|
||||
string channel = 4;
|
||||
int64 version = 5;
|
||||
uint64 last_delta_timestamp = 6;
|
||||
}
|
||||
|
||||
message ChannelVersionInfo {
|
||||
|
@ -516,6 +538,7 @@ message Replica {
|
|||
enum SyncType {
|
||||
Remove = 0;
|
||||
Set = 1;
|
||||
Amend = 2;
|
||||
}
|
||||
|
||||
message SyncAction {
|
||||
|
@ -524,6 +547,7 @@ message SyncAction {
|
|||
int64 segmentID = 3;
|
||||
int64 nodeID = 4;
|
||||
int64 version = 5;
|
||||
SegmentLoadInfo info = 6;
|
||||
}
|
||||
|
||||
message SyncDistributionRequest {
|
||||
|
@ -568,4 +592,13 @@ message ResourceGroupInfo {
|
|||
map<int64, int32> num_outgoing_node = 5;
|
||||
// collection id -> be accessed node num by other rg
|
||||
map<int64, int32> num_incoming_node = 6;
|
||||
}
|
||||
}
|
||||
message DeleteRequest {
|
||||
common.MsgBase base = 1;
|
||||
int64 collection_id = 2;
|
||||
int64 partition_id = 3;
|
||||
string vchannel_name = 4;
|
||||
int64 segment_id = 5;
|
||||
schema.IDs primary_keys = 6;
|
||||
repeated uint64 timestamps = 7;
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -191,7 +191,7 @@ func (rl *rateLimiter) registerLimiters() {
|
|||
}
|
||||
limit := ratelimitutil.Limit(r.GetAsFloat())
|
||||
burst := r.GetAsFloat() // use rate as burst, because Limiter is with punishment mechanism, burst is insignificant.
|
||||
rl.limiters.InsertIfNotPresent(internalpb.RateType(rt), ratelimitutil.NewLimiter(limit, burst))
|
||||
rl.limiters.GetOrInsert(internalpb.RateType(rt), ratelimitutil.NewLimiter(limit, burst))
|
||||
onEvent := func(rateType internalpb.RateType) func(*config.Event) {
|
||||
return func(event *config.Event) {
|
||||
f, err := strconv.ParseFloat(event.Value, 64)
|
||||
|
|
|
@ -135,3 +135,7 @@ func (m *QueryNodeMock) GetDataDistribution(context.Context, *querypb.GetDataDis
|
|||
func (m *QueryNodeMock) SyncDistribution(context.Context, *querypb.SyncDistributionRequest) (*commonpb.Status, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *QueryNodeMock) Delete(context.Context, *querypb.DeleteRequest) (*commonpb.Status, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
@ -299,8 +300,10 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro
|
|||
}
|
||||
|
||||
func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string, channelNum int) error {
|
||||
nodeReq := proto.Clone(g.GetStatisticsRequest).(*internalpb.GetStatisticsRequest)
|
||||
nodeReq.Base.TargetID = nodeID
|
||||
req := &querypb.GetStatisticsRequest{
|
||||
Req: g.GetStatisticsRequest,
|
||||
Req: nodeReq,
|
||||
DmlChannels: channelIDs,
|
||||
Scope: querypb.DataScope_All,
|
||||
}
|
||||
|
|
|
@ -160,7 +160,8 @@ func (c *SegmentChecker) getStreamingSegmentsDist(distMgr *meta.DistributionMana
|
|||
}
|
||||
|
||||
// GetHistoricalSegmentDiff get historical segment diff between target and dist
|
||||
func (c *SegmentChecker) getHistoricalSegmentDiff(targetMgr *meta.TargetManager,
|
||||
func (c *SegmentChecker) getHistoricalSegmentDiff(
|
||||
targetMgr *meta.TargetManager,
|
||||
distMgr *meta.DistributionManager,
|
||||
metaInfo *meta.Meta,
|
||||
collectionID int64,
|
||||
|
@ -179,7 +180,7 @@ func (c *SegmentChecker) getHistoricalSegmentDiff(targetMgr *meta.TargetManager,
|
|||
nextTargetMap := targetMgr.GetHistoricalSegmentsByCollection(collectionID, meta.NextTarget)
|
||||
currentTargetMap := targetMgr.GetHistoricalSegmentsByCollection(collectionID, meta.CurrentTarget)
|
||||
|
||||
//get segment which exist on next target, but not on dist
|
||||
// Segment which exist on next target, but not on dist
|
||||
for segmentID, segment := range nextTargetMap {
|
||||
if !distMap.Contain(segmentID) {
|
||||
toLoad = append(toLoad, segment)
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
|
@ -119,14 +120,16 @@ func (dh *distHandler) updateSegmentsDistribution(resp *querypb.GetDataDistribut
|
|||
PartitionID: s.GetPartition(),
|
||||
InsertChannel: s.GetChannel(),
|
||||
},
|
||||
Node: resp.GetNodeID(),
|
||||
Version: s.GetVersion(),
|
||||
Node: resp.GetNodeID(),
|
||||
Version: s.GetVersion(),
|
||||
LastDeltaTimestamp: s.GetLastDeltaTimestamp(),
|
||||
}
|
||||
} else {
|
||||
segment = &meta.Segment{
|
||||
SegmentInfo: proto.Clone(segmentInfo).(*datapb.SegmentInfo),
|
||||
Node: resp.GetNodeID(),
|
||||
Version: s.GetVersion(),
|
||||
SegmentInfo: proto.Clone(segmentInfo).(*datapb.SegmentInfo),
|
||||
Node: resp.GetNodeID(),
|
||||
Version: s.GetVersion(),
|
||||
LastDeltaTimestamp: s.GetLastDeltaTimestamp(),
|
||||
}
|
||||
}
|
||||
updates = append(updates, segment)
|
||||
|
@ -200,12 +203,24 @@ func (dh *distHandler) updateLeaderView(resp *querypb.GetDataDistributionRespons
|
|||
func (dh *distHandler) getDistribution(ctx context.Context) error {
|
||||
dh.mu.Lock()
|
||||
defer dh.mu.Unlock()
|
||||
|
||||
channels := make(map[string]*msgpb.MsgPosition)
|
||||
for _, channel := range dh.dist.ChannelDistManager.GetByNode(dh.nodeID) {
|
||||
targetChannel := dh.target.GetDmChannel(channel.GetCollectionID(), channel.GetChannelName(), meta.CurrentTarget)
|
||||
if targetChannel == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
channels[channel.GetChannelName()] = targetChannel.GetSeekPosition()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, distReqTimeout)
|
||||
defer cancel()
|
||||
resp, err := dh.client.GetDataDistribution(ctx, dh.nodeID, &querypb.GetDataDistributionRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_GetDistribution),
|
||||
),
|
||||
Checkpoints: channels,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
|
|
@ -26,8 +26,9 @@ import (
|
|||
|
||||
type Segment struct {
|
||||
*datapb.SegmentInfo
|
||||
Node int64 // Node the segment is in
|
||||
Version int64 // Version is the timestamp of loading segment
|
||||
Node int64 // Node the segment is in
|
||||
Version int64 // Version is the timestamp of loading segment
|
||||
LastDeltaTimestamp uint64 // The timestamp of the last delta record
|
||||
}
|
||||
|
||||
func SegmentFromInfo(info *datapb.SegmentInfo) *Segment {
|
||||
|
|
|
@ -29,6 +29,53 @@ func (_m *MockQueryNodeServer) EXPECT() *MockQueryNodeServer_Expecter {
|
|||
return &MockQueryNodeServer_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Delete provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryNodeServer) Delete(_a0 context.Context, _a1 *querypb.DeleteRequest) (*commonpb.Status, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 *commonpb.Status
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.DeleteRequest) *commonpb.Status); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*commonpb.Status)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.DeleteRequest) error); ok {
|
||||
r1 = rf(_a0, _a1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockQueryNodeServer_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete'
|
||||
type MockQueryNodeServer_Delete_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Delete is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.DeleteRequest
|
||||
func (_e *MockQueryNodeServer_Expecter) Delete(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Delete_Call {
|
||||
return &MockQueryNodeServer_Delete_Call{Call: _e.mock.On("Delete", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockQueryNodeServer_Delete_Call) Run(run func(_a0 context.Context, _a1 *querypb.DeleteRequest)) *MockQueryNodeServer_Delete_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.DeleteRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockQueryNodeServer_Delete_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryNodeServer_Delete_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetComponentStates provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockQueryNodeServer) GetComponentStates(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
|
|
@ -31,7 +31,10 @@ import (
|
|||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const interval = 1 * time.Second
|
||||
const (
|
||||
interval = 1 * time.Second
|
||||
RPCTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
// LeaderObserver is to sync the distribution with leader
|
||||
type LeaderObserver struct {
|
||||
|
@ -40,6 +43,7 @@ type LeaderObserver struct {
|
|||
dist *meta.DistributionManager
|
||||
meta *meta.Meta
|
||||
target *meta.TargetManager
|
||||
broker meta.Broker
|
||||
cluster session.Cluster
|
||||
|
||||
stopOnce sync.Once
|
||||
|
@ -106,18 +110,54 @@ func (o *LeaderObserver) findNeedLoadedSegments(leaderView *meta.LeaderView, dis
|
|||
dists = utils.FindMaxVersionSegments(dists)
|
||||
for _, s := range dists {
|
||||
version, ok := leaderView.Segments[s.GetID()]
|
||||
existInCurrentTarget := o.target.GetHistoricalSegment(s.CollectionID, s.GetID(), meta.CurrentTarget) != nil
|
||||
currentTarget := o.target.GetHistoricalSegment(s.CollectionID, s.GetID(), meta.CurrentTarget)
|
||||
existInCurrentTarget := currentTarget != nil
|
||||
existInNextTarget := o.target.GetHistoricalSegment(s.CollectionID, s.GetID(), meta.NextTarget) != nil
|
||||
if ok && version.GetVersion() >= s.Version || (!existInCurrentTarget && !existInNextTarget) {
|
||||
|
||||
if !existInCurrentTarget && !existInNextTarget {
|
||||
continue
|
||||
}
|
||||
ret = append(ret, &querypb.SyncAction{
|
||||
Type: querypb.SyncType_Set,
|
||||
PartitionID: s.GetPartitionID(),
|
||||
SegmentID: s.GetID(),
|
||||
NodeID: s.Node,
|
||||
Version: s.Version,
|
||||
})
|
||||
|
||||
if !ok || version.GetVersion() < s.Version { // Leader misses this segment
|
||||
ctx := context.Background()
|
||||
resp, err := o.broker.GetSegmentInfo(ctx, s.GetID())
|
||||
if err != nil || len(resp.GetInfos()) == 0 {
|
||||
log.Warn("failed to get segment info from DataCoord", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
segment := resp.GetInfos()[0]
|
||||
loadInfo := utils.PackSegmentLoadInfo(segment, nil)
|
||||
|
||||
// Fix the leader view with lacks of delta logs
|
||||
if existInCurrentTarget && s.LastDeltaTimestamp < currentTarget.GetDmlPosition().GetTimestamp() {
|
||||
log.Info("Fix QueryNode delta logs lag",
|
||||
zap.Int64("nodeID", s.Node),
|
||||
zap.Int64("collectionID", s.GetCollectionID()),
|
||||
zap.Int64("partitionID", s.GetPartitionID()),
|
||||
zap.Int64("segmentID", s.GetID()),
|
||||
zap.Uint64("segmentDeltaTimestamp", s.LastDeltaTimestamp),
|
||||
zap.Uint64("channelTimestamp", currentTarget.GetDmlPosition().GetTimestamp()),
|
||||
)
|
||||
|
||||
ret = append(ret, &querypb.SyncAction{
|
||||
Type: querypb.SyncType_Amend,
|
||||
PartitionID: s.GetPartitionID(),
|
||||
SegmentID: s.GetID(),
|
||||
NodeID: s.Node,
|
||||
Version: s.Version,
|
||||
Info: loadInfo,
|
||||
})
|
||||
}
|
||||
|
||||
ret = append(ret, &querypb.SyncAction{
|
||||
Type: querypb.SyncType_Set,
|
||||
PartitionID: s.GetPartitionID(),
|
||||
SegmentID: s.GetID(),
|
||||
NodeID: s.Node,
|
||||
Version: s.Version,
|
||||
Info: loadInfo,
|
||||
})
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
@ -181,6 +221,7 @@ func NewLeaderObserver(
|
|||
dist *meta.DistributionManager,
|
||||
meta *meta.Meta,
|
||||
targetMgr *meta.TargetManager,
|
||||
broker meta.Broker,
|
||||
cluster session.Cluster,
|
||||
) *LeaderObserver {
|
||||
return &LeaderObserver{
|
||||
|
@ -188,6 +229,7 @@ func NewLeaderObserver(
|
|||
dist: dist,
|
||||
meta: meta,
|
||||
target: targetMgr,
|
||||
broker: broker,
|
||||
cluster: cluster,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -73,7 +73,7 @@ func (suite *LeaderObserverTestSuite) SetupTest() {
|
|||
suite.mockCluster = session.NewMockCluster(suite.T())
|
||||
distManager := meta.NewDistributionManager()
|
||||
targetManager := meta.NewTargetManager(suite.broker, suite.meta)
|
||||
suite.observer = NewLeaderObserver(distManager, suite.meta, targetManager, suite.mockCluster)
|
||||
suite.observer = NewLeaderObserver(distManager, suite.meta, targetManager, suite.broker, suite.mockCluster)
|
||||
}
|
||||
|
||||
func (suite *LeaderObserverTestSuite) TearDownTest() {
|
||||
|
@ -98,6 +98,15 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegments() {
|
|||
},
|
||||
}
|
||||
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
|
||||
info := &datapb.SegmentInfo{
|
||||
ID: 1,
|
||||
CollectionID: 1,
|
||||
PartitionID: 1,
|
||||
InsertChannel: "test-insert-channel",
|
||||
}
|
||||
loadInfo := utils.PackSegmentLoadInfo(info, nil)
|
||||
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, int64(1)).Return(
|
||||
&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{info}}, nil)
|
||||
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
|
||||
channels, segments, nil)
|
||||
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
|
||||
|
@ -118,6 +127,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegments() {
|
|||
SegmentID: 1,
|
||||
NodeID: 1,
|
||||
Version: 1,
|
||||
Info: loadInfo,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -154,6 +164,15 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncLoadedSegments() {
|
|||
},
|
||||
}
|
||||
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
|
||||
info := &datapb.SegmentInfo{
|
||||
ID: 1,
|
||||
CollectionID: 1,
|
||||
PartitionID: 1,
|
||||
InsertChannel: "test-insert-channel",
|
||||
}
|
||||
loadInfo := utils.PackSegmentLoadInfo(info, nil)
|
||||
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, int64(1)).Return(
|
||||
&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{info}}, nil)
|
||||
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
|
||||
channels, segments, nil)
|
||||
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
|
||||
|
@ -176,6 +195,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncLoadedSegments() {
|
|||
SegmentID: 1,
|
||||
NodeID: 1,
|
||||
Version: 1,
|
||||
Info: loadInfo,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -252,6 +272,15 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegmentsWithReplicas() {
|
|||
},
|
||||
}
|
||||
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
|
||||
info := &datapb.SegmentInfo{
|
||||
ID: 1,
|
||||
CollectionID: 1,
|
||||
PartitionID: 1,
|
||||
InsertChannel: "test-insert-channel",
|
||||
}
|
||||
loadInfo := utils.PackSegmentLoadInfo(info, nil)
|
||||
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, int64(1)).Return(
|
||||
&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{info}}, nil)
|
||||
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
|
||||
channels, segments, nil)
|
||||
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
|
||||
|
@ -274,6 +303,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegmentsWithReplicas() {
|
|||
SegmentID: 1,
|
||||
NodeID: 1,
|
||||
Version: 1,
|
||||
Info: loadInfo,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -331,6 +331,7 @@ func (s *Server) initObserver() {
|
|||
s.dist,
|
||||
s.meta,
|
||||
s.targetMgr,
|
||||
s.broker,
|
||||
s.cluster,
|
||||
)
|
||||
s.targetObserver = observers.NewTargetObserver(
|
||||
|
|
|
@ -77,6 +77,7 @@ type SegmentAction struct {
|
|||
func NewSegmentAction(nodeID UniqueID, typ ActionType, shard string, segmentID UniqueID) *SegmentAction {
|
||||
return NewSegmentActionWithScope(nodeID, typ, shard, segmentID, querypb.DataScope_All)
|
||||
}
|
||||
|
||||
func NewSegmentActionWithScope(nodeID UniqueID, typ ActionType, shard string, segmentID UniqueID, scope querypb.DataScope) *SegmentAction {
|
||||
base := NewBaseAction(nodeID, typ, shard)
|
||||
return &SegmentAction{
|
||||
|
|
|
@ -99,10 +99,7 @@ func packLoadSegmentRequest(
|
|||
segment := resp.GetInfos()[0]
|
||||
|
||||
var posSrcStr string
|
||||
if resp.GetChannelCheckpoint() != nil && resp.ChannelCheckpoint[segment.InsertChannel] != nil {
|
||||
deltaPosition = resp.ChannelCheckpoint[segment.InsertChannel]
|
||||
posSrcStr = "channelCheckpoint"
|
||||
} else if segment.GetDmlPosition() != nil {
|
||||
if segment.GetDmlPosition() != nil {
|
||||
deltaPosition = segment.GetDmlPosition()
|
||||
posSrcStr = "segmentDMLPos"
|
||||
} else {
|
||||
|
|
|
@ -26,6 +26,7 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
||||
)
|
||||
|
||||
|
@ -36,7 +37,6 @@ func Test_packLoadSegmentRequest(t *testing.T) {
|
|||
t0 := tsoutil.ComposeTSByTime(time.Now().Add(-20*time.Minute), 0)
|
||||
t1 := tsoutil.ComposeTSByTime(time.Now().Add(-8*time.Minute), 0)
|
||||
t2 := tsoutil.ComposeTSByTime(time.Now().Add(-5*time.Minute), 0)
|
||||
t3 := tsoutil.ComposeTSByTime(time.Now().Add(-1*time.Minute), 0)
|
||||
|
||||
segmentInfo := &datapb.SegmentInfo{
|
||||
ID: 0,
|
||||
|
@ -51,28 +51,6 @@ func Test_packLoadSegmentRequest(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
t.Run("test set deltaPosition from channel checkpoint", func(t *testing.T) {
|
||||
segmentAction := NewSegmentAction(0, 0, "", 0)
|
||||
segmentTask, err := NewSegmentTask(context.TODO(), 5*time.Second, 0, 0, 0, segmentAction)
|
||||
assert.NoError(t, err)
|
||||
|
||||
resp := &datapb.GetSegmentInfoResponse{
|
||||
Infos: []*datapb.SegmentInfo{
|
||||
proto.Clone(segmentInfo).(*datapb.SegmentInfo),
|
||||
},
|
||||
ChannelCheckpoint: map[string]*msgpb.MsgPosition{
|
||||
mockVChannel: {
|
||||
ChannelName: mockPChannel,
|
||||
Timestamp: t3,
|
||||
},
|
||||
},
|
||||
}
|
||||
req := packLoadSegmentRequest(segmentTask, segmentAction, nil, nil, nil, resp)
|
||||
assert.Equal(t, 1, len(req.GetDeltaPositions()))
|
||||
assert.Equal(t, mockPChannel, req.DeltaPositions[0].ChannelName)
|
||||
assert.Equal(t, t3, req.DeltaPositions[0].Timestamp)
|
||||
})
|
||||
|
||||
t.Run("test set deltaPosition from segment dmlPosition", func(t *testing.T) {
|
||||
segmentAction := NewSegmentAction(0, 0, "", 0)
|
||||
segmentTask, err := NewSegmentTask(context.TODO(), 5*time.Second, 0, 0, 0, segmentAction)
|
||||
|
@ -83,7 +61,7 @@ func Test_packLoadSegmentRequest(t *testing.T) {
|
|||
proto.Clone(segmentInfo).(*datapb.SegmentInfo),
|
||||
},
|
||||
}
|
||||
req := packLoadSegmentRequest(segmentTask, segmentAction, nil, nil, nil, resp)
|
||||
req := packLoadSegmentRequest(segmentTask, segmentAction, nil, nil, &querypb.SegmentLoadInfo{}, resp)
|
||||
assert.Equal(t, 1, len(req.GetDeltaPositions()))
|
||||
assert.Equal(t, mockPChannel, req.DeltaPositions[0].ChannelName)
|
||||
assert.Equal(t, t2, req.DeltaPositions[0].Timestamp)
|
||||
|
@ -99,7 +77,7 @@ func Test_packLoadSegmentRequest(t *testing.T) {
|
|||
resp := &datapb.GetSegmentInfoResponse{
|
||||
Infos: []*datapb.SegmentInfo{segInfo},
|
||||
}
|
||||
req := packLoadSegmentRequest(segmentTask, segmentAction, nil, nil, nil, resp)
|
||||
req := packLoadSegmentRequest(segmentTask, segmentAction, nil, nil, &querypb.SegmentLoadInfo{}, resp)
|
||||
assert.Equal(t, 1, len(req.GetDeltaPositions()))
|
||||
assert.Equal(t, mockPChannel, req.DeltaPositions[0].ChannelName)
|
||||
assert.Equal(t, t1, req.DeltaPositions[0].Timestamp)
|
||||
|
@ -115,7 +93,7 @@ func Test_packLoadSegmentRequest(t *testing.T) {
|
|||
resp := &datapb.GetSegmentInfoResponse{
|
||||
Infos: []*datapb.SegmentInfo{segInfo},
|
||||
}
|
||||
req := packLoadSegmentRequest(segmentTask, segmentAction, nil, nil, nil, resp)
|
||||
req := packLoadSegmentRequest(segmentTask, segmentAction, nil, nil, &querypb.SegmentLoadInfo{}, resp)
|
||||
assert.Equal(t, 1, len(req.GetDeltaPositions()))
|
||||
assert.Equal(t, mockPChannel, req.DeltaPositions[0].ChannelName)
|
||||
assert.Equal(t, t0, req.DeltaPositions[0].Timestamp)
|
||||
|
|
|
@ -90,6 +90,8 @@ func PackSegmentLoadInfo(segment *datapb.SegmentInfo, indexes []*querypb.FieldIn
|
|||
Deltalogs: segment.Deltalogs,
|
||||
InsertChannel: segment.InsertChannel,
|
||||
IndexInfos: indexes,
|
||||
StartPosition: segment.GetStartPosition(),
|
||||
EndPosition: segment.GetDmlPosition(),
|
||||
}
|
||||
loadInfo.SegmentSize = calculateSegmentSize(loadInfo)
|
||||
return loadInfo
|
||||
|
|
|
@ -1499,3 +1499,10 @@ func (node *QueryNode) GetSession() *sessionutil.Session {
|
|||
defer node.sessionMu.Unlock()
|
||||
return node.session
|
||||
}
|
||||
|
||||
func (node *QueryNode) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "not implemented in qnv1",
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -622,24 +622,24 @@ func (replica *metaReplica) addSegmentPrivate(segment *Segment) error {
|
|||
zap.Int64("segment ID", segID),
|
||||
zap.String("segment type", segType.String()),
|
||||
zap.Int64("row count", rowCount),
|
||||
zap.Uint64("segment indexed fields", segment.indexedFieldInfos.Len()),
|
||||
)
|
||||
metrics.QueryNodeNumSegments.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
fmt.Sprint(segment.collectionID),
|
||||
fmt.Sprint(segment.partitionID),
|
||||
segType.String(),
|
||||
fmt.Sprint(segment.indexedFieldInfos.Len()),
|
||||
).Inc()
|
||||
if rowCount > 0 {
|
||||
metrics.QueryNodeNumEntities.WithLabelValues(
|
||||
/*
|
||||
metrics.QueryNodeNumSegments.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
fmt.Sprint(segment.collectionID),
|
||||
fmt.Sprint(segment.partitionID),
|
||||
segType.String(),
|
||||
fmt.Sprint(segment.indexedFieldInfos.Len()),
|
||||
).Add(float64(rowCount))
|
||||
}
|
||||
).Inc()
|
||||
if rowCount > 0 {
|
||||
metrics.QueryNodeNumEntities.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
fmt.Sprint(segment.collectionID),
|
||||
fmt.Sprint(segment.partitionID),
|
||||
segType.String(),
|
||||
fmt.Sprint(segment.indexedFieldInfos.Len()),
|
||||
).Add(float64(rowCount))
|
||||
}*/
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -730,25 +730,25 @@ func (replica *metaReplica) removeSegmentPrivate(segmentID UniqueID, segType seg
|
|||
zap.Int64("segment ID", segmentID),
|
||||
zap.String("segment type", segType.String()),
|
||||
zap.Int64("row count", rowCount),
|
||||
zap.Uint64("segment indexed fields", segment.indexedFieldInfos.Len()),
|
||||
)
|
||||
metrics.QueryNodeNumSegments.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
fmt.Sprint(segment.collectionID),
|
||||
fmt.Sprint(segment.partitionID),
|
||||
segType.String(),
|
||||
// Note: this field is mutable after segment is loaded.
|
||||
fmt.Sprint(segment.indexedFieldInfos.Len()),
|
||||
).Dec()
|
||||
if rowCount > 0 {
|
||||
metrics.QueryNodeNumEntities.WithLabelValues(
|
||||
/*
|
||||
metrics.QueryNodeNumSegments.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
fmt.Sprint(segment.collectionID),
|
||||
fmt.Sprint(segment.partitionID),
|
||||
segType.String(),
|
||||
// Note: this field is mutable after segment is loaded.
|
||||
fmt.Sprint(segment.indexedFieldInfos.Len()),
|
||||
).Sub(float64(rowCount))
|
||||
}
|
||||
).Dec()
|
||||
if rowCount > 0 {
|
||||
metrics.QueryNodeNumEntities.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
fmt.Sprint(segment.collectionID),
|
||||
fmt.Sprint(segment.partitionID),
|
||||
segType.String(),
|
||||
fmt.Sprint(segment.indexedFieldInfos.Len()),
|
||||
).Sub(float64(rowCount))
|
||||
}*/
|
||||
}
|
||||
replica.sendNoSegmentSignal()
|
||||
}
|
||||
|
|
|
@ -595,10 +595,7 @@ func genVectorChunkManager(ctx context.Context, col *Collection) (*storage.Vecto
|
|||
return nil, err
|
||||
}
|
||||
|
||||
vcm, err := storage.NewVectorChunkManager(ctx, lcm, rcm, &etcdpb.CollectionMeta{
|
||||
ID: col.id,
|
||||
Schema: col.schema,
|
||||
}, Params.QueryNodeCfg.CacheMemoryLimit.GetAsInt64(), false)
|
||||
vcm, err := storage.NewVectorChunkManager(ctx, lcm, rcm, Params.QueryNodeCfg.CacheMemoryLimit.GetAsInt64(), false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -908,7 +905,7 @@ func genStorageBlob(collectionID UniqueID,
|
|||
}
|
||||
tmpSchema.Fields = append(tmpSchema.Fields, schema.Fields...)
|
||||
collMeta := genCollectionMeta(collectionID, tmpSchema)
|
||||
inCodec := storage.NewInsertCodec(collMeta)
|
||||
inCodec := storage.NewInsertCodecWithSchema(collMeta)
|
||||
insertData, err := genInsertData(msgLength, schema)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
|
|
@ -23,7 +23,6 @@ import (
|
|||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/etcdpb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
)
|
||||
|
@ -68,11 +67,11 @@ func newQueryShard(
|
|||
if remoteChunkManager == nil {
|
||||
return nil, fmt.Errorf("can not create vector chunk manager for remote chunk manager is nil")
|
||||
}
|
||||
vectorChunkManager, err := storage.NewVectorChunkManager(ctx, localChunkManager, remoteChunkManager,
|
||||
&etcdpb.CollectionMeta{
|
||||
ID: collectionID,
|
||||
Schema: collection.schema,
|
||||
}, Params.QueryNodeCfg.CacheMemoryLimit.GetAsInt64(), localCacheEnabled)
|
||||
vectorChunkManager, err := storage.NewVectorChunkManager(ctx,
|
||||
localChunkManager,
|
||||
remoteChunkManager,
|
||||
Params.QueryNodeCfg.CacheMemoryLimit.GetAsInt64(),
|
||||
localCacheEnabled)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -122,7 +122,7 @@ func (s *Segment) getType() segmentType {
|
|||
}
|
||||
|
||||
func (s *Segment) setIndexedFieldInfo(fieldID UniqueID, info *IndexedFieldInfo) {
|
||||
s.indexedFieldInfos.InsertIfNotPresent(fieldID, info)
|
||||
s.indexedFieldInfos.GetOrInsert(fieldID, info)
|
||||
}
|
||||
|
||||
func (s *Segment) getIndexedFieldInfo(fieldID UniqueID) (*IndexedFieldInfo, error) {
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Manager is the interface for worker manager.
|
||||
type Manager interface {
|
||||
GetWorker(nodeID int64) (Worker, error)
|
||||
}
|
||||
|
||||
// WorkerBuilder is function alias to build a worker from NodeID
|
||||
type WorkerBuilder func(nodeID int64) (Worker, error)
|
||||
|
||||
type grpcWorkerManager struct {
|
||||
workers *typeutil.ConcurrentMap[int64, Worker]
|
||||
builder WorkerBuilder
|
||||
}
|
||||
|
||||
// GetWorker returns worker with specified nodeID.
|
||||
func (m *grpcWorkerManager) GetWorker(nodeID int64) (Worker, error) {
|
||||
worker, ok := m.workers.Get(nodeID)
|
||||
var err error
|
||||
if !ok {
|
||||
//TODO merge request?
|
||||
worker, err = m.builder(nodeID)
|
||||
if err != nil {
|
||||
log.Warn("failed to build worker",
|
||||
zap.Int64("nodeID", nodeID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
old, exist := m.workers.GetOrInsert(nodeID, worker)
|
||||
if exist {
|
||||
worker.Stop()
|
||||
worker = old
|
||||
}
|
||||
}
|
||||
if !worker.IsHealthy() {
|
||||
// TODO wrap error
|
||||
return nil, fmt.Errorf("node is not healthy: %d", nodeID)
|
||||
}
|
||||
return worker, nil
|
||||
}
|
||||
|
||||
func NewWorkerManager(builder WorkerBuilder) Manager {
|
||||
return &grpcWorkerManager{
|
||||
workers: typeutil.NewConcurrentMap[int64, Worker](),
|
||||
builder: builder,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestManager(t *testing.T) {
|
||||
t.Run("normal_get", func(t *testing.T) {
|
||||
worker := &MockWorker{}
|
||||
worker.EXPECT().IsHealthy().Return(true)
|
||||
var buildErr error
|
||||
var called int
|
||||
builder := func(nodeID int64) (Worker, error) {
|
||||
called++
|
||||
return worker, buildErr
|
||||
}
|
||||
manager := NewWorkerManager(builder)
|
||||
|
||||
w, err := manager.GetWorker(0)
|
||||
assert.Equal(t, worker, w)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, called)
|
||||
|
||||
w, err = manager.GetWorker(0)
|
||||
assert.Equal(t, worker, w)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, called)
|
||||
})
|
||||
|
||||
t.Run("builder_return_error", func(t *testing.T) {
|
||||
worker := &MockWorker{}
|
||||
worker.EXPECT().IsHealthy().Return(true)
|
||||
var buildErr error
|
||||
var called int
|
||||
buildErr = errors.New("mocked error")
|
||||
builder := func(nodeID int64) (Worker, error) {
|
||||
called++
|
||||
return worker, buildErr
|
||||
}
|
||||
manager := NewWorkerManager(builder)
|
||||
|
||||
_, err := manager.GetWorker(0)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, 1, called)
|
||||
})
|
||||
|
||||
t.Run("worker_not_healthy", func(t *testing.T) {
|
||||
worker := &MockWorker{}
|
||||
worker.EXPECT().IsHealthy().Return(false)
|
||||
var buildErr error
|
||||
var called int
|
||||
builder := func(nodeID int64) (Worker, error) {
|
||||
called++
|
||||
return worker, buildErr
|
||||
}
|
||||
manager := NewWorkerManager(builder)
|
||||
|
||||
_, err := manager.GetWorker(0)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, 1, called)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
// Code generated by mockery v2.16.0. DO NOT EDIT.
|
||||
|
||||
package cluster
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockManager is an autogenerated mock type for the Manager type
|
||||
type MockManager struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockManager_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockManager) EXPECT() *MockManager_Expecter {
|
||||
return &MockManager_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// GetWorker provides a mock function with given fields: nodeID
|
||||
func (_m *MockManager) GetWorker(nodeID int64) (Worker, error) {
|
||||
ret := _m.Called(nodeID)
|
||||
|
||||
var r0 Worker
|
||||
if rf, ok := ret.Get(0).(func(int64) Worker); ok {
|
||||
r0 = rf(nodeID)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(Worker)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(int64) error); ok {
|
||||
r1 = rf(nodeID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockManager_GetWorker_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorker'
|
||||
type MockManager_GetWorker_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetWorker is a helper method to define mock.On call
|
||||
// - nodeID int64
|
||||
func (_e *MockManager_Expecter) GetWorker(nodeID interface{}) *MockManager_GetWorker_Call {
|
||||
return &MockManager_GetWorker_Call{Call: _e.mock.On("GetWorker", nodeID)}
|
||||
}
|
||||
|
||||
func (_c *MockManager_GetWorker_Call) Run(run func(nodeID int64)) *MockManager_GetWorker_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(int64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockManager_GetWorker_Call) Return(_a0 Worker, _a1 error) *MockManager_GetWorker_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
type mockConstructorTestingTNewMockManager interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}
|
||||
|
||||
// NewMockManager creates a new instance of MockManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
func NewMockManager(t mockConstructorTestingTNewMockManager) *MockManager {
|
||||
mock := &MockManager{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
|
@ -0,0 +1,358 @@
|
|||
// Code generated by mockery v2.16.0. DO NOT EDIT.
|
||||
|
||||
package cluster
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
internalpb "github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
querypb "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
)
|
||||
|
||||
// MockWorker is an autogenerated mock type for the Worker type
|
||||
type MockWorker struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockWorker_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockWorker) EXPECT() *MockWorker_Expecter {
|
||||
return &MockWorker_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Delete provides a mock function with given fields: ctx, req
|
||||
func (_m *MockWorker) Delete(ctx context.Context, req *querypb.DeleteRequest) error {
|
||||
ret := _m.Called(ctx, req)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.DeleteRequest) error); ok {
|
||||
r0 = rf(ctx, req)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockWorker_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete'
|
||||
type MockWorker_Delete_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Delete is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - req *querypb.DeleteRequest
|
||||
func (_e *MockWorker_Expecter) Delete(ctx interface{}, req interface{}) *MockWorker_Delete_Call {
|
||||
return &MockWorker_Delete_Call{Call: _e.mock.On("Delete", ctx, req)}
|
||||
}
|
||||
|
||||
func (_c *MockWorker_Delete_Call) Run(run func(ctx context.Context, req *querypb.DeleteRequest)) *MockWorker_Delete_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.DeleteRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockWorker_Delete_Call) Return(_a0 error) *MockWorker_Delete_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetStatistics provides a mock function with given fields: ctx, req
|
||||
func (_m *MockWorker) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) {
|
||||
ret := _m.Called(ctx, req)
|
||||
|
||||
var r0 *internalpb.GetStatisticsResponse
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest) *internalpb.GetStatisticsResponse); ok {
|
||||
r0 = rf(ctx, req)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*internalpb.GetStatisticsResponse)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetStatisticsRequest) error); ok {
|
||||
r1 = rf(ctx, req)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockWorker_GetStatistics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStatistics'
|
||||
type MockWorker_GetStatistics_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetStatistics is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - req *querypb.GetStatisticsRequest
|
||||
func (_e *MockWorker_Expecter) GetStatistics(ctx interface{}, req interface{}) *MockWorker_GetStatistics_Call {
|
||||
return &MockWorker_GetStatistics_Call{Call: _e.mock.On("GetStatistics", ctx, req)}
|
||||
}
|
||||
|
||||
func (_c *MockWorker_GetStatistics_Call) Run(run func(ctx context.Context, req *querypb.GetStatisticsRequest)) *MockWorker_GetStatistics_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.GetStatisticsRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockWorker_GetStatistics_Call) Return(_a0 *internalpb.GetStatisticsResponse, _a1 error) *MockWorker_GetStatistics_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
// IsHealthy provides a mock function with given fields:
|
||||
func (_m *MockWorker) IsHealthy() bool {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func() bool); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockWorker_IsHealthy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsHealthy'
|
||||
type MockWorker_IsHealthy_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// IsHealthy is a helper method to define mock.On call
|
||||
func (_e *MockWorker_Expecter) IsHealthy() *MockWorker_IsHealthy_Call {
|
||||
return &MockWorker_IsHealthy_Call{Call: _e.mock.On("IsHealthy")}
|
||||
}
|
||||
|
||||
func (_c *MockWorker_IsHealthy_Call) Run(run func()) *MockWorker_IsHealthy_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockWorker_IsHealthy_Call) Return(_a0 bool) *MockWorker_IsHealthy_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// LoadSegments provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockWorker) LoadSegments(_a0 context.Context, _a1 *querypb.LoadSegmentsRequest) error {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadSegmentsRequest) error); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockWorker_LoadSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadSegments'
|
||||
type MockWorker_LoadSegments_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// LoadSegments is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.LoadSegmentsRequest
|
||||
func (_e *MockWorker_Expecter) LoadSegments(_a0 interface{}, _a1 interface{}) *MockWorker_LoadSegments_Call {
|
||||
return &MockWorker_LoadSegments_Call{Call: _e.mock.On("LoadSegments", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockWorker_LoadSegments_Call) Run(run func(_a0 context.Context, _a1 *querypb.LoadSegmentsRequest)) *MockWorker_LoadSegments_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.LoadSegmentsRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockWorker_LoadSegments_Call) Return(_a0 error) *MockWorker_LoadSegments_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Query provides a mock function with given fields: ctx, req
|
||||
func (_m *MockWorker) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
|
||||
ret := _m.Called(ctx, req)
|
||||
|
||||
var r0 *internalpb.RetrieveResults
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) *internalpb.RetrieveResults); ok {
|
||||
r0 = rf(ctx, req)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*internalpb.RetrieveResults)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest) error); ok {
|
||||
r1 = rf(ctx, req)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockWorker_Query_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Query'
|
||||
type MockWorker_Query_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Query is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - req *querypb.QueryRequest
|
||||
func (_e *MockWorker_Expecter) Query(ctx interface{}, req interface{}) *MockWorker_Query_Call {
|
||||
return &MockWorker_Query_Call{Call: _e.mock.On("Query", ctx, req)}
|
||||
}
|
||||
|
||||
func (_c *MockWorker_Query_Call) Run(run func(ctx context.Context, req *querypb.QueryRequest)) *MockWorker_Query_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.QueryRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockWorker_Query_Call) Return(_a0 *internalpb.RetrieveResults, _a1 error) *MockWorker_Query_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ReleaseSegments provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockWorker) ReleaseSegments(_a0 context.Context, _a1 *querypb.ReleaseSegmentsRequest) error {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseSegmentsRequest) error); ok {
|
||||
r0 = rf(_a0, _a1)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockWorker_ReleaseSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleaseSegments'
|
||||
type MockWorker_ReleaseSegments_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ReleaseSegments is a helper method to define mock.On call
|
||||
// - _a0 context.Context
|
||||
// - _a1 *querypb.ReleaseSegmentsRequest
|
||||
func (_e *MockWorker_Expecter) ReleaseSegments(_a0 interface{}, _a1 interface{}) *MockWorker_ReleaseSegments_Call {
|
||||
return &MockWorker_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", _a0, _a1)}
|
||||
}
|
||||
|
||||
func (_c *MockWorker_ReleaseSegments_Call) Run(run func(_a0 context.Context, _a1 *querypb.ReleaseSegmentsRequest)) *MockWorker_ReleaseSegments_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.ReleaseSegmentsRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockWorker_ReleaseSegments_Call) Return(_a0 error) *MockWorker_ReleaseSegments_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Search provides a mock function with given fields: ctx, req
|
||||
func (_m *MockWorker) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
|
||||
ret := _m.Called(ctx, req)
|
||||
|
||||
var r0 *internalpb.SearchResults
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) *internalpb.SearchResults); ok {
|
||||
r0 = rf(ctx, req)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*internalpb.SearchResults)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SearchRequest) error); ok {
|
||||
r1 = rf(ctx, req)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockWorker_Search_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Search'
|
||||
type MockWorker_Search_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Search is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - req *querypb.SearchRequest
|
||||
func (_e *MockWorker_Expecter) Search(ctx interface{}, req interface{}) *MockWorker_Search_Call {
|
||||
return &MockWorker_Search_Call{Call: _e.mock.On("Search", ctx, req)}
|
||||
}
|
||||
|
||||
func (_c *MockWorker_Search_Call) Run(run func(ctx context.Context, req *querypb.SearchRequest)) *MockWorker_Search_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.SearchRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockWorker_Search_Call) Return(_a0 *internalpb.SearchResults, _a1 error) *MockWorker_Search_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Stop provides a mock function with given fields:
|
||||
func (_m *MockWorker) Stop() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// MockWorker_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
|
||||
type MockWorker_Stop_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Stop is a helper method to define mock.On call
|
||||
func (_e *MockWorker_Expecter) Stop() *MockWorker_Stop_Call {
|
||||
return &MockWorker_Stop_Call{Call: _e.mock.On("Stop")}
|
||||
}
|
||||
|
||||
func (_c *MockWorker_Stop_Call) Run(run func()) *MockWorker_Stop_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockWorker_Stop_Call) Return() *MockWorker_Stop_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
type mockConstructorTestingTNewMockWorker interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}
|
||||
|
||||
// NewMockWorker creates a new instance of MockWorker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
func NewMockWorker(t mockConstructorTestingTNewMockWorker) *MockWorker {
|
||||
mock := &MockWorker{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// delegator package contains the logic of shard delegator.
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/merr"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Worker is the interface definition for querynode worker role.
|
||||
type Worker interface {
|
||||
LoadSegments(context.Context, *querypb.LoadSegmentsRequest) error
|
||||
ReleaseSegments(context.Context, *querypb.ReleaseSegmentsRequest) error
|
||||
Delete(ctx context.Context, req *querypb.DeleteRequest) error
|
||||
Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error)
|
||||
Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error)
|
||||
GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error)
|
||||
|
||||
IsHealthy() bool
|
||||
Stop()
|
||||
}
|
||||
|
||||
// remoteWorker wraps grpc QueryNode client as Worker.
|
||||
type remoteWorker struct {
|
||||
client types.QueryNode
|
||||
}
|
||||
|
||||
// NewRemoteWorker creates a grpcWorker.
|
||||
func NewRemoteWorker(client types.QueryNode) Worker {
|
||||
return &remoteWorker{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadSegments implements Worker.
|
||||
func (w *remoteWorker) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("workerID", req.GetDstNodeID()),
|
||||
)
|
||||
status, err := w.client.LoadSegments(ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("failed to call LoadSegments via grpc worker",
|
||||
zap.Error(err),
|
||||
)
|
||||
return err
|
||||
} else if status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
log.Warn("failed to call LoadSegments, worker return error",
|
||||
zap.String("errorCode", status.GetErrorCode().String()),
|
||||
zap.String("reason", status.GetReason()),
|
||||
)
|
||||
return fmt.Errorf(status.Reason)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *remoteWorker) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) error {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("workerID", req.GetNodeID()),
|
||||
)
|
||||
status, err := w.client.ReleaseSegments(ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("failed to call ReleaseSegments via grpc worker",
|
||||
zap.Error(err),
|
||||
)
|
||||
return err
|
||||
} else if status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
log.Warn("failed to call ReleaseSegments, worker return error",
|
||||
zap.String("errorCode", status.GetErrorCode().String()),
|
||||
zap.String("reason", status.GetReason()),
|
||||
)
|
||||
return fmt.Errorf(status.Reason)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *remoteWorker) Delete(ctx context.Context, req *querypb.DeleteRequest) error {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("workerID", req.GetBase().GetTargetID()),
|
||||
)
|
||||
status, err := w.client.Delete(ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("failed to call Delete via grpc worker",
|
||||
zap.Error(err),
|
||||
)
|
||||
return err
|
||||
} else if status.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("failed to call Delete, worker return error",
|
||||
zap.String("errorCode", status.GetErrorCode().String()),
|
||||
zap.String("reason", status.GetReason()),
|
||||
)
|
||||
return merr.Error(status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *remoteWorker) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
|
||||
return w.client.Search(ctx, req)
|
||||
}
|
||||
|
||||
func (w *remoteWorker) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
|
||||
return w.client.Query(ctx, req)
|
||||
}
|
||||
|
||||
func (w *remoteWorker) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) {
|
||||
return w.client.GetStatistics(ctx, req)
|
||||
}
|
||||
|
||||
func (w *remoteWorker) IsHealthy() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (w *remoteWorker) Stop() {
|
||||
w.client.Stop()
|
||||
}
|
|
@ -0,0 +1,372 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// delegator package contains the logic of shard delegator.
|
||||
|
||||
package cluster
|
||||
|
||||
import (
|
||||
context "context"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
commonpb "github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
internalpb "github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
querypb "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type RemoteWorkerSuite struct {
|
||||
suite.Suite
|
||||
|
||||
mockClient *types.MockQueryNode
|
||||
worker *remoteWorker
|
||||
}
|
||||
|
||||
func (s *RemoteWorkerSuite) SetupTest() {
|
||||
s.mockClient = &types.MockQueryNode{}
|
||||
s.worker = &remoteWorker{client: s.mockClient}
|
||||
}
|
||||
|
||||
func (s *RemoteWorkerSuite) TearDownTest() {
|
||||
s.mockClient = nil
|
||||
s.worker = nil
|
||||
}
|
||||
|
||||
func (s *RemoteWorkerSuite) TestLoadSegments() {
|
||||
s.Run("normal_run", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
s.mockClient.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
|
||||
Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.worker.LoadSegments(ctx, &querypb.LoadSegmentsRequest{})
|
||||
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("client_return_error", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
s.mockClient.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
|
||||
Return(nil, errors.New("mocked error"))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.worker.LoadSegments(ctx, &querypb.LoadSegmentsRequest{})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("client_return_fail_status", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
s.mockClient.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
|
||||
Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "mocked failure"}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.worker.LoadSegments(ctx, &querypb.LoadSegmentsRequest{})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RemoteWorkerSuite) TestReleaseSegments() {
|
||||
s.Run("normal_run", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
s.mockClient.EXPECT().ReleaseSegments(mock.Anything, mock.AnythingOfType("*querypb.ReleaseSegmentsRequest")).
|
||||
Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.worker.ReleaseSegments(ctx, &querypb.ReleaseSegmentsRequest{})
|
||||
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("client_return_error", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
s.mockClient.EXPECT().ReleaseSegments(mock.Anything, mock.AnythingOfType("*querypb.ReleaseSegmentsRequest")).
|
||||
Return(nil, errors.New("mocked error"))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.worker.ReleaseSegments(ctx, &querypb.ReleaseSegmentsRequest{})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("client_return_fail_status", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
s.mockClient.EXPECT().ReleaseSegments(mock.Anything, mock.AnythingOfType("*querypb.ReleaseSegmentsRequest")).
|
||||
Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "mocked failure"}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.worker.ReleaseSegments(ctx, &querypb.ReleaseSegmentsRequest{})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RemoteWorkerSuite) TestDelete() {
|
||||
s.Run("normal_run", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
s.mockClient.EXPECT().Delete(mock.Anything, mock.AnythingOfType("*querypb.DeleteRequest")).
|
||||
Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.worker.Delete(ctx, &querypb.DeleteRequest{})
|
||||
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("client_return_error", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
s.mockClient.EXPECT().Delete(mock.Anything, mock.AnythingOfType("*querypb.DeleteRequest")).
|
||||
Return(nil, errors.New("mocked error"))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.worker.Delete(ctx, &querypb.DeleteRequest{})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("client_return_fail_status", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
s.mockClient.EXPECT().Delete(mock.Anything, mock.AnythingOfType("*querypb.DeleteRequest")).
|
||||
Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "mocked failure"}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.worker.Delete(ctx, &querypb.DeleteRequest{})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RemoteWorkerSuite) TestSearch() {
|
||||
s.Run("normal_run", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
var result *internalpb.SearchResults
|
||||
var err error
|
||||
|
||||
result = &internalpb.SearchResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
}
|
||||
s.mockClient.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||
Return(result, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
sr, serr := s.worker.Search(ctx, &querypb.SearchRequest{})
|
||||
|
||||
s.Equal(err, serr)
|
||||
s.Equal(result, sr)
|
||||
})
|
||||
|
||||
s.Run("client_return_error", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
var result *internalpb.SearchResults
|
||||
err := errors.New("mocked error")
|
||||
s.mockClient.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||
Return(result, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
sr, serr := s.worker.Search(ctx, &querypb.SearchRequest{})
|
||||
|
||||
s.Equal(err, serr)
|
||||
s.Equal(result, sr)
|
||||
})
|
||||
|
||||
s.Run("client_return_fail_status", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
var result *internalpb.SearchResults
|
||||
var err error
|
||||
|
||||
result = &internalpb.SearchResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError},
|
||||
}
|
||||
s.mockClient.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||
Return(result, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
sr, serr := s.worker.Search(ctx, &querypb.SearchRequest{})
|
||||
|
||||
s.Equal(err, serr)
|
||||
s.Equal(result, sr)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RemoteWorkerSuite) TestQuery() {
|
||||
s.Run("normal_run", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
var result *internalpb.RetrieveResults
|
||||
var err error
|
||||
|
||||
result = &internalpb.RetrieveResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
}
|
||||
s.mockClient.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).
|
||||
Return(result, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
sr, serr := s.worker.Query(ctx, &querypb.QueryRequest{})
|
||||
|
||||
s.Equal(err, serr)
|
||||
s.Equal(result, sr)
|
||||
})
|
||||
|
||||
s.Run("client_return_error", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
var result *internalpb.RetrieveResults
|
||||
|
||||
err := errors.New("mocked error")
|
||||
s.mockClient.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).
|
||||
Return(result, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
sr, serr := s.worker.Query(ctx, &querypb.QueryRequest{})
|
||||
|
||||
s.Equal(err, serr)
|
||||
s.Equal(result, sr)
|
||||
})
|
||||
|
||||
s.Run("client_return_fail_status", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
var result *internalpb.RetrieveResults
|
||||
var err error
|
||||
|
||||
result = &internalpb.RetrieveResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError},
|
||||
}
|
||||
s.mockClient.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).
|
||||
Return(result, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
sr, serr := s.worker.Query(ctx, &querypb.QueryRequest{})
|
||||
|
||||
s.Equal(err, serr)
|
||||
s.Equal(result, sr)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RemoteWorkerSuite) TestGetStatistics() {
|
||||
s.Run("normal_run", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
var result *internalpb.GetStatisticsResponse
|
||||
var err error
|
||||
|
||||
result = &internalpb.GetStatisticsResponse{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
}
|
||||
s.mockClient.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).
|
||||
Return(result, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
sr, serr := s.worker.GetStatistics(ctx, &querypb.GetStatisticsRequest{})
|
||||
|
||||
s.Equal(err, serr)
|
||||
s.Equal(result, sr)
|
||||
})
|
||||
|
||||
s.Run("client_return_error", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
var result *internalpb.GetStatisticsResponse
|
||||
|
||||
err := errors.New("mocked error")
|
||||
s.mockClient.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).
|
||||
Return(result, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
sr, serr := s.worker.GetStatistics(ctx, &querypb.GetStatisticsRequest{})
|
||||
|
||||
s.Equal(err, serr)
|
||||
s.Equal(result, sr)
|
||||
})
|
||||
|
||||
s.Run("client_return_fail_status", func() {
|
||||
defer func() { s.mockClient.ExpectedCalls = nil }()
|
||||
|
||||
var result *internalpb.GetStatisticsResponse
|
||||
var err error
|
||||
|
||||
result = &internalpb.GetStatisticsResponse{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError},
|
||||
}
|
||||
s.mockClient.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).
|
||||
Return(result, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
sr, serr := s.worker.GetStatistics(ctx, &querypb.GetStatisticsRequest{})
|
||||
|
||||
s.Equal(err, serr)
|
||||
s.Equal(result, sr)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RemoteWorkerSuite) TestBasic() {
|
||||
s.True(s.worker.IsHealthy())
|
||||
|
||||
s.mockClient.EXPECT().Stop().Return(nil)
|
||||
s.worker.Stop()
|
||||
s.mockClient.AssertCalled(s.T(), "Stop")
|
||||
}
|
||||
|
||||
func TestRemoteWorker(t *testing.T) {
|
||||
suite.Run(t, new(RemoteWorkerSuite))
|
||||
}
|
||||
|
||||
func TestNewRemoteWorker(t *testing.T) {
|
||||
client := &types.MockQueryNode{}
|
||||
|
||||
w := NewRemoteWorker(client)
|
||||
|
||||
rw, ok := w.(*remoteWorker)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, client, rw.client)
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package collector
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type averageData struct {
|
||||
total float64
|
||||
count int32
|
||||
}
|
||||
|
||||
func (v *averageData) Add(value float64) {
|
||||
v.total += value
|
||||
v.count++
|
||||
}
|
||||
|
||||
func (v *averageData) Value() float64 {
|
||||
if v.count == 0 {
|
||||
return 0
|
||||
}
|
||||
return v.total / float64(v.count)
|
||||
}
|
||||
|
||||
type averageCollector struct {
|
||||
sync.Mutex
|
||||
averages map[string]*averageData
|
||||
}
|
||||
|
||||
func (c *averageCollector) Register(label string) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if _, ok := c.averages[label]; !ok {
|
||||
c.averages[label] = &averageData{}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *averageCollector) Add(label string, value float64) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if average, ok := c.averages[label]; ok {
|
||||
average.Add(value)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *averageCollector) Average(label string) (float64, error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
average, ok := c.averages[label]
|
||||
if !ok {
|
||||
return 0, WrapErrAvarageLabelNotRegister(label)
|
||||
}
|
||||
|
||||
return average.Value(), nil
|
||||
}
|
||||
|
||||
func newAverageCollector() *averageCollector {
|
||||
return &averageCollector{
|
||||
averages: make(map[string]*averageData),
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package collector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type AverageCollectorTestSuite struct {
|
||||
suite.Suite
|
||||
label string
|
||||
average *averageCollector
|
||||
}
|
||||
|
||||
func (suite *AverageCollectorTestSuite) SetupSuite() {
|
||||
suite.average = newAverageCollector()
|
||||
suite.label = "test_label"
|
||||
}
|
||||
|
||||
func (suite *AverageCollectorTestSuite) TestBasic() {
|
||||
//Get average not register
|
||||
_, err := suite.average.Average(suite.label)
|
||||
suite.Error(err)
|
||||
|
||||
//register and get
|
||||
suite.average.Register(suite.label)
|
||||
value, err := suite.average.Average(suite.label)
|
||||
suite.Equal(float64(0), value)
|
||||
suite.NoError(err)
|
||||
|
||||
//add and get
|
||||
sum := 4
|
||||
for i := 0; i <= sum; i++ {
|
||||
suite.average.Add(suite.label, float64(i))
|
||||
}
|
||||
value, err = suite.average.Average(suite.label)
|
||||
suite.NoError(err)
|
||||
suite.Equal(float64(sum)/2, value)
|
||||
}
|
||||
|
||||
func TestAverageCollector(t *testing.T) {
|
||||
suite.Run(t, new(AverageCollectorTestSuite))
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package collector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/internal/util/ratelimitutil"
|
||||
)
|
||||
|
||||
var Average *averageCollector
|
||||
var Rate *ratelimitutil.RateCollector
|
||||
var Counter *counter
|
||||
|
||||
func RateMetrics() []string {
|
||||
return []string{
|
||||
metricsinfo.NQPerSecond,
|
||||
metricsinfo.SearchThroughput,
|
||||
metricsinfo.InsertConsumeThroughput,
|
||||
metricsinfo.DeleteConsumeThroughput,
|
||||
}
|
||||
}
|
||||
|
||||
func AverageMetrics() []string {
|
||||
return []string{
|
||||
metricsinfo.QueryQueueMetric,
|
||||
metricsinfo.SearchQueueMetric,
|
||||
}
|
||||
}
|
||||
|
||||
func ConstructLabel(subs ...string) string {
|
||||
label := ""
|
||||
for id, sub := range subs {
|
||||
label += sub
|
||||
if id != len(subs)-1 {
|
||||
label += "-"
|
||||
}
|
||||
}
|
||||
return label
|
||||
}
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
Rate, err = ratelimitutil.NewRateCollector(ratelimitutil.DefaultWindow, ratelimitutil.DefaultGranularity)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("querynode collector init failed, err = %s", err)
|
||||
log.Error(err.Error())
|
||||
panic(err)
|
||||
}
|
||||
Average = newAverageCollector()
|
||||
Counter = newCounter()
|
||||
|
||||
//init rate Metric
|
||||
for _, label := range RateMetrics() {
|
||||
Rate.Register(label)
|
||||
}
|
||||
//init average metric
|
||||
|
||||
for _, label := range AverageMetrics() {
|
||||
Average.Register(label)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package collector
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type counter struct {
|
||||
sync.Mutex
|
||||
values map[string]int64
|
||||
}
|
||||
|
||||
func (c *counter) Inc(label string, value int64) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
v, ok := c.values[label]
|
||||
if !ok {
|
||||
c.values[label] = value
|
||||
} else {
|
||||
v += value
|
||||
c.values[label] = v
|
||||
}
|
||||
}
|
||||
|
||||
func (c *counter) Get(label string) int64 {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
v, ok := c.values[label]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (c *counter) Remove(label string) bool {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
_, ok := c.values[label]
|
||||
if ok {
|
||||
delete(c.values, label)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
func newCounter() *counter {
|
||||
return &counter{
|
||||
values: make(map[string]int64),
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package collector
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type CounterTestSuite struct {
|
||||
suite.Suite
|
||||
label string
|
||||
counter *counter
|
||||
}
|
||||
|
||||
func (suite *CounterTestSuite) SetupSuite() {
|
||||
suite.counter = newCounter()
|
||||
suite.label = "test_label"
|
||||
}
|
||||
|
||||
func (suite *CounterTestSuite) TestBasic() {
|
||||
//get default value(zero)
|
||||
value := suite.counter.Get(suite.label)
|
||||
suite.Equal(int64(0), value)
|
||||
|
||||
//get after inc
|
||||
suite.counter.Inc(suite.label, 3)
|
||||
value = suite.counter.Get(suite.label)
|
||||
suite.Equal(int64(3), value)
|
||||
|
||||
//remote
|
||||
suite.counter.Remove(suite.label)
|
||||
value = suite.counter.Get(suite.label)
|
||||
suite.Equal(int64(0), value)
|
||||
}
|
||||
|
||||
func TestCounter(t *testing.T) {
|
||||
suite.Run(t, new(CounterTestSuite))
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package collector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAvarageLabelNotRegister = errors.New("AvarageLabelNotRegister")
|
||||
)
|
||||
|
||||
func WrapErrAvarageLabelNotRegister(label string) error {
|
||||
return fmt.Errorf("%w :%s", ErrAvarageLabelNotRegister, label)
|
||||
}
|
|
@ -0,0 +1,565 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// delegator package contains the logic of shard delegator.
|
||||
package delegator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator/deletebuffer"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type lifetime struct {
|
||||
state atomic.Int32
|
||||
closeCh chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (lt *lifetime) SetState(state int32) {
|
||||
lt.state.Store(state)
|
||||
}
|
||||
|
||||
func (lt *lifetime) GetState() int32 {
|
||||
return lt.state.Load()
|
||||
}
|
||||
|
||||
func (lt *lifetime) Close() {
|
||||
lt.closeOnce.Do(func() {
|
||||
close(lt.closeCh)
|
||||
})
|
||||
}
|
||||
|
||||
func newLifetime() *lifetime {
|
||||
return &lifetime{
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// ShardDelegator is the interface definition.
|
||||
type ShardDelegator interface {
|
||||
Collection() int64
|
||||
Version() int64
|
||||
GetDistribution() *distribution
|
||||
SyncDistribution(ctx context.Context, entries ...SegmentEntry)
|
||||
Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error)
|
||||
Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error)
|
||||
GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error)
|
||||
|
||||
//data
|
||||
ProcessInsert(insertRecords map[int64]*InsertData)
|
||||
ProcessDelete(deleteData []*DeleteData, ts uint64)
|
||||
LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error
|
||||
LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error
|
||||
ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error
|
||||
|
||||
// control
|
||||
Serviceable() bool
|
||||
Start()
|
||||
Close()
|
||||
}
|
||||
|
||||
var _ ShardDelegator = (*shardDelegator)(nil)
|
||||
|
||||
const (
|
||||
initializing int32 = iota
|
||||
working
|
||||
stopped
|
||||
)
|
||||
|
||||
// shardDelegator maintains the shard distribution and streaming part of the data.
|
||||
type shardDelegator struct {
|
||||
// shard information attributes
|
||||
collectionID int64
|
||||
replicaID int64
|
||||
vchannelName string
|
||||
version int64
|
||||
// collection schema
|
||||
collection *segments.Collection
|
||||
|
||||
workerManager cluster.Manager
|
||||
|
||||
lifetime *lifetime
|
||||
|
||||
distribution *distribution
|
||||
segmentManager segments.SegmentManager
|
||||
tsafeManager tsafe.Manager
|
||||
pkOracle pkoracle.PkOracle
|
||||
// L0 delete buffer
|
||||
deleteMut sync.Mutex
|
||||
deleteBuffer deletebuffer.DeleteBuffer[*deletebuffer.Item]
|
||||
//dispatcherClient msgdispatcher.Client
|
||||
factory msgstream.Factory
|
||||
|
||||
loader segments.Loader
|
||||
wg sync.WaitGroup
|
||||
tsCond *sync.Cond
|
||||
latestTsafe *atomic.Uint64
|
||||
}
|
||||
|
||||
// getLogger returns the zap logger with pre-defined shard attributes.
|
||||
func (sd *shardDelegator) getLogger(ctx context.Context) *log.MLogger {
|
||||
return log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", sd.collectionID),
|
||||
zap.String("channel", sd.vchannelName),
|
||||
zap.Int64("replicaID", sd.replicaID),
|
||||
)
|
||||
}
|
||||
|
||||
// Serviceable returns whether delegator is serviceable now.
|
||||
func (sd *shardDelegator) Serviceable() bool {
|
||||
return sd.lifetime.GetState() == working
|
||||
}
|
||||
|
||||
// Start sets delegator to working state.
|
||||
func (sd *shardDelegator) Start() {
|
||||
sd.lifetime.SetState(working)
|
||||
}
|
||||
|
||||
// Collection returns delegator collection id.
|
||||
func (sd *shardDelegator) Collection() int64 {
|
||||
return sd.collectionID
|
||||
}
|
||||
|
||||
// Version returns delegator version.
|
||||
func (sd *shardDelegator) Version() int64 {
|
||||
return sd.version
|
||||
}
|
||||
|
||||
func (sd *shardDelegator) GetDistribution() *distribution {
|
||||
return sd.distribution
|
||||
}
|
||||
|
||||
// SyncDistribution revises distribution.
|
||||
func (sd *shardDelegator) SyncDistribution(ctx context.Context, entries ...SegmentEntry) {
|
||||
log := sd.getLogger(ctx)
|
||||
|
||||
log.Info("sync distribution", zap.Any("entries", entries))
|
||||
|
||||
sd.distribution.AddDistributions(entries...)
|
||||
}
|
||||
|
||||
func modifySearchRequest(req *querypb.SearchRequest, scope querypb.DataScope, segmentIDs []int64, targetID int64) *querypb.SearchRequest {
|
||||
nodeReq := proto.Clone(req).(*querypb.SearchRequest)
|
||||
nodeReq.Scope = scope
|
||||
nodeReq.Req.Base.TargetID = targetID
|
||||
nodeReq.SegmentIDs = segmentIDs
|
||||
nodeReq.FromShardLeader = true
|
||||
return nodeReq
|
||||
}
|
||||
|
||||
func modifyQueryRequest(req *querypb.QueryRequest, scope querypb.DataScope, segmentIDs []int64, targetID int64) *querypb.QueryRequest {
|
||||
nodeReq := proto.Clone(req).(*querypb.QueryRequest)
|
||||
nodeReq.Scope = scope
|
||||
nodeReq.Req.Base.TargetID = targetID
|
||||
nodeReq.SegmentIDs = segmentIDs
|
||||
nodeReq.FromShardLeader = true
|
||||
return nodeReq
|
||||
}
|
||||
|
||||
// Search preforms search operation on shard.
|
||||
func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
|
||||
log := sd.getLogger(ctx)
|
||||
if !sd.Serviceable() {
|
||||
return nil, errors.New("delegator is not serviceable")
|
||||
}
|
||||
|
||||
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
|
||||
log.Warn("deletgator received search request not belongs to it",
|
||||
zap.Strings("reqChannels", req.GetDmlChannels()),
|
||||
)
|
||||
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
|
||||
}
|
||||
|
||||
// wait tsafe
|
||||
err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
|
||||
if err != nil {
|
||||
log.Warn("delegator search failed to wait tsafe", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sealed, growing, version := sd.distribution.GetCurrent(req.GetReq().GetPartitionIDs()...)
|
||||
defer sd.distribution.FinishUsage(version)
|
||||
|
||||
if req.Req.IgnoreGrowing {
|
||||
growing = []SegmentEntry{}
|
||||
}
|
||||
tasks, err := organizeSubTask(req, sealed, growing, sd.workerManager, modifySearchRequest)
|
||||
if err != nil {
|
||||
log.Warn("Search organizeSubTask failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) {
|
||||
return worker.Search(ctx, req)
|
||||
}, "Search", log)
|
||||
if err != nil {
|
||||
log.Warn("Delegator search failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info("Delegator search done")
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Query performs query operation on shard.
|
||||
func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) {
|
||||
log := sd.getLogger(ctx)
|
||||
if !sd.Serviceable() {
|
||||
return nil, errors.New("delegator is not serviceable")
|
||||
}
|
||||
|
||||
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
|
||||
log.Warn("deletgator received query request not belongs to it",
|
||||
zap.Strings("reqChannels", req.GetDmlChannels()),
|
||||
)
|
||||
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
|
||||
}
|
||||
|
||||
// wait tsafe
|
||||
err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
|
||||
if err != nil {
|
||||
log.Warn("delegator query failed to wait tsafe", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sealed, growing, version := sd.distribution.GetCurrent(req.GetReq().GetPartitionIDs()...)
|
||||
defer sd.distribution.FinishUsage(version)
|
||||
if req.Req.IgnoreGrowing {
|
||||
growing = []SegmentEntry{}
|
||||
}
|
||||
|
||||
log.Info("query segments...",
|
||||
zap.Int("sealedNum", len(sealed)),
|
||||
zap.Int("growingNum", len(growing)),
|
||||
)
|
||||
tasks, err := organizeSubTask(req, sealed, growing, sd.workerManager, modifyQueryRequest)
|
||||
if err != nil {
|
||||
log.Warn("query organizeSubTask failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.QueryRequest, worker cluster.Worker) (*internalpb.RetrieveResults, error) {
|
||||
return worker.Query(ctx, req)
|
||||
}, "Query", log)
|
||||
if err != nil {
|
||||
log.Warn("Delegator query failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info("Delegator Query done")
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetStatistics returns statistics aggregated by delegator.
|
||||
func (sd *shardDelegator) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error) {
|
||||
log := sd.getLogger(ctx)
|
||||
if !sd.Serviceable() {
|
||||
return nil, errors.New("delegator is not serviceable")
|
||||
}
|
||||
|
||||
if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) {
|
||||
log.Warn("deletgator received query request not belongs to it",
|
||||
zap.Strings("reqChannels", req.GetDmlChannels()),
|
||||
)
|
||||
return nil, fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels())
|
||||
}
|
||||
|
||||
// wait tsafe
|
||||
err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
|
||||
if err != nil {
|
||||
log.Warn("delegator query failed to wait tsafe", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sealed, growing, version := sd.distribution.GetCurrent(req.Req.GetPartitionIDs()...)
|
||||
defer sd.distribution.FinishUsage(version)
|
||||
|
||||
tasks, err := organizeSubTask(req, sealed, growing, sd.workerManager, func(req *querypb.GetStatisticsRequest, scope querypb.DataScope, segmentIDs []int64, targetID int64) *querypb.GetStatisticsRequest {
|
||||
nodeReq := proto.Clone(req).(*querypb.GetStatisticsRequest)
|
||||
nodeReq.GetReq().GetBase().TargetID = targetID
|
||||
nodeReq.Scope = scope
|
||||
nodeReq.SegmentIDs = segmentIDs
|
||||
nodeReq.FromShardLeader = true
|
||||
return nodeReq
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn("Get statistics organizeSubTask failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.GetStatisticsRequest, worker cluster.Worker) (*internalpb.GetStatisticsResponse, error) {
|
||||
return worker.GetStatistics(ctx, req)
|
||||
}, "GetStatistics", log)
|
||||
if err != nil {
|
||||
log.Warn("Delegator get statistics failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
type subTask[T any] struct {
|
||||
req T
|
||||
targetID int64
|
||||
worker cluster.Worker
|
||||
}
|
||||
|
||||
func organizeSubTask[T any](req T, sealed []SnapshotItem, growing []SegmentEntry, workerManager cluster.Manager, modify func(T, querypb.DataScope, []int64, int64) T) ([]subTask[T], error) {
|
||||
result := make([]subTask[T], 0, len(sealed)+1)
|
||||
|
||||
packSubTask := func(segments []SegmentEntry, workerID int64, scope querypb.DataScope) error {
|
||||
segmentIDs := lo.Map(segments, func(item SegmentEntry, _ int) int64 {
|
||||
return item.SegmentID
|
||||
})
|
||||
if len(segmentIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
// update request
|
||||
req := modify(req, scope, segmentIDs, workerID)
|
||||
|
||||
worker, err := workerManager.GetWorker(workerID)
|
||||
if err != nil {
|
||||
log.Warn("failed to get worker",
|
||||
zap.Int64("nodeID", workerID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("failed to get worker %d, %w", workerID, err)
|
||||
}
|
||||
|
||||
result = append(result, subTask[T]{
|
||||
req: req,
|
||||
targetID: workerID,
|
||||
worker: worker,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, entry := range sealed {
|
||||
err := packSubTask(entry.Segments, entry.NodeID, querypb.DataScope_Historical)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
packSubTask(growing, paramtable.GetNodeID(), querypb.DataScope_Streaming)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func executeSubTasks[T any, R interface {
|
||||
GetStatus() *commonpb.Status
|
||||
}](ctx context.Context, tasks []subTask[T], execute func(context.Context, T, cluster.Worker) (R, error), taskType string, log *log.MLogger) ([]R, error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(tasks))
|
||||
|
||||
resultCh := make(chan R, len(tasks))
|
||||
errCh := make(chan error, 1)
|
||||
for _, task := range tasks {
|
||||
go func(task subTask[T]) {
|
||||
defer wg.Done()
|
||||
result, err := execute(ctx, task.req, task.worker)
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
err = fmt.Errorf("worker(%d) query failed: %s", task.targetID, result.GetStatus().GetReason())
|
||||
}
|
||||
if err != nil {
|
||||
log.Warn("failed to execute sub task",
|
||||
zap.String("taskType", taskType),
|
||||
zap.Int64("nodeID", task.targetID),
|
||||
zap.Error(err),
|
||||
)
|
||||
select {
|
||||
case errCh <- err: // must be the first
|
||||
default: // skip other errors
|
||||
}
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
resultCh <- result
|
||||
}(task)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(resultCh)
|
||||
select {
|
||||
case err := <-errCh:
|
||||
log.Warn("Delegator execute subTask failed",
|
||||
zap.String("taskType", taskType),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, err
|
||||
default:
|
||||
}
|
||||
|
||||
results := make([]R, 0, len(tasks))
|
||||
for result := range resultCh {
|
||||
results = append(results, result)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// waitTSafe returns when tsafe listener notifies a timestamp which meet the guarantee ts.
|
||||
func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error {
|
||||
log := sd.getLogger(ctx)
|
||||
// already safe to search
|
||||
if sd.latestTsafe.Load() >= ts {
|
||||
return nil
|
||||
}
|
||||
// check lag duration too large
|
||||
st, _ := tsoutil.ParseTS(sd.latestTsafe.Load())
|
||||
gt, _ := tsoutil.ParseTS(ts)
|
||||
lag := gt.Sub(st)
|
||||
maxLag := paramtable.Get().QueryNodeCfg.MaxTimestampLag.GetAsDuration(time.Second)
|
||||
if lag > maxLag {
|
||||
log.Warn("guarantee and servicable ts larger than MaxLag",
|
||||
zap.Time("guaranteeTime", gt),
|
||||
zap.Time("serviceableTime", st),
|
||||
zap.Duration("lag", lag),
|
||||
zap.Duration("maxTsLag", maxLag),
|
||||
)
|
||||
return WrapErrTsLagTooLarge(lag, maxLag)
|
||||
}
|
||||
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
sd.tsCond.L.Lock()
|
||||
defer sd.tsCond.L.Unlock()
|
||||
|
||||
for sd.latestTsafe.Load() < ts && ctx.Err() == nil {
|
||||
sd.tsCond.Wait()
|
||||
}
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
// timeout
|
||||
case <-ctx.Done():
|
||||
// notify wait goroutine to quit
|
||||
sd.tsCond.Broadcast()
|
||||
return ctx.Err()
|
||||
case <-ch:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// watchTSafe is the worker function to update serviceable timestamp.
|
||||
func (sd *shardDelegator) watchTSafe() {
|
||||
defer sd.wg.Done()
|
||||
listener := sd.tsafeManager.WatchChannel(sd.vchannelName)
|
||||
sd.updateTSafe()
|
||||
log := sd.getLogger(context.Background())
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-listener.On():
|
||||
if !ok {
|
||||
// listener close
|
||||
log.Warn("tsafe listener closed")
|
||||
return
|
||||
}
|
||||
sd.updateTSafe()
|
||||
case <-sd.lifetime.closeCh:
|
||||
log.Info("updateTSafe quit")
|
||||
// shard delegator closed
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateTSafe read current tsafe value from tsafeManager.
|
||||
func (sd *shardDelegator) updateTSafe() {
|
||||
sd.tsCond.L.Lock()
|
||||
tsafe, err := sd.tsafeManager.Get(sd.vchannelName)
|
||||
if err != nil {
|
||||
log.Warn("tsafeManager failed to get lastest", zap.Error(err))
|
||||
}
|
||||
if tsafe > sd.latestTsafe.Load() {
|
||||
sd.latestTsafe.Store(tsafe)
|
||||
sd.tsCond.Broadcast()
|
||||
}
|
||||
sd.tsCond.L.Unlock()
|
||||
}
|
||||
|
||||
// Close closes the delegator.
|
||||
func (sd *shardDelegator) Close() {
|
||||
sd.lifetime.SetState(stopped)
|
||||
sd.lifetime.Close()
|
||||
sd.wg.Wait()
|
||||
}
|
||||
|
||||
// NewShardDelegator creates a new ShardDelegator instance with all fields initialized.
|
||||
func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string, version int64,
|
||||
workerManager cluster.Manager, manager *segments.Manager, tsafeManager tsafe.Manager, loader segments.Loader,
|
||||
factory msgstream.Factory, startTs uint64) (ShardDelegator, error) {
|
||||
|
||||
collection := manager.Collection.Get(collectionID)
|
||||
if collection == nil {
|
||||
return nil, fmt.Errorf("collection(%d) not found in manager", collectionID)
|
||||
}
|
||||
|
||||
maxSegmentDeleteBuffer := paramtable.Get().QueryNodeCfg.MaxSegmentDeleteBuffer.GetAsInt64()
|
||||
log.Info("Init delte cache", zap.Int64("maxSegmentCacheBuffer", maxSegmentDeleteBuffer), zap.Time("startTime", tsoutil.PhysicalTime(startTs)))
|
||||
|
||||
sd := &shardDelegator{
|
||||
collectionID: collectionID,
|
||||
replicaID: replicaID,
|
||||
vchannelName: channel,
|
||||
version: version,
|
||||
collection: collection,
|
||||
segmentManager: manager.Segment,
|
||||
workerManager: workerManager,
|
||||
lifetime: newLifetime(),
|
||||
distribution: NewDistribution(),
|
||||
deleteBuffer: deletebuffer.NewDoubleCacheDeleteBuffer[*deletebuffer.Item](startTs, maxSegmentDeleteBuffer),
|
||||
pkOracle: pkoracle.NewPkOracle(),
|
||||
tsafeManager: tsafeManager,
|
||||
latestTsafe: atomic.NewUint64(0),
|
||||
loader: loader,
|
||||
factory: factory,
|
||||
}
|
||||
m := sync.Mutex{}
|
||||
sd.tsCond = sync.NewCond(&m)
|
||||
sd.wg.Add(1)
|
||||
go sd.watchTSafe()
|
||||
return sd, nil
|
||||
}
|
|
@ -0,0 +1,615 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package delegator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator/deletebuffer"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/merr"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
"github.com/milvus-io/milvus/internal/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// delegator data related part
|
||||
|
||||
// InsertData
|
||||
type InsertData struct {
|
||||
RowIDs []int64
|
||||
PrimaryKeys []storage.PrimaryKey
|
||||
Timestamps []uint64
|
||||
InsertRecord *segcorepb.InsertRecord
|
||||
StartPosition *msgpb.MsgPosition
|
||||
PartitionID int64
|
||||
}
|
||||
|
||||
type DeleteData struct {
|
||||
PartitionID int64
|
||||
PrimaryKeys []storage.PrimaryKey
|
||||
Timestamps []uint64
|
||||
RowCount int64
|
||||
}
|
||||
|
||||
// Append appends another delete data into this one.
|
||||
func (d *DeleteData) Append(ad DeleteData) {
|
||||
d.PrimaryKeys = append(d.PrimaryKeys, ad.PrimaryKeys...)
|
||||
d.Timestamps = append(d.Timestamps, ad.Timestamps...)
|
||||
d.RowCount += ad.RowCount
|
||||
}
|
||||
|
||||
func (sd *shardDelegator) newGrowing(segmentID int64, insertData *InsertData) segments.Segment {
|
||||
log := sd.getLogger(context.Background()).With(zap.Int64("segmentID", segmentID))
|
||||
|
||||
// try add partition
|
||||
if sd.collection.GetLoadType() == loadTypeCollection {
|
||||
sd.collection.AddPartition(insertData.PartitionID)
|
||||
}
|
||||
|
||||
segment, err := segments.NewSegment(sd.collection, segmentID, insertData.PartitionID, sd.collectionID, sd.vchannelName, segments.SegmentTypeGrowing, 0, insertData.StartPosition, insertData.StartPosition)
|
||||
if err != nil {
|
||||
log.Error("failed to create new segment", zap.Error(err))
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sd.pkOracle.Register(segment, paramtable.GetNodeID())
|
||||
|
||||
sd.segmentManager.Put(segments.SegmentTypeGrowing, segment)
|
||||
sd.addGrowing(SegmentEntry{
|
||||
NodeID: paramtable.GetNodeID(),
|
||||
SegmentID: segmentID,
|
||||
PartitionID: insertData.PartitionID,
|
||||
Version: 0,
|
||||
})
|
||||
return segment
|
||||
}
|
||||
|
||||
// ProcessInsert handles insert data in delegator.
|
||||
func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) {
|
||||
log := sd.getLogger(context.Background())
|
||||
for segmentID, insertData := range insertRecords {
|
||||
growing := sd.segmentManager.GetGrowing(segmentID)
|
||||
if growing == nil {
|
||||
growing = sd.newGrowing(segmentID, insertData)
|
||||
}
|
||||
|
||||
err := growing.Insert(insertData.RowIDs, insertData.Timestamps, insertData.InsertRecord)
|
||||
if err != nil {
|
||||
log.Error("failed to insert data into growing segment",
|
||||
zap.Int64("segmentID", segmentID),
|
||||
zap.Error(err),
|
||||
)
|
||||
// panic here, insert failure
|
||||
panic(err)
|
||||
}
|
||||
growing.UpdateBloomFilter(insertData.PrimaryKeys)
|
||||
|
||||
log.Debug("insert into growing segment",
|
||||
zap.Int64("collectionID", growing.Collection()),
|
||||
zap.Int64("segmentID", segmentID),
|
||||
zap.Int("rowCount", len(insertData.RowIDs)),
|
||||
zap.Uint64("maxTimestamp", insertData.Timestamps[len(insertData.Timestamps)-1]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessDelete handles delete data in delegator.
|
||||
// delegator puts deleteData into buffer first,
|
||||
// then dispatch data to segments acoording to the result of pkOracle.
|
||||
func (sd *shardDelegator) ProcessDelete(deleteData []*DeleteData, ts uint64) {
|
||||
// block load segment handle delete buffer
|
||||
sd.deleteMut.Lock()
|
||||
defer sd.deleteMut.Unlock()
|
||||
|
||||
log := sd.getLogger(context.Background())
|
||||
|
||||
log.Debug("start to process delete", zap.Uint64("ts", ts))
|
||||
// add deleteData into buffer.
|
||||
cacheItems := make([]deletebuffer.BufferItem, 0, len(deleteData))
|
||||
for _, entry := range deleteData {
|
||||
cacheItems = append(cacheItems, deletebuffer.BufferItem{
|
||||
PartitionID: entry.PartitionID,
|
||||
DeleteData: storage.DeleteData{
|
||||
Pks: entry.PrimaryKeys,
|
||||
Tss: entry.Timestamps,
|
||||
RowCount: entry.RowCount,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
sd.deleteBuffer.Put(&deletebuffer.Item{
|
||||
Ts: ts,
|
||||
Data: cacheItems,
|
||||
})
|
||||
|
||||
// segment => delete data
|
||||
delRecords := make(map[int64]DeleteData)
|
||||
for _, data := range deleteData {
|
||||
for i, pk := range data.PrimaryKeys {
|
||||
segmentIDs, err := sd.pkOracle.Get(pk, pkoracle.WithPartitionID(data.PartitionID))
|
||||
if err != nil {
|
||||
log.Warn("failed to get delete candidates for pk", zap.Any("pk", pk.GetValue()))
|
||||
continue
|
||||
}
|
||||
for _, segmentID := range segmentIDs {
|
||||
delRecord := delRecords[segmentID]
|
||||
delRecord.PrimaryKeys = append(delRecord.PrimaryKeys, pk)
|
||||
delRecord.Timestamps = append(delRecord.Timestamps, data.Timestamps[i])
|
||||
delRecord.RowCount++
|
||||
delRecords[segmentID] = delRecord
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
offlineSegments := typeutil.NewConcurrentSet[int64]()
|
||||
|
||||
sealed, growing, version := sd.distribution.GetCurrent()
|
||||
|
||||
eg, ctx := errgroup.WithContext(context.Background())
|
||||
for _, entry := range sealed {
|
||||
entry := entry
|
||||
eg.Go(func() error {
|
||||
worker, err := sd.workerManager.GetWorker(entry.NodeID)
|
||||
if err != nil {
|
||||
log.Warn("failed to get worker",
|
||||
zap.Int64("nodeID", paramtable.GetNodeID()),
|
||||
zap.Error(err),
|
||||
)
|
||||
// skip if node down
|
||||
// delete will be processed after loaded again
|
||||
return nil
|
||||
}
|
||||
offlineSegments.Upsert(sd.applyDelete(ctx, entry.NodeID, worker, delRecords, entry.Segments)...)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
if len(growing) > 0 {
|
||||
eg.Go(func() error {
|
||||
worker, err := sd.workerManager.GetWorker(paramtable.GetNodeID())
|
||||
if err != nil {
|
||||
log.Error("failed to get worker(local)",
|
||||
zap.Int64("nodeID", paramtable.GetNodeID()),
|
||||
zap.Error(err),
|
||||
)
|
||||
// panic here, local worker shall not have error
|
||||
panic(err)
|
||||
}
|
||||
offlineSegments.Upsert(sd.applyDelete(ctx, paramtable.GetNodeID(), worker, delRecords, growing)...)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// not error return in apply delete
|
||||
_ = eg.Wait()
|
||||
|
||||
sd.distribution.FinishUsage(version)
|
||||
offlineSegIDs := offlineSegments.Collect()
|
||||
if len(offlineSegIDs) > 0 {
|
||||
log.Warn("failed to apply delete, mark segment offline", zap.Int64s("offlineSegments", offlineSegIDs))
|
||||
sd.markSegmentOffline(offlineSegIDs...)
|
||||
}
|
||||
}
|
||||
|
||||
// applyDelete handles delete record and apply them to corresponding workers.
|
||||
func (sd *shardDelegator) applyDelete(ctx context.Context, nodeID int64, worker cluster.Worker, delRecords map[int64]DeleteData, entries []SegmentEntry) []int64 {
|
||||
var offlineSegments []int64
|
||||
log := sd.getLogger(ctx)
|
||||
for _, segmentEntry := range entries {
|
||||
log := log.With(
|
||||
zap.Int64("segmentID", segmentEntry.SegmentID),
|
||||
zap.Int64("workerID", nodeID),
|
||||
)
|
||||
delRecord, ok := delRecords[segmentEntry.SegmentID]
|
||||
if ok {
|
||||
log.Debug("delegator plan to applyDelete via worker")
|
||||
err := retry.Do(ctx, func() error {
|
||||
err := worker.Delete(ctx, &querypb.DeleteRequest{
|
||||
Base: commonpbutil.NewMsgBase(commonpbutil.WithTargetID(nodeID)),
|
||||
CollectionId: sd.collectionID,
|
||||
PartitionId: segmentEntry.PartitionID,
|
||||
VchannelName: sd.vchannelName,
|
||||
SegmentId: segmentEntry.SegmentID,
|
||||
PrimaryKeys: storage.ParsePrimaryKeys2IDs(delRecord.PrimaryKeys),
|
||||
Timestamps: delRecord.Timestamps,
|
||||
})
|
||||
if errors.Is(err, merr.ErrSegmentNotFound) {
|
||||
log.Warn("try to delete data of released segment")
|
||||
return nil
|
||||
} else if err != nil {
|
||||
log.Warn("worker failed to delete on segment",
|
||||
zap.Error(err),
|
||||
)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}, retry.Attempts(10))
|
||||
if err != nil {
|
||||
log.Warn("apply delete for segment failed, marking it offline")
|
||||
offlineSegments = append(offlineSegments, segmentEntry.SegmentID)
|
||||
}
|
||||
}
|
||||
}
|
||||
return offlineSegments
|
||||
}
|
||||
|
||||
// markSegmentOffline makes segment go offline and waits for QueryCoord to fix.
|
||||
func (sd *shardDelegator) markSegmentOffline(segmentIDs ...int64) {
|
||||
sd.distribution.AddOfflines(segmentIDs...)
|
||||
}
|
||||
|
||||
// addGrowing add growing segment record for delegator.
|
||||
func (sd *shardDelegator) addGrowing(entries ...SegmentEntry) {
|
||||
log := sd.getLogger(context.Background())
|
||||
log.Info("add growing segments to delegator", zap.Int64s("segmentIDs", lo.Map(entries, func(entry SegmentEntry, _ int) int64 {
|
||||
return entry.SegmentID
|
||||
})))
|
||||
sd.distribution.AddGrowing(entries...)
|
||||
}
|
||||
|
||||
// LoadGrowing load growing segments locally.
|
||||
func (sd *shardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error {
|
||||
log := sd.getLogger(ctx)
|
||||
|
||||
loaded, err := sd.loader.Load(ctx, sd.collectionID, segments.SegmentTypeGrowing, version, infos...)
|
||||
if err != nil {
|
||||
log.Warn("failed to load growing segment", zap.Error(err))
|
||||
for _, segment := range loaded {
|
||||
segments.DeleteSegment(segment.(*segments.LocalSegment))
|
||||
}
|
||||
return err
|
||||
}
|
||||
for _, candidate := range loaded {
|
||||
sd.pkOracle.Register(candidate, paramtable.GetNodeID())
|
||||
}
|
||||
sd.segmentManager.Put(segments.SegmentTypeGrowing, loaded...)
|
||||
sd.addGrowing(lo.Map(loaded, func(segment segments.Segment, _ int) SegmentEntry {
|
||||
return SegmentEntry{
|
||||
NodeID: paramtable.GetNodeID(),
|
||||
SegmentID: segment.ID(),
|
||||
PartitionID: segment.Partition(),
|
||||
Version: version,
|
||||
}
|
||||
})...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadSegments load segments local or remotely depends on the target node.
|
||||
func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error {
|
||||
log := sd.getLogger(ctx)
|
||||
|
||||
targetNodeID := req.GetDstNodeID()
|
||||
// add common log fields
|
||||
log = log.With(
|
||||
zap.Int64("workID", req.GetDstNodeID()),
|
||||
zap.Int64s("segments", lo.Map(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) int64 { return info.GetSegmentID() })),
|
||||
)
|
||||
|
||||
worker, err := sd.workerManager.GetWorker(targetNodeID)
|
||||
if err != nil {
|
||||
log.Warn("delegator failed to find worker", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// load bloom filter only when candidate not exists
|
||||
infos := lo.Filter(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) bool {
|
||||
return !sd.pkOracle.Exists(pkoracle.NewCandidateKey(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed), targetNodeID)
|
||||
})
|
||||
candidates, err := sd.loader.LoadBloomFilterSet(ctx, req.GetCollectionID(), req.GetVersion(), infos...)
|
||||
if err != nil {
|
||||
log.Warn("failed to load bloom filter set for segment", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
req.Base.TargetID = req.GetDstNodeID()
|
||||
log.Info("worker loads segments...")
|
||||
err = worker.LoadSegments(ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("worker failed to load segments", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Info("work loads segments done")
|
||||
|
||||
log.Info("load delete...")
|
||||
err = sd.loadStreamDelete(ctx, candidates, infos, targetNodeID, worker)
|
||||
if err != nil {
|
||||
log.Warn("load stream delete failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Info("load delete done")
|
||||
// alter distribution
|
||||
entries := lo.Map(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) SegmentEntry {
|
||||
return SegmentEntry{
|
||||
SegmentID: info.GetSegmentID(),
|
||||
PartitionID: info.GetPartitionID(),
|
||||
NodeID: req.GetDstNodeID(),
|
||||
Version: req.GetVersion(),
|
||||
}
|
||||
})
|
||||
removed := sd.distribution.AddDistributions(entries...)
|
||||
|
||||
// call worker release async
|
||||
if len(removed) > 0 {
|
||||
go func() {
|
||||
worker, err := sd.workerManager.GetWorker(paramtable.GetNodeID())
|
||||
if err != nil {
|
||||
log.Warn("failed to get local worker when try to release related growing", zap.Error(err))
|
||||
return
|
||||
}
|
||||
err = worker.ReleaseSegments(context.Background(), &querypb.ReleaseSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(commonpbutil.WithTargetID(paramtable.GetNodeID())),
|
||||
CollectionID: sd.collectionID,
|
||||
NodeID: paramtable.GetNodeID(),
|
||||
Scope: querypb.DataScope_Streaming,
|
||||
SegmentIDs: removed,
|
||||
Shard: sd.vchannelName,
|
||||
NeedTransfer: false,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn("failed to call release segments(local)", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sd *shardDelegator) loadStreamDelete(ctx context.Context, candidates []*pkoracle.BloomFilterSet, infos []*querypb.SegmentLoadInfo,
|
||||
targetNodeID int64, worker cluster.Worker) error {
|
||||
log := sd.getLogger(ctx)
|
||||
sd.deleteMut.Lock()
|
||||
defer sd.deleteMut.Unlock()
|
||||
|
||||
idCandidates := lo.SliceToMap(candidates, func(candidate *pkoracle.BloomFilterSet) (int64, *pkoracle.BloomFilterSet) {
|
||||
return candidate.ID(), candidate
|
||||
})
|
||||
|
||||
// apply buffered delete for new segments
|
||||
// no goroutines here since qnv2 has no load merging logic
|
||||
for _, info := range infos {
|
||||
candidate := idCandidates[info.GetSegmentID()]
|
||||
|
||||
deleteData := &storage.DeleteData{}
|
||||
// start position is dml position for segment
|
||||
// if this position is before deleteBuffer's safe ts, it means some delete shall be read from msgstream
|
||||
if info.GetEndPosition().GetTimestamp() < sd.deleteBuffer.SafeTs() {
|
||||
log.Info("load delete from stream...")
|
||||
var err error
|
||||
deleteData, err = sd.readDeleteFromMsgstream(ctx, info.GetEndPosition(), sd.deleteBuffer.SafeTs(), candidate)
|
||||
if err != nil {
|
||||
log.Warn("failed to read delete data from msgstream", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Info("load delete from stream done")
|
||||
}
|
||||
|
||||
// list buffered delete
|
||||
deleteRecords := sd.deleteBuffer.ListAfter(info.GetEndPosition().GetTimestamp())
|
||||
for _, entry := range deleteRecords {
|
||||
for _, record := range entry.Data {
|
||||
if record.PartitionID != common.InvalidPartitionID && candidate.Partition() != record.PartitionID {
|
||||
continue
|
||||
}
|
||||
for i, pk := range record.DeleteData.Pks {
|
||||
if candidate.MayPkExist(pk) {
|
||||
deleteData.Pks = append(deleteData.Pks, pk)
|
||||
deleteData.Tss = append(deleteData.Tss, record.DeleteData.Tss[i])
|
||||
deleteData.RowCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// if delete count not empty, apply
|
||||
if deleteData.RowCount > 0 {
|
||||
log.Info("forward delete to worker...", zap.Int64("deleteRowNum", deleteData.RowCount))
|
||||
err := worker.Delete(ctx, &querypb.DeleteRequest{
|
||||
Base: commonpbutil.NewMsgBase(commonpbutil.WithTargetID(targetNodeID)),
|
||||
CollectionId: info.GetCollectionID(),
|
||||
PartitionId: info.GetPartitionID(),
|
||||
SegmentId: info.GetSegmentID(),
|
||||
PrimaryKeys: storage.ParsePrimaryKeys2IDs(deleteData.Pks),
|
||||
Timestamps: deleteData.Tss,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn("failed to apply delete when LoadSegment", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add candidate after load success
|
||||
for _, candidate := range candidates {
|
||||
log.Info("register sealed segment bfs into pko candidates",
|
||||
zap.Int64("segmentID", candidate.ID()),
|
||||
)
|
||||
sd.pkOracle.Register(candidate, targetNodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position *msgpb.MsgPosition, safeTs uint64, candidate *pkoracle.BloomFilterSet) (*storage.DeleteData, error) {
|
||||
|
||||
log := sd.getLogger(ctx).With(
|
||||
zap.String("channel", position.ChannelName),
|
||||
zap.Int64("segmentID", candidate.ID()),
|
||||
)
|
||||
stream, err := sd.factory.NewTtMsgStream(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vchannelName := position.ChannelName
|
||||
pChannelName := funcutil.ToPhysicalChannel(vchannelName)
|
||||
position.ChannelName = pChannelName
|
||||
|
||||
ts, _ := tsoutil.ParseTS(position.Timestamp)
|
||||
|
||||
// Random the subname in case we trying to load same delta at the same time
|
||||
subName := fmt.Sprintf("querynode-delta-loader-%d-%d-%d", paramtable.GetNodeID(), sd.collectionID, rand.Int())
|
||||
log.Info("from dml check point load delete", zap.Any("position", position), zap.String("subName", subName), zap.Time("positionTs", ts))
|
||||
stream.AsConsumer([]string{pChannelName}, subName, mqwrapper.SubscriptionPositionUnknown)
|
||||
|
||||
err = stream.Seek([]*msgpb.MsgPosition{position})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &storage.DeleteData{}
|
||||
hasMore := true
|
||||
for hasMore {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debug("read delta msg from seek position done", zap.Error(ctx.Err()))
|
||||
return nil, ctx.Err()
|
||||
case msgPack, ok := <-stream.Chan():
|
||||
if !ok {
|
||||
err = fmt.Errorf("stream channel closed, pChannelName=%v, msgID=%v", pChannelName, position.GetMsgID())
|
||||
log.Warn("fail to read delta msg",
|
||||
zap.String("pChannelName", pChannelName),
|
||||
zap.Binary("msgID", position.GetMsgID()),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if msgPack == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, tsMsg := range msgPack.Msgs {
|
||||
if tsMsg.Type() == commonpb.MsgType_Delete {
|
||||
dmsg := tsMsg.(*msgstream.DeleteMsg)
|
||||
if dmsg.CollectionID != sd.collectionID || dmsg.GetPartitionID() != candidate.Partition() {
|
||||
continue
|
||||
}
|
||||
|
||||
for idx, pk := range storage.ParseIDs2PrimaryKeys(dmsg.GetPrimaryKeys()) {
|
||||
if candidate.MayPkExist(pk) {
|
||||
result.Pks = append(result.Pks, pk)
|
||||
result.Tss = append(result.Tss, dmsg.Timestamps[idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reach safe ts
|
||||
if safeTs <= msgPack.EndPositions[0].GetTimestamp() {
|
||||
hasMore = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ReleaseSegments releases segments local or remotely depends ont the target node.
|
||||
func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error {
|
||||
log := sd.getLogger(ctx)
|
||||
|
||||
targetNodeID := req.GetNodeID()
|
||||
// add common log fields
|
||||
log = log.With(
|
||||
zap.Int64s("segmentIDs", req.GetSegmentIDs()),
|
||||
zap.Int64("nodeID", req.GetNodeID()),
|
||||
zap.String("scope", req.GetScope().String()),
|
||||
zap.Bool("force", force))
|
||||
|
||||
log.Info("delegator start to release segments")
|
||||
// alter distribution first
|
||||
if force {
|
||||
targetNodeID = wildcardNodeID
|
||||
}
|
||||
|
||||
var sealed, growing []SegmentEntry
|
||||
convertSealed := func(segmentID int64, _ int) SegmentEntry {
|
||||
return SegmentEntry{
|
||||
SegmentID: segmentID,
|
||||
NodeID: targetNodeID,
|
||||
}
|
||||
}
|
||||
convertGrowing := func(segmentID int64, _ int) SegmentEntry {
|
||||
return SegmentEntry{
|
||||
SegmentID: segmentID,
|
||||
}
|
||||
}
|
||||
switch req.GetScope() {
|
||||
case querypb.DataScope_All:
|
||||
sealed = lo.Map(req.GetSegmentIDs(), convertSealed)
|
||||
growing = lo.Map(req.GetSegmentIDs(), convertGrowing)
|
||||
case querypb.DataScope_Streaming:
|
||||
growing = lo.Map(req.GetSegmentIDs(), convertGrowing)
|
||||
case querypb.DataScope_Historical:
|
||||
sealed = lo.Map(req.GetSegmentIDs(), convertSealed)
|
||||
}
|
||||
|
||||
signal := sd.distribution.RemoveDistributions(sealed, growing)
|
||||
// wait cleared signal
|
||||
<-signal
|
||||
if len(sealed) > 0 {
|
||||
sd.pkOracle.Remove(
|
||||
pkoracle.WithSegmentIDs(lo.Map(sealed, func(entry SegmentEntry, _ int) int64 { return entry.SegmentID })...),
|
||||
pkoracle.WithSegmentType(commonpb.SegmentState_Sealed),
|
||||
pkoracle.WithWorkerID(targetNodeID),
|
||||
)
|
||||
}
|
||||
if len(growing) > 0 {
|
||||
sd.pkOracle.Remove(
|
||||
pkoracle.WithSegmentIDs(lo.Map(growing, func(entry SegmentEntry, _ int) int64 { return entry.SegmentID })...),
|
||||
pkoracle.WithSegmentType(commonpb.SegmentState_Growing),
|
||||
)
|
||||
}
|
||||
|
||||
if !force {
|
||||
worker, err := sd.workerManager.GetWorker(targetNodeID)
|
||||
if err != nil {
|
||||
log.Warn("delegator failed to find worker",
|
||||
zap.Error(err),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
err = worker.ReleaseSegments(ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("worker failed to release segments",
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,657 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package delegator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/bits-and-blooms/bloom/v3"
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type DelegatorDataSuite struct {
|
||||
suite.Suite
|
||||
|
||||
collectionID int64
|
||||
replicaID int64
|
||||
vchannelName string
|
||||
version int64
|
||||
workerManager *cluster.MockManager
|
||||
manager *segments.Manager
|
||||
tsafeManager tsafe.Manager
|
||||
loader *segments.MockLoader
|
||||
mq *msgstream.MockMsgStream
|
||||
|
||||
delegator ShardDelegator
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) SetupSuite() {
|
||||
paramtable.Init()
|
||||
paramtable.SetNodeID(1)
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) SetupTest() {
|
||||
s.collectionID = 1000
|
||||
s.replicaID = 65535
|
||||
s.vchannelName = "rootcoord-dml_1000_v0"
|
||||
s.version = 2000
|
||||
s.workerManager = &cluster.MockManager{}
|
||||
s.manager = segments.NewManager()
|
||||
s.tsafeManager = tsafe.NewTSafeReplica()
|
||||
s.loader = &segments.MockLoader{}
|
||||
|
||||
// init schema
|
||||
s.manager.Collection.Put(s.collectionID, &schemapb.CollectionSchema{
|
||||
Name: "TestCollection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "id",
|
||||
FieldID: 100,
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
AutoID: true,
|
||||
},
|
||||
{
|
||||
Name: "vector",
|
||||
FieldID: 101,
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_BinaryVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "128",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, &querypb.LoadMetaInfo{
|
||||
LoadType: querypb.LoadType_LoadCollection,
|
||||
})
|
||||
|
||||
s.mq = &msgstream.MockMsgStream{}
|
||||
|
||||
var err error
|
||||
s.delegator, err = NewShardDelegator(s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) TestProcessInsert() {
|
||||
s.Run("normal_insert", func() {
|
||||
s.delegator.ProcessInsert(map[int64]*InsertData{
|
||||
100: {
|
||||
RowIDs: []int64{0, 1},
|
||||
PrimaryKeys: []storage.PrimaryKey{storage.NewInt64PrimaryKey(1), storage.NewInt64PrimaryKey(2)},
|
||||
Timestamps: []uint64{10, 10},
|
||||
PartitionID: 500,
|
||||
StartPosition: &msgpb.MsgPosition{},
|
||||
InsertRecord: &segcorepb.InsertRecord{
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "id",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{1, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
FieldId: 100,
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: "vector",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 128,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{Data: make([]float32, 128*2)},
|
||||
},
|
||||
},
|
||||
},
|
||||
FieldId: 101,
|
||||
},
|
||||
},
|
||||
NumRows: 2,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
s.NotNil(s.manager.Segment.GetGrowing(100))
|
||||
})
|
||||
|
||||
s.Run("insert_bad_data", func() {
|
||||
s.Panics(func() {
|
||||
s.delegator.ProcessInsert(map[int64]*InsertData{
|
||||
100: {
|
||||
RowIDs: []int64{0, 1},
|
||||
PrimaryKeys: []storage.PrimaryKey{storage.NewInt64PrimaryKey(1), storage.NewInt64PrimaryKey(2)},
|
||||
Timestamps: []uint64{10, 10},
|
||||
PartitionID: 500,
|
||||
StartPosition: &msgpb.MsgPosition{},
|
||||
InsertRecord: &segcorepb.InsertRecord{
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "id",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{1, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
FieldId: 100,
|
||||
},
|
||||
},
|
||||
NumRows: 2,
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) TestProcessDelete() {
|
||||
s.loader.EXPECT().
|
||||
Load(mock.Anything, s.collectionID, segments.SegmentTypeGrowing, int64(0), mock.Anything).
|
||||
Call.Return(func(ctx context.Context, collectionID int64, segmentType segments.SegmentType, version int64, infos ...*querypb.SegmentLoadInfo) []segments.Segment {
|
||||
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) segments.Segment {
|
||||
ms := &segments.MockSegment{}
|
||||
ms.EXPECT().ID().Return(info.GetSegmentID())
|
||||
ms.EXPECT().Type().Return(segments.SegmentTypeGrowing)
|
||||
ms.EXPECT().Collection().Return(info.GetCollectionID())
|
||||
ms.EXPECT().Partition().Return(info.GetPartitionID())
|
||||
ms.EXPECT().Indexes().Return(nil)
|
||||
ms.EXPECT().RowNum().Return(info.GetNumOfRows())
|
||||
ms.EXPECT().MayPkExist(mock.Anything).Call.Return(func(pk storage.PrimaryKey) bool {
|
||||
return pk.EQ(storage.NewInt64PrimaryKey(10))
|
||||
})
|
||||
return ms
|
||||
})
|
||||
}, nil)
|
||||
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
|
||||
Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet {
|
||||
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet {
|
||||
bfs := pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed)
|
||||
bf := bloom.NewWithEstimates(storage.BloomFilterSize, storage.MaxBloomFalsePositive)
|
||||
pks := &storage.PkStatistics{
|
||||
PkFilter: bf,
|
||||
}
|
||||
pks.UpdatePKRange(&storage.Int64FieldData{
|
||||
Data: []int64{10, 20, 30},
|
||||
})
|
||||
bfs.AddHistoricalStats(pks)
|
||||
return bfs
|
||||
})
|
||||
}, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
workers[1] = worker1
|
||||
|
||||
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
|
||||
Return(nil)
|
||||
worker1.EXPECT().Delete(mock.Anything, mock.AnythingOfType("*querypb.DeleteRequest")).Return(nil)
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
// load growing
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.delegator.LoadGrowing(ctx, []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: 1001,
|
||||
CollectionID: s.collectionID,
|
||||
PartitionID: 500,
|
||||
},
|
||||
}, 0)
|
||||
s.Require().NoError(err)
|
||||
// load sealed
|
||||
s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
DstNodeID: 1,
|
||||
CollectionID: s.collectionID,
|
||||
Infos: []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: 1000,
|
||||
CollectionID: s.collectionID,
|
||||
PartitionID: 500,
|
||||
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
},
|
||||
},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.delegator.ProcessDelete([]*DeleteData{
|
||||
{
|
||||
PartitionID: 500,
|
||||
PrimaryKeys: []storage.PrimaryKey{storage.NewInt64PrimaryKey(10)},
|
||||
Timestamps: []uint64{10},
|
||||
RowCount: 1,
|
||||
},
|
||||
}, 10)
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) TestLoadSegments() {
|
||||
s.Run("normal_run", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
s.loader.ExpectedCalls = nil
|
||||
}()
|
||||
|
||||
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
|
||||
Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet {
|
||||
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet {
|
||||
return pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed)
|
||||
})
|
||||
}, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
workers[1] = worker1
|
||||
|
||||
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
|
||||
Return(nil)
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
DstNodeID: 1,
|
||||
CollectionID: s.collectionID,
|
||||
Infos: []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: 100,
|
||||
PartitionID: 500,
|
||||
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
s.NoError(err)
|
||||
sealed, _, _ := s.delegator.GetDistribution().GetCurrent()
|
||||
s.Require().Equal(1, len(sealed))
|
||||
s.Equal(int64(1), sealed[0].NodeID)
|
||||
s.ElementsMatch([]SegmentEntry{
|
||||
{
|
||||
SegmentID: 100,
|
||||
NodeID: 1,
|
||||
PartitionID: 500,
|
||||
},
|
||||
}, sealed[0].Segments)
|
||||
})
|
||||
|
||||
s.Run("load_segments_with_delete", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
s.loader.ExpectedCalls = nil
|
||||
}()
|
||||
|
||||
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
|
||||
Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet {
|
||||
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet {
|
||||
bfs := pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed)
|
||||
bf := bloom.NewWithEstimates(storage.BloomFilterSize, storage.MaxBloomFalsePositive)
|
||||
pks := &storage.PkStatistics{
|
||||
PkFilter: bf,
|
||||
}
|
||||
pks.UpdatePKRange(&storage.Int64FieldData{
|
||||
Data: []int64{10, 20, 30},
|
||||
})
|
||||
bfs.AddHistoricalStats(pks)
|
||||
return bfs
|
||||
})
|
||||
}, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
workers[1] = worker1
|
||||
|
||||
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
|
||||
Return(nil)
|
||||
worker1.EXPECT().Delete(mock.Anything, mock.AnythingOfType("*querypb.DeleteRequest")).Return(nil)
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
s.delegator.ProcessDelete([]*DeleteData{
|
||||
{
|
||||
PartitionID: 500,
|
||||
PrimaryKeys: []storage.PrimaryKey{
|
||||
storage.NewInt64PrimaryKey(1),
|
||||
storage.NewInt64PrimaryKey(10),
|
||||
},
|
||||
Timestamps: []uint64{10, 10},
|
||||
RowCount: 2,
|
||||
},
|
||||
}, 10)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
DstNodeID: 1,
|
||||
CollectionID: s.collectionID,
|
||||
Infos: []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: 200,
|
||||
PartitionID: 500,
|
||||
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
s.NoError(err)
|
||||
sealed, _, _ := s.delegator.GetDistribution().GetCurrent()
|
||||
s.Require().Equal(1, len(sealed))
|
||||
s.Equal(int64(1), sealed[0].NodeID)
|
||||
s.ElementsMatch([]SegmentEntry{
|
||||
{
|
||||
SegmentID: 100,
|
||||
NodeID: 1,
|
||||
PartitionID: 500,
|
||||
},
|
||||
{
|
||||
SegmentID: 200,
|
||||
NodeID: 1,
|
||||
PartitionID: 500,
|
||||
},
|
||||
}, sealed[0].Segments)
|
||||
|
||||
})
|
||||
|
||||
s.Run("get_worker_fail", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
s.loader.ExpectedCalls = nil
|
||||
}()
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Return(nil, errors.New("mock error"))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
DstNodeID: 1,
|
||||
CollectionID: s.collectionID,
|
||||
Infos: []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: 100,
|
||||
PartitionID: 500,
|
||||
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("loader_bfs_fail", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
s.loader.ExpectedCalls = nil
|
||||
}()
|
||||
|
||||
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
|
||||
Return(nil, errors.New("mocked error"))
|
||||
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
workers[1] = worker1
|
||||
|
||||
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
|
||||
Return(nil)
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
DstNodeID: 1,
|
||||
CollectionID: s.collectionID,
|
||||
Infos: []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: 100,
|
||||
PartitionID: 500,
|
||||
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("worker_load_fail", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
s.loader.ExpectedCalls = nil
|
||||
}()
|
||||
|
||||
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
|
||||
Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet {
|
||||
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet {
|
||||
return pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed)
|
||||
})
|
||||
}, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
workers[1] = worker1
|
||||
|
||||
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
|
||||
Return(errors.New("mocked error"))
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
DstNodeID: 1,
|
||||
CollectionID: s.collectionID,
|
||||
Infos: []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: 100,
|
||||
PartitionID: 500,
|
||||
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) TestReleaseSegment() {
|
||||
s.loader.EXPECT().
|
||||
Load(mock.Anything, s.collectionID, segments.SegmentTypeGrowing, int64(0), mock.Anything).
|
||||
Call.Return(func(ctx context.Context, collectionID int64, segmentType segments.SegmentType, version int64, infos ...*querypb.SegmentLoadInfo) []segments.Segment {
|
||||
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) segments.Segment {
|
||||
ms := &segments.MockSegment{}
|
||||
ms.EXPECT().ID().Return(info.GetSegmentID())
|
||||
ms.EXPECT().Type().Return(segments.SegmentTypeGrowing)
|
||||
ms.EXPECT().Partition().Return(info.GetPartitionID())
|
||||
ms.EXPECT().Collection().Return(info.GetCollectionID())
|
||||
ms.EXPECT().Indexes().Return(nil)
|
||||
ms.EXPECT().RowNum().Return(info.GetNumOfRows())
|
||||
ms.EXPECT().MayPkExist(mock.Anything).Call.Return(func(pk storage.PrimaryKey) bool {
|
||||
return pk.EQ(storage.NewInt64PrimaryKey(10))
|
||||
})
|
||||
return ms
|
||||
})
|
||||
}, nil)
|
||||
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
|
||||
Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet {
|
||||
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet {
|
||||
bfs := pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed)
|
||||
bf := bloom.NewWithEstimates(storage.BloomFilterSize, storage.MaxBloomFalsePositive)
|
||||
pks := &storage.PkStatistics{
|
||||
PkFilter: bf,
|
||||
}
|
||||
pks.UpdatePKRange(&storage.Int64FieldData{
|
||||
Data: []int64{10, 20, 30},
|
||||
})
|
||||
bfs.AddHistoricalStats(pks)
|
||||
return bfs
|
||||
})
|
||||
}, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
workers[1] = worker1
|
||||
|
||||
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
|
||||
Return(nil)
|
||||
worker1.EXPECT().ReleaseSegments(mock.Anything, mock.AnythingOfType("*querypb.ReleaseSegmentsRequest")).Return(nil)
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
// load growing
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.delegator.LoadGrowing(ctx, []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: 1001,
|
||||
CollectionID: s.collectionID,
|
||||
PartitionID: 500,
|
||||
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
},
|
||||
}, 0)
|
||||
s.Require().NoError(err)
|
||||
// load sealed
|
||||
s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
DstNodeID: 1,
|
||||
CollectionID: s.collectionID,
|
||||
Infos: []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: 1000,
|
||||
CollectionID: s.collectionID,
|
||||
PartitionID: 500,
|
||||
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
EndPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
},
|
||||
},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
sealed, growing, version := s.delegator.GetDistribution().GetCurrent()
|
||||
s.delegator.GetDistribution().FinishUsage(version)
|
||||
s.Require().Equal(1, len(sealed))
|
||||
s.Equal(int64(1), sealed[0].NodeID)
|
||||
s.ElementsMatch([]SegmentEntry{
|
||||
{
|
||||
SegmentID: 1000,
|
||||
NodeID: 1,
|
||||
PartitionID: 500,
|
||||
},
|
||||
}, sealed[0].Segments)
|
||||
|
||||
s.ElementsMatch([]SegmentEntry{
|
||||
{
|
||||
SegmentID: 1001,
|
||||
NodeID: 1,
|
||||
PartitionID: 500,
|
||||
},
|
||||
}, growing)
|
||||
|
||||
err = s.delegator.ReleaseSegments(ctx, &querypb.ReleaseSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
NodeID: 1,
|
||||
SegmentIDs: []int64{1000},
|
||||
Scope: querypb.DataScope_Historical,
|
||||
}, false)
|
||||
|
||||
s.NoError(err)
|
||||
sealed, _, version = s.delegator.GetDistribution().GetCurrent()
|
||||
s.delegator.GetDistribution().FinishUsage(version)
|
||||
s.Equal(0, len(sealed))
|
||||
|
||||
err = s.delegator.ReleaseSegments(ctx, &querypb.ReleaseSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
NodeID: 1,
|
||||
SegmentIDs: []int64{1001},
|
||||
Scope: querypb.DataScope_Streaming,
|
||||
}, false)
|
||||
|
||||
s.NoError(err)
|
||||
_, growing, version = s.delegator.GetDistribution().GetCurrent()
|
||||
s.delegator.GetDistribution().FinishUsage(version)
|
||||
s.Equal(0, len(growing))
|
||||
|
||||
err = s.delegator.ReleaseSegments(ctx, &querypb.ReleaseSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
NodeID: 1,
|
||||
SegmentIDs: []int64{1000},
|
||||
Scope: querypb.DataScope_All,
|
||||
}, true)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func TestDelegatorDataSuite(t *testing.T) {
|
||||
suite.Run(t, new(DelegatorDataSuite))
|
||||
}
|
|
@ -0,0 +1,831 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package delegator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type DelegatorSuite struct {
|
||||
suite.Suite
|
||||
|
||||
collectionID int64
|
||||
replicaID int64
|
||||
vchannelName string
|
||||
version int64
|
||||
workerManager *cluster.MockManager
|
||||
manager *segments.Manager
|
||||
tsafeManager tsafe.Manager
|
||||
loader *segments.MockLoader
|
||||
mq *msgstream.MockMsgStream
|
||||
|
||||
delegator ShardDelegator
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) SetupSuite() {
|
||||
paramtable.Init()
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) TearDownSuite() {
|
||||
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) SetupTest() {
|
||||
s.collectionID = 1000
|
||||
s.replicaID = 65535
|
||||
s.vchannelName = "rootcoord-dml_1000_v0"
|
||||
s.version = 2000
|
||||
s.workerManager = &cluster.MockManager{}
|
||||
s.manager = segments.NewManager()
|
||||
s.tsafeManager = tsafe.NewTSafeReplica()
|
||||
s.loader = &segments.MockLoader{}
|
||||
s.loader.EXPECT().
|
||||
Load(mock.Anything, s.collectionID, segments.SegmentTypeGrowing, int64(0), mock.Anything).
|
||||
Call.Return(func(ctx context.Context, collectionID int64, segmentType segments.SegmentType, version int64, infos ...*querypb.SegmentLoadInfo) []segments.Segment {
|
||||
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) segments.Segment {
|
||||
ms := &segments.MockSegment{}
|
||||
ms.EXPECT().ID().Return(info.GetSegmentID())
|
||||
ms.EXPECT().Type().Return(segments.SegmentTypeGrowing)
|
||||
ms.EXPECT().Partition().Return(info.GetPartitionID())
|
||||
ms.EXPECT().Collection().Return(info.GetCollectionID())
|
||||
ms.EXPECT().Indexes().Return(nil)
|
||||
ms.EXPECT().RowNum().Return(info.GetNumOfRows())
|
||||
return ms
|
||||
})
|
||||
}, nil)
|
||||
|
||||
// init schema
|
||||
s.manager.Collection.Put(s.collectionID, &schemapb.CollectionSchema{
|
||||
Name: "TestCollection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "id",
|
||||
FieldID: 100,
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
AutoID: true,
|
||||
},
|
||||
{
|
||||
Name: "vector",
|
||||
FieldID: 101,
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_BinaryVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "128",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, &querypb.LoadMetaInfo{})
|
||||
|
||||
s.mq = &msgstream.MockMsgStream{}
|
||||
|
||||
var err error
|
||||
// s.delegator, err = NewShardDelegator(s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader)
|
||||
s.delegator, err = NewShardDelegator(s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) TearDownTest() {
|
||||
s.delegator.Close()
|
||||
s.delegator = nil
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) TestBasicInfo() {
|
||||
s.Equal(s.collectionID, s.delegator.Collection())
|
||||
s.Equal(s.version, s.delegator.Version())
|
||||
s.False(s.delegator.Serviceable())
|
||||
s.delegator.Start()
|
||||
s.True(s.delegator.Serviceable())
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) TestDistribution() {
|
||||
sealed, growing, version := s.delegator.GetDistribution().GetCurrent()
|
||||
s.delegator.GetDistribution().FinishUsage(version)
|
||||
s.Equal(0, len(sealed))
|
||||
s.Equal(0, len(growing))
|
||||
|
||||
s.delegator.SyncDistribution(context.Background(), SegmentEntry{
|
||||
NodeID: 1,
|
||||
SegmentID: 1001,
|
||||
PartitionID: 500,
|
||||
Version: 2001,
|
||||
})
|
||||
|
||||
sealed, growing, version = s.delegator.GetDistribution().GetCurrent()
|
||||
s.EqualValues([]SnapshotItem{
|
||||
{
|
||||
NodeID: 1,
|
||||
Segments: []SegmentEntry{
|
||||
{
|
||||
NodeID: 1,
|
||||
SegmentID: 1001,
|
||||
PartitionID: 500,
|
||||
Version: 2001,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, sealed)
|
||||
s.Equal(0, len(growing))
|
||||
s.delegator.GetDistribution().FinishUsage(version)
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) TestSearch() {
|
||||
s.delegator.Start()
|
||||
// 1 => sealed segment 1000, 1001
|
||||
// 1 => growing segment 1004
|
||||
// 2 => sealed segment 1002, 1003
|
||||
paramtable.SetNodeID(1)
|
||||
s.delegator.LoadGrowing(context.Background(), []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: 1004,
|
||||
CollectionID: s.collectionID,
|
||||
PartitionID: 500,
|
||||
},
|
||||
}, 0)
|
||||
s.delegator.SyncDistribution(context.Background(),
|
||||
SegmentEntry{
|
||||
NodeID: 1,
|
||||
SegmentID: 1000,
|
||||
PartitionID: 500,
|
||||
Version: 2001,
|
||||
},
|
||||
SegmentEntry{
|
||||
NodeID: 1,
|
||||
SegmentID: 1001,
|
||||
PartitionID: 501,
|
||||
Version: 2001,
|
||||
},
|
||||
SegmentEntry{
|
||||
NodeID: 2,
|
||||
SegmentID: 1002,
|
||||
PartitionID: 500,
|
||||
Version: 2001,
|
||||
},
|
||||
SegmentEntry{
|
||||
NodeID: 2,
|
||||
SegmentID: 1003,
|
||||
PartitionID: 501,
|
||||
Version: 2001,
|
||||
},
|
||||
)
|
||||
s.Run("normal", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
}()
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
worker2 := &cluster.MockWorker{}
|
||||
|
||||
workers[1] = worker1
|
||||
workers[2] = worker2
|
||||
|
||||
worker1.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||
s.EqualValues(1, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
if req.GetScope() == querypb.DataScope_Streaming {
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1004}, req.GetSegmentIDs())
|
||||
}
|
||||
if req.GetScope() == querypb.DataScope_Historical {
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1000, 1001}, req.GetSegmentIDs())
|
||||
}
|
||||
}).Return(&internalpb.SearchResults{}, nil)
|
||||
worker2.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||
}).Return(&internalpb.SearchResults{}, nil)
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
results, err := s.delegator.Search(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.NoError(err)
|
||||
s.Equal(3, len(results))
|
||||
})
|
||||
|
||||
s.Run("worker_return_error", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
}()
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
worker2 := &cluster.MockWorker{}
|
||||
|
||||
workers[1] = worker1
|
||||
workers[2] = worker2
|
||||
|
||||
worker1.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(nil, errors.New("mock error"))
|
||||
worker2.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||
}).Return(&internalpb.SearchResults{}, nil)
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := s.delegator.Search(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("worker_return_failure_code", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
}()
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
worker2 := &cluster.MockWorker{}
|
||||
|
||||
workers[1] = worker1
|
||||
workers[2] = worker2
|
||||
|
||||
worker1.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).Return(&internalpb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "mocked error",
|
||||
},
|
||||
}, nil)
|
||||
worker2.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")).
|
||||
Run(func(_ context.Context, req *querypb.SearchRequest) {
|
||||
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||
}).Return(&internalpb.SearchResults{}, nil)
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := s.delegator.Search(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("wrong_channel", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.Search(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{"non_exist_channel"},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("wait_tsafe_timeout", func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.Search(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
GuaranteeTimestamp: 100,
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("tsafe_behind_max_lag", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.Search(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
GuaranteeTimestamp: uint64(paramtable.Get().QueryNodeCfg.MaxTimestampLag.GetAsDuration(time.Second)) + 1,
|
||||
},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("cluster_not_serviceable", func() {
|
||||
s.delegator.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.Search(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) TestQuery() {
|
||||
s.delegator.Start()
|
||||
// 1 => sealed segment 1000, 1001
|
||||
// 1 => growing segment 1004
|
||||
// 2 => sealed segment 1002, 1003
|
||||
paramtable.SetNodeID(1)
|
||||
s.delegator.LoadGrowing(context.Background(), []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: 1004,
|
||||
CollectionID: s.collectionID,
|
||||
PartitionID: 500,
|
||||
},
|
||||
}, 0)
|
||||
s.delegator.SyncDistribution(context.Background(),
|
||||
SegmentEntry{
|
||||
NodeID: 1,
|
||||
SegmentID: 1000,
|
||||
PartitionID: 500,
|
||||
Version: 2001,
|
||||
},
|
||||
SegmentEntry{
|
||||
NodeID: 1,
|
||||
SegmentID: 1001,
|
||||
PartitionID: 501,
|
||||
Version: 2001,
|
||||
},
|
||||
SegmentEntry{
|
||||
NodeID: 2,
|
||||
SegmentID: 1002,
|
||||
PartitionID: 500,
|
||||
Version: 2001,
|
||||
},
|
||||
SegmentEntry{
|
||||
NodeID: 2,
|
||||
SegmentID: 1003,
|
||||
PartitionID: 501,
|
||||
Version: 2001,
|
||||
},
|
||||
)
|
||||
s.Run("normal", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
}()
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
worker2 := &cluster.MockWorker{}
|
||||
|
||||
workers[1] = worker1
|
||||
workers[2] = worker2
|
||||
|
||||
worker1.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).
|
||||
Run(func(_ context.Context, req *querypb.QueryRequest) {
|
||||
s.EqualValues(1, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
if req.GetScope() == querypb.DataScope_Streaming {
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1004}, req.GetSegmentIDs())
|
||||
}
|
||||
if req.GetScope() == querypb.DataScope_Historical {
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1000, 1001}, req.GetSegmentIDs())
|
||||
}
|
||||
}).Return(&internalpb.RetrieveResults{}, nil)
|
||||
worker2.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).
|
||||
Run(func(_ context.Context, req *querypb.QueryRequest) {
|
||||
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||
}).Return(&internalpb.RetrieveResults{}, nil)
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
results, err := s.delegator.Query(ctx, &querypb.QueryRequest{
|
||||
Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.NoError(err)
|
||||
s.Equal(3, len(results))
|
||||
})
|
||||
|
||||
s.Run("worker_return_error", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
}()
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
worker2 := &cluster.MockWorker{}
|
||||
|
||||
workers[1] = worker1
|
||||
workers[2] = worker2
|
||||
|
||||
worker1.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).Return(nil, errors.New("mock error"))
|
||||
worker2.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).
|
||||
Run(func(_ context.Context, req *querypb.QueryRequest) {
|
||||
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||
}).Return(&internalpb.RetrieveResults{}, nil)
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := s.delegator.Query(ctx, &querypb.QueryRequest{
|
||||
Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("worker_return_failure_code", func() {
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
worker2 := &cluster.MockWorker{}
|
||||
|
||||
workers[1] = worker1
|
||||
workers[2] = worker2
|
||||
|
||||
worker1.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).Return(&internalpb.RetrieveResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "mocked error",
|
||||
},
|
||||
}, nil)
|
||||
worker2.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")).
|
||||
Run(func(_ context.Context, req *querypb.QueryRequest) {
|
||||
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||
}).Return(&internalpb.RetrieveResults{}, nil)
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := s.delegator.Query(ctx, &querypb.QueryRequest{
|
||||
Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("wrong_channel", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.Query(ctx, &querypb.QueryRequest{
|
||||
Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{"non_exist_channel"},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("cluster_not_serviceable", func() {
|
||||
s.delegator.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.Query(ctx, &querypb.QueryRequest{
|
||||
Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) TestGetStats() {
|
||||
s.delegator.Start()
|
||||
// 1 => sealed segment 1000, 1001
|
||||
// 1 => growing segment 1004
|
||||
// 2 => sealed segment 1002, 1003
|
||||
paramtable.SetNodeID(1)
|
||||
s.delegator.LoadGrowing(context.Background(), []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: 1004,
|
||||
CollectionID: s.collectionID,
|
||||
PartitionID: 500,
|
||||
},
|
||||
}, 0)
|
||||
s.delegator.SyncDistribution(context.Background(),
|
||||
SegmentEntry{
|
||||
NodeID: 1,
|
||||
SegmentID: 1000,
|
||||
PartitionID: 500,
|
||||
Version: 2001,
|
||||
},
|
||||
SegmentEntry{
|
||||
NodeID: 1,
|
||||
SegmentID: 1001,
|
||||
PartitionID: 501,
|
||||
Version: 2001,
|
||||
},
|
||||
SegmentEntry{
|
||||
NodeID: 2,
|
||||
SegmentID: 1002,
|
||||
PartitionID: 500,
|
||||
Version: 2001,
|
||||
},
|
||||
SegmentEntry{
|
||||
NodeID: 2,
|
||||
SegmentID: 1003,
|
||||
PartitionID: 501,
|
||||
Version: 2001,
|
||||
},
|
||||
)
|
||||
s.Run("normal", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
}()
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
worker2 := &cluster.MockWorker{}
|
||||
|
||||
workers[1] = worker1
|
||||
workers[2] = worker2
|
||||
|
||||
worker1.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).
|
||||
Run(func(_ context.Context, req *querypb.GetStatisticsRequest) {
|
||||
s.EqualValues(1, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
if req.GetScope() == querypb.DataScope_Streaming {
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1004}, req.GetSegmentIDs())
|
||||
}
|
||||
if req.GetScope() == querypb.DataScope_Historical {
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1000, 1001}, req.GetSegmentIDs())
|
||||
}
|
||||
}).Return(&internalpb.GetStatisticsResponse{}, nil)
|
||||
worker2.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).
|
||||
Run(func(_ context.Context, req *querypb.GetStatisticsRequest) {
|
||||
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||
}).Return(&internalpb.GetStatisticsResponse{}, nil)
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
results, err := s.delegator.GetStatistics(ctx, &querypb.GetStatisticsRequest{
|
||||
Req: &internalpb.GetStatisticsRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.NoError(err)
|
||||
s.Equal(3, len(results))
|
||||
})
|
||||
|
||||
s.Run("worker_return_error", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
}()
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
worker2 := &cluster.MockWorker{}
|
||||
|
||||
workers[1] = worker1
|
||||
workers[2] = worker2
|
||||
|
||||
worker1.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).Return(nil, errors.New("mock error"))
|
||||
worker2.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).
|
||||
Run(func(_ context.Context, req *querypb.GetStatisticsRequest) {
|
||||
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||
}).Return(&internalpb.GetStatisticsResponse{}, nil)
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := s.delegator.GetStatistics(ctx, &querypb.GetStatisticsRequest{
|
||||
Req: &internalpb.GetStatisticsRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("worker_return_failure_code", func() {
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
worker2 := &cluster.MockWorker{}
|
||||
|
||||
workers[1] = worker1
|
||||
workers[2] = worker2
|
||||
|
||||
worker1.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).Return(&internalpb.GetStatisticsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "mocked error",
|
||||
},
|
||||
}, nil)
|
||||
worker2.EXPECT().GetStatistics(mock.Anything, mock.AnythingOfType("*querypb.GetStatisticsRequest")).
|
||||
Run(func(_ context.Context, req *querypb.GetStatisticsRequest) {
|
||||
s.EqualValues(2, req.Req.GetBase().GetTargetID())
|
||||
s.True(req.GetFromShardLeader())
|
||||
s.Equal(querypb.DataScope_Historical, req.GetScope())
|
||||
s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels())
|
||||
s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs())
|
||||
}).Return(&internalpb.GetStatisticsResponse{}, nil)
|
||||
|
||||
s.workerManager.EXPECT().GetWorker(mock.AnythingOfType("int64")).Call.Return(func(nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
_, err := s.delegator.GetStatistics(ctx, &querypb.GetStatisticsRequest{
|
||||
Req: &internalpb.GetStatisticsRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("wrong_channel", func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.GetStatistics(ctx, &querypb.GetStatisticsRequest{
|
||||
Req: &internalpb.GetStatisticsRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{"non_exist_channel"},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("cluster_not_serviceable", func() {
|
||||
s.delegator.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
_, err := s.delegator.GetStatistics(ctx, &querypb.GetStatisticsRequest{
|
||||
Req: &internalpb.GetStatisticsRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDelegatorSuite(t *testing.T) {
|
||||
suite.Run(t, new(DelegatorSuite))
|
||||
}
|
||||
|
||||
func TestDelegatorWatchTsafe(t *testing.T) {
|
||||
channelName := "default_dml_channel"
|
||||
|
||||
tsafeManager := tsafe.NewTSafeReplica()
|
||||
tsafeManager.Add(channelName, 100)
|
||||
sd := &shardDelegator{
|
||||
tsafeManager: tsafeManager,
|
||||
vchannelName: channelName,
|
||||
lifetime: newLifetime(),
|
||||
latestTsafe: atomic.NewUint64(0),
|
||||
}
|
||||
defer sd.Close()
|
||||
|
||||
m := sync.Mutex{}
|
||||
sd.tsCond = sync.NewCond(&m)
|
||||
sd.wg.Add(1)
|
||||
go sd.watchTSafe()
|
||||
|
||||
err := tsafeManager.Set(channelName, 200)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return sd.latestTsafe.Load() == 200
|
||||
}, time.Second*10, time.Millisecond*10)
|
||||
}
|
||||
|
||||
func TestDelegatorTSafeListenerClosed(t *testing.T) {
|
||||
channelName := "default_dml_channel"
|
||||
|
||||
tsafeManager := tsafe.NewTSafeReplica()
|
||||
tsafeManager.Add(channelName, 100)
|
||||
sd := &shardDelegator{
|
||||
tsafeManager: tsafeManager,
|
||||
vchannelName: channelName,
|
||||
lifetime: newLifetime(),
|
||||
latestTsafe: atomic.NewUint64(0),
|
||||
}
|
||||
defer sd.Close()
|
||||
|
||||
m := sync.Mutex{}
|
||||
sd.tsCond = sync.NewCond(&m)
|
||||
sd.wg.Add(1)
|
||||
signal := make(chan struct{})
|
||||
go func() {
|
||||
sd.watchTSafe()
|
||||
close(signal)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-signal:
|
||||
assert.FailNow(t, "watchTsafe quit unexpectedly")
|
||||
case <-time.After(time.Millisecond * 10):
|
||||
}
|
||||
|
||||
tsafeManager.Remove(channelName)
|
||||
|
||||
select {
|
||||
case <-signal:
|
||||
case <-time.After(time.Second):
|
||||
assert.FailNow(t, "watchTsafe still working after listener closed")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,138 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deletebuffer
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
errBufferFull = errors.New("buffer full")
|
||||
)
|
||||
|
||||
type timed interface {
|
||||
Timestamp() uint64
|
||||
Size() int64
|
||||
}
|
||||
|
||||
// DeleteBuffer is the interface for delete buffer.
|
||||
type DeleteBuffer[T timed] interface {
|
||||
Put(T)
|
||||
ListAfter(uint64) []T
|
||||
SafeTs() uint64
|
||||
}
|
||||
|
||||
func NewDoubleCacheDeleteBuffer[T timed](startTs uint64, maxSize int64) DeleteBuffer[T] {
|
||||
return &doubleCacheBuffer[T]{
|
||||
head: newDoubleCacheItem[T](startTs, maxSize),
|
||||
maxSize: maxSize,
|
||||
ts: startTs,
|
||||
}
|
||||
}
|
||||
|
||||
// doubleCacheBuffer implements DeleteBuffer with fixed sized double cache.
|
||||
type doubleCacheBuffer[T timed] struct {
|
||||
mut sync.RWMutex
|
||||
head, tail *doubleCacheItem[T]
|
||||
maxSize int64
|
||||
ts uint64
|
||||
}
|
||||
|
||||
func (c *doubleCacheBuffer[T]) SafeTs() uint64 {
|
||||
return c.ts
|
||||
}
|
||||
|
||||
// Put implements DeleteBuffer.
|
||||
func (c *doubleCacheBuffer[T]) Put(entry T) {
|
||||
c.mut.Lock()
|
||||
defer c.mut.Unlock()
|
||||
|
||||
err := c.head.Put(entry)
|
||||
if errors.Is(err, errBufferFull) {
|
||||
c.evict(entry.Timestamp())
|
||||
c.head.Put(entry)
|
||||
}
|
||||
}
|
||||
|
||||
// ListAfter implements DeleteBuffer.
|
||||
func (c *doubleCacheBuffer[T]) ListAfter(ts uint64) []T {
|
||||
c.mut.RLock()
|
||||
defer c.mut.RUnlock()
|
||||
var result []T
|
||||
if c.tail != nil {
|
||||
result = append(result, c.tail.ListAfter(ts)...)
|
||||
}
|
||||
if c.head != nil {
|
||||
result = append(result, c.head.ListAfter(ts)...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// evict sets head as tail and evicts tail.
|
||||
func (c *doubleCacheBuffer[T]) evict(newTs uint64) {
|
||||
c.tail = c.head
|
||||
c.head = newDoubleCacheItem[T](newTs, c.maxSize/2)
|
||||
c.ts = c.tail.headTs
|
||||
}
|
||||
|
||||
func newDoubleCacheItem[T timed](ts uint64, maxSize int64) *doubleCacheItem[T] {
|
||||
return &doubleCacheItem[T]{
|
||||
headTs: ts,
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
type doubleCacheItem[T timed] struct {
|
||||
mut sync.RWMutex
|
||||
headTs uint64
|
||||
size int64
|
||||
maxSize int64
|
||||
|
||||
data []T
|
||||
}
|
||||
|
||||
// Cache adds entry into cache item.
|
||||
// returns error if item is full
|
||||
func (c *doubleCacheItem[T]) Put(entry T) error {
|
||||
c.mut.Lock()
|
||||
defer c.mut.Unlock()
|
||||
|
||||
if c.size+entry.Size() > c.maxSize {
|
||||
return errBufferFull
|
||||
}
|
||||
|
||||
c.data = append(c.data, entry)
|
||||
c.size += entry.Size()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListAfter returns entries of which ts after provided value.
|
||||
func (c *doubleCacheItem[T]) ListAfter(ts uint64) []T {
|
||||
c.mut.RLock()
|
||||
defer c.mut.RUnlock()
|
||||
idx := sort.Search(len(c.data), func(idx int) bool {
|
||||
return c.data[idx].Timestamp() >= ts
|
||||
})
|
||||
// not found
|
||||
if idx == len(c.data) {
|
||||
return nil
|
||||
}
|
||||
return c.data[idx:]
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deletebuffer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
func TestSkipListDeleteBuffer(t *testing.T) {
|
||||
db := NewDeleteBuffer()
|
||||
|
||||
db.Cache(10, []BufferItem{
|
||||
{PartitionID: 1},
|
||||
})
|
||||
|
||||
result := db.List(0)
|
||||
|
||||
assert.Equal(t, 1, len(result))
|
||||
assert.Equal(t, int64(1), result[0][0].PartitionID)
|
||||
|
||||
db.TruncateBefore(11)
|
||||
result = db.List(0)
|
||||
assert.Equal(t, 0, len(result))
|
||||
}
|
||||
|
||||
type DoubleCacheBufferSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *DoubleCacheBufferSuite) TestNewBuffer() {
|
||||
buffer := NewDoubleCacheDeleteBuffer[*Item](10, 1000)
|
||||
|
||||
s.EqualValues(10, buffer.SafeTs())
|
||||
}
|
||||
|
||||
func (s *DoubleCacheBufferSuite) TestCache() {
|
||||
buffer := NewDoubleCacheDeleteBuffer[*Item](10, 1000)
|
||||
buffer.Put(&Item{
|
||||
Ts: 11,
|
||||
Data: []BufferItem{
|
||||
{
|
||||
PartitionID: 200,
|
||||
DeleteData: storage.DeleteData{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
buffer.Put(&Item{
|
||||
Ts: 12,
|
||||
Data: []BufferItem{
|
||||
{
|
||||
PartitionID: 200,
|
||||
DeleteData: storage.DeleteData{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
s.Equal(2, len(buffer.ListAfter(11)))
|
||||
s.Equal(1, len(buffer.ListAfter(12)))
|
||||
}
|
||||
|
||||
func TestDoubleCacheDeleteBuffer(t *testing.T) {
|
||||
suite.Run(t, new(DoubleCacheBufferSuite))
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
package deletebuffer
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
// Item wraps cache item as `timed`.
|
||||
type Item struct {
|
||||
Ts uint64
|
||||
Data []BufferItem
|
||||
}
|
||||
|
||||
// Timestamp implements `timed`.
|
||||
func (item *Item) Timestamp() uint64 {
|
||||
return item.Ts
|
||||
}
|
||||
|
||||
// Size implements `timed`.
|
||||
func (item *Item) Size() int64 {
|
||||
return lo.Reduce(item.Data, func(size int64, item BufferItem, _ int) int64 {
|
||||
return size + item.Size()
|
||||
}, int64(0))
|
||||
}
|
||||
|
||||
type BufferItem struct {
|
||||
PartitionID int64
|
||||
DeleteData storage.DeleteData
|
||||
}
|
||||
|
||||
func (item *BufferItem) Size() int64 {
|
||||
var pkSize int64
|
||||
if len(item.DeleteData.Pks) > 0 {
|
||||
pkSize = int64(len(item.DeleteData.Pks)) * item.DeleteData.Pks[0].Size()
|
||||
}
|
||||
|
||||
return int64(96) + pkSize + int64(8*len(item.DeleteData.Tss))
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
package deletebuffer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDeleteBufferItem(t *testing.T) {
|
||||
item := &BufferItem{
|
||||
PartitionID: 100,
|
||||
DeleteData: storage.DeleteData{},
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(96), item.Size())
|
||||
|
||||
item.DeleteData.Pks = []storage.PrimaryKey{
|
||||
storage.NewInt64PrimaryKey(10),
|
||||
}
|
||||
item.DeleteData.Tss = []uint64{2000}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package deletebuffer
|
||||
|
||||
import "github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
||||
// deleteBuffer caches L0 delete buffer for remote segments.
|
||||
type deleteBuffer struct {
|
||||
// timestamp => DeleteData
|
||||
cache *typeutil.SkipList[uint64, []BufferItem]
|
||||
}
|
||||
|
||||
// Cache delete data.
|
||||
func (b *deleteBuffer) Cache(timestamp uint64, data []BufferItem) {
|
||||
b.cache.Upsert(timestamp, data)
|
||||
}
|
||||
|
||||
func (b *deleteBuffer) List(since uint64) [][]BufferItem {
|
||||
return b.cache.ListAfter(since, false)
|
||||
}
|
||||
|
||||
func (b *deleteBuffer) TruncateBefore(ts uint64) {
|
||||
b.cache.TruncateBefore(ts)
|
||||
}
|
||||
|
||||
func NewDeleteBuffer() *deleteBuffer {
|
||||
cache, _ := typeutil.NewSkipList[uint64, []BufferItem]()
|
||||
return &deleteBuffer{
|
||||
cache: cache,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,242 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package delegator
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
const (
|
||||
// wildcardNodeID matches any nodeID, used for force distribution correction.
|
||||
wildcardNodeID = int64(-1)
|
||||
)
|
||||
|
||||
// distribution is the struct to store segment distribution.
|
||||
// it contains both growing and sealed segments.
|
||||
type distribution struct {
|
||||
// segments information
|
||||
// map[SegmentID]=>segmentEntry
|
||||
growingSegments map[UniqueID]SegmentEntry
|
||||
sealedSegments map[UniqueID]SegmentEntry
|
||||
|
||||
// version indicator
|
||||
version int64
|
||||
// quick flag for current snapshot is serviceable
|
||||
serviceable *atomic.Bool
|
||||
offlines typeutil.Set[int64]
|
||||
|
||||
snapshots *typeutil.ConcurrentMap[int64, *snapshot]
|
||||
// current is the snapshot for quick usage for search/query
|
||||
// generated for each change of distribution
|
||||
current *snapshot
|
||||
// protects current & segments
|
||||
mut sync.RWMutex
|
||||
}
|
||||
|
||||
// SegmentEntry stores the segment meta information.
|
||||
type SegmentEntry struct {
|
||||
NodeID int64
|
||||
SegmentID UniqueID
|
||||
PartitionID UniqueID
|
||||
Version int64
|
||||
}
|
||||
|
||||
// NewDistribution creates a new distribution instance with all field initialized.
|
||||
func NewDistribution() *distribution {
|
||||
dist := &distribution{
|
||||
serviceable: atomic.NewBool(false),
|
||||
growingSegments: make(map[UniqueID]SegmentEntry),
|
||||
sealedSegments: make(map[UniqueID]SegmentEntry),
|
||||
snapshots: typeutil.NewConcurrentMap[int64, *snapshot](),
|
||||
}
|
||||
|
||||
dist.genSnapshot()
|
||||
return dist
|
||||
}
|
||||
|
||||
// GetCurrent returns current snapshot.
|
||||
func (d *distribution) GetCurrent(partitions ...int64) (sealed []SnapshotItem, growing []SegmentEntry, version int64) {
|
||||
d.mut.RLock()
|
||||
defer d.mut.RUnlock()
|
||||
|
||||
sealed, growing = d.current.Get(partitions...)
|
||||
version = d.current.version
|
||||
return
|
||||
}
|
||||
|
||||
// FinishUsage notifies snapshot one reference is released.
|
||||
func (d *distribution) FinishUsage(version int64) {
|
||||
snapshot, ok := d.snapshots.Get(version)
|
||||
if ok {
|
||||
snapshot.Done(d.getCleanup(snapshot.version))
|
||||
}
|
||||
}
|
||||
|
||||
// Serviceable returns wether current snapshot is serviceable.
|
||||
func (d *distribution) Serviceable() bool {
|
||||
return d.serviceable.Load()
|
||||
}
|
||||
|
||||
// AddDistributions add multiple segment entries.
|
||||
func (d *distribution) AddDistributions(entries ...SegmentEntry) []int64 {
|
||||
d.mut.Lock()
|
||||
defer d.mut.Unlock()
|
||||
|
||||
// remove growing if sealed is loaded
|
||||
var removed []int64
|
||||
for _, entry := range entries {
|
||||
d.sealedSegments[entry.SegmentID] = entry
|
||||
d.offlines.Remove(entry.SegmentID)
|
||||
_, ok := d.growingSegments[entry.SegmentID]
|
||||
if ok {
|
||||
removed = append(removed, entry.SegmentID)
|
||||
delete(d.growingSegments, entry.SegmentID)
|
||||
}
|
||||
}
|
||||
|
||||
d.genSnapshot()
|
||||
return removed
|
||||
}
|
||||
|
||||
// AddGrowing adds growing segment distribution.
|
||||
func (d *distribution) AddGrowing(entries ...SegmentEntry) {
|
||||
d.mut.Lock()
|
||||
defer d.mut.Unlock()
|
||||
|
||||
for _, entry := range entries {
|
||||
d.growingSegments[entry.SegmentID] = entry
|
||||
}
|
||||
|
||||
d.genSnapshot()
|
||||
}
|
||||
|
||||
// AddOffline set segmentIDs to offlines.
|
||||
func (d *distribution) AddOfflines(segmentIDs ...int64) {
|
||||
d.mut.Lock()
|
||||
defer d.mut.Unlock()
|
||||
|
||||
updated := false
|
||||
for _, segmentID := range segmentIDs {
|
||||
_, ok := d.sealedSegments[segmentID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
updated = true
|
||||
d.offlines.Insert(segmentID)
|
||||
}
|
||||
|
||||
if updated {
|
||||
d.genSnapshot()
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveDistributions remove segments distributions and returns the clear signal channel.
|
||||
func (d *distribution) RemoveDistributions(sealedSegments []SegmentEntry, growingSegments []SegmentEntry) chan struct{} {
|
||||
d.mut.Lock()
|
||||
defer d.mut.Unlock()
|
||||
|
||||
changed := false
|
||||
for _, sealed := range sealedSegments {
|
||||
if d.offlines.Contain(sealed.SegmentID) {
|
||||
d.offlines.Remove(sealed.SegmentID)
|
||||
changed = true
|
||||
}
|
||||
entry, ok := d.sealedSegments[sealed.SegmentID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if entry.NodeID == sealed.NodeID || sealed.NodeID == wildcardNodeID {
|
||||
delete(d.sealedSegments, sealed.SegmentID)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, growing := range growingSegments {
|
||||
_, ok := d.growingSegments[growing.SegmentID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
delete(d.growingSegments, growing.SegmentID)
|
||||
changed = true
|
||||
}
|
||||
|
||||
if !changed {
|
||||
// no change made, return closed signal channel
|
||||
ch := make(chan struct{})
|
||||
close(ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
return d.genSnapshot()
|
||||
}
|
||||
|
||||
// getSnapshot converts current distribution to snapshot format.
|
||||
// in which, user could juse found nodeID=>segmentID list.
|
||||
// mutex RLock is required before calling this method.
|
||||
func (d *distribution) genSnapshot() chan struct{} {
|
||||
|
||||
nodeSegments := make(map[int64][]SegmentEntry)
|
||||
for _, entry := range d.sealedSegments {
|
||||
nodeSegments[entry.NodeID] = append(nodeSegments[entry.NodeID], entry)
|
||||
}
|
||||
|
||||
dist := make([]SnapshotItem, 0, len(nodeSegments))
|
||||
for nodeID, items := range nodeSegments {
|
||||
dist = append(dist, SnapshotItem{
|
||||
NodeID: nodeID,
|
||||
Segments: items,
|
||||
})
|
||||
}
|
||||
|
||||
growing := make([]SegmentEntry, 0, len(d.growingSegments))
|
||||
for _, entry := range d.growingSegments {
|
||||
growing = append(growing, entry)
|
||||
}
|
||||
|
||||
d.serviceable.Store(d.offlines.Len() == 0)
|
||||
|
||||
// stores last snapshot
|
||||
// ok to be nil
|
||||
last := d.current
|
||||
// increase version
|
||||
d.version++
|
||||
d.current = NewSnapshot(dist, growing, last, d.version)
|
||||
// shall be a new one
|
||||
d.snapshots.GetOrInsert(d.version, d.current)
|
||||
|
||||
// first snapshot, return closed chan
|
||||
if last == nil {
|
||||
ch := make(chan struct{})
|
||||
close(ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
last.Expire(d.getCleanup(last.version))
|
||||
|
||||
return last.cleared
|
||||
}
|
||||
|
||||
// getCleanup returns cleanup snapshots function.
|
||||
func (d *distribution) getCleanup(version int64) snapshotCleanup {
|
||||
return func() {
|
||||
d.snapshots.GetAndRemove(version)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,395 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package delegator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type DistributionSuite struct {
|
||||
suite.Suite
|
||||
dist *distribution
|
||||
}
|
||||
|
||||
func (s *DistributionSuite) SetupTest() {
|
||||
s.dist = NewDistribution()
|
||||
}
|
||||
|
||||
func (s *DistributionSuite) TearDownTest() {
|
||||
s.dist = nil
|
||||
}
|
||||
|
||||
func (s *DistributionSuite) TestAddDistribution() {
|
||||
type testCase struct {
|
||||
tag string
|
||||
input []SegmentEntry
|
||||
expected []SnapshotItem
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
tag: "one node",
|
||||
input: []SegmentEntry{
|
||||
{
|
||||
NodeID: 1,
|
||||
SegmentID: 1,
|
||||
},
|
||||
{
|
||||
NodeID: 1,
|
||||
SegmentID: 2,
|
||||
},
|
||||
},
|
||||
expected: []SnapshotItem{
|
||||
{
|
||||
NodeID: 1,
|
||||
Segments: []SegmentEntry{
|
||||
{
|
||||
NodeID: 1,
|
||||
SegmentID: 1,
|
||||
},
|
||||
{
|
||||
NodeID: 1,
|
||||
SegmentID: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
tag: "multiple nodes",
|
||||
input: []SegmentEntry{
|
||||
{
|
||||
NodeID: 1,
|
||||
SegmentID: 1,
|
||||
},
|
||||
{
|
||||
NodeID: 2,
|
||||
SegmentID: 2,
|
||||
},
|
||||
{
|
||||
NodeID: 1,
|
||||
SegmentID: 3,
|
||||
},
|
||||
},
|
||||
expected: []SnapshotItem{
|
||||
{
|
||||
NodeID: 1,
|
||||
Segments: []SegmentEntry{
|
||||
{
|
||||
NodeID: 1,
|
||||
SegmentID: 1,
|
||||
},
|
||||
|
||||
{
|
||||
NodeID: 1,
|
||||
SegmentID: 3,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
NodeID: 2,
|
||||
Segments: []SegmentEntry{
|
||||
{
|
||||
NodeID: 2,
|
||||
SegmentID: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
s.Run(tc.tag, func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
s.dist.AddDistributions(tc.input...)
|
||||
sealed, _, version := s.dist.GetCurrent()
|
||||
defer s.dist.FinishUsage(version)
|
||||
s.compareSnapshotItems(tc.expected, sealed)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DistributionSuite) compareSnapshotItems(target, value []SnapshotItem) {
|
||||
if !s.Equal(len(target), len(value)) {
|
||||
return
|
||||
}
|
||||
mapNodeItem := make(map[int64]SnapshotItem)
|
||||
for _, valueItem := range value {
|
||||
mapNodeItem[valueItem.NodeID] = valueItem
|
||||
}
|
||||
|
||||
for _, targetItem := range target {
|
||||
valueItem, ok := mapNodeItem[targetItem.NodeID]
|
||||
if !s.True(ok) {
|
||||
return
|
||||
}
|
||||
|
||||
s.ElementsMatch(targetItem.Segments, valueItem.Segments)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DistributionSuite) TestAddGrowing() {
|
||||
type testCase struct {
|
||||
tag string
|
||||
input []SegmentEntry
|
||||
expected []SegmentEntry
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
tag: "nil input",
|
||||
input: nil,
|
||||
expected: []SegmentEntry{},
|
||||
},
|
||||
{
|
||||
tag: "normal case",
|
||||
input: []SegmentEntry{
|
||||
{SegmentID: 1, PartitionID: 1},
|
||||
{SegmentID: 2, PartitionID: 2},
|
||||
},
|
||||
expected: []SegmentEntry{
|
||||
{SegmentID: 1, PartitionID: 1},
|
||||
{SegmentID: 2, PartitionID: 2},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
s.Run(tc.tag, func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.dist.AddGrowing(tc.input...)
|
||||
_, growing, version := s.dist.GetCurrent()
|
||||
defer s.dist.FinishUsage(version)
|
||||
|
||||
s.ElementsMatch(tc.expected, growing)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DistributionSuite) TestRemoveDistribution() {
|
||||
type testCase struct {
|
||||
tag string
|
||||
presetSealed []SegmentEntry
|
||||
presetGrowing []SegmentEntry
|
||||
removalSealed []SegmentEntry
|
||||
removalGrowing []SegmentEntry
|
||||
|
||||
withMockRead bool
|
||||
expectSealed []SnapshotItem
|
||||
expectGrowing []SegmentEntry
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
tag: "remove with no read",
|
||||
presetSealed: []SegmentEntry{
|
||||
{NodeID: 1, SegmentID: 1},
|
||||
{NodeID: 2, SegmentID: 2},
|
||||
{NodeID: 1, SegmentID: 3},
|
||||
},
|
||||
presetGrowing: []SegmentEntry{
|
||||
{SegmentID: 4},
|
||||
{SegmentID: 5},
|
||||
},
|
||||
|
||||
removalSealed: []SegmentEntry{
|
||||
{NodeID: 1, SegmentID: 1},
|
||||
},
|
||||
removalGrowing: []SegmentEntry{
|
||||
{SegmentID: 5},
|
||||
},
|
||||
|
||||
withMockRead: false,
|
||||
|
||||
expectSealed: []SnapshotItem{
|
||||
{
|
||||
NodeID: 1,
|
||||
Segments: []SegmentEntry{
|
||||
{NodeID: 1, SegmentID: 3},
|
||||
},
|
||||
},
|
||||
{
|
||||
NodeID: 2,
|
||||
Segments: []SegmentEntry{
|
||||
{NodeID: 2, SegmentID: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectGrowing: []SegmentEntry{{SegmentID: 4}},
|
||||
},
|
||||
{
|
||||
tag: "remove with wrong nodeID",
|
||||
presetSealed: []SegmentEntry{
|
||||
{NodeID: 1, SegmentID: 1},
|
||||
{NodeID: 2, SegmentID: 2},
|
||||
{NodeID: 1, SegmentID: 3},
|
||||
},
|
||||
presetGrowing: []SegmentEntry{
|
||||
{SegmentID: 4},
|
||||
{SegmentID: 5},
|
||||
},
|
||||
|
||||
removalSealed: []SegmentEntry{
|
||||
{NodeID: 2, SegmentID: 1},
|
||||
},
|
||||
removalGrowing: nil,
|
||||
|
||||
withMockRead: false,
|
||||
|
||||
expectSealed: []SnapshotItem{
|
||||
{
|
||||
NodeID: 1,
|
||||
Segments: []SegmentEntry{
|
||||
{NodeID: 1, SegmentID: 1},
|
||||
{NodeID: 1, SegmentID: 3},
|
||||
},
|
||||
},
|
||||
{
|
||||
NodeID: 2,
|
||||
Segments: []SegmentEntry{
|
||||
{NodeID: 2, SegmentID: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectGrowing: []SegmentEntry{{SegmentID: 4}, {SegmentID: 5}},
|
||||
},
|
||||
{
|
||||
tag: "remove with wildcardNodeID",
|
||||
presetSealed: []SegmentEntry{
|
||||
{NodeID: 1, SegmentID: 1},
|
||||
{NodeID: 2, SegmentID: 2},
|
||||
{NodeID: 1, SegmentID: 3},
|
||||
},
|
||||
presetGrowing: []SegmentEntry{
|
||||
{SegmentID: 4},
|
||||
{SegmentID: 5},
|
||||
},
|
||||
|
||||
removalSealed: []SegmentEntry{
|
||||
{NodeID: wildcardNodeID, SegmentID: 1},
|
||||
},
|
||||
removalGrowing: nil,
|
||||
|
||||
withMockRead: false,
|
||||
|
||||
expectSealed: []SnapshotItem{
|
||||
{
|
||||
NodeID: 1,
|
||||
Segments: []SegmentEntry{
|
||||
{NodeID: 1, SegmentID: 3},
|
||||
},
|
||||
},
|
||||
{
|
||||
NodeID: 2,
|
||||
Segments: []SegmentEntry{
|
||||
{NodeID: 2, SegmentID: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectGrowing: []SegmentEntry{{SegmentID: 4}, {SegmentID: 5}},
|
||||
},
|
||||
{
|
||||
tag: "remove with read",
|
||||
presetSealed: []SegmentEntry{
|
||||
{NodeID: 1, SegmentID: 1},
|
||||
{NodeID: 2, SegmentID: 2},
|
||||
{NodeID: 1, SegmentID: 3},
|
||||
},
|
||||
presetGrowing: []SegmentEntry{
|
||||
{SegmentID: 4},
|
||||
{SegmentID: 5},
|
||||
},
|
||||
|
||||
removalSealed: []SegmentEntry{
|
||||
{NodeID: 1, SegmentID: 1},
|
||||
},
|
||||
removalGrowing: []SegmentEntry{
|
||||
{SegmentID: 5},
|
||||
},
|
||||
|
||||
withMockRead: true,
|
||||
|
||||
expectSealed: []SnapshotItem{
|
||||
{
|
||||
NodeID: 1,
|
||||
Segments: []SegmentEntry{
|
||||
{NodeID: 1, SegmentID: 3},
|
||||
},
|
||||
},
|
||||
{
|
||||
NodeID: 2,
|
||||
Segments: []SegmentEntry{
|
||||
{NodeID: 2, SegmentID: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectGrowing: []SegmentEntry{{SegmentID: 4}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
s.Run(tc.tag, func() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.dist.AddGrowing(tc.presetGrowing...)
|
||||
s.dist.AddDistributions(tc.presetSealed...)
|
||||
|
||||
var version int64
|
||||
if tc.withMockRead {
|
||||
_, _, version = s.dist.GetCurrent()
|
||||
}
|
||||
|
||||
ch := s.dist.RemoveDistributions(tc.removalSealed, tc.removalGrowing)
|
||||
|
||||
if tc.withMockRead {
|
||||
// check ch not closed
|
||||
select {
|
||||
case <-ch:
|
||||
s.Fail("ch closed with running read")
|
||||
default:
|
||||
}
|
||||
|
||||
s.dist.FinishUsage(version)
|
||||
}
|
||||
// check ch close very soon
|
||||
timeout := time.NewTimer(time.Second)
|
||||
defer timeout.Stop()
|
||||
select {
|
||||
case <-timeout.C:
|
||||
s.Fail("ch not closed after 1 second")
|
||||
case <-ch:
|
||||
}
|
||||
|
||||
sealed, growing, version := s.dist.GetCurrent()
|
||||
defer s.dist.FinishUsage(version)
|
||||
s.compareSnapshotItems(tc.expectSealed, sealed)
|
||||
s.ElementsMatch(tc.expectGrowing, growing)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistributionSuite(t *testing.T) {
|
||||
suite.Run(t, new(DistributionSuite))
|
||||
}
|
|
@ -0,0 +1,597 @@
|
|||
// Code generated by mockery v2.16.0. DO NOT EDIT.
|
||||
|
||||
package delegator
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
internalpb "github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
querypb "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
)
|
||||
|
||||
// MockShardDelegator is an autogenerated mock type for the ShardDelegator type
|
||||
type MockShardDelegator struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockShardDelegator_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockShardDelegator) EXPECT() *MockShardDelegator_Expecter {
|
||||
return &MockShardDelegator_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Close provides a mock function with given fields:
|
||||
func (_m *MockShardDelegator) Close() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// MockShardDelegator_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
|
||||
type MockShardDelegator_Close_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Close is a helper method to define mock.On call
|
||||
func (_e *MockShardDelegator_Expecter) Close() *MockShardDelegator_Close_Call {
|
||||
return &MockShardDelegator_Close_Call{Call: _e.mock.On("Close")}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_Close_Call) Run(run func()) *MockShardDelegator_Close_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_Close_Call) Return() *MockShardDelegator_Close_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
// Collection provides a mock function with given fields:
|
||||
func (_m *MockShardDelegator) Collection() int64 {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 int64
|
||||
if rf, ok := ret.Get(0).(func() int64); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(int64)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockShardDelegator_Collection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Collection'
|
||||
type MockShardDelegator_Collection_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Collection is a helper method to define mock.On call
|
||||
func (_e *MockShardDelegator_Expecter) Collection() *MockShardDelegator_Collection_Call {
|
||||
return &MockShardDelegator_Collection_Call{Call: _e.mock.On("Collection")}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_Collection_Call) Run(run func()) *MockShardDelegator_Collection_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_Collection_Call) Return(_a0 int64) *MockShardDelegator_Collection_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetDistribution provides a mock function with given fields:
|
||||
func (_m *MockShardDelegator) GetDistribution() *distribution {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 *distribution
|
||||
if rf, ok := ret.Get(0).(func() *distribution); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*distribution)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockShardDelegator_GetDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDistribution'
|
||||
type MockShardDelegator_GetDistribution_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetDistribution is a helper method to define mock.On call
|
||||
func (_e *MockShardDelegator_Expecter) GetDistribution() *MockShardDelegator_GetDistribution_Call {
|
||||
return &MockShardDelegator_GetDistribution_Call{Call: _e.mock.On("GetDistribution")}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_GetDistribution_Call) Run(run func()) *MockShardDelegator_GetDistribution_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_GetDistribution_Call) Return(_a0 *distribution) *MockShardDelegator_GetDistribution_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetStatistics provides a mock function with given fields: ctx, req
|
||||
func (_m *MockShardDelegator) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error) {
|
||||
ret := _m.Called(ctx, req)
|
||||
|
||||
var r0 []*internalpb.GetStatisticsResponse
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest) []*internalpb.GetStatisticsResponse); ok {
|
||||
r0 = rf(ctx, req)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]*internalpb.GetStatisticsResponse)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetStatisticsRequest) error); ok {
|
||||
r1 = rf(ctx, req)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockShardDelegator_GetStatistics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStatistics'
|
||||
type MockShardDelegator_GetStatistics_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetStatistics is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - req *querypb.GetStatisticsRequest
|
||||
func (_e *MockShardDelegator_Expecter) GetStatistics(ctx interface{}, req interface{}) *MockShardDelegator_GetStatistics_Call {
|
||||
return &MockShardDelegator_GetStatistics_Call{Call: _e.mock.On("GetStatistics", ctx, req)}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_GetStatistics_Call) Run(run func(ctx context.Context, req *querypb.GetStatisticsRequest)) *MockShardDelegator_GetStatistics_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.GetStatisticsRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_GetStatistics_Call) Return(_a0 []*internalpb.GetStatisticsResponse, _a1 error) *MockShardDelegator_GetStatistics_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
// LoadGrowing provides a mock function with given fields: ctx, infos, version
|
||||
func (_m *MockShardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error {
|
||||
ret := _m.Called(ctx, infos, version)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, []*querypb.SegmentLoadInfo, int64) error); ok {
|
||||
r0 = rf(ctx, infos, version)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockShardDelegator_LoadGrowing_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadGrowing'
|
||||
type MockShardDelegator_LoadGrowing_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// LoadGrowing is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - infos []*querypb.SegmentLoadInfo
|
||||
// - version int64
|
||||
func (_e *MockShardDelegator_Expecter) LoadGrowing(ctx interface{}, infos interface{}, version interface{}) *MockShardDelegator_LoadGrowing_Call {
|
||||
return &MockShardDelegator_LoadGrowing_Call{Call: _e.mock.On("LoadGrowing", ctx, infos, version)}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_LoadGrowing_Call) Run(run func(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64)) *MockShardDelegator_LoadGrowing_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].([]*querypb.SegmentLoadInfo), args[2].(int64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_LoadGrowing_Call) Return(_a0 error) *MockShardDelegator_LoadGrowing_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// LoadSegments provides a mock function with given fields: ctx, req
|
||||
func (_m *MockShardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error {
|
||||
ret := _m.Called(ctx, req)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadSegmentsRequest) error); ok {
|
||||
r0 = rf(ctx, req)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockShardDelegator_LoadSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadSegments'
|
||||
type MockShardDelegator_LoadSegments_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// LoadSegments is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - req *querypb.LoadSegmentsRequest
|
||||
func (_e *MockShardDelegator_Expecter) LoadSegments(ctx interface{}, req interface{}) *MockShardDelegator_LoadSegments_Call {
|
||||
return &MockShardDelegator_LoadSegments_Call{Call: _e.mock.On("LoadSegments", ctx, req)}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_LoadSegments_Call) Run(run func(ctx context.Context, req *querypb.LoadSegmentsRequest)) *MockShardDelegator_LoadSegments_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.LoadSegmentsRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_LoadSegments_Call) Return(_a0 error) *MockShardDelegator_LoadSegments_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ProcessDelete provides a mock function with given fields: deleteData, ts
|
||||
func (_m *MockShardDelegator) ProcessDelete(deleteData []*DeleteData, ts uint64) {
|
||||
_m.Called(deleteData, ts)
|
||||
}
|
||||
|
||||
// MockShardDelegator_ProcessDelete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ProcessDelete'
|
||||
type MockShardDelegator_ProcessDelete_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ProcessDelete is a helper method to define mock.On call
|
||||
// - deleteData []*DeleteData
|
||||
// - ts uint64
|
||||
func (_e *MockShardDelegator_Expecter) ProcessDelete(deleteData interface{}, ts interface{}) *MockShardDelegator_ProcessDelete_Call {
|
||||
return &MockShardDelegator_ProcessDelete_Call{Call: _e.mock.On("ProcessDelete", deleteData, ts)}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_ProcessDelete_Call) Run(run func(deleteData []*DeleteData, ts uint64)) *MockShardDelegator_ProcessDelete_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]*DeleteData), args[1].(uint64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_ProcessDelete_Call) Return() *MockShardDelegator_ProcessDelete_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
// ProcessInsert provides a mock function with given fields: insertRecords
|
||||
func (_m *MockShardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) {
|
||||
_m.Called(insertRecords)
|
||||
}
|
||||
|
||||
// MockShardDelegator_ProcessInsert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ProcessInsert'
|
||||
type MockShardDelegator_ProcessInsert_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ProcessInsert is a helper method to define mock.On call
|
||||
// - insertRecords map[int64]*InsertData
|
||||
func (_e *MockShardDelegator_Expecter) ProcessInsert(insertRecords interface{}) *MockShardDelegator_ProcessInsert_Call {
|
||||
return &MockShardDelegator_ProcessInsert_Call{Call: _e.mock.On("ProcessInsert", insertRecords)}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_ProcessInsert_Call) Run(run func(insertRecords map[int64]*InsertData)) *MockShardDelegator_ProcessInsert_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(map[int64]*InsertData))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_ProcessInsert_Call) Return() *MockShardDelegator_ProcessInsert_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
// Query provides a mock function with given fields: ctx, req
|
||||
func (_m *MockShardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) {
|
||||
ret := _m.Called(ctx, req)
|
||||
|
||||
var r0 []*internalpb.RetrieveResults
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) []*internalpb.RetrieveResults); ok {
|
||||
r0 = rf(ctx, req)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]*internalpb.RetrieveResults)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest) error); ok {
|
||||
r1 = rf(ctx, req)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockShardDelegator_Query_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Query'
|
||||
type MockShardDelegator_Query_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Query is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - req *querypb.QueryRequest
|
||||
func (_e *MockShardDelegator_Expecter) Query(ctx interface{}, req interface{}) *MockShardDelegator_Query_Call {
|
||||
return &MockShardDelegator_Query_Call{Call: _e.mock.On("Query", ctx, req)}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_Query_Call) Run(run func(ctx context.Context, req *querypb.QueryRequest)) *MockShardDelegator_Query_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.QueryRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_Query_Call) Return(_a0 []*internalpb.RetrieveResults, _a1 error) *MockShardDelegator_Query_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ReleaseSegments provides a mock function with given fields: ctx, req, force
|
||||
func (_m *MockShardDelegator) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error {
|
||||
ret := _m.Called(ctx, req, force)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseSegmentsRequest, bool) error); ok {
|
||||
r0 = rf(ctx, req, force)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockShardDelegator_ReleaseSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleaseSegments'
|
||||
type MockShardDelegator_ReleaseSegments_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ReleaseSegments is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - req *querypb.ReleaseSegmentsRequest
|
||||
// - force bool
|
||||
func (_e *MockShardDelegator_Expecter) ReleaseSegments(ctx interface{}, req interface{}, force interface{}) *MockShardDelegator_ReleaseSegments_Call {
|
||||
return &MockShardDelegator_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", ctx, req, force)}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_ReleaseSegments_Call) Run(run func(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool)) *MockShardDelegator_ReleaseSegments_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.ReleaseSegmentsRequest), args[2].(bool))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_ReleaseSegments_Call) Return(_a0 error) *MockShardDelegator_ReleaseSegments_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Search provides a mock function with given fields: ctx, req
|
||||
func (_m *MockShardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {
|
||||
ret := _m.Called(ctx, req)
|
||||
|
||||
var r0 []*internalpb.SearchResults
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) []*internalpb.SearchResults); ok {
|
||||
r0 = rf(ctx, req)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]*internalpb.SearchResults)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SearchRequest) error); ok {
|
||||
r1 = rf(ctx, req)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockShardDelegator_Search_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Search'
|
||||
type MockShardDelegator_Search_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Search is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - req *querypb.SearchRequest
|
||||
func (_e *MockShardDelegator_Expecter) Search(ctx interface{}, req interface{}) *MockShardDelegator_Search_Call {
|
||||
return &MockShardDelegator_Search_Call{Call: _e.mock.On("Search", ctx, req)}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_Search_Call) Run(run func(ctx context.Context, req *querypb.SearchRequest)) *MockShardDelegator_Search_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*querypb.SearchRequest))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_Search_Call) Return(_a0 []*internalpb.SearchResults, _a1 error) *MockShardDelegator_Search_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Serviceable provides a mock function with given fields:
|
||||
func (_m *MockShardDelegator) Serviceable() bool {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func() bool); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockShardDelegator_Serviceable_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Serviceable'
|
||||
type MockShardDelegator_Serviceable_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Serviceable is a helper method to define mock.On call
|
||||
func (_e *MockShardDelegator_Expecter) Serviceable() *MockShardDelegator_Serviceable_Call {
|
||||
return &MockShardDelegator_Serviceable_Call{Call: _e.mock.On("Serviceable")}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_Serviceable_Call) Run(run func()) *MockShardDelegator_Serviceable_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_Serviceable_Call) Return(_a0 bool) *MockShardDelegator_Serviceable_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Start provides a mock function with given fields:
|
||||
func (_m *MockShardDelegator) Start() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// MockShardDelegator_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start'
|
||||
type MockShardDelegator_Start_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Start is a helper method to define mock.On call
|
||||
func (_e *MockShardDelegator_Expecter) Start() *MockShardDelegator_Start_Call {
|
||||
return &MockShardDelegator_Start_Call{Call: _e.mock.On("Start")}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_Start_Call) Run(run func()) *MockShardDelegator_Start_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_Start_Call) Return() *MockShardDelegator_Start_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
// SyncDistribution provides a mock function with given fields: ctx, entries
|
||||
func (_m *MockShardDelegator) SyncDistribution(ctx context.Context, entries ...SegmentEntry) {
|
||||
_va := make([]interface{}, len(entries))
|
||||
for _i := range entries {
|
||||
_va[_i] = entries[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx)
|
||||
_ca = append(_ca, _va...)
|
||||
_m.Called(_ca...)
|
||||
}
|
||||
|
||||
// MockShardDelegator_SyncDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SyncDistribution'
|
||||
type MockShardDelegator_SyncDistribution_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SyncDistribution is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - entries ...SegmentEntry
|
||||
func (_e *MockShardDelegator_Expecter) SyncDistribution(ctx interface{}, entries ...interface{}) *MockShardDelegator_SyncDistribution_Call {
|
||||
return &MockShardDelegator_SyncDistribution_Call{Call: _e.mock.On("SyncDistribution",
|
||||
append([]interface{}{ctx}, entries...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_SyncDistribution_Call) Run(run func(ctx context.Context, entries ...SegmentEntry)) *MockShardDelegator_SyncDistribution_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]SegmentEntry, len(args)-1)
|
||||
for i, a := range args[1:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(SegmentEntry)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_SyncDistribution_Call) Return() *MockShardDelegator_SyncDistribution_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
// Version provides a mock function with given fields:
|
||||
func (_m *MockShardDelegator) Version() int64 {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 int64
|
||||
if rf, ok := ret.Get(0).(func() int64); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(int64)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockShardDelegator_Version_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Version'
|
||||
type MockShardDelegator_Version_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Version is a helper method to define mock.On call
|
||||
func (_e *MockShardDelegator_Expecter) Version() *MockShardDelegator_Version_Call {
|
||||
return &MockShardDelegator_Version_Call{Call: _e.mock.On("Version")}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_Version_Call) Run(run func()) *MockShardDelegator_Version_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_Version_Call) Return(_a0 int64) *MockShardDelegator_Version_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
type mockConstructorTestingTNewMockShardDelegator interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}
|
||||
|
||||
// NewMockShardDelegator creates a new instance of MockShardDelegator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
func NewMockShardDelegator(t mockConstructorTestingTNewMockShardDelegator) *MockShardDelegator {
|
||||
mock := &MockShardDelegator{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package delegator
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// SnapshotItem group segmentEntry slice
|
||||
type SnapshotItem struct {
|
||||
NodeID int64
|
||||
Segments []SegmentEntry
|
||||
}
|
||||
|
||||
// snapshotCleanup cleanup function signature.
|
||||
type snapshotCleanup func()
|
||||
|
||||
// snapshot records segment distribution with ref count.
|
||||
type snapshot struct {
|
||||
dist []SnapshotItem
|
||||
growing []SegmentEntry
|
||||
|
||||
// version ID for tracking
|
||||
version int64
|
||||
|
||||
// signal channel for this snapshot cleared
|
||||
cleared chan struct{}
|
||||
once sync.Once
|
||||
|
||||
// reference to last version
|
||||
last *snapshot
|
||||
|
||||
// reference count for this snapshot
|
||||
inUse atomic.Int64
|
||||
|
||||
// expired flag
|
||||
expired bool
|
||||
}
|
||||
|
||||
// NewSnapshot returns a prepared snapshot with channel initialized.
|
||||
func NewSnapshot(sealed []SnapshotItem, growing []SegmentEntry, last *snapshot, version int64) *snapshot {
|
||||
return &snapshot{
|
||||
version: version,
|
||||
growing: growing,
|
||||
dist: sealed,
|
||||
last: last,
|
||||
cleared: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Expire sets expired flag to true.
|
||||
func (s *snapshot) Expire(cleanup snapshotCleanup) {
|
||||
s.expired = true
|
||||
s.checkCleared(cleanup)
|
||||
}
|
||||
|
||||
// Get returns segment distributions with provided partition ids.
|
||||
func (s *snapshot) Get(partitions ...int64) (sealed []SnapshotItem, growing []SegmentEntry) {
|
||||
s.inUse.Inc()
|
||||
|
||||
filter := func(entry SegmentEntry, idx int) bool {
|
||||
return len(partitions) == 0 || funcutil.SliceContain(partitions, entry.PartitionID)
|
||||
}
|
||||
|
||||
sealed = make([]SnapshotItem, 0, len(s.dist))
|
||||
for _, item := range s.dist {
|
||||
segments := lo.Filter(item.Segments, filter)
|
||||
sealed = append(sealed, SnapshotItem{
|
||||
NodeID: item.NodeID,
|
||||
Segments: segments,
|
||||
})
|
||||
}
|
||||
|
||||
growing = lo.Filter(s.growing, filter)
|
||||
return
|
||||
}
|
||||
|
||||
// Done decreases inUse count for snapshot.
|
||||
// also performs cleared check.
|
||||
func (s *snapshot) Done(cleanup snapshotCleanup) {
|
||||
s.inUse.Dec()
|
||||
s.checkCleared(cleanup)
|
||||
}
|
||||
|
||||
// checkCleared performs safety check for snapshot closing the cleared signal.
|
||||
func (s *snapshot) checkCleared(cleanup snapshotCleanup) {
|
||||
if s.expired && s.inUse.Load() == 0 {
|
||||
s.once.Do(func() {
|
||||
// first snapshot
|
||||
if s.last == nil {
|
||||
close(s.cleared)
|
||||
cleanup()
|
||||
return
|
||||
}
|
||||
|
||||
// wait last version cleared
|
||||
go func() {
|
||||
<-s.last.cleared
|
||||
s.last = nil
|
||||
cleanup()
|
||||
close(s.cleared)
|
||||
}()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func inList(list []int64, target int64) bool {
|
||||
for _, i := range list {
|
||||
if i == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,254 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package delegator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type SnapshotSuite struct {
|
||||
suite.Suite
|
||||
|
||||
snapshot *snapshot
|
||||
}
|
||||
|
||||
func (s *SnapshotSuite) SetupTest() {
|
||||
last := NewSnapshot(nil, nil, nil, 0)
|
||||
last.Expire(func() {})
|
||||
|
||||
dist := []SnapshotItem{
|
||||
{
|
||||
NodeID: 1,
|
||||
Segments: []SegmentEntry{
|
||||
{
|
||||
SegmentID: 1,
|
||||
PartitionID: 1,
|
||||
},
|
||||
{
|
||||
SegmentID: 2,
|
||||
PartitionID: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
NodeID: 2,
|
||||
Segments: []SegmentEntry{
|
||||
{
|
||||
SegmentID: 3,
|
||||
PartitionID: 1,
|
||||
},
|
||||
{
|
||||
SegmentID: 4,
|
||||
PartitionID: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
growing := []SegmentEntry{
|
||||
{
|
||||
SegmentID: 5,
|
||||
PartitionID: 1,
|
||||
},
|
||||
{
|
||||
SegmentID: 6,
|
||||
PartitionID: 2,
|
||||
},
|
||||
}
|
||||
|
||||
s.snapshot = NewSnapshot(dist, growing, last, 1)
|
||||
}
|
||||
|
||||
func (s *SnapshotSuite) TearDownTest() {
|
||||
s.snapshot = nil
|
||||
}
|
||||
|
||||
func (s *SnapshotSuite) TestGet() {
|
||||
type testCase struct {
|
||||
tag string
|
||||
partitions []int64
|
||||
expectedSealed []SnapshotItem
|
||||
expectedGrowing []SegmentEntry
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
tag: "nil partition",
|
||||
partitions: nil,
|
||||
expectedSealed: []SnapshotItem{
|
||||
{
|
||||
NodeID: 1,
|
||||
Segments: []SegmentEntry{
|
||||
{SegmentID: 1, PartitionID: 1},
|
||||
{SegmentID: 2, PartitionID: 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
NodeID: 2,
|
||||
Segments: []SegmentEntry{
|
||||
{SegmentID: 3, PartitionID: 1},
|
||||
{SegmentID: 4, PartitionID: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedGrowing: []SegmentEntry{
|
||||
{
|
||||
SegmentID: 5,
|
||||
PartitionID: 1,
|
||||
},
|
||||
{
|
||||
SegmentID: 6,
|
||||
PartitionID: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
tag: "partition_1",
|
||||
partitions: []int64{1},
|
||||
expectedSealed: []SnapshotItem{
|
||||
{
|
||||
NodeID: 1,
|
||||
Segments: []SegmentEntry{
|
||||
{SegmentID: 1, PartitionID: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
NodeID: 2,
|
||||
Segments: []SegmentEntry{
|
||||
{SegmentID: 3, PartitionID: 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedGrowing: []SegmentEntry{
|
||||
{
|
||||
SegmentID: 5,
|
||||
PartitionID: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
tag: "partition_2",
|
||||
partitions: []int64{2},
|
||||
expectedSealed: []SnapshotItem{
|
||||
{
|
||||
NodeID: 1,
|
||||
Segments: []SegmentEntry{
|
||||
{SegmentID: 2, PartitionID: 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
NodeID: 2,
|
||||
Segments: []SegmentEntry{
|
||||
{SegmentID: 4, PartitionID: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedGrowing: []SegmentEntry{
|
||||
{
|
||||
SegmentID: 6,
|
||||
PartitionID: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
tag: "partition not exists",
|
||||
partitions: []int64{3},
|
||||
expectedSealed: []SnapshotItem{
|
||||
{
|
||||
NodeID: 1,
|
||||
Segments: []SegmentEntry{},
|
||||
},
|
||||
{
|
||||
NodeID: 2,
|
||||
Segments: []SegmentEntry{},
|
||||
},
|
||||
},
|
||||
expectedGrowing: []SegmentEntry{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
s.Run(tc.tag, func() {
|
||||
before := s.snapshot.inUse.Load()
|
||||
sealed, growing := s.snapshot.Get(tc.partitions...)
|
||||
after := s.snapshot.inUse.Load()
|
||||
s.ElementsMatch(tc.expectedSealed, sealed)
|
||||
s.ElementsMatch(tc.expectedGrowing, growing)
|
||||
s.EqualValues(1, after-before)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SnapshotSuite) TestDone() {
|
||||
s.Run("done not expired snapshot", func() {
|
||||
inUse := s.snapshot.inUse.Load()
|
||||
s.Require().EqualValues(0, inUse)
|
||||
s.snapshot.Get()
|
||||
inUse = s.snapshot.inUse.Load()
|
||||
s.Require().EqualValues(1, inUse)
|
||||
|
||||
s.snapshot.Done(func() {})
|
||||
inUse = s.snapshot.inUse.Load()
|
||||
s.EqualValues(0, inUse)
|
||||
// check cleared channel closed
|
||||
select {
|
||||
case <-s.snapshot.cleared:
|
||||
s.Fail("snapshot channel closed in non-expired state")
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
s.Run("done expired snapshot", func() {
|
||||
inUse := s.snapshot.inUse.Load()
|
||||
s.Require().EqualValues(0, inUse)
|
||||
s.snapshot.Get()
|
||||
inUse = s.snapshot.inUse.Load()
|
||||
s.Require().EqualValues(1, inUse)
|
||||
s.snapshot.Expire(func() {})
|
||||
// check cleared channel closed
|
||||
select {
|
||||
case <-s.snapshot.cleared:
|
||||
s.FailNow("snapshot channel closed in non-expired state")
|
||||
default:
|
||||
}
|
||||
|
||||
signal := make(chan struct{})
|
||||
s.snapshot.Done(func() { close(signal) })
|
||||
inUse = s.snapshot.inUse.Load()
|
||||
s.EqualValues(0, inUse)
|
||||
timeout := time.NewTimer(time.Second)
|
||||
defer timeout.Stop()
|
||||
select {
|
||||
case <-timeout.C:
|
||||
s.FailNow("cleanup never called")
|
||||
case <-signal:
|
||||
}
|
||||
timeout = time.NewTimer(10 * time.Millisecond)
|
||||
defer timeout.Stop()
|
||||
select {
|
||||
case <-timeout.C:
|
||||
s.FailNow("snapshot channel not closed after expired and no use")
|
||||
case <-s.snapshot.cleared:
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSnapshot(t *testing.T) {
|
||||
suite.Run(t, new(SnapshotSuite))
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package delegator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
const (
|
||||
rowIDFieldID FieldID = 0
|
||||
timestampFieldID FieldID = 1
|
||||
)
|
||||
|
||||
type (
|
||||
// UniqueID is an identifier that is guaranteed to be unique among all the collections, partitions and segments
|
||||
UniqueID = typeutil.UniqueID
|
||||
// Timestamp is timestamp
|
||||
Timestamp = typeutil.Timestamp
|
||||
// FieldID is to uniquely identify the field
|
||||
FieldID = int64
|
||||
// IntPrimaryKey is the primary key of int type
|
||||
IntPrimaryKey = typeutil.IntPrimaryKey
|
||||
// DSL is the Domain Specific Language
|
||||
DSL = string
|
||||
// ConsumeSubName is consumer's subscription name of the message stream
|
||||
ConsumeSubName = string
|
||||
)
|
||||
|
||||
// TimeRange is a range of time periods
|
||||
type TimeRange struct {
|
||||
timestampMin Timestamp
|
||||
timestampMax Timestamp
|
||||
}
|
||||
|
||||
// loadType is load collection or load partition
|
||||
type loadType = querypb.LoadType
|
||||
|
||||
const (
|
||||
loadTypeCollection = querypb.LoadType_LoadCollection
|
||||
loadTypePartition = querypb.LoadType_LoadPartition
|
||||
)
|
||||
|
||||
// TSafeUpdater is the interface for type provides tsafe update event
|
||||
type TSafeUpdater interface {
|
||||
RegisterChannel(string) chan Timestamp
|
||||
UnregisterChannel(string) error
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrTsLagTooLarge serviceable and guarantee lag too large.
|
||||
ErrTsLagTooLarge = errors.New("Timestamp lag too large")
|
||||
)
|
||||
|
||||
// WrapErrTsLagTooLarge wraps ErrTsLagTooLarge with lag and max value.
|
||||
func WrapErrTsLagTooLarge(duration time.Duration, maxLag time.Duration) error {
|
||||
return fmt.Errorf("%w lag(%s) max(%s)", ErrTsLagTooLarge, duration, maxLag)
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package querynodev2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNodeUnhealthy = errors.New("NodeIsUnhealthy")
|
||||
ErrGetDelegatorFailed = errors.New("GetShardDelefatorFailed")
|
||||
ErrInitPipelineFailed = errors.New("InitPipelineFailed")
|
||||
)
|
||||
|
||||
// WrapErrNodeUnhealthy wraps ErrNodeUnhealthy with nodeID.
|
||||
func WrapErrNodeUnhealthy(nodeID int64) error {
|
||||
return fmt.Errorf("%w id: %d", ErrNodeUnhealthy, nodeID)
|
||||
}
|
||||
|
||||
func WrapErrInitPipelineFailed(err error) error {
|
||||
return fmt.Errorf("%w err: %s", ErrInitPipelineFailed, err.Error())
|
||||
}
|
||||
|
||||
func msgQueryNodeIsUnhealthy(nodeID int64) string {
|
||||
return fmt.Sprintf("query node %d is not ready", nodeID)
|
||||
}
|
|
@ -0,0 +1,455 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package querynodev2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tasks"
|
||||
"github.com/milvus-io/milvus/internal/util"
|
||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/timerecord"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func loadGrowingSegments(ctx context.Context, delegator delegator.ShardDelegator, req *querypb.WatchDmChannelsRequest) error {
|
||||
// load growing segments
|
||||
growingSegments := make([]*querypb.SegmentLoadInfo, 0, len(req.Infos))
|
||||
for _, info := range req.Infos {
|
||||
for _, segmentID := range info.GetUnflushedSegmentIds() {
|
||||
// unFlushed segment may not have binLogs, skip loading
|
||||
segmentInfo := req.GetSegmentInfos()[segmentID]
|
||||
if segmentInfo == nil {
|
||||
log.Warn("an unflushed segment is not found in segment infos", zap.Int64("segment ID", segmentID))
|
||||
continue
|
||||
}
|
||||
if len(segmentInfo.GetBinlogs()) > 0 {
|
||||
growingSegments = append(growingSegments, &querypb.SegmentLoadInfo{
|
||||
SegmentID: segmentInfo.ID,
|
||||
PartitionID: segmentInfo.PartitionID,
|
||||
CollectionID: segmentInfo.CollectionID,
|
||||
BinlogPaths: segmentInfo.Binlogs,
|
||||
NumOfRows: segmentInfo.NumOfRows,
|
||||
Statslogs: segmentInfo.Statslogs,
|
||||
Deltalogs: segmentInfo.Deltalogs,
|
||||
InsertChannel: segmentInfo.InsertChannel,
|
||||
})
|
||||
} else {
|
||||
log.Info("skip segment which binlog is empty", zap.Int64("segmentID", segmentInfo.ID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return delegator.LoadGrowing(ctx, growingSegments, req.GetVersion())
|
||||
}
|
||||
|
||||
func (node *QueryNode) loadDeltaLogs(ctx context.Context, req *querypb.LoadSegmentsRequest) *commonpb.Status {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", req.GetCollectionID()),
|
||||
)
|
||||
loadedSegments := make([]int64, 0, len(req.GetInfos()))
|
||||
var finalErr error
|
||||
for _, info := range req.GetInfos() {
|
||||
segment := node.manager.Segment.GetSealed(info.GetSegmentID())
|
||||
if segment == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
local := segment.(*segments.LocalSegment)
|
||||
err := node.loader.LoadDeltaLogs(ctx, local, info.GetDeltalogs())
|
||||
if err != nil {
|
||||
if finalErr == nil {
|
||||
finalErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
loadedSegments = append(loadedSegments, info.GetSegmentID())
|
||||
}
|
||||
|
||||
if finalErr != nil {
|
||||
log.Warn("failed to load delta logs", zap.Error(finalErr))
|
||||
return util.WrapStatus(commonpb.ErrorCode_UnexpectedError, "failed to load delta logs", finalErr)
|
||||
}
|
||||
|
||||
return util.SuccessStatus()
|
||||
}
|
||||
|
||||
func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryRequest, channel string) (*internalpb.RetrieveResults, error) {
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel).Inc()
|
||||
failRet := WrapRetrieveResult(commonpb.ErrorCode_UnexpectedError, "")
|
||||
msgID := req.Req.Base.GetMsgID()
|
||||
traceID := trace.SpanFromContext(ctx).SpanContext().TraceID()
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("msgID", msgID),
|
||||
zap.Int64("collectionID", req.GetReq().GetCollectionID()),
|
||||
zap.String("channel", channel),
|
||||
zap.String("scope", req.GetScope().String()),
|
||||
)
|
||||
|
||||
defer func() {
|
||||
if failRet.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel).Inc()
|
||||
}
|
||||
}()
|
||||
|
||||
if !node.lifetime.Add(commonpbutil.IsHealthy) {
|
||||
failRet.Status.Reason = msgQueryNodeIsUnhealthy(paramtable.GetNodeID())
|
||||
return failRet, nil
|
||||
}
|
||||
defer node.lifetime.Done()
|
||||
|
||||
log.Debug("start do query with channel",
|
||||
zap.Bool("fromShardLeader", req.GetFromShardLeader()),
|
||||
zap.Int64s("segmentIDs", req.GetSegmentIDs()),
|
||||
)
|
||||
// add cancel when error occurs
|
||||
queryCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// TODO From Shard Delegator
|
||||
if req.FromShardLeader {
|
||||
tr := timerecord.NewTimeRecorder("queryChannel")
|
||||
results, err := node.querySegments(queryCtx, req)
|
||||
if err != nil {
|
||||
log.Warn("failed to query channel", zap.Error(err))
|
||||
failRet.Status.Reason = err.Error()
|
||||
return failRet, nil
|
||||
}
|
||||
|
||||
tr.CtxElapse(ctx, fmt.Sprintf("do query done, traceID = %s, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
|
||||
traceID,
|
||||
req.GetFromShardLeader(),
|
||||
channel,
|
||||
req.GetSegmentIDs(),
|
||||
))
|
||||
|
||||
failRet.Status.ErrorCode = commonpb.ErrorCode_Success
|
||||
// TODO QueryNodeSQLatencyInQueue QueryNodeReduceLatency
|
||||
latency := tr.ElapseSpan()
|
||||
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel).Inc()
|
||||
return results, nil
|
||||
}
|
||||
// From Proxy
|
||||
tr := timerecord.NewTimeRecorder("queryDelegator")
|
||||
// get delegator
|
||||
sd, ok := node.delegators.Get(channel)
|
||||
if !ok {
|
||||
log.Warn("Query failed, failed to get query shard delegator", zap.Error(ErrGetDelegatorFailed))
|
||||
failRet.Status.Reason = ErrGetDelegatorFailed.Error()
|
||||
return failRet, nil
|
||||
}
|
||||
|
||||
// do query
|
||||
results, err := sd.Query(queryCtx, req)
|
||||
if err != nil {
|
||||
log.Warn("failed to query on delegator", zap.Error(err))
|
||||
failRet.Status.Reason = err.Error()
|
||||
return failRet, nil
|
||||
}
|
||||
|
||||
// reduce result
|
||||
tr.CtxElapse(ctx, fmt.Sprintf("start reduce query result, traceID = %s, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
|
||||
traceID,
|
||||
req.GetFromShardLeader(),
|
||||
channel,
|
||||
req.GetSegmentIDs(),
|
||||
))
|
||||
|
||||
collection := node.manager.Collection.Get(req.Req.GetCollectionID())
|
||||
if collection == nil {
|
||||
log.Warn("Query failed, failed to get collection")
|
||||
failRet.Status.Reason = segments.WrapCollectionNotFound(req.Req.CollectionID).Error()
|
||||
return failRet, nil
|
||||
}
|
||||
|
||||
ret, err := segments.MergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, req.Req.GetLimit(), req.GetReq().GetOutputFieldsId(), collection.Schema())
|
||||
if err != nil {
|
||||
failRet.Status.Reason = err.Error()
|
||||
return failRet, nil
|
||||
}
|
||||
|
||||
tr.CtxElapse(ctx, fmt.Sprintf("do query with channel done , vChannel = %s, segmentIDs = %v",
|
||||
channel,
|
||||
req.GetSegmentIDs(),
|
||||
))
|
||||
|
||||
//
|
||||
failRet.Status.ErrorCode = commonpb.ErrorCode_Success
|
||||
latency := tr.ElapseSpan()
|
||||
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel).Inc()
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) querySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
|
||||
collection := node.manager.Collection.Get(req.Req.GetCollectionID())
|
||||
if collection == nil {
|
||||
return nil, segments.ErrCollectionNotFound
|
||||
}
|
||||
// build plan
|
||||
retrievePlan, err := segments.NewRetrievePlan(
|
||||
collection,
|
||||
req.Req.GetSerializedExprPlan(),
|
||||
req.Req.GetTravelTimestamp(),
|
||||
req.Req.Base.GetMsgID(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer retrievePlan.Delete()
|
||||
|
||||
var results []*segcorepb.RetrieveResults
|
||||
if req.GetScope() == querypb.DataScope_Historical {
|
||||
results, _, _, err = segments.RetrieveHistorical(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs(), node.cacheChunkManager)
|
||||
} else {
|
||||
results, _, _, err = segments.RetrieveStreaming(ctx, node.manager, retrievePlan, req.Req.CollectionID, nil, req.GetSegmentIDs(), node.cacheChunkManager)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reducedResult, err := segments.MergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, req.Req.GetLimit(), req.Req.GetOutputFieldsId(), collection.Schema())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &internalpb.RetrieveResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
Ids: reducedResult.Ids,
|
||||
FieldsData: reducedResult.FieldsData,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchRequest, channel string) (*internalpb.SearchResults, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
|
||||
zap.Int64("collectionID", req.Req.GetCollectionID()),
|
||||
zap.String("channel", channel),
|
||||
zap.String("scope", req.GetScope().String()),
|
||||
)
|
||||
traceID := trace.SpanFromContext(ctx).SpanContext().TraceID()
|
||||
|
||||
if !node.lifetime.Add(commonpbutil.IsHealthy) {
|
||||
return nil, WrapErrNodeUnhealthy(paramtable.GetNodeID())
|
||||
}
|
||||
defer node.lifetime.Done()
|
||||
|
||||
log.Debug("start to search channel",
|
||||
zap.Bool("fromShardLeader", req.GetFromShardLeader()),
|
||||
zap.Int64s("segmentIDs", req.GetSegmentIDs()),
|
||||
)
|
||||
searchCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// TODO From Shard Delegator
|
||||
if req.GetFromShardLeader() {
|
||||
tr := timerecord.NewTimeRecorder("searchChannel")
|
||||
log.Debug("search channel...")
|
||||
|
||||
collection := node.manager.Collection.Get(req.Req.GetCollectionID())
|
||||
if collection == nil {
|
||||
log.Warn("failed to search channel", zap.Error(segments.ErrCollectionNotFound))
|
||||
return nil, segments.WrapCollectionNotFound(req.GetReq().GetCollectionID())
|
||||
}
|
||||
|
||||
task := tasks.NewSearchTask(searchCtx, collection, node.manager, req)
|
||||
if !node.scheduler.Add(task) {
|
||||
log.Warn("failed to search channel", zap.Error(tasks.ErrTaskQueueFull))
|
||||
return nil, tasks.ErrTaskQueueFull
|
||||
}
|
||||
|
||||
err := task.Wait()
|
||||
if err != nil {
|
||||
log.Warn("failed to search channel", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tr.CtxElapse(ctx, fmt.Sprintf("search channel done, channel = %s, segmentIDs = %v",
|
||||
channel,
|
||||
req.GetSegmentIDs(),
|
||||
))
|
||||
|
||||
// TODO QueryNodeSQLatencyInQueue QueryNodeReduceLatency
|
||||
latency := tr.ElapseSpan()
|
||||
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel).Inc()
|
||||
return task.Result(), nil
|
||||
}
|
||||
|
||||
// From Proxy
|
||||
tr := timerecord.NewTimeRecorder("searchDelegator")
|
||||
// get delegator
|
||||
sd, ok := node.delegators.Get(channel)
|
||||
if !ok {
|
||||
log.Warn("Query failed, failed to get query shard delegator", zap.Error(ErrGetDelegatorFailed))
|
||||
return nil, ErrGetDelegatorFailed
|
||||
}
|
||||
// do search
|
||||
results, err := sd.Search(searchCtx, req)
|
||||
if err != nil {
|
||||
log.Warn("failed to search on delegator", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// reduce result
|
||||
tr.CtxElapse(ctx, fmt.Sprintf("start reduce query result, traceID = %s, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v",
|
||||
traceID,
|
||||
req.GetFromShardLeader(),
|
||||
channel,
|
||||
req.GetSegmentIDs(),
|
||||
))
|
||||
|
||||
ret, err := segments.ReduceSearchResults(ctx, results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tr.CtxElapse(ctx, fmt.Sprintf("do search with channel done , vChannel = %s, segmentIDs = %v",
|
||||
channel,
|
||||
req.GetSegmentIDs(),
|
||||
))
|
||||
|
||||
// update metric to prometheus
|
||||
latency := tr.ElapseSpan()
|
||||
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds()))
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel).Inc()
|
||||
metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetNq()))
|
||||
metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetTopk()))
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.GetStatisticsRequest, channel string) (*internalpb.GetStatisticsResponse, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", req.Req.GetCollectionID()),
|
||||
zap.String("channel", channel),
|
||||
zap.String("scope", req.GetScope().String()),
|
||||
)
|
||||
failRet := &internalpb.GetStatisticsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
},
|
||||
}
|
||||
|
||||
if req.GetFromShardLeader() {
|
||||
var results []segments.SegmentStats
|
||||
var err error
|
||||
|
||||
switch req.GetScope() {
|
||||
case querypb.DataScope_Historical:
|
||||
results, _, _, err = segments.StatisticsHistorical(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs())
|
||||
case querypb.DataScope_Streaming:
|
||||
results, _, _, err = segments.StatisticStreaming(ctx, node.manager, req.Req.GetCollectionID(), req.Req.GetPartitionIDs(), req.GetSegmentIDs())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warn("get segments statistics failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
return segmentStatsResponse(results), nil
|
||||
}
|
||||
|
||||
sd, ok := node.delegators.Get(channel)
|
||||
if !ok {
|
||||
log.Warn("GetStatistics failed, failed to get query shard delegator")
|
||||
return failRet, nil
|
||||
}
|
||||
|
||||
results, err := sd.GetStatistics(ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("failed to get statistics from delegator", zap.Error(err))
|
||||
failRet.Status.Reason = err.Error()
|
||||
return failRet, nil
|
||||
}
|
||||
ret, err := reduceStatisticResponse(results)
|
||||
if err != nil {
|
||||
failRet.Status.Reason = err.Error()
|
||||
return failRet, nil
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func segmentStatsResponse(segStats []segments.SegmentStats) *internalpb.GetStatisticsResponse {
|
||||
var totalRowNum int64
|
||||
for _, stats := range segStats {
|
||||
totalRowNum += stats.RowCount
|
||||
}
|
||||
|
||||
resultMap := make(map[string]string)
|
||||
resultMap["row_count"] = strconv.FormatInt(totalRowNum, 10)
|
||||
|
||||
ret := &internalpb.GetStatisticsResponse{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
Stats: funcutil.Map2KeyValuePair(resultMap),
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func reduceStatisticResponse(results []*internalpb.GetStatisticsResponse) (*internalpb.GetStatisticsResponse, error) {
|
||||
mergedResults := map[string]interface{}{
|
||||
"row_count": int64(0),
|
||||
}
|
||||
fieldMethod := map[string]func(string) error{
|
||||
"row_count": func(str string) error {
|
||||
count, err := strconv.ParseInt(str, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mergedResults["row_count"] = mergedResults["row_count"].(int64) + count
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
for _, partialResult := range results {
|
||||
for _, pair := range partialResult.Stats {
|
||||
fn, ok := fieldMethod[pair.Key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown statistic field: %s", pair.Key)
|
||||
}
|
||||
if err := fn(pair.Value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stringMap := make(map[string]string)
|
||||
for k, v := range mergedResults {
|
||||
stringMap[k] = fmt.Sprint(v)
|
||||
}
|
||||
|
||||
ret := &internalpb.GetStatisticsResponse{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
Stats: funcutil.Map2KeyValuePair(stringMap),
|
||||
}
|
||||
return ret, nil
|
||||
}
|
|
@ -0,0 +1,138 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package querynodev2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/etcd"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
)
|
||||
|
||||
type HandlersSuite struct {
|
||||
suite.Suite
|
||||
// Data
|
||||
collectionID int64
|
||||
collectionName string
|
||||
segmentID int64
|
||||
channel string
|
||||
|
||||
// Dependency
|
||||
params *paramtable.ComponentParam
|
||||
node *QueryNode
|
||||
etcd *clientv3.Client
|
||||
chunkManagerFactory *storage.ChunkManagerFactory
|
||||
|
||||
// Mock
|
||||
factory *dependency.MockFactory
|
||||
}
|
||||
|
||||
func (suite *HandlersSuite) SetupSuite() {
|
||||
suite.collectionID = 111
|
||||
suite.collectionName = "test-collection"
|
||||
suite.segmentID = 1
|
||||
suite.channel = "test-channel"
|
||||
}
|
||||
|
||||
func (suite *HandlersSuite) SetupTest() {
|
||||
var err error
|
||||
paramtable.Init()
|
||||
suite.params = paramtable.Get()
|
||||
suite.params.Save(suite.params.QueryNodeCfg.GCEnabled.Key, "false")
|
||||
|
||||
// mock factory
|
||||
suite.factory = dependency.NewMockFactory(suite.T())
|
||||
suite.chunkManagerFactory = storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus_test"))
|
||||
|
||||
// new node
|
||||
suite.node = NewQueryNode(context.Background(), suite.factory)
|
||||
// init etcd
|
||||
suite.etcd, err = etcd.GetEtcdClient(
|
||||
suite.params.EtcdCfg.UseEmbedEtcd.GetAsBool(),
|
||||
suite.params.EtcdCfg.EtcdUseSSL.GetAsBool(),
|
||||
suite.params.EtcdCfg.Endpoints.GetAsStrings(),
|
||||
suite.params.EtcdCfg.EtcdTLSCert.GetValue(),
|
||||
suite.params.EtcdCfg.EtcdTLSKey.GetValue(),
|
||||
suite.params.EtcdCfg.EtcdTLSCACert.GetValue(),
|
||||
suite.params.EtcdCfg.EtcdTLSMinVersion.GetValue())
|
||||
suite.NoError(err)
|
||||
}
|
||||
|
||||
func (suite *HandlersSuite) TearDownTest() {
|
||||
suite.etcd.Close()
|
||||
os.RemoveAll("/tmp/milvus-test")
|
||||
}
|
||||
|
||||
func (suite *HandlersSuite) TestLoadGrowingSegments() {
|
||||
ctx := context.Background()
|
||||
var err error
|
||||
// mock
|
||||
loadSegmetns := []int64{}
|
||||
delegator := delegator.NewMockShardDelegator(suite.T())
|
||||
delegator.EXPECT().LoadGrowing(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) {
|
||||
for _, info := range infos {
|
||||
loadSegmetns = append(loadSegmetns, info.SegmentID)
|
||||
}
|
||||
}).Return(nil)
|
||||
|
||||
req := &querypb.WatchDmChannelsRequest{
|
||||
Infos: []*datapb.VchannelInfo{
|
||||
{
|
||||
CollectionID: suite.collectionID,
|
||||
ChannelName: suite.channel,
|
||||
UnflushedSegmentIds: []int64{suite.segmentID},
|
||||
},
|
||||
},
|
||||
SegmentInfos: make(map[int64]*datapb.SegmentInfo),
|
||||
}
|
||||
|
||||
// unflushed segment not in segmentInfos, will skip
|
||||
err = loadGrowingSegments(ctx, delegator, req)
|
||||
suite.NoError(err)
|
||||
suite.Equal(0, len(loadSegmetns))
|
||||
|
||||
// binlog was empty, will skip
|
||||
req.SegmentInfos[suite.segmentID] = &datapb.SegmentInfo{
|
||||
ID: suite.segmentID,
|
||||
CollectionID: suite.collectionID,
|
||||
Binlogs: make([]*datapb.FieldBinlog, 0),
|
||||
}
|
||||
err = loadGrowingSegments(ctx, delegator, req)
|
||||
suite.NoError(err)
|
||||
suite.Equal(0, len(loadSegmetns))
|
||||
|
||||
// normal load
|
||||
binlog := &datapb.FieldBinlog{}
|
||||
req.SegmentInfos[suite.segmentID].Binlogs = append(req.SegmentInfos[suite.segmentID].Binlogs, binlog)
|
||||
err = loadGrowingSegments(ctx, delegator, req)
|
||||
suite.NoError(err)
|
||||
suite.Equal(1, len(loadSegmetns))
|
||||
}
|
||||
|
||||
func TestHandlersSuite(t *testing.T) {
|
||||
suite.Run(t, new(HandlersSuite))
|
||||
}
|
|
@ -0,0 +1,103 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package querynodev2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var _ cluster.Worker = &LocalWorker{}
|
||||
|
||||
type LocalWorker struct {
|
||||
node *QueryNode
|
||||
}
|
||||
|
||||
func NewLocalWorker(node *QueryNode) *LocalWorker {
|
||||
return &LocalWorker{
|
||||
node: node,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *LocalWorker) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error {
|
||||
log := log.Ctx(ctx)
|
||||
log.Info("start to load segments...")
|
||||
loaded, err := w.node.loader.Load(ctx,
|
||||
req.GetCollectionID(),
|
||||
segments.SegmentTypeSealed,
|
||||
req.GetVersion(),
|
||||
req.GetInfos()...,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("save loaded segments...",
|
||||
zap.Int64s("segments", lo.Map(loaded, func(s segments.Segment, _ int) int64 { return s.ID() })))
|
||||
w.node.manager.Segment.Put(segments.SegmentTypeSealed, loaded...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *LocalWorker) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) error {
|
||||
log := log.Ctx(ctx)
|
||||
log.Info("start to release segments")
|
||||
for _, id := range req.GetSegmentIDs() {
|
||||
w.node.manager.Segment.Remove(id, req.GetScope())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *LocalWorker) Delete(ctx context.Context, req *querypb.DeleteRequest) error {
|
||||
log := log.Ctx(ctx)
|
||||
log.Info("start to process segment delete")
|
||||
status, err := w.node.Delete(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if status.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
return fmt.Errorf(status.GetReason())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *LocalWorker) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
|
||||
return w.node.Search(ctx, req)
|
||||
}
|
||||
|
||||
func (w *LocalWorker) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
|
||||
return w.node.Query(ctx, req)
|
||||
}
|
||||
|
||||
func (w *LocalWorker) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) {
|
||||
return w.node.GetStatistics(ctx, req)
|
||||
}
|
||||
|
||||
func (w *LocalWorker) IsHealthy() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (w *LocalWorker) Stop() {
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package querynodev2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/etcd"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/suite"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
)
|
||||
|
||||
type LocalWorkerTestSuite struct {
|
||||
suite.Suite
|
||||
params *paramtable.ComponentParam
|
||||
// data
|
||||
collectionID int64
|
||||
collectionName string
|
||||
channel string
|
||||
partitionIDs []int64
|
||||
segmentIDs []int64
|
||||
schema *schemapb.CollectionSchema
|
||||
|
||||
// dependency
|
||||
node *QueryNode
|
||||
worker *LocalWorker
|
||||
etcdClient *clientv3.Client
|
||||
// context
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (suite *LocalWorkerTestSuite) SetupSuite() {
|
||||
suite.collectionID = 111
|
||||
suite.collectionName = "test-collection"
|
||||
suite.channel = "test-channel"
|
||||
suite.partitionIDs = []int64{11, 22}
|
||||
suite.segmentIDs = []int64{0, 1}
|
||||
}
|
||||
|
||||
func (suite *LocalWorkerTestSuite) BeforeTest(suiteName, testName string) {
|
||||
var err error
|
||||
// init param
|
||||
paramtable.Init()
|
||||
suite.params = paramtable.Get()
|
||||
// close GC at test to avoid data race
|
||||
suite.params.Save(suite.params.QueryNodeCfg.GCEnabled.Key, "false")
|
||||
|
||||
suite.ctx, suite.cancel = context.WithCancel(context.Background())
|
||||
// init node
|
||||
factory := dependency.NewDefaultFactory(true)
|
||||
suite.node = NewQueryNode(suite.ctx, factory)
|
||||
// init etcd
|
||||
suite.etcdClient, err = etcd.GetEtcdClient(
|
||||
suite.params.EtcdCfg.UseEmbedEtcd.GetAsBool(),
|
||||
suite.params.EtcdCfg.EtcdUseSSL.GetAsBool(),
|
||||
suite.params.EtcdCfg.Endpoints.GetAsStrings(),
|
||||
suite.params.EtcdCfg.EtcdTLSCert.GetValue(),
|
||||
suite.params.EtcdCfg.EtcdTLSKey.GetValue(),
|
||||
suite.params.EtcdCfg.EtcdTLSCACert.GetValue(),
|
||||
suite.params.EtcdCfg.EtcdTLSMinVersion.GetValue())
|
||||
suite.NoError(err)
|
||||
suite.node.SetEtcdClient(suite.etcdClient)
|
||||
err = suite.node.Init()
|
||||
suite.NoError(err)
|
||||
err = suite.node.Start()
|
||||
suite.NoError(err)
|
||||
|
||||
suite.schema = segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
|
||||
collection := segments.NewCollection(suite.collectionID, suite.schema, querypb.LoadType_LoadCollection)
|
||||
loadMata := &querypb.LoadMetaInfo{
|
||||
LoadType: querypb.LoadType_LoadCollection,
|
||||
CollectionID: suite.collectionID,
|
||||
}
|
||||
suite.node.manager.Collection.Put(suite.collectionID, collection.Schema(), loadMata)
|
||||
suite.worker = NewLocalWorker(suite.node)
|
||||
}
|
||||
|
||||
func (suite *LocalWorkerTestSuite) AfterTest(suiteName, testName string) {
|
||||
suite.node.Stop()
|
||||
suite.etcdClient.Close()
|
||||
suite.cancel()
|
||||
}
|
||||
|
||||
func (suite *LocalWorkerTestSuite) TestLoadSegment() {
|
||||
// load empty
|
||||
req := &querypb.LoadSegmentsRequest{
|
||||
CollectionID: suite.collectionID,
|
||||
Infos: lo.Map(suite.segmentIDs, func(segID int64, _ int) *querypb.SegmentLoadInfo {
|
||||
return &querypb.SegmentLoadInfo{
|
||||
CollectionID: suite.collectionID,
|
||||
PartitionID: suite.partitionIDs[segID%2],
|
||||
SegmentID: segID,
|
||||
}
|
||||
}),
|
||||
}
|
||||
err := suite.worker.LoadSegments(suite.ctx, req)
|
||||
suite.NoError(err)
|
||||
}
|
||||
|
||||
func (suite *LocalWorkerTestSuite) TestReleaseSegment() {
|
||||
req := &querypb.ReleaseSegmentsRequest{
|
||||
CollectionID: suite.collectionID,
|
||||
SegmentIDs: suite.segmentIDs,
|
||||
}
|
||||
err := suite.worker.ReleaseSegments(suite.ctx, req)
|
||||
suite.NoError(err)
|
||||
}
|
||||
|
||||
func TestLocalWorker(t *testing.T) {
|
||||
suite.Run(t, new(LocalWorkerTestSuite))
|
||||
}
|
|
@ -0,0 +1,183 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package querynodev2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/collector"
|
||||
"github.com/milvus-io/milvus/internal/util/hardware"
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/ratelimitutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
func getRateMetric() ([]metricsinfo.RateMetric, error) {
|
||||
rms := make([]metricsinfo.RateMetric, 0)
|
||||
for _, label := range collector.RateMetrics() {
|
||||
rate, err := collector.Rate.Rate(label, ratelimitutil.DefaultAvgDuration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rms = append(rms, metricsinfo.RateMetric{
|
||||
Label: label,
|
||||
Rate: rate,
|
||||
})
|
||||
}
|
||||
return rms, nil
|
||||
}
|
||||
|
||||
func getSearchNQInQueue() (metricsinfo.ReadInfoInQueue, error) {
|
||||
average, err := collector.Average.Average(metricsinfo.SearchQueueMetric)
|
||||
if err != nil {
|
||||
return metricsinfo.ReadInfoInQueue{}, err
|
||||
}
|
||||
|
||||
unsolvedQueueLabel := collector.ConstructLabel(metricsinfo.UnsolvedQueueType, metricsinfo.SearchQueueMetric)
|
||||
readyQueueLabel := collector.ConstructLabel(metricsinfo.ReadyQueueType, metricsinfo.SearchQueueMetric)
|
||||
receiveQueueLabel := collector.ConstructLabel(metricsinfo.ReceiveQueueType, metricsinfo.SearchQueueMetric)
|
||||
executeQueueLabel := collector.ConstructLabel(metricsinfo.ExecuteQueueType, metricsinfo.SearchQueueMetric)
|
||||
|
||||
return metricsinfo.ReadInfoInQueue{
|
||||
UnsolvedQueue: collector.Counter.Get(unsolvedQueueLabel),
|
||||
ReadyQueue: collector.Counter.Get(readyQueueLabel),
|
||||
ReceiveChan: collector.Counter.Get(receiveQueueLabel),
|
||||
ExecuteChan: collector.Counter.Get(executeQueueLabel),
|
||||
AvgQueueDuration: time.Duration(int64(average)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getQueryTasksInQueue() (metricsinfo.ReadInfoInQueue, error) {
|
||||
average, err := collector.Average.Average(metricsinfo.QueryQueueMetric)
|
||||
if err != nil {
|
||||
return metricsinfo.ReadInfoInQueue{}, err
|
||||
}
|
||||
|
||||
unsolvedQueueLabel := collector.ConstructLabel(metricsinfo.UnsolvedQueueType, metricsinfo.QueryQueueMetric)
|
||||
readyQueueLabel := collector.ConstructLabel(metricsinfo.ReadyQueueType, metricsinfo.QueryQueueMetric)
|
||||
receiveQueueLabel := collector.ConstructLabel(metricsinfo.ReceiveQueueType, metricsinfo.QueryQueueMetric)
|
||||
executeQueueLabel := collector.ConstructLabel(metricsinfo.ExecuteQueueType, metricsinfo.QueryQueueMetric)
|
||||
|
||||
return metricsinfo.ReadInfoInQueue{
|
||||
UnsolvedQueue: collector.Counter.Get(unsolvedQueueLabel),
|
||||
ReadyQueue: collector.Counter.Get(readyQueueLabel),
|
||||
ReceiveChan: collector.Counter.Get(receiveQueueLabel),
|
||||
ExecuteChan: collector.Counter.Get(executeQueueLabel),
|
||||
AvgQueueDuration: time.Duration(int64(average)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getQuotaMetrics returns QueryNodeQuotaMetrics.
|
||||
func getQuotaMetrics(node *QueryNode) (*metricsinfo.QueryNodeQuotaMetrics, error) {
|
||||
rms, err := getRateMetric()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sqms, err := getSearchNQInQueue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qqms, err := getQueryTasksInQueue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
minTsafeChannel, minTsafe := node.tSafeManager.Min()
|
||||
return &metricsinfo.QueryNodeQuotaMetrics{
|
||||
Hms: metricsinfo.HardwareMetrics{},
|
||||
Rms: rms,
|
||||
Fgm: metricsinfo.FlowGraphMetric{
|
||||
MinFlowGraphChannel: minTsafeChannel,
|
||||
MinFlowGraphTt: minTsafe,
|
||||
NumFlowGraph: node.pipelineManager.Num(),
|
||||
},
|
||||
SearchQueue: sqms,
|
||||
QueryQueue: qqms,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getSystemInfoMetrics returns metrics info of QueryNode
|
||||
func getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, node *QueryNode) (*milvuspb.GetMetricsResponse, error) {
|
||||
usedMem := hardware.GetUsedMemoryCount()
|
||||
totalMem := hardware.GetMemoryCount()
|
||||
|
||||
quotaMetrics, err := getQuotaMetrics(node)
|
||||
if err != nil {
|
||||
return &milvuspb.GetMetricsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
ComponentName: metricsinfo.ConstructComponentName(typeutil.DataNodeRole, paramtable.GetNodeID()),
|
||||
}, nil
|
||||
}
|
||||
hardwareInfos := metricsinfo.HardwareMetrics{
|
||||
IP: node.session.Address,
|
||||
CPUCoreCount: hardware.GetCPUNum(),
|
||||
CPUCoreUsage: hardware.GetCPUUsage(),
|
||||
Memory: totalMem,
|
||||
MemoryUsage: usedMem,
|
||||
Disk: hardware.GetDiskCount(),
|
||||
DiskUsage: hardware.GetDiskUsage(),
|
||||
}
|
||||
quotaMetrics.Hms = hardwareInfos
|
||||
|
||||
nodeInfos := metricsinfo.QueryNodeInfos{
|
||||
BaseComponentInfos: metricsinfo.BaseComponentInfos{
|
||||
Name: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, paramtable.GetNodeID()),
|
||||
HardwareInfos: hardwareInfos,
|
||||
SystemInfo: metricsinfo.DeployMetrics{},
|
||||
CreatedTime: paramtable.GetCreateTime().String(),
|
||||
UpdatedTime: paramtable.GetUpdateTime().String(),
|
||||
Type: typeutil.QueryNodeRole,
|
||||
ID: node.session.ServerID,
|
||||
},
|
||||
SystemConfigurations: metricsinfo.QueryNodeConfiguration{
|
||||
SimdType: paramtable.Get().CommonCfg.SimdType.GetValue(),
|
||||
},
|
||||
QuotaMetrics: quotaMetrics,
|
||||
}
|
||||
metricsinfo.FillDeployMetricsWithEnv(&nodeInfos.SystemInfo)
|
||||
|
||||
resp, err := metricsinfo.MarshalComponentInfos(nodeInfos)
|
||||
if err != nil {
|
||||
return &milvuspb.GetMetricsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
Response: "",
|
||||
ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, paramtable.GetNodeID()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &milvuspb.GetMetricsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
Response: resp,
|
||||
ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, paramtable.GetNodeID()),
|
||||
}, nil
|
||||
}
|
|
@ -0,0 +1,230 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package querynodev2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
// ---------- unittest util functions ----------
|
||||
// common definitions
|
||||
const (
|
||||
metricTypeKey = common.MetricTypeKey
|
||||
defaultTopK = int64(10)
|
||||
defaultRoundDecimal = int64(6)
|
||||
defaultDim = 128
|
||||
defaultNProb = 10
|
||||
defaultEf = 10
|
||||
defaultMetricType = "L2"
|
||||
defaultNQ = 10
|
||||
)
|
||||
|
||||
const (
|
||||
// index type
|
||||
IndexFaissIDMap = "FLAT"
|
||||
IndexFaissIVFFlat = "IVF_FLAT"
|
||||
IndexFaissIVFPQ = "IVF_PQ"
|
||||
IndexFaissIVFSQ8 = "IVF_SQ8"
|
||||
IndexFaissBinIDMap = "BIN_FLAT"
|
||||
IndexFaissBinIVFFlat = "BIN_IVF_FLAT"
|
||||
|
||||
IndexHNSW = "HNSW"
|
||||
IndexANNOY = "ANNOY"
|
||||
)
|
||||
|
||||
// ---------- unittest util functions ----------
|
||||
// functions of messages and requests
|
||||
func genIVFFlatDSL(schema *schemapb.CollectionSchema, nProb int, topK int64, roundDecimal int64) (string, error) {
|
||||
var vecFieldName string
|
||||
var metricType string
|
||||
nProbStr := strconv.Itoa(nProb)
|
||||
topKStr := strconv.FormatInt(topK, 10)
|
||||
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
|
||||
for _, f := range schema.Fields {
|
||||
if f.DataType == schemapb.DataType_FloatVector {
|
||||
vecFieldName = f.Name
|
||||
for _, p := range f.IndexParams {
|
||||
if p.Key == metricTypeKey {
|
||||
metricType = p.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if vecFieldName == "" || metricType == "" {
|
||||
err := errors.New("invalid vector field name or metric type")
|
||||
return "", err
|
||||
}
|
||||
|
||||
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
|
||||
"\": {\n \"metric_type\": \"" + metricType +
|
||||
"\", \n \"params\": {\n \"nprobe\": " + nProbStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
|
||||
" \n,\"round_decimal\": " + roundDecimalStr +
|
||||
"\n } \n } \n } \n }", nil
|
||||
}
|
||||
|
||||
func genHNSWDSL(schema *schemapb.CollectionSchema, ef int, topK int64, roundDecimal int64) (string, error) {
|
||||
var vecFieldName string
|
||||
var metricType string
|
||||
efStr := strconv.Itoa(ef)
|
||||
topKStr := strconv.FormatInt(topK, 10)
|
||||
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
|
||||
for _, f := range schema.Fields {
|
||||
if f.DataType == schemapb.DataType_FloatVector {
|
||||
vecFieldName = f.Name
|
||||
for _, p := range f.IndexParams {
|
||||
if p.Key == metricTypeKey {
|
||||
metricType = p.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if vecFieldName == "" || metricType == "" {
|
||||
err := errors.New("invalid vector field name or metric type")
|
||||
return "", err
|
||||
}
|
||||
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
|
||||
"\": {\n \"metric_type\": \"" + metricType +
|
||||
"\", \n \"params\": {\n \"ef\": " + efStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
|
||||
" \n,\"round_decimal\": " + roundDecimalStr +
|
||||
"\n } \n } \n } \n }", nil
|
||||
}
|
||||
|
||||
func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecimal int64) (string, error) {
|
||||
var vecFieldName string
|
||||
var metricType string
|
||||
topKStr := strconv.FormatInt(topK, 10)
|
||||
nProbStr := strconv.Itoa(defaultNProb)
|
||||
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
|
||||
for _, f := range schema.Fields {
|
||||
if f.DataType == schemapb.DataType_FloatVector {
|
||||
vecFieldName = f.Name
|
||||
for _, p := range f.IndexParams {
|
||||
if p.Key == metricTypeKey {
|
||||
metricType = p.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if vecFieldName == "" || metricType == "" {
|
||||
err := errors.New("invalid vector field name or metric type")
|
||||
return "", err
|
||||
}
|
||||
return "{\"bool\": { \n\"vector\": {\n \"" + vecFieldName +
|
||||
"\": {\n \"metric_type\": \"" + metricType +
|
||||
"\", \n \"params\": {\n \"nprobe\": " + nProbStr + " \n},\n \"query\": \"$0\",\n \"topk\": " + topKStr +
|
||||
" \n,\"round_decimal\": " + roundDecimalStr +
|
||||
"\n } \n } \n } \n }", nil
|
||||
}
|
||||
|
||||
func genDSLByIndexType(schema *schemapb.CollectionSchema, indexType string) (string, error) {
|
||||
if indexType == IndexFaissIDMap { // float vector
|
||||
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
|
||||
} else if indexType == IndexFaissBinIDMap {
|
||||
return genBruteForceDSL(schema, defaultTopK, defaultRoundDecimal)
|
||||
} else if indexType == IndexFaissIVFFlat {
|
||||
return genIVFFlatDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
|
||||
} else if indexType == IndexFaissBinIVFFlat { // binary vector
|
||||
return genIVFFlatDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
|
||||
} else if indexType == IndexHNSW {
|
||||
return genHNSWDSL(schema, defaultEf, defaultTopK, defaultRoundDecimal)
|
||||
}
|
||||
return "", fmt.Errorf("Invalid indexType")
|
||||
}
|
||||
|
||||
func genPlaceHolderGroup(nq int64) ([]byte, error) {
|
||||
placeholderValue := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_FloatVector,
|
||||
Values: make([][]byte, 0),
|
||||
}
|
||||
for i := int64(0); i < nq; i++ {
|
||||
var vec = make([]float32, defaultDim)
|
||||
for j := 0; j < defaultDim; j++ {
|
||||
vec[j] = rand.Float32()
|
||||
}
|
||||
var rawData []byte
|
||||
for k, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
common.Endian.PutUint32(buf, math.Float32bits(ele+float32(k*2)))
|
||||
rawData = append(rawData, buf...)
|
||||
}
|
||||
placeholderValue.Values = append(placeholderValue.Values, rawData)
|
||||
}
|
||||
|
||||
// generate placeholder
|
||||
placeholderGroup := commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{placeholderValue},
|
||||
}
|
||||
placeGroupByte, err := proto.Marshal(&placeholderGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return placeGroupByte, nil
|
||||
}
|
||||
|
||||
func genSimpleRetrievePlanExpr(schema *schemapb.CollectionSchema) ([]byte, error) {
|
||||
pkField, err := typeutil.GetPrimaryFieldSchema(schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
planNode := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_Predicates{
|
||||
Predicates: &planpb.Expr{
|
||||
Expr: &planpb.Expr_TermExpr{
|
||||
TermExpr: &planpb.TermExpr{
|
||||
ColumnInfo: &planpb.ColumnInfo{
|
||||
FieldId: pkField.FieldID,
|
||||
DataType: pkField.DataType,
|
||||
},
|
||||
Values: []*planpb.GenericValue{
|
||||
{
|
||||
Val: &planpb.GenericValue_Int64Val{
|
||||
Int64Val: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Val: &planpb.GenericValue_Int64Val{
|
||||
Int64Val: 2,
|
||||
},
|
||||
},
|
||||
{
|
||||
Val: &planpb.GenericValue_Int64Val{
|
||||
Int64Val: 3,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
OutputFieldIds: []int64{pkField.FieldID, 100},
|
||||
}
|
||||
planExpr, err := proto.Marshal(planNode)
|
||||
return planExpr, err
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
base "github.com/milvus-io/milvus/internal/util/pipeline"
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type deleteNode struct {
|
||||
*BaseNode
|
||||
collectionID UniqueID
|
||||
channel string
|
||||
|
||||
manager *DataManager
|
||||
tSafeManager TSafeManager
|
||||
delegator delegator.ShardDelegator
|
||||
}
|
||||
|
||||
//addDeleteData find the segment of delete column in DeleteMsg and save in deleteData
|
||||
func (dNode *deleteNode) addDeleteData(deleteDatas map[UniqueID]*delegator.DeleteData, msg *DeleteMsg) {
|
||||
deleteData, ok := deleteDatas[msg.PartitionID]
|
||||
if !ok {
|
||||
deleteData = &delegator.DeleteData{
|
||||
PartitionID: msg.PartitionID,
|
||||
}
|
||||
deleteDatas[msg.PartitionID] = deleteData
|
||||
}
|
||||
pks := storage.ParseIDs2PrimaryKeys(msg.PrimaryKeys)
|
||||
deleteData.PrimaryKeys = append(deleteData.PrimaryKeys, pks...)
|
||||
deleteData.Timestamps = append(deleteData.Timestamps, msg.Timestamps...)
|
||||
deleteData.RowCount += int64(len(pks))
|
||||
|
||||
log.Info("pipeline fetch delete msg",
|
||||
zap.Int64("collectionID", dNode.collectionID),
|
||||
zap.Int64("partitionID", msg.PartitionID),
|
||||
zap.Int("insertRowNum", len(pks)),
|
||||
zap.Uint64("timestampMin", msg.BeginTimestamp),
|
||||
zap.Uint64("timestampMax", msg.EndTimestamp))
|
||||
}
|
||||
|
||||
func (dNode *deleteNode) Operate(in Msg) Msg {
|
||||
nodeMsg := in.(*deleteNodeMsg)
|
||||
|
||||
// partition id = > DeleteData
|
||||
deleteDatas := make(map[UniqueID]*delegator.DeleteData)
|
||||
|
||||
for _, msg := range nodeMsg.deleteMsgs {
|
||||
dNode.addDeleteData(deleteDatas, msg)
|
||||
}
|
||||
|
||||
if len(deleteDatas) > 0 {
|
||||
//do Delete, use ts range max as ts
|
||||
dNode.delegator.ProcessDelete(lo.Values(deleteDatas), nodeMsg.timeRange.timestampMax)
|
||||
}
|
||||
|
||||
//update tSafe
|
||||
err := dNode.tSafeManager.Set(dNode.channel, nodeMsg.timeRange.timestampMax)
|
||||
if err != nil {
|
||||
// should not happen, QueryNode should addTSafe before start pipeline
|
||||
panic(fmt.Errorf("serviceTimeNode setTSafe timeout, collectionID = %d, err = %s", dNode.collectionID, err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newDeleteNode(
|
||||
collectionID UniqueID, channel string,
|
||||
manager *DataManager, tSafeManager TSafeManager, delegator delegator.ShardDelegator,
|
||||
maxQueueLength int32,
|
||||
) *deleteNode {
|
||||
return &deleteNode{
|
||||
BaseNode: base.NewBaseNode(fmt.Sprintf("DeleteNode-%s", channel), maxQueueLength),
|
||||
collectionID: collectionID,
|
||||
channel: channel,
|
||||
manager: manager,
|
||||
tSafeManager: tSafeManager,
|
||||
delegator: delegator,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type DeleteNodeSuite struct {
|
||||
suite.Suite
|
||||
//datas
|
||||
collectionID int64
|
||||
collectionName string
|
||||
partitionIDs []int64
|
||||
deletePKs []int64
|
||||
channel string
|
||||
timeRange TimeRange
|
||||
|
||||
//dependency
|
||||
tSafeManager TSafeManager
|
||||
//mocks
|
||||
manager *segments.Manager
|
||||
delegator *delegator.MockShardDelegator
|
||||
}
|
||||
|
||||
func (suite *DeleteNodeSuite) SetupSuite() {
|
||||
suite.collectionID = 111
|
||||
suite.collectionName = "test-collection"
|
||||
suite.partitionIDs = []int64{11, 22}
|
||||
suite.channel = "test-channel"
|
||||
//segment own data row which‘s pk same with segment‘s ID
|
||||
suite.deletePKs = []int64{1, 2, 3, 4}
|
||||
suite.timeRange = TimeRange{
|
||||
timestampMin: 0,
|
||||
timestampMax: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *DeleteNodeSuite) buildDeleteNodeMsg() *deleteNodeMsg {
|
||||
nodeMsg := &deleteNodeMsg{
|
||||
deleteMsgs: []*DeleteMsg{},
|
||||
timeRange: suite.timeRange,
|
||||
}
|
||||
|
||||
for i, pk := range suite.deletePKs {
|
||||
deleteMsg := buildDeleteMsg(suite.collectionID, suite.partitionIDs[i%len(suite.partitionIDs)], suite.channel, 1)
|
||||
deleteMsg.PrimaryKeys = genDeletePK(pk)
|
||||
nodeMsg.deleteMsgs = append(nodeMsg.deleteMsgs, deleteMsg)
|
||||
}
|
||||
return nodeMsg
|
||||
}
|
||||
|
||||
func (suite *DeleteNodeSuite) TestBasic() {
|
||||
//mock
|
||||
mockCollectionManager := segments.NewMockCollectionManager(suite.T())
|
||||
mockSegmentManager := segments.NewMockSegmentManager(suite.T())
|
||||
suite.manager = &segments.Manager{
|
||||
Collection: mockCollectionManager,
|
||||
Segment: mockSegmentManager,
|
||||
}
|
||||
suite.delegator = delegator.NewMockShardDelegator(suite.T())
|
||||
suite.delegator.EXPECT().ProcessDelete(mock.Anything, mock.Anything).Run(
|
||||
func(deleteData []*delegator.DeleteData, ts uint64) {
|
||||
for _, data := range deleteData {
|
||||
for _, pk := range data.PrimaryKeys {
|
||||
suite.True(lo.Contains(suite.deletePKs, pk.GetValue().(int64)))
|
||||
}
|
||||
}
|
||||
})
|
||||
//init dependency
|
||||
suite.tSafeManager = tsafe.NewTSafeReplica()
|
||||
suite.tSafeManager.Add(suite.channel, 0)
|
||||
//build delete node and data
|
||||
node := newDeleteNode(suite.collectionID, suite.channel, suite.manager, suite.tSafeManager, suite.delegator, 8)
|
||||
in := suite.buildDeleteNodeMsg()
|
||||
//run
|
||||
out := node.Operate(in)
|
||||
suite.Nil(out)
|
||||
//check tsafe
|
||||
tt, err := suite.tSafeManager.Get(suite.channel)
|
||||
suite.NoError(err)
|
||||
suite.Equal(suite.timeRange.timestampMax, tt)
|
||||
}
|
||||
|
||||
func TestDeleteNode(t *testing.T) {
|
||||
suite.Run(t, new(DeleteNodeSuite))
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMsgInvalidType = errors.New("InvalidMessageType")
|
||||
ErrMsgNotAligned = errors.New("CheckAlignedFailed")
|
||||
ErrMsgEmpty = errors.New("EmptyMessage")
|
||||
ErrMsgNotTarget = errors.New("NotTarget")
|
||||
ErrMsgExcluded = errors.New("SegmentExcluded")
|
||||
|
||||
ErrCollectionNotFound = errors.New("CollectionNotFound")
|
||||
ErrShardDelegatorNotFound = errors.New("ShardDelegatorNotFound")
|
||||
|
||||
ErrNewPipelineFailed = errors.New("FailedCreateNewPipeline")
|
||||
ErrStartPipeline = errors.New("PipineStartFailed")
|
||||
)
|
||||
|
||||
func WrapErrMsgNotAligned(err error) error {
|
||||
return fmt.Errorf("%w :%s", ErrMsgNotAligned, err)
|
||||
}
|
||||
|
||||
func WrapErrMsgNotTarget(reason string) error {
|
||||
return fmt.Errorf("%w%s", ErrMsgNotTarget, reason)
|
||||
}
|
||||
|
||||
func WrapErrMsgExcluded(segmentID int64) error {
|
||||
return fmt.Errorf("%w ID:%d", ErrMsgExcluded, segmentID)
|
||||
}
|
||||
|
||||
func WrapErrNewPipelineFailed(err error) error {
|
||||
return fmt.Errorf("%w :%s", ErrNewPipelineFailed, err)
|
||||
}
|
||||
|
||||
func WrapErrStartPipeline(reason string) error {
|
||||
return fmt.Errorf("%w :%s", ErrStartPipeline, reason)
|
||||
}
|
||||
|
||||
func WrapErrShardDelegatorNotFound(channel string) error {
|
||||
return fmt.Errorf("%w channel:%s", ErrShardDelegatorNotFound, channel)
|
||||
}
|
|
@ -0,0 +1,158 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
base "github.com/milvus-io/milvus/internal/util/pipeline"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
//filterNode filter the invalid message of pipeline
|
||||
type filterNode struct {
|
||||
*BaseNode
|
||||
collectionID UniqueID
|
||||
manager *DataManager
|
||||
excludedSegments *typeutil.ConcurrentMap[int64, *datapb.SegmentInfo]
|
||||
channel string
|
||||
InsertMsgPolicys []InsertMsgFilter
|
||||
DeleteMsgPolicys []DeleteMsgFilter
|
||||
}
|
||||
|
||||
func (fNode *filterNode) Operate(in Msg) Msg {
|
||||
if in == nil {
|
||||
log.Debug("type assertion failed for Msg in filterNode because it's nil",
|
||||
zap.String("name", fNode.Name()))
|
||||
return nil
|
||||
}
|
||||
|
||||
streamMsgPack, ok := in.(*msgstream.MsgPack)
|
||||
if !ok {
|
||||
log.Warn("type assertion failed for MsgPack",
|
||||
zap.String("msgType", reflect.TypeOf(in).Name()),
|
||||
zap.String("name", fNode.Name()))
|
||||
return nil
|
||||
}
|
||||
|
||||
metrics.QueryNodeConsumerMsgCount.
|
||||
WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.InsertLabel, fmt.Sprint(fNode.collectionID)).
|
||||
Inc()
|
||||
|
||||
metrics.QueryNodeConsumeTimeTickLag.
|
||||
WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.InsertLabel, fmt.Sprint(fNode.collectionID)).
|
||||
Set(float64(streamMsgPack.EndTs))
|
||||
|
||||
//Get collection from collection manager
|
||||
collection := fNode.manager.Collection.Get(fNode.collectionID)
|
||||
if collection == nil {
|
||||
err := segments.WrapCollectionNotFound(fNode.collectionID)
|
||||
log.Error(err.Error())
|
||||
panic(err)
|
||||
}
|
||||
|
||||
out := &insertNodeMsg{
|
||||
insertMsgs: []*InsertMsg{},
|
||||
deleteMsgs: []*DeleteMsg{},
|
||||
timeRange: TimeRange{
|
||||
timestampMin: streamMsgPack.BeginTs,
|
||||
timestampMax: streamMsgPack.EndTs,
|
||||
},
|
||||
}
|
||||
|
||||
//add msg to out if msg pass check of filter
|
||||
for _, msg := range streamMsgPack.Msgs {
|
||||
err := fNode.filtrate(collection, msg)
|
||||
if err != nil {
|
||||
log.Debug("filter invalid message",
|
||||
zap.String("message type", msg.Type().String()),
|
||||
zap.String("channel", fNode.channel),
|
||||
zap.Int64("collectionID", fNode.collectionID),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
out.append(msg)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
//filtrate message with filter policy
|
||||
func (fNode *filterNode) filtrate(c *Collection, msg msgstream.TsMsg) error {
|
||||
|
||||
switch msg.Type() {
|
||||
case commonpb.MsgType_Insert:
|
||||
insertMsg := msg.(*msgstream.InsertMsg)
|
||||
metrics.QueryNodeConsumeCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.InsertLabel).Add(float64(proto.Size(insertMsg)))
|
||||
for _, policy := range fNode.InsertMsgPolicys {
|
||||
err := policy(fNode, c, insertMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
case commonpb.MsgType_Delete:
|
||||
deleteMsg := msg.(*msgstream.DeleteMsg)
|
||||
metrics.QueryNodeConsumeCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.InsertLabel).Add(float64(proto.Size(deleteMsg)))
|
||||
for _, policy := range fNode.DeleteMsgPolicys {
|
||||
err := policy(fNode, c, deleteMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
default:
|
||||
return ErrMsgInvalidType
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newFilterNode(
|
||||
collectionID int64,
|
||||
channel string,
|
||||
manager *DataManager,
|
||||
excludedSegments *typeutil.ConcurrentMap[int64, *datapb.SegmentInfo],
|
||||
maxQueueLength int32,
|
||||
) *filterNode {
|
||||
return &filterNode{
|
||||
BaseNode: base.NewBaseNode(fmt.Sprintf("FilterNode-%s", channel), maxQueueLength),
|
||||
collectionID: collectionID,
|
||||
manager: manager,
|
||||
channel: channel,
|
||||
excludedSegments: excludedSegments,
|
||||
InsertMsgPolicys: []InsertMsgFilter{
|
||||
InsertNotAligned,
|
||||
InsertEmpty,
|
||||
InsertOutOfTarget,
|
||||
InsertExcluded,
|
||||
},
|
||||
DeleteMsgPolicys: []DeleteMsgFilter{
|
||||
DeleteNotAligned,
|
||||
DeleteEmpty,
|
||||
DeleteOutOfTarget,
|
||||
},
|
||||
}
|
||||
}
|
|
@ -0,0 +1,200 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
//test of filter node
|
||||
type FilterNodeSuite struct {
|
||||
suite.Suite
|
||||
//datas
|
||||
collectionID int64
|
||||
partitionIDs []int64
|
||||
channel string
|
||||
|
||||
validSegmentIDs []int64
|
||||
excludedSegments *typeutil.ConcurrentMap[int64, *datapb.SegmentInfo]
|
||||
excludedSegmentIDs []int64
|
||||
insertSegmentIDs []int64
|
||||
deleteSegmentSum int
|
||||
//segmentID of msg invalid because empty of not aligned
|
||||
errSegmentID int64
|
||||
|
||||
//mocks
|
||||
manager *segments.Manager
|
||||
}
|
||||
|
||||
func (suite *FilterNodeSuite) SetupSuite() {
|
||||
paramtable.Init()
|
||||
suite.collectionID = 111
|
||||
suite.partitionIDs = []int64{11, 22}
|
||||
suite.channel = "test-channel"
|
||||
|
||||
// first one invalid because insert max timestamp before dmlPosition timestamp
|
||||
suite.excludedSegmentIDs = []int64{1, 2}
|
||||
suite.insertSegmentIDs = []int64{3, 4, 5, 6}
|
||||
suite.deleteSegmentSum = 4
|
||||
suite.errSegmentID = 7
|
||||
|
||||
//init excludedSegment
|
||||
suite.excludedSegments = typeutil.NewConcurrentMap[int64, *datapb.SegmentInfo]()
|
||||
for _, id := range suite.excludedSegmentIDs {
|
||||
suite.excludedSegments.Insert(id, &datapb.SegmentInfo{
|
||||
DmlPosition: &msgpb.MsgPosition{
|
||||
Timestamp: 1,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//test filter node with collection load collection
|
||||
func (suite *FilterNodeSuite) TestWithLoadCollection() {
|
||||
//data
|
||||
suite.validSegmentIDs = []int64{2, 3, 4, 5, 6}
|
||||
|
||||
//mock
|
||||
collection := segments.NewCollectionWithoutSchema(suite.collectionID, querypb.LoadType_LoadCollection)
|
||||
for _, partitionID := range suite.partitionIDs {
|
||||
collection.AddPartition(partitionID)
|
||||
}
|
||||
|
||||
mockCollectionManager := segments.NewMockCollectionManager(suite.T())
|
||||
mockCollectionManager.EXPECT().Get(suite.collectionID).Return(collection)
|
||||
|
||||
mockSegmentManager := segments.NewMockSegmentManager(suite.T())
|
||||
|
||||
suite.manager = &segments.Manager{
|
||||
Collection: mockCollectionManager,
|
||||
Segment: mockSegmentManager,
|
||||
}
|
||||
|
||||
node := newFilterNode(suite.collectionID, suite.channel, suite.manager, suite.excludedSegments, 8)
|
||||
in := suite.buildMsgPack()
|
||||
out := node.Operate(in)
|
||||
|
||||
nodeMsg, ok := out.(*insertNodeMsg)
|
||||
suite.True(ok)
|
||||
|
||||
suite.Equal(len(suite.validSegmentIDs), len(nodeMsg.insertMsgs))
|
||||
for _, msg := range nodeMsg.insertMsgs {
|
||||
suite.True(lo.Contains(suite.validSegmentIDs, msg.SegmentID))
|
||||
}
|
||||
suite.Equal(suite.deleteSegmentSum, len(nodeMsg.deleteMsgs))
|
||||
}
|
||||
|
||||
//test filter node with collection load partition
|
||||
func (suite *FilterNodeSuite) TestWithLoadPartation() {
|
||||
//data
|
||||
suite.validSegmentIDs = []int64{2, 4, 6}
|
||||
|
||||
//mock
|
||||
collection := segments.NewCollectionWithoutSchema(suite.collectionID, querypb.LoadType_LoadPartition)
|
||||
collection.AddPartition(suite.partitionIDs[0])
|
||||
|
||||
mockCollectionManager := segments.NewMockCollectionManager(suite.T())
|
||||
mockCollectionManager.EXPECT().Get(suite.collectionID).Return(collection)
|
||||
|
||||
mockSegmentManager := segments.NewMockSegmentManager(suite.T())
|
||||
|
||||
suite.manager = &segments.Manager{
|
||||
Collection: mockCollectionManager,
|
||||
Segment: mockSegmentManager,
|
||||
}
|
||||
|
||||
node := newFilterNode(suite.collectionID, suite.channel, suite.manager, suite.excludedSegments, 8)
|
||||
in := suite.buildMsgPack()
|
||||
out := node.Operate(in)
|
||||
|
||||
nodeMsg, ok := out.(*insertNodeMsg)
|
||||
suite.True(ok)
|
||||
|
||||
suite.Equal(len(suite.validSegmentIDs), len(nodeMsg.insertMsgs))
|
||||
for _, msg := range nodeMsg.insertMsgs {
|
||||
suite.True(lo.Contains(suite.validSegmentIDs, msg.SegmentID))
|
||||
}
|
||||
suite.Equal(suite.deleteSegmentSum/2, len(nodeMsg.deleteMsgs))
|
||||
}
|
||||
|
||||
func (suite *FilterNodeSuite) buildMsgPack() *msgstream.MsgPack {
|
||||
msgPack := &msgstream.MsgPack{
|
||||
BeginTs: 0,
|
||||
EndTs: 0,
|
||||
Msgs: []msgstream.TsMsg{},
|
||||
}
|
||||
|
||||
//add valid insert
|
||||
for _, id := range suite.insertSegmentIDs {
|
||||
insertMsg := buildInsertMsg(suite.collectionID, suite.partitionIDs[id%2], id, suite.channel, 1)
|
||||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
|
||||
//add valid delete
|
||||
for i := 0; i < suite.deleteSegmentSum; i++ {
|
||||
deleteMsg := buildDeleteMsg(suite.collectionID, suite.partitionIDs[i%2], suite.channel, 1)
|
||||
msgPack.Msgs = append(msgPack.Msgs, deleteMsg)
|
||||
}
|
||||
|
||||
//add invalid msg
|
||||
|
||||
//segment in excludedSegments
|
||||
//some one end timestamp befroe dmlPosition timestamp will be invalid
|
||||
for _, id := range suite.excludedSegmentIDs {
|
||||
insertMsg := buildInsertMsg(suite.collectionID, suite.partitionIDs[id%2], id, suite.channel, 1)
|
||||
insertMsg.EndTimestamp = uint64(id)
|
||||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
|
||||
//empty msg
|
||||
insertMsg := buildInsertMsg(suite.collectionID, suite.partitionIDs[0], suite.errSegmentID, suite.channel, 0)
|
||||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
|
||||
deleteMsg := buildDeleteMsg(suite.collectionID, suite.partitionIDs[0], suite.channel, 0)
|
||||
msgPack.Msgs = append(msgPack.Msgs, deleteMsg)
|
||||
|
||||
//msg not target
|
||||
insertMsg = buildInsertMsg(suite.collectionID+1, 1, 0, "Unknown", 1)
|
||||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
|
||||
deleteMsg = buildDeleteMsg(suite.collectionID+1, 1, "Unknown", 1)
|
||||
msgPack.Msgs = append(msgPack.Msgs, deleteMsg)
|
||||
|
||||
//msg not aligned
|
||||
insertMsg = buildInsertMsg(suite.collectionID, suite.partitionIDs[0], suite.errSegmentID, suite.channel, 1)
|
||||
insertMsg.Timestamps = []uint64{}
|
||||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
|
||||
deleteMsg = buildDeleteMsg(suite.collectionID, suite.partitionIDs[0], suite.channel, 1)
|
||||
deleteMsg.Timestamps = append(deleteMsg.Timestamps, 1)
|
||||
msgPack.Msgs = append(msgPack.Msgs, deleteMsg)
|
||||
return msgPack
|
||||
}
|
||||
|
||||
func TestFilterNode(t *testing.T) {
|
||||
suite.Run(t, new(FilterNodeSuite))
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import "github.com/milvus-io/milvus/internal/common"
|
||||
|
||||
//MsgFilter will return error if Msg was invalid
|
||||
type InsertMsgFilter = func(n *filterNode, c *Collection, msg *InsertMsg) error
|
||||
type DeleteMsgFilter = func(n *filterNode, c *Collection, msg *DeleteMsg) error
|
||||
|
||||
//Chack msg is aligned --
|
||||
//len of each kind of infos in InsertMsg should match each other
|
||||
func InsertNotAligned(n *filterNode, c *Collection, msg *InsertMsg) error {
|
||||
err := msg.CheckAligned()
|
||||
if err != nil {
|
||||
return WrapErrMsgNotAligned(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func InsertEmpty(n *filterNode, c *Collection, msg *InsertMsg) error {
|
||||
if len(msg.GetTimestamps()) <= 0 {
|
||||
return ErrMsgEmpty
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func InsertOutOfTarget(n *filterNode, c *Collection, msg *InsertMsg) error {
|
||||
if msg.GetCollectionID() != c.ID() {
|
||||
return WrapErrMsgNotTarget("Collection")
|
||||
}
|
||||
|
||||
if c.GetLoadType() == loadTypePartition {
|
||||
if msg.PartitionID != common.InvalidPartitionID && !c.ExistPartition(msg.PartitionID) {
|
||||
return WrapErrMsgNotTarget("Partition")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func InsertExcluded(n *filterNode, c *Collection, msg *InsertMsg) error {
|
||||
segInfo, ok := n.excludedSegments.Get(msg.SegmentID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if msg.EndTimestamp <= segInfo.GetDmlPosition().Timestamp {
|
||||
return WrapErrMsgExcluded(msg.SegmentID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteNotAligned(n *filterNode, c *Collection, msg *DeleteMsg) error {
|
||||
err := msg.CheckAligned()
|
||||
if err != nil {
|
||||
return WrapErrMsgNotAligned(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteEmpty(n *filterNode, c *Collection, msg *DeleteMsg) error {
|
||||
if len(msg.GetTimestamps()) <= 0 {
|
||||
return ErrMsgEmpty
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteOutOfTarget(n *filterNode, c *Collection, msg *DeleteMsg) error {
|
||||
if msg.GetCollectionID() != c.ID() {
|
||||
return WrapErrMsgNotTarget("Collection")
|
||||
}
|
||||
if c.GetLoadType() == loadTypePartition {
|
||||
if msg.PartitionID != common.InvalidPartitionID && !c.ExistPartition(msg.PartitionID) {
|
||||
return WrapErrMsgNotTarget("Partition")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
base "github.com/milvus-io/milvus/internal/util/pipeline"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type insertNode struct {
|
||||
*BaseNode
|
||||
collectionID int64
|
||||
channel string
|
||||
manager *DataManager
|
||||
delegator delegator.ShardDelegator
|
||||
}
|
||||
|
||||
func (iNode *insertNode) addInsertData(insertDatas map[UniqueID]*delegator.InsertData, msg *InsertMsg, collection *Collection) {
|
||||
insertRecord, err := storage.TransferInsertMsgToInsertRecord(collection.Schema(), msg)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to get primary keys, err = %d", err)
|
||||
log.Error(err.Error(), zap.Int64("collectionID", iNode.collectionID), zap.String("channel", iNode.channel))
|
||||
panic(err)
|
||||
}
|
||||
iData, ok := insertDatas[msg.SegmentID]
|
||||
if !ok {
|
||||
iData = &delegator.InsertData{
|
||||
PartitionID: msg.PartitionID,
|
||||
InsertRecord: insertRecord,
|
||||
StartPosition: &msgpb.MsgPosition{
|
||||
Timestamp: msg.BeginTs(),
|
||||
ChannelName: msg.GetShardName(),
|
||||
},
|
||||
}
|
||||
insertDatas[msg.SegmentID] = iData
|
||||
} else {
|
||||
typeutil.MergeFieldData(iData.InsertRecord.FieldsData, insertRecord.FieldsData)
|
||||
iData.InsertRecord.NumRows += insertRecord.NumRows
|
||||
}
|
||||
|
||||
pks, err := segments.GetPrimaryKeys(msg, collection.Schema())
|
||||
if err != nil {
|
||||
log.Error("failed to get primary keys from insert message", zap.Error(err))
|
||||
panic(err)
|
||||
}
|
||||
|
||||
iData.PrimaryKeys = append(iData.PrimaryKeys, pks...)
|
||||
iData.RowIDs = append(iData.RowIDs, msg.RowIDs...)
|
||||
iData.Timestamps = append(iData.Timestamps, msg.Timestamps...)
|
||||
log.Info("pipeline fetch insert msg",
|
||||
zap.Int64("collectionID", iNode.collectionID),
|
||||
zap.Int64("segmentID", msg.SegmentID),
|
||||
zap.Int("insertRowNum", len(pks)),
|
||||
zap.Uint64("timestampMin", msg.BeginTimestamp),
|
||||
zap.Uint64("timestampMax", msg.EndTimestamp))
|
||||
}
|
||||
|
||||
//Insert task
|
||||
func (iNode *insertNode) Operate(in Msg) Msg {
|
||||
nodeMsg := in.(*insertNodeMsg)
|
||||
|
||||
sort.Slice(nodeMsg.insertMsgs, func(i, j int) bool {
|
||||
return nodeMsg.insertMsgs[i].BeginTs() < nodeMsg.insertMsgs[j].BeginTs()
|
||||
})
|
||||
|
||||
insertDatas := make(map[UniqueID]*delegator.InsertData)
|
||||
collection := iNode.manager.Collection.Get(iNode.collectionID)
|
||||
if collection == nil {
|
||||
log.Error("insertNode with collection not exist", zap.Int64("collection", iNode.collectionID))
|
||||
panic("insertNode with collection not exist")
|
||||
}
|
||||
|
||||
//get InsertData and merge datas of same segment
|
||||
for _, msg := range nodeMsg.insertMsgs {
|
||||
iNode.addInsertData(insertDatas, msg, collection)
|
||||
}
|
||||
|
||||
iNode.delegator.ProcessInsert(insertDatas)
|
||||
|
||||
return &deleteNodeMsg{
|
||||
deleteMsgs: nodeMsg.deleteMsgs,
|
||||
timeRange: nodeMsg.timeRange,
|
||||
}
|
||||
}
|
||||
|
||||
func newInsertNode(
|
||||
collectionID UniqueID,
|
||||
channel string,
|
||||
manager *DataManager,
|
||||
delegator delegator.ShardDelegator,
|
||||
maxQueueLength int32,
|
||||
) *insertNode {
|
||||
return &insertNode{
|
||||
BaseNode: base.NewBaseNode(fmt.Sprintf("InsertNode-%s", channel), maxQueueLength),
|
||||
collectionID: collectionID,
|
||||
channel: channel,
|
||||
manager: manager,
|
||||
delegator: delegator,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type InsertNodeSuite struct {
|
||||
suite.Suite
|
||||
//datas
|
||||
collectionName string
|
||||
collectionID int64
|
||||
partitionID int64
|
||||
channel string
|
||||
insertSegmentIDs []int64
|
||||
deleteSegmentSum int
|
||||
//mocks
|
||||
manager *segments.Manager
|
||||
delegator *delegator.MockShardDelegator
|
||||
}
|
||||
|
||||
func (suite *InsertNodeSuite) SetupSuite() {
|
||||
suite.collectionName = "test-collection"
|
||||
suite.collectionID = 111
|
||||
suite.partitionID = 11
|
||||
suite.channel = "test_channel"
|
||||
|
||||
suite.insertSegmentIDs = []int64{4, 3}
|
||||
suite.deleteSegmentSum = 2
|
||||
}
|
||||
|
||||
func (suite *InsertNodeSuite) TestBasic() {
|
||||
//data
|
||||
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
|
||||
in := suite.buildInsertNodeMsg(schema)
|
||||
|
||||
collection := segments.NewCollection(suite.collectionID, schema, querypb.LoadType_LoadCollection)
|
||||
collection.AddPartition(suite.partitionID)
|
||||
|
||||
//init mock
|
||||
mockCollectionManager := segments.NewMockCollectionManager(suite.T())
|
||||
mockCollectionManager.EXPECT().Get(suite.collectionID).Return(collection)
|
||||
|
||||
mockSegmentManager := segments.NewMockSegmentManager(suite.T())
|
||||
|
||||
suite.manager = &segments.Manager{
|
||||
Collection: mockCollectionManager,
|
||||
Segment: mockSegmentManager,
|
||||
}
|
||||
|
||||
suite.delegator = delegator.NewMockShardDelegator(suite.T())
|
||||
suite.delegator.EXPECT().ProcessInsert(mock.Anything).Run(func(insertRecords map[int64]*delegator.InsertData) {
|
||||
for segID := range insertRecords {
|
||||
suite.True(lo.Contains(suite.insertSegmentIDs, segID))
|
||||
}
|
||||
})
|
||||
|
||||
//TODO mock a delgator for test
|
||||
node := newInsertNode(suite.collectionID, suite.channel, suite.manager, suite.delegator, 8)
|
||||
out := node.Operate(in)
|
||||
|
||||
nodeMsg, ok := out.(*deleteNodeMsg)
|
||||
suite.True(ok)
|
||||
suite.Equal(suite.deleteSegmentSum, len(nodeMsg.deleteMsgs))
|
||||
}
|
||||
|
||||
func (suite *InsertNodeSuite) buildInsertNodeMsg(schema *schemapb.CollectionSchema) *insertNodeMsg {
|
||||
nodeMsg := insertNodeMsg{
|
||||
insertMsgs: []*InsertMsg{},
|
||||
deleteMsgs: []*DeleteMsg{},
|
||||
timeRange: TimeRange{
|
||||
timestampMin: 0,
|
||||
timestampMax: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, segmentID := range suite.insertSegmentIDs {
|
||||
insertMsg := buildInsertMsg(suite.collectionID, suite.partitionID, segmentID, suite.channel, 1)
|
||||
insertMsg.FieldsData = genFiledDataWithSchema(schema, 1)
|
||||
nodeMsg.insertMsgs = append(nodeMsg.insertMsgs, insertMsg)
|
||||
}
|
||||
|
||||
for i := 0; i < suite.deleteSegmentSum; i++ {
|
||||
deleteMsg := buildDeleteMsg(suite.collectionID, suite.partitionID, suite.channel, 1)
|
||||
nodeMsg.deleteMsgs = append(nodeMsg.deleteMsgs, deleteMsg)
|
||||
}
|
||||
|
||||
return &nodeMsg
|
||||
}
|
||||
|
||||
func TestInsertNode(t *testing.T) {
|
||||
suite.Run(t, new(InsertNodeSuite))
|
||||
}
|
|
@ -0,0 +1,163 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
//Manager manage pipeline in querynode
|
||||
type Manager interface {
|
||||
Num() int
|
||||
Add(collectionID UniqueID, channel string) (Pipeline, error)
|
||||
Get(channel string) Pipeline
|
||||
Remove(channels ...string)
|
||||
Start(channels ...string) error
|
||||
Close()
|
||||
}
|
||||
type manager struct {
|
||||
channel2Pipeline map[string]Pipeline
|
||||
dataManager *DataManager
|
||||
delegators *typeutil.ConcurrentMap[string, delegator.ShardDelegator]
|
||||
|
||||
tSafeManager TSafeManager
|
||||
dispatcher msgdispatcher.Client
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *manager) Num() int {
|
||||
return len(m.channel2Pipeline)
|
||||
}
|
||||
|
||||
//Add pipeline for each channel of collection
|
||||
func (m *manager) Add(collectionID UniqueID, channel string) (Pipeline, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
log.Debug("start create pipeine",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.String("channel", channel),
|
||||
)
|
||||
collection := m.dataManager.Collection.Get(collectionID)
|
||||
if collection == nil {
|
||||
return nil, segments.WrapCollectionNotFound(collectionID)
|
||||
}
|
||||
|
||||
if pipeline, ok := m.channel2Pipeline[channel]; ok {
|
||||
return pipeline, nil
|
||||
}
|
||||
|
||||
//get shard delegator for add growing in pipeline
|
||||
delegator, ok := m.delegators.Get(channel)
|
||||
if !ok {
|
||||
return nil, WrapErrShardDelegatorNotFound(channel)
|
||||
}
|
||||
|
||||
newPipeLine, err := NewPipeLine(collectionID, channel, m.dataManager, m.tSafeManager, m.dispatcher, delegator)
|
||||
if err != nil {
|
||||
return nil, WrapErrNewPipelineFailed(err)
|
||||
}
|
||||
|
||||
m.channel2Pipeline[channel] = newPipeLine
|
||||
metrics.QueryNodeNumFlowGraphs.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
|
||||
metrics.QueryNodeNumDmlChannels.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
|
||||
return newPipeLine, nil
|
||||
}
|
||||
|
||||
func (m *manager) Get(channel string) Pipeline {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
pipeline, ok := m.channel2Pipeline[channel]
|
||||
if !ok {
|
||||
log.Warn("pipeline not existed",
|
||||
zap.String("channel", channel),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
return pipeline
|
||||
}
|
||||
|
||||
//Remove pipeline from Manager by channel
|
||||
func (m *manager) Remove(channels ...string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, channel := range channels {
|
||||
if pipeline, ok := m.channel2Pipeline[channel]; ok {
|
||||
pipeline.Close()
|
||||
delete(m.channel2Pipeline, channel)
|
||||
} else {
|
||||
log.Warn("pipeline to be removed doesn't existed", zap.Any("channel", channel))
|
||||
}
|
||||
}
|
||||
metrics.QueryNodeNumFlowGraphs.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
|
||||
metrics.QueryNodeNumDmlChannels.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
|
||||
}
|
||||
|
||||
//Start pipeline by channel
|
||||
func (m *manager) Start(channels ...string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
//check pipelie all exist before start
|
||||
for _, channel := range channels {
|
||||
if _, ok := m.channel2Pipeline[channel]; !ok {
|
||||
return WrapErrStartPipeline(fmt.Sprintf("pipeline with channel %s not exist", channel))
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range channels {
|
||||
m.channel2Pipeline[channel].Start()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//Close all pipeline of Manager
|
||||
func (m *manager) Close() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for _, pipeline := range m.channel2Pipeline {
|
||||
pipeline.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func NewManager(dataManager *DataManager,
|
||||
tSafeManager TSafeManager,
|
||||
dispatcher msgdispatcher.Client,
|
||||
delegators *typeutil.ConcurrentMap[string, delegator.ShardDelegator],
|
||||
) Manager {
|
||||
return &manager{
|
||||
channel2Pipeline: make(map[string]Pipeline),
|
||||
dataManager: dataManager,
|
||||
delegators: delegators,
|
||||
tSafeManager: tSafeManager,
|
||||
dispatcher: dispatcher,
|
||||
mu: sync.Mutex{},
|
||||
}
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type PipelineManagerTestSuite struct {
|
||||
suite.Suite
|
||||
//data
|
||||
collectionID int64
|
||||
channel string
|
||||
//dependencies
|
||||
tSafeManager TSafeManager
|
||||
delegators *typeutil.ConcurrentMap[string, delegator.ShardDelegator]
|
||||
|
||||
//mocks
|
||||
segmentManager *segments.MockSegmentManager
|
||||
collectionManager *segments.MockCollectionManager
|
||||
delegator *delegator.MockShardDelegator
|
||||
msgDispatcher *msgdispatcher.MockClient
|
||||
msgChan chan *msgstream.MsgPack
|
||||
}
|
||||
|
||||
func (suite *PipelineManagerTestSuite) SetupSuite() {
|
||||
suite.collectionID = 111
|
||||
suite.msgChan = make(chan *msgstream.MsgPack, 1)
|
||||
}
|
||||
|
||||
func (suite *PipelineManagerTestSuite) SetupTest() {
|
||||
paramtable.Init()
|
||||
//init dependency
|
||||
// init tsafeManager
|
||||
suite.tSafeManager = tsafe.NewTSafeReplica()
|
||||
suite.tSafeManager.Add(suite.channel, 0)
|
||||
suite.delegators = typeutil.NewConcurrentMap[string, delegator.ShardDelegator]()
|
||||
|
||||
//init mock
|
||||
// init manager
|
||||
suite.collectionManager = segments.NewMockCollectionManager(suite.T())
|
||||
suite.segmentManager = segments.NewMockSegmentManager(suite.T())
|
||||
// init delegator
|
||||
suite.delegator = delegator.NewMockShardDelegator(suite.T())
|
||||
suite.delegators.Insert(suite.channel, suite.delegator)
|
||||
// init mq dispatcher
|
||||
suite.msgDispatcher = msgdispatcher.NewMockClient(suite.T())
|
||||
}
|
||||
|
||||
func (suite *PipelineManagerTestSuite) TestBasic() {
|
||||
//init mock
|
||||
// mock collection manager
|
||||
suite.collectionManager.EXPECT().Get(suite.collectionID).Return(&segments.Collection{})
|
||||
// mock mq factory
|
||||
suite.msgDispatcher.EXPECT().Register(suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.msgChan, nil)
|
||||
suite.msgDispatcher.EXPECT().Deregister(suite.channel)
|
||||
|
||||
//build manager
|
||||
manager := &segments.Manager{
|
||||
Collection: suite.collectionManager,
|
||||
Segment: suite.segmentManager,
|
||||
}
|
||||
pipelineManager := NewManager(manager, suite.tSafeManager, suite.msgDispatcher, suite.delegators)
|
||||
defer pipelineManager.Close()
|
||||
|
||||
//Add pipeline
|
||||
_, err := pipelineManager.Add(suite.collectionID, suite.channel)
|
||||
suite.NoError(err)
|
||||
suite.Equal(1, pipelineManager.Num())
|
||||
|
||||
//Get pipeline
|
||||
pipeline := pipelineManager.Get(suite.channel)
|
||||
suite.NotNil(pipeline)
|
||||
|
||||
//Init Consumer
|
||||
err = pipeline.ConsumeMsgStream(&msgpb.MsgPosition{})
|
||||
suite.NoError(err)
|
||||
|
||||
//Start pipeline
|
||||
err = pipelineManager.Start(suite.channel)
|
||||
suite.NoError(err)
|
||||
|
||||
//Remove pipeline
|
||||
pipelineManager.Remove(suite.channel)
|
||||
suite.Equal(0, pipelineManager.Num())
|
||||
}
|
||||
|
||||
func TestQueryNodePipelineManager(t *testing.T) {
|
||||
suite.Run(t, new(PipelineManagerTestSuite))
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/collector"
|
||||
"github.com/milvus-io/milvus/internal/util/metricsinfo"
|
||||
)
|
||||
|
||||
type insertNodeMsg struct {
|
||||
insertMsgs []*InsertMsg
|
||||
deleteMsgs []*DeleteMsg
|
||||
timeRange TimeRange
|
||||
}
|
||||
|
||||
type deleteNodeMsg struct {
|
||||
deleteMsgs []*DeleteMsg
|
||||
timeRange TimeRange
|
||||
}
|
||||
|
||||
func (msg *insertNodeMsg) append(taskMsg msgstream.TsMsg) error {
|
||||
switch taskMsg.Type() {
|
||||
case commonpb.MsgType_Insert:
|
||||
insertMsg := taskMsg.(*InsertMsg)
|
||||
msg.insertMsgs = append(msg.insertMsgs, insertMsg)
|
||||
collector.Rate.Add(metricsinfo.InsertConsumeThroughput, float64(proto.Size(&insertMsg.InsertRequest)))
|
||||
case commonpb.MsgType_Delete:
|
||||
deleteMsg := taskMsg.(*DeleteMsg)
|
||||
msg.deleteMsgs = append(msg.deleteMsgs, deleteMsg)
|
||||
collector.Rate.Add(metricsinfo.DeleteConsumeThroughput, float64(proto.Size(&deleteMsg.DeleteRequest)))
|
||||
default:
|
||||
return ErrMsgInvalidType
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,173 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||
)
|
||||
|
||||
const defaultDim = 128
|
||||
|
||||
func buildDeleteMsg(collectionID int64, partitionID int64, channel string, rowSum int) *msgstream.DeleteMsg {
|
||||
deleteMsg := emptyDeleteMsg(collectionID, partitionID, channel)
|
||||
for i := 1; i <= rowSum; i++ {
|
||||
deleteMsg.Timestamps = append(deleteMsg.Timestamps, 0)
|
||||
deleteMsg.HashValues = append(deleteMsg.HashValues, 0)
|
||||
deleteMsg.NumRows++
|
||||
}
|
||||
deleteMsg.PrimaryKeys = genDefaultDeletePK(rowSum)
|
||||
return deleteMsg
|
||||
}
|
||||
|
||||
func buildInsertMsg(collectionID int64, partitionID int64, segmentID int64, channel string, rowSum int) *msgstream.InsertMsg {
|
||||
insertMsg := emptyInsertMsg(collectionID, partitionID, segmentID, channel)
|
||||
for i := 1; i <= rowSum; i++ {
|
||||
insertMsg.HashValues = append(insertMsg.HashValues, 0)
|
||||
insertMsg.Timestamps = append(insertMsg.Timestamps, 0)
|
||||
insertMsg.RowIDs = append(insertMsg.RowIDs, rand.Int63n(100))
|
||||
insertMsg.NumRows++
|
||||
}
|
||||
insertMsg.FieldsData = genDefaultFiledData(rowSum)
|
||||
return insertMsg
|
||||
}
|
||||
|
||||
func emptyDeleteMsg(collectionID int64, partitionID int64, channel string) *msgstream.DeleteMsg {
|
||||
deleteReq := msgpb.DeleteRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Delete),
|
||||
commonpbutil.WithTimeStamp(0),
|
||||
),
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
ShardName: channel,
|
||||
}
|
||||
|
||||
return &msgstream.DeleteMsg{
|
||||
BaseMsg: msgstream.BaseMsg{},
|
||||
DeleteRequest: deleteReq,
|
||||
}
|
||||
}
|
||||
|
||||
func emptyInsertMsg(collectionID int64, partitionID int64, segmentID int64, channel string) *msgstream.InsertMsg {
|
||||
insertReq := msgpb.InsertRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
|
||||
commonpbutil.WithTimeStamp(0),
|
||||
),
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
SegmentID: segmentID,
|
||||
ShardName: channel,
|
||||
Version: msgpb.InsertDataVersion_ColumnBased,
|
||||
}
|
||||
insertMsg := &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{},
|
||||
InsertRequest: insertReq,
|
||||
}
|
||||
|
||||
return insertMsg
|
||||
}
|
||||
|
||||
//gen IDs with random pks for DeleteMsg
|
||||
func genDefaultDeletePK(rowSum int) *schemapb.IDs {
|
||||
pkDatas := []int64{}
|
||||
|
||||
for i := 1; i <= rowSum; i++ {
|
||||
pkDatas = append(pkDatas, int64(i))
|
||||
}
|
||||
|
||||
return &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: pkDatas,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
//gen IDs with specified pk
|
||||
func genDeletePK(pks ...int64) *schemapb.IDs {
|
||||
pkDatas := make([]int64, 0, len(pks))
|
||||
pkDatas = append(pkDatas, pks...)
|
||||
|
||||
return &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: pkDatas,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func genDefaultFiledData(numRows int) []*schemapb.FieldData {
|
||||
pkDatas := []int64{}
|
||||
vectorDatas := []byte{}
|
||||
|
||||
for i := 1; i <= numRows; i++ {
|
||||
pkDatas = append(pkDatas, int64(i))
|
||||
vectorDatas = append(vectorDatas, uint8(i))
|
||||
}
|
||||
|
||||
return []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "pk",
|
||||
FieldId: 100,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: pkDatas,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_BinaryVector,
|
||||
FieldName: "vector",
|
||||
FieldId: 101,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 8,
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: vectorDatas,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func genFiledDataWithSchema(schema *schemapb.CollectionSchema, numRows int) []*schemapb.FieldData {
|
||||
fieldsData := make([]*schemapb.FieldData, 0)
|
||||
for _, field := range schema.Fields {
|
||||
if field.DataType < 100 {
|
||||
fieldsData = append(fieldsData, segments.GenTestScalarFieldData(field.DataType, field.DataType.String(), numRows))
|
||||
} else {
|
||||
fieldsData = append(fieldsData, segments.GenTestVectorFiledData(field.DataType, field.DataType.String(), numRows, defaultDim))
|
||||
}
|
||||
}
|
||||
return fieldsData
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
base "github.com/milvus-io/milvus/internal/util/pipeline"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
//pipeline used for querynode
|
||||
type Pipeline interface {
|
||||
base.StreamPipeline
|
||||
ExcludedSegments(segInfos ...*datapb.SegmentInfo)
|
||||
}
|
||||
type pipeline struct {
|
||||
base.StreamPipeline
|
||||
|
||||
excludedSegments *typeutil.ConcurrentMap[int64, *datapb.SegmentInfo]
|
||||
collectionID UniqueID
|
||||
}
|
||||
|
||||
func (p *pipeline) ExcludedSegments(segInfos ...*datapb.SegmentInfo) {
|
||||
for _, segInfo := range segInfos {
|
||||
log.Debug("pipeline add exclude info",
|
||||
zap.Int64("segmentID", segInfo.GetID()),
|
||||
zap.Uint64("tss", segInfo.GetDmlPosition().Timestamp),
|
||||
)
|
||||
p.excludedSegments.Insert(segInfo.GetID(), segInfo)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pipeline) Close() {
|
||||
p.StreamPipeline.Close()
|
||||
metrics.CleanupQueryNodeCollectionMetrics(paramtable.GetNodeID(), p.collectionID)
|
||||
}
|
||||
|
||||
func NewPipeLine(
|
||||
collectionID UniqueID,
|
||||
channel string,
|
||||
manager *DataManager,
|
||||
tSafeManager TSafeManager,
|
||||
dispatcher msgdispatcher.Client,
|
||||
delegator delegator.ShardDelegator,
|
||||
) (Pipeline, error) {
|
||||
pipelineQueueLength := paramtable.Get().QueryNodeCfg.FlowGraphMaxQueueLength.GetAsInt32()
|
||||
excludedSegments := typeutil.NewConcurrentMap[int64, *datapb.SegmentInfo]()
|
||||
|
||||
p := &pipeline{
|
||||
collectionID: collectionID,
|
||||
excludedSegments: excludedSegments,
|
||||
StreamPipeline: base.NewPipelineWithStream(dispatcher, nodeCtxTtInterval, enableTtChecker, channel),
|
||||
}
|
||||
|
||||
filterNode := newFilterNode(collectionID, channel, manager, excludedSegments, pipelineQueueLength)
|
||||
insertNode := newInsertNode(collectionID, channel, manager, delegator, pipelineQueueLength)
|
||||
deleteNode := newDeleteNode(collectionID, channel, manager, tSafeManager, delegator, pipelineQueueLength)
|
||||
p.Add(filterNode, insertNode, deleteNode)
|
||||
return p, nil
|
||||
}
|
|
@ -0,0 +1,167 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgdispatcher"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type PipelineTestSuite struct {
|
||||
suite.Suite
|
||||
//datas
|
||||
collectionName string
|
||||
collectionID int64
|
||||
partitionIDs []int64
|
||||
channel string
|
||||
insertSegmentIDs []int64
|
||||
deletePKs []int64
|
||||
|
||||
//dependencies
|
||||
tSafeManager TSafeManager
|
||||
|
||||
//mocks
|
||||
segmentManager *segments.MockSegmentManager
|
||||
collectionManager *segments.MockCollectionManager
|
||||
delegator *delegator.MockShardDelegator
|
||||
msgDispatcher *msgdispatcher.MockClient
|
||||
msgChan chan *msgstream.MsgPack
|
||||
}
|
||||
|
||||
func (suite *PipelineTestSuite) SetupSuite() {
|
||||
suite.collectionID = 111
|
||||
suite.collectionName = "test-collection"
|
||||
suite.channel = "test-channel"
|
||||
suite.partitionIDs = []int64{11, 22}
|
||||
suite.insertSegmentIDs = []int64{1, 2, 3}
|
||||
suite.deletePKs = []int64{1, 2, 3}
|
||||
suite.msgChan = make(chan *msgstream.MsgPack, 1)
|
||||
}
|
||||
|
||||
func (suite *PipelineTestSuite) buildMsgPack(schema *schemapb.CollectionSchema) *msgstream.MsgPack {
|
||||
msgPack := &msgstream.MsgPack{
|
||||
BeginTs: 0,
|
||||
EndTs: 1,
|
||||
Msgs: []msgstream.TsMsg{},
|
||||
}
|
||||
|
||||
for id, segmentID := range suite.insertSegmentIDs {
|
||||
insertMsg := buildInsertMsg(suite.collectionID, suite.partitionIDs[id%len(suite.partitionIDs)], segmentID, suite.channel, 1)
|
||||
insertMsg.FieldsData = genFiledDataWithSchema(schema, 1)
|
||||
msgPack.Msgs = append(msgPack.Msgs, insertMsg)
|
||||
}
|
||||
|
||||
for id, pk := range suite.deletePKs {
|
||||
deleteMsg := buildDeleteMsg(suite.collectionID, suite.partitionIDs[id%len(suite.partitionIDs)], suite.channel, 1)
|
||||
deleteMsg.PrimaryKeys = genDeletePK(pk)
|
||||
msgPack.Msgs = append(msgPack.Msgs, deleteMsg)
|
||||
}
|
||||
return msgPack
|
||||
}
|
||||
|
||||
func (suite *PipelineTestSuite) SetupTest() {
|
||||
paramtable.Init()
|
||||
//init mock
|
||||
// init manager
|
||||
suite.collectionManager = segments.NewMockCollectionManager(suite.T())
|
||||
suite.segmentManager = segments.NewMockSegmentManager(suite.T())
|
||||
// init delegator
|
||||
suite.delegator = delegator.NewMockShardDelegator(suite.T())
|
||||
// init mq dispatcher
|
||||
suite.msgDispatcher = msgdispatcher.NewMockClient(suite.T())
|
||||
|
||||
//init dependency
|
||||
// init tsafeManager
|
||||
suite.tSafeManager = tsafe.NewTSafeReplica()
|
||||
suite.tSafeManager.Add(suite.channel, 0)
|
||||
}
|
||||
|
||||
func (suite *PipelineTestSuite) TestBasic() {
|
||||
//init mock
|
||||
// mock collection manager
|
||||
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
|
||||
collection := segments.NewCollection(suite.collectionID, schema, querypb.LoadType_LoadCollection)
|
||||
suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection)
|
||||
|
||||
// mock mq factory
|
||||
suite.msgDispatcher.EXPECT().Register(suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.msgChan, nil)
|
||||
suite.msgDispatcher.EXPECT().Deregister(suite.channel)
|
||||
|
||||
// mock delegator
|
||||
suite.delegator.EXPECT().ProcessInsert(mock.Anything).Run(
|
||||
func(insertRecords map[int64]*delegator.InsertData) {
|
||||
for segmentID := range insertRecords {
|
||||
suite.True(lo.Contains(suite.insertSegmentIDs, segmentID))
|
||||
}
|
||||
})
|
||||
|
||||
suite.delegator.EXPECT().ProcessDelete(mock.Anything, mock.Anything).Run(
|
||||
func(deleteData []*delegator.DeleteData, ts uint64) {
|
||||
for _, data := range deleteData {
|
||||
for _, pk := range data.PrimaryKeys {
|
||||
suite.True(lo.Contains(suite.deletePKs, pk.GetValue().(int64)))
|
||||
}
|
||||
}
|
||||
})
|
||||
//build pipleine
|
||||
manager := &segments.Manager{
|
||||
Collection: suite.collectionManager,
|
||||
Segment: suite.segmentManager,
|
||||
}
|
||||
pipeline, err := NewPipeLine(suite.collectionID, suite.channel, manager, suite.tSafeManager, suite.msgDispatcher, suite.delegator)
|
||||
suite.NoError(err)
|
||||
|
||||
//Init Consumer
|
||||
err = pipeline.ConsumeMsgStream(&msgpb.MsgPosition{})
|
||||
suite.NoError(err)
|
||||
|
||||
err = pipeline.Start()
|
||||
suite.NoError(err)
|
||||
defer pipeline.Close()
|
||||
|
||||
// watch tsafe manager
|
||||
listener := suite.tSafeManager.WatchChannel(suite.channel)
|
||||
|
||||
// build input msg
|
||||
in := suite.buildMsgPack(schema)
|
||||
suite.msgChan <- in
|
||||
|
||||
// wait pipeline work
|
||||
<-listener.On()
|
||||
|
||||
//check tsafe
|
||||
tsafe, err := suite.tSafeManager.Get(suite.channel)
|
||||
suite.NoError(err)
|
||||
suite.Equal(in.EndTs, tsafe)
|
||||
}
|
||||
|
||||
func TestQueryNodePipeline(t *testing.T) {
|
||||
suite.Run(t, new(PipelineTestSuite))
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
base "github.com/milvus-io/milvus/internal/util/pipeline"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
const (
|
||||
// TODO: better to be configured
|
||||
nodeCtxTtInterval = 2 * time.Minute
|
||||
enableTtChecker = true
|
||||
loadTypeCollection = querypb.LoadType_LoadCollection
|
||||
loadTypePartition = querypb.LoadType_LoadPartition
|
||||
segmentTypeGrowing = commonpb.SegmentState_Growing
|
||||
segmentTypeSealed = commonpb.SegmentState_Sealed
|
||||
)
|
||||
|
||||
type (
|
||||
UniqueID = typeutil.UniqueID
|
||||
Timestamp = typeutil.Timestamp
|
||||
PrimaryKey = storage.PrimaryKey
|
||||
|
||||
InsertMsg = msgstream.InsertMsg
|
||||
DeleteMsg = msgstream.DeleteMsg
|
||||
|
||||
Collection = segments.Collection
|
||||
DataManager = segments.Manager
|
||||
Segment = segments.Segment
|
||||
|
||||
TSafeManager = tsafe.Manager
|
||||
|
||||
BaseNode = base.BaseNode
|
||||
Msg = base.Msg
|
||||
)
|
||||
|
||||
type TimeRange struct {
|
||||
timestampMin Timestamp
|
||||
timestampMax Timestamp
|
||||
}
|
|
@ -0,0 +1,131 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pkoracle
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/bits-and-blooms/bloom/v3"
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var _ Candidate = (*BloomFilterSet)(nil)
|
||||
|
||||
// BloomFilterSet is one implementation of Candidate with bloom filter in statslog.
|
||||
type BloomFilterSet struct {
|
||||
statsMutex sync.RWMutex
|
||||
segmentID int64
|
||||
paritionID int64
|
||||
segType commonpb.SegmentState
|
||||
currentStat *storage.PkStatistics
|
||||
historyStats []*storage.PkStatistics
|
||||
}
|
||||
|
||||
// MayPkExist returns whether any bloom filters returns positive.
|
||||
func (s *BloomFilterSet) MayPkExist(pk storage.PrimaryKey) bool {
|
||||
s.statsMutex.RLock()
|
||||
defer s.statsMutex.RUnlock()
|
||||
if s.currentStat != nil && s.currentStat.PkExist(pk) {
|
||||
return true
|
||||
}
|
||||
|
||||
// for sealed, if one of the stats shows it exist, then we have to check it
|
||||
for _, historyStat := range s.historyStats {
|
||||
if historyStat.PkExist(pk) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ID implement candidate.
|
||||
func (s *BloomFilterSet) ID() int64 {
|
||||
return s.segmentID
|
||||
}
|
||||
|
||||
// Partition implements candidate.
|
||||
func (s *BloomFilterSet) Partition() int64 {
|
||||
return s.paritionID
|
||||
}
|
||||
|
||||
// Type implements candidate.
|
||||
func (s *BloomFilterSet) Type() commonpb.SegmentState {
|
||||
return s.segType
|
||||
}
|
||||
|
||||
// UpdateBloomFilter updates currentStats with provided pks.
|
||||
func (s *BloomFilterSet) UpdateBloomFilter(pks []storage.PrimaryKey) {
|
||||
s.statsMutex.Lock()
|
||||
defer s.statsMutex.Unlock()
|
||||
|
||||
if s.currentStat == nil {
|
||||
s.currentStat = &storage.PkStatistics{
|
||||
PkFilter: bloom.NewWithEstimates(storage.BloomFilterSize, storage.MaxBloomFalsePositive),
|
||||
}
|
||||
}
|
||||
|
||||
buf := make([]byte, 8)
|
||||
for _, pk := range pks {
|
||||
s.currentStat.UpdateMinMax(pk)
|
||||
switch pk.Type() {
|
||||
case schemapb.DataType_Int64:
|
||||
int64Value := pk.(*storage.Int64PrimaryKey).Value
|
||||
common.Endian.PutUint64(buf, uint64(int64Value))
|
||||
s.currentStat.PkFilter.Add(buf)
|
||||
case schemapb.DataType_VarChar:
|
||||
stringValue := pk.(*storage.VarCharPrimaryKey).Value
|
||||
s.currentStat.PkFilter.AddString(stringValue)
|
||||
default:
|
||||
log.Error("failed to update bloomfilter", zap.Any("PK type", pk.Type()))
|
||||
panic("failed to update bloomfilter")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddHistoricalStats add loaded historical stats.
|
||||
func (s *BloomFilterSet) AddHistoricalStats(stats *storage.PkStatistics) {
|
||||
s.statsMutex.Lock()
|
||||
defer s.statsMutex.Unlock()
|
||||
|
||||
s.historyStats = append(s.historyStats, stats)
|
||||
}
|
||||
|
||||
// initCurrentStat initialize currentStats if nil.
|
||||
// Note: invoker shall acquire statsMutex lock first.
|
||||
func (s *BloomFilterSet) initCurrentStat() {
|
||||
if s.currentStat == nil {
|
||||
s.currentStat = &storage.PkStatistics{
|
||||
PkFilter: bloom.NewWithEstimates(storage.BloomFilterSize, storage.MaxBloomFalsePositive),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewBloomFilterSet returns a new BloomFilterSet.
|
||||
func NewBloomFilterSet(segmentID int64, paritionID int64, segType commonpb.SegmentState) *BloomFilterSet {
|
||||
bfs := &BloomFilterSet{
|
||||
segmentID: segmentID,
|
||||
paritionID: paritionID,
|
||||
segType: segType,
|
||||
}
|
||||
// does not need to init current
|
||||
return bfs
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pkoracle
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
// Candidate is the interface for pk oracle candidate.
|
||||
type Candidate interface {
|
||||
// MayPkExist checks whether primary key could exists in this candidate.
|
||||
MayPkExist(pk storage.PrimaryKey) bool
|
||||
|
||||
ID() int64
|
||||
Partition() int64
|
||||
Type() commonpb.SegmentState
|
||||
}
|
||||
|
||||
type candidateWithWorker struct {
|
||||
Candidate
|
||||
workerID int64
|
||||
}
|
||||
|
||||
// CandidateFilter filter type for candidate.
|
||||
type CandidateFilter func(candidate candidateWithWorker) bool
|
||||
|
||||
// WithSegmentType returns CandiateFilter with provided segment type.
|
||||
func WithSegmentType(typ commonpb.SegmentState) CandidateFilter {
|
||||
return func(candidate candidateWithWorker) bool {
|
||||
return candidate.Type() == typ
|
||||
}
|
||||
}
|
||||
|
||||
// WithWorkerID returns CandidateFilter with provided worker id.
|
||||
func WithWorkerID(workerID int64) CandidateFilter {
|
||||
return func(candidate candidateWithWorker) bool {
|
||||
return candidate.workerID == workerID
|
||||
}
|
||||
}
|
||||
|
||||
// WithSegmentIDs returns CandidateFilter with provided segment ids.
|
||||
func WithSegmentIDs(segmentIDs ...int64) CandidateFilter {
|
||||
set := typeutil.NewSet[int64]()
|
||||
set.Insert(segmentIDs...)
|
||||
return func(candidate candidateWithWorker) bool {
|
||||
return set.Contain(candidate.ID())
|
||||
}
|
||||
}
|
||||
|
||||
// WithPartitionID returns CandidateFilter with provided partitionID.
|
||||
func WithPartitionID(partitionID int64) CandidateFilter {
|
||||
return func(candidate candidateWithWorker) bool {
|
||||
return candidate.Partition() == partitionID || partitionID == common.InvalidPartitionID
|
||||
}
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pkoracle
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
)
|
||||
|
||||
type candidateKey struct {
|
||||
segmentID int64
|
||||
partitionID int64
|
||||
typ commonpb.SegmentState
|
||||
}
|
||||
|
||||
// MayPkExist checks whether primary key could exists in this candidate.
|
||||
func (k candidateKey) MayPkExist(pk storage.PrimaryKey) bool {
|
||||
// always return true to prevent miuse
|
||||
return true
|
||||
}
|
||||
|
||||
// ID implements Candidate.
|
||||
func (k candidateKey) ID() int64 {
|
||||
return k.segmentID
|
||||
}
|
||||
|
||||
// Partition implements Candidate.
|
||||
func (k candidateKey) Partition() int64 {
|
||||
return k.partitionID
|
||||
}
|
||||
|
||||
// Type implements Candidate.
|
||||
func (k candidateKey) Type() commonpb.SegmentState {
|
||||
return k.typ
|
||||
}
|
||||
|
||||
// NewCandidateKey creates a candidateKey and returns as Candidate.
|
||||
func NewCandidateKey(id int64, partitionID int64, typ commonpb.SegmentState) Candidate {
|
||||
return candidateKey{
|
||||
segmentID: id,
|
||||
partitionID: partitionID,
|
||||
typ: typ,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,103 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// pkoracle package contains pk - segment mapping logic.
|
||||
package pkoracle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
// PkOracle interface for pk oracle.
|
||||
type PkOracle interface {
|
||||
// GetCandidates returns segment candidates of which pk might belongs to.
|
||||
Get(pk storage.PrimaryKey, filters ...CandidateFilter) ([]int64, error)
|
||||
// RegisterCandidate adds candidate into pkOracle.
|
||||
Register(candidate Candidate, workerID int64) error
|
||||
// RemoveCandidate removes candidate
|
||||
Remove(filters ...CandidateFilter) error
|
||||
// CheckCandidate checks whether candidate with provided key exists.
|
||||
Exists(candidate Candidate, workerID int64) bool
|
||||
}
|
||||
|
||||
var _ PkOracle = (*pkOracle)(nil)
|
||||
|
||||
// pkOracle implementation.
|
||||
type pkOracle struct {
|
||||
candidates *typeutil.ConcurrentMap[string, candidateWithWorker]
|
||||
}
|
||||
|
||||
// Get implements PkOracle.
|
||||
func (pko *pkOracle) Get(pk storage.PrimaryKey, filters ...CandidateFilter) ([]int64, error) {
|
||||
var result []int64
|
||||
pko.candidates.Range(func(key string, candidate candidateWithWorker) bool {
|
||||
for _, filter := range filters {
|
||||
if !filter(candidate) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if candidate.MayPkExist(pk) {
|
||||
result = append(result, candidate.ID())
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (pko *pkOracle) candidateKey(candidate Candidate, workerID int64) string {
|
||||
return fmt.Sprintf("%s-%d-%d", candidate.Type().String(), workerID, candidate.ID())
|
||||
}
|
||||
|
||||
// Register register candidate
|
||||
func (pko *pkOracle) Register(candidate Candidate, workerID int64) error {
|
||||
pko.candidates.Insert(pko.candidateKey(candidate, workerID), candidateWithWorker{
|
||||
Candidate: candidate,
|
||||
workerID: workerID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove removes candidate from pko.
|
||||
func (pko *pkOracle) Remove(filters ...CandidateFilter) error {
|
||||
pko.candidates.Range(func(key string, candidate candidateWithWorker) bool {
|
||||
for _, filter := range filters {
|
||||
if !filter(candidate) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
pko.candidates.GetAndRemove(pko.candidateKey(candidate, candidate.workerID))
|
||||
|
||||
return true
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pko *pkOracle) Exists(candidate Candidate, workerID int64) bool {
|
||||
_, ok := pko.candidates.Get(pko.candidateKey(candidate, workerID))
|
||||
return ok
|
||||
}
|
||||
|
||||
// NewPkOracle returns pkOracle as PkOracle interface.
|
||||
func NewPkOracle() PkOracle {
|
||||
return &pkOracle{
|
||||
candidates: typeutil.NewConcurrentMap[string, candidateWithWorker](),
|
||||
}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package segments
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/bits-and-blooms/bloom/v3"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
storage "github.com/milvus-io/milvus/internal/storage"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type bloomFilterSet struct {
|
||||
statsMutex sync.RWMutex
|
||||
currentStat *storage.PkStatistics
|
||||
historyStats []*storage.PkStatistics
|
||||
}
|
||||
|
||||
func newBloomFilterSet() *bloomFilterSet {
|
||||
return &bloomFilterSet{}
|
||||
}
|
||||
|
||||
// MayPkExist returns whether any bloom filters returns positive.
|
||||
func (s *bloomFilterSet) MayPkExist(pk storage.PrimaryKey) bool {
|
||||
s.statsMutex.RLock()
|
||||
defer s.statsMutex.RUnlock()
|
||||
if s.currentStat != nil && s.currentStat.PkExist(pk) {
|
||||
return true
|
||||
}
|
||||
|
||||
// for sealed, if one of the stats shows it exist, then we have to check it
|
||||
for _, historyStat := range s.historyStats {
|
||||
if historyStat.PkExist(pk) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateBloomFilter updates currentStats with provided pks.
|
||||
func (s *bloomFilterSet) UpdateBloomFilter(pks []storage.PrimaryKey) {
|
||||
s.statsMutex.Lock()
|
||||
defer s.statsMutex.Unlock()
|
||||
|
||||
if s.currentStat == nil {
|
||||
s.initCurrentStat()
|
||||
}
|
||||
|
||||
buf := make([]byte, 8)
|
||||
for _, pk := range pks {
|
||||
s.currentStat.UpdateMinMax(pk)
|
||||
switch pk.Type() {
|
||||
case schemapb.DataType_Int64:
|
||||
int64Value := pk.(*storage.Int64PrimaryKey).Value
|
||||
common.Endian.PutUint64(buf, uint64(int64Value))
|
||||
s.currentStat.PkFilter.Add(buf)
|
||||
case schemapb.DataType_VarChar:
|
||||
stringValue := pk.(*storage.VarCharPrimaryKey).Value
|
||||
s.currentStat.PkFilter.AddString(stringValue)
|
||||
default:
|
||||
log.Error("failed to update bloomfilter", zap.Any("PK type", pk.Type()))
|
||||
panic("failed to update bloomfilter")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddHistoricalStats add loaded historical stats.
|
||||
func (s *bloomFilterSet) AddHistoricalStats(stats *storage.PkStatistics) {
|
||||
s.statsMutex.Lock()
|
||||
defer s.statsMutex.Unlock()
|
||||
|
||||
s.historyStats = append(s.historyStats, stats)
|
||||
}
|
||||
|
||||
// initCurrentStat initialize currentStats if nil.
|
||||
// Note: invoker shall acquire statsMutex lock first.
|
||||
func (s *bloomFilterSet) initCurrentStat() {
|
||||
s.currentStat = &storage.PkStatistics{
|
||||
PkFilter: bloom.NewWithEstimates(storage.BloomFilterSize, storage.MaxBloomFalsePositive),
|
||||
}
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package segments
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type BloomFilterSetSuite struct {
|
||||
suite.Suite
|
||||
|
||||
intPks []int64
|
||||
stringPks []string
|
||||
set *bloomFilterSet
|
||||
}
|
||||
|
||||
func (suite *BloomFilterSetSuite) SetupTest() {
|
||||
suite.intPks = []int64{1, 2, 3}
|
||||
suite.stringPks = []string{"1", "2", "3"}
|
||||
suite.set = newBloomFilterSet()
|
||||
}
|
||||
|
||||
func (suite *BloomFilterSetSuite) TestInt64PkBloomFilter() {
|
||||
pks, err := storage.GenInt64PrimaryKeys(suite.intPks...)
|
||||
suite.NoError(err)
|
||||
|
||||
suite.set.UpdateBloomFilter(pks)
|
||||
for _, pk := range pks {
|
||||
exist := suite.set.MayPkExist(pk)
|
||||
suite.True(exist)
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *BloomFilterSetSuite) TestStringPkBloomFilter() {
|
||||
pks, err := storage.GenVarcharPrimaryKeys(suite.stringPks...)
|
||||
suite.NoError(err)
|
||||
|
||||
suite.set.UpdateBloomFilter(pks)
|
||||
for _, pk := range pks {
|
||||
exist := suite.set.MayPkExist(pk)
|
||||
suite.True(exist)
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *BloomFilterSetSuite) TestHistoricalBloomFilter() {
|
||||
pks, err := storage.GenVarcharPrimaryKeys(suite.stringPks...)
|
||||
suite.NoError(err)
|
||||
|
||||
suite.set.UpdateBloomFilter(pks)
|
||||
for _, pk := range pks {
|
||||
exist := suite.set.MayPkExist(pk)
|
||||
suite.True(exist)
|
||||
}
|
||||
|
||||
old := suite.set.currentStat
|
||||
suite.set.currentStat = nil
|
||||
for _, pk := range pks {
|
||||
exist := suite.set.MayPkExist(pk)
|
||||
suite.False(exist)
|
||||
}
|
||||
|
||||
suite.set.AddHistoricalStats(old)
|
||||
for _, pk := range pks {
|
||||
exist := suite.set.MayPkExist(pk)
|
||||
suite.True(exist)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBloomFilterSet(t *testing.T) {
|
||||
suite.Run(t, &BloomFilterSetSuite{})
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package segments
|
||||
|
||||
/*
|
||||
#cgo pkg-config: milvus_segcore milvus_storage
|
||||
|
||||
#include "segcore/collection_c.h"
|
||||
#include "common/type_c.h"
|
||||
#include "segcore/segment_c.h"
|
||||
#include "storage/storage_c.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/util/cgoconverter"
|
||||
)
|
||||
|
||||
// HandleCStatus deals with the error returned from CGO
|
||||
func HandleCStatus(status *C.CStatus, extraInfo string) error {
|
||||
if status.error_code == 0 {
|
||||
return nil
|
||||
}
|
||||
errorCode := status.error_code
|
||||
errorName, ok := commonpb.ErrorCode_name[int32(errorCode)]
|
||||
if !ok {
|
||||
errorName = "UnknownError"
|
||||
}
|
||||
errorMsg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
|
||||
finalMsg := fmt.Sprintf("[%s] %s", errorName, errorMsg)
|
||||
logMsg := fmt.Sprintf("%s, C Runtime Exception: %s\n", extraInfo, finalMsg)
|
||||
log.Warn(logMsg)
|
||||
return errors.New(finalMsg)
|
||||
}
|
||||
|
||||
// HandleCProto deal with the result proto returned from CGO
|
||||
func HandleCProto(cRes *C.CProto, msg proto.Message) error {
|
||||
// Standalone CProto is protobuf created by C side,
|
||||
// Passed from c side
|
||||
// memory is managed manually
|
||||
lease, blob := cgoconverter.UnsafeGoBytes(&cRes.proto_blob, int(cRes.proto_size))
|
||||
defer cgoconverter.Release(lease)
|
||||
|
||||
return proto.Unmarshal(blob, msg)
|
||||
}
|
||||
|
||||
// CopyCProtoBlob returns the copy of C memory
|
||||
func CopyCProtoBlob(cProto *C.CProto) []byte {
|
||||
blob := C.GoBytes(cProto.proto_blob, C.int32_t(cProto.proto_size))
|
||||
C.free(cProto.proto_blob)
|
||||
return blob
|
||||
}
|
||||
|
||||
// GetCProtoBlob returns the raw C memory, invoker should release it itself
|
||||
func GetCProtoBlob(cProto *C.CProto) []byte {
|
||||
lease, blob := cgoconverter.UnsafeGoBytes(&cProto.proto_blob, int(cProto.proto_size))
|
||||
cgoconverter.Extract(lease)
|
||||
return blob
|
||||
}
|
||||
|
||||
func GetLocalUsedSize() (int64, error) {
|
||||
var availableSize int64
|
||||
cSize := C.int64_t(availableSize)
|
||||
|
||||
status := C.GetLocalUsedSize(&cSize)
|
||||
err := HandleCStatus(&status, "get local used size failed")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return availableSize, nil
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package segments
|
||||
|
||||
/*
|
||||
#cgo pkg-config: milvus_segcore
|
||||
|
||||
#include "segcore/collection_c.h"
|
||||
#include "segcore/segment_c.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
type CollectionManager interface {
|
||||
Get(collectionID int64) *Collection
|
||||
Put(collectionID int64, schema *schemapb.CollectionSchema, loadMeta *querypb.LoadMetaInfo)
|
||||
}
|
||||
|
||||
type collectionManager struct {
|
||||
mut sync.RWMutex
|
||||
collections map[int64]*Collection
|
||||
}
|
||||
|
||||
func NewCollectionManager() *collectionManager {
|
||||
return &collectionManager{
|
||||
collections: make(map[int64]*Collection),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *collectionManager) Get(collectionID int64) *Collection {
|
||||
m.mut.RLock()
|
||||
defer m.mut.RUnlock()
|
||||
|
||||
return m.collections[collectionID]
|
||||
}
|
||||
|
||||
func (m *collectionManager) Put(collectionID int64, schema *schemapb.CollectionSchema, loadMeta *querypb.LoadMetaInfo) {
|
||||
m.mut.Lock()
|
||||
defer m.mut.Unlock()
|
||||
|
||||
if _, ok := m.collections[collectionID]; ok {
|
||||
return
|
||||
}
|
||||
|
||||
collection := NewCollection(collectionID, schema, loadMeta.GetLoadType())
|
||||
collection.AddPartition(loadMeta.GetPartitionIDs()...)
|
||||
m.collections[collectionID] = collection
|
||||
}
|
||||
|
||||
// Collection is a wrapper of the underlying C-structure C.CCollection
|
||||
type Collection struct {
|
||||
mu sync.RWMutex // protects colllectionPtr
|
||||
collectionPtr C.CCollection
|
||||
id int64
|
||||
partitions *typeutil.ConcurrentSet[int64]
|
||||
loadType querypb.LoadType
|
||||
schema *schemapb.CollectionSchema
|
||||
}
|
||||
|
||||
// ID returns collection id
|
||||
func (c *Collection) ID() int64 {
|
||||
return c.id
|
||||
}
|
||||
|
||||
// Schema returns the schema of collection
|
||||
func (c *Collection) Schema() *schemapb.CollectionSchema {
|
||||
return c.schema
|
||||
}
|
||||
|
||||
// getPartitionIDs return partitionIDs of collection
|
||||
func (c *Collection) GetPartitions() []int64 {
|
||||
return c.partitions.Collect()
|
||||
}
|
||||
|
||||
func (c *Collection) ExistPartition(partitionIDs ...int64) bool {
|
||||
return c.partitions.Contain(partitionIDs...)
|
||||
}
|
||||
|
||||
// addPartitionID would add a partition id to partition id list of collection
|
||||
func (c *Collection) AddPartition(partitions ...int64) {
|
||||
for i := range partitions {
|
||||
c.partitions.Insert(partitions[i])
|
||||
}
|
||||
}
|
||||
|
||||
// removePartitionID removes the partition id from partition id list of collection
|
||||
func (c *Collection) RemovePartition(partitionID int64) {
|
||||
c.partitions.Remove(partitionID)
|
||||
}
|
||||
|
||||
// getLoadType get the loadType of collection, which is loadTypeCollection or loadTypePartition
|
||||
func (c *Collection) GetLoadType() querypb.LoadType {
|
||||
return c.loadType
|
||||
}
|
||||
|
||||
// newCollection returns a new Collection
|
||||
func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, loadType querypb.LoadType) *Collection {
|
||||
/*
|
||||
CCollection
|
||||
NewCollection(const char* schema_proto_blob);
|
||||
*/
|
||||
schemaBlob := proto.MarshalTextString(schema)
|
||||
cSchemaBlob := C.CString(schemaBlob)
|
||||
defer C.free(unsafe.Pointer(cSchemaBlob))
|
||||
|
||||
collection := C.NewCollection(cSchemaBlob)
|
||||
return &Collection{
|
||||
collectionPtr: collection,
|
||||
id: collectionID,
|
||||
schema: schema,
|
||||
partitions: typeutil.NewConcurrentSet[int64](),
|
||||
loadType: loadType,
|
||||
}
|
||||
}
|
||||
|
||||
func NewCollectionWithoutSchema(collectionID int64, loadType querypb.LoadType) *Collection {
|
||||
return &Collection{
|
||||
id: collectionID,
|
||||
partitions: typeutil.NewConcurrentSet[int64](),
|
||||
loadType: loadType,
|
||||
}
|
||||
}
|
||||
|
||||
// deleteCollection delete collection and free the collection memory
|
||||
func DeleteCollection(collection *Collection) {
|
||||
/*
|
||||
void
|
||||
deleteCollection(CCollection collection);
|
||||
*/
|
||||
cPtr := collection.collectionPtr
|
||||
C.DeleteCollection(cPtr)
|
||||
|
||||
collection.collectionPtr = nil
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package segments
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// Manager related errors
|
||||
ErrCollectionNotFound = errors.New("CollectionNotFound")
|
||||
ErrPartitionNotFound = errors.New("PartitionNotFound")
|
||||
ErrSegmentNotFound = errors.New("SegmentNotFound")
|
||||
ErrFieldNotFound = errors.New("FieldNotFound")
|
||||
ErrSegmentReleased = errors.New("SegmentReleased")
|
||||
)
|
||||
|
||||
func WrapSegmentNotFound(segmentID int64) error {
|
||||
return fmt.Errorf("%w(%v)", ErrSegmentNotFound, segmentID)
|
||||
}
|
||||
|
||||
func WrapCollectionNotFound(collectionID int64) error {
|
||||
return fmt.Errorf("%w(%v)", ErrCollectionNotFound, collectionID)
|
||||
}
|
||||
|
||||
func WrapFieldNotFound(fieldID int64) error {
|
||||
return fmt.Errorf("%w(%v)", ErrFieldNotFound, fieldID)
|
||||
}
|
||||
|
||||
// WrapSegmentReleased wrap ErrSegmentReleased with segmentID.
|
||||
func WrapSegmentReleased(segmentID int64) error {
|
||||
return fmt.Errorf("%w(%d)", ErrSegmentReleased, segmentID)
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package segments
|
||||
|
||||
/*
|
||||
#cgo pkg-config: milvus_common milvus_segcore
|
||||
|
||||
#include "segcore/load_index_c.h"
|
||||
#include "common/binary_set_c.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"unsafe"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparams"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
)
|
||||
|
||||
// LoadIndexInfo is a wrapper of the underlying C-structure C.CLoadIndexInfo
|
||||
type LoadIndexInfo struct {
|
||||
cLoadIndexInfo C.CLoadIndexInfo
|
||||
}
|
||||
|
||||
// newLoadIndexInfo returns a new LoadIndexInfo and error
|
||||
func newLoadIndexInfo() (*LoadIndexInfo, error) {
|
||||
var cLoadIndexInfo C.CLoadIndexInfo
|
||||
|
||||
// TODO::xige-16 support embedded milvus
|
||||
storageType := "minio"
|
||||
cAddress := C.CString(paramtable.Get().MinioCfg.Address.GetValue())
|
||||
cBucketName := C.CString(paramtable.Get().MinioCfg.BucketName.GetValue())
|
||||
cAccessKey := C.CString(paramtable.Get().MinioCfg.AccessKeyID.GetValue())
|
||||
cAccessValue := C.CString(paramtable.Get().MinioCfg.SecretAccessKey.GetValue())
|
||||
cRootPath := C.CString(paramtable.Get().MinioCfg.RootPath.GetValue())
|
||||
cStorageType := C.CString(storageType)
|
||||
cIamEndPoint := C.CString(paramtable.Get().MinioCfg.IAMEndpoint.GetValue())
|
||||
defer C.free(unsafe.Pointer(cAddress))
|
||||
defer C.free(unsafe.Pointer(cBucketName))
|
||||
defer C.free(unsafe.Pointer(cAccessKey))
|
||||
defer C.free(unsafe.Pointer(cAccessValue))
|
||||
defer C.free(unsafe.Pointer(cRootPath))
|
||||
defer C.free(unsafe.Pointer(cStorageType))
|
||||
defer C.free(unsafe.Pointer(cIamEndPoint))
|
||||
storageConfig := C.CStorageConfig{
|
||||
address: cAddress,
|
||||
bucket_name: cBucketName,
|
||||
access_key_id: cAccessKey,
|
||||
access_key_value: cAccessValue,
|
||||
remote_root_path: cRootPath,
|
||||
storage_type: cStorageType,
|
||||
iam_endpoint: cIamEndPoint,
|
||||
useSSL: C.bool(paramtable.Get().MinioCfg.UseSSL.GetAsBool()),
|
||||
useIAM: C.bool(paramtable.Get().MinioCfg.UseIAM.GetAsBool()),
|
||||
}
|
||||
|
||||
status := C.NewLoadIndexInfo(&cLoadIndexInfo, storageConfig)
|
||||
if err := HandleCStatus(&status, "NewLoadIndexInfo failed"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &LoadIndexInfo{cLoadIndexInfo: cLoadIndexInfo}, nil
|
||||
}
|
||||
|
||||
// deleteLoadIndexInfo would delete C.CLoadIndexInfo
|
||||
func deleteLoadIndexInfo(info *LoadIndexInfo) {
|
||||
C.DeleteLoadIndexInfo(info.cLoadIndexInfo)
|
||||
}
|
||||
|
||||
func (li *LoadIndexInfo) appendLoadIndexInfo(bytesIndex [][]byte, indexInfo *querypb.FieldIndexInfo, collectionID int64, partitionID int64, segmentID int64, fieldType schemapb.DataType) error {
|
||||
fieldID := indexInfo.FieldID
|
||||
indexPaths := indexInfo.IndexFilePaths
|
||||
|
||||
err := li.appendFieldInfo(collectionID, partitionID, segmentID, fieldID, fieldType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = li.appendIndexInfo(indexInfo.IndexID, indexInfo.BuildID, indexInfo.IndexVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// some build params also exist in indexParams, which are useless during loading process
|
||||
indexParams := funcutil.KeyValuePair2Map(indexInfo.IndexParams)
|
||||
indexparams.SetDiskIndexLoadParams(indexParams, indexInfo.GetNumRows())
|
||||
|
||||
jsonIndexParams, err := json.Marshal(indexParams)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to json marshal index params %w", err)
|
||||
return err
|
||||
}
|
||||
log.Info("start append index params", zap.String("index params", string(jsonIndexParams)))
|
||||
|
||||
for key, value := range indexParams {
|
||||
err = li.appendIndexParam(key, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = li.appendIndexData(bytesIndex, indexPaths)
|
||||
return err
|
||||
}
|
||||
|
||||
// appendIndexParam append indexParam to index
|
||||
func (li *LoadIndexInfo) appendIndexParam(indexKey string, indexValue string) error {
|
||||
cIndexKey := C.CString(indexKey)
|
||||
defer C.free(unsafe.Pointer(cIndexKey))
|
||||
cIndexValue := C.CString(indexValue)
|
||||
defer C.free(unsafe.Pointer(cIndexValue))
|
||||
status := C.AppendIndexParam(li.cLoadIndexInfo, cIndexKey, cIndexValue)
|
||||
return HandleCStatus(&status, "AppendIndexParam failed")
|
||||
}
|
||||
|
||||
func (li *LoadIndexInfo) appendIndexInfo(indexID int64, buildID int64, indexVersion int64) error {
|
||||
cIndexID := C.int64_t(indexID)
|
||||
cBuildID := C.int64_t(buildID)
|
||||
cIndexVersion := C.int64_t(indexVersion)
|
||||
|
||||
status := C.AppendIndexInfo(li.cLoadIndexInfo, cIndexID, cBuildID, cIndexVersion)
|
||||
return HandleCStatus(&status, "AppendIndexInfo failed")
|
||||
}
|
||||
|
||||
func (li *LoadIndexInfo) cleanLocalData() error {
|
||||
status := C.CleanLoadedIndex(li.cLoadIndexInfo)
|
||||
return HandleCStatus(&status, "failed to clean cached data on disk")
|
||||
}
|
||||
|
||||
func (li *LoadIndexInfo) appendIndexFile(filePath string) error {
|
||||
cIndexFilePath := C.CString(filePath)
|
||||
defer C.free(unsafe.Pointer(cIndexFilePath))
|
||||
|
||||
status := C.AppendIndexFilePath(li.cLoadIndexInfo, cIndexFilePath)
|
||||
return HandleCStatus(&status, "AppendIndexIFile failed")
|
||||
}
|
||||
|
||||
// appendFieldInfo appends fieldID & fieldType to index
|
||||
func (li *LoadIndexInfo) appendFieldInfo(collectionID int64, partitionID int64, segmentID int64, fieldID int64, fieldType schemapb.DataType) error {
|
||||
cColID := C.int64_t(collectionID)
|
||||
cParID := C.int64_t(partitionID)
|
||||
cSegID := C.int64_t(segmentID)
|
||||
cFieldID := C.int64_t(fieldID)
|
||||
cintDType := uint32(fieldType)
|
||||
status := C.AppendFieldInfo(li.cLoadIndexInfo, cColID, cParID, cSegID, cFieldID, cintDType)
|
||||
return HandleCStatus(&status, "AppendFieldInfo failed")
|
||||
}
|
||||
|
||||
// appendIndexData appends binarySet index to cLoadIndexInfo
|
||||
func (li *LoadIndexInfo) appendIndexData(bytesIndex [][]byte, indexKeys []string) error {
|
||||
for _, indexPath := range indexKeys {
|
||||
err := li.appendIndexFile(indexPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var cBinarySet C.CBinarySet
|
||||
status := C.NewBinarySet(&cBinarySet)
|
||||
defer C.DeleteBinarySet(cBinarySet)
|
||||
|
||||
if err := HandleCStatus(&status, "NewBinarySet failed"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, byteIndex := range bytesIndex {
|
||||
indexPtr := unsafe.Pointer(&byteIndex[0])
|
||||
indexLen := C.int64_t(len(byteIndex))
|
||||
binarySetKey := filepath.Base(indexKeys[i])
|
||||
indexKey := C.CString(binarySetKey)
|
||||
status = C.AppendIndexBinary(cBinarySet, indexPtr, indexLen, indexKey)
|
||||
C.free(unsafe.Pointer(indexKey))
|
||||
if err := HandleCStatus(&status, "LoadIndexInfo AppendIndexBinary failed"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
status = C.AppendIndex(li.cLoadIndexInfo, cBinarySet)
|
||||
return HandleCStatus(&status, "AppendIndex failed")
|
||||
}
|
|
@ -0,0 +1,308 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package segments
|
||||
|
||||
/*
|
||||
#cgo pkg-config: milvus_segcore
|
||||
|
||||
#include "segcore/collection_c.h"
|
||||
#include "segcore/segment_c.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
. "github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
type SegmentFilter func(segment Segment) bool
|
||||
|
||||
func WithPartition(partitionID UniqueID) SegmentFilter {
|
||||
return func(segment Segment) bool {
|
||||
return segment.Partition() == partitionID
|
||||
}
|
||||
}
|
||||
|
||||
func WithChannel(channel string) SegmentFilter {
|
||||
return func(segment Segment) bool {
|
||||
return segment.Shard() == channel
|
||||
}
|
||||
}
|
||||
|
||||
func WithType(typ SegmentType) SegmentFilter {
|
||||
return func(segment Segment) bool {
|
||||
return segment.Type() == typ
|
||||
}
|
||||
}
|
||||
|
||||
func WithID(id int64) SegmentFilter {
|
||||
return func(segment Segment) bool {
|
||||
return segment.ID() == id
|
||||
}
|
||||
}
|
||||
|
||||
type actionType int32
|
||||
|
||||
const (
|
||||
removeAction actionType = iota
|
||||
addAction
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
Collection CollectionManager
|
||||
Segment SegmentManager
|
||||
}
|
||||
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
Collection: NewCollectionManager(),
|
||||
Segment: NewSegmentManager(),
|
||||
}
|
||||
}
|
||||
|
||||
type SegmentManager interface {
|
||||
// Put puts the given segments in,
|
||||
// and increases the ref count of the corresponding collection,
|
||||
// dup segments will not increase the ref count
|
||||
Put(segmentType SegmentType, segments ...Segment)
|
||||
Get(segmentID UniqueID) Segment
|
||||
GetWithType(segmentID UniqueID, typ SegmentType) Segment
|
||||
GetBy(filters ...SegmentFilter) []Segment
|
||||
GetSealed(segmentID UniqueID) Segment
|
||||
GetGrowing(segmentID UniqueID) Segment
|
||||
// Remove removes the given segment,
|
||||
// and decreases the ref count of the corresponding collection,
|
||||
// will not decrease the ref count if the given segment not exists
|
||||
Remove(segmentID UniqueID, scope querypb.DataScope)
|
||||
RemoveBy(filters ...SegmentFilter)
|
||||
}
|
||||
|
||||
var _ SegmentManager = (*segmentManager)(nil)
|
||||
|
||||
// Manager manages all collections and segments
|
||||
type segmentManager struct {
|
||||
mu sync.RWMutex // guards all
|
||||
|
||||
growingSegments map[UniqueID]Segment
|
||||
sealedSegments map[UniqueID]Segment
|
||||
}
|
||||
|
||||
func NewSegmentManager() *segmentManager {
|
||||
return &segmentManager{
|
||||
growingSegments: make(map[int64]Segment),
|
||||
sealedSegments: make(map[int64]Segment),
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
|
||||
targetMap := mgr.growingSegments
|
||||
switch segmentType {
|
||||
case SegmentTypeGrowing:
|
||||
targetMap = mgr.growingSegments
|
||||
case SegmentTypeSealed:
|
||||
targetMap = mgr.sealedSegments
|
||||
default:
|
||||
panic("unexpected segment type")
|
||||
}
|
||||
|
||||
for _, segment := range segments {
|
||||
if _, ok := targetMap[segment.ID()]; ok {
|
||||
continue
|
||||
}
|
||||
targetMap[segment.ID()] = segment
|
||||
metrics.QueryNodeNumSegments.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
fmt.Sprint(segment.Collection()),
|
||||
fmt.Sprint(segment.Partition()),
|
||||
segment.Type().String(),
|
||||
fmt.Sprint(len(segment.Indexes())),
|
||||
).Inc()
|
||||
if segment.RowNum() > 0 {
|
||||
metrics.QueryNodeNumEntities.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
fmt.Sprint(segment.Collection()),
|
||||
fmt.Sprint(segment.Partition()),
|
||||
segment.Type().String(),
|
||||
fmt.Sprint(len(segment.Indexes())),
|
||||
).Add(float64(segment.RowNum()))
|
||||
}
|
||||
}
|
||||
mgr.updateMetric()
|
||||
}
|
||||
|
||||
func (mgr *segmentManager) Get(segmentID UniqueID) Segment {
|
||||
mgr.mu.RLock()
|
||||
defer mgr.mu.RUnlock()
|
||||
|
||||
if segment, ok := mgr.growingSegments[segmentID]; ok {
|
||||
return segment
|
||||
} else if segment, ok = mgr.sealedSegments[segmentID]; ok {
|
||||
return segment
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mgr *segmentManager) GetWithType(segmentID UniqueID, typ SegmentType) Segment {
|
||||
mgr.mu.RLock()
|
||||
defer mgr.mu.RUnlock()
|
||||
|
||||
switch typ {
|
||||
case SegmentTypeSealed:
|
||||
return mgr.sealedSegments[segmentID]
|
||||
case SegmentTypeGrowing:
|
||||
return mgr.growingSegments[segmentID]
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *segmentManager) GetBy(filters ...SegmentFilter) []Segment {
|
||||
mgr.mu.RLock()
|
||||
defer mgr.mu.RUnlock()
|
||||
|
||||
ret := make([]Segment, 0)
|
||||
for _, segment := range mgr.growingSegments {
|
||||
if filter(segment, filters...) {
|
||||
ret = append(ret, segment)
|
||||
}
|
||||
}
|
||||
|
||||
for _, segment := range mgr.sealedSegments {
|
||||
if filter(segment, filters...) {
|
||||
ret = append(ret, segment)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func filter(segment Segment, filters ...SegmentFilter) bool {
|
||||
for _, filter := range filters {
|
||||
if !filter(segment) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (mgr *segmentManager) GetSealed(segmentID UniqueID) Segment {
|
||||
mgr.mu.RLock()
|
||||
defer mgr.mu.RUnlock()
|
||||
|
||||
if segment, ok := mgr.sealedSegments[segmentID]; ok {
|
||||
return segment
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mgr *segmentManager) GetGrowing(segmentID UniqueID) Segment {
|
||||
mgr.mu.RLock()
|
||||
defer mgr.mu.RUnlock()
|
||||
|
||||
if segment, ok := mgr.growingSegments[segmentID]; ok {
|
||||
return segment
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mgr *segmentManager) Remove(segmentID UniqueID, scope querypb.DataScope) {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
|
||||
switch scope {
|
||||
case querypb.DataScope_Streaming:
|
||||
remove(segmentID, mgr.growingSegments)
|
||||
|
||||
case querypb.DataScope_Historical:
|
||||
remove(segmentID, mgr.sealedSegments)
|
||||
|
||||
case querypb.DataScope_All:
|
||||
remove(segmentID, mgr.growingSegments)
|
||||
remove(segmentID, mgr.sealedSegments)
|
||||
}
|
||||
mgr.updateMetric()
|
||||
}
|
||||
|
||||
func (mgr *segmentManager) RemoveBy(filters ...SegmentFilter) {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
|
||||
for id, segment := range mgr.growingSegments {
|
||||
if filter(segment, filters...) {
|
||||
remove(id, mgr.growingSegments)
|
||||
}
|
||||
}
|
||||
|
||||
for id, segment := range mgr.sealedSegments {
|
||||
if filter(segment, filters...) {
|
||||
remove(id, mgr.sealedSegments)
|
||||
}
|
||||
}
|
||||
mgr.updateMetric()
|
||||
}
|
||||
|
||||
func (mgr *segmentManager) updateMetric() {
|
||||
// update collection and partiation metric
|
||||
var collections, partiations = make(Set[int64]), make(Set[int64])
|
||||
for _, seg := range mgr.growingSegments {
|
||||
collections.Insert(seg.Collection())
|
||||
partiations.Insert(seg.Partition())
|
||||
}
|
||||
for _, seg := range mgr.sealedSegments {
|
||||
collections.Insert(seg.Collection())
|
||||
partiations.Insert(seg.Partition())
|
||||
}
|
||||
metrics.QueryNodeNumCollections.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(collections.Len()))
|
||||
metrics.QueryNodeNumPartitions.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(partiations.Len()))
|
||||
}
|
||||
|
||||
func remove(segmentID int64, container map[int64]Segment) {
|
||||
segment, ok := container[segmentID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(container, segmentID)
|
||||
|
||||
rowNum := segment.RowNum()
|
||||
DeleteSegment(segment.(*LocalSegment))
|
||||
|
||||
metrics.QueryNodeNumSegments.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
fmt.Sprint(segment.Collection()),
|
||||
fmt.Sprint(segment.Partition()),
|
||||
segment.Type().String(),
|
||||
fmt.Sprint(len(segment.Indexes())),
|
||||
).Dec()
|
||||
if rowNum > 0 {
|
||||
metrics.QueryNodeNumEntities.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
fmt.Sprint(segment.Collection()),
|
||||
fmt.Sprint(segment.Partition()),
|
||||
segment.Type().String(),
|
||||
fmt.Sprint(len(segment.Indexes())),
|
||||
).Sub(float64(rowNum))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
package segments
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/paramtable"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ManagerSuite struct {
|
||||
suite.Suite
|
||||
|
||||
// Data
|
||||
segmentIDs []int64
|
||||
collectionIDs []int64
|
||||
partitionIDs []int64
|
||||
channels []string
|
||||
types []SegmentType
|
||||
segments []*LocalSegment
|
||||
|
||||
mgr *segmentManager
|
||||
}
|
||||
|
||||
func (s *ManagerSuite) SetupSuite() {
|
||||
paramtable.Init()
|
||||
s.segmentIDs = []int64{1, 2, 3}
|
||||
s.collectionIDs = []int64{100, 200, 300}
|
||||
s.partitionIDs = []int64{10, 11, 12}
|
||||
s.channels = []string{"dml1", "dml2", "dml3"}
|
||||
s.types = []SegmentType{SegmentTypeSealed, SegmentTypeGrowing, SegmentTypeSealed}
|
||||
}
|
||||
|
||||
func (s *ManagerSuite) SetupTest() {
|
||||
s.mgr = NewSegmentManager()
|
||||
|
||||
for i, id := range s.segmentIDs {
|
||||
segment, err := NewSegment(
|
||||
NewCollection(s.collectionIDs[i], GenTestCollectionSchema("manager-suite", schemapb.DataType_Int64), querypb.LoadType_LoadCollection),
|
||||
id,
|
||||
s.partitionIDs[i],
|
||||
s.collectionIDs[i],
|
||||
s.channels[i],
|
||||
s.types[i],
|
||||
0,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
s.segments = append(s.segments, segment)
|
||||
|
||||
s.mgr.Put(s.types[i], segment)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ManagerSuite) TestGetBy() {
|
||||
for i, partitionID := range s.partitionIDs {
|
||||
segments := s.mgr.GetBy(WithPartition(partitionID))
|
||||
s.Contains(segments, s.segments[i])
|
||||
}
|
||||
|
||||
for i, channel := range s.channels {
|
||||
segments := s.mgr.GetBy(WithChannel(channel))
|
||||
s.Contains(segments, s.segments[i])
|
||||
}
|
||||
|
||||
for i, typ := range s.types {
|
||||
segments := s.mgr.GetBy(WithType(typ))
|
||||
s.Contains(segments, s.segments[i])
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ManagerSuite) TestRemoveGrowing() {
|
||||
for i, id := range s.segmentIDs {
|
||||
isGrowing := s.types[i] == SegmentTypeGrowing
|
||||
|
||||
s.mgr.Remove(id, querypb.DataScope_Streaming)
|
||||
s.Equal(s.mgr.Get(id) == nil, isGrowing)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ManagerSuite) TestRemoveSealed() {
|
||||
for i, id := range s.segmentIDs {
|
||||
isSealed := s.types[i] == SegmentTypeSealed
|
||||
|
||||
s.mgr.Remove(id, querypb.DataScope_Historical)
|
||||
s.Equal(s.mgr.Get(id) == nil, isSealed)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ManagerSuite) TestRemoveAll() {
|
||||
for _, id := range s.segmentIDs {
|
||||
s.mgr.Remove(id, querypb.DataScope_All)
|
||||
s.Nil(s.mgr.Get(id))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ManagerSuite) TestRemoveBy() {
|
||||
for _, id := range s.segmentIDs {
|
||||
s.mgr.RemoveBy(WithID(id))
|
||||
s.Nil(s.mgr.Get(id))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager(t *testing.T) {
|
||||
suite.Run(t, new(ManagerSuite))
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue