mirror of https://github.com/milvus-io/milvus.git
Enable Query of Segment Naive
Signed-off-by: FluorineDog <guilin.gou@zilliz.com>pull/4973/head^2
parent
f319803eab
commit
78992a98b0
|
@ -115,7 +115,7 @@ class ConcurrentVector : public VectorBase {
|
||||||
|
|
||||||
void
|
void
|
||||||
set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) override {
|
set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) override {
|
||||||
set_data(element_count, static_cast<const Type*>(source), element_count);
|
set_data(element_offset, static_cast<const Type*>(source), element_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
|
|
|
@ -45,7 +45,7 @@ class SegmentBase {
|
||||||
|
|
||||||
// query contains metadata of
|
// query contains metadata of
|
||||||
virtual Status
|
virtual Status
|
||||||
Query(const query::QueryPtr& query, Timestamp timestamp, QueryResult& results) = 0;
|
Query(query::QueryPtr query, Timestamp timestamp, QueryResult& results) = 0;
|
||||||
|
|
||||||
// // THIS FUNCTION IS REMOVED
|
// // THIS FUNCTION IS REMOVED
|
||||||
// virtual Status
|
// virtual Status
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
#include "utils/Types.h"
|
#include "utils/Types.h"
|
||||||
// #include "knowhere/index/Index.h"
|
// #include "knowhere/index/Index.h"
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
#include <dog_segment/SegmentNaive.h>
|
#include <dog_segment/SegmentNaive.h>
|
||||||
|
#include <random>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
namespace milvus::dog_segment {
|
namespace milvus::dog_segment {
|
||||||
int
|
int
|
||||||
|
@ -171,14 +172,32 @@ SegmentNaive::QueryImpl(const query::QueryPtr& query, Timestamp timestamp, Query
|
||||||
}
|
}
|
||||||
|
|
||||||
Status
|
Status
|
||||||
SegmentNaive::Query(const query::QueryPtr& query, Timestamp timestamp, QueryResult& result) {
|
SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
|
||||||
// TODO: enable delete
|
// TODO: enable delete
|
||||||
// TODO: enable index
|
// TODO: enable index
|
||||||
auto& field = schema_->operator[](query->field_name);
|
|
||||||
|
if(query_info == nullptr) {
|
||||||
|
query_info = std::make_shared<query::Query>();
|
||||||
|
query_info->field_name = "fakevec";
|
||||||
|
query_info->topK = 10;
|
||||||
|
query_info->num_queries = 1;
|
||||||
|
|
||||||
|
auto dim = schema_->operator[]("fakevec").get_dim();
|
||||||
|
std::default_random_engine e(42);
|
||||||
|
std::uniform_real_distribution<> dis(0.0, 1.0);
|
||||||
|
query_info->query_raw_data.resize(query_info->num_queries * dim);
|
||||||
|
for(auto& x: query_info->query_raw_data) {
|
||||||
|
x = dis(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& field = schema_->operator[](query_info->field_name);
|
||||||
assert(field.get_data_type() == DataType::VECTOR_FLOAT);
|
assert(field.get_data_type() == DataType::VECTOR_FLOAT);
|
||||||
|
|
||||||
auto dim = field.get_dim();
|
auto dim = field.get_dim();
|
||||||
auto topK = query->topK;
|
auto topK = query_info->topK;
|
||||||
|
auto num_queries = query_info->num_queries;
|
||||||
|
|
||||||
|
|
||||||
int64_t barrier = [&]
|
int64_t barrier = [&]
|
||||||
{
|
{
|
||||||
|
@ -197,24 +216,68 @@ SegmentNaive::Query(const query::QueryPtr& query, Timestamp timestamp, QueryResu
|
||||||
return beg;
|
return beg;
|
||||||
}();
|
}();
|
||||||
|
|
||||||
|
|
||||||
|
if(topK > barrier) {
|
||||||
|
topK = barrier;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_L2_distance = [dim](const float* a, const float* b) {
|
||||||
|
float L2_distance = 0;
|
||||||
|
for(auto i = 0; i < dim; ++i) {
|
||||||
|
auto d = a[i] - b[i];
|
||||||
|
L2_distance += d * d;
|
||||||
|
}
|
||||||
|
return L2_distance;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::priority_queue<std::pair<float, int>>> records(num_queries);
|
||||||
// TODO: optimize
|
// TODO: optimize
|
||||||
auto vec_ptr = std::static_pointer_cast<ConcurrentVector<float>>(record_.entity_vec_[0]);
|
auto vec_ptr = std::static_pointer_cast<ConcurrentVector<float>>(record_.entity_vec_[0]);
|
||||||
for(int64_t i = 0; i < barrier; ++i) {
|
for(int64_t i = 0; i < barrier; ++i) {
|
||||||
auto element = vec_ptr->get_element(i);
|
auto element = vec_ptr->get_element(i);
|
||||||
|
for(auto query_id = 0; query_id < num_queries; ++query_id) {
|
||||||
|
auto query_blob = query_info->query_raw_data.data() + query_id * dim;
|
||||||
|
auto dis = get_L2_distance(query_blob, element);
|
||||||
|
auto& record = records[query_id];
|
||||||
|
if(record.size() < topK) {
|
||||||
|
record.emplace(dis, i);
|
||||||
|
} else if(record.top().first > dis) {
|
||||||
|
record.emplace(dis, i);
|
||||||
|
record.pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
throw std::runtime_error("unimplemented");
|
|
||||||
|
result.num_queries_ = num_queries;
|
||||||
|
result.topK_ = topK;
|
||||||
|
auto row_num = topK * num_queries;
|
||||||
|
result.row_num_ = topK * num_queries;
|
||||||
|
|
||||||
|
result.result_ids_.resize(row_num);
|
||||||
|
result.result_distances_.resize(row_num);
|
||||||
|
|
||||||
|
for(int q_id = 0; q_id < num_queries; ++q_id) {
|
||||||
|
// reverse
|
||||||
|
for(int i = 0; i < topK; ++i) {
|
||||||
|
auto dst_id = topK - 1 - i + q_id * topK;
|
||||||
|
auto [dis, offset] = records[q_id].top();
|
||||||
|
records[q_id].pop();
|
||||||
|
result.result_ids_[dst_id] = record_.uids_[offset];
|
||||||
|
result.result_distances_[dst_id] = dis;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
// find end of binary
|
// find end of binary
|
||||||
// throw std::runtime_error("unimplemented");
|
// throw std::runtime_error("unimplemented");
|
||||||
// auto record_ptr = GetMutableRecord();
|
// auto record_ptr = GetMutableRecord();
|
||||||
// if (record_ptr) {
|
// if (record_ptr) {
|
||||||
// return QueryImpl(*record_ptr, query, timestamp, result);
|
// return QueryImpl(*record_ptr, query, timestamp, result);
|
||||||
// } else {
|
// } else {
|
||||||
// assert(ready_immutable_);
|
// assert(ready_immutable_);
|
||||||
// return QueryImpl(*record_immutable_, query, timestamp, result);
|
// return QueryImpl(*record_immutable_, query, timestamp, result);
|
||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
Status
|
Status
|
||||||
|
|
|
@ -58,7 +58,7 @@ class SegmentNaive : public SegmentBase {
|
||||||
|
|
||||||
// query contains metadata of
|
// query contains metadata of
|
||||||
Status
|
Status
|
||||||
Query(const query::QueryPtr& query, Timestamp timestamp, QueryResult& results) override;
|
Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override;
|
||||||
|
|
||||||
// stop receive insert requests
|
// stop receive insert requests
|
||||||
// will move data to immutable vector or something
|
// will move data to immutable vector or something
|
||||||
|
|
|
@ -137,8 +137,10 @@ struct AttrsData {
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
struct QueryResult {
|
struct QueryResult {
|
||||||
uint64_t row_num_;
|
uint64_t row_num_; // row_num_ = topK * num_queries_
|
||||||
engine::ResultIds result_ids_;
|
uint64_t topK_;
|
||||||
|
uint64_t num_queries_; // currently must be 1
|
||||||
|
engine::ResultIds result_ids_; // top1, top2, ..;
|
||||||
engine::ResultDistances result_distances_;
|
engine::ResultDistances result_distances_;
|
||||||
// engine::DataChunkPtr data_chunk_;
|
// engine::DataChunkPtr data_chunk_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -134,11 +134,11 @@ TEST(CApiTest, SearchTest) {
|
||||||
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
|
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
|
||||||
assert(ins_res == 0);
|
assert(ins_res == 0);
|
||||||
|
|
||||||
long result_ids;
|
long result_ids[10];
|
||||||
float result_distances;
|
float result_distances[10];
|
||||||
auto sea_res = Search(segment, nullptr, 0, &result_ids, &result_distances);
|
auto sea_res = Search(segment, nullptr, 0, result_ids, result_distances);
|
||||||
assert(sea_res == 0);
|
assert(sea_res == 0);
|
||||||
assert(result_ids == 104490);
|
assert(result_ids[0] == 100911);
|
||||||
|
|
||||||
DeleteCollection(collection);
|
DeleteCollection(collection);
|
||||||
DeletePartition(partition);
|
DeletePartition(partition);
|
||||||
|
|
|
@ -1,31 +0,0 @@
|
||||||
package conf
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"path"
|
|
||||||
"os"
|
|
||||||
"github.com/BurntSushi/toml"
|
|
||||||
)
|
|
||||||
|
|
||||||
type StorageConfig struct {
|
|
||||||
Driver string
|
|
||||||
}
|
|
||||||
|
|
||||||
var config *StorageConfig = new(StorageConfig)
|
|
||||||
|
|
||||||
func GetConfig() *StorageConfig {
|
|
||||||
return config
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
//读取配置文件
|
|
||||||
dirPath, _ := os.Getwd()
|
|
||||||
filePath := path.Join(dirPath, "config/storage.toml")
|
|
||||||
fmt.Println("aaa")
|
|
||||||
fmt.Println(filePath)
|
|
||||||
fmt.Println("bbb")
|
|
||||||
_, err := toml.DecodeFile(filePath, config)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,16 +0,0 @@
|
||||||
package conf_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"storage/pkg/conf"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
|
||||||
exitCode := m.Run()
|
|
||||||
fmt.Println("haha")
|
|
||||||
config := conf.GetConfig()
|
|
||||||
fmt.Println(config.Driver)
|
|
||||||
os.Exit(exitCode)
|
|
||||||
}
|
|
|
@ -1,31 +0,0 @@
|
||||||
package conf
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"path"
|
|
||||||
"os"
|
|
||||||
"github.com/BurntSushi/toml"
|
|
||||||
)
|
|
||||||
|
|
||||||
type StorageConfig struct {
|
|
||||||
Driver string
|
|
||||||
}
|
|
||||||
|
|
||||||
var config *StorageConfig = new(StorageConfig)
|
|
||||||
|
|
||||||
func GetConfig() *StorageConfig {
|
|
||||||
return config
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
//读取配置文件
|
|
||||||
dirPath, _ := os.Getwd()
|
|
||||||
filePath := path.Join(dirPath, "config/storage.toml")
|
|
||||||
fmt.Println("aaa")
|
|
||||||
fmt.Println(filePath)
|
|
||||||
fmt.Println("bbb")
|
|
||||||
_, err := toml.DecodeFile(filePath, config)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,16 +0,0 @@
|
||||||
package conf_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"storage/pkg/conf"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
|
||||||
exitCode := m.Run()
|
|
||||||
fmt.Println("haha")
|
|
||||||
config := conf.GetConfig()
|
|
||||||
fmt.Println(config.Driver)
|
|
||||||
os.Exit(exitCode)
|
|
||||||
}
|
|
Loading…
Reference in New Issue