milvus/internal/util/segcore/reduce.go

171 lines
5.6 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package segcore
/*
#cgo pkg-config: milvus_core
#include "segcore/plan_c.h"
#include "segcore/reduce_c.h"
*/
import "C"
import (
"context"
"fmt"
"github.com/cockroachdb/errors"
)
type SliceInfo struct {
SliceNQs []int64
SliceTopKs []int64
}
// SearchResultDataBlobs is the CSearchResultsDataBlobs in C++
type (
SearchResultDataBlobs = C.CSearchResultDataBlobs
StreamSearchReducer = C.CSearchStreamReducer
)
func ParseSliceInfo(originNQs []int64, originTopKs []int64, nqPerSlice int64) *SliceInfo {
sInfo := &SliceInfo{
SliceNQs: make([]int64, 0),
SliceTopKs: make([]int64, 0),
}
if nqPerSlice == 0 {
return sInfo
}
for i := 0; i < len(originNQs); i++ {
for j := 0; j < int(originNQs[i]/nqPerSlice); j++ {
sInfo.SliceNQs = append(sInfo.SliceNQs, nqPerSlice)
sInfo.SliceTopKs = append(sInfo.SliceTopKs, originTopKs[i])
}
if tailSliceSize := originNQs[i] % nqPerSlice; tailSliceSize > 0 {
sInfo.SliceNQs = append(sInfo.SliceNQs, tailSliceSize)
sInfo.SliceTopKs = append(sInfo.SliceTopKs, originTopKs[i])
}
}
return sInfo
}
func NewStreamReducer(ctx context.Context,
plan *SearchPlan,
sliceNQs []int64,
sliceTopKs []int64,
) (StreamSearchReducer, error) {
if plan.cSearchPlan == nil {
return nil, fmt.Errorf("nil search plan")
}
if len(sliceNQs) == 0 {
return nil, fmt.Errorf("empty slice nqs is not allowed")
}
if len(sliceNQs) != len(sliceTopKs) {
return nil, fmt.Errorf("unaligned sliceNQs(len=%d) and sliceTopKs(len=%d)", len(sliceNQs), len(sliceTopKs))
}
cSliceNQSPtr := (*C.int64_t)(&sliceNQs[0])
cSliceTopKSPtr := (*C.int64_t)(&sliceTopKs[0])
cNumSlices := C.int64_t(len(sliceNQs))
var streamReducer StreamSearchReducer
status := C.NewStreamReducer(plan.cSearchPlan, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices, &streamReducer)
if err := ConsumeCStatusIntoError(&status); err != nil {
return nil, errors.Wrap(err, "MergeSearchResultsWithOutputFields failed")
}
return streamReducer, nil
}
func StreamReduceSearchResult(ctx context.Context,
newResult *SearchResult, streamReducer StreamSearchReducer,
) error {
cSearchResults := make([]C.CSearchResult, 0)
cSearchResults = append(cSearchResults, newResult.cSearchResult)
cSearchResultPtr := &cSearchResults[0]
status := C.StreamReduce(streamReducer, cSearchResultPtr, 1)
if err := ConsumeCStatusIntoError(&status); err != nil {
return errors.Wrap(err, "StreamReduceSearchResult failed")
}
return nil
}
func GetStreamReduceResult(ctx context.Context, streamReducer StreamSearchReducer) (SearchResultDataBlobs, error) {
var cSearchResultDataBlobs SearchResultDataBlobs
status := C.GetStreamReduceResult(streamReducer, &cSearchResultDataBlobs)
if err := ConsumeCStatusIntoError(&status); err != nil {
return nil, errors.Wrap(err, "ReduceSearchResultsAndFillData failed")
}
return cSearchResultDataBlobs, nil
}
func ReduceSearchResultsAndFillData(ctx context.Context, plan *SearchPlan, searchResults []*SearchResult,
numSegments int64, sliceNQs []int64, sliceTopKs []int64,
) (SearchResultDataBlobs, error) {
if plan.cSearchPlan == nil {
return nil, fmt.Errorf("nil search plan")
}
if len(sliceNQs) == 0 {
return nil, fmt.Errorf("empty slice nqs is not allowed")
}
if len(sliceNQs) != len(sliceTopKs) {
return nil, fmt.Errorf("unaligned sliceNQs(len=%d) and sliceTopKs(len=%d)", len(sliceNQs), len(sliceTopKs))
}
cSearchResults := make([]C.CSearchResult, 0)
for _, res := range searchResults {
if res == nil {
return nil, fmt.Errorf("nil searchResult detected when reduceSearchResultsAndFillData")
}
cSearchResults = append(cSearchResults, res.cSearchResult)
}
cSearchResultPtr := &cSearchResults[0]
cNumSegments := C.int64_t(numSegments)
cSliceNQSPtr := (*C.int64_t)(&sliceNQs[0])
cSliceTopKSPtr := (*C.int64_t)(&sliceTopKs[0])
cNumSlices := C.int64_t(len(sliceNQs))
var cSearchResultDataBlobs SearchResultDataBlobs
traceCtx := ParseCTraceContext(ctx)
status := C.ReduceSearchResultsAndFillData(traceCtx.ctx, &cSearchResultDataBlobs, plan.cSearchPlan, cSearchResultPtr,
cNumSegments, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices)
if err := ConsumeCStatusIntoError(&status); err != nil {
return nil, errors.Wrap(err, "ReduceSearchResultsAndFillData failed")
}
return cSearchResultDataBlobs, nil
}
func GetSearchResultDataBlob(ctx context.Context, cSearchResultDataBlobs SearchResultDataBlobs, blobIndex int) ([]byte, error) {
var blob C.CProto
status := C.GetSearchResultDataBlob(&blob, cSearchResultDataBlobs, C.int32_t(blobIndex))
if err := ConsumeCStatusIntoError(&status); err != nil {
return nil, errors.Wrap(err, "marshal failed")
}
return getCProtoBlob(&blob), nil
}
func DeleteSearchResultDataBlobs(cSearchResultDataBlobs SearchResultDataBlobs) {
C.DeleteSearchResultDataBlobs(cSearchResultDataBlobs)
}
func DeleteStreamReduceHelper(cStreamReduceHelper StreamSearchReducer) {
C.DeleteStreamSearchReducer(cStreamReduceHelper)
}