feat: support load and query with bm25 metric (#36071)

relate: https://github.com/milvus-io/milvus/issues/35853

---------

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
pull/36715/head
aoiasd 2024-10-11 10:23:20 +08:00 committed by GitHub
parent 90285830de
commit db34572c56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 2372 additions and 71 deletions

View File

@ -532,6 +532,9 @@ generate-mockery-utils: getdeps
# proxy_client_manager.go
$(INSTALL_PATH)/mockery --name=ProxyClientManagerInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_client_manager.go --with-expecter --structname=MockProxyClientManager --inpackage
$(INSTALL_PATH)/mockery --name=ProxyWatcherInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_watcher.go --with-expecter --structname=MockProxyWatcher --inpackage
# function
$(INSTALL_PATH)/mockery --name=FunctionRunner --dir=$(PWD)/internal/util/function --output=$(PWD)/internal/util/function --filename=mock_function.go --with-expecter --structname=MockFunctionRunner --inpackage
generate-mockery-kv: getdeps
$(INSTALL_PATH)/mockery --name=TxnKV --dir=$(PWD)/pkg/kv --output=$(PWD)/internal/kv/mocks --filename=txn_kv.go --with-expecter

View File

@ -410,7 +410,8 @@ inline bool
IsFloatVectorMetricType(const MetricType& metric_type) {
return metric_type == knowhere::metric::L2 ||
metric_type == knowhere::metric::IP ||
metric_type == knowhere::metric::COSINE;
metric_type == knowhere::metric::COSINE ||
metric_type == knowhere::metric::BM25;
}
inline bool

View File

@ -160,13 +160,15 @@ inline bool
IsFloatMetricType(const knowhere::MetricType& metric_type) {
return IsMetricType(metric_type, knowhere::metric::L2) ||
IsMetricType(metric_type, knowhere::metric::IP) ||
IsMetricType(metric_type, knowhere::metric::COSINE);
IsMetricType(metric_type, knowhere::metric::COSINE) ||
IsMetricType(metric_type, knowhere::metric::BM25);
}
inline bool
PositivelyRelated(const knowhere::MetricType& metric_type) {
return IsMetricType(metric_type, knowhere::metric::IP) ||
IsMetricType(metric_type, knowhere::metric::COSINE);
IsMetricType(metric_type, knowhere::metric::COSINE) ||
IsMetricType(metric_type, knowhere::metric::BM25);
}
inline std::string

View File

@ -409,7 +409,8 @@ VectorMemIndex<T>::Query(const DatasetPtr dataset,
milvus::tracer::AddEvent("finish_knowhere_index_search");
if (!res.has_value()) {
PanicInfo(ErrorCode::UnexpectedError,
"failed to search: {}: {}",
"failed to search: config={} {}: {}",
search_conf.dump(),
KnowhereStatusString(res.error()),
res.what());
}

View File

@ -52,6 +52,11 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
search_info.materialized_view_involved =
query_info_proto.materialized_view_involved();
if (query_info_proto.bm25_avgdl() > 0) {
search_info.search_params_[knowhere::meta::BM25_AVGDL] =
query_info_proto.bm25_avgdl();
}
if (query_info_proto.group_by_field_id() > 0) {
auto group_by_field_id =
FieldId(query_info_proto.group_by_field_id());

View File

@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "IndexConfigGenerator.h"
#include "knowhere/comp/index_param.h"
#include "log/Log.h"
namespace milvus::segcore {
@ -49,15 +50,28 @@ VecIndexConfig::VecIndexConfig(const int64_t max_index_row_cout,
std::to_string(config_.get_nlist());
build_params_[knowhere::indexparam::SSIZE] = std::to_string(
std::max((int)(config_.get_chunk_rows() / config_.get_nlist()), 48));
if (is_sparse && metric_type_ == knowhere::metric::BM25) {
build_params_[knowhere::meta::BM25_K1] =
index_meta_.GetIndexParams().at(knowhere::meta::BM25_K1);
build_params_[knowhere::meta::BM25_B] =
index_meta_.GetIndexParams().at(knowhere::meta::BM25_B);
build_params_[knowhere::meta::BM25_AVGDL] =
index_meta_.GetIndexParams().at(knowhere::meta::BM25_AVGDL);
}
search_params_[knowhere::indexparam::NPROBE] =
std::to_string(config_.get_nprobe());
// note for sparse vector index: drop_ratio_build is not allowed for growing
// segment index.
LOG_INFO(
"VecIndexConfig: origin_index_type={}, index_type={}, metric_type={}",
"VecIndexConfig: origin_index_type={}, index_type={}, metric_type={}, "
"config={}",
origin_index_type_,
index_type_,
metric_type_);
metric_type_,
build_params_.dump());
}
int64_t
@ -100,6 +114,11 @@ VecIndexConfig::GetSearchConf(const SearchInfo& searchInfo) {
searchParam.search_params_[key] = searchInfo.search_params_[key];
}
}
if (metric_type_ == knowhere::metric::BM25) {
searchParam.search_params_[knowhere::meta::BM25_AVGDL] =
searchInfo.search_params_[knowhere::meta::BM25_AVGDL];
}
return searchParam;
}

View File

@ -25,6 +25,7 @@
#include "indexbuilder/ScalarIndexCreator.h"
#include "indexbuilder/VecIndexCreator.h"
#include "indexbuilder/index_c.h"
#include "knowhere/comp/index_param.h"
#include "pb/index_cgo_msg.pb.h"
#include "storage/Types.h"
@ -100,6 +101,14 @@ generate_build_conf(const milvus::IndexType& index_type,
};
} else if (index_type == knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX ||
index_type == knowhere::IndexEnum::INDEX_SPARSE_WAND) {
if (metric_type == knowhere::metric::BM25) {
return knowhere::Json{
{knowhere::meta::METRIC_TYPE, metric_type},
{knowhere::indexparam::DROP_RATIO_BUILD, "0.1"},
{knowhere::meta::BM25_K1, "1.2"},
{knowhere::meta::BM25_B, "0.75"},
{knowhere::meta::BM25_AVGDL, "100"}};
}
return knowhere::Json{
{knowhere::meta::METRIC_TYPE, metric_type},
{knowhere::indexparam::DROP_RATIO_BUILD, "0.1"},

View File

@ -652,7 +652,7 @@ func (s *L0CompactionTaskSuite) TestPorcessStateTrans() {
s.Equal(datapb.CompactionTaskState_failed, t.GetState())
})
s.Run("test unkonwn task", func() {
s.Run("test unknown task", func() {
t := s.generateTestL0Task(datapb.CompactionTaskState_unknown)
got := t.Process()

View File

@ -73,7 +73,7 @@ func newEmbeddingNode(channelName string, schema *schemapb.CollectionSchema) (*e
}
func (eNode *embeddingNode) Name() string {
return fmt.Sprintf("embeddingNode-%s-%s", "BM25test", eNode.channelName)
return fmt.Sprintf("embeddingNode-%s", eNode.channelName)
}
func (eNode *embeddingNode) bm25Embedding(runner function.FunctionRunner, inputFieldId, outputFieldId int64, data *storage.InsertData, meta map[int64]*storage.BM25Stats) error {

View File

@ -96,6 +96,7 @@ message SubSearchRequest {
string metricType = 9;
int64 group_by_field_id = 10;
int64 group_size = 11;
int64 field_id = 12;
}
message SearchRequest {
@ -124,6 +125,7 @@ message SearchRequest {
common.ConsistencyLevel consistency_level = 22;
int64 group_by_field_id = 23;
int64 group_size = 24;
int64 field_id = 25;
}
message SubSearchResults {

View File

@ -64,6 +64,8 @@ message QueryInfo {
bool materialized_view_involved = 7;
int64 group_size = 8;
bool group_strict_size = 9;
double bm25_avgdl = 10;
int64 query_field_id =11;
}
message ColumnInfo {

View File

@ -367,6 +367,7 @@ message SegmentLoadInfo {
int64 storageVersion = 18;
bool is_sorted = 19;
map<int64, data.TextIndexStats> textStatsLogs = 20;
repeated data.FieldBinlog bm25logs = 21;
}
message FieldIndexInfo {

View File

@ -361,8 +361,12 @@ func (cit *createIndexTask) parseIndexParams() error {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "float vector index does not support metric type: "+metricType)
}
} else if typeutil.IsSparseFloatVectorType(cit.fieldSchema.DataType) {
if metricType != metric.IP {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "only IP is the supported metric type for sparse index")
if metricType != metric.IP && metricType != metric.BM25 {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "only IP&BM25 is the supported metric type for sparse index")
}
if metricType == metric.BM25 && cit.functionSchema.GetType() != schemapb.FunctionType_BM25 {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "only BM25 Function output field support BM25 metric type")
}
} else if typeutil.IsBinaryVectorType(cit.fieldSchema.DataType) {
if !funcutil.SliceContain(indexparamcheck.BinaryVectorMetrics, metricType) {

View File

@ -370,6 +370,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
GroupSize: t.rankParams.GetGroupSize(),
}
internalSubReq.FieldId = queryInfo.GetQueryFieldId()
// set PartitionIDs for sub search
if t.partitionKeyMode {
// isolatioin has tighter constraint, check first
@ -449,6 +450,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
}
t.SearchRequest.Offset = offset
t.SearchRequest.FieldId = queryInfo.GetQueryFieldId()
if t.partitionKeyMode {
// isolatioin has tighter constraint, check first
@ -511,6 +513,8 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string
if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
return nil, nil, 0, errors.New("not support search_group_by operation based on binary vector column")
}
queryInfo.QueryFieldId = annField.GetFieldID()
plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, queryInfo)
if planErr != nil {
log.Warn("failed to create query plan", zap.Error(planErr),

View File

@ -81,6 +81,7 @@ func PackSegmentLoadInfo(segment *datapb.SegmentInfo, channelCheckpoint *msgpb.M
NumOfRows: segment.NumOfRows,
Statslogs: segment.Statslogs,
Deltalogs: segment.Deltalogs,
Bm25Logs: segment.Bm25Statslogs,
InsertChannel: segment.InsertChannel,
IndexInfos: indexes,
StartPosition: segment.GetStartPosition(),

View File

@ -34,6 +34,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
@ -43,6 +44,7 @@ import (
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/common"
@ -54,6 +56,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/lifetime"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metautil"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
@ -110,7 +113,9 @@ type shardDelegator struct {
lifetime lifetime.Lifetime[lifetime.State]
distribution *distribution
distribution *distribution
idfOracle IDFOracle
segmentManager segments.SegmentManager
tsafeManager tsafe.Manager
pkOracle pkoracle.PkOracle
@ -135,6 +140,10 @@ type shardDelegator struct {
// in order to make add/remove growing be atomic, need lock before modify these meta info
growingSegmentLock sync.RWMutex
partitionStatsMut sync.RWMutex
// fieldId -> functionRunner map for search function field
functionRunners map[UniqueID]function.FunctionRunner
hasBM25Field bool
}
// getLogger returns the zap logger with pre-defined shard attributes.
@ -235,6 +244,19 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest
}()
}
// build idf for bm25 search
if req.GetReq().GetMetricType() == metric.BM25 {
avgdl, err := sd.buildBM25IDF(req.GetReq())
if err != nil {
return nil, err
}
if avgdl <= 0 {
log.Warn("search bm25 from empty data, skip search", zap.String("channel", sd.vchannelName), zap.Float64("avgdl", avgdl))
return []*internalpb.SearchResults{}, nil
}
}
// get final sealedNum after possible segment prune
sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) })
log.Debug("search segments...",
@ -335,6 +357,7 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
IsAdvanced: false,
GroupByFieldId: subReq.GetGroupByFieldId(),
GroupSize: subReq.GetGroupSize(),
FieldId: subReq.GetFieldId(),
}
future := conc.Go(func() (*internalpb.SearchResults, error) {
searchReq := &querypb.SearchRequest{
@ -862,6 +885,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
excludedSegments := NewExcludedSegments(paramtable.Get().QueryNodeCfg.CleanExcludeSegInterval.GetAsDuration(time.Second))
idfOracle := NewIDFOracle(collection.Schema().GetFunctions())
sd := &shardDelegator{
collectionID: collectionID,
replicaID: replicaID,
@ -871,7 +895,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
segmentManager: manager.Segment,
workerManager: workerManager,
lifetime: lifetime.NewLifetime(lifetime.Initializing),
distribution: NewDistribution(),
distribution: NewDistribution(idfOracle),
deleteBuffer: deletebuffer.NewListDeleteBuffer[*deletebuffer.Item](startTs, sizePerBlock),
pkOracle: pkoracle.NewPkOracle(),
tsafeManager: tsafeManager,
@ -880,9 +904,25 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
factory: factory,
queryHook: queryHook,
chunkManager: chunkManager,
idfOracle: idfOracle,
partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot),
excludedSegments: excludedSegments,
functionRunners: make(map[int64]function.FunctionRunner),
}
for _, tf := range collection.Schema().GetFunctions() {
if tf.GetType() == schemapb.FunctionType_BM25 {
functionRunner, err := function.NewFunctionRunner(collection.Schema(), tf)
if err != nil {
return nil, err
}
sd.functionRunners[tf.OutputFieldIds[0]] = functionRunner
if tf.GetType() == schemapb.FunctionType_BM25 {
sd.hasBM25Field = true
}
}
}
m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)
if sd.lifetime.Add(lifetime.NotStopped) == nil {

View File

@ -27,11 +27,14 @@ import (
"github.com/samber/lo"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
@ -63,10 +66,12 @@ import (
// InsertData
type InsertData struct {
RowIDs []int64
PrimaryKeys []storage.PrimaryKey
Timestamps []uint64
InsertRecord *segcorepb.InsertRecord
RowIDs []int64
PrimaryKeys []storage.PrimaryKey
Timestamps []uint64
InsertRecord *segcorepb.InsertRecord
BM25Stats map[int64]*storage.BM25Stats
StartPosition *msgpb.MsgPosition
PartitionID int64
}
@ -149,6 +154,7 @@ func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) {
if !sd.pkOracle.Exists(growing, paramtable.GetNodeID()) {
// register created growing segment after insert, avoid to add empty growing to delegator
sd.pkOracle.Register(growing, paramtable.GetNodeID())
sd.idfOracle.Register(segmentID, insertData.BM25Stats, segments.SegmentTypeGrowing)
sd.segmentManager.Put(context.Background(), segments.SegmentTypeGrowing, growing)
sd.addGrowing(SegmentEntry{
NodeID: paramtable.GetNodeID(),
@ -158,10 +164,12 @@ func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) {
TargetVersion: initialTargetVersion,
})
}
sd.growingSegmentLock.Unlock()
}
log.Debug("insert into growing segment",
sd.growingSegmentLock.Unlock()
} else {
sd.idfOracle.UpdateGrowing(growing.ID(), insertData.BM25Stats)
}
log.Info("insert into growing segment",
zap.Int64("collectionID", growing.Collection()),
zap.Int64("segmentID", segmentID),
zap.Int("rowCount", len(insertData.RowIDs)),
@ -375,8 +383,11 @@ func (sd *shardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.Segm
segmentIDs = lo.Map(loaded, func(segment segments.Segment, _ int) int64 { return segment.ID() })
log.Info("load growing segments done", zap.Int64s("segmentIDs", segmentIDs))
for _, candidate := range loaded {
sd.pkOracle.Register(candidate, paramtable.GetNodeID())
for _, segment := range loaded {
sd.pkOracle.Register(segment, paramtable.GetNodeID())
if sd.hasBM25Field {
sd.idfOracle.Register(segment.ID(), segment.GetBM25Stats(), segments.SegmentTypeGrowing)
}
}
sd.addGrowing(lo.Map(loaded, func(segment segments.Segment, _ int) SegmentEntry {
return SegmentEntry{
@ -472,6 +483,16 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg
infos := lo.Filter(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) bool {
return !sd.pkOracle.Exists(pkoracle.NewCandidateKey(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed), targetNodeID)
})
var bm25Stats *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats]
if sd.hasBM25Field {
bm25Stats, err = sd.loader.LoadBM25Stats(ctx, req.GetCollectionID(), infos...)
if err != nil {
log.Warn("failed to load bm25 stats for segment", zap.Error(err))
return err
}
}
candidates, err := sd.loader.LoadBloomFilterSet(ctx, req.GetCollectionID(), req.GetVersion(), infos...)
if err != nil {
log.Warn("failed to load bloom filter set for segment", zap.Error(err))
@ -479,7 +500,7 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg
}
log.Debug("load delete...")
err = sd.loadStreamDelete(ctx, candidates, infos, req, targetNodeID, worker)
err = sd.loadStreamDelete(ctx, candidates, bm25Stats, infos, req, targetNodeID, worker)
if err != nil {
log.Warn("load stream delete failed", zap.Error(err))
return err
@ -552,6 +573,7 @@ func (sd *shardDelegator) RefreshLevel0DeletionStats() {
func (sd *shardDelegator) loadStreamDelete(ctx context.Context,
candidates []*pkoracle.BloomFilterSet,
bm25Stats *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats],
infos []*querypb.SegmentLoadInfo,
req *querypb.LoadSegmentsRequest,
targetNodeID int64,
@ -665,6 +687,14 @@ func (sd *shardDelegator) loadStreamDelete(ctx context.Context,
)
sd.pkOracle.Register(candidate, targetNodeID)
}
if bm25Stats != nil {
bm25Stats.Range(func(segmentID int64, stats map[int64]*storage.BM25Stats) bool {
sd.idfOracle.Register(segmentID, stats, segments.SegmentTypeSealed)
return false
})
}
log.Info("load delete done")
return nil
@ -963,3 +993,47 @@ func (sd *shardDelegator) TryCleanExcludedSegments(ts uint64) {
sd.excludedSegments.CleanInvalid(ts)
}
}
func (sd *shardDelegator) buildBM25IDF(req *internalpb.SearchRequest) (float64, error) {
pb := &commonpb.PlaceholderGroup{}
proto.Unmarshal(req.GetPlaceholderGroup(), pb)
if len(pb.Placeholders) != 1 || len(pb.Placeholders[0].Values) != 1 {
return 0, merr.WrapErrParameterInvalidMsg("please provide varchar for bm25")
}
holder := pb.Placeholders[0]
if holder.Type != commonpb.PlaceholderType_VarChar {
return 0, fmt.Errorf("can't build BM25 IDF for data not varchar")
}
str := funcutil.GetVarCharFromPlaceholder(holder)
functionRunner, ok := sd.functionRunners[req.GetFieldId()]
if !ok {
return 0, fmt.Errorf("functionRunner not found for field: %d", req.GetFieldId())
}
// get search text term frequency
output, err := functionRunner.BatchRun(str)
if err != nil {
return 0, err
}
tfArray, ok := output[0].(*schemapb.SparseFloatArray)
if !ok {
return 0, fmt.Errorf("functionRunner return unknown data")
}
idfSparseVector, avgdl, err := sd.idfOracle.BuildIDF(req.GetFieldId(), tfArray)
if err != nil {
return 0, err
}
err = SetBM25Params(req, avgdl)
if err != nil {
return 0, err
}
req.PlaceholderGroup = funcutil.SparseVectorDataToPlaceholderGroupBytes(idfSparseVector)
return avgdl, nil
}

View File

@ -26,14 +26,19 @@ import (
"time"
"github.com/cockroachdb/errors"
"github.com/pingcap/log"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
@ -42,10 +47,12 @@ import (
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/bloomfilter"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/internal/util/initcore"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metautil"
"github.com/milvus-io/milvus/pkg/util/metric"
@ -95,13 +102,7 @@ func (s *DelegatorDataSuite) TearDownSuite() {
paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.CleanExcludeSegInterval.Key)
}
func (s *DelegatorDataSuite) SetupTest() {
s.workerManager = &cluster.MockManager{}
s.manager = segments.NewManager()
s.tsafeManager = tsafe.NewTSafeReplica()
s.loader = &segments.MockLoader{}
// init schema
func (s *DelegatorDataSuite) genNormalCollection() {
s.manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{
Name: "TestCollection",
Fields: []*schemapb.FieldSchema{
@ -154,7 +155,59 @@ func (s *DelegatorDataSuite) SetupTest() {
LoadType: querypb.LoadType_LoadCollection,
PartitionIDs: []int64{1001, 1002},
})
}
func (s *DelegatorDataSuite) genCollectionWithFunction() {
s.manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{
Name: "TestCollection",
Fields: []*schemapb.FieldSchema{
{
Name: "id",
FieldID: 100,
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
AutoID: true,
}, {
Name: "vector",
FieldID: 101,
IsPrimaryKey: false,
DataType: schemapb.DataType_SparseFloatVector,
}, {
Name: "text",
FieldID: 102,
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.MaxLengthKey,
Value: "256",
},
},
},
},
Functions: []*schemapb.FunctionSchema{{
Type: schemapb.FunctionType_BM25,
InputFieldIds: []int64{102},
OutputFieldIds: []int64{101},
}},
}, nil, nil)
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager)
s.NoError(err)
s.delegator = delegator.(*shardDelegator)
}
func (s *DelegatorDataSuite) SetupTest() {
s.workerManager = &cluster.MockManager{}
s.manager = segments.NewManager()
s.tsafeManager = tsafe.NewTSafeReplica()
s.loader = &segments.MockLoader{}
// init schema
s.genNormalCollection()
s.mq = &msgstream.MockMsgStream{}
s.rootPath = s.Suite.T().Name()
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
@ -471,6 +524,127 @@ func (s *DelegatorDataSuite) TestProcessDelete() {
s.False(s.delegator.distribution.Serviceable())
}
func (s *DelegatorDataSuite) TestLoadGrowingWithBM25() {
s.genCollectionWithFunction()
mockSegment := segments.NewMockSegment(s.T())
s.loader.EXPECT().Load(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]segments.Segment{mockSegment}, nil)
mockSegment.EXPECT().Partition().Return(111)
mockSegment.EXPECT().ID().Return(111)
mockSegment.EXPECT().Type().Return(commonpb.SegmentState_Growing)
mockSegment.EXPECT().GetBM25Stats().Return(map[int64]*storage.BM25Stats{})
err := s.delegator.LoadGrowing(context.Background(), []*querypb.SegmentLoadInfo{{SegmentID: 1}}, 1)
s.NoError(err)
}
func (s *DelegatorDataSuite) TestLoadSegmentsWithBm25() {
s.genCollectionWithFunction()
s.Run("normal_run", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
s.loader.ExpectedCalls = nil
}()
statsMap := typeutil.NewConcurrentMap[int64, map[int64]*storage.BM25Stats]()
stats := storage.NewBM25Stats()
stats.Append(map[uint32]float32{1: 1})
statsMap.Insert(1, map[int64]*storage.BM25Stats{101: stats})
s.loader.EXPECT().LoadBM25Stats(mock.Anything, s.collectionID, mock.Anything).Return(statsMap, nil)
s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything).
Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet {
return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet {
return pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed)
})
}, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error {
return nil
})
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
workers[1] = worker1
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(nil)
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
Base: commonpbutil.NewMsgBase(),
DstNodeID: 1,
CollectionID: s.collectionID,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: 100,
PartitionID: 500,
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
DeltaPosition: &msgpb.MsgPosition{Timestamp: 20000},
Level: datapb.SegmentLevel_L1,
InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", s.collectionID),
},
},
})
s.NoError(err)
sealed, _ := s.delegator.GetSegmentInfo(false)
s.Require().Equal(1, len(sealed))
s.Equal(int64(1), sealed[0].NodeID)
s.ElementsMatch([]SegmentEntry{
{
SegmentID: 100,
NodeID: 1,
PartitionID: 500,
TargetVersion: unreadableTargetVersion,
Level: datapb.SegmentLevel_L1,
},
}, sealed[0].Segments)
})
s.Run("loadBM25_failed", func() {
defer func() {
s.workerManager.ExpectedCalls = nil
s.loader.ExpectedCalls = nil
}()
s.loader.EXPECT().LoadBM25Stats(mock.Anything, s.collectionID, mock.Anything).Return(nil, fmt.Errorf("mock error"))
workers := make(map[int64]*cluster.MockWorker)
worker1 := &cluster.MockWorker{}
workers[1] = worker1
worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(nil)
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker {
return workers[nodeID]
}, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{
Base: commonpbutil.NewMsgBase(),
DstNodeID: 1,
CollectionID: s.collectionID,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: 100,
PartitionID: 500,
StartPosition: &msgpb.MsgPosition{Timestamp: 20000},
DeltaPosition: &msgpb.MsgPosition{Timestamp: 20000},
Level: datapb.SegmentLevel_L1,
InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", s.collectionID),
},
},
})
s.Error(err)
})
}
func (s *DelegatorDataSuite) TestLoadSegments() {
s.Run("normal_run", func() {
defer func() {
@ -883,6 +1057,214 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
})
}
func (s *DelegatorDataSuite) TestBuildBM25IDF() {
s.genCollectionWithFunction()
genBM25Stats := func(start uint32, end uint32) map[int64]*storage.BM25Stats {
result := make(map[int64]*storage.BM25Stats)
result[101] = storage.NewBM25Stats()
for i := start; i < end; i++ {
row := map[uint32]float32{i: 1}
result[101].Append(row)
}
return result
}
genSnapShot := func(seals, grows []int64, targetVersion int64) *snapshot {
snapshot := &snapshot{
dist: []SnapshotItem{{1, make([]SegmentEntry, 0)}},
targetVersion: targetVersion,
}
newSeal := []SegmentEntry{}
for _, seg := range seals {
newSeal = append(newSeal, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: targetVersion})
}
newGrow := []SegmentEntry{}
for _, seg := range grows {
newGrow = append(newGrow, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: targetVersion})
}
log.Info("Test-", zap.Any("shanshot", snapshot), zap.Any("seg", newSeal))
snapshot.dist[0].Segments = newSeal
snapshot.growing = newGrow
return snapshot
}
genStringFieldData := func(strs ...string) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 102,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: strs,
},
},
},
},
}
}
s.Run("normal case", func() {
// register sealed
sealedSegs := []int64{1, 2, 3, 4}
for _, segID := range sealedSegs {
// every segment stats only has one token, avgdl = 1
s.delegator.idfOracle.Register(segID, genBM25Stats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
}
snapshot := genSnapShot([]int64{1, 2, 3, 4}, []int64{}, 100)
s.delegator.idfOracle.SyncDistribution(snapshot)
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
s.NoError(err)
plan, err := proto.Marshal(&planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
VectorAnns: &planpb.VectorANNS{
QueryInfo: &planpb.QueryInfo{},
},
},
})
s.NoError(err)
req := &internalpb.SearchRequest{
PlaceholderGroup: placeholderGroupBytes,
SerializedExprPlan: plan,
FieldId: 101,
}
avgdl, err := s.delegator.buildBM25IDF(req)
s.NoError(err)
s.Equal(float64(1), avgdl)
// check avgdl in plan
newplan := &planpb.PlanNode{}
err = proto.Unmarshal(req.GetSerializedExprPlan(), newplan)
s.NoError(err)
annplan, ok := newplan.GetNode().(*planpb.PlanNode_VectorAnns)
s.Require().True(ok)
s.Equal(avgdl, annplan.VectorAnns.QueryInfo.Bm25Avgdl)
// check idf in placeholder
placeholder := &commonpb.PlaceholderGroup{}
err = proto.Unmarshal(req.GetPlaceholderGroup(), placeholder)
s.Require().NoError(err)
s.Equal(placeholder.GetPlaceholders()[0].GetType(), commonpb.PlaceholderType_SparseFloatVector)
})
s.Run("invalid place holder type error", func() {
placeholderGroupBytes, err := proto.Marshal(&commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{{Type: commonpb.PlaceholderType_SparseFloatVector}},
})
s.NoError(err)
req := &internalpb.SearchRequest{
PlaceholderGroup: placeholderGroupBytes,
FieldId: 101,
}
_, err = s.delegator.buildBM25IDF(req)
s.Error(err)
})
s.Run("no function runner error", func() {
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
s.NoError(err)
req := &internalpb.SearchRequest{
PlaceholderGroup: placeholderGroupBytes,
FieldId: 103, // invalid field id
}
_, err = s.delegator.buildBM25IDF(req)
s.Error(err)
})
s.Run("function runner run failed error", func() {
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
s.NoError(err)
oldRunner := s.delegator.functionRunners
mockRunner := function.NewMockFunctionRunner(s.T())
s.delegator.functionRunners = map[int64]function.FunctionRunner{101: mockRunner}
mockRunner.EXPECT().BatchRun(mock.Anything).Return(nil, fmt.Errorf("mock err"))
defer func() {
s.delegator.functionRunners = oldRunner
}()
req := &internalpb.SearchRequest{
PlaceholderGroup: placeholderGroupBytes,
FieldId: 101,
}
_, err = s.delegator.buildBM25IDF(req)
s.Error(err)
})
s.Run("function runner output type error", func() {
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
s.NoError(err)
oldRunner := s.delegator.functionRunners
mockRunner := function.NewMockFunctionRunner(s.T())
s.delegator.functionRunners = map[int64]function.FunctionRunner{101: mockRunner}
mockRunner.EXPECT().BatchRun(mock.Anything).Return([]interface{}{1}, nil)
defer func() {
s.delegator.functionRunners = oldRunner
}()
req := &internalpb.SearchRequest{
PlaceholderGroup: placeholderGroupBytes,
FieldId: 101,
}
_, err = s.delegator.buildBM25IDF(req)
s.Error(err)
})
s.Run("idf oracle build idf error", func() {
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
s.NoError(err)
oldRunner := s.delegator.functionRunners
mockRunner := function.NewMockFunctionRunner(s.T())
s.delegator.functionRunners = map[int64]function.FunctionRunner{103: mockRunner}
mockRunner.EXPECT().BatchRun(mock.Anything).Return([]interface{}{&schemapb.SparseFloatArray{Contents: [][]byte{typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{1: 1})}}}, nil)
defer func() {
s.delegator.functionRunners = oldRunner
}()
req := &internalpb.SearchRequest{
PlaceholderGroup: placeholderGroupBytes,
FieldId: 103, // invalid field
}
_, err = s.delegator.buildBM25IDF(req)
s.Error(err)
log.Info("test", zap.Error(err))
})
s.Run("set avgdl failed", func() {
// register sealed
sealedSegs := []int64{1, 2, 3, 4}
for _, segID := range sealedSegs {
// every segment stats only has one token, avgdl = 1
s.delegator.idfOracle.Register(segID, genBM25Stats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
}
snapshot := genSnapShot([]int64{1, 2, 3, 4}, []int64{}, 100)
s.delegator.idfOracle.SyncDistribution(snapshot)
placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(genStringFieldData("test bm25 data"))
s.NoError(err)
req := &internalpb.SearchRequest{
PlaceholderGroup: placeholderGroupBytes,
FieldId: 101,
}
_, err = s.delegator.buildBM25IDF(req)
s.Error(err)
})
}
func (s *DelegatorDataSuite) TestReleaseSegment() {
s.loader.EXPECT().
Load(mock.Anything, s.collectionID, segments.SegmentTypeGrowing, int64(0), mock.Anything).

View File

@ -178,6 +178,84 @@ func (s *DelegatorSuite) TearDownTest() {
s.delegator = nil
}
func (s *DelegatorSuite) TestCreateDelegatorWithFunction() {
s.Run("init function failed", func() {
manager := segments.NewManager()
manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{
Name: "TestCollection",
Fields: []*schemapb.FieldSchema{
{
Name: "id",
FieldID: 100,
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
AutoID: true,
}, {
Name: "vector",
FieldID: 101,
IsPrimaryKey: false,
DataType: schemapb.DataType_SparseFloatVector,
},
},
Functions: []*schemapb.FunctionSchema{{
Type: schemapb.FunctionType_BM25,
InputFieldIds: []int64{102},
OutputFieldIds: []int64{101, 103}, // invalid output field
}},
}, nil, nil)
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager)
s.Error(err)
})
s.Run("init function failed", func() {
manager := segments.NewManager()
manager.Collection.PutOrRef(s.collectionID, &schemapb.CollectionSchema{
Name: "TestCollection",
Fields: []*schemapb.FieldSchema{
{
Name: "id",
FieldID: 100,
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
AutoID: true,
}, {
Name: "vector",
FieldID: 101,
IsPrimaryKey: false,
DataType: schemapb.DataType_SparseFloatVector,
}, {
Name: "text",
FieldID: 102,
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.MaxLengthKey,
Value: "256",
},
},
},
},
Functions: []*schemapb.FunctionSchema{{
Type: schemapb.FunctionType_BM25,
InputFieldIds: []int64{102},
OutputFieldIds: []int64{101},
}},
}, nil, nil)
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager)
s.NoError(err)
})
}
func (s *DelegatorSuite) TestBasicInfo() {
s.Equal(s.collectionID, s.delegator.Collection())
s.Equal(s.version, s.delegator.Version())

View File

@ -74,6 +74,8 @@ type distribution struct {
// current is the snapshot for quick usage for search/query
// generated for each change of distribution
current *atomic.Pointer[snapshot]
idfOracle IDFOracle
// protects current & segments
mut sync.RWMutex
}
@ -89,7 +91,7 @@ type SegmentEntry struct {
}
// NewDistribution creates a new distribution instance with all field initialized.
func NewDistribution() *distribution {
func NewDistribution(idfOracle IDFOracle) *distribution {
dist := &distribution{
serviceable: atomic.NewBool(false),
growingSegments: make(map[UniqueID]SegmentEntry),
@ -98,6 +100,7 @@ func NewDistribution() *distribution {
current: atomic.NewPointer[snapshot](nil),
offlines: typeutil.NewSet[int64](),
targetVersion: atomic.NewInt64(initialTargetVersion),
idfOracle: idfOracle,
}
dist.genSnapshot()
@ -367,6 +370,7 @@ func (d *distribution) genSnapshot() chan struct{} {
d.current.Store(newSnapShot)
// shall be a new one
d.snapshots.GetOrInsert(d.snapshotVersion, newSnapShot)
d.idfOracle.SyncDistribution(newSnapShot)
// first snapshot, return closed chan
if last == nil {

View File

@ -21,6 +21,8 @@ import (
"time"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
type DistributionSuite struct {
@ -29,7 +31,7 @@ type DistributionSuite struct {
}
func (s *DistributionSuite) SetupTest() {
s.dist = NewDistribution()
s.dist = NewDistribution(NewIDFOracle([]*schemapb.FunctionSchema{}))
s.Equal(initialTargetVersion, s.dist.getTargetVersion())
}

View File

@ -0,0 +1,262 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package delegator
import (
"fmt"
"sync"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/log"
)
type IDFOracle interface {
// Activate(segmentID int64, state commonpb.SegmentState) error
// Deactivate(segmentID int64, state commonpb.SegmentState) error
SyncDistribution(snapshot *snapshot)
UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats)
Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState)
Remove(segmentID int64, state commonpb.SegmentState)
BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error)
}
type bm25Stats struct {
stats map[int64]*storage.BM25Stats
activate bool
targetVersion int64
}
func (s *bm25Stats) Merge(stats map[int64]*storage.BM25Stats) {
for fieldID, newstats := range stats {
if stats, ok := s.stats[fieldID]; ok {
stats.Merge(newstats)
} else {
log.Panic("merge failed, BM25 stats not exist", zap.Int64("fieldID", fieldID))
}
}
}
func (s *bm25Stats) Minus(stats map[int64]*storage.BM25Stats) {
for fieldID, newstats := range stats {
if stats, ok := s.stats[fieldID]; ok {
stats.Minus(newstats)
} else {
log.Panic("minus failed, BM25 stats not exist", zap.Int64("fieldID", fieldID))
}
}
}
func (s *bm25Stats) GetStats(fieldID int64) (*storage.BM25Stats, error) {
stats, ok := s.stats[fieldID]
if !ok {
return nil, fmt.Errorf("field not found in idf oracle BM25 stats")
}
return stats, nil
}
func (s *bm25Stats) NumRow() int64 {
for _, stats := range s.stats {
return stats.NumRow()
}
return 0
}
func newBm25Stats(functions []*schemapb.FunctionSchema) *bm25Stats {
stats := &bm25Stats{
stats: make(map[int64]*storage.BM25Stats),
}
for _, function := range functions {
if function.GetType() == schemapb.FunctionType_BM25 {
stats.stats[function.GetOutputFieldIds()[0]] = storage.NewBM25Stats()
}
}
return stats
}
type idfOracle struct {
sync.RWMutex
current *bm25Stats
growing map[int64]*bm25Stats
sealed map[int64]*bm25Stats
targetVersion int64
}
func (o *idfOracle) Register(segmentID int64, stats map[int64]*storage.BM25Stats, state commonpb.SegmentState) {
o.Lock()
defer o.Unlock()
switch state {
case segments.SegmentTypeGrowing:
if _, ok := o.growing[segmentID]; ok {
return
}
o.growing[segmentID] = &bm25Stats{
stats: stats,
activate: true,
targetVersion: initialTargetVersion,
}
o.current.Merge(stats)
case segments.SegmentTypeSealed:
if _, ok := o.sealed[segmentID]; ok {
return
}
o.sealed[segmentID] = &bm25Stats{
stats: stats,
activate: false,
targetVersion: initialTargetVersion,
}
default:
log.Warn("register segment with unknown state", zap.String("stats", state.String()))
return
}
}
func (o *idfOracle) UpdateGrowing(segmentID int64, stats map[int64]*storage.BM25Stats) {
if len(stats) == 0 {
return
}
o.Lock()
defer o.Unlock()
old, ok := o.growing[segmentID]
if !ok {
return
}
old.Merge(stats)
if old.activate {
o.current.Merge(stats)
}
}
func (o *idfOracle) Remove(segmentID int64, state commonpb.SegmentState) {
o.Lock()
defer o.Unlock()
switch state {
case segments.SegmentTypeGrowing:
if stats, ok := o.growing[segmentID]; ok {
if stats.activate {
o.current.Minus(stats.stats)
}
delete(o.growing, segmentID)
}
case segments.SegmentTypeSealed:
if stats, ok := o.sealed[segmentID]; ok {
if stats.activate {
o.current.Minus(stats.stats)
}
delete(o.sealed, segmentID)
}
default:
return
}
}
func (o *idfOracle) activate(stats *bm25Stats) {
stats.activate = true
o.current.Merge(stats.stats)
}
func (o *idfOracle) deactivate(stats *bm25Stats) {
stats.activate = false
o.current.Minus(stats.stats)
}
func (o *idfOracle) SyncDistribution(snapshot *snapshot) {
o.Lock()
defer o.Unlock()
sealed, growing := snapshot.Peek()
for _, item := range sealed {
for _, segment := range item.Segments {
if stats, ok := o.sealed[segment.SegmentID]; ok {
stats.targetVersion = segment.TargetVersion
} else {
log.Warn("idf oracle lack some sealed segment", zap.Int64("segmentID", segment.SegmentID))
}
}
}
for _, segment := range growing {
if stats, ok := o.growing[segment.SegmentID]; ok {
stats.targetVersion = segment.TargetVersion
} else {
log.Warn("idf oracle lack some growing segment", zap.Int64("segmentID", segment.SegmentID))
}
}
o.targetVersion = snapshot.targetVersion
for _, stats := range o.sealed {
if !stats.activate && stats.targetVersion == o.targetVersion {
o.activate(stats)
} else if stats.activate && stats.targetVersion != o.targetVersion {
o.deactivate(stats)
}
}
for _, stats := range o.growing {
if !stats.activate && (stats.targetVersion == o.targetVersion || stats.targetVersion == initialTargetVersion) {
o.activate(stats)
} else if stats.activate && (stats.targetVersion != o.targetVersion && stats.targetVersion != initialTargetVersion) {
o.deactivate(stats)
}
}
log.Debug("sync distribution finished", zap.Int64("version", o.targetVersion), zap.Int64("numrow", o.current.NumRow()))
}
func (o *idfOracle) BuildIDF(fieldID int64, tfs *schemapb.SparseFloatArray) ([][]byte, float64, error) {
o.RLock()
defer o.RUnlock()
stats, err := o.current.GetStats(fieldID)
if err != nil {
return nil, 0, err
}
idfBytes := make([][]byte, len(tfs.GetContents()))
for i, tf := range tfs.GetContents() {
idf := stats.BuildIDF(tf)
idfBytes[i] = idf
}
return idfBytes, stats.GetAvgdl(), nil
}
func NewIDFOracle(functions []*schemapb.FunctionSchema) IDFOracle {
return &idfOracle{
current: newBm25Stats(functions),
growing: make(map[int64]*bm25Stats),
sealed: make(map[int64]*bm25Stats),
}
}

View File

@ -0,0 +1,198 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package delegator
import (
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type IDFOracleSuite struct {
suite.Suite
collectionID int64
collectionSchema *schemapb.CollectionSchema
idfOracle *idfOracle
targetVersion int64
snapshot *snapshot
}
func (suite *IDFOracleSuite) SetupSuite() {
suite.collectionID = 111
suite.collectionSchema = &schemapb.CollectionSchema{
Functions: []*schemapb.FunctionSchema{{
Type: schemapb.FunctionType_BM25,
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
}},
}
}
func (suite *IDFOracleSuite) SetupTest() {
suite.idfOracle = NewIDFOracle(suite.collectionSchema.GetFunctions()).(*idfOracle)
suite.snapshot = &snapshot{
dist: []SnapshotItem{{1, make([]SegmentEntry, 0)}},
}
suite.targetVersion = 0
}
func (suite *IDFOracleSuite) genStats(start uint32, end uint32) map[int64]*storage.BM25Stats {
result := make(map[int64]*storage.BM25Stats)
result[102] = storage.NewBM25Stats()
for i := start; i < end; i++ {
row := map[uint32]float32{i: 1}
result[102].Append(row)
}
return result
}
// update test snapshot
func (suite *IDFOracleSuite) updateSnapshot(seals, grows, drops []int64) *snapshot {
suite.targetVersion++
snapshot := &snapshot{
dist: []SnapshotItem{{1, make([]SegmentEntry, 0)}},
targetVersion: suite.targetVersion,
}
dropSet := typeutil.NewSet[int64]()
dropSet.Insert(drops...)
newSeal := []SegmentEntry{}
for _, seg := range suite.snapshot.dist[0].Segments {
if !dropSet.Contain(seg.SegmentID) {
seg.TargetVersion = suite.targetVersion
}
newSeal = append(newSeal, seg)
}
for _, seg := range seals {
newSeal = append(newSeal, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: suite.targetVersion})
}
newGrow := []SegmentEntry{}
for _, seg := range suite.snapshot.growing {
if !dropSet.Contain(seg.SegmentID) {
seg.TargetVersion = suite.targetVersion
} else {
seg.TargetVersion = redundantTargetVersion
}
newGrow = append(newGrow, seg)
}
for _, seg := range grows {
newGrow = append(newGrow, SegmentEntry{NodeID: 1, SegmentID: seg, TargetVersion: suite.targetVersion})
}
snapshot.dist[0].Segments = newSeal
snapshot.growing = newGrow
suite.snapshot = snapshot
return snapshot
}
func (suite *IDFOracleSuite) TestSealed() {
// register sealed
sealedSegs := []int64{1, 2, 3, 4}
for _, segID := range sealedSegs {
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
}
// reduplicate register
for _, segID := range sealedSegs {
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Sealed)
}
// register sealed segment but all deactvate
suite.Zero(suite.idfOracle.current.NumRow())
// update and sync snapshot make all sealed activate
suite.updateSnapshot(sealedSegs, []int64{}, []int64{})
suite.idfOracle.SyncDistribution(suite.snapshot)
suite.Equal(int64(4), suite.idfOracle.current.NumRow())
releasedSeg := []int64{1, 2, 3}
suite.updateSnapshot([]int64{}, []int64{}, releasedSeg)
suite.idfOracle.SyncDistribution(suite.snapshot)
suite.Equal(int64(1), suite.idfOracle.current.NumRow())
for _, segID := range releasedSeg {
suite.idfOracle.Remove(segID, commonpb.SegmentState_Sealed)
}
sparse := typeutil.CreateAndSortSparseFloatRow(map[uint32]float32{4: 1})
bytes, avgdl, err := suite.idfOracle.BuildIDF(102, &schemapb.SparseFloatArray{Contents: [][]byte{sparse}, Dim: 1})
suite.NoError(err)
suite.Equal(float64(1), avgdl)
suite.Equal(map[uint32]float32{4: 0.2876821}, typeutil.SparseFloatBytesToMap(bytes[0]))
}
func (suite *IDFOracleSuite) TestGrow() {
// register grow
growSegs := []int64{1, 2, 3, 4}
for _, segID := range growSegs {
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Growing)
}
// reduplicate register
for _, segID := range growSegs {
suite.idfOracle.Register(segID, suite.genStats(uint32(segID), uint32(segID)+1), commonpb.SegmentState_Growing)
}
// register sealed segment but all deactvate
suite.Equal(int64(4), suite.idfOracle.current.NumRow())
suite.updateSnapshot([]int64{}, growSegs, []int64{})
releasedSeg := []int64{1, 2, 3}
suite.updateSnapshot([]int64{}, []int64{}, releasedSeg)
suite.idfOracle.SyncDistribution(suite.snapshot)
suite.Equal(int64(1), suite.idfOracle.current.NumRow())
suite.idfOracle.UpdateGrowing(4, suite.genStats(5, 6))
suite.Equal(int64(2), suite.idfOracle.current.NumRow())
for _, segID := range releasedSeg {
suite.idfOracle.Remove(segID, commonpb.SegmentState_Growing)
}
}
func (suite *IDFOracleSuite) TestStats() {
stats := newBm25Stats([]*schemapb.FunctionSchema{{
Type: schemapb.FunctionType_BM25,
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
}})
suite.Panics(func() {
stats.Merge(map[int64]*storage.BM25Stats{103: storage.NewBM25Stats()})
})
suite.Panics(func() {
stats.Minus(map[int64]*storage.BM25Stats{103: storage.NewBM25Stats()})
})
_, err := stats.GetStats(103)
suite.Error(err)
_, err = stats.GetStats(102)
suite.NoError(err)
}
func TestIDFOracle(t *testing.T) {
suite.Run(t, new(IDFOracleSuite))
}

View File

@ -0,0 +1,64 @@
package delegator
import (
"fmt"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
)
func BuildSparseFieldData(field *schemapb.FieldSchema, sparseArray *schemapb.SparseFloatArray) *schemapb.FieldData {
return &schemapb.FieldData{
Type: field.GetDataType(),
FieldName: field.GetName(),
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: sparseArray.GetDim(),
Data: &schemapb.VectorField_SparseFloatVector{
SparseFloatVector: sparseArray,
},
},
},
FieldId: field.GetFieldID(),
}
}
func SetBM25Params(req *internalpb.SearchRequest, avgdl float64) error {
log := log.With(zap.Int64("collection", req.GetCollectionID()))
serializedPlan := req.GetSerializedExprPlan()
// plan not found
if serializedPlan == nil {
log.Warn("serialized plan not found")
return merr.WrapErrParameterInvalid("serialized search plan", "nil")
}
plan := planpb.PlanNode{}
err := proto.Unmarshal(serializedPlan, &plan)
if err != nil {
log.Warn("failed to unmarshal plan", zap.Error(err))
return merr.WrapErrParameterInvalid("valid serialized search plan", "no unmarshalable one", err.Error())
}
switch plan.GetNode().(type) {
case *planpb.PlanNode_VectorAnns:
queryInfo := plan.GetVectorAnns().GetQueryInfo()
queryInfo.Bm25Avgdl = avgdl
serializedExprPlan, err := proto.Marshal(&plan)
if err != nil {
log.Warn("failed to marshal optimized plan", zap.Error(err))
return merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error())
}
req.SerializedExprPlan = serializedExprPlan
log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo))
default:
log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode())))
}
return nil
}

