Change API retrieve return type from CProtoResult to CProto (#11555)

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
pull/11714/head
Cai Yudong 2021-11-12 10:04:49 +08:00 committed by GitHub
parent eb41afc661
commit 5fdc6626cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 86 additions and 40 deletions

View File

@ -16,14 +16,6 @@
namespace milvus {
inline CProtoResult
AllocCProtoResult(const google::protobuf::Message& msg) {
auto size = msg.ByteSize();
void* buffer = malloc(size);
msg.SerializePartialToArray(buffer, size);
return CProtoResult{CStatus{Success}, CProto{buffer, size}};
}
inline CStatus
SuccessCStatus() {
return CStatus{Success, ""};

View File

@ -53,11 +53,6 @@ typedef struct CLoadDeletedRecordInfo {
int64_t row_count;
} CLoadDeletedRecordInfo;
typedef struct CProtoResult {
CStatus status;
CProto proto;
} CProtoResult;
#ifdef __cplusplus
}
#endif

View File

@ -87,6 +87,30 @@ Search(CSegmentInterface c_segment,
}
}
void
DeleteRetrieveResult(CRetrieveResult* retrieve_result) {
std::free((void*)(retrieve_result->proto_blob));
}
CStatus
Retrieve(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp, CRetrieveResult* result) {
try {
auto segment = (const milvus::segcore::SegmentInterface*)c_segment;
auto plan = (const milvus::query::RetrievePlan*)c_plan;
auto retrieve_result = segment->Retrieve(plan, timestamp);
auto size = retrieve_result->ByteSize();
void* buffer = malloc(size);
retrieve_result->SerializePartialToArray(buffer, size);
result->proto_blob = buffer;
result->proto_size = size;
return milvus::SuccessCStatus();
} catch (std::exception& e) {
return milvus::FailureCStatus(UnexpectedError, e.what());
}
}
int64_t
GetMemoryUsageInBytes(CSegmentInterface c_segment) {
auto segment = (milvus::segcore::SegmentInterface*)c_segment;
@ -237,15 +261,3 @@ DropSealedSegmentIndex(CSegmentInterface c_segment, int64_t field_id) {
return milvus::FailureCStatus(UnexpectedError, e.what());
}
}
CProtoResult
Retrieve(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp) {
try {
auto segment = (const milvus::segcore::SegmentInterface*)c_segment;
auto plan = (const milvus::query::RetrievePlan*)c_plan;
auto result = segment->Retrieve(plan, timestamp);
return milvus::AllocCProtoResult(*result);
} catch (std::exception& e) {
return CProtoResult{milvus::FailureCStatus(UnexpectedError, e.what())};
}
}

View File

@ -25,7 +25,7 @@ extern "C" {
typedef void* CSegmentInterface;
typedef void* CSearchResult;
typedef void* CRetrieveResult;
typedef CProto CRetrieveResult;
////////////////////////////// common interfaces //////////////////////////////
CSegmentInterface
@ -44,8 +44,11 @@ Search(CSegmentInterface c_segment,
uint64_t timestamp,
CSearchResult* result);
CProtoResult
Retrieve(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp);
void
DeleteRetrieveResult(CRetrieveResult* retrieve_result);
CStatus
Retrieve(CSegmentInterface c_segment, CRetrievePlan c_plan, uint64_t timestamp, CRetrieveResult* result);
int64_t
GetMemoryUsageInBytes(CSegmentInterface c_segment);

View File

@ -24,6 +24,7 @@
#include "index/knowhere/knowhere/index/vector_index/IndexIVFPQ.h"
#include "pb/milvus.pb.h"
#include "pb/plan.pb.h"
#include "query/ExprImpl.h"
#include "segcore/Collection.h"
#include "segcore/reduce_c.h"
#include "test_utils/DataGen.h"
@ -351,6 +352,44 @@ TEST(CApiTest, SearchTestWithExpr) {
DeleteSegment(segment);
}
TEST(CApiTest, RetrieveTestWithExpr) {
auto collection = NewCollection(get_default_schema_config());
auto segment = NewSegment(collection, 0, Growing);
int N = 10000;
auto [raw_data, timestamps, uids] = generate_data(N);
auto line_sizeof = (sizeof(int) + sizeof(float) * DIM);
int64_t offset;
PreInsert(segment, N, &offset);
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
ASSERT_EQ(ins_res.error_code, Success);
auto schema = ((milvus::segcore::Collection*)collection)->get_schema();
auto plan = std::make_unique<query::RetrievePlan>(*schema);
// create retrieve plan "age in [0]"
auto term_expr = std::make_unique<query::TermExprImpl<int64_t>>();
term_expr->field_offset_ = FieldOffset(1);
term_expr->data_type_ = DataType::INT32;
term_expr->terms_.emplace_back(0);
plan->plan_node_ = std::make_unique<query::RetrievePlanNode>();
plan->plan_node_->predicate_ = std::move(term_expr);
std::vector<FieldOffset> target_offsets{FieldOffset(0), FieldOffset(1)};
plan->field_offsets_ = target_offsets;
CRetrieveResult retrieve_result;
auto res = Retrieve(segment, plan.release(), timestamps[0], &retrieve_result);
ASSERT_EQ(res.error_code, Success);
DeleteRetrievePlan(plan.release());
DeleteRetrieveResult(&retrieve_result);
DeleteCollection(collection);
DeleteSegment(segment);
}
TEST(CApiTest, GetMemoryUsageInBytesTest) {
auto collection = NewCollection(get_default_schema_config());
auto segment = NewSegment(collection, 0, Growing);

View File

@ -80,17 +80,12 @@ func HandleCStatus(status *C.CStatus, extraInfo string) error {
return errors.New(finalMsg)
}
// HandleCProtoResult deal with the result proto returned from CGO
func HandleCProtoResult(cRes *C.CProtoResult, msg proto.Message) error {
// HandleCProto deal with the result proto returned from CGO
func HandleCProto(cRes *C.CProto, msg proto.Message) error {
// Standalone CProto is protobuf created by C side,
// Passed from c side
// memory is managed manually
err := HandleCStatus(&cRes.status, "")
if err != nil {
return err
}
cpro := cRes.proto
blob := C.GoBytes(unsafe.Pointer(cpro.proto_blob), C.int32_t(cpro.proto_size))
defer C.free(cpro.proto_blob)
blob := C.GoBytes(unsafe.Pointer(cRes.proto_blob), C.int32_t(cRes.proto_size))
defer C.free(cRes.proto_blob)
return proto.Unmarshal(blob, msg)
}

View File

@ -35,6 +35,11 @@ type MarshaledHits struct {
cMarshaledHits C.CMarshaledHits
}
// RetrieveResult contains a pointer to the retrieve result in C++ memory
type RetrieveResult struct {
cRetrieveResult C.CRetrieveResult
}
func reduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchResult, numSegments int64) error {
if plan.cSearchPlan == nil {
return errors.New("nil search plan")

View File

@ -316,10 +316,15 @@ func (s *Segment) retrieve(plan *RetrievePlan) (*segcorepb.RetrieveResults, erro
if s.segmentPtr == nil {
return nil, errors.New("null seg core pointer")
}
resProto := C.Retrieve(s.segmentPtr, plan.cRetrievePlan, C.uint64_t(plan.Timestamp))
var retrieveResult RetrieveResult
ts := C.uint64_t(plan.Timestamp)
status := C.Retrieve(s.segmentPtr, plan.cRetrievePlan, ts, &retrieveResult.cRetrieveResult)
if err := HandleCStatus(&status, "Retrieve failed"); err != nil {
return nil, err
}
result := new(segcorepb.RetrieveResults)
err := HandleCProtoResult(&resProto, result)
if err != nil {
if err := HandleCProto(&retrieveResult.cRetrieveResult, result); err != nil {
return nil, err
}
return result, nil