Refactor collection's cgo call (#28055)

Signed-off-by: Enwei Jiao <enwei.jiao@zilliz.com>
pull/28110/head
Enwei Jiao 2023-11-02 13:02:13 +08:00 committed by GitHub
parent 7bd44bd671
commit f8dd589755
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 70 additions and 75 deletions

View File

@ -17,31 +17,29 @@
namespace milvus::segcore {
Collection::Collection(const std::string_view collection_proto)
: schema_proto_(collection_proto) {
parse();
Collection::Collection(const milvus::proto::schema::CollectionSchema* schema) {
Assert(schema != nullptr);
collection_name_ = schema->name();
schema_ = Schema::ParseFrom(*schema);
}
void
Collection::parse() {
// if (schema_proto_.empty()) {
// // TODO: remove hard code use unittests are ready
// std::cout << "WARN: Use default schema" << std::endl;
// auto schema = std::make_shared<Schema>();
// schema->AddDebugField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
// schema->AddDebugField("age", DataType::INT32);
// collection_name_ = "default-collection";
// schema_ = schema;
// return;
// }
Assert(!schema_proto_.empty());
Collection::Collection(const std::string_view schema_proto) {
milvus::proto::schema::CollectionSchema collection_schema;
auto suc = google::protobuf::TextFormat::ParseFromString(
schema_proto_, &collection_schema);
std::string(schema_proto), &collection_schema);
if (!suc) {
std::cerr << "unmarshal schema string failed" << std::endl;
LOG_SEGCORE_WARNING_ << "unmarshal schema string failed";
}
collection_name_ = collection_schema.name();
schema_ = Schema::ParseFrom(collection_schema);
}
Collection::Collection(const void* schema_proto, const int64_t length) {
Assert(schema_proto != nullptr);
milvus::proto::schema::CollectionSchema collection_schema;
auto suc = collection_schema.ParseFromArray(schema_proto, length);
if (!suc) {
LOG_SEGCORE_WARNING_ << "unmarshal schema string failed";
}
collection_name_ = collection_schema.name();
@ -49,20 +47,18 @@ Collection::parse() {
}
void
Collection::parseIndexMeta(const std::string_view index_meta_proto_) {
Assert(!index_meta_proto_.empty());
Collection::parseIndexMeta(const void* index_proto, const int64_t length) {
Assert(index_proto != nullptr);
milvus::proto::segcore::CollectionIndexMeta protobuf_indexMeta;
auto suc = google::protobuf::TextFormat::ParseFromString(
std::string(index_meta_proto_), &protobuf_indexMeta);
milvus::proto::segcore::CollectionIndexMeta indexMeta;
auto suc = indexMeta.ParseFromArray(index_proto, length);
if (!suc) {
LOG_SEGCORE_ERROR_ << "unmarshal index meta string failed" << std::endl;
LOG_SEGCORE_ERROR_ << "unmarshal index meta string failed";
return;
}
index_meta_ = std::shared_ptr<CollectionIndexMeta>(
new CollectionIndexMeta(protobuf_indexMeta));
index_meta_ = std::make_shared<CollectionIndexMeta>(indexMeta);
LOG_SEGCORE_INFO_ << "index meta info : " << index_meta_->ToString();
}

View File

@ -21,13 +21,12 @@ namespace milvus::segcore {
class Collection {
public:
explicit Collection(const std::string_view collection_proto);
explicit Collection(const milvus::proto::schema::CollectionSchema* schema);
explicit Collection(const std::string_view schema_proto);
explicit Collection(const void* collection_proto, const int64_t length);
void
parse();
void
parseIndexMeta(const std::string_view index_meta_proto_blob);
parseIndexMeta(const void* index_meta_proto_blob, const int64_t length);
public:
SchemaPtr&
@ -47,7 +46,6 @@ class Collection {
private:
std::string collection_name_;
std::string schema_proto_;
SchemaPtr schema_;
IndexMetaPtr index_meta_;
};

View File

@ -18,17 +18,18 @@
#include "segcore/Collection.h"
CCollection
NewCollection(const char* schema_proto_blob) {
auto proto = std::string(schema_proto_blob);
auto collection = std::make_unique<milvus::segcore::Collection>(proto);
NewCollection(const void* schema_proto_blob, const int64_t length) {
auto collection = std::make_unique<milvus::segcore::Collection>(
schema_proto_blob, length);
return (void*)collection.release();
}
void
SetIndexMeta(CCollection collection, const char* index_meta_proto_blob) {
SetIndexMeta(CCollection collection,
const void* proto_blob,
const int64_t length) {
auto col = (milvus::segcore::Collection*)collection;
auto proto = std::string_view(index_meta_proto_blob);
col->parseIndexMeta(proto);
col->parseIndexMeta(proto_blob, length);
}
void

View File

@ -11,6 +11,8 @@
#pragma once
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
@ -18,10 +20,12 @@ extern "C" {
typedef void* CCollection;
CCollection
NewCollection(const char* schema_proto_blob);
NewCollection(const void* schema_proto_blob, const int64_t length);
void
SetIndexMeta(CCollection collection, const char* index_meta_proto_blob);
SetIndexMeta(CCollection collection,
const void* proto_blob,
const int64_t length);
void
DeleteCollection(CCollection collection);

View File

@ -275,7 +275,12 @@ TEST(CApiTest, CollectionTest) {
TEST(CApiTest, SetIndexMetaTest) {
auto collection = NewCollection(get_default_schema_config());
SetIndexMeta(collection, get_default_index_meta());
milvus::proto::segcore::CollectionIndexMeta indexMeta;
indexMeta.ParseFromString(get_default_index_meta());
char buffer[indexMeta.ByteSizeLong()];
indexMeta.SerializeToArray(buffer, indexMeta.ByteSizeLong());
SetIndexMeta(collection, buffer, indexMeta.ByteSizeLong());
DeleteCollection(collection);
}

View File

@ -18,6 +18,7 @@
#include "query/ExprImpl.h"
#include "segcore/Reduce.h"
#include "segcore/reduce_c.h"
#include "test_utils/DataGen.h"
#include "test_utils/PbHelper.h"
#include "test_utils/indexbuilder_test_utils.h"
@ -253,35 +254,11 @@ generate_collection_schema(std::string metric_type, int dim, bool is_fp16) {
return schema_string;
}
CCollection
NewCollection(const char* schema_proto_blob) {
auto proto = std::string(schema_proto_blob);
auto collection = std::make_unique<milvus::segcore::Collection>(proto);
return (void*)collection.release();
}
TEST(Float16, CApiCPlan) {
std::string schema_string =
generate_collection_schema(knowhere::metric::L2, 16, true);
auto collection = NewCollection(schema_string.c_str());
// const char* dsl_string = R"(
// {
// "bool": {
// "vector": {
// "fakevec": {
// "metric_type": "L2",
// "params": {
// "nprobe": 10
// },
// "query": "$0",
// "topk": 10,
// "round_decimal": 3
// }
// }
// }
// })";
milvus::proto::plan::PlanNode plan_node;
auto vector_anns = plan_node.mutable_vector_anns();
vector_anns->set_vector_type(
@ -416,4 +393,4 @@ TEST(Float16, ExecWithPredicate) {
query::Json json = SearchResultToJson(*sr);
std::cout << json.dump(2);
}
}

View File

@ -25,12 +25,14 @@
#include "index/StringIndexSort.h"
#include "index/VectorMemIndex.h"
#include "query/SearchOnIndex.h"
#include "segcore/Collection.h"
#include "segcore/SegmentGrowingImpl.h"
#include "segcore/SegmentSealedImpl.h"
#include "segcore/Utils.h"
#include "knowhere/comp/index_param.h"
#include "PbHelper.h"
#include "segcore/collection_c.h"
using boost::algorithm::starts_with;
@ -1012,4 +1014,11 @@ GenRandomIds(int rows, int64_t seed = 42) {
return ids_ds;
}
inline CCollection
NewCollection(const char* schema_proto_blob) {
auto proto = std::string(schema_proto_blob);
auto collection = std::make_unique<milvus::segcore::Collection>(proto);
return (void*)collection.release();
}
} // namespace milvus::segcore

View File

@ -195,16 +195,21 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM
CCollection
NewCollection(const char* schema_proto_blob);
*/
schemaBlob := proto.MarshalTextString(schema)
cSchemaBlob := C.CString(schemaBlob)
defer C.free(unsafe.Pointer(cSchemaBlob))
schemaBlob, err := proto.Marshal(schema)
if err != nil {
log.Warn("marshal schema failed", zap.Error(err))
return nil
}
collection := C.NewCollection(cSchemaBlob)
collection := C.NewCollection(unsafe.Pointer(&schemaBlob[0]), (C.int64_t)(len(schemaBlob)))
if indexMeta != nil && len(indexMeta.GetIndexMetas()) > 0 && indexMeta.GetMaxIndexRowCount() > 0 {
indexMetaBlob := proto.MarshalTextString(indexMeta)
cIndexMetaBlob := C.CString(indexMetaBlob)
C.SetIndexMeta(collection, cIndexMetaBlob)
indexMetaBlob, err := proto.Marshal(indexMeta)
if err != nil {
log.Warn("marshal index meta failed", zap.Error(err))
return nil
}
C.SetIndexMeta(collection, unsafe.Pointer(&indexMetaBlob[0]), (C.int64_t)(len(indexMetaBlob)))
}
return &Collection{