View File

@ -60,6 +60,7 @@ func loadL0Segments(ctx context.Context, delegator delegator.ShardDelegator, req
NumOfRows: segmentInfo.NumOfRows,
Statslogs: segmentInfo.Statslogs,
Deltalogs: segmentInfo.Deltalogs,
Bm25Logs: segmentInfo.Bm25Statslogs,
InsertChannel: segmentInfo.InsertChannel,
StartPosition: segmentInfo.GetStartPosition(),
Level: segmentInfo.GetLevel(),
@ -101,6 +102,7 @@ func loadGrowingSegments(ctx context.Context, delegator delegator.ShardDelegator
NumOfRows: segmentInfo.NumOfRows,
Statslogs: segmentInfo.Statslogs,
Deltalogs: segmentInfo.Deltalogs,
Bm25Logs: segmentInfo.Bm25Statslogs,
InsertChannel: segmentInfo.InsertChannel,
})
} else {

View File

@ -0,0 +1,211 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"fmt"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/function"
base "github.com/milvus-io/milvus/internal/util/pipeline"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type embeddingNode struct {
*BaseNode
collectionID int64
channel string
manager *DataManager
functionRunners []function.FunctionRunner
}
func newEmbeddingNode(collectionID int64, channelName string, manager *DataManager, maxQueueLength int32) (*embeddingNode, error) {
collection := manager.Collection.Get(collectionID)
if collection == nil {
log.Error("embeddingNode init failed with collection not exist", zap.Int64("collection", collectionID))
return nil, merr.WrapErrCollectionNotFound(collectionID)
}
if len(collection.Schema().GetFunctions()) == 0 {
return nil, nil
}
node := &embeddingNode{
BaseNode: base.NewBaseNode(fmt.Sprintf("EmbeddingNode-%s", channelName), maxQueueLength),
collectionID: collectionID,
channel: channelName,
manager: manager,
functionRunners: make([]function.FunctionRunner, 0),
}
for _, tf := range collection.Schema().GetFunctions() {
functionRunner, err := function.NewFunctionRunner(collection.Schema(), tf)
if err != nil {
return nil, err
}
node.functionRunners = append(node.functionRunners, functionRunner)
}
return node, nil
}
func (eNode *embeddingNode) Name() string {
return fmt.Sprintf("embeddingNode-%s", eNode.channel)
}
func (eNode *embeddingNode) addInsertData(insertDatas map[UniqueID]*delegator.InsertData, msg *InsertMsg, collection *Collection) error {
iData, ok := insertDatas[msg.SegmentID]
if !ok {
iData = &delegator.InsertData{
PartitionID: msg.PartitionID,
BM25Stats: make(map[int64]*storage.BM25Stats),
StartPosition: &msgpb.MsgPosition{
Timestamp: msg.BeginTs(),
ChannelName: msg.GetShardName(),
},
}
insertDatas[msg.SegmentID] = iData
}
err := eNode.embedding(msg, iData.BM25Stats)
if err != nil {
log.Error("failed to function data", zap.Error(err))
return err
}
insertRecord, err := storage.TransferInsertMsgToInsertRecord(collection.Schema(), msg)
if err != nil {
err = fmt.Errorf("failed to get primary keys, err = %d", err)
log.Error(err.Error(), zap.String("channel", eNode.channel))
return err
}
if iData.InsertRecord == nil {
iData.InsertRecord = insertRecord
} else {
err := typeutil.MergeFieldData(iData.InsertRecord.FieldsData, insertRecord.FieldsData)
if err != nil {
log.Warn("failed to merge field data", zap.String("channel", eNode.channel), zap.Error(err))
return err
}
iData.InsertRecord.NumRows += insertRecord.NumRows
}
pks, err := segments.GetPrimaryKeys(msg, collection.Schema())
if err != nil {
log.Warn("failed to get primary keys from insert message", zap.String("channel", eNode.channel), zap.Error(err))
return err
}
iData.PrimaryKeys = append(iData.PrimaryKeys, pks...)
iData.RowIDs = append(iData.RowIDs, msg.RowIDs...)
iData.Timestamps = append(iData.Timestamps, msg.Timestamps...)
log.Debug("pipeline embedding insert msg",
zap.Int64("collectionID", eNode.collectionID),
zap.Int64("segmentID", msg.SegmentID),
zap.Int("insertRowNum", len(pks)),
zap.Uint64("timestampMin", msg.BeginTimestamp),
zap.Uint64("timestampMax", msg.EndTimestamp))
return nil
}
func (eNode *embeddingNode) bm25Embedding(runner function.FunctionRunner, msg *msgstream.InsertMsg, stats map[int64]*storage.BM25Stats) error {
functionSchema := runner.GetSchema()
inputFieldID := functionSchema.GetInputFieldIds()[0]
outputFieldID := functionSchema.GetOutputFieldIds()[0]
outputField := runner.GetOutputFields()[0]
data, err := GetEmbeddingFieldData(msg.GetFieldsData(), inputFieldID)
if data == nil || err != nil {
return merr.WrapErrFieldNotFound(fmt.Sprint(inputFieldID))
}
output, err := runner.BatchRun(data)
if err != nil {
return err
}
sparseArray, ok := output[0].(*schemapb.SparseFloatArray)
if !ok {
return fmt.Errorf("BM25 runner return unknown type output")
}
if _, ok := stats[outputFieldID]; !ok {
stats[outputFieldID] = storage.NewBM25Stats()
}
stats[outputFieldID].AppendBytes(sparseArray.GetContents()...)
msg.FieldsData = append(msg.FieldsData, delegator.BuildSparseFieldData(outputField, sparseArray))
return nil
}
func (eNode *embeddingNode) embedding(msg *msgstream.InsertMsg, stats map[int64]*storage.BM25Stats) error {
for _, functionRunner := range eNode.functionRunners {
functionSchema := functionRunner.GetSchema()
switch functionSchema.GetType() {
case schemapb.FunctionType_BM25:
err := eNode.bm25Embedding(functionRunner, msg, stats)
if err != nil {
return err
}
default:
log.Warn("pipeline embedding with unknown function type", zap.Any("type", functionSchema.GetType()))
return fmt.Errorf("unknown function type")
}
}
return nil
}
func (eNode *embeddingNode) Operate(in Msg) Msg {
nodeMsg := in.(*insertNodeMsg)
nodeMsg.insertDatas = make(map[int64]*delegator.InsertData)
collection := eNode.manager.Collection.Get(eNode.collectionID)
if collection == nil {
log.Error("embeddingNode with collection not exist", zap.Int64("collection", eNode.collectionID))
panic("embeddingNode with collection not exist")
}
for _, msg := range nodeMsg.insertMsgs {
err := eNode.addInsertData(nodeMsg.insertDatas, msg, collection)
if err != nil {
panic(err)
}
}
return nodeMsg
}
func GetEmbeddingFieldData(datas []*schemapb.FieldData, fieldID int64) ([]string, error) {
for _, data := range datas {
if data.GetFieldId() == fieldID {
return data.GetScalars().GetStringData().GetData(), nil
}
}
return nil, fmt.Errorf("field %d not found", fieldID)
}

View File

@ -0,0 +1,281 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pipeline
import (
"fmt"
"testing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
// test of embedding node
type EmbeddingNodeSuite struct {
suite.Suite
// datas
collectionID int64
collectionSchema *schemapb.CollectionSchema
channel string
msgs []*InsertMsg
// mocks
manager *segments.Manager
segManager *segments.MockSegmentManager
colManager *segments.MockCollectionManager
}
func (suite *EmbeddingNodeSuite) SetupTest() {
paramtable.Init()
suite.collectionID = 111
suite.channel = "test-channel"
suite.collectionSchema = &schemapb.CollectionSchema{
Name: "test-collection",
Fields: []*schemapb.FieldSchema{
{
FieldID: common.TimeStampField,
Name: common.TimeStampFieldName,
DataType: schemapb.DataType_Int64,
}, {
Name: "pk",
FieldID: 100,
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
}, {
Name: "text",
FieldID: 101,
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{},
}, {
Name: "sparse",
FieldID: 102,
DataType: schemapb.DataType_SparseFloatVector,
IsFunctionOutput: true,
},
},
Functions: []*schemapb.FunctionSchema{{
Name: "BM25",
Type: schemapb.FunctionType_BM25,
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
}},
}
suite.msgs = []*msgstream.InsertMsg{{
BaseMsg: msgstream.BaseMsg{},
InsertRequest: &msgpb.InsertRequest{
SegmentID: 1,
NumRows: 3,
Version: msgpb.InsertDataVersion_ColumnBased,
Timestamps: []uint64{1, 1, 1},
FieldsData: []*schemapb.FieldData{
{
FieldId: 100,
Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}}},
},
}, {
FieldId: 101,
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{"test1", "test2", "test3"}}}},
},
},
},
},
}}
suite.segManager = segments.NewMockSegmentManager(suite.T())
suite.colManager = segments.NewMockCollectionManager(suite.T())
suite.manager = &segments.Manager{
Collection: suite.colManager,
Segment: suite.segManager,
}
}
func (suite *EmbeddingNodeSuite) TestCreateEmbeddingNode() {
suite.Run("collection not found", func() {
suite.colManager.EXPECT().Get(suite.collectionID).Return(nil).Once()
_, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
suite.Error(err)
})
suite.Run("function invalid", func() {
collSchema := proto.Clone(suite.collectionSchema).(*schemapb.CollectionSchema)
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, collSchema)
collection.Schema().Functions = []*schemapb.FunctionSchema{{}}
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
_, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
suite.Error(err)
})
suite.Run("normal case", func() {
collSchema := proto.Clone(suite.collectionSchema).(*schemapb.CollectionSchema)
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, collSchema)
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
_, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
suite.NoError(err)
})
}
func (suite *EmbeddingNodeSuite) TestOperator() {
suite.Run("collection not found", func() {
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
suite.NoError(err)
suite.colManager.EXPECT().Get(suite.collectionID).Return(nil).Once()
suite.Panics(func() {
node.Operate(&insertNodeMsg{})
})
})
suite.Run("add InsertData Failed", func() {
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Times(2)
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
suite.NoError(err)
suite.Panics(func() {
node.Operate(&insertNodeMsg{
insertMsgs: []*msgstream.InsertMsg{{
BaseMsg: msgstream.BaseMsg{},
InsertRequest: &msgpb.InsertRequest{
SegmentID: 1,
NumRows: 3,
Version: msgpb.InsertDataVersion_ColumnBased,
FieldsData: []*schemapb.FieldData{
{
FieldId: 100,
Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}}},
},
},
},
},
}},
})
})
})
suite.Run("normal case", func() {
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Times(2)
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
suite.NoError(err)
suite.NotPanics(func() {
output := node.Operate(&insertNodeMsg{
insertMsgs: suite.msgs,
})
msg, ok := output.(*insertNodeMsg)
suite.Require().True(ok)
suite.Require().NotNil(msg.insertDatas)
suite.Require().Equal(int64(3), msg.insertDatas[1].BM25Stats[102].NumRow())
suite.Require().Equal(int64(3), msg.insertDatas[1].InsertRecord.GetNumRows())
})
})
}
func (suite *EmbeddingNodeSuite) TestAddInsertData() {
suite.Run("transfer insert msg failed", func() {
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
suite.NoError(err)
// transfer insert msg failed because rowbase data not support sparse vector
insertDatas := make(map[int64]*delegator.InsertData)
rowBaseReq := proto.Clone(suite.msgs[0].InsertRequest).(*msgpb.InsertRequest)
rowBaseReq.Version = msgpb.InsertDataVersion_RowBased
rowBaseMsg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{},
InsertRequest: rowBaseReq,
}
err = node.addInsertData(insertDatas, rowBaseMsg, collection)
suite.Error(err)
})
suite.Run("merge failed data failed", func() {
// remove pk
suite.collectionSchema.Fields[1].IsPrimaryKey = false
defer func() {
suite.collectionSchema.Fields[1].IsPrimaryKey = true
}()
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
suite.NoError(err)
insertDatas := make(map[int64]*delegator.InsertData)
err = node.addInsertData(insertDatas, suite.msgs[0], collection)
suite.Error(err)
})
}
func (suite *EmbeddingNodeSuite) TestBM25Embedding() {
suite.Run("function run failed", func() {
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
suite.NoError(err)
runner := function.NewMockFunctionRunner(suite.T())
runner.EXPECT().BatchRun(mock.Anything).Return(nil, fmt.Errorf("mock error"))
runner.EXPECT().GetSchema().Return(suite.collectionSchema.GetFunctions()[0])
runner.EXPECT().GetOutputFields().Return([]*schemapb.FieldSchema{nil})
err = node.bm25Embedding(runner, suite.msgs[0], nil)
suite.Error(err)
})
suite.Run("output with unknown type failed", func() {
collection := segments.NewCollectionWithoutSegcoreForTest(suite.collectionID, suite.collectionSchema)
suite.colManager.EXPECT().Get(suite.collectionID).Return(collection).Once()
node, err := newEmbeddingNode(suite.collectionID, suite.channel, suite.manager, 128)
suite.NoError(err)
runner := function.NewMockFunctionRunner(suite.T())
runner.EXPECT().BatchRun(mock.Anything).Return([]interface{}{1}, nil)
runner.EXPECT().GetSchema().Return(suite.collectionSchema.GetFunctions()[0])
runner.EXPECT().GetOutputFields().Return([]*schemapb.FieldSchema{nil})
err = node.bm25Embedding(runner, suite.msgs[0], nil)
suite.Error(err)
})
}
func TestEmbeddingNode(t *testing.T) {
suite.Run(t, new(EmbeddingNodeSuite))
}

