mirror of https://github.com/milvus-io/milvus.git
Enable search with pagination (#19044)
See also: #19043 Signed-off-by: yangxuan <xuan.yang@zilliz.com> Signed-off-by: yangxuan <xuan.yang@zilliz.com>pull/19209/head
parent
992d163988
commit
200df76864
|
@ -44,6 +44,7 @@ type searchTask struct {
|
|||
collectionName string
|
||||
schema *schemapb.CollectionSchema
|
||||
|
||||
offset int64
|
||||
resultBuf chan *internalpb.SearchResults
|
||||
toReduceResults []*internalpb.SearchResults
|
||||
|
||||
|
@ -88,48 +89,64 @@ func getPartitionIDs(ctx context.Context, collectionName string, partitionNames
|
|||
return partitionIDs, nil
|
||||
}
|
||||
|
||||
func parseQueryInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInfo, error) {
|
||||
// parseQueryInfo returns QueryInfo and offset
|
||||
func parseQueryInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInfo, int64, error) {
|
||||
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
|
||||
if err != nil {
|
||||
return nil, errors.New(TopKKey + " not found in search_params")
|
||||
return nil, 0, errors.New(TopKKey + " not found in search_params")
|
||||
}
|
||||
topK, err := strconv.ParseInt(topKStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
|
||||
}
|
||||
if err := validateTopK(topK); err != nil {
|
||||
return nil, fmt.Errorf("invalid limit, %w", err)
|
||||
return nil, 0, fmt.Errorf("invalid limit, %w", err)
|
||||
}
|
||||
|
||||
var offset int64
|
||||
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, searchParamsPair)
|
||||
if err == nil {
|
||||
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
|
||||
}
|
||||
}
|
||||
|
||||
queryTopK := topK + offset
|
||||
if err := validateTopK(queryTopK); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(MetricTypeKey, searchParamsPair)
|
||||
if err != nil {
|
||||
return nil, errors.New(MetricTypeKey + " not found in search_params")
|
||||
return nil, 0, errors.New(MetricTypeKey + " not found in search_params")
|
||||
}
|
||||
|
||||
searchParams, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, searchParamsPair)
|
||||
if err != nil {
|
||||
return nil, errors.New(SearchParamsKey + " not found in search_params")
|
||||
return nil, 0, errors.New(SearchParamsKey + " not found in search_params")
|
||||
}
|
||||
|
||||
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, searchParamsPair)
|
||||
if err != nil {
|
||||
roundDecimalStr = "-1"
|
||||
}
|
||||
|
||||
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
||||
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
||||
return &planpb.QueryInfo{
|
||||
Topk: topK,
|
||||
Topk: queryTopK,
|
||||
MetricType: metricType,
|
||||
SearchParams: searchParams,
|
||||
RoundDecimal: roundDecimal,
|
||||
}, nil
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func getOutputFieldIDs(schema *schemapb.CollectionSchema, outputFields []string) (outputFieldIDs []UniqueID, err error) {
|
||||
|
@ -222,10 +239,11 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
return errors.New(AnnsFieldKey + " not found in search_params")
|
||||
}
|
||||
|
||||
queryInfo, err := parseQueryInfo(t.request.GetSearchParams())
|
||||
queryInfo, offset, err := parseQueryInfo(t.request.GetSearchParams())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.offset = offset
|
||||
|
||||
plan, err := planparserv2.CreateSearchPlan(t.schema, t.request.Dsl, annsField, queryInfo)
|
||||
if err != nil {
|
||||
|
@ -242,6 +260,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
||||
plan.OutputFieldIds = outputFieldIDs
|
||||
|
||||
|
@ -272,16 +291,18 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
guaranteeTs = parseGuaranteeTs(guaranteeTs, t.BeginTs())
|
||||
t.SearchRequest.GuaranteeTimestamp = guaranteeTs
|
||||
|
||||
deadline, ok := t.TraceCtx().Deadline()
|
||||
if ok {
|
||||
if deadline, ok := t.TraceCtx().Deadline(); ok {
|
||||
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
|
||||
}
|
||||
|
||||
t.SearchRequest.Dsl = t.request.Dsl
|
||||
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
||||
if t.SearchRequest.Nq, err = getNq(t.request); err != nil {
|
||||
nq, err := getNq(t.request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.SearchRequest.Nq = nq
|
||||
|
||||
log.Ctx(ctx).Debug("search PreExecute done.", zap.Int64("msgID", t.ID()),
|
||||
zap.Uint64("travel_ts", travelTimestamp), zap.Uint64("guarantee_ts", guaranteeTs),
|
||||
zap.Uint64("timeout_ts", t.SearchRequest.GetTimeoutTimestamp()))
|
||||
|
@ -327,24 +348,23 @@ func (t *searchTask) Execute(ctx context.Context) error {
|
|||
func (t *searchTask) PostExecute(ctx context.Context) error {
|
||||
sp, ctx := trace.StartSpanFromContextWithOperationName(t.TraceCtx(), "Proxy-Search-PostExecute")
|
||||
defer sp.Finish()
|
||||
|
||||
tr := timerecord.NewTimeRecorder("searchTask PostExecute")
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
}()
|
||||
|
||||
select {
|
||||
// in case timeout happened
|
||||
case <-t.TraceCtx().Done():
|
||||
log.Ctx(ctx).Debug("wait to finish timeout!", zap.Int64("msgID", t.ID()))
|
||||
return nil
|
||||
default:
|
||||
log.Ctx(ctx).Debug("all searches are finished or canceled", zap.Int64("msgID", t.ID()))
|
||||
close(t.resultBuf)
|
||||
for res := range t.resultBuf {
|
||||
t.toReduceResults = append(t.toReduceResults, res)
|
||||
log.Ctx(ctx).Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Int64("msgID", t.ID()))
|
||||
}
|
||||
var (
|
||||
Nq = t.SearchRequest.GetNq()
|
||||
Topk = t.SearchRequest.GetTopk()
|
||||
MetricType = t.SearchRequest.GetMetricType()
|
||||
)
|
||||
|
||||
if err := t.collectSearchResults(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Decode all search results
|
||||
tr.CtxRecord(ctx, "decodeResultStart")
|
||||
validSearchResults, err := decodeSearchResults(ctx, t.toReduceResults)
|
||||
if err != nil {
|
||||
|
@ -353,57 +373,31 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||
metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10),
|
||||
metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
|
||||
|
||||
log.Ctx(ctx).Debug("proxy search post execute stage 2", zap.Int64("msgID", t.ID()),
|
||||
zap.Int("len(validSearchResults)", len(validSearchResults)))
|
||||
if len(validSearchResults) <= 0 {
|
||||
log.Ctx(ctx).Warn("search result is empty", zap.Int64("msgID", t.ID()))
|
||||
|
||||
t.result = &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "search result is empty",
|
||||
},
|
||||
CollectionName: t.collectionName,
|
||||
}
|
||||
// add information if any
|
||||
if len(t.toReduceResults) > 0 {
|
||||
t.result.Results = &schemapb.SearchResultData{
|
||||
NumQueries: t.toReduceResults[0].NumQueries,
|
||||
Topks: make([]int64, t.toReduceResults[0].NumQueries),
|
||||
}
|
||||
}
|
||||
t.fillInEmptyResult(Nq)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reduce all search results
|
||||
log.Ctx(ctx).Debug("proxy search post execute reduce", zap.Int64("msgID", t.ID()), zap.Int("number of valid search results", len(validSearchResults)))
|
||||
tr.CtxRecord(ctx, "reduceResultStart")
|
||||
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(t.schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.result, err = reduceSearchResultData(ctx, validSearchResults, t.toReduceResults[0].NumQueries,
|
||||
t.toReduceResults[0].TopK, t.toReduceResults[0].MetricType, primaryFieldSchema.DataType)
|
||||
|
||||
t.result, err = reduceSearchResultData(ctx, validSearchResults, Nq, Topk, MetricType, primaryFieldSchema.DataType, t.offset)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.GetNodeID(), 10), metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
|
||||
t.result.CollectionName = t.collectionName
|
||||
|
||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, t.request.CollectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(t.request.OutputFields) != 0 && len(t.result.Results.FieldsData) != 0 {
|
||||
for k, fieldName := range t.request.OutputFields {
|
||||
for _, field := range schema.Fields {
|
||||
if t.result.Results.FieldsData[k] != nil && field.Name == fieldName {
|
||||
t.result.Results.FieldsData[k].FieldName = field.Name
|
||||
t.result.Results.FieldsData[k].FieldId = field.FieldID
|
||||
t.result.Results.FieldsData[k].Type = field.DataType
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
t.result.CollectionName = t.collectionName
|
||||
t.fillInFieldInfo()
|
||||
|
||||
log.Ctx(ctx).Debug("Search post execute done", zap.Int64("msgID", t.ID()))
|
||||
return nil
|
||||
}
|
||||
|
@ -435,6 +429,50 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) fillInEmptyResult(numQueries int64) {
|
||||
t.result = &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "search result is empty",
|
||||
},
|
||||
CollectionName: t.collectionName,
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: numQueries,
|
||||
Topks: make([]int64, numQueries),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *searchTask) fillInFieldInfo() {
|
||||
if len(t.request.OutputFields) != 0 && len(t.result.Results.FieldsData) != 0 {
|
||||
for i, name := range t.request.OutputFields {
|
||||
for _, field := range t.schema.Fields {
|
||||
if t.result.Results.FieldsData[i] != nil && field.Name == name {
|
||||
t.result.Results.FieldsData[i].FieldName = field.Name
|
||||
t.result.Results.FieldsData[i].FieldId = field.FieldID
|
||||
t.result.Results.FieldsData[i].Type = field.DataType
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *searchTask) collectSearchResults(ctx context.Context) error {
|
||||
select {
|
||||
case <-t.TraceCtx().Done():
|
||||
log.Ctx(ctx).Debug("wait to finish timeout!", zap.Int64("msgID", t.ID()))
|
||||
return fmt.Errorf("search task wait to finish timeout, msgID=%d", t.ID())
|
||||
default:
|
||||
log.Ctx(ctx).Debug("all searches are finished or canceled", zap.Int64("msgID", t.ID()))
|
||||
close(t.resultBuf)
|
||||
for res := range t.resultBuf {
|
||||
t.toReduceResults = append(t.toReduceResults, res)
|
||||
log.Ctx(ctx).Debug("proxy receives one search result", zap.Int64("sourceID", res.GetBase().GetSourceID()), zap.Int64("msgID", t.ID()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkIfLoaded check if collection was loaded into QueryNode
|
||||
func checkIfLoaded(ctx context.Context, qc types.QueryCoord, collectionName string, searchPartitionIDs []UniqueID) (bool, error) {
|
||||
info, err := globalMetaCache.GetCollectionInfo(ctx, collectionName)
|
||||
|
@ -508,45 +546,56 @@ func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64
|
|||
return nil
|
||||
}
|
||||
|
||||
func selectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffsets [][]int64, offsets []int64, qi int64) int {
|
||||
sel := -1
|
||||
maxDistance := minFloat32 // distance here means score :)
|
||||
for i, offset := range offsets { // query num, the number of ways to merge
|
||||
if offset >= dataArray[i].Topks[qi] {
|
||||
func selectHighestScoreIndex(subSearchResultData []*schemapb.SearchResultData, subSearchNqOffset [][]int64, cursors []int64, qi int64) (int, int64) {
|
||||
var (
|
||||
subSearchIdx = -1
|
||||
resultDataIdx int64 = -1
|
||||
)
|
||||
maxScore := minFloat32
|
||||
for i := range cursors {
|
||||
if cursors[i] >= subSearchResultData[i].Topks[qi] {
|
||||
continue
|
||||
}
|
||||
idx := resultOffsets[i][qi] + offset
|
||||
distance := dataArray[i].Scores[idx]
|
||||
if distance > maxDistance {
|
||||
sel = i
|
||||
maxDistance = distance
|
||||
sIdx := subSearchNqOffset[i][qi] + cursors[i]
|
||||
sScore := subSearchResultData[i].Scores[sIdx]
|
||||
if sScore > maxScore {
|
||||
subSearchIdx = i
|
||||
resultDataIdx = sIdx
|
||||
|
||||
maxScore = sScore
|
||||
}
|
||||
}
|
||||
return sel
|
||||
return subSearchIdx, resultDataIdx
|
||||
}
|
||||
|
||||
func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType) (*milvuspb.SearchResults, error) {
|
||||
func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, nq int64, topk int64, metricType string, pkType schemapb.DataType, offset int64) (*milvuspb.SearchResults, error) {
|
||||
tr := timerecord.NewTimeRecorder("reduceSearchResultData")
|
||||
defer func() {
|
||||
tr.CtxElapse(ctx, "done")
|
||||
}()
|
||||
|
||||
log.Ctx(ctx).Debug("reduceSearchResultData", zap.Int("len(searchResultData)", len(searchResultData)),
|
||||
zap.Int64("nq", nq), zap.Int64("topk", topk), zap.String("metricType", metricType))
|
||||
limit := topk - offset
|
||||
log.Ctx(ctx).Debug("reduceSearchResultData",
|
||||
zap.Int("len(subSearchResultData)", len(subSearchResultData)),
|
||||
zap.Int64("nq", nq),
|
||||
zap.Int64("offset", offset),
|
||||
zap.Int64("limit", limit),
|
||||
zap.String("metricType", metricType))
|
||||
|
||||
ret := &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: 0,
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: topk,
|
||||
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
|
||||
Scores: make([]float32, 0),
|
||||
FieldsData: make([]*schemapb.FieldData, len(subSearchResultData[0].FieldsData)),
|
||||
Scores: []float32{},
|
||||
Ids: &schemapb.IDs{},
|
||||
Topks: make([]int64, 0),
|
||||
Topks: []int64{},
|
||||
},
|
||||
}
|
||||
|
||||
switch pkType {
|
||||
case schemapb.DataType_Int64:
|
||||
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
||||
|
@ -564,13 +613,12 @@ func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
|
|||
return nil, errors.New("unsupported pk type")
|
||||
}
|
||||
|
||||
for i, sData := range searchResultData {
|
||||
log.Ctx(ctx).Debug("reduceSearchResultData",
|
||||
for i, sData := range subSearchResultData {
|
||||
log.Ctx(ctx).Debug("subSearchResultData",
|
||||
zap.Int("result No.", i),
|
||||
zap.Int64("nq", sData.NumQueries),
|
||||
zap.Int64("topk", sData.TopK),
|
||||
zap.Any("len(topks)", len(sData.Topks)),
|
||||
zap.Any("len(FieldsData)", len(sData.FieldsData)))
|
||||
zap.Any("length of FieldsData", len(sData.FieldsData)))
|
||||
if err := checkSearchResultData(sData, nq, topk); err != nil {
|
||||
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||
return ret, err
|
||||
|
@ -578,34 +626,61 @@ func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
|
|||
//printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
|
||||
}
|
||||
|
||||
resultOffsets := make([][]int64, len(searchResultData))
|
||||
for i := 0; i < len(searchResultData); i++ {
|
||||
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
||||
var (
|
||||
subSearchNum = len(subSearchResultData)
|
||||
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
||||
subSearchNqOffset = make([][]int64, subSearchNum)
|
||||
)
|
||||
for i := 0; i < subSearchNum; i++ {
|
||||
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
||||
for j := int64(1); j < nq; j++ {
|
||||
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
||||
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
||||
}
|
||||
}
|
||||
|
||||
var skipDupCnt int64
|
||||
var realTopK int64 = -1
|
||||
for i := int64(0); i < nq; i++ {
|
||||
offsets := make([]int64, len(searchResultData))
|
||||
var (
|
||||
skipDupCnt int64
|
||||
realTopK int64 = -1
|
||||
)
|
||||
|
||||
var idSet = make(map[interface{}]struct{})
|
||||
var j int64
|
||||
for j = 0; j < topk; {
|
||||
sel := selectSearchResultData(searchResultData, resultOffsets, offsets, i)
|
||||
if sel == -1 {
|
||||
// reducing nq * topk results
|
||||
for i := int64(0); i < nq; i++ {
|
||||
|
||||
var (
|
||||
// cursor of current data of each subSearch for merging the j-th data of TopK.
|
||||
// sum(cursors) == j
|
||||
cursors = make([]int64, subSearchNum)
|
||||
|
||||
j int64
|
||||
idSet = make(map[interface{}]struct{})
|
||||
)
|
||||
|
||||
// skip offset results
|
||||
for k := int64(0); k < offset; k++ {
|
||||
subSearchIdx, _ := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
}
|
||||
idx := resultOffsets[sel][i] + offsets[sel]
|
||||
|
||||
id := typeutil.GetPK(searchResultData[sel].GetIds(), idx)
|
||||
score := searchResultData[sel].Scores[idx]
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
|
||||
// keep limit results
|
||||
for j = 0; j < limit; {
|
||||
// From all the sub-query result sets of the i-th query vector,
|
||||
// find the sub-query result set index of the score j-th data,
|
||||
// and the index of the data in schemapb.SearchResultData
|
||||
subSearchIdx, resultDataIdx := selectHighestScoreIndex(subSearchResultData, subSearchNqOffset, cursors, i)
|
||||
if subSearchIdx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
id := typeutil.GetPK(subSearchResultData[subSearchIdx].GetIds(), resultDataIdx)
|
||||
score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
|
||||
|
||||
// remove duplicates
|
||||
if _, ok := idSet[id]; !ok {
|
||||
typeutil.AppendFieldData(ret.Results.FieldsData, searchResultData[sel].FieldsData, idx)
|
||||
typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
|
||||
typeutil.AppendPKs(ret.Results.Ids, id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
idSet[id] = struct{}{}
|
||||
|
@ -614,7 +689,7 @@ func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
|
|||
// skip entity with same id
|
||||
skipDupCnt++
|
||||
}
|
||||
offsets[sel]++
|
||||
cursors[subSearchIdx]++
|
||||
}
|
||||
if realTopK != -1 && realTopK != j {
|
||||
log.Ctx(ctx).Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different")))
|
||||
|
@ -624,8 +699,12 @@ func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
|
|||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||
}
|
||||
log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
||||
ret.Results.TopK = realTopK
|
||||
|
||||
if skipDupCnt > 0 {
|
||||
log.Info("skip duplicated search result", zap.Int64("count", skipDupCnt))
|
||||
}
|
||||
|
||||
ret.Results.TopK = realTopK // realTopK is the topK of the nq-th query
|
||||
if !distance.PositivelyRelated(metricType) {
|
||||
for k := range ret.Results.Scores {
|
||||
ret.Results.Scores[k] *= -1
|
||||
|
@ -635,17 +714,17 @@ func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
//func printSearchResultData(data *schemapb.SearchResultData, header string) {
|
||||
// size := len(data.Ids.GetIntId().Data)
|
||||
// if size != len(data.Scores) {
|
||||
// log.Error("SearchResultData length mis-match")
|
||||
// }
|
||||
// log.Debug("==== SearchResultData ====",
|
||||
// zap.String("header", header), zap.Int64("nq", data.NumQueries), zap.Int64("topk", data.TopK))
|
||||
// for i := 0; i < size; i++ {
|
||||
// log.Debug("", zap.Int("i", i), zap.Int64("id", data.Ids.GetIntId().Data[i]), zap.Float32("score", data.Scores[i]))
|
||||
// }
|
||||
//}
|
||||
// func printSearchResultData(data *schemapb.SearchResultData, header string) {
|
||||
// size := len(data.GetIds().GetIntId().GetData())
|
||||
// if size != len(data.Scores) {
|
||||
// log.Error("SearchResultData length mis-match")
|
||||
// }
|
||||
// log.Debug("==== SearchResultData ====",
|
||||
// zap.String("header", header), zap.Int64("nq", data.NumQueries), zap.Int64("topk", data.TopK))
|
||||
// for i := 0; i < size; i++ {
|
||||
// log.Debug("", zap.Int("i", i), zap.Int64("id", data.GetIds().GetIntId().Data[i]), zap.Float32("score", data.Scores[i]))
|
||||
// }
|
||||
// }
|
||||
|
||||
func (t *searchTask) TraceCtx() context.Context {
|
||||
return t.ctx
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -1102,302 +1103,370 @@ func Test_checkSearchResultData(t *testing.T) {
|
|||
topk int64
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
description string
|
||||
wantErr bool
|
||||
|
||||
args args
|
||||
}{
|
||||
{
|
||||
args: args{
|
||||
{"data.NumQueries != nq", true,
|
||||
args{
|
||||
data: &schemapb.SearchResultData{NumQueries: 100},
|
||||
nq: 10,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
}},
|
||||
{"data.TopK != topk", true,
|
||||
args{
|
||||
data: &schemapb.SearchResultData{NumQueries: 1, TopK: 1},
|
||||
nq: 1,
|
||||
topk: 10,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
}},
|
||||
{"size of IntId != NumQueries * TopK", true,
|
||||
args{
|
||||
data: &schemapb.SearchResultData{
|
||||
NumQueries: 1,
|
||||
TopK: 1,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2}, // != nq * topk
|
||||
},
|
||||
},
|
||||
},
|
||||
IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1, 2}}}},
|
||||
},
|
||||
nq: 1,
|
||||
topk: 1,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
}},
|
||||
{"size of StrID != NumQueries * TopK", true,
|
||||
args{
|
||||
data: &schemapb.SearchResultData{
|
||||
NumQueries: 1,
|
||||
TopK: 1,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"1", "2"}, // != nq * topk
|
||||
},
|
||||
},
|
||||
},
|
||||
IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: []string{"1", "2"}}}},
|
||||
},
|
||||
nq: 1,
|
||||
topk: 1,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
}},
|
||||
{"size of score != nq * topK", true,
|
||||
args{
|
||||
data: &schemapb.SearchResultData{
|
||||
NumQueries: 1,
|
||||
TopK: 1,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{0.99, 0.98}, // != nq * topk
|
||||
IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1}}}},
|
||||
Scores: []float32{0.99, 0.98},
|
||||
},
|
||||
nq: 1,
|
||||
topk: 1,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
}},
|
||||
{"correct params", false,
|
||||
args{
|
||||
data: &schemapb.SearchResultData{
|
||||
NumQueries: 1,
|
||||
TopK: 1,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{0.99},
|
||||
},
|
||||
IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1}}}},
|
||||
Scores: []float32{0.99}},
|
||||
nq: 1,
|
||||
topk: 1,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := checkSearchResultData(tt.args.data, tt.args.nq, tt.args.topk); (err != nil) != tt.wantErr {
|
||||
t.Errorf("checkSearchResultData(%v, %v, %v) error = %v, wantErr %v",
|
||||
tt.args.data, tt.args.nq, tt.args.topk, err, tt.wantErr)
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
err := checkSearchResultData(test.args.data, test.args.nq, test.args.topk)
|
||||
|
||||
if test.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_selectSearchResultData_int(t *testing.T) {
|
||||
type args struct {
|
||||
dataArray []*schemapb.SearchResultData
|
||||
resultOffsets [][]int64
|
||||
offsets []int64
|
||||
topk int64
|
||||
nq int64
|
||||
qi int64
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want int
|
||||
}{
|
||||
{
|
||||
args: args{
|
||||
dataArray: []*schemapb.SearchResultData{
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{11, 9, 7, 5, 3, 1},
|
||||
func TestTaskSearch_selectHighestScoreIndex(t *testing.T) {
|
||||
t.Run("Integer ID", func(t *testing.T) {
|
||||
type args struct {
|
||||
subSearchResultData []*schemapb.SearchResultData
|
||||
subSearchNqOffset [][]int64
|
||||
cursors []int64
|
||||
topk int64
|
||||
nq int64
|
||||
}
|
||||
tests := []struct {
|
||||
description string
|
||||
args args
|
||||
|
||||
expectedIdx []int
|
||||
expectedDataIdx []int
|
||||
}{
|
||||
{
|
||||
description: "reduce 2 subSearchResultData",
|
||||
args: args{
|
||||
subSearchResultData: []*schemapb.SearchResultData{
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{11, 9, 8, 5, 3, 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.1},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.7, 0.5, 0.3, 0.1},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{12, 10, 8, 6, 4, 2},
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{12, 10, 7, 6, 4, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.2, 1.0, 0.7, 0.6, 0.4, 0.2},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
Scores: []float32{1.2, 1.0, 0.8, 0.6, 0.4, 0.2},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
subSearchNqOffset: [][]int64{{0, 2, 4}, {0, 2, 4}},
|
||||
cursors: []int64{0, 0},
|
||||
topk: 2,
|
||||
nq: 3,
|
||||
},
|
||||
resultOffsets: [][]int64{{0, 2, 4}, {0, 2, 4}},
|
||||
offsets: []int64{0, 1},
|
||||
topk: 2,
|
||||
nq: 3,
|
||||
qi: 0,
|
||||
expectedIdx: []int{1, 0, 1},
|
||||
expectedDataIdx: []int{0, 2, 4},
|
||||
},
|
||||
want: 0,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
for nqNum := int64(0); nqNum < test.args.nq; nqNum++ {
|
||||
idx, dataIdx := selectHighestScoreIndex(test.args.subSearchResultData, test.args.subSearchNqOffset, test.args.cursors, nqNum)
|
||||
assert.Equal(t, test.expectedIdx[nqNum], idx)
|
||||
assert.Equal(t, test.expectedDataIdx[nqNum], int(dataIdx))
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("String ID", func(t *testing.T) {
|
||||
type args struct {
|
||||
subSearchResultData []*schemapb.SearchResultData
|
||||
subSearchNqOffset [][]int64
|
||||
cursors []int64
|
||||
topk int64
|
||||
nq int64
|
||||
}
|
||||
tests := []struct {
|
||||
description string
|
||||
args args
|
||||
|
||||
expectedIdx []int
|
||||
expectedDataIdx []int
|
||||
}{
|
||||
{
|
||||
description: "reduce 2 subSearchResultData",
|
||||
args: args{
|
||||
subSearchResultData: []*schemapb.SearchResultData{
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"11", "9", "8", "5", "3", "1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.1},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"12", "10", "7", "6", "4", "2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.2, 1.0, 0.7, 0.6, 0.4, 0.2},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
},
|
||||
subSearchNqOffset: [][]int64{{0, 2, 4}, {0, 2, 4}},
|
||||
cursors: []int64{0, 0},
|
||||
topk: 2,
|
||||
nq: 3,
|
||||
},
|
||||
expectedIdx: []int{1, 0, 1},
|
||||
expectedDataIdx: []int{0, 2, 4},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
for nqNum := int64(0); nqNum < test.args.nq; nqNum++ {
|
||||
idx, dataIdx := selectHighestScoreIndex(test.args.subSearchResultData, test.args.subSearchNqOffset, test.args.cursors, nqNum)
|
||||
assert.Equal(t, test.expectedIdx[nqNum], idx)
|
||||
assert.Equal(t, test.expectedDataIdx[nqNum], int(dataIdx))
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
||||
var (
|
||||
topk int64 = 5
|
||||
nq int64 = 2
|
||||
)
|
||||
|
||||
data := [][]int64{
|
||||
{10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
|
||||
{20, 19, 18, 17, 16, 15, 14, 13, 12, 11},
|
||||
{30, 29, 28, 27, 26, 25, 24, 23, 22, 21},
|
||||
{40, 39, 38, 37, 36, 35, 34, 33, 32, 31},
|
||||
{50, 49, 48, 47, 46, 45, 44, 43, 42, 41},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := selectSearchResultData(tt.args.dataArray, tt.args.resultOffsets, tt.args.offsets, tt.args.qi); got != tt.want {
|
||||
t.Errorf("selectSearchResultData() = %v, want %v", got, tt.want)
|
||||
|
||||
score := [][]float32{
|
||||
{10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
|
||||
{20, 19, 18, 17, 16, 15, 14, 13, 12, 11},
|
||||
{30, 29, 28, 27, 26, 25, 24, 23, 22, 21},
|
||||
{40, 39, 38, 37, 36, 35, 34, 33, 32, 31},
|
||||
{50, 49, 48, 47, 46, 45, 44, 43, 42, 41},
|
||||
}
|
||||
|
||||
resultScore := []float32{-50, -49, -48, -47, -46, -45, -44, -43, -42, -41}
|
||||
|
||||
t.Run("Offset limit", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
description string
|
||||
offset int64
|
||||
limit int64
|
||||
|
||||
outScore []float32
|
||||
outData []int64
|
||||
}{
|
||||
{"offset 0, limit 5", 0, 5,
|
||||
[]float32{-50, -49, -48, -47, -46, -45, -44, -43, -42, -41},
|
||||
[]int64{50, 49, 48, 47, 46, 45, 44, 43, 42, 41}},
|
||||
{"offset 1, limit 4", 1, 4,
|
||||
[]float32{-49, -48, -47, -46, -44, -43, -42, -41},
|
||||
[]int64{49, 48, 47, 46, 44, 43, 42, 41}},
|
||||
{"offset 2, limit 3", 2, 3,
|
||||
[]float32{-48, -47, -46, -43, -42, -41},
|
||||
[]int64{48, 47, 46, 43, 42, 41}},
|
||||
{"offset 3, limit 2", 3, 2,
|
||||
[]float32{-47, -46, -42, -41},
|
||||
[]int64{47, 46, 42, 41}},
|
||||
{"offset 4, limit 1", 4, 1,
|
||||
[]float32{-46, -41},
|
||||
[]int64{46, 41}},
|
||||
}
|
||||
|
||||
var results []*schemapb.SearchResultData
|
||||
for i := range data {
|
||||
r := getSearchResultData(nq, topk)
|
||||
|
||||
r.Ids.IdField = &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: data[i]}}
|
||||
r.Scores = score[i]
|
||||
r.Topks = []int64{5, 5}
|
||||
|
||||
results = append(results, r)
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, distance.L2, schemapb.DataType_Int64, test.offset)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||
assert.Equal(t, []int64{test.limit, test.limit}, reduced.GetResults().GetTopks())
|
||||
assert.Equal(t, test.limit, reduced.GetResults().GetTopK())
|
||||
assert.InDeltaSlice(t, test.outScore, reduced.GetResults().GetScores(), 10e-8)
|
||||
})
|
||||
}
|
||||
|
||||
lessThanLimitTests := []struct {
|
||||
description string
|
||||
offset int64
|
||||
limit int64
|
||||
|
||||
outLimit int64
|
||||
outScore []float32
|
||||
outData []int64
|
||||
}{
|
||||
{"offset 0, limit 6", 0, 6, 5,
|
||||
[]float32{-50, -49, -48, -47, -46, -45, -44, -43, -42, -41},
|
||||
[]int64{50, 49, 48, 47, 46, 45, 44, 43, 42, 41}},
|
||||
{"offset 1, limit 5", 1, 5, 4,
|
||||
[]float32{-49, -48, -47, -46, -44, -43, -42, -41},
|
||||
[]int64{49, 48, 47, 46, 44, 43, 42, 41}},
|
||||
{"offset 2, limit 4", 2, 4, 3,
|
||||
[]float32{-48, -47, -46, -43, -42, -41},
|
||||
[]int64{48, 47, 46, 43, 42, 41}},
|
||||
{"offset 3, limit 3", 3, 3, 2,
|
||||
[]float32{-47, -46, -42, -41},
|
||||
[]int64{47, 46, 42, 41}},
|
||||
{"offset 4, limit 2", 4, 2, 1,
|
||||
[]float32{-46, -41},
|
||||
[]int64{46, 41}},
|
||||
{"offset 5, limit 1", 5, 1, 0,
|
||||
[]float32{},
|
||||
[]int64{}},
|
||||
}
|
||||
|
||||
for _, test := range lessThanLimitTests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, distance.L2, schemapb.DataType_Int64, test.offset)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||
assert.Equal(t, []int64{test.outLimit, test.outLimit}, reduced.GetResults().GetTopks())
|
||||
assert.Equal(t, test.outLimit, reduced.GetResults().GetTopK())
|
||||
assert.InDeltaSlice(t, test.outScore, reduced.GetResults().GetScores(), 10e-8)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Int64 ID", func(t *testing.T) {
|
||||
resultData := []int64{50, 49, 48, 47, 46, 45, 44, 43, 42, 41}
|
||||
|
||||
var results []*schemapb.SearchResultData
|
||||
for i := range data {
|
||||
r := getSearchResultData(nq, topk)
|
||||
|
||||
r.Ids.IdField = &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: data[i]}}
|
||||
r.Scores = score[i]
|
||||
r.Topks = []int64{5, 5}
|
||||
|
||||
results = append(results, r)
|
||||
}
|
||||
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, distance.L2, schemapb.DataType_Int64, 0)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||
assert.Equal(t, []int64{5, 5}, reduced.GetResults().GetTopks())
|
||||
assert.Equal(t, int64(5), reduced.GetResults().GetTopK())
|
||||
assert.InDeltaSlice(t, resultScore, reduced.GetResults().GetScores(), 10e-8)
|
||||
})
|
||||
|
||||
t.Run("String ID", func(t *testing.T) {
|
||||
resultData := []string{"50", "49", "48", "47", "46", "45", "44", "43", "42", "41"}
|
||||
|
||||
var results []*schemapb.SearchResultData
|
||||
for i := range data {
|
||||
r := getSearchResultData(nq, topk)
|
||||
|
||||
var strData []string
|
||||
for _, d := range data[i] {
|
||||
strData = append(strData, strconv.FormatInt(d, 10))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
r.Ids.IdField = &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: strData}}
|
||||
r.Scores = score[i]
|
||||
r.Topks = []int64{5, 5}
|
||||
|
||||
func Test_selectSearchResultData_str(t *testing.T) {
|
||||
type args struct {
|
||||
dataArray []*schemapb.SearchResultData
|
||||
resultOffsets [][]int64
|
||||
offsets []int64
|
||||
topk int64
|
||||
nq int64
|
||||
qi int64
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want int
|
||||
}{
|
||||
{
|
||||
args: args{
|
||||
dataArray: []*schemapb.SearchResultData{
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"11", "9", "7", "5", "3", "1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.7, 0.5, 0.3, 0.1},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
{
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"12", "10", "8", "6", "4", "2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.2, 1.0, 0.8, 0.6, 0.4, 0.2},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
},
|
||||
resultOffsets: [][]int64{{0, 2, 4}, {0, 2, 4}},
|
||||
offsets: []int64{0, 1},
|
||||
topk: 2,
|
||||
nq: 3,
|
||||
qi: 1,
|
||||
},
|
||||
want: 0,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := selectSearchResultData(tt.args.dataArray, tt.args.resultOffsets, tt.args.offsets, tt.args.qi); got != tt.want {
|
||||
t.Errorf("selectSearchResultData() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
results = append(results, r)
|
||||
}
|
||||
|
||||
func Test_reduceSearchResultData_int(t *testing.T) {
|
||||
topk := 2
|
||||
nq := 3
|
||||
results := []*schemapb.SearchResultData{
|
||||
{
|
||||
NumQueries: int64(nq),
|
||||
TopK: int64(topk),
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{11, 9, 7, 5, 3, 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.7, 0.5, 0.3, 0.1},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
{
|
||||
NumQueries: int64(nq),
|
||||
TopK: int64(topk),
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{12, 10, 8, 6, 4, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.2, 1.0, 0.8, 0.6, 0.4, 0.2},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
}
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, nq, topk, distance.L2, schemapb.DataType_VarChar, 0)
|
||||
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, int64(nq), int64(topk), distance.L2, schemapb.DataType_Int64)
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, []int64{3, 4, 7, 8, 11, 12}, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||
// hard to compare floating point value.
|
||||
// TODO: compare scores.
|
||||
}
|
||||
|
||||
func Test_reduceSearchResultData_str(t *testing.T) {
|
||||
topk := 2
|
||||
nq := 3
|
||||
results := []*schemapb.SearchResultData{
|
||||
{
|
||||
NumQueries: int64(nq),
|
||||
TopK: int64(topk),
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"11", "9", "7", "5", "3", "1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.7, 0.5, 0.3, 0.1},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
{
|
||||
NumQueries: int64(nq),
|
||||
TopK: int64(topk),
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"12", "10", "8", "6", "4", "2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.2, 1.0, 0.8, 0.6, 0.4, 0.2},
|
||||
Topks: []int64{2, 2, 2},
|
||||
},
|
||||
}
|
||||
|
||||
reduced, err := reduceSearchResultData(context.TODO(), results, int64(nq), int64(topk), distance.L2, schemapb.DataType_VarChar)
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, []string{"3", "4", "7", "8", "11", "12"}, reduced.GetResults().GetIds().GetStrId().GetData())
|
||||
// hard to compare floating point value.
|
||||
// TODO: compare scores.
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetStrId().GetData())
|
||||
assert.Equal(t, []int64{5, 5}, reduced.GetResults().GetTopks())
|
||||
assert.Equal(t, int64(5), reduced.GetResults().GetTopK())
|
||||
assert.InDeltaSlice(t, resultScore, reduced.GetResults().GetScores(), 10e-8)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_checkIfLoaded(t *testing.T) {
|
||||
|
@ -1687,6 +1756,21 @@ func TestSearchTask_ErrExecute(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
||||
t.Run("parseQueryInfo no error", func(t *testing.T) {
|
||||
var targetOffset int64 = 200
|
||||
|
||||
sp := getValidSearchParams()
|
||||
sp = append(sp, &commonpb.KeyValuePair{
|
||||
Key: OffsetKey,
|
||||
Value: strconv.FormatInt(targetOffset, 10),
|
||||
})
|
||||
|
||||
info, offset, err := parseQueryInfo(sp)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, info)
|
||||
assert.Equal(t, targetOffset, offset)
|
||||
})
|
||||
|
||||
t.Run("parseQueryInfo error", func(t *testing.T) {
|
||||
spNoTopk := []*commonpb.KeyValuePair{{
|
||||
Key: AnnsFieldKey,
|
||||
|
@ -1707,10 +1791,17 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
Value: "10",
|
||||
})
|
||||
|
||||
spInvalidTopkPlusOffset := append(spNoTopk, &commonpb.KeyValuePair{
|
||||
Key: OffsetKey,
|
||||
Value: "65535",
|
||||
})
|
||||
|
||||
spNoSearchParams := append(spNoMetricType, &commonpb.KeyValuePair{
|
||||
Key: MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
})
|
||||
|
||||
// no roundDecimal is valid
|
||||
noRoundDecimal := append(spNoSearchParams, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10}`,
|
||||
|
@ -1726,6 +1817,11 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
Value: "invalid",
|
||||
})
|
||||
|
||||
spInvalidOffset := append(noRoundDecimal, &commonpb.KeyValuePair{
|
||||
Key: OffsetKey,
|
||||
Value: "invalid",
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
description string
|
||||
invalidParams []*commonpb.KeyValuePair
|
||||
|
@ -1733,18 +1829,32 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||
{"No_topk", spNoTopk},
|
||||
{"Invalid_topk", spInvalidTopk},
|
||||
{"Invalid_topk_65536", spInvalidTopk65536},
|
||||
{"Invalid_topk_plus_offset", spInvalidTopkPlusOffset},
|
||||
{"No_Metric_type", spNoMetricType},
|
||||
{"No_search_params", spNoSearchParams},
|
||||
{"Invalid_round_decimal", spInvalidRoundDecimal},
|
||||
{"Invalid_round_decimal_1000", spInvalidRoundDecimal2},
|
||||
{"Invalid_offset", spInvalidOffset},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, err := parseQueryInfo(test.invalidParams)
|
||||
info, offset, err := parseQueryInfo(test.invalidParams)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, info)
|
||||
assert.Zero(t, offset)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
|
||||
result := schemapb.SearchResultData{
|
||||
NumQueries: nq,
|
||||
TopK: topk,
|
||||
Ids: &schemapb.IDs{},
|
||||
Scores: []float32{},
|
||||
Topks: []int64{},
|
||||
}
|
||||
return &result
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue