mirror of https://github.com/milvus-io/milvus.git
feat: support load and query with bm25 metric (#36071)
relate: https://github.com/milvus-io/milvus/issues/35853 --------- Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>pull/36715/head
parent
90285830de
commit
db34572c56
3
Makefile
3
Makefile
|
@ -532,6 +532,9 @@ generate-mockery-utils: getdeps
|
|||
# proxy_client_manager.go
|
||||
$(INSTALL_PATH)/mockery --name=ProxyClientManagerInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_client_manager.go --with-expecter --structname=MockProxyClientManager --inpackage
|
||||
$(INSTALL_PATH)/mockery --name=ProxyWatcherInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_watcher.go --with-expecter --structname=MockProxyWatcher --inpackage
|
||||
# function
|
||||
$(INSTALL_PATH)/mockery --name=FunctionRunner --dir=$(PWD)/internal/util/function --output=$(PWD)/internal/util/function --filename=mock_function.go --with-expecter --structname=MockFunctionRunner --inpackage
|
||||
|
||||
|
||||
generate-mockery-kv: getdeps
|
||||
$(INSTALL_PATH)/mockery --name=TxnKV --dir=$(PWD)/pkg/kv --output=$(PWD)/internal/kv/mocks --filename=txn_kv.go --with-expecter
|
||||
|
|
|
@ -410,7 +410,8 @@ inline bool
|
|||
IsFloatVectorMetricType(const MetricType& metric_type) {
|
||||
return metric_type == knowhere::metric::L2 ||
|
||||
metric_type == knowhere::metric::IP ||
|
||||
metric_type == knowhere::metric::COSINE;
|
||||
metric_type == knowhere::metric::COSINE ||
|
||||
metric_type == knowhere::metric::BM25;
|
||||
}
|
||||
|
||||
inline bool
|
||||
|
|
|
@ -160,13 +160,15 @@ inline bool
|
|||
IsFloatMetricType(const knowhere::MetricType& metric_type) {
|
||||
return IsMetricType(metric_type, knowhere::metric::L2) ||
|
||||
IsMetricType(metric_type, knowhere::metric::IP) ||
|
||||
IsMetricType(metric_type, knowhere::metric::COSINE);
|
||||
IsMetricType(metric_type, knowhere::metric::COSINE) ||
|
||||
IsMetricType(metric_type, knowhere::metric::BM25);
|
||||
}
|
||||
|
||||
inline bool
|
||||
PositivelyRelated(const knowhere::MetricType& metric_type) {
|
||||
return IsMetricType(metric_type, knowhere::metric::IP) ||
|
||||
IsMetricType(metric_type, knowhere::metric::COSINE);
|
||||
IsMetricType(metric_type, knowhere::metric::COSINE) ||
|
||||
IsMetricType(metric_type, knowhere::metric::BM25);
|
||||
}
|
||||
|
||||
inline std::string
|
||||
|
|
|
@ -409,7 +409,8 @@ VectorMemIndex<T>::Query(const DatasetPtr dataset,
|
|||
milvus::tracer::AddEvent("finish_knowhere_index_search");
|
||||
if (!res.has_value()) {
|
||||
PanicInfo(ErrorCode::UnexpectedError,
|
||||
"failed to search: {}: {}",
|
||||
"failed to search: config={} {}: {}",
|
||||
search_conf.dump(),
|
||||
KnowhereStatusString(res.error()),
|
||||
res.what());
|
||||
}
|
||||
|
|
|
@ -52,6 +52,11 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
|||
search_info.materialized_view_involved =
|
||||
query_info_proto.materialized_view_involved();
|
||||
|
||||
if (query_info_proto.bm25_avgdl() > 0) {
|
||||
search_info.search_params_[knowhere::meta::BM25_AVGDL] =
|
||||
query_info_proto.bm25_avgdl();
|
||||
}
|
||||
|
||||
if (query_info_proto.group_by_field_id() > 0) {
|
||||
auto group_by_field_id =
|
||||
FieldId(query_info_proto.group_by_field_id());
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "IndexConfigGenerator.h"
|
||||
#include "knowhere/comp/index_param.h"
|
||||
#include "log/Log.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
@ -49,15 +50,28 @@ VecIndexConfig::VecIndexConfig(const int64_t max_index_row_cout,
|
|||
std::to_string(config_.get_nlist());
|
||||
build_params_[knowhere::indexparam::SSIZE] = std::to_string(
|
||||
std::max((int)(config_.get_chunk_rows() / config_.get_nlist()), 48));
|
||||
|
||||
if (is_sparse && metric_type_ == knowhere::metric::BM25) {
|
||||
build_params_[knowhere::meta::BM25_K1] =
|
||||
index_meta_.GetIndexParams().at(knowhere::meta::BM25_K1);
|
||||
build_params_[knowhere::meta::BM25_B] =
|
||||
index_meta_.GetIndexParams().at(knowhere::meta::BM25_B);
|
||||
build_params_[knowhere::meta::BM25_AVGDL] =
|
||||
index_meta_.GetIndexParams().at(knowhere::meta::BM25_AVGDL);
|
||||
}
|
||||
|
||||
search_params_[knowhere::indexparam::NPROBE] =
|
||||
std::to_string(config_.get_nprobe());
|
||||
|
||||
// note for sparse vector index: drop_ratio_build is not allowed for growing
|
||||
// segment index.
|
||||
LOG_INFO(
|
||||
"VecIndexConfig: origin_index_type={}, index_type={}, metric_type={}",
|
||||
"VecIndexConfig: origin_index_type={}, index_type={}, metric_type={}, "
|
||||
"config={}",
|
||||
origin_index_type_,
|
||||
index_type_,
|
||||
metric_type_);
|
||||
metric_type_,
|
||||
build_params_.dump());
|
||||
}
|
||||
|
||||
int64_t
|
||||
|
@ -100,6 +114,11 @@ VecIndexConfig::GetSearchConf(const SearchInfo& searchInfo) {
|
|||
searchParam.search_params_[key] = searchInfo.search_params_[key];
|
||||
}
|
||||
}
|
||||
|
||||
if (metric_type_ == knowhere::metric::BM25) {
|
||||
searchParam.search_params_[knowhere::meta::BM25_AVGDL] =
|
||||
searchInfo.search_params_[knowhere::meta::BM25_AVGDL];
|
||||
}
|
||||
return searchParam;
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "indexbuilder/ScalarIndexCreator.h"
|
||||
#include "indexbuilder/VecIndexCreator.h"
|
||||
#include "indexbuilder/index_c.h"
|
||||
#include "knowhere/comp/index_param.h"
|
||||
#include "pb/index_cgo_msg.pb.h"
|
||||
#include "storage/Types.h"
|
||||
|
||||
|
@ -100,6 +101,14 @@ generate_build_conf(const milvus::IndexType& index_type,
|
|||
};
|
||||
} else if (index_type == knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX ||
|
||||
index_type == knowhere::IndexEnum::INDEX_SPARSE_WAND) {
|
||||
if (metric_type == knowhere::metric::BM25) {
|
||||
return knowhere::Json{
|
||||
{knowhere::meta::METRIC_TYPE, metric_type},
|
||||
{knowhere::indexparam::DROP_RATIO_BUILD, "0.1"},
|
||||
{knowhere::meta::BM25_K1, "1.2"},
|
||||
{knowhere::meta::BM25_B, "0.75"},
|
||||
{knowhere::meta::BM25_AVGDL, "100"}};
|
||||
}
|
||||
return knowhere::Json{
|
||||
{knowhere::meta::METRIC_TYPE, metric_type},
|
||||
{knowhere::indexparam::DROP_RATIO_BUILD, "0.1"},
|
||||
|
|
|
@ -652,7 +652,7 @@ func (s *L0CompactionTaskSuite) TestPorcessStateTrans() {
|
|||
s.Equal(datapb.CompactionTaskState_failed, t.GetState())
|
||||
})
|
||||
|
||||
s.Run("test unkonwn task", func() {
|
||||
s.Run("test unknown task", func() {
|
||||
t := s.generateTestL0Task(datapb.CompactionTaskState_unknown)
|
||||
|
||||
got := t.Process()
|
||||
|
|
|
@ -73,7 +73,7 @@ func newEmbeddingNode(channelName string, schema *schemapb.CollectionSchema) (*e
|
|||
}
|
||||
|
||||
func (eNode *embeddingNode) Name() string {
|
||||
return fmt.Sprintf("embeddingNode-%s-%s", "BM25test", eNode.channelName)
|
||||
return fmt.Sprintf("embeddingNode-%s", eNode.channelName)
|
||||
}
|
||||
|
||||
func (eNode *embeddingNode) bm25Embedding(runner function.FunctionRunner, inputFieldId, outputFieldId int64, data *storage.InsertData, meta map[int64]*storage.BM25Stats) error {
|
||||
|
|
|
@ -96,6 +96,7 @@ message SubSearchRequest {
|
|||
string metricType = 9;
|
||||
int64 group_by_field_id = 10;
|
||||
int64 group_size = 11;
|
||||
int64 field_id = 12;
|
||||
}
|
||||
|
||||
message SearchRequest {
|
||||
|
@ -124,6 +125,7 @@ message SearchRequest {
|
|||
common.ConsistencyLevel consistency_level = 22;
|
||||
int64 group_by_field_id = 23;
|
||||
int64 group_size = 24;
|
||||
int64 field_id = 25;
|
||||
}
|
||||
|
||||
message SubSearchResults {
|
||||
|
|
|
@ -64,6 +64,8 @@ message QueryInfo {
|
|||
bool materialized_view_involved = 7;
|
||||
int64 group_size = 8;
|
||||
bool group_strict_size = 9;
|
||||
double bm25_avgdl = 10;
|
||||
int64 query_field_id =11;
|
||||
}
|
||||
|
||||
message ColumnInfo {
|
||||
|
|
|
@ -367,6 +367,7 @@ message SegmentLoadInfo {
|
|||
int64 storageVersion = 18;
|
||||
bool is_sorted = 19;
|
||||
map<int64, data.TextIndexStats> textStatsLogs = 20;
|
||||
repeated data.FieldBinlog bm25logs = 21;
|
||||
}
|
||||
|
||||
message FieldIndexInfo {
|
||||
|
|
|
@ -361,8 +361,12 @@ func (cit *createIndexTask) parseIndexParams() error {
|
|||
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "float vector index does not support metric type: "+metricType)
|
||||
}
|
||||
} else if typeutil.IsSparseFloatVectorType(cit.fieldSchema.DataType) {
|
||||
if metricType != metric.IP {
|
||||
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "only IP is the supported metric type for sparse index")
|
||||
if metricType != metric.IP && metricType != metric.BM25 {
|
||||
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "only IP&BM25 is the supported metric type for sparse index")
|
||||
}
|
||||
|
||||
if metricType == metric.BM25 && cit.functionSchema.GetType() != schemapb.FunctionType_BM25 {
|
||||
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "only BM25 Function output field support BM25 metric type")
|
||||
}
|
||||
} else if typeutil.IsBinaryVectorType(cit.fieldSchema.DataType) {
|
||||
if !funcutil.SliceContain(indexparamcheck.BinaryVectorMetrics, metricType) {
|
||||
|
|
|
@ -370,6 +370,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||
GroupSize: t.rankParams.GetGroupSize(),
|
||||
}
|
||||
|
||||
internalSubReq.FieldId = queryInfo.GetQueryFieldId()
|
||||
// set PartitionIDs for sub search
|
||||
if t.partitionKeyMode {
|
||||
// isolatioin has tighter constraint, check first
|
||||
|
@ -449,6 +450,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
|||
}
|
||||
|
||||
t.SearchRequest.Offset = offset
|
||||
t.SearchRequest.FieldId = queryInfo.GetQueryFieldId()
|
||||
|
||||
if t.partitionKeyMode {
|
||||
// isolatioin has tighter constraint, check first
|
||||
|
@ -511,6 +513,8 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string
|
|||
if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
|
||||
return nil, nil, 0, errors.New("not support search_group_by operation based on binary vector column")
|
||||
}
|
||||
|
||||
queryInfo.QueryFieldId = annField.GetFieldID()
|
||||
plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, queryInfo)
|
||||
if planErr != nil {
|
||||
log.Warn("failed to create query plan", zap.Error(planErr),
|
||||
|
|
|
@ -81,6 +81,7 @@ func PackSegmentLoadInfo(segment *datapb.SegmentInfo, channelCheckpoint *msgpb.M
|
|||
NumOfRows: segment.NumOfRows,
|
||||
Statslogs: segment.Statslogs,
|
||||
Deltalogs: segment.Deltalogs,
|
||||
Bm25Logs: segment.Bm25Statslogs,
|
||||
InsertChannel: segment.InsertChannel,
|
||||
IndexInfos: indexes,
|
||||
StartPosition: segment.GetStartPosition(),
|
||||
|
|
|
@ -34,6 +34,7 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||
|
@ -43,6 +44,7 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
|
@ -54,6 +56,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/util/lifetime"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metautil"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
|
@ -110,7 +113,9 @@ type shardDelegator struct {
|
|||
|
||||
lifetime lifetime.Lifetime[lifetime.State]
|
||||
|
||||
distribution *distribution
|
||||
distribution *distribution
|
||||
idfOracle IDFOracle
|
||||
|
||||
segmentManager segments.SegmentManager
|
||||
tsafeManager tsafe.Manager
|
||||
pkOracle pkoracle.PkOracle
|
||||
|
@ -135,6 +140,10 @@ type shardDelegator struct {
|
|||
// in order to make add/remove growing be atomic, need lock before modify these meta info
|
||||
growingSegmentLock sync.RWMutex
|
||||
partitionStatsMut sync.RWMutex
|
||||
|
||||
// fieldId -> functionRunner map for search function field
|
||||
functionRunners map[UniqueID]function.FunctionRunner
|
||||
hasBM25Field bool
|
||||
}
|
||||
|
||||
// getLogger returns the zap logger with pre-defined shard attributes.
|
||||
|
@ -235,6 +244,19 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest
|
|||
}()
|
||||
}
|
||||
|
||||
// build idf for bm25 search
|
||||
if req.GetReq().GetMetricType() == metric.BM25 {
|
||||
avgdl, err := sd.buildBM25IDF(req.GetReq())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if avgdl <= 0 {
|
||||
log.Warn("search bm25 from empty data, skip search", zap.String("channel", sd.vchannelName), zap.Float64("avgdl", avgdl))
|
||||
return []*internalpb.SearchResults{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// get final sealedNum after possible segment prune
|
||||
sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) })
|
||||
log.Debug("search segments...",
|
||||
|
@ -335,6 +357,7 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
|||
IsAdvanced: false,
|
||||
GroupByFieldId: subReq.GetGroupByFieldId(),
|
||||
GroupSize: subReq.GetGroupSize(),
|
||||
FieldId: subReq.GetFieldId(),
|
||||
}
|
||||
future := conc.Go(func() (*internalpb.SearchResults, error) {
|
||||
searchReq := &querypb.SearchRequest{
|
||||
|
@ -862,6 +885,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
|
|||
|
||||
excludedSegments := NewExcludedSegments(paramtable.Get().QueryNodeCfg.CleanExcludeSegInterval.GetAsDuration(time.Second))
|
||||
|
||||
idfOracle := NewIDFOracle(collection.Schema().GetFunctions())
|
||||
sd := &shardDelegator{
|
||||
collectionID: collectionID,
|
||||
replicaID: replicaID,
|
||||
|
@ -871,7 +895,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
|
|||
segmentManager: manager.Segment,
|
||||
workerManager: workerManager,
|
||||
lifetime: lifetime.NewLifetime(lifetime.Initializing),
|
||||
distribution: NewDistribution(),
|
||||
distribution: NewDistribution(idfOracle),
|
||||
deleteBuffer: deletebuffer.NewListDeleteBuffer[*deletebuffer.Item](startTs, sizePerBlock),
|
||||
pkOracle: pkoracle.NewPkOracle(),
|
||||
tsafeManager: tsafeManager,
|
||||
|
@ -880,9 +904,25 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
|
|||
factory: factory,
|
||||
queryHook: queryHook,
|
||||
chunkManager: chunkManager,
|
||||
idfOracle: idfOracle,
|
||||
partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot),
|
||||
excludedSegments: excludedSegments,
|
||||
functionRunners: make(map[int64]function.FunctionRunner),
|
||||
}
|
||||
|
||||
for _, tf := range collection.Schema().GetFunctions() {
|
||||
if tf.GetType() == schemapb.FunctionType_BM25 {
|
||||
functionRunner, err := function.NewFunctionRunner(collection.Schema(), tf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sd.functionRunners[tf.OutputFieldIds[0]] = functionRunner
|
||||
if tf.GetType() == schemapb.FunctionType_BM25 {
|
||||
sd.hasBM25Field = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m := sync.Mutex{}
|
||||
sd.tsCond = sync.NewCond(&m)
|
||||
if sd.lifetime.Add(lifetime.NotStopped) == nil {
|
||||
|
|
|
@ -27,11 +27,14 @@ import (
|
|||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/distributed/streaming"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||
|
@ -63,10 +66,12 @@ import (
|
|||
|
||||
// InsertData
|
||||
type InsertData struct {
|
||||
RowIDs []int64
|
||||
PrimaryKeys []storage.PrimaryKey
|
||||
Timestamps []uint64
|
||||
InsertRecord *segcorepb.InsertRecord
|
||||
RowIDs []int64
|
||||
PrimaryKeys []storage.PrimaryKey
|
||||
Timestamps []uint64
|
||||
InsertRecord *segcorepb.InsertRecord
|
||||
BM25Stats map[int64]*storage.BM25Stats
|
||||
|
||||
StartPosition *msgpb.MsgPosition
|
||||
PartitionID int64
|
||||
}
|
||||
|
@ -149,6 +154,7 @@ func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) {
|
|||
if !sd.pkOracle.Exists(growing, paramtable.GetNodeID()) {
|
||||
// register created growing segment after insert, avoid to add empty growing to delegator
|
||||
sd.pkOracle.Register(growing, paramtable.GetNodeID())
|
||||
sd.idfOracle.Register(segmentID, insertData.BM25Stats, segments.SegmentTypeGrowing)
|
||||
sd.segmentManager.Put(context.Background(), segments.SegmentTypeGrowing, growing)
|
||||
sd.addGrowing(SegmentEntry{
|
||||
NodeID: paramtable.GetNodeID(),
|
||||
|
@ -158,10 +164,12 @@ func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) {
|
|||
TargetVersion: initialTargetVersion,
|
||||
})
|
||||
}
|
||||
sd.growingSegmentLock.Unlock()
|
||||
}
|
||||
|
||||
log.Debug("insert into growing segment",
|
||||
sd.growingSegmentLock.Unlock()
|
||||
} else {
|
||||
sd.idfOracle.UpdateGrowing(growing.ID(), insertData.BM25Stats)
|
||||
}
|
||||
log.Info("insert into growing segment",
|
||||
zap.Int64("collectionID", growing.Collection()),
|
||||
zap.Int64("segmentID", segmentID),
|
||||
zap.Int("rowCount", len(insertData.RowIDs)),
|
||||
|
@ -375,8 +383,11 @@ func (sd *shardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.Segm
|
|||
segmentIDs = lo.Map(loaded, func(segment segments.Segment, _ int) int64 { return segment.ID() })
|
||||
log.Info("load growing segments done", zap.Int64s("segmentIDs", segmentIDs))
|
||||
|
||||
for _, candidate := range loaded {
|
||||
sd.pkOracle.Register(candidate, paramtable.GetNodeID())
|
||||
for _, segment := range loaded {
|
||||
sd.pkOracle.Register(segment, paramtable.GetNodeID())
|
||||
if sd.hasBM25Field {
|
||||
sd.idfOracle.Register(segment.ID(), segment.GetBM25Stats(), segments.SegmentTypeGrowing)
|
||||
}
|
||||
}
|
||||
sd.addGrowing(lo.Map(loaded, func(segment segments.Segment, _ int) SegmentEntry {
|
||||
return SegmentEntry{
|
||||
|
@ -472,6 +483,16 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg
|
|||
infos := lo.Filter(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) bool {
|
||||
return !sd.pkOracle.Exists(pkoracle.NewCandidateKey(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed), targetNodeID)
|
||||
})
|
||||
|
||||
var bm25Stats *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats]
|
||||
if sd.hasBM25Field {
|
||||
bm25Stats, err = sd.loader.LoadBM25Stats(ctx, req.GetCollectionID(), infos...)
|
||||
if err != nil {
|
||||
log.Warn("failed to load bm25 stats for segment", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
candidates, err := sd.loader.LoadBloomFilterSet(ctx, req.GetCollectionID(), req.GetVersion(), infos...)
|
||||
if err != nil {
|
||||
log.Warn("failed to load bloom filter set for segment", zap.Error(err))
|
||||
|
@ -479,7 +500,7 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg
|
|||
}
|
||||
|
||||
log.Debug("load delete...")
|
||||
err = sd.loadStreamDelete(ctx, candidates, infos, req, targetNodeID, worker)
|
||||
err = sd.loadStreamDelete(ctx, candidates, bm25Stats, infos, req, targetNodeID, worker)
|
||||
if err != nil {
|
||||
log.Warn("load stream delete failed", zap.Error(err))
|
||||
return err
|
||||
|
@ -552,6 +573,7 @@ func (sd *shardDelegator) RefreshLevel0DeletionStats() {
|
|||
|
||||
func (sd *shardDelegator) loadStreamDelete(ctx context.Context,
|
||||
candidates []*pkoracle.BloomFilterSet,
|
||||
bm25Stats *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats],
|
||||
infos []*querypb.SegmentLoadInfo,
|
||||
req *querypb.LoadSegmentsRequest,
|
||||
targetNodeID int64,
|
||||
|
@ -665,6 +687,14 @@ func (sd *shardDelegator) loadStreamDelete(ctx context.Context,
|
|||
)
|
||||
sd.pkOracle.Register(candidate, targetNodeID)
|
||||
}
|
||||
|
||||
if bm25Stats != nil {
|
||||
bm25Stats.Range(func(segmentID int64, stats map[int64]*storage.BM25Stats) bool {
|
||||
sd.idfOracle.Register(segmentID, stats, segments.SegmentTypeSealed)
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
log.Info("load delete done")
|
||||
|
||||
return nil
|
||||
|
@ -963,3 +993,47 @@ func (sd *shardDelegator) TryCleanExcludedSegments(ts uint64) {
|
|||
sd.excludedSegments.CleanInvalid(ts)
|
||||
}
|
||||
}
|
||||
|
||||
func (sd *shardDelegator) buildBM25IDF(req *internalpb.SearchRequest) (float64, error) {
|
||||
pb := &commonpb.PlaceholderGroup{}
|
||||
proto.Unmarshal(req.GetPlaceholderGroup(), pb)
|
||||
|
||||
if len(pb.Placeholders) != 1 || len(pb.Placeholders[0].Values) != 1 {
|
||||
return 0, merr.WrapErrParameterInvalidMsg("please provide varchar for bm25")
|
||||
}
|
||||
|
||||
holder := pb.Placeholders[0]
|
||||
if holder.Type != commonpb.PlaceholderType_VarChar {
|
||||
return 0, fmt.Errorf("can't build BM25 IDF for data not varchar")
|
||||
}
|
||||
|
||||
str := funcutil.GetVarCharFromPlaceholder(holder)
|
||||
functionRunner, ok := sd.functionRunners[req.GetFieldId()]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("functionRunner not found for field: %d", req.GetFieldId())
|
||||
}
|
||||
|
||||
// get search text term frequency
|
||||
output, err := functionRunner.BatchRun(str)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
tfArray, ok := output[0].(*schemapb.SparseFloatArray)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("functionRunner return unknown data")
|
||||
}
|
||||
|
||||
idfSparseVector, avgdl, err := sd.idfOracle.BuildIDF(req.GetFieldId(), tfArray)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = SetBM25Params(req, avgdl)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
req.PlaceholderGroup = funcutil.SparseVectorDataToPlaceholderGroupBytes(idfSparseVector)
|
||||
return avgdl, nil
|
||||
}
|
||||
|
|
|
@ -26,14 +26,19 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/pingcap/log"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||
|
@ -42,10 +47,12 @@ import (
|
|||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/bloomfilter"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/internal/util/initcore"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metautil"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
|
@ -95,13 +102,7 @@ func (s *DelegatorDataSuite) TearDownSuite() {
|
|||
paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.CleanExcludeSegInterval.Key)
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) SetupTest() {
|
||||
s.workerManager = &cluster.MockManager{}
|
||||
s.manager = segments.NewManager()
|
||||
s.tsafeManager = tsafe.NewTSafeReplica()
|
||||
s.loader = &segments.MockLoader{}
|
||||
|
||||
// init schema
|
||||
func (s *DelegatorDataSuite) genNormalCollection() {
|
||||
s.manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{
|
||||
Name: "TestCollection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
|
@ -154,7 +155,59 @@ func (s *DelegatorDataSuite) SetupTest() {
|
|||
LoadType: querypb.LoadType_LoadCollection,
|
||||
PartitionIDs: []int64{1001, 1002},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) genCollectionWithFunction() {
|
||||
s.manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{
|
||||
Name: "TestCollection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "id",
|
||||
FieldID: 100,
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
AutoID: true,
|
||||
}, {
|
||||
Name: "vector",
|
||||
FieldID: 101,
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
}, {
|
||||
Name: "text",
|
||||
FieldID: 102,
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.MaxLengthKey,
|
||||
Value: "256",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{{
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldIds: []int64{102},
|
||||
OutputFieldIds: []int64{101},
|
||||
}},
|
||||
}, nil, nil)
|
||||
|
||||
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager)
|
||||
s.NoError(err)
|
||||
s.delegator = delegator.(*shardDelegator)
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) SetupTest() {
|
||||
s.workerManager = &cluster.MockManager{}
|
||||
s.manager = segments.NewManager()
|
||||
s.tsafeManager = tsafe.NewTSafeReplica()
|
||||
s.loader = &segments.MockLoader{}
|
||||
|
||||
// init schema
|
||||
s.genNormalCollection()
|
||||
s.mq = &msgstream.MockMsgStream{}
|
||||
s.rootPath = s.Suite.T().Name()
|
||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
|
||||
|
@ -471,6 +524,127 @@ func (s *DelegatorDataSuite) TestProcessDelete() {
|
|||
s.False(s.delegator.distribution.Serviceable())
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) TestLoadGrowingWithBM25() {
|
||||
s.genCollectionWithFunction()
|
||||
mockSegment := segments.NewMockSegment(s.T())
|
||||
s.loader.EXPECT().Load(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]segments.Segment{mockSegment}, nil)
|
||||
|
||||
mockSegment.EXPECT().Partition().Return(111)
|
||||
mockSegment.EXPECT().ID().Return(111)
|
||||
mockSegment.EXPECT().Type().Return(commonpb.SegmentState_Growing)
|
||||
mockSegment.EXPECT().GetBM25Stats().Return(map[int64]*storage.BM25Stats{})
|
||||
|
||||
err := s.delegator.LoadGrowing(context.Background(), []*querypb.SegmentLoadInfo{{SegmentID: 1}}, 1)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) TestLoadSegmentsWithBm25() {
|
||||
s.genCollectionWithFunction()
|
||||
s.Run("normal_run", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
s.loader.ExpectedCalls = nil
|
||||
}()
|
||||
|
||||
statsMap := typeutil.NewConcurrentMap[int64, map[int64]*storage.BM25Stats]()
|
||||
stats := storage.NewBM25Stats()
|
||||
stats.Append(map[uint32]float32{1: 1})
|
||||
|
||||
statsMap.Insert(1, map[int64]*storage.BM25Stats{101: stats})
|
||||
|
||||
s.loader.EXPECT().LoadBM25Stats(mock.Anything, s.collectionID, mock.Anything).Return(statsMap, nil)
|
||||
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
|
||||
Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet {
|
||||
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet {
|
||||
return pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed)
|
||||
})
|
||||
}, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
workers[1] = worker1
|
||||
|
||||
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
|
||||
Return(nil)
|
||||
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
DstNodeID: 1,
|
||||
CollectionID: s.collectionID,
|
||||
Infos: []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: 100,
|
||||
PartitionID: 500,
|
||||
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
DeltaPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
Level: datapb.SegmentLevel_L1,
|
||||
InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", s.collectionID),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
s.NoError(err)
|
||||
sealed, _ := s.delegator.GetSegmentInfo(false)
|
||||
s.Require().Equal(1, len(sealed))
|
||||
s.Equal(int64(1), sealed[0].NodeID)
|
||||
s.ElementsMatch([]SegmentEntry{
|
||||
{
|
||||
SegmentID: 100,
|
||||
NodeID: 1,
|
||||
PartitionID: 500,
|
||||
TargetVersion: unreadableTargetVersion,
|
||||
Level: datapb.SegmentLevel_L1,
|
||||
},
|
||||
}, sealed[0].Segments)
|
||||
})
|
||||
|
||||
s.Run("loadBM25_failed", func() {
|
||||
defer func() {
|
||||
s.workerManager.ExpectedCalls = nil
|
||||
s.loader.ExpectedCalls = nil
|
||||
}()
|
||||
|
||||
s.loader.EXPECT().LoadBM25Stats(mock.Anything, s.collectionID, mock.Anything).Return(nil, fmt.Errorf("mock error"))
|
||||
|
||||
workers := make(map[int64]*cluster.MockWorker)
|
||||
worker1 := &cluster.MockWorker{}
|
||||
workers[1] = worker1
|
||||
|
||||
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
|
||||
Return(nil)
|
||||
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
|
||||
return workers[nodeID]
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(),
|
||||
DstNodeID: 1,
|
||||
CollectionID: s.collectionID,
|
||||
Infos: []*querypb.SegmentLoadInfo{
|
||||
{
|
||||
SegmentID: 100,
|
||||
PartitionID: 500,
|
||||
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
DeltaPosition: &msgpb.MsgPosition{Timestamp: 20000},
|
||||
Level: datapb.SegmentLevel_L1,
|
||||
InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", s.collectionID),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) TestLoadSegments() {
|
||||
s.Run("normal_run", func() {
|
||||
defer func() {
|
||||
|
@ -883,6 +1057,214 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
|
|||
})
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) TestBuildBM25IDF() {
|
||||
s.genCollectionWithFunction()
|
||||
|
||||
genBM25Stats := func(start uint32, end uint32) map[int64]*storage.BM25Stats {
|
||||
result := make(map[int64]*storage.BM25Stats)
|
||||
result[101] = storage.NewBM25Stats()
|
||||
for i := start; i < end; i++ {
|
||||
row := map[uint32]float32{i: 1}
|
||||
result[101].Append(row)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
genSnapShot := func(seals, grows []int64, targetVersion int64) *snapshot {
|
||||
snapshot := &snapshot{
|
||||
dist: []SnapshotItem{{1, make([]SegmentEntry, 0)}},
|
||||
targetVersion: targetVersion,
|
||||
}
|
||||
|
||||
newSeal := []SegmentEntry{}
|
||||
for _, seg := range seals {
|
||||
newSeal = append(newSeal, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: targetVersion})
|
||||
}
|
||||
|
||||
newGrow := []SegmentEntry{}
|
||||
for _, seg := range grows {
|
||||
newGrow = append(newGrow, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: targetVersion})
|
||||
}
|
||||
|
||||
log.Info("Test-", zap.Any("shanshot", snapshot), zap.Any("seg", newSeal))
|
||||
snapshot.dist[0].Segments = newSeal
|
||||
snapshot.growing = newGrow
|
||||
return snapshot
|
||||
}
|
||||
|
||||
genStringFieldData := func(strs ...string) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldId: 102,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: strs,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
s.Run("normal case", func() {
|
||||
// register sealed
|
||||
sealedSegs := []int64{1, 2, 3, 4}
|
||||
for _, segID := range sealedSegs {
|
||||
// every segment stats only has one token, avgdl = 1
|
||||
s.delegator.idfOracle.Register(segID, genBM25Stats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
|
||||
}
|
||||
snapshot := genSnapShot([]int64{1, 2, 3, 4}, []int64{}, 100)
|
||||
|
||||
s.delegator.idfOracle.SyncDistribution(snapshot)
|
||||
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
|
||||
s.NoError(err)
|
||||
|
||||
plan, err := proto.Marshal(&planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
QueryInfo: &planpb.QueryInfo{},
|
||||
},
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
req := &internalpb.SearchRequest{
|
||||
PlaceholderGroup: placeholderGroupBytes,
|
||||
SerializedExprPlan: plan,
|
||||
FieldId: 101,
|
||||
}
|
||||
avgdl, err := s.delegator.buildBM25IDF(req)
|
||||
s.NoError(err)
|
||||
s.Equal(float64(1), avgdl)
|
||||
|
||||
// check avgdl in plan
|
||||
newplan := &planpb.PlanNode{}
|
||||
err = proto.Unmarshal(req.GetSerializedExprPlan(), newplan)
|
||||
s.NoError(err)
|
||||
|
||||
annplan, ok := newplan.GetNode().(*planpb.PlanNode_VectorAnns)
|
||||
s.Require().True(ok)
|
||||
s.Equal(avgdl, annplan.VectorAnns.QueryInfo.Bm25Avgdl)
|
||||
|
||||
// check idf in placeholder
|
||||
placeholder := &commonpb.PlaceholderGroup{}
|
||||
err = proto.Unmarshal(req.GetPlaceholderGroup(), placeholder)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(placeholder.GetPlaceholders()[0].GetType(), commonpb.PlaceholderType_SparseFloatVector)
|
||||
})
|
||||
|
||||
s.Run("invalid place holder type error", func() {
|
||||
placeholderGroupBytes, err := proto.Marshal(&commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{{Type: commonpb.PlaceholderType_SparseFloatVector}},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
req := &internalpb.SearchRequest{
|
||||
PlaceholderGroup: placeholderGroupBytes,
|
||||
FieldId: 101,
|
||||
}
|
||||
_, err = s.delegator.buildBM25IDF(req)
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("no function runner error", func() {
|
||||
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
|
||||
s.NoError(err)
|
||||
|
||||
req := &internalpb.SearchRequest{
|
||||
PlaceholderGroup: placeholderGroupBytes,
|
||||
FieldId: 103, // invalid field id
|
||||
}
|
||||
|
||||
_, err = s.delegator.buildBM25IDF(req)
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("function runner run failed error", func() {
|
||||
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
|
||||
s.NoError(err)
|
||||
|
||||
oldRunner := s.delegator.functionRunners
|
||||
mockRunner := function.NewMockFunctionRunner(s.T())
|
||||
s.delegator.functionRunners = map[int64]function.FunctionRunner{101: mockRunner}
|
||||
mockRunner.EXPECT().BatchRun(mock.Anything).Return(nil, fmt.Errorf("mock err"))
|
||||
defer func() {
|
||||
s.delegator.functionRunners = oldRunner
|
||||
}()
|
||||
|
||||
req := &internalpb.SearchRequest{
|
||||
PlaceholderGroup: placeholderGroupBytes,
|
||||
FieldId: 101,
|
||||
}
|
||||
_, err = s.delegator.buildBM25IDF(req)
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("function runner output type error", func() {
|
||||
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
|
||||
s.NoError(err)
|
||||
|
||||
oldRunner := s.delegator.functionRunners
|
||||
mockRunner := function.NewMockFunctionRunner(s.T())
|
||||
s.delegator.functionRunners = map[int64]function.FunctionRunner{101: mockRunner}
|
||||
mockRunner.EXPECT().BatchRun(mock.Anything).Return([]interface{}{1}, nil)
|
||||
defer func() {
|
||||
s.delegator.functionRunners = oldRunner
|
||||
}()
|
||||
|
||||
req := &internalpb.SearchRequest{
|
||||
PlaceholderGroup: placeholderGroupBytes,
|
||||
FieldId: 101,
|
||||
}
|
||||
_, err = s.delegator.buildBM25IDF(req)
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("idf oracle build idf error", func() {
|
||||
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
|
||||
s.NoError(err)
|
||||
|
||||
oldRunner := s.delegator.functionRunners
|
||||
mockRunner := function.NewMockFunctionRunner(s.T())
|
||||
s.delegator.functionRunners = map[int64]function.FunctionRunner{103: mockRunner}
|
||||
mockRunner.EXPECT().BatchRun(mock.Anything).Return([]interface{}{&schemapb.SparseFloatArray{Contents: [][]byte{typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{1: 1})}}}, nil)
|
||||
defer func() {
|
||||
s.delegator.functionRunners = oldRunner
|
||||
}()
|
||||
|
||||
req := &internalpb.SearchRequest{
|
||||
PlaceholderGroup: placeholderGroupBytes,
|
||||
FieldId: 103, // invalid field
|
||||
}
|
||||
_, err = s.delegator.buildBM25IDF(req)
|
||||
s.Error(err)
|
||||
log.Info("test", zap.Error(err))
|
||||
})
|
||||
|
||||
s.Run("set avgdl failed", func() {
|
||||
// register sealed
|
||||
sealedSegs := []int64{1, 2, 3, 4}
|
||||
for _, segID := range sealedSegs {
|
||||
// every segment stats only has one token, avgdl = 1
|
||||
s.delegator.idfOracle.Register(segID, genBM25Stats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
|
||||
}
|
||||
snapshot := genSnapShot([]int64{1, 2, 3, 4}, []int64{}, 100)
|
||||
|
||||
s.delegator.idfOracle.SyncDistribution(snapshot)
|
||||
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
|
||||
s.NoError(err)
|
||||
|
||||
req := &internalpb.SearchRequest{
|
||||
PlaceholderGroup: placeholderGroupBytes,
|
||||
FieldId: 101,
|
||||
}
|
||||
_, err = s.delegator.buildBM25IDF(req)
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) TestReleaseSegment() {
|
||||
s.loader.EXPECT().
|
||||
Load(mock.Anything, s.collectionID, segments.SegmentTypeGrowing, int64(0), mock.Anything).
|
||||
|
|
|
@ -178,6 +178,84 @@ func (s *DelegatorSuite) TearDownTest() {
|
|||
s.delegator = nil
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) TestCreateDelegatorWithFunction() {
|
||||
s.Run("init function failed", func() {
|
||||
manager := segments.NewManager()
|
||||
manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{
|
||||
Name: "TestCollection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "id",
|
||||
FieldID: 100,
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
AutoID: true,
|
||||
}, {
|
||||
Name: "vector",
|
||||
FieldID: 101,
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{{
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldIds: []int64{102},
|
||||
OutputFieldIds: []int64{101, 103}, // invalid output field
|
||||
}},
|
||||
}, nil, nil)
|
||||
|
||||
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager)
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("init function failed", func() {
|
||||
manager := segments.NewManager()
|
||||
manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{
|
||||
Name: "TestCollection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "id",
|
||||
FieldID: 100,
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
AutoID: true,
|
||||
}, {
|
||||
Name: "vector",
|
||||
FieldID: 101,
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
}, {
|
||||
Name: "text",
|
||||
FieldID: 102,
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.MaxLengthKey,
|
||||
Value: "256",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{{
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldIds: []int64{102},
|
||||
OutputFieldIds: []int64{101},
|
||||
}},
|
||||
}, nil, nil)
|
||||
|
||||
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager)
|
||||
s.NoError(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) TestBasicInfo() {
|
||||
s.Equal(s.collectionID, s.delegator.Collection())
|
||||
s.Equal(s.version, s.delegator.Version())
|
||||
|
|
|
@ -74,6 +74,8 @@ type distribution struct {
|
|||
// current is the snapshot for quick usage for search/query
|
||||
// generated for each change of distribution
|
||||
current *atomic.Pointer[snapshot]
|
||||
|
||||
idfOracle IDFOracle
|
||||
// protects current & segments
|
||||
mut sync.RWMutex
|
||||
}
|
||||
|
@ -89,7 +91,7 @@ type SegmentEntry struct {
|
|||
}
|
||||
|
||||
// NewDistribution creates a new distribution instance with all field initialized.
|
||||
func NewDistribution() *distribution {
|
||||
func NewDistribution(idfOracle IDFOracle) *distribution {
|
||||
dist := &distribution{
|
||||
serviceable: atomic.NewBool(false),
|
||||
growingSegments: make(map[UniqueID]SegmentEntry),
|
||||
|
@ -98,6 +100,7 @@ func NewDistribution() *distribution {
|
|||
current: atomic.NewPointer[snapshot](nil),
|
||||
offlines: typeutil.NewSet[int64](),
|
||||
targetVersion: atomic.NewInt64(initialTargetVersion),
|
||||
idfOracle: idfOracle,
|
||||
}
|
||||
|
||||
dist.genSnapshot()
|
||||
|
@ -367,6 +370,7 @@ func (d *distribution) genSnapshot() chan struct{} {
|
|||
d.current.Store(newSnapShot)
|
||||
// shall be a new one
|
||||
d.snapshots.GetOrInsert(d.snapshotVersion, newSnapShot)
|
||||
d.idfOracle.SyncDistribution(newSnapShot)
|
||||
|
||||
// first snapshot, return closed chan
|
||||
if last == nil {
|
||||
|
|
|
@ -21,6 +21,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
type DistributionSuite struct {
|
||||
|
@ -29,7 +31,7 @@ type DistributionSuite struct {
|
|||
}
|
||||
|
||||
func (s *DistributionSuite) SetupTest() {
|
||||
s.dist = NewDistribution()
|
||||
s.dist = NewDistribution(NewIDFOracle([]*schemapb.FunctionSchema{}))
|
||||
s.Equal(initialTargetVersion, s.dist.getTargetVersion())
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,262 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package delegator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
)
|
||||
|
||||
type IDFOracle interface {
|
||||
// Activate(segmentID int64, state commonpb.SegmentState) error
|
||||
// Deactivate(segmentID int64, state commonpb.SegmentState) error
|
||||
|
||||
SyncDistribution(snapshot *snapshot)
|
||||
|
||||
UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats)
|
||||
|
||||
Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState)
|
||||
Remove(segmentID int64, state commonpb.SegmentState)
|
||||
|
||||
BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error)
|
||||
}
|
||||
|
||||
type bm25Stats struct {
|
||||
stats map[int64]*storage.BM25Stats
|
||||
activate bool
|
||||
targetVersion int64
|
||||
}
|
||||
|
||||
func (s *bm25Stats) Merge(stats map[int64]*storage.BM25Stats) {
|
||||
for fieldID, newstats := range stats {
|
||||
if stats, ok := s.stats[fieldID]; ok {
|
||||
stats.Merge(newstats)
|
||||
} else {
|
||||
log.Panic("merge failed, BM25 stats not exist", zap.Int64("fieldID", fieldID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *bm25Stats) Minus(stats map[int64]*storage.BM25Stats) {
|
||||
for fieldID, newstats := range stats {
|
||||
if stats, ok := s.stats[fieldID]; ok {
|
||||
stats.Minus(newstats)
|
||||
} else {
|
||||
log.Panic("minus failed, BM25 stats not exist", zap.Int64("fieldID", fieldID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *bm25Stats) GetStats(fieldID int64) (*storage.BM25Stats, error) {
|
||||
stats, ok := s.stats[fieldID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("field not found in idf oracle BM25 stats")
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *bm25Stats) NumRow() int64 {
|
||||
for _, stats := range s.stats {
|
||||
return stats.NumRow()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func newBm25Stats(functions []*schemapb.FunctionSchema) *bm25Stats {
|
||||
stats := &bm25Stats{
|
||||
stats: make(map[int64]*storage.BM25Stats),
|
||||
}
|
||||
|
||||
for _, function := range functions {
|
||||
if function.GetType() == schemapb.FunctionType_BM25 {
|
||||
stats.stats[function.GetOutputFieldIds()[0]] = storage.NewBM25Stats()
|
||||
}
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
type idfOracle struct {
|
||||
sync.RWMutex
|
||||
|
||||
current *bm25Stats
|
||||
|
||||
growing map[int64]*bm25Stats
|
||||
sealed map[int64]*bm25Stats
|
||||
|
||||
targetVersion int64
|
||||
}
|
||||
|
||||
func (o *idfOracle) Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState) {
|
||||
o.Lock()
|
||||
defer o.Unlock()
|
||||
|
||||
switch state {
|
||||
case segments.SegmentTypeGrowing:
|
||||
if _, ok := o.growing[segmentID]; ok {
|
||||
return
|
||||
}
|
||||
o.growing[segmentID] = &bm25Stats{
|
||||
stats: stats,
|
||||
activate: true,
|
||||
targetVersion: initialTargetVersion,
|
||||
}
|
||||
o.current.Merge(stats)
|
||||
case segments.SegmentTypeSealed:
|
||||
if _, ok := o.sealed[segmentID]; ok {
|
||||
return
|
||||
}
|
||||
o.sealed[segmentID] = &bm25Stats{
|
||||
stats: stats,
|
||||
activate: false,
|
||||
targetVersion: initialTargetVersion,
|
||||
}
|
||||
default:
|
||||
log.Warn("register segment with unknown state", zap.String("stats", state.String()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (o *idfOracle) UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats) {
|
||||
if len(stats) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
o.Lock()
|
||||
defer o.Unlock()
|
||||
|
||||
old, ok := o.growing[segmentID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
old.Merge(stats)
|
||||
if old.activate {
|
||||
o.current.Merge(stats)
|
||||
}
|
||||
}
|
||||
|
||||
func (o *idfOracle) Remove(segmentID int64, state commonpb.SegmentState) {
|
||||
o.Lock()
|
||||
defer o.Unlock()
|
||||
|
||||
switch state {
|
||||
case segments.SegmentTypeGrowing:
|
||||
if stats, ok := o.growing[segmentID]; ok {
|
||||
if stats.activate {
|
||||
o.current.Minus(stats.stats)
|
||||
}
|
||||
delete(o.growing, segmentID)
|
||||
}
|
||||
case segments.SegmentTypeSealed:
|
||||
if stats, ok := o.sealed[segmentID]; ok {
|
||||
if stats.activate {
|
||||
o.current.Minus(stats.stats)
|
||||
}
|
||||
delete(o.sealed, segmentID)
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (o *idfOracle) activate(stats *bm25Stats) {
|
||||
stats.activate = true
|
||||
o.current.Merge(stats.stats)
|
||||
}
|
||||
|
||||
func (o *idfOracle) deactivate(stats *bm25Stats) {
|
||||
stats.activate = false
|
||||
o.current.Minus(stats.stats)
|
||||
}
|
||||
|
||||
func (o *idfOracle) SyncDistribution(snapshot *snapshot) {
|
||||
o.Lock()
|
||||
defer o.Unlock()
|
||||
|
||||
sealed, growing := snapshot.Peek()
|
||||
|
||||
for _, item := range sealed {
|
||||
for _, segment := range item.Segments {
|
||||
if stats, ok := o.sealed[segment.SegmentID]; ok {
|
||||
stats.targetVersion = segment.TargetVersion
|
||||
} else {
|
||||
log.Warn("idf oracle lack some sealed segment", zap.Int64("segmentID", segment.SegmentID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, segment := range growing {
|
||||
if stats, ok := o.growing[segment.SegmentID]; ok {
|
||||
stats.targetVersion = segment.TargetVersion
|
||||
} else {
|
||||
log.Warn("idf oracle lack some growing segment", zap.Int64("segmentID", segment.SegmentID))
|
||||
}
|
||||
}
|
||||
|
||||
o.targetVersion = snapshot.targetVersion
|
||||
|
||||
for _, stats := range o.sealed {
|
||||
if !stats.activate && stats.targetVersion == o.targetVersion {
|
||||
o.activate(stats)
|
||||
} else if stats.activate && stats.targetVersion != o.targetVersion {
|
||||
o.deactivate(stats)
|
||||
}
|
||||
}
|
||||
|
||||
for _, stats := range o.growing {
|
||||
if !stats.activate && (stats.targetVersion == o.targetVersion || stats.targetVersion == initialTargetVersion) {
|
||||
o.activate(stats)
|
||||
} else if stats.activate && (stats.targetVersion != o.targetVersion && stats.targetVersion != initialTargetVersion) {
|
||||
o.deactivate(stats)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("sync distribution finished", zap.Int64("version", o.targetVersion), zap.Int64("numrow", o.current.NumRow()))
|
||||
}
|
||||
|
||||
func (o *idfOracle) BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error) {
|
||||
o.RLock()
|
||||
defer o.RUnlock()
|
||||
|
||||
stats, err := o.current.GetStats(fieldID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
idfBytes := make([][]byte, len(tfs.GetContents()))
|
||||
for i, tf := range tfs.GetContents() {
|
||||
idf := stats.BuildIDF(tf)
|
||||
idfBytes[i] = idf
|
||||
}
|
||||
return idfBytes, stats.GetAvgdl(), nil
|
||||
}
|
||||
|
||||
func NewIDFOracle(functions []*schemapb.FunctionSchema) IDFOracle {
|
||||
return &idfOracle{
|
||||
current: newBm25Stats(functions),
|
||||
growing: make(map[int64]*bm25Stats),
|
||||
sealed: make(map[int64]*bm25Stats),
|
||||
}
|
||||
}
|
|
@ -0,0 +1,198 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package delegator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type IDFOracleSuite struct {
|
||||
suite.Suite
|
||||
collectionID int64
|
||||
collectionSchema *schemapb.CollectionSchema
|
||||
idfOracle *idfOracle
|
||||
|
||||
targetVersion int64
|
||||
snapshot *snapshot
|
||||
}
|
||||
|
||||
func (suite *IDFOracleSuite) SetupSuite() {
|
||||
suite.collectionID = 111
|
||||
suite.collectionSchema = &schemapb.CollectionSchema{
|
||||
Functions: []*schemapb.FunctionSchema{{
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *IDFOracleSuite) SetupTest() {
|
||||
suite.idfOracle = NewIDFOracle(suite.collectionSchema.GetFunctions()).(*idfOracle)
|
||||
suite.snapshot = &snapshot{
|
||||
dist: []SnapshotItem{{1, make([]SegmentEntry, 0)}},
|
||||
}
|
||||
suite.targetVersion = 0
|
||||
}
|
||||
|
||||
func (suite *IDFOracleSuite) genStats(start uint32, end uint32) map[int64]*storage.BM25Stats {
|
||||
result := make(map[int64]*storage.BM25Stats)
|
||||
result[102] = storage.NewBM25Stats()
|
||||
for i := start; i < end; i++ {
|
||||
row := map[uint32]float32{i: 1}
|
||||
result[102].Append(row)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// update test snapshot
|
||||
func (suite *IDFOracleSuite) updateSnapshot(seals, grows, drops []int64) *snapshot {
|
||||
suite.targetVersion++
|
||||
snapshot := &snapshot{
|
||||
dist: []SnapshotItem{{1, make([]SegmentEntry, 0)}},
|
||||
targetVersion: suite.targetVersion,
|
||||
}
|
||||
|
||||
dropSet := typeutil.NewSet[int64]()
|
||||
dropSet.Insert(drops...)
|
||||
|
||||
newSeal := []SegmentEntry{}
|
||||
for _, seg := range suite.snapshot.dist[0].Segments {
|
||||
if !dropSet.Contain(seg.SegmentID) {
|
||||
seg.TargetVersion = suite.targetVersion
|
||||
}
|
||||
newSeal = append(newSeal, seg)
|
||||
}
|
||||
for _, seg := range seals {
|
||||
newSeal = append(newSeal, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: suite.targetVersion})
|
||||
}
|
||||
|
||||
newGrow := []SegmentEntry{}
|
||||
for _, seg := range suite.snapshot.growing {
|
||||
if !dropSet.Contain(seg.SegmentID) {
|
||||
seg.TargetVersion = suite.targetVersion
|
||||
} else {
|
||||
seg.TargetVersion = redundantTargetVersion
|
||||
}
|
||||
newGrow = append(newGrow, seg)
|
||||
}
|
||||
for _, seg := range grows {
|
||||
newGrow = append(newGrow, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: suite.targetVersion})
|
||||
}
|
||||
|
||||
snapshot.dist[0].Segments = newSeal
|
||||
snapshot.growing = newGrow
|
||||
suite.snapshot = snapshot
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (suite *IDFOracleSuite) TestSealed() {
|
||||
// register sealed
|
||||
sealedSegs := []int64{1, 2, 3, 4}
|
||||
for _, segID := range sealedSegs {
|
||||
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
|
||||
}
|
||||
|
||||
// reduplicate register
|
||||
for _, segID := range sealedSegs {
|
||||
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
|
||||
}
|
||||
|
||||
// register sealed segment but all deactvate
|
||||
suite.Zero(suite.idfOracle.current.NumRow())
|
||||
|
||||
// update and sync snapshot make all sealed activate
|
||||
suite.updateSnapshot(sealedSegs, []int64{}, []int64{})
|
||||
suite.idfOracle.SyncDistribution(suite.snapshot)
|
||||
suite.Equal(int64(4), suite.idfOracle.current.NumRow())
|
||||
|
||||
releasedSeg := []int64{1, 2, 3}
|
||||
suite.updateSnapshot([]int64{}, []int64{}, releasedSeg)
|
||||
suite.idfOracle.SyncDistribution(suite.snapshot)
|
||||
suite.Equal(int64(1), suite.idfOracle.current.NumRow())
|
||||
|
||||
for _, segID := range releasedSeg {
|
||||
suite.idfOracle.Remove(segID, commonpb.SegmentState_Sealed)
|
||||
}
|
||||
|
||||
sparse := typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{4: 1})
|
||||
bytes, avgdl, err := suite.idfOracle.BuildIDF(102, &schemapb.SparseFloatArray{Contents: [][]byte{sparse}, Dim: 1})
|
||||
suite.NoError(err)
|
||||
suite.Equal(float64(1), avgdl)
|
||||
suite.Equal(map[uint32]float32{4: 0.2876821}, typeutil.SparseFloatBytesToMap(bytes[0]))
|
||||
}
|
||||
|
||||
func (suite *IDFOracleSuite) TestGrow() {
|
||||
// register grow
|
||||
growSegs := []int64{1, 2, 3, 4}
|
||||
for _, segID := range growSegs {
|
||||
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Growing)
|
||||
}
|
||||
// reduplicate register
|
||||
for _, segID := range growSegs {
|
||||
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Growing)
|
||||
}
|
||||
|
||||
// register sealed segment but all deactvate
|
||||
suite.Equal(int64(4), suite.idfOracle.current.NumRow())
|
||||
suite.updateSnapshot([]int64{}, growSegs, []int64{})
|
||||
|
||||
releasedSeg := []int64{1, 2, 3}
|
||||
suite.updateSnapshot([]int64{}, []int64{}, releasedSeg)
|
||||
suite.idfOracle.SyncDistribution(suite.snapshot)
|
||||
suite.Equal(int64(1), suite.idfOracle.current.NumRow())
|
||||
|
||||
suite.idfOracle.UpdateGrowing(4, suite.genStats(5, 6))
|
||||
suite.Equal(int64(2), suite.idfOracle.current.NumRow())
|
||||
|
||||
for _, segID := range releasedSeg {
|
||||
suite.idfOracle.Remove(segID, commonpb.SegmentState_Growing)
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *IDFOracleSuite) TestStats() {
|
||||
stats := newBm25Stats([]*schemapb.FunctionSchema{{
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
}})
|
||||
|
||||
suite.Panics(func() {
|
||||
stats.Merge(map[int64]*storage.BM25Stats{103: storage.NewBM25Stats()})
|
||||
})
|
||||
|
||||
suite.Panics(func() {
|
||||
stats.Minus(map[int64]*storage.BM25Stats{103: storage.NewBM25Stats()})
|
||||
})
|
||||
|
||||
_, err := stats.GetStats(103)
|
||||
suite.Error(err)
|
||||
|
||||
_, err = stats.GetStats(102)
|
||||
suite.NoError(err)
|
||||
}
|
||||
|
||||
func TestIDFOracle(t *testing.T) {
|
||||
suite.Run(t, new(IDFOracleSuite))
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
package delegator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
func BuildSparseFieldData(field *schemapb.FieldSchema, sparseArray *schemapb.SparseFloatArray) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: field.GetDataType(),
|
||||
FieldName: field.GetName(),
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: sparseArray.GetDim(),
|
||||
Data: &schemapb.VectorField_SparseFloatVector{
|
||||
SparseFloatVector: sparseArray,
|
||||
},
|
||||
},
|
||||
},
|
||||
FieldId: field.GetFieldID(),
|
||||
}
|
||||
}
|
||||
|
||||
func SetBM25Params(req *internalpb.SearchRequest, avgdl float64) error {
|
||||
log := log.With(zap.Int64("collection", req.GetCollectionID()))
|
||||
|
||||
serializedPlan := req.GetSerializedExprPlan()
|
||||
// plan not found
|
||||
if serializedPlan == nil {
|
||||
log.Warn("serialized plan not found")
|
||||
return merr.WrapErrParameterInvalid("serialized search plan", "nil")
|
||||
}
|
||||
|
||||
plan := planpb.PlanNode{}
|
||||
err := proto.Unmarshal(serializedPlan, &plan)
|
||||
if err != nil {
|
||||
log.Warn("failed to unmarshal plan", zap.Error(err))
|
||||
return merr.WrapErrParameterInvalid("valid serialized search plan", "no unmarshalable one", err.Error())
|
||||
}
|
||||
|
||||
switch plan.GetNode().(type) {
|
||||
case *planpb.PlanNode_VectorAnns:
|
||||
queryInfo := plan.GetVectorAnns().GetQueryInfo()
|
||||
queryInfo.Bm25Avgdl = avgdl
|
||||
serializedExprPlan, err := proto.Marshal(&plan)
|
||||
if err != nil {
|
||||
log.Warn("failed to marshal optimized plan", zap.Error(err))
|
||||
return merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error())
|
||||
}
|
||||
req.SerializedExprPlan = serializedExprPlan
|
||||
log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo))
|
||||
default:
|
||||
log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode())))
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -60,6 +60,7 @@ func loadL0Segments(ctx context.Context, delegator delegator.ShardDelegator, req
|
|||
NumOfRows: segmentInfo.NumOfRows,
|
||||
Statslogs: segmentInfo.Statslogs,
|
||||
Deltalogs: segmentInfo.Deltalogs,
|
||||
Bm25Logs: segmentInfo.Bm25Statslogs,
|
||||
InsertChannel: segmentInfo.InsertChannel,
|
||||
StartPosition: segmentInfo.GetStartPosition(),
|
||||
Level: segmentInfo.GetLevel(),
|
||||
|
@ -101,6 +102,7 @@ func loadGrowingSegments(ctx context.Context, delegator delegator.ShardDelegator
|
|||
NumOfRows: segmentInfo.NumOfRows,
|
||||
Statslogs: segmentInfo.Statslogs,
|
||||
Deltalogs: segmentInfo.Deltalogs,
|
||||
Bm25Logs: segmentInfo.Bm25Statslogs,
|
||||
InsertChannel: segmentInfo.InsertChannel,
|
||||
})
|
||||
} else {
|
||||
|
|
|
@ -0,0 +1,211 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
base "github.com/milvus-io/milvus/internal/util/pipeline"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type embeddingNode struct {
|
||||
*BaseNode
|
||||
|
||||
collectionID int64
|
||||
channel string
|
||||
|
||||
manager *DataManager
|
||||
|
||||
functionRunners []function.FunctionRunner
|
||||
}
|
||||
|
||||
func newEmbeddingNode(collectionID int64, channelName string, manager *DataManager, maxQueueLength int32) (*embeddingNode, error) {
|
||||
collection := manager.Collection.Get(collectionID)
|
||||
if collection == nil {
|
||||
log.Error("embeddingNode init failed with collection not exist", zap.Int64("collection", collectionID))
|
||||
return nil, merr.WrapErrCollectionNotFound(collectionID)
|
||||
}
|
||||
|
||||
if len(collection.Schema().GetFunctions()) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
node := &embeddingNode{
|
||||
BaseNode: base.NewBaseNode(fmt.Sprintf("EmbeddingNode-%s", channelName), maxQueueLength),
|
||||
collectionID: collectionID,
|
||||
channel: channelName,
|
||||
manager: manager,
|
||||
functionRunners: make([]function.FunctionRunner, 0),
|
||||
}
|
||||
|
||||
for _, tf := range collection.Schema().GetFunctions() {
|
||||
functionRunner, err := function.NewFunctionRunner(collection.Schema(), tf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
node.functionRunners = append(node.functionRunners, functionRunner)
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
func (eNode *embeddingNode) Name() string {
|
||||
return fmt.Sprintf("embeddingNode-%s", eNode.channel)
|
||||
}
|
||||
|
||||
func (eNode *embeddingNode) addInsertData(insertDatas map[UniqueID]*delegator.InsertData, msg *InsertMsg, collection *Collection) error {
|
||||
iData, ok := insertDatas[msg.SegmentID]
|
||||
if !ok {
|
||||
iData = &delegator.InsertData{
|
||||
PartitionID: msg.PartitionID,
|
||||
BM25Stats: make(map[int64]*storage.BM25Stats),
|
||||
StartPosition: &msgpb.MsgPosition{
|
||||
Timestamp: msg.BeginTs(),
|
||||
ChannelName: msg.GetShardName(),
|
||||
},
|
||||
}
|
||||
insertDatas[msg.SegmentID] = iData
|
||||
}
|
||||
|
||||
err := eNode.embedding(msg, iData.BM25Stats)
|
||||
if err != nil {
|
||||
log.Error("failed to function data", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
insertRecord, err := storage.TransferInsertMsgToInsertRecord(collection.Schema(), msg)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to get primary keys, err = %d", err)
|
||||
log.Error(err.Error(), zap.String("channel", eNode.channel))
|
||||
return err
|
||||
}
|
||||
|
||||
if iData.InsertRecord == nil {
|
||||
iData.InsertRecord = insertRecord
|
||||
} else {
|
||||
err := typeutil.MergeFieldData(iData.InsertRecord.FieldsData, insertRecord.FieldsData)
|
||||
if err != nil {
|
||||
log.Warn("failed to merge field data", zap.String("channel", eNode.channel), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
iData.InsertRecord.NumRows += insertRecord.NumRows
|
||||
}
|
||||
|
||||
pks, err := segments.GetPrimaryKeys(msg, collection.Schema())
|
||||
if err != nil {
|
||||
log.Warn("failed to get primary keys from insert message", zap.String("channel", eNode.channel), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
iData.PrimaryKeys = append(iData.PrimaryKeys, pks...)
|
||||
iData.RowIDs = append(iData.RowIDs, msg.RowIDs...)
|
||||
iData.Timestamps = append(iData.Timestamps, msg.Timestamps...)
|
||||
log.Debug("pipeline embedding insert msg",
|
||||
zap.Int64("collectionID", eNode.collectionID),
|
||||
zap.Int64("segmentID", msg.SegmentID),
|
||||
zap.Int("insertRowNum", len(pks)),
|
||||
zap.Uint64("timestampMin", msg.BeginTimestamp),
|
||||
zap.Uint64("timestampMax", msg.EndTimestamp))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (eNode *embeddingNode) bm25Embedding(runner function.FunctionRunner, msg *msgstream.InsertMsg, stats map[int64]*storage.BM25Stats) error {
|
||||
functionSchema := runner.GetSchema()
|
||||
inputFieldID := functionSchema.GetInputFieldIds()[0]
|
||||
outputFieldID := functionSchema.GetOutputFieldIds()[0]
|
||||
outputField := runner.GetOutputFields()[0]
|
||||
|
||||
data, err := GetEmbeddingFieldData(msg.GetFieldsData(), inputFieldID)
|
||||
if data == nil || err != nil {
|
||||
return merr.WrapErrFieldNotFound(fmt.Sprint(inputFieldID))
|
||||
}
|
||||
|
||||
output, err := runner.BatchRun(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sparseArray, ok := output[0].(*schemapb.SparseFloatArray)
|
||||
if !ok {
|
||||
return fmt.Errorf("BM25 runner return unknown type output")
|
||||
}
|
||||
|
||||
if _, ok := stats[outputFieldID]; !ok {
|
||||
stats[outputFieldID] = storage.NewBM25Stats()
|
||||
}
|
||||
stats[outputFieldID].AppendBytes(sparseArray.GetContents()...)
|
||||
msg.FieldsData = append(msg.FieldsData, delegator.BuildSparseFieldData(outputField, sparseArray))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (eNode *embeddingNode) embedding(msg *msgstream.InsertMsg, stats map[int64]*storage.BM25Stats) error {
|
||||
for _, functionRunner := range eNode.functionRunners {
|
||||
functionSchema := functionRunner.GetSchema()
|
||||
switch functionSchema.GetType() {
|
||||
case schemapb.FunctionType_BM25:
|
||||
err := eNode.bm25Embedding(functionRunner, msg, stats)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
log.Warn("pipeline embedding with unknown function type", zap.Any("type", functionSchema.GetType()))
|
||||
return fmt.Errorf("unknown function type")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (eNode *embeddingNode) Operate(in Msg) Msg {
|
||||
nodeMsg := in.(*insertNodeMsg)
|
||||
nodeMsg.insertDatas = make(map[int64]*delegator.InsertData)
|
||||
|
||||
collection := eNode.manager.Collection.Get(eNode.collectionID)
|
||||
if collection == nil {
|
||||
log.Error("embeddingNode with collection not exist", zap.Int64("collection", eNode.collectionID))
|
||||
panic("embeddingNode with collection not exist")
|
||||
}
|
||||
|
||||
for _, msg := range nodeMsg.insertMsgs {
|
||||
err := eNode.addInsertData(nodeMsg.insertDatas, msg, collection)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nodeMsg
|
||||
}
|
||||
|
||||
func GetEmbeddingFieldData(datas []*schemapb.FieldData, fieldID int64) ([]string, error) {
|
||||
for _, data := range datas {
|
||||
if data.GetFieldId() == fieldID {
|
||||
return data.GetScalars().GetStringData().GetData(), nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("field %d not found", fieldID)
|
||||
}
|
|
@ -0,0 +1,281 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
// test of embedding node
|
||||
type EmbeddingNodeSuite struct {
|
||||
suite.Suite
|
||||
// datas
|
||||
collectionID int64
|
||||
collectionSchema *schemapb.CollectionSchema
|
||||
channel string
|
||||
msgs []*InsertMsg
|
||||
|
||||
// mocks
|
||||
manager *segments.Manager
|
||||
segManager *segments.MockSegmentManager
|
||||
colManager *segments.MockCollectionManager
|
||||
}
|
||||
|
||||
func (suite *EmbeddingNodeSuite) SetupTest() {
|
||||
paramtable.Init()
|
||||
suite.collectionID = 111
|
||||
suite.channel = "test-channel"
|
||||
suite.collectionSchema = &schemapb.CollectionSchema{
|
||||
Name: "test-collection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: common.TimeStampField,
|
||||
Name: common.TimeStampFieldName,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
}, {
|
||||
Name: "pk",
|
||||
FieldID: 100,
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
}, {
|
||||
Name: "text",
|
||||
FieldID: 101,
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
}, {
|
||||
Name: "sparse",
|
||||
FieldID: 102,
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
IsFunctionOutput: true,
|
||||
},
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{{
|
||||
Name: "BM25",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldIds: []int64{101},
|
||||
OutputFieldIds: []int64{102},
|
||||
}},
|
||||
}
|
||||
|
||||
suite.msgs = []*msgstream.InsertMsg{{
|
||||
BaseMsg: msgstream.BaseMsg{},
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
SegmentID: 1,
|
||||
NumRows: 3,
|
||||
Version: msgpb.InsertDataVersion_ColumnBased,
|
||||
Timestamps: []uint64{1, 1, 1},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
FieldId: 100,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}}},
|
||||
},
|
||||
}, {
|
||||
FieldId: 101,
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{"test1", "test2", "test3"}}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
|
||||
suite.segManager = segments.NewMockSegmentManager(suite.T())
|
||||
suite.colManager = segments.NewMockCollectionManager(suite.T())
|
||||
|
||||
suite.manager = &segments.Manager{
|
||||
Collection: suite.colManager,
|
||||
Segment: suite.segManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *EmbeddingNodeSuite) TestCreateEmbeddingNode() {
|
||||
suite.Run("collection not found", func() {
|
||||
suite.colManager.EXPECT().Get(suite.collectionID).Return(nil).Once()
|
||||
_, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||
suite.Error(err)
|
||||
})
|
||||
|
||||
suite.Run("function invalid", func() {
|
||||
collSchema := proto.Clone(suite.collectionSchema).(*schemapb.CollectionSchema)
|
||||
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, collSchema)
|
||||
collection.Schema().Functions = []*schemapb.FunctionSchema{{}}
|
||||
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
|
||||
_, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||
suite.Error(err)
|
||||
})
|
||||
|
||||
suite.Run("normal case", func() {
|
||||
collSchema := proto.Clone(suite.collectionSchema).(*schemapb.CollectionSchema)
|
||||
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, collSchema)
|
||||
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
|
||||
_, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||
suite.NoError(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (suite *EmbeddingNodeSuite) TestOperator() {
|
||||
suite.Run("collection not found", func() {
|
||||
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
|
||||
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
|
||||
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||
suite.NoError(err)
|
||||
|
||||
suite.colManager.EXPECT().Get(suite.collectionID).Return(nil).Once()
|
||||
suite.Panics(func() {
|
||||
node.Operate(&insertNodeMsg{})
|
||||
})
|
||||
})
|
||||
|
||||
suite.Run("add InsertData Failed", func() {
|
||||
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
|
||||
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Times(2)
|
||||
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||
suite.NoError(err)
|
||||
|
||||
suite.Panics(func() {
|
||||
node.Operate(&insertNodeMsg{
|
||||
insertMsgs: []*msgstream.InsertMsg{{
|
||||
BaseMsg: msgstream.BaseMsg{},
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
SegmentID: 1,
|
||||
NumRows: 3,
|
||||
Version: msgpb.InsertDataVersion_ColumnBased,
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
FieldId: 100,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
suite.Run("normal case", func() {
|
||||
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
|
||||
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Times(2)
|
||||
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||
suite.NoError(err)
|
||||
|
||||
suite.NotPanics(func() {
|
||||
output := node.Operate(&insertNodeMsg{
|
||||
insertMsgs: suite.msgs,
|
||||
})
|
||||
|
||||
msg, ok := output.(*insertNodeMsg)
|
||||
suite.Require().True(ok)
|
||||
suite.Require().NotNil(msg.insertDatas)
|
||||
suite.Require().Equal(int64(3), msg.insertDatas[1].BM25Stats[102].NumRow())
|
||||
suite.Require().Equal(int64(3), msg.insertDatas[1].InsertRecord.GetNumRows())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (suite *EmbeddingNodeSuite) TestAddInsertData() {
|
||||
suite.Run("transfer insert msg failed", func() {
|
||||
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
|
||||
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
|
||||
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||
suite.NoError(err)
|
||||
|
||||
// transfer insert msg failed because rowbase data not support sparse vector
|
||||
insertDatas := make(map[int64]*delegator.InsertData)
|
||||
rowBaseReq := proto.Clone(suite.msgs[0].InsertRequest).(*msgpb.InsertRequest)
|
||||
rowBaseReq.Version = msgpb.InsertDataVersion_RowBased
|
||||
rowBaseMsg := &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{},
|
||||
InsertRequest: rowBaseReq,
|
||||
}
|
||||
err = node.addInsertData(insertDatas, rowBaseMsg, collection)
|
||||
suite.Error(err)
|
||||
})
|
||||
|
||||
suite.Run("merge failed data failed", func() {
|
||||
// remove pk
|
||||
suite.collectionSchema.Fields[1].IsPrimaryKey = false
|
||||
defer func() {
|
||||
suite.collectionSchema.Fields[1].IsPrimaryKey = true
|
||||
}()
|
||||
|
||||
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
|
||||
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
|
||||
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||
suite.NoError(err)
|
||||
|
||||
insertDatas := make(map[int64]*delegator.InsertData)
|
||||
err = node.addInsertData(insertDatas, suite.msgs[0], collection)
|
||||
suite.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (suite *EmbeddingNodeSuite) TestBM25Embedding() {
|
||||
suite.Run("function run failed", func() {
|
||||
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
|
||||
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
|
||||
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||
suite.NoError(err)
|
||||
|
||||
runner := function.NewMockFunctionRunner(suite.T())
|
||||
runner.EXPECT().BatchRun(mock.Anything).Return(nil, fmt.Errorf("mock error"))
|
||||
runner.EXPECT().GetSchema().Return(suite.collectionSchema.GetFunctions()[0])
|
||||
runner.EXPECT().GetOutputFields().Return([]*schemapb.FieldSchema{nil})
|
||||
|
||||
err = node.bm25Embedding(runner, suite.msgs[0], nil)
|
||||
suite.Error(err)
|
||||
})
|
||||
|
||||
suite.Run("output with unknown type failed", func() {
|
||||
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
|
||||
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
|
||||
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
|
||||
suite.NoError(err)
|
||||
|
||||
runner := function.NewMockFunctionRunner(suite.T())
|
||||
runner.EXPECT().BatchRun(mock.Anything).Return([]interface{}{1}, nil)
|
||||
runner.EXPECT().GetSchema().Return(suite.collectionSchema.GetFunctions()[0])
|
||||
runner.EXPECT().GetOutputFields().Return([]*schemapb.FieldSchema{nil})
|
||||
|
||||
err = node.bm25Embedding(runner, suite.msgs[0], nil)
|
||||
suite.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddingNode(t *testing.T) {
|
||||
suite.Run(t, new(EmbeddingNodeSuite))
|
||||
}
|
|
@ -62,7 +62,7 @@ func (iNode *insertNode) addInsertData(insertDatas map[UniqueID]*delegator.Inser
|
|||
} else {
|
||||
err := typeutil.MergeFieldData(iData.InsertRecord.FieldsData, insertRecord.FieldsData)
|
||||
if err != nil {
|
||||
log.Error("failed to merge field data", zap.Error(err))
|
||||
log.Error("failed to merge field data", zap.String("channel", iNode.channel), zap.Error(err))
|
||||
panic(err)
|
||||
}
|
||||
iData.InsertRecord.NumRows += insertRecord.NumRows
|
||||
|
@ -95,21 +95,23 @@ func (iNode *insertNode) Operate(in Msg) Msg {
|
|||
return nodeMsg.insertMsgs[i].BeginTs() < nodeMsg.insertMsgs[j].BeginTs()
|
||||
})
|
||||
|
||||
insertDatas := make(map[UniqueID]*delegator.InsertData)
|
||||
collection := iNode.manager.Collection.Get(iNode.collectionID)
|
||||
if collection == nil {
|
||||
log.Error("insertNode with collection not exist", zap.Int64("collection", iNode.collectionID))
|
||||
panic("insertNode with collection not exist")
|
||||
// build insert data if no embedding node
|
||||
if nodeMsg.insertDatas == nil {
|
||||
collection := iNode.manager.Collection.Get(iNode.collectionID)
|
||||
if collection == nil {
|
||||
log.Error("insertNode with collection not exist", zap.Int64("collection", iNode.collectionID))
|
||||
panic("insertNode with collection not exist")
|
||||
}
|
||||
|
||||
nodeMsg.insertDatas = make(map[UniqueID]*delegator.InsertData)
|
||||
// get InsertData and merge datas of same segment
|
||||
for _, msg := range nodeMsg.insertMsgs {
|
||||
iNode.addInsertData(nodeMsg.insertDatas, msg, collection)
|
||||
}
|
||||
}
|
||||
|
||||
// get InsertData and merge datas of same segment
|
||||
for _, msg := range nodeMsg.insertMsgs {
|
||||
iNode.addInsertData(insertDatas, msg, collection)
|
||||
}
|
||||
|
||||
iNode.delegator.ProcessInsert(insertDatas)
|
||||
iNode.delegator.ProcessInsert(nodeMsg.insertDatas)
|
||||
}
|
||||
|
||||
metrics.QueryNodeWaitProcessingMsgCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.DeleteLabel).Inc()
|
||||
|
||||
return &deleteNodeMsg{
|
||||
|
|
|
@ -19,15 +19,17 @@ package pipeline
|
|||
import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/collector"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
|
||||
)
|
||||
|
||||
type insertNodeMsg struct {
|
||||
insertMsgs []*InsertMsg
|
||||
deleteMsgs []*DeleteMsg
|
||||
timeRange TimeRange
|
||||
insertMsgs []*InsertMsg
|
||||
deleteMsgs []*DeleteMsg
|
||||
insertDatas map[int64]*delegator.InsertData
|
||||
timeRange TimeRange
|
||||
}
|
||||
|
||||
type deleteNodeMsg struct {
|
||||
|
|
|
@ -31,7 +31,8 @@ type Pipeline interface {
|
|||
type pipeline struct {
|
||||
base.StreamPipeline
|
||||
|
||||
collectionID UniqueID
|
||||
collectionID UniqueID
|
||||
embeddingNode embeddingNode
|
||||
}
|
||||
|
||||
func (p *pipeline) Close() {
|
||||
|
@ -54,8 +55,21 @@ func NewPipeLine(
|
|||
}
|
||||
|
||||
filterNode := newFilterNode(collectionID, channel, manager, delegator, pipelineQueueLength)
|
||||
|
||||
embeddingNode, err := newEmbeddingNode(collectionID, channel, manager, pipelineQueueLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
insertNode := newInsertNode(collectionID, channel, manager, delegator, pipelineQueueLength)
|
||||
deleteNode := newDeleteNode(collectionID, channel, manager, tSafeManager, delegator, pipelineQueueLength)
|
||||
p.Add(filterNode, insertNode, deleteNode)
|
||||
|
||||
// skip add embedding node when collection has no function.
|
||||
if embeddingNode != nil {
|
||||
p.Add(filterNode, embeddingNode, insertNode, deleteNode)
|
||||
} else {
|
||||
p.Add(filterNode, insertNode, deleteNode)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
|
|
@ -299,6 +299,18 @@ func NewCollectionWithoutSchema(collectionID int64, loadType querypb.LoadType) *
|
|||
}
|
||||
}
|
||||
|
||||
// new collection without segcore prepare
|
||||
// ONLY FOR TEST
|
||||
func NewCollectionWithoutSegcoreForTest(collectionID int64, schema *schemapb.CollectionSchema) *Collection {
|
||||
coll := &Collection{
|
||||
id: collectionID,
|
||||
partitions: typeutil.NewConcurrentSet[int64](),
|
||||
refCount: atomic.NewUint32(0),
|
||||
}
|
||||
coll.schema.Store(schema)
|
||||
return coll
|
||||
}
|
||||
|
||||
// deleteCollection delete collection and free the collection memory
|
||||
func DeleteCollection(collection *Collection) {
|
||||
/*
|
||||
|
|
|
@ -47,6 +47,7 @@ import (
|
|||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/metautil"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/testutils"
|
||||
|
@ -220,6 +221,11 @@ func genConstantFieldSchema(param constFieldParam) *schemapb.FieldSchema {
|
|||
DataType: param.dataType,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
}
|
||||
if param.dataType == schemapb.DataType_VarChar {
|
||||
field.TypeParams = []*commonpb.KeyValuePair{
|
||||
{Key: common.MaxLengthKey, Value: "128"},
|
||||
}
|
||||
}
|
||||
return field
|
||||
}
|
||||
|
||||
|
@ -263,6 +269,35 @@ func genVectorFieldSchema(param vecFieldParam) *schemapb.FieldSchema {
|
|||
return fieldVec
|
||||
}
|
||||
|
||||
func GenTestBM25CollectionSchema(collectionName string) *schemapb.CollectionSchema {
|
||||
fieldRowID := genConstantFieldSchema(rowIDField)
|
||||
fieldTimestamp := genConstantFieldSchema(timestampField)
|
||||
pkFieldSchema := genPKFieldSchema(simpleInt64Field)
|
||||
textFieldSchema := genConstantFieldSchema(simpleVarCharField)
|
||||
sparseFieldSchema := genVectorFieldSchema(simpleSparseFloatVectorField)
|
||||
sparseFieldSchema.IsFunctionOutput = true
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
fieldRowID,
|
||||
fieldTimestamp,
|
||||
pkFieldSchema,
|
||||
textFieldSchema,
|
||||
sparseFieldSchema,
|
||||
},
|
||||
Functions: []*schemapb.FunctionSchema{{
|
||||
Name: "BM25",
|
||||
Type: schemapb.FunctionType_BM25,
|
||||
InputFieldNames: []string{textFieldSchema.GetName()},
|
||||
InputFieldIds: []int64{textFieldSchema.GetFieldID()},
|
||||
OutputFieldNames: []string{sparseFieldSchema.GetName()},
|
||||
OutputFieldIds: []int64{sparseFieldSchema.GetFieldID()},
|
||||
}},
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
// some tests do not yet support sparse float vector, see comments of
|
||||
// GenSparseFloatVecDataset in indexcgowrapper/dataset.go
|
||||
func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, withSparse bool) *schemapb.CollectionSchema {
|
||||
|
@ -671,6 +706,32 @@ func SaveDeltaLog(collectionID int64,
|
|||
return fieldBinlog, cm.MultiWrite(context.Background(), kvs)
|
||||
}
|
||||
|
||||
func SaveBM25Log(collectionID int64, partitionID int64, segmentID int64, fieldID int64, msgLength int, cm storage.ChunkManager) (*datapb.FieldBinlog, error) {
|
||||
stats := storage.NewBM25Stats()
|
||||
|
||||
for i := 0; i < msgLength; i++ {
|
||||
stats.Append(map[uint32]float32{1: 1})
|
||||
}
|
||||
|
||||
bytes, err := stats.Serialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kvs := make(map[string][]byte, 1)
|
||||
key := path.Join(cm.RootPath(), common.SegmentBm25LogPath, metautil.JoinIDPath(collectionID, partitionID, segmentID, fieldID, 1001))
|
||||
kvs[key] = bytes
|
||||
fieldBinlog := &datapb.FieldBinlog{
|
||||
FieldID: fieldID,
|
||||
Binlogs: []*datapb.Binlog{{
|
||||
LogPath: key,
|
||||
TimestampFrom: 100,
|
||||
TimestampTo: 200,
|
||||
}},
|
||||
}
|
||||
return fieldBinlog, cm.MultiWrite(context.Background(), kvs)
|
||||
}
|
||||
|
||||
func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64,
|
||||
fieldSchema *schemapb.FieldSchema,
|
||||
indexInfo *indexpb.IndexInfo,
|
||||
|
|
|
@ -14,6 +14,10 @@ import (
|
|||
pkoracle "github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
|
||||
|
||||
querypb "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
|
||||
storage "github.com/milvus-io/milvus/internal/storage"
|
||||
|
||||
typeutil "github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
// MockLoader is an autogenerated mock type for the Loader type
|
||||
|
@ -101,6 +105,76 @@ func (_c *MockLoader_Load_Call) RunAndReturn(run func(context.Context, int64, co
|
|||
return _c
|
||||
}
|
||||
|
||||
// LoadBM25Stats provides a mock function with given fields: ctx, collectionID, infos
|
||||
func (_m *MockLoader) LoadBM25Stats(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error) {
|
||||
_va := make([]interface{}, len(infos))
|
||||
for _i := range infos {
|
||||
_va[_i] = infos[_i]
|
||||
}
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, ctx, collectionID)
|
||||
_ca = append(_ca, _va...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats]
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int64, ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error)); ok {
|
||||
return rf(ctx, collectionID, infos...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int64, ...*querypb.SegmentLoadInfo) *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats]); ok {
|
||||
r0 = rf(ctx, collectionID, infos...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats])
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, int64, ...*querypb.SegmentLoadInfo) error); ok {
|
||||
r1 = rf(ctx, collectionID, infos...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockLoader_LoadBM25Stats_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadBM25Stats'
|
||||
type MockLoader_LoadBM25Stats_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// LoadBM25Stats is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - collectionID int64
|
||||
// - infos ...*querypb.SegmentLoadInfo
|
||||
func (_e *MockLoader_Expecter) LoadBM25Stats(ctx interface{}, collectionID interface{}, infos ...interface{}) *MockLoader_LoadBM25Stats_Call {
|
||||
return &MockLoader_LoadBM25Stats_Call{Call: _e.mock.On("LoadBM25Stats",
|
||||
append([]interface{}{ctx, collectionID}, infos...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockLoader_LoadBM25Stats_Call) Run(run func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo)) *MockLoader_LoadBM25Stats_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]*querypb.SegmentLoadInfo, len(args)-2)
|
||||
for i, a := range args[2:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(*querypb.SegmentLoadInfo)
|
||||
}
|
||||
}
|
||||
run(args[0].(context.Context), args[1].(int64), variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockLoader_LoadBM25Stats_Call) Return(_a0 *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], _a1 error) *MockLoader_LoadBM25Stats_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockLoader_LoadBM25Stats_Call) RunAndReturn(run func(context.Context, int64, ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error)) *MockLoader_LoadBM25Stats_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// LoadBloomFilterSet provides a mock function with given fields: ctx, collectionID, version, infos
|
||||
func (_m *MockLoader) LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) {
|
||||
_va := make([]interface{}, len(infos))
|
||||
|
|
|
@ -290,6 +290,49 @@ func (_c *MockSegment_ExistIndex_Call) RunAndReturn(run func(int64) bool) *MockS
|
|||
return _c
|
||||
}
|
||||
|
||||
// GetBM25Stats provides a mock function with given fields:
|
||||
func (_m *MockSegment) GetBM25Stats() map[int64]*storage.BM25Stats {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 map[int64]*storage.BM25Stats
|
||||
if rf, ok := ret.Get(0).(func() map[int64]*storage.BM25Stats); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(map[int64]*storage.BM25Stats)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockSegment_GetBM25Stats_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetBM25Stats'
|
||||
type MockSegment_GetBM25Stats_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetBM25Stats is a helper method to define mock.On call
|
||||
func (_e *MockSegment_Expecter) GetBM25Stats() *MockSegment_GetBM25Stats_Call {
|
||||
return &MockSegment_GetBM25Stats_Call{Call: _e.mock.On("GetBM25Stats")}
|
||||
}
|
||||
|
||||
func (_c *MockSegment_GetBM25Stats_Call) Run(run func()) *MockSegment_GetBM25Stats_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_GetBM25Stats_Call) Return(_a0 map[int64]*storage.BM25Stats) *MockSegment_GetBM25Stats_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_GetBM25Stats_Call) RunAndReturn(run func() map[int64]*storage.BM25Stats) *MockSegment_GetBM25Stats_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetIndex provides a mock function with given fields: fieldID
|
||||
func (_m *MockSegment) GetIndex(fieldID int64) *IndexedFieldInfo {
|
||||
ret := _m.Called(fieldID)
|
||||
|
@ -1570,6 +1613,39 @@ func (_c *MockSegment_Unpin_Call) RunAndReturn(run func()) *MockSegment_Unpin_Ca
|
|||
return _c
|
||||
}
|
||||
|
||||
// UpdateBM25Stats provides a mock function with given fields: stats
|
||||
func (_m *MockSegment) UpdateBM25Stats(stats map[int64]*storage.BM25Stats) {
|
||||
_m.Called(stats)
|
||||
}
|
||||
|
||||
// MockSegment_UpdateBM25Stats_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateBM25Stats'
|
||||
type MockSegment_UpdateBM25Stats_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UpdateBM25Stats is a helper method to define mock.On call
|
||||
// - stats map[int64]*storage.BM25Stats
|
||||
func (_e *MockSegment_Expecter) UpdateBM25Stats(stats interface{}) *MockSegment_UpdateBM25Stats_Call {
|
||||
return &MockSegment_UpdateBM25Stats_Call{Call: _e.mock.On("UpdateBM25Stats", stats)}
|
||||
}
|
||||
|
||||
func (_c *MockSegment_UpdateBM25Stats_Call) Run(run func(stats map[int64]*storage.BM25Stats)) *MockSegment_UpdateBM25Stats_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(map[int64]*storage.BM25Stats))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_UpdateBM25Stats_Call) Return() *MockSegment_UpdateBM25Stats_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockSegment_UpdateBM25Stats_Call) RunAndReturn(run func(map[int64]*storage.BM25Stats)) *MockSegment_UpdateBM25Stats_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateBloomFilter provides a mock function with given fields: pks
|
||||
func (_m *MockSegment) UpdateBloomFilter(pks []storage.PrimaryKey) {
|
||||
_m.Called(pks)
|
||||
|
|
|
@ -91,6 +91,8 @@ type baseSegment struct {
|
|||
isLazyLoad bool
|
||||
channel metautil.Channel
|
||||
|
||||
bm25Stats map[int64]*storage.BM25Stats
|
||||
|
||||
resourceUsageCache *atomic.Pointer[ResourceUsage]
|
||||
|
||||
needUpdatedVersion *atomic.Int64 // only for lazy load mode update index
|
||||
|
@ -107,6 +109,7 @@ func newBaseSegment(collection *Collection, segmentType SegmentType, version int
|
|||
version: atomic.NewInt64(version),
|
||||
segmentType: segmentType,
|
||||
bloomFilterSet: pkoracle.NewBloomFilterSet(loadInfo.GetSegmentID(), loadInfo.GetPartitionID(), segmentType),
|
||||
bm25Stats: make(map[int64]*storage.BM25Stats),
|
||||
channel: channel,
|
||||
isLazyLoad: isLazyLoad(collection, segmentType),
|
||||
|
||||
|
@ -185,6 +188,20 @@ func (s *baseSegment) UpdateBloomFilter(pks []storage.PrimaryKey) {
|
|||
s.bloomFilterSet.UpdateBloomFilter(pks)
|
||||
}
|
||||
|
||||
func (s *baseSegment) UpdateBM25Stats(stats map[int64]*storage.BM25Stats) {
|
||||
for fieldID, new := range stats {
|
||||
if current, ok := s.bm25Stats[fieldID]; ok {
|
||||
current.Merge(new)
|
||||
} else {
|
||||
s.bm25Stats[fieldID] = new
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *baseSegment) GetBM25Stats() map[int64]*storage.BM25Stats {
|
||||
return s.bm25Stats
|
||||
}
|
||||
|
||||
// MayPkExist returns true if the given PK exists in the PK range and being positive through the bloom filter,
|
||||
// false otherwise,
|
||||
// may returns true even the PK doesn't exist actually
|
||||
|
|
|
@ -87,6 +87,10 @@ type Segment interface {
|
|||
MayPkExist(lc *storage.LocationsCache) bool
|
||||
BatchPkExist(lc *storage.BatchLocationsCache) []bool
|
||||
|
||||
// BM25 stats
|
||||
UpdateBM25Stats(stats map[int64]*storage.BM25Stats)
|
||||
GetBM25Stats() map[int64]*storage.BM25Stats
|
||||
|
||||
// Read operations
|
||||
Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error)
|
||||
Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error)
|
||||
|
|
|
@ -77,6 +77,9 @@ type Loader interface {
|
|||
// LoadBloomFilterSet loads needed statslog for RemoteSegment.
|
||||
LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error)
|
||||
|
||||
// LoadBM25Stats loads BM25 statslog for RemoteSegment
|
||||
LoadBM25Stats(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error)
|
||||
|
||||
// LoadIndex append index for segment and remove vector binlogs.
|
||||
LoadIndex(ctx context.Context,
|
||||
segment Segment,
|
||||
|
@ -543,6 +546,47 @@ func (loader *segmentLoader) waitSegmentLoadDone(ctx context.Context, segmentTyp
|
|||
return nil
|
||||
}
|
||||
|
||||
func (loader *segmentLoader) LoadBM25Stats(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error) {
|
||||
segmentNum := len(infos)
|
||||
if segmentNum == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
segments := lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) int64 {
|
||||
return info.GetSegmentID()
|
||||
})
|
||||
log.Info("start loading bm25 stats for remote...", zap.Int64("collectionID", collectionID), zap.Int64s("segmentIDs", segments), zap.Int("segmentNum", segmentNum))
|
||||
|
||||
loadedStats := typeutil.NewConcurrentMap[int64, map[int64]*storage.BM25Stats]()
|
||||
loadRemoteBM25Func := func(idx int) error {
|
||||
loadInfo := infos[idx]
|
||||
segmentID := loadInfo.SegmentID
|
||||
stats := make(map[int64]*storage.BM25Stats)
|
||||
|
||||
log.Info("loading bm25 stats for remote...", zap.Int64("collectionID", collectionID), zap.Int64("segment", segmentID))
|
||||
logpaths := loader.filterBM25Stats(loadInfo.Bm25Logs)
|
||||
err := loader.loadBm25Stats(ctx, segmentID, stats, logpaths)
|
||||
if err != nil {
|
||||
log.Warn("load remote segment bm25 stats failed",
|
||||
zap.Int64("segmentID", segmentID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return err
|
||||
}
|
||||
loadedStats.Insert(segmentID, stats)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := funcutil.ProcessFuncParallel(segmentNum, segmentNum, loadRemoteBM25Func, "loadRemoteBM25Func")
|
||||
if err != nil {
|
||||
// no partial success here
|
||||
log.Warn("failed to load bm25 stats for remote segment", zap.Int64("collectionID", collectionID), zap.Int64s("segmentIDs", segments), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return loadedStats, nil
|
||||
}
|
||||
|
||||
func (loader *segmentLoader) LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", collectionID),
|
||||
|
@ -826,6 +870,16 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context,
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(loadInfo.Bm25Logs) > 0 {
|
||||
log.Info("loading bm25 stats...")
|
||||
bm25StatsLogs := loader.filterBM25Stats(loadInfo.Bm25Logs)
|
||||
|
||||
err = loader.loadBm25Stats(ctx, segment.ID(), segment.bm25Stats, bm25StatsLogs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -898,6 +952,26 @@ func (loader *segmentLoader) filterPKStatsBinlogs(fieldBinlogs []*datapb.FieldBi
|
|||
return result, storage.DefaultStatsType
|
||||
}
|
||||
|
||||
func (loader *segmentLoader) filterBM25Stats(fieldBinlogs []*datapb.FieldBinlog) map[int64][]string {
|
||||
result := make(map[int64][]string, 0)
|
||||
for _, fieldBinlog := range fieldBinlogs {
|
||||
logpaths := []string{}
|
||||
for _, binlog := range fieldBinlog.GetBinlogs() {
|
||||
_, logidx := path.Split(binlog.GetLogPath())
|
||||
// if special status log exist
|
||||
// only load one file
|
||||
if logidx == storage.CompoundStatsType.LogIdx() {
|
||||
logpaths = []string{binlog.GetLogPath()}
|
||||
break
|
||||
} else {
|
||||
logpaths = append(logpaths, binlog.GetLogPath())
|
||||
}
|
||||
}
|
||||
result[fieldBinlog.FieldID] = logpaths
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func loadSealedSegmentFields(ctx context.Context, collection *Collection, segment *LocalSegment, fields []*datapb.FieldBinlog, rowCount int64) error {
|
||||
runningGroup, _ := errgroup.WithContext(ctx)
|
||||
for _, field := range fields {
|
||||
|
@ -989,6 +1063,51 @@ func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalS
|
|||
return segment.LoadIndex(ctx, indexInfo, fieldType)
|
||||
}
|
||||
|
||||
func (loader *segmentLoader) loadBm25Stats(ctx context.Context, segmentID int64, stats map[int64]*storage.BM25Stats, binlogPaths map[int64][]string) error {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("segmentID", segmentID),
|
||||
)
|
||||
if len(binlogPaths) == 0 {
|
||||
log.Info("there are no bm25 stats logs saved with segment")
|
||||
return nil
|
||||
}
|
||||
|
||||
pathList := []string{}
|
||||
fieldList := []int64{}
|
||||
fieldOffset := []int{}
|
||||
for fieldId, logpaths := range binlogPaths {
|
||||
pathList = append(pathList, logpaths...)
|
||||
fieldList = append(fieldList, fieldId)
|
||||
fieldOffset = append(fieldOffset, len(logpaths))
|
||||
}
|
||||
|
||||
startTs := time.Now()
|
||||
values, err := loader.cm.MultiRead(ctx, pathList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cnt := 0
|
||||
for i, fieldID := range fieldList {
|
||||
newStats, ok := stats[fieldID]
|
||||
if !ok {
|
||||
newStats = storage.NewBM25Stats()
|
||||
stats[fieldID] = newStats
|
||||
}
|
||||
|
||||
for j := 0; j < fieldOffset[i]; j++ {
|
||||
err := newStats.Deserialize(values[cnt+j])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
cnt += fieldOffset[i]
|
||||
log.Info("Successfully load bm25 stats", zap.Duration("time", time.Since(startTs)), zap.Int64("numRow", newStats.NumRow()), zap.Int64("fieldID", fieldID))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (loader *segmentLoader) loadBloomFilter(ctx context.Context, segmentID int64, bfs *pkoracle.BloomFilterSet,
|
||||
binlogPaths []string, logType storage.StatsLogType,
|
||||
) error {
|
||||
|
|
|
@ -70,14 +70,16 @@ func (suite *SegmentLoaderSuite) SetupSuite() {
|
|||
}
|
||||
|
||||
func (suite *SegmentLoaderSuite) SetupTest() {
|
||||
// Dependencies
|
||||
suite.manager = NewManager()
|
||||
ctx := context.Background()
|
||||
|
||||
// TODO:: cpp chunk manager not support local chunk manager
|
||||
// suite.chunkManager = storage.NewLocalChunkManager(storage.RootPath(
|
||||
// fmt.Sprintf("/tmp/milvus-ut/%d", rand.Int63())))
|
||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
|
||||
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
|
||||
|
||||
// Dependencies
|
||||
suite.manager = NewManager()
|
||||
suite.loader = NewLoader(suite.manager, suite.chunkManager)
|
||||
initcore.InitRemoteChunkManager(paramtable.Get())
|
||||
|
||||
|
@ -92,6 +94,22 @@ func (suite *SegmentLoaderSuite) SetupTest() {
|
|||
suite.manager.Collection.PutOrRef(suite.collectionID, suite.schema, indexMeta, loadMeta)
|
||||
}
|
||||
|
||||
func (suite *SegmentLoaderSuite) SetupBM25() {
|
||||
// Dependencies
|
||||
suite.manager = NewManager()
|
||||
suite.loader = NewLoader(suite.manager, suite.chunkManager)
|
||||
initcore.InitRemoteChunkManager(paramtable.Get())
|
||||
|
||||
suite.schema = GenTestBM25CollectionSchema("test")
|
||||
indexMeta := GenTestIndexMeta(suite.collectionID, suite.schema)
|
||||
loadMeta := &querypb.LoadMetaInfo{
|
||||
LoadType: querypb.LoadType_LoadCollection,
|
||||
CollectionID: suite.collectionID,
|
||||
PartitionIDs: []int64{suite.partitionID},
|
||||
}
|
||||
suite.manager.Collection.PutOrRef(suite.collectionID, suite.schema, indexMeta, loadMeta)
|
||||
}
|
||||
|
||||
func (suite *SegmentLoaderSuite) TearDownTest() {
|
||||
ctx := context.Background()
|
||||
for i := 0; i < suite.segmentNum; i++ {
|
||||
|
@ -407,6 +425,41 @@ func (suite *SegmentLoaderSuite) TestLoadDeltaLogs() {
|
|||
}
|
||||
}
|
||||
|
||||
func (suite *SegmentLoaderSuite) TestLoadBm25Stats() {
|
||||
suite.SetupBM25()
|
||||
msgLength := 1
|
||||
sparseFieldID := simpleSparseFloatVectorField.id
|
||||
loadInfos := make([]*querypb.SegmentLoadInfo, 0, suite.segmentNum)
|
||||
|
||||
for i := 0; i < suite.segmentNum; i++ {
|
||||
segmentID := suite.segmentID + int64(i)
|
||||
|
||||
bm25logs, err := SaveBM25Log(suite.collectionID, suite.partitionID, segmentID, sparseFieldID, msgLength, suite.chunkManager)
|
||||
suite.NoError(err)
|
||||
|
||||
loadInfos = append(loadInfos, &querypb.SegmentLoadInfo{
|
||||
SegmentID: segmentID,
|
||||
PartitionID: suite.partitionID,
|
||||
CollectionID: suite.collectionID,
|
||||
Bm25Logs: []*datapb.FieldBinlog{bm25logs},
|
||||
NumOfRows: int64(msgLength),
|
||||
InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID),
|
||||
})
|
||||
}
|
||||
|
||||
statsMap, err := suite.loader.LoadBM25Stats(context.Background(), suite.collectionID, loadInfos...)
|
||||
suite.NoError(err)
|
||||
|
||||
for i := 0; i < suite.segmentNum; i++ {
|
||||
segmentID := suite.segmentID + int64(i)
|
||||
stats, ok := statsMap.Get(segmentID)
|
||||
suite.True(ok)
|
||||
fieldStats, ok := stats[sparseFieldID]
|
||||
suite.True(ok)
|
||||
suite.Equal(int64(msgLength), fieldStats.NumRow())
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *SegmentLoaderSuite) TestLoadDupDeltaLogs() {
|
||||
ctx := context.Background()
|
||||
loadInfos := make([]*querypb.SegmentLoadInfo, 0, suite.segmentNum)
|
||||
|
|
|
@ -21,10 +21,10 @@ import (
|
|||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"math"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/bloomfilter"
|
||||
|
@ -347,7 +347,7 @@ func (m *BM25Stats) AppendFieldData(datas ...*SparseFloatVectorFieldData) {
|
|||
// Update BM25Stats by sparse vector bytes
|
||||
func (m *BM25Stats) AppendBytes(datas ...[]byte) {
|
||||
for _, data := range datas {
|
||||
dim := len(data) / 8
|
||||
dim := typeutil.SparseFloatRowElementCount(data)
|
||||
for i := 0; i < dim; i++ {
|
||||
index := typeutil.SparseFloatRowIndexAt(data, i)
|
||||
value := typeutil.SparseFloatRowValueAt(data, i)
|
||||
|
@ -454,17 +454,19 @@ func (m *BM25Stats) Deserialize(bs []byte) error {
|
|||
m.rowsWithToken[keys[i]] += values[i]
|
||||
}
|
||||
|
||||
log.Info("test-- deserialize", zap.Int64("numrow", m.numRow), zap.Int64("tokenNum", m.numToken))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BM25Stats) BuildIDF(tf map[uint32]float32) map[uint32]float32 {
|
||||
vector := make(map[uint32]float32)
|
||||
for key, value := range tf {
|
||||
func (m *BM25Stats) BuildIDF(tf []byte) (idf []byte) {
|
||||
dim := typeutil.SparseFloatRowElementCount(tf)
|
||||
idf = make([]byte, len(tf))
|
||||
for idx := 0; idx < dim; idx++ {
|
||||
key := typeutil.SparseFloatRowIndexAt(tf, idx)
|
||||
value := typeutil.SparseFloatRowValueAt(tf, idx)
|
||||
nq := m.rowsWithToken[key]
|
||||
vector[key] = value * float32(math.Log(1+(float64(m.numRow)-float64(nq)+0.5)/(float64(nq)+0.5)))
|
||||
typeutil.SparseFloatRowSetAt(idf, idx, key, value*float32(math.Log(1+(float64(m.numRow)-float64(nq)+0.5)/(float64(nq)+0.5))))
|
||||
}
|
||||
return vector
|
||||
return
|
||||
}
|
||||
|
||||
func (m *BM25Stats) GetAvgdl() float64 {
|
||||
|
|
|
@ -358,7 +358,7 @@ func readDoubleArray(blobReaders []io.Reader) []float64 {
|
|||
return ret
|
||||
}
|
||||
|
||||
func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemapb.CollectionSchema) (idata *InsertData, err error) {
|
||||
func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemapb.CollectionSchema, skipFunction bool) (idata *InsertData, err error) {
|
||||
blobReaders := make([]io.Reader, 0)
|
||||
for _, blob := range msg.RowData {
|
||||
blobReaders = append(blobReaders, bytes.NewReader(blob.GetValue()))
|
||||
|
@ -371,7 +371,7 @@ func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemap
|
|||
}
|
||||
|
||||
for _, field := range collSchema.Fields {
|
||||
if field.GetIsFunctionOutput() {
|
||||
if skipFunction && field.GetIsFunctionOutput() {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -696,7 +696,7 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche
|
|||
|
||||
func InsertMsgToInsertData(msg *msgstream.InsertMsg, schema *schemapb.CollectionSchema) (idata *InsertData, err error) {
|
||||
if msg.IsRowBased() {
|
||||
return RowBasedInsertMsgToInsertData(msg, schema)
|
||||
return RowBasedInsertMsgToInsertData(msg, schema, true)
|
||||
}
|
||||
return ColumnBasedInsertMsgToInsertData(msg, schema)
|
||||
}
|
||||
|
@ -1272,7 +1272,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
|
|||
|
||||
func TransferInsertMsgToInsertRecord(schema *schemapb.CollectionSchema, msg *msgstream.InsertMsg) (*segcorepb.InsertRecord, error) {
|
||||
if msg.IsRowBased() {
|
||||
insertData, err := RowBasedInsertMsgToInsertData(msg, schema)
|
||||
insertData, err := RowBasedInsertMsgToInsertData(msg, schema, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1281,7 +1281,8 @@ func TransferInsertMsgToInsertRecord(schema *schemapb.CollectionSchema, msg *msg
|
|||
|
||||
// column base insert msg
|
||||
insertRecord := &segcorepb.InsertRecord{
|
||||
NumRows: int64(msg.NumRows),
|
||||
NumRows: int64(msg.NumRows),
|
||||
FieldsData: make([]*schemapb.FieldData, 0),
|
||||
}
|
||||
|
||||
insertRecord.FieldsData = append(insertRecord.FieldsData, msg.FieldsData...)
|
||||
|
|
|
@ -1035,7 +1035,7 @@ func TestRowBasedInsertMsgToInsertData(t *testing.T) {
|
|||
fieldIDs = fieldIDs[:len(fieldIDs)-2]
|
||||
msg, _, columns := genRowBasedInsertMsg(numRows, fVecDim, bVecDim, f16VecDim, bf16VecDim)
|
||||
|
||||
idata, err := RowBasedInsertMsgToInsertData(msg, schema)
|
||||
idata, err := RowBasedInsertMsgToInsertData(msg, schema, false)
|
||||
assert.NoError(t, err)
|
||||
for idx, fID := range fieldIDs {
|
||||
column := columns[idx]
|
||||
|
@ -1096,7 +1096,7 @@ func TestRowBasedInsertMsgToInsertFloat16VectorDataError(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
_, err := RowBasedInsertMsgToInsertData(msg, schema)
|
||||
_, err := RowBasedInsertMsgToInsertData(msg, schema, false)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
|
@ -1139,7 +1139,7 @@ func TestRowBasedInsertMsgToInsertBFloat16VectorDataError(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
_, err := RowBasedInsertMsgToInsertData(msg, schema)
|
||||
_, err := RowBasedInsertMsgToInsertData(msg, schema, false)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
// Code generated by mockery v2.32.4. DO NOT EDIT.
|
||||
|
||||
package function
|
||||
|
||||
import (
|
||||
schemapb "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockFunctionRunner is an autogenerated mock type for the FunctionRunner type
|
||||
type MockFunctionRunner struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockFunctionRunner_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockFunctionRunner) EXPECT() *MockFunctionRunner_Expecter {
|
||||
return &MockFunctionRunner_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// BatchRun provides a mock function with given fields: inputs
|
||||
func (_m *MockFunctionRunner) BatchRun(inputs ...interface{}) ([]interface{}, error) {
|
||||
var _ca []interface{}
|
||||
_ca = append(_ca, inputs...)
|
||||
ret := _m.Called(_ca...)
|
||||
|
||||
var r0 []interface{}
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(...interface{}) ([]interface{}, error)); ok {
|
||||
return rf(inputs...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(...interface{}) []interface{}); ok {
|
||||
r0 = rf(inputs...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]interface{})
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(...interface{}) error); ok {
|
||||
r1 = rf(inputs...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockFunctionRunner_BatchRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BatchRun'
|
||||
type MockFunctionRunner_BatchRun_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// BatchRun is a helper method to define mock.On call
|
||||
// - inputs ...interface{}
|
||||
func (_e *MockFunctionRunner_Expecter) BatchRun(inputs ...interface{}) *MockFunctionRunner_BatchRun_Call {
|
||||
return &MockFunctionRunner_BatchRun_Call{Call: _e.mock.On("BatchRun",
|
||||
append([]interface{}{}, inputs...)...)}
|
||||
}
|
||||
|
||||
func (_c *MockFunctionRunner_BatchRun_Call) Run(run func(inputs ...interface{})) *MockFunctionRunner_BatchRun_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
variadicArgs := make([]interface{}, len(args)-0)
|
||||
for i, a := range args[0:] {
|
||||
if a != nil {
|
||||
variadicArgs[i] = a.(interface{})
|
||||
}
|
||||
}
|
||||
run(variadicArgs...)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockFunctionRunner_BatchRun_Call) Return(_a0 []interface{}, _a1 error) *MockFunctionRunner_BatchRun_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockFunctionRunner_BatchRun_Call) RunAndReturn(run func(...interface{}) ([]interface{}, error)) *MockFunctionRunner_BatchRun_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetOutputFields provides a mock function with given fields:
|
||||
func (_m *MockFunctionRunner) GetOutputFields() []*schemapb.FieldSchema {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 []*schemapb.FieldSchema
|
||||
if rf, ok := ret.Get(0).(func() []*schemapb.FieldSchema); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]*schemapb.FieldSchema)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockFunctionRunner_GetOutputFields_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetOutputFields'
|
||||
type MockFunctionRunner_GetOutputFields_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetOutputFields is a helper method to define mock.On call
|
||||
func (_e *MockFunctionRunner_Expecter) GetOutputFields() *MockFunctionRunner_GetOutputFields_Call {
|
||||
return &MockFunctionRunner_GetOutputFields_Call{Call: _e.mock.On("GetOutputFields")}
|
||||
}
|
||||
|
||||
func (_c *MockFunctionRunner_GetOutputFields_Call) Run(run func()) *MockFunctionRunner_GetOutputFields_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockFunctionRunner_GetOutputFields_Call) Return(_a0 []*schemapb.FieldSchema) *MockFunctionRunner_GetOutputFields_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockFunctionRunner_GetOutputFields_Call) RunAndReturn(run func() []*schemapb.FieldSchema) *MockFunctionRunner_GetOutputFields_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetSchema provides a mock function with given fields:
|
||||
func (_m *MockFunctionRunner) GetSchema() *schemapb.FunctionSchema {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 *schemapb.FunctionSchema
|
||||
if rf, ok := ret.Get(0).(func() *schemapb.FunctionSchema); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*schemapb.FunctionSchema)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockFunctionRunner_GetSchema_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSchema'
|
||||
type MockFunctionRunner_GetSchema_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetSchema is a helper method to define mock.On call
|
||||
func (_e *MockFunctionRunner_Expecter) GetSchema() *MockFunctionRunner_GetSchema_Call {
|
||||
return &MockFunctionRunner_GetSchema_Call{Call: _e.mock.On("GetSchema")}
|
||||
}
|
||||
|
||||
func (_c *MockFunctionRunner_GetSchema_Call) Run(run func()) *MockFunctionRunner_GetSchema_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockFunctionRunner_GetSchema_Call) Return(_a0 *schemapb.FunctionSchema) *MockFunctionRunner_GetSchema_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockFunctionRunner_GetSchema_Call) RunAndReturn(run func() *schemapb.FunctionSchema) *MockFunctionRunner_GetSchema_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewMockFunctionRunner creates a new instance of MockFunctionRunner. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockFunctionRunner(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MockFunctionRunner {
|
||||
mock := &MockFunctionRunner{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
|
@ -6,12 +6,26 @@ import (
|
|||
"math"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func SparseVectorDataToPlaceholderGroupBytes(contents [][]byte) []byte {
|
||||
placeholderGroup := &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_SparseFloatVector,
|
||||
Values: contents,
|
||||
}},
|
||||
}
|
||||
|
||||
bytes, _ := proto.Marshal(placeholderGroup)
|
||||
return bytes
|
||||
}
|
||||
|
||||
func FieldDataToPlaceholderGroupBytes(fieldData *schemapb.FieldData) ([]byte, error) {
|
||||
placeholderValue, err := fieldDataToPlaceholderValue(fieldData)
|
||||
if err != nil {
|
||||
|
@ -93,6 +107,14 @@ func fieldDataToPlaceholderValue(fieldData *schemapb.FieldData) (*commonpb.Place
|
|||
Values: [][]byte{bytes},
|
||||
}
|
||||
return placeholderValue, nil
|
||||
case schemapb.DataType_VarChar:
|
||||
strs := fieldData.GetScalars().GetStringData().GetData()
|
||||
placeholderValue := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Type: commonpb.PlaceholderType_VarChar,
|
||||
Values: lo.Map(strs, func(str string, _ int) []byte { return []byte(str) }),
|
||||
}
|
||||
return placeholderValue, nil
|
||||
default:
|
||||
return nil, errors.New("field is not a vector field")
|
||||
}
|
||||
|
@ -157,3 +179,7 @@ func flattenedBFloat16VectorsToByteVectors(flattenedVectors []byte, dimension in
|
|||
|
||||
return result
|
||||
}
|
||||
|
||||
func GetVarCharFromPlaceholder(holder *commonpb.PlaceholderValue) []string {
|
||||
return lo.Map(holder.Values, func(bytes []byte, _ int) string { return string(bytes) })
|
||||
}
|
||||
|
|
|
@ -36,4 +36,6 @@ const (
|
|||
|
||||
// SUPERSTRUCTURE represents superstructure distance
|
||||
SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE"
|
||||
|
||||
BM25 MetricType = "BM25"
|
||||
)
|
||||
|
|
|
@ -21,5 +21,5 @@ import "strings"
|
|||
// PositivelyRelated return if metricType are "ip" or "IP"
|
||||
func PositivelyRelated(metricType string) bool {
|
||||
mUpper := strings.ToUpper(metricType)
|
||||
return mUpper == strings.ToUpper(IP) || mUpper == strings.ToUpper(COSINE)
|
||||
return mUpper == strings.ToUpper(IP) || mUpper == strings.ToUpper(COSINE) || mUpper == strings.ToUpper(BM25)
|
||||
}
|
||||
|
|
|
@ -212,7 +212,7 @@ func TestIndexAutoSparseVector(t *testing.T) {
|
|||
for _, unsupportedMt := range hp.UnsupportedSparseVecMetricsType {
|
||||
idx := index.NewAutoIndex(unsupportedMt)
|
||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idx))
|
||||
common.CheckErr(t, err, false, "only IP is the supported metric type for sparse index")
|
||||
common.CheckErr(t, err, false, "only IP&BM25 is the supported metric type for sparse index")
|
||||
}
|
||||
|
||||
// auto index with different metric type on sparse vec
|
||||
|
@ -829,11 +829,11 @@ func TestCreateSparseIndexInvalidParams(t *testing.T) {
|
|||
for _, mt := range hp.UnsupportedSparseVecMetricsType {
|
||||
idxInverted := index.NewSparseInvertedIndex(mt, 0.2)
|
||||
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxInverted))
|
||||
common.CheckErr(t, err, false, "only IP is the supported metric type for sparse index")
|
||||
common.CheckErr(t, err, false, "only IP&BM25 is the supported metric type for sparse index")
|
||||
|
||||
idxWand := index.NewSparseWANDIndex(mt, 0.2)
|
||||
_, err = mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxWand))
|
||||
common.CheckErr(t, err, false, "only IP is the supported metric type for sparse index")
|
||||
common.CheckErr(t, err, false, "only IP&BM25 is the supported metric type for sparse index")
|
||||
}
|
||||
|
||||
// create index with invalid drop_ratio_build
|
||||
|
|
Loading…
Reference in New Issue