View File

@ -62,7 +62,7 @@ func (iNode *insertNode) addInsertData(insertDatas map[UniqueID]*delegator.Inser
} else {
err := typeutil.MergeFieldData(iData.InsertRecord.FieldsData, insertRecord.FieldsData)
if err != nil {
log.Error("failed to merge field data", zap.Error(err))
log.Error("failed to merge field data", zap.String("channel", iNode.channel), zap.Error(err))
panic(err)
}
iData.InsertRecord.NumRows += insertRecord.NumRows
@ -95,21 +95,23 @@ func (iNode *insertNode) Operate(in Msg) Msg {
return nodeMsg.insertMsgs[i].BeginTs() < nodeMsg.insertMsgs[j].BeginTs()
})
insertDatas := make(map[UniqueID]*delegator.InsertData)
collection := iNode.manager.Collection.Get(iNode.collectionID)
if collection == nil {
log.Error("insertNode with collection not exist", zap.Int64("collection", iNode.collectionID))
panic("insertNode with collection not exist")
// build insert data if no embedding node
if nodeMsg.insertDatas == nil {
collection := iNode.manager.Collection.Get(iNode.collectionID)
if collection == nil {
log.Error("insertNode with collection not exist", zap.Int64("collection", iNode.collectionID))
panic("insertNode with collection not exist")
}
nodeMsg.insertDatas = make(map[UniqueID]*delegator.InsertData)
// get InsertData and merge datas of same segment
for _, msg := range nodeMsg.insertMsgs {
iNode.addInsertData(nodeMsg.insertDatas, msg, collection)
}
}
// get InsertData and merge datas of same segment
for _, msg := range nodeMsg.insertMsgs {
iNode.addInsertData(insertDatas, msg, collection)
}
iNode.delegator.ProcessInsert(insertDatas)
iNode.delegator.ProcessInsert(nodeMsg.insertDatas)
}
metrics.QueryNodeWaitProcessingMsgCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.DeleteLabel).Inc()
return &deleteNodeMsg{

View File

@ -19,15 +19,17 @@ package pipeline
import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/querynodev2/collector"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
)
type insertNodeMsg struct {
insertMsgs []*InsertMsg
deleteMsgs []*DeleteMsg
timeRange TimeRange
insertMsgs []*InsertMsg
deleteMsgs []*DeleteMsg
insertDatas map[int64]*delegator.InsertData
timeRange TimeRange
}
type deleteNodeMsg struct {

View File

@ -31,7 +31,8 @@ type Pipeline interface {
type pipeline struct {
base.StreamPipeline
collectionID UniqueID
collectionID UniqueID
embeddingNode embeddingNode
}
func (p *pipeline) Close() {
@ -54,8 +55,21 @@ func NewPipeLine(
}
filterNode := newFilterNode(collectionID, channel, manager, delegator, pipelineQueueLength)
embeddingNode, err := newEmbeddingNode(collectionID, channel, manager, pipelineQueueLength)
if err != nil {
return nil, err
}
insertNode := newInsertNode(collectionID, channel, manager, delegator, pipelineQueueLength)
deleteNode := newDeleteNode(collectionID, channel, manager, tSafeManager, delegator, pipelineQueueLength)
p.Add(filterNode, insertNode, deleteNode)
// skip add embedding node when collection has no function.
if embeddingNode != nil {
p.Add(filterNode, embeddingNode, insertNode, deleteNode)
} else {
p.Add(filterNode, insertNode, deleteNode)
}
return p, nil
}

View File

@ -299,6 +299,18 @@ func NewCollectionWithoutSchema(collectionID int64, loadType querypb.LoadType) *
}
}
// new collection without segcore prepare
// ONLY FOR TEST
func NewCollectionWithoutSegcoreForTest(collectionID int64, schema *schemapb.CollectionSchema) *Collection {
coll := &Collection{
id: collectionID,
partitions: typeutil.NewConcurrentSet[int64](),
refCount: atomic.NewUint32(0),
}
coll.schema.Store(schema)
return coll
}
// deleteCollection delete collection and free the collection memory
func DeleteCollection(collection *Collection) {
/*

View File

@ -47,6 +47,7 @@ import (
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/metautil"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/testutils"
@ -220,6 +221,11 @@ func genConstantFieldSchema(param constFieldParam) *schemapb.FieldSchema {
DataType: param.dataType,
ElementType: schemapb.DataType_Int32,
}
if param.dataType == schemapb.DataType_VarChar {
field.TypeParams = []*commonpb.KeyValuePair{
{Key: common.MaxLengthKey, Value: "128"},
}
}
return field
}
@ -263,6 +269,35 @@ func genVectorFieldSchema(param vecFieldParam) *schemapb.FieldSchema {
return fieldVec
}
func GenTestBM25CollectionSchema(collectionName string) *schemapb.CollectionSchema {
fieldRowID := genConstantFieldSchema(rowIDField)
fieldTimestamp := genConstantFieldSchema(timestampField)
pkFieldSchema := genPKFieldSchema(simpleInt64Field)
textFieldSchema := genConstantFieldSchema(simpleVarCharField)
sparseFieldSchema := genVectorFieldSchema(simpleSparseFloatVectorField)
sparseFieldSchema.IsFunctionOutput = true
schema := &schemapb.CollectionSchema{
Name: collectionName,
Fields: []*schemapb.FieldSchema{
fieldRowID,
fieldTimestamp,
pkFieldSchema,
textFieldSchema,
sparseFieldSchema,
},
Functions: []*schemapb.FunctionSchema{{
Name: "BM25",
Type: schemapb.FunctionType_BM25,
InputFieldNames: []string{textFieldSchema.GetName()},
InputFieldIds: []int64{textFieldSchema.GetFieldID()},
OutputFieldNames: []string{sparseFieldSchema.GetName()},
OutputFieldIds: []int64{sparseFieldSchema.GetFieldID()},
}},
}
return schema
}
// some tests do not yet support sparse float vector, see comments of
// GenSparseFloatVecDataset in indexcgowrapper/dataset.go
func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, withSparse bool) *schemapb.CollectionSchema {
@ -671,6 +706,32 @@ func SaveDeltaLog(collectionID int64,
return fieldBinlog, cm.MultiWrite(context.Background(), kvs)
}
func SaveBM25Log(collectionID int64, partitionID int64, segmentID int64, fieldID int64, msgLength int, cm storage.ChunkManager) (*datapb.FieldBinlog, error) {
stats := storage.NewBM25Stats()
for i := 0; i < msgLength; i++ {
stats.Append(map[uint32]float32{1: 1})
}
bytes, err := stats.Serialize()
if err != nil {
return nil, err
}
kvs := make(map[string][]byte, 1)
key := path.Join(cm.RootPath(), common.SegmentBm25LogPath, metautil.JoinIDPath(collectionID, partitionID, segmentID, fieldID, 1001))
kvs[key] = bytes
fieldBinlog := &datapb.FieldBinlog{
FieldID: fieldID,
Binlogs: []*datapb.Binlog{{
LogPath: key,
TimestampFrom: 100,
TimestampTo: 200,
}},
}
return fieldBinlog, cm.MultiWrite(context.Background(), kvs)
}
func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64,
fieldSchema *schemapb.FieldSchema,
indexInfo *indexpb.IndexInfo,

View File

@ -14,6 +14,10 @@ import (
pkoracle "github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
querypb "github.com/milvus-io/milvus/internal/proto/querypb"
storage "github.com/milvus-io/milvus/internal/storage"
typeutil "github.com/milvus-io/milvus/pkg/util/typeutil"
)
// MockLoader is an autogenerated mock type for the Loader type
@ -101,6 +105,76 @@ func (_c *MockLoader_Load_Call) RunAndReturn(run func(context.Context, int64, co
return _c
}
// LoadBM25Stats provides a mock function with given fields: ctx, collectionID, infos
func (_m *MockLoader) LoadBM25Stats(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error) {
_va := make([]interface{}, len(infos))
for _i := range infos {
_va[_i] = infos[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, collectionID)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats]
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, int64, ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error)); ok {
return rf(ctx, collectionID, infos...)
}
if rf, ok := ret.Get(0).(func(context.Context, int64, ...*querypb.SegmentLoadInfo) *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats]); ok {
r0 = rf(ctx, collectionID, infos...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats])
}
}
if rf, ok := ret.Get(1).(func(context.Context, int64, ...*querypb.SegmentLoadInfo) error); ok {
r1 = rf(ctx, collectionID, infos...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockLoader_LoadBM25Stats_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadBM25Stats'
type MockLoader_LoadBM25Stats_Call struct {
*mock.Call
}
// LoadBM25Stats is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - infos ...*querypb.SegmentLoadInfo
func (_e *MockLoader_Expecter) LoadBM25Stats(ctx interface{}, collectionID interface{}, infos ...interface{}) *MockLoader_LoadBM25Stats_Call {
return &MockLoader_LoadBM25Stats_Call{Call: _e.mock.On("LoadBM25Stats",
append([]interface{}{ctx, collectionID}, infos...)...)}
}
func (_c *MockLoader_LoadBM25Stats_Call) Run(run func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo)) *MockLoader_LoadBM25Stats_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]*querypb.SegmentLoadInfo, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(*querypb.SegmentLoadInfo)
}
}
run(args[0].(context.Context), args[1].(int64), variadicArgs...)
})
return _c
}
func (_c *MockLoader_LoadBM25Stats_Call) Return(_a0 *typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], _a1 error) *MockLoader_LoadBM25Stats_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockLoader_LoadBM25Stats_Call) RunAndReturn(run func(context.Context, int64, ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error)) *MockLoader_LoadBM25Stats_Call {
_c.Call.Return(run)
return _c
}
// LoadBloomFilterSet provides a mock function with given fields: ctx, collectionID, version, infos
func (_m *MockLoader) LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) {
_va := make([]interface{}, len(infos))

View File

@ -290,6 +290,49 @@ func (_c *MockSegment_ExistIndex_Call) RunAndReturn(run func(int64) bool) *MockS
return _c
}
// GetBM25Stats provides a mock function with given fields:
func (_m *MockSegment) GetBM25Stats() map[int64]*storage.BM25Stats {
ret := _m.Called()
var r0 map[int64]*storage.BM25Stats
if rf, ok := ret.Get(0).(func() map[int64]*storage.BM25Stats); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[int64]*storage.BM25Stats)
}
}
return r0
}
// MockSegment_GetBM25Stats_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetBM25Stats'
type MockSegment_GetBM25Stats_Call struct {
*mock.Call
}
// GetBM25Stats is a helper method to define mock.On call
func (_e *MockSegment_Expecter) GetBM25Stats() *MockSegment_GetBM25Stats_Call {
return &MockSegment_GetBM25Stats_Call{Call: _e.mock.On("GetBM25Stats")}
}
func (_c *MockSegment_GetBM25Stats_Call) Run(run func()) *MockSegment_GetBM25Stats_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockSegment_GetBM25Stats_Call) Return(_a0 map[int64]*storage.BM25Stats) *MockSegment_GetBM25Stats_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockSegment_GetBM25Stats_Call) RunAndReturn(run func() map[int64]*storage.BM25Stats) *MockSegment_GetBM25Stats_Call {
_c.Call.Return(run)
return _c
}
// GetIndex provides a mock function with given fields: fieldID
func (_m *MockSegment) GetIndex(fieldID int64) *IndexedFieldInfo {
ret := _m.Called(fieldID)
@ -1570,6 +1613,39 @@ func (_c *MockSegment_Unpin_Call) RunAndReturn(run func()) *MockSegment_Unpin_Ca
return _c
}
// UpdateBM25Stats provides a mock function with given fields: stats
func (_m *MockSegment) UpdateBM25Stats(stats map[int64]*storage.BM25Stats) {
_m.Called(stats)
}
// MockSegment_UpdateBM25Stats_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateBM25Stats'
type MockSegment_UpdateBM25Stats_Call struct {
*mock.Call
}
// UpdateBM25Stats is a helper method to define mock.On call
// - stats map[int64]*storage.BM25Stats
func (_e *MockSegment_Expecter) UpdateBM25Stats(stats interface{}) *MockSegment_UpdateBM25Stats_Call {
return &MockSegment_UpdateBM25Stats_Call{Call: _e.mock.On("UpdateBM25Stats", stats)}
}
func (_c *MockSegment_UpdateBM25Stats_Call) Run(run func(stats map[int64]*storage.BM25Stats)) *MockSegment_UpdateBM25Stats_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(map[int64]*storage.BM25Stats))
})
return _c
}
func (_c *MockSegment_UpdateBM25Stats_Call) Return() *MockSegment_UpdateBM25Stats_Call {
_c.Call.Return()
return _c
}
func (_c *MockSegment_UpdateBM25Stats_Call) RunAndReturn(run func(map[int64]*storage.BM25Stats)) *MockSegment_UpdateBM25Stats_Call {
_c.Call.Return(run)
return _c
}
// UpdateBloomFilter provides a mock function with given fields: pks
func (_m *MockSegment) UpdateBloomFilter(pks []storage.PrimaryKey) {
_m.Called(pks)

View File

@ -91,6 +91,8 @@ type baseSegment struct {
isLazyLoad bool
channel metautil.Channel
bm25Stats map[int64]*storage.BM25Stats
resourceUsageCache *atomic.Pointer[ResourceUsage]
needUpdatedVersion *atomic.Int64 // only for lazy load mode update index
@ -107,6 +109,7 @@ func newBaseSegment(collection *Collection, segmentType SegmentType, version int
version: atomic.NewInt64(version),
segmentType: segmentType,
bloomFilterSet: pkoracle.NewBloomFilterSet(loadInfo.GetSegmentID(), loadInfo.GetPartitionID(), segmentType),
bm25Stats: make(map[int64]*storage.BM25Stats),
channel: channel,
isLazyLoad: isLazyLoad(collection, segmentType),
@ -185,6 +188,20 @@ func (s *baseSegment) UpdateBloomFilter(pks []storage.PrimaryKey) {
s.bloomFilterSet.UpdateBloomFilter(pks)
}
func (s *baseSegment) UpdateBM25Stats(stats map[int64]*storage.BM25Stats) {
for fieldID, new := range stats {
if current, ok := s.bm25Stats[fieldID]; ok {
current.Merge(new)
} else {
s.bm25Stats[fieldID] = new
}
}
}
func (s *baseSegment) GetBM25Stats() map[int64]*storage.BM25Stats {
return s.bm25Stats
}
// MayPkExist returns true if the given PK exists in the PK range and being positive through the bloom filter,
// false otherwise,
// may returns true even the PK doesn't exist actually

View File

@ -87,6 +87,10 @@ type Segment interface {
MayPkExist(lc *storage.LocationsCache) bool
BatchPkExist(lc *storage.BatchLocationsCache) []bool
// BM25 stats
UpdateBM25Stats(stats map[int64]*storage.BM25Stats)
GetBM25Stats() map[int64]*storage.BM25Stats
// Read operations
Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error)
Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error)

View File

@ -77,6 +77,9 @@ type Loader interface {
// LoadBloomFilterSet loads needed statslog for RemoteSegment.
LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error)
// LoadBM25Stats loads BM25 statslog for RemoteSegment
LoadBM25Stats(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error)
// LoadIndex append index for segment and remove vector binlogs.
LoadIndex(ctx context.Context,
segment Segment,
@ -543,6 +546,47 @@ func (loader *segmentLoader) waitSegmentLoadDone(ctx context.Context, segmentTyp
return nil
}
func (loader *segmentLoader) LoadBM25Stats(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error) {
segmentNum := len(infos)
if segmentNum == 0 {
return nil, nil
}
segments := lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) int64 {
return info.GetSegmentID()
})
log.Info("start loading bm25 stats for remote...", zap.Int64("collectionID", collectionID), zap.Int64s("segmentIDs", segments), zap.Int("segmentNum", segmentNum))
loadedStats := typeutil.NewConcurrentMap[int64, map[int64]*storage.BM25Stats]()
loadRemoteBM25Func := func(idx int) error {
loadInfo := infos[idx]
segmentID := loadInfo.SegmentID
stats := make(map[int64]*storage.BM25Stats)
log.Info("loading bm25 stats for remote...", zap.Int64("collectionID", collectionID), zap.Int64("segment", segmentID))
logpaths := loader.filterBM25Stats(loadInfo.Bm25Logs)
err := loader.loadBm25Stats(ctx, segmentID, stats, logpaths)
if err != nil {
log.Warn("load remote segment bm25 stats failed",
zap.Int64("segmentID", segmentID),
zap.Error(err),
)
return err
}
loadedStats.Insert(segmentID, stats)
return nil
}
err := funcutil.ProcessFuncParallel(segmentNum, segmentNum, loadRemoteBM25Func, "loadRemoteBM25Func")
if err != nil {
// no partial success here
log.Warn("failed to load bm25 stats for remote segment", zap.Int64("collectionID", collectionID), zap.Int64s("segmentIDs", segments), zap.Error(err))
return nil, err
}
return loadedStats, nil
}
func (loader *segmentLoader) LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", collectionID),
@ -826,6 +870,16 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context,
if err != nil {
return err
}
if len(loadInfo.Bm25Logs) > 0 {
log.Info("loading bm25 stats...")
bm25StatsLogs := loader.filterBM25Stats(loadInfo.Bm25Logs)
err = loader.loadBm25Stats(ctx, segment.ID(), segment.bm25Stats, bm25StatsLogs)
if err != nil {
return err
}
}
}
return nil
}
@ -898,6 +952,26 @@ func (loader *segmentLoader) filterPKStatsBinlogs(fieldBinlogs []*datapb.FieldBi
return result, storage.DefaultStatsType
}
func (loader *segmentLoader) filterBM25Stats(fieldBinlogs []*datapb.FieldBinlog) map[int64][]string {
result := make(map[int64][]string, 0)
for _, fieldBinlog := range fieldBinlogs {
logpaths := []string{}
for _, binlog := range fieldBinlog.GetBinlogs() {
_, logidx := path.Split(binlog.GetLogPath())
// if special status log exist
// only load one file
if logidx == storage.CompoundStatsType.LogIdx() {
logpaths = []string{binlog.GetLogPath()}
break
} else {
logpaths = append(logpaths, binlog.GetLogPath())
}
}
result[fieldBinlog.FieldID] = logpaths
}
return result
}
func loadSealedSegmentFields(ctx context.Context, collection *Collection, segment *LocalSegment, fields []*datapb.FieldBinlog, rowCount int64) error {
runningGroup, _ := errgroup.WithContext(ctx)
for _, field := range fields {
@ -989,6 +1063,51 @@ func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalS
return segment.LoadIndex(ctx, indexInfo, fieldType)
}
func (loader *segmentLoader) loadBm25Stats(ctx context.Context, segmentID int64, stats map[int64]*storage.BM25Stats, binlogPaths map[int64][]string) error {
log := log.Ctx(ctx).With(
zap.Int64("segmentID", segmentID),
)
if len(binlogPaths) == 0 {
log.Info("there are no bm25 stats logs saved with segment")
return nil
}
pathList := []string{}
fieldList := []int64{}
fieldOffset := []int{}
for fieldId, logpaths := range binlogPaths {
pathList = append(pathList, logpaths...)
fieldList = append(fieldList, fieldId)
fieldOffset = append(fieldOffset, len(logpaths))
}
startTs := time.Now()
values, err := loader.cm.MultiRead(ctx, pathList)
if err != nil {
return err
}
cnt := 0
for i, fieldID := range fieldList {
newStats, ok := stats[fieldID]
if !ok {
newStats = storage.NewBM25Stats()
stats[fieldID] = newStats
}
for j := 0; j < fieldOffset[i]; j++ {
err := newStats.Deserialize(values[cnt+j])
if err != nil {
return err
}
}
cnt += fieldOffset[i]
log.Info("Successfully load bm25 stats", zap.Duration("time", time.Since(startTs)), zap.Int64("numRow", newStats.NumRow()), zap.Int64("fieldID", fieldID))
}
return nil
}
func (loader *segmentLoader) loadBloomFilter(ctx context.Context, segmentID int64, bfs *pkoracle.BloomFilterSet,
binlogPaths []string, logType storage.StatsLogType,
) error {

View File

@ -70,14 +70,16 @@ func (suite *SegmentLoaderSuite) SetupSuite() {
}
func (suite *SegmentLoaderSuite) SetupTest() {
// Dependencies
suite.manager = NewManager()
ctx := context.Background()
// TODO:: cpp chunk manager not support local chunk manager
// suite.chunkManager = storage.NewLocalChunkManager(storage.RootPath(
// fmt.Sprintf("/tmp/milvus-ut/%d", rand.Int63())))
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath)
suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx)
// Dependencies
suite.manager = NewManager()
suite.loader = NewLoader(suite.manager, suite.chunkManager)
initcore.InitRemoteChunkManager(paramtable.Get())
@ -92,6 +94,22 @@ func (suite *SegmentLoaderSuite) SetupTest() {
suite.manager.Collection.PutOrRef(suite.collectionID, suite.schema, indexMeta, loadMeta)
}
func (suite *SegmentLoaderSuite) SetupBM25() {
// Dependencies
suite.manager = NewManager()
suite.loader = NewLoader(suite.manager, suite.chunkManager)
initcore.InitRemoteChunkManager(paramtable.Get())
suite.schema = GenTestBM25CollectionSchema("test")
indexMeta := GenTestIndexMeta(suite.collectionID, suite.schema)
loadMeta := &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection,
CollectionID: suite.collectionID,
PartitionIDs: []int64{suite.partitionID},
}
suite.manager.Collection.PutOrRef(suite.collectionID, suite.schema, indexMeta, loadMeta)
}
func (suite *SegmentLoaderSuite) TearDownTest() {
ctx := context.Background()
for i := 0; i < suite.segmentNum; i++ {
@ -407,6 +425,41 @@ func (suite *SegmentLoaderSuite) TestLoadDeltaLogs() {
}
}
func (suite *SegmentLoaderSuite) TestLoadBm25Stats() {
suite.SetupBM25()
msgLength := 1
sparseFieldID := simpleSparseFloatVectorField.id
loadInfos := make([]*querypb.SegmentLoadInfo, 0, suite.segmentNum)
for i := 0; i < suite.segmentNum; i++ {
segmentID := suite.segmentID + int64(i)
bm25logs, err := SaveBM25Log(suite.collectionID, suite.partitionID, segmentID, sparseFieldID, msgLength, suite.chunkManager)
suite.NoError(err)
loadInfos = append(loadInfos, &querypb.SegmentLoadInfo{
SegmentID: segmentID,
PartitionID: suite.partitionID,
CollectionID: suite.collectionID,
Bm25Logs: []*datapb.FieldBinlog{bm25logs},
NumOfRows: int64(msgLength),
InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID),
})
}
statsMap, err := suite.loader.LoadBM25Stats(context.Background(), suite.collectionID, loadInfos...)
suite.NoError(err)
for i := 0; i < suite.segmentNum; i++ {
segmentID := suite.segmentID + int64(i)
stats, ok := statsMap.Get(segmentID)
suite.True(ok)
fieldStats, ok := stats[sparseFieldID]
suite.True(ok)
suite.Equal(int64(msgLength), fieldStats.NumRow())
}
}
func (suite *SegmentLoaderSuite) TestLoadDupDeltaLogs() {
ctx := context.Background()
loadInfos := make([]*querypb.SegmentLoadInfo, 0, suite.segmentNum)

View File

@ -21,10 +21,10 @@ import (
"encoding/binary"
"encoding/json"
"fmt"
"maps"
"math"
"go.uber.org/zap"
"golang.org/x/exp/maps"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/bloomfilter"
@ -347,7 +347,7 @@ func (m *BM25Stats) AppendFieldData(datas ...*SparseFloatVectorFieldData) {
// Update BM25Stats by sparse vector bytes
func (m *BM25Stats) AppendBytes(datas ...[]byte) {
for _, data := range datas {
dim := len(data) / 8
dim := typeutil.SparseFloatRowElementCount(data)
for i := 0; i < dim; i++ {
index := typeutil.SparseFloatRowIndexAt(data, i)
value := typeutil.SparseFloatRowValueAt(data, i)
@ -454,17 +454,19 @@ func (m *BM25Stats) Deserialize(bs []byte) error {
m.rowsWithToken[keys[i]] += values[i]
}
log.Info("test-- deserialize", zap.Int64("numrow", m.numRow), zap.Int64("tokenNum", m.numToken))
return nil
}
func (m *BM25Stats) BuildIDF(tf map[uint32]float32) map[uint32]float32 {
vector := make(map[uint32]float32)
for key, value := range tf {
func (m *BM25Stats) BuildIDF(tf []byte) (idf []byte) {
dim := typeutil.SparseFloatRowElementCount(tf)
idf = make([]byte, len(tf))
for idx := 0; idx < dim; idx++ {
key := typeutil.SparseFloatRowIndexAt(tf, idx)
value := typeutil.SparseFloatRowValueAt(tf, idx)
nq := m.rowsWithToken[key]
vector[key] = value * float32(math.Log(1+(float64(m.numRow)-float64(nq)+0.5)/(float64(nq)+0.5)))
typeutil.SparseFloatRowSetAt(idf, idx, key, value*float32(math.Log(1+(float64(m.numRow)-float64(nq)+0.5)/(float64(nq)+0.5))))
}
return vector
return
}
func (m *BM25Stats) GetAvgdl() float64 {

View File

@ -358,7 +358,7 @@ func readDoubleArray(blobReaders []io.Reader) []float64 {
return ret
}
func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemapb.CollectionSchema) (idata *InsertData, err error) {
func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemapb.CollectionSchema, skipFunction bool) (idata *InsertData, err error) {
blobReaders := make([]io.Reader, 0)
for _, blob := range msg.RowData {
blobReaders = append(blobReaders, bytes.NewReader(blob.GetValue()))
@ -371,7 +371,7 @@ func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemap
}
for _, field := range collSchema.Fields {
if field.GetIsFunctionOutput() {
if skipFunction && field.GetIsFunctionOutput() {
continue
}
@ -696,7 +696,7 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche
func InsertMsgToInsertData(msg *msgstream.InsertMsg, schema *schemapb.CollectionSchema) (idata *InsertData, err error) {
if msg.IsRowBased() {
return RowBasedInsertMsgToInsertData(msg, schema)
return RowBasedInsertMsgToInsertData(msg, schema, true)
}
return ColumnBasedInsertMsgToInsertData(msg, schema)
}
@ -1272,7 +1272,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
func TransferInsertMsgToInsertRecord(schema *schemapb.CollectionSchema, msg *msgstream.InsertMsg) (*segcorepb.InsertRecord, error) {
if msg.IsRowBased() {
insertData, err := RowBasedInsertMsgToInsertData(msg, schema)
insertData, err := RowBasedInsertMsgToInsertData(msg, schema, false)
if err != nil {
return nil, err
}
@ -1281,7 +1281,8 @@ func TransferInsertMsgToInsertRecord(schema *schemapb.CollectionSchema, msg *msg
// column base insert msg
insertRecord := &segcorepb.InsertRecord{
NumRows: int64(msg.NumRows),
NumRows: int64(msg.NumRows),
FieldsData: make([]*schemapb.FieldData, 0),
}
insertRecord.FieldsData = append(insertRecord.FieldsData, msg.FieldsData...)

View File

@ -1035,7 +1035,7 @@ func TestRowBasedInsertMsgToInsertData(t *testing.T) {
fieldIDs = fieldIDs[:len(fieldIDs)-2]
msg, _, columns := genRowBasedInsertMsg(numRows, fVecDim, bVecDim, f16VecDim, bf16VecDim)
idata, err := RowBasedInsertMsgToInsertData(msg, schema)
idata, err := RowBasedInsertMsgToInsertData(msg, schema, false)
assert.NoError(t, err)
for idx, fID := range fieldIDs {
column := columns[idx]
@ -1096,7 +1096,7 @@ func TestRowBasedInsertMsgToInsertFloat16VectorDataError(t *testing.T) {
},
},
}
_, err := RowBasedInsertMsgToInsertData(msg, schema)
_, err := RowBasedInsertMsgToInsertData(msg, schema, false)
assert.Error(t, err)
}
@ -1139,7 +1139,7 @@ func TestRowBasedInsertMsgToInsertBFloat16VectorDataError(t *testing.T) {
},
},
}
_, err := RowBasedInsertMsgToInsertData(msg, schema)
_, err := RowBasedInsertMsgToInsertData(msg, schema, false)
assert.Error(t, err)
}

View File

@ -0,0 +1,184 @@
// Code generated by mockery v2.32.4. DO NOT EDIT.
package function
import (
schemapb "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
mock "github.com/stretchr/testify/mock"
)
// MockFunctionRunner is an autogenerated mock type for the FunctionRunner type
type MockFunctionRunner struct {
mock.Mock
}
type MockFunctionRunner_Expecter struct {
mock *mock.Mock
}
func (_m *MockFunctionRunner) EXPECT() *MockFunctionRunner_Expecter {
return &MockFunctionRunner_Expecter{mock: &_m.Mock}
}
// BatchRun provides a mock function with given fields: inputs
func (_m *MockFunctionRunner) BatchRun(inputs ...interface{}) ([]interface{}, error) {
var _ca []interface{}
_ca = append(_ca, inputs...)
ret := _m.Called(_ca...)
var r0 []interface{}
var r1 error
if rf, ok := ret.Get(0).(func(...interface{}) ([]interface{}, error)); ok {
return rf(inputs...)
}
if rf, ok := ret.Get(0).(func(...interface{}) []interface{}); ok {
r0 = rf(inputs...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]interface{})
}
}
if rf, ok := ret.Get(1).(func(...interface{}) error); ok {
r1 = rf(inputs...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockFunctionRunner_BatchRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BatchRun'
type MockFunctionRunner_BatchRun_Call struct {
*mock.Call
}
// BatchRun is a helper method to define mock.On call
// - inputs ...interface{}
func (_e *MockFunctionRunner_Expecter) BatchRun(inputs ...interface{}) *MockFunctionRunner_BatchRun_Call {
return &MockFunctionRunner_BatchRun_Call{Call: _e.mock.On("BatchRun",
append([]interface{}{}, inputs...)...)}
}
func (_c *MockFunctionRunner_BatchRun_Call) Run(run func(inputs ...interface{})) *MockFunctionRunner_BatchRun_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]interface{}, len(args)-0)
for i, a := range args[0:] {
if a != nil {
variadicArgs[i] = a.(interface{})
}
}
run(variadicArgs...)
})
return _c
}
func (_c *MockFunctionRunner_BatchRun_Call) Return(_a0 []interface{}, _a1 error) *MockFunctionRunner_BatchRun_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockFunctionRunner_BatchRun_Call) RunAndReturn(run func(...interface{}) ([]interface{}, error)) *MockFunctionRunner_BatchRun_Call {
_c.Call.Return(run)
return _c
}
// GetOutputFields provides a mock function with given fields:
func (_m *MockFunctionRunner) GetOutputFields() []*schemapb.FieldSchema {
ret := _m.Called()
var r0 []*schemapb.FieldSchema
if rf, ok := ret.Get(0).(func() []*schemapb.FieldSchema); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*schemapb.FieldSchema)
}
}
return r0
}
// MockFunctionRunner_GetOutputFields_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetOutputFields'
type MockFunctionRunner_GetOutputFields_Call struct {
*mock.Call
}
// GetOutputFields is a helper method to define mock.On call
func (_e *MockFunctionRunner_Expecter) GetOutputFields() *MockFunctionRunner_GetOutputFields_Call {
return &MockFunctionRunner_GetOutputFields_Call{Call: _e.mock.On("GetOutputFields")}
}
func (_c *MockFunctionRunner_GetOutputFields_Call) Run(run func()) *MockFunctionRunner_GetOutputFields_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockFunctionRunner_GetOutputFields_Call) Return(_a0 []*schemapb.FieldSchema) *MockFunctionRunner_GetOutputFields_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockFunctionRunner_GetOutputFields_Call) RunAndReturn(run func() []*schemapb.FieldSchema) *MockFunctionRunner_GetOutputFields_Call {
_c.Call.Return(run)
return _c
}
// GetSchema provides a mock function with given fields:
func (_m *MockFunctionRunner) GetSchema() *schemapb.FunctionSchema {
ret := _m.Called()
var r0 *schemapb.FunctionSchema
if rf, ok := ret.Get(0).(func() *schemapb.FunctionSchema); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*schemapb.FunctionSchema)
}
}
return r0
}
// MockFunctionRunner_GetSchema_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSchema'
type MockFunctionRunner_GetSchema_Call struct {
*mock.Call
}
// GetSchema is a helper method to define mock.On call
func (_e *MockFunctionRunner_Expecter) GetSchema() *MockFunctionRunner_GetSchema_Call {
return &MockFunctionRunner_GetSchema_Call{Call: _e.mock.On("GetSchema")}
}
func (_c *MockFunctionRunner_GetSchema_Call) Run(run func()) *MockFunctionRunner_GetSchema_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockFunctionRunner_GetSchema_Call) Return(_a0 *schemapb.FunctionSchema) *MockFunctionRunner_GetSchema_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockFunctionRunner_GetSchema_Call) RunAndReturn(run func() *schemapb.FunctionSchema) *MockFunctionRunner_GetSchema_Call {
_c.Call.Return(run)
return _c
}
// NewMockFunctionRunner creates a new instance of MockFunctionRunner. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockFunctionRunner(t interface {
mock.TestingT
Cleanup(func())
}) *MockFunctionRunner {
mock := &MockFunctionRunner{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -6,12 +6,26 @@ import (
"math"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func SparseVectorDataToPlaceholderGroupBytes(contents [][]byte) []byte {
placeholderGroup := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{{
Tag: "$0",
Type: commonpb.PlaceholderType_SparseFloatVector,
Values: contents,
}},
}
bytes, _ := proto.Marshal(placeholderGroup)
return bytes
}
func FieldDataToPlaceholderGroupBytes(fieldData *schemapb.FieldData) ([]byte, error) {
placeholderValue, err := fieldDataToPlaceholderValue(fieldData)
if err != nil {
@ -93,6 +107,14 @@ func fieldDataToPlaceholderValue(fieldData *schemapb.FieldData) (*commonpb.Place
Values: [][]byte{bytes},
}
return placeholderValue, nil
case schemapb.DataType_VarChar:
strs := fieldData.GetScalars().GetStringData().GetData()
placeholderValue := &commonpb.PlaceholderValue{
Tag: "$0",
Type: commonpb.PlaceholderType_VarChar,
Values: lo.Map(strs, func(str string, _ int) []byte { return []byte(str) }),
}
return placeholderValue, nil
default:
return nil, errors.New("field is not a vector field")
}
@ -157,3 +179,7 @@ func flattenedBFloat16VectorsToByteVectors(flattenedVectors []byte, dimension in
return result
}
func GetVarCharFromPlaceholder(holder *commonpb.PlaceholderValue) []string {
return lo.Map(holder.Values, func(bytes []byte, _ int) string { return string(bytes) })
}

View File

@ -36,4 +36,6 @@ const (
// SUPERSTRUCTURE represents superstructure distance
SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE"
BM25 MetricType = "BM25"
)

View File

@ -21,5 +21,5 @@ import "strings"
// PositivelyRelated return if metricType are "ip" or "IP"
func PositivelyRelated(metricType string) bool {
mUpper := strings.ToUpper(metricType)
return mUpper == strings.ToUpper(IP) || mUpper == strings.ToUpper(COSINE)
return mUpper == strings.ToUpper(IP) || mUpper == strings.ToUpper(COSINE) || mUpper == strings.ToUpper(BM25)
}

View File

@ -212,7 +212,7 @@ func TestIndexAutoSparseVector(t *testing.T) {
for _, unsupportedMt := range hp.UnsupportedSparseVecMetricsType {
idx := index.NewAutoIndex(unsupportedMt)
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idx))
common.CheckErr(t, err, false, "only IP is the supported metric type for sparse index")
common.CheckErr(t, err, false, "only IP&BM25 is the supported metric type for sparse index")
}
// auto index with different metric type on sparse vec
@ -829,11 +829,11 @@ func TestCreateSparseIndexInvalidParams(t *testing.T) {
for _, mt := range hp.UnsupportedSparseVecMetricsType {
idxInverted := index.NewSparseInvertedIndex(mt, 0.2)
_, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxInverted))
common.CheckErr(t, err, false, "only IP is the supported metric type for sparse index")
common.CheckErr(t, err, false, "only IP&BM25 is the supported metric type for sparse index")
idxWand := index.NewSparseWANDIndex(mt, 0.2)
_, err = mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, common.DefaultSparseVecFieldName, idxWand))
common.CheckErr(t, err, false, "only IP is the supported metric type for sparse index")
common.CheckErr(t, err, false, "only IP&BM25 is the supported metric type for sparse index")
}
// create index with invalid drop_ratio_build