diff --git a/core/src/dog_segment/ConcurrentVector.h b/core/src/dog_segment/ConcurrentVector.h index f87b8fee9a..82dcab3339 100644 --- a/core/src/dog_segment/ConcurrentVector.h +++ b/core/src/dog_segment/ConcurrentVector.h @@ -115,7 +115,7 @@ class ConcurrentVector : public VectorBase { void set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) override { - set_data(element_count, static_cast(source), element_count); + set_data(element_offset, static_cast(source), element_count); } void diff --git a/core/src/dog_segment/SegmentBase.h b/core/src/dog_segment/SegmentBase.h index c916f6e9b5..578239686d 100644 --- a/core/src/dog_segment/SegmentBase.h +++ b/core/src/dog_segment/SegmentBase.h @@ -45,7 +45,7 @@ class SegmentBase { // query contains metadata of virtual Status - Query(const query::QueryPtr& query, Timestamp timestamp, QueryResult& results) = 0; + Query(query::QueryPtr query, Timestamp timestamp, QueryResult& results) = 0; // // THIS FUNCTION IS REMOVED // virtual Status diff --git a/core/src/dog_segment/SegmentDefs.h b/core/src/dog_segment/SegmentDefs.h index ba5c3666ed..20dafe8bfd 100644 --- a/core/src/dog_segment/SegmentDefs.h +++ b/core/src/dog_segment/SegmentDefs.h @@ -2,6 +2,7 @@ #include #include +#include #include "utils/Types.h" // #include "knowhere/index/Index.h" diff --git a/core/src/dog_segment/SegmentNaive.cpp b/core/src/dog_segment/SegmentNaive.cpp index 9c7c29b87b..020ab58b6e 100644 --- a/core/src/dog_segment/SegmentNaive.cpp +++ b/core/src/dog_segment/SegmentNaive.cpp @@ -1,8 +1,9 @@ #include - +#include #include #include #include +#include namespace milvus::dog_segment { int @@ -171,14 +172,32 @@ SegmentNaive::QueryImpl(const query::QueryPtr& query, Timestamp timestamp, Query } Status -SegmentNaive::Query(const query::QueryPtr& query, Timestamp timestamp, QueryResult& result) { +SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) { // TODO: enable delete // TODO: enable index - auto& field = schema_->operator[](query->field_name); + + if(query_info == nullptr) { + query_info = std::make_shared(); + query_info->field_name = "fakevec"; + query_info->topK = 10; + query_info->num_queries = 1; + + auto dim = schema_->operator[]("fakevec").get_dim(); + std::default_random_engine e(42); + std::uniform_real_distribution<> dis(0.0, 1.0); + query_info->query_raw_data.resize(query_info->num_queries * dim); + for(auto& x: query_info->query_raw_data) { + x = dis(e); + } + } + + auto& field = schema_->operator[](query_info->field_name); assert(field.get_data_type() == DataType::VECTOR_FLOAT); auto dim = field.get_dim(); - auto topK = query->topK; + auto topK = query_info->topK; + auto num_queries = query_info->num_queries; + int64_t barrier = [&] { @@ -197,24 +216,68 @@ SegmentNaive::Query(const query::QueryPtr& query, Timestamp timestamp, QueryResu return beg; }(); + + if(topK > barrier) { + topK = barrier; + } + + auto get_L2_distance = [dim](const float* a, const float* b) { + float L2_distance = 0; + for(auto i = 0; i < dim; ++i) { + auto d = a[i] - b[i]; + L2_distance += d * d; + } + return L2_distance; + }; + + std::vector>> records(num_queries); // TODO: optimize auto vec_ptr = std::static_pointer_cast>(record_.entity_vec_[0]); for(int64_t i = 0; i < barrier; ++i) { auto element = vec_ptr->get_element(i); + for(auto query_id = 0; query_id < num_queries; ++query_id) { + auto query_blob = query_info->query_raw_data.data() + query_id * dim; + auto dis = get_L2_distance(query_blob, element); + auto& record = records[query_id]; + if(record.size() < topK) { + record.emplace(dis, i); + } else if(record.top().first > dis) { + record.emplace(dis, i); + record.pop(); + } + } + } - throw std::runtime_error("unimplemented"); + + result.num_queries_ = num_queries; + result.topK_ = topK; + auto row_num = topK * num_queries; + result.row_num_ = topK * num_queries; + + result.result_ids_.resize(row_num); + result.result_distances_.resize(row_num); + + for(int q_id = 0; q_id < num_queries; ++q_id) { + // reverse + for(int i = 0; i < topK; ++i) { + auto dst_id = topK - 1 - i + q_id * topK; + auto [dis, offset] = records[q_id].top(); + records[q_id].pop(); + result.result_ids_[dst_id] = record_.uids_[offset]; + result.result_distances_[dst_id] = dis; + } } return Status::OK(); - // find end of binary - // throw std::runtime_error("unimplemented"); - // auto record_ptr = GetMutableRecord(); - // if (record_ptr) { - // return QueryImpl(*record_ptr, query, timestamp, result); - // } else { - // assert(ready_immutable_); - // return QueryImpl(*record_immutable_, query, timestamp, result); - // } +// find end of binary +// throw std::runtime_error("unimplemented"); +// auto record_ptr = GetMutableRecord(); +// if (record_ptr) { +// return QueryImpl(*record_ptr, query, timestamp, result); +// } else { +// assert(ready_immutable_); +// return QueryImpl(*record_immutable_, query, timestamp, result); +// } } Status diff --git a/core/src/dog_segment/SegmentNaive.h b/core/src/dog_segment/SegmentNaive.h index e58bfa9e18..5b32ede440 100644 --- a/core/src/dog_segment/SegmentNaive.h +++ b/core/src/dog_segment/SegmentNaive.h @@ -58,7 +58,7 @@ class SegmentNaive : public SegmentBase { // query contains metadata of Status - Query(const query::QueryPtr& query, Timestamp timestamp, QueryResult& results) override; + Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override; // stop receive insert requests // will move data to immutable vector or something diff --git a/core/src/utils/Types.h b/core/src/utils/Types.h index 8da7b795e7..f394e5b23d 100644 --- a/core/src/utils/Types.h +++ b/core/src/utils/Types.h @@ -137,8 +137,10 @@ struct AttrsData { /////////////////////////////////////////////////////////////////////////////////////////////////// struct QueryResult { - uint64_t row_num_; - engine::ResultIds result_ids_; + uint64_t row_num_; // row_num_ = topK * num_queries_ + uint64_t topK_; + uint64_t num_queries_; // currently must be 1 + engine::ResultIds result_ids_; // top1, top2, ..; engine::ResultDistances result_distances_; // engine::DataChunkPtr data_chunk_; }; diff --git a/core/unittest/test_c_api.cpp b/core/unittest/test_c_api.cpp index 4d56f8f5ec..149e17355d 100644 --- a/core/unittest/test_c_api.cpp +++ b/core/unittest/test_c_api.cpp @@ -134,11 +134,11 @@ TEST(CApiTest, SearchTest) { auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); assert(ins_res == 0); - long result_ids; - float result_distances; - auto sea_res = Search(segment, nullptr, 0, &result_ids, &result_distances); + long result_ids[10]; + float result_distances[10]; + auto sea_res = Search(segment, nullptr, 0, result_ids, result_distances); assert(sea_res == 0); - assert(result_ids == 104490); + assert(result_ids[0] == 100911); DeleteCollection(collection); DeletePartition(partition); diff --git a/storage/conf/conf.go b/storage/conf/conf.go deleted file mode 100644 index ce7a0e5a86..0000000000 --- a/storage/conf/conf.go +++ /dev/null @@ -1,31 +0,0 @@ -package conf - -import ( - "fmt" - "path" - "os" - "github.com/BurntSushi/toml" -) - -type StorageConfig struct { - Driver string -} - -var config *StorageConfig = new(StorageConfig) - -func GetConfig() *StorageConfig { - return config -} - -func init() { - //读取配置文件 - dirPath, _ := os.Getwd() - filePath := path.Join(dirPath, "config/storage.toml") - fmt.Println("aaa") - fmt.Println(filePath) - fmt.Println("bbb") - _, err := toml.DecodeFile(filePath, config) - if err != nil { - fmt.Println(err) - } -} diff --git a/storage/conf/conf_test.go b/storage/conf/conf_test.go deleted file mode 100644 index b5a0714d65..0000000000 --- a/storage/conf/conf_test.go +++ /dev/null @@ -1,16 +0,0 @@ -package conf_test - -import ( - "fmt" - "os" - "storage/pkg/conf" - "testing" -) - -func TestMain(m *testing.M) { - exitCode := m.Run() - fmt.Println("haha") - config := conf.GetConfig() - fmt.Println(config.Driver) - os.Exit(exitCode) -} diff --git a/storage/pkg/conf/conf.go b/storage/pkg/conf/conf.go deleted file mode 100644 index ce7a0e5a86..0000000000 --- a/storage/pkg/conf/conf.go +++ /dev/null @@ -1,31 +0,0 @@ -package conf - -import ( - "fmt" - "path" - "os" - "github.com/BurntSushi/toml" -) - -type StorageConfig struct { - Driver string -} - -var config *StorageConfig = new(StorageConfig) - -func GetConfig() *StorageConfig { - return config -} - -func init() { - //读取配置文件 - dirPath, _ := os.Getwd() - filePath := path.Join(dirPath, "config/storage.toml") - fmt.Println("aaa") - fmt.Println(filePath) - fmt.Println("bbb") - _, err := toml.DecodeFile(filePath, config) - if err != nil { - fmt.Println(err) - } -} diff --git a/storage/pkg/conf/conf_test.go b/storage/pkg/conf/conf_test.go deleted file mode 100644 index b5a0714d65..0000000000 --- a/storage/pkg/conf/conf_test.go +++ /dev/null @@ -1,16 +0,0 @@ -package conf_test - -import ( - "fmt" - "os" - "storage/pkg/conf" - "testing" -) - -func TestMain(m *testing.M) { - exitCode := m.Run() - fmt.Println("haha") - config := conf.GetConfig() - fmt.Println(config.Driver) - os.Exit(exitCode) -}