mirror of https://github.com/milvus-io/milvus.git
Enable Query of Segment Naive
Signed-off-by: FluorineDog <guilin.gou@zilliz.com>pull/4973/head^2
parent
f319803eab
commit
78992a98b0
|
@ -115,7 +115,7 @@ class ConcurrentVector : public VectorBase {
|
|||
|
||||
void
|
||||
set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) override {
|
||||
set_data(element_count, static_cast<const Type*>(source), element_count);
|
||||
set_data(element_offset, static_cast<const Type*>(source), element_count);
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
@ -45,7 +45,7 @@ class SegmentBase {
|
|||
|
||||
// query contains metadata of
|
||||
virtual Status
|
||||
Query(const query::QueryPtr& query, Timestamp timestamp, QueryResult& results) = 0;
|
||||
Query(query::QueryPtr query, Timestamp timestamp, QueryResult& results) = 0;
|
||||
|
||||
// // THIS FUNCTION IS REMOVED
|
||||
// virtual Status
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <assert.h>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "utils/Types.h"
|
||||
// #include "knowhere/index/Index.h"
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
#include <dog_segment/SegmentNaive.h>
|
||||
|
||||
#include <random>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <thread>
|
||||
#include <queue>
|
||||
|
||||
namespace milvus::dog_segment {
|
||||
int
|
||||
|
@ -171,14 +172,32 @@ SegmentNaive::QueryImpl(const query::QueryPtr& query, Timestamp timestamp, Query
|
|||
}
|
||||
|
||||
Status
|
||||
SegmentNaive::Query(const query::QueryPtr& query, Timestamp timestamp, QueryResult& result) {
|
||||
SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
|
||||
// TODO: enable delete
|
||||
// TODO: enable index
|
||||
auto& field = schema_->operator[](query->field_name);
|
||||
|
||||
if(query_info == nullptr) {
|
||||
query_info = std::make_shared<query::Query>();
|
||||
query_info->field_name = "fakevec";
|
||||
query_info->topK = 10;
|
||||
query_info->num_queries = 1;
|
||||
|
||||
auto dim = schema_->operator[]("fakevec").get_dim();
|
||||
std::default_random_engine e(42);
|
||||
std::uniform_real_distribution<> dis(0.0, 1.0);
|
||||
query_info->query_raw_data.resize(query_info->num_queries * dim);
|
||||
for(auto& x: query_info->query_raw_data) {
|
||||
x = dis(e);
|
||||
}
|
||||
}
|
||||
|
||||
auto& field = schema_->operator[](query_info->field_name);
|
||||
assert(field.get_data_type() == DataType::VECTOR_FLOAT);
|
||||
|
||||
auto dim = field.get_dim();
|
||||
auto topK = query->topK;
|
||||
auto topK = query_info->topK;
|
||||
auto num_queries = query_info->num_queries;
|
||||
|
||||
|
||||
int64_t barrier = [&]
|
||||
{
|
||||
|
@ -197,24 +216,68 @@ SegmentNaive::Query(const query::QueryPtr& query, Timestamp timestamp, QueryResu
|
|||
return beg;
|
||||
}();
|
||||
|
||||
|
||||
if(topK > barrier) {
|
||||
topK = barrier;
|
||||
}
|
||||
|
||||
auto get_L2_distance = [dim](const float* a, const float* b) {
|
||||
float L2_distance = 0;
|
||||
for(auto i = 0; i < dim; ++i) {
|
||||
auto d = a[i] - b[i];
|
||||
L2_distance += d * d;
|
||||
}
|
||||
return L2_distance;
|
||||
};
|
||||
|
||||
std::vector<std::priority_queue<std::pair<float, int>>> records(num_queries);
|
||||
// TODO: optimize
|
||||
auto vec_ptr = std::static_pointer_cast<ConcurrentVector<float>>(record_.entity_vec_[0]);
|
||||
for(int64_t i = 0; i < barrier; ++i) {
|
||||
auto element = vec_ptr->get_element(i);
|
||||
for(auto query_id = 0; query_id < num_queries; ++query_id) {
|
||||
auto query_blob = query_info->query_raw_data.data() + query_id * dim;
|
||||
auto dis = get_L2_distance(query_blob, element);
|
||||
auto& record = records[query_id];
|
||||
if(record.size() < topK) {
|
||||
record.emplace(dis, i);
|
||||
} else if(record.top().first > dis) {
|
||||
record.emplace(dis, i);
|
||||
record.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw std::runtime_error("unimplemented");
|
||||
|
||||
result.num_queries_ = num_queries;
|
||||
result.topK_ = topK;
|
||||
auto row_num = topK * num_queries;
|
||||
result.row_num_ = topK * num_queries;
|
||||
|
||||
result.result_ids_.resize(row_num);
|
||||
result.result_distances_.resize(row_num);
|
||||
|
||||
for(int q_id = 0; q_id < num_queries; ++q_id) {
|
||||
// reverse
|
||||
for(int i = 0; i < topK; ++i) {
|
||||
auto dst_id = topK - 1 - i + q_id * topK;
|
||||
auto [dis, offset] = records[q_id].top();
|
||||
records[q_id].pop();
|
||||
result.result_ids_[dst_id] = record_.uids_[offset];
|
||||
result.result_distances_[dst_id] = dis;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
// find end of binary
|
||||
// throw std::runtime_error("unimplemented");
|
||||
// auto record_ptr = GetMutableRecord();
|
||||
// if (record_ptr) {
|
||||
// return QueryImpl(*record_ptr, query, timestamp, result);
|
||||
// } else {
|
||||
// assert(ready_immutable_);
|
||||
// return QueryImpl(*record_immutable_, query, timestamp, result);
|
||||
// }
|
||||
// find end of binary
|
||||
// throw std::runtime_error("unimplemented");
|
||||
// auto record_ptr = GetMutableRecord();
|
||||
// if (record_ptr) {
|
||||
// return QueryImpl(*record_ptr, query, timestamp, result);
|
||||
// } else {
|
||||
// assert(ready_immutable_);
|
||||
// return QueryImpl(*record_immutable_, query, timestamp, result);
|
||||
// }
|
||||
}
|
||||
|
||||
Status
|
||||
|
|
|
@ -58,7 +58,7 @@ class SegmentNaive : public SegmentBase {
|
|||
|
||||
// query contains metadata of
|
||||
Status
|
||||
Query(const query::QueryPtr& query, Timestamp timestamp, QueryResult& results) override;
|
||||
Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override;
|
||||
|
||||
// stop receive insert requests
|
||||
// will move data to immutable vector or something
|
||||
|
|
|
@ -137,8 +137,10 @@ struct AttrsData {
|
|||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct QueryResult {
|
||||
uint64_t row_num_;
|
||||
engine::ResultIds result_ids_;
|
||||
uint64_t row_num_; // row_num_ = topK * num_queries_
|
||||
uint64_t topK_;
|
||||
uint64_t num_queries_; // currently must be 1
|
||||
engine::ResultIds result_ids_; // top1, top2, ..;
|
||||
engine::ResultDistances result_distances_;
|
||||
// engine::DataChunkPtr data_chunk_;
|
||||
};
|
||||
|
|
|
@ -134,11 +134,11 @@ TEST(CApiTest, SearchTest) {
|
|||
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
|
||||
assert(ins_res == 0);
|
||||
|
||||
long result_ids;
|
||||
float result_distances;
|
||||
auto sea_res = Search(segment, nullptr, 0, &result_ids, &result_distances);
|
||||
long result_ids[10];
|
||||
float result_distances[10];
|
||||
auto sea_res = Search(segment, nullptr, 0, result_ids, result_distances);
|
||||
assert(sea_res == 0);
|
||||
assert(result_ids == 104490);
|
||||
assert(result_ids[0] == 100911);
|
||||
|
||||
DeleteCollection(collection);
|
||||
DeletePartition(partition);
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path"
|
||||
"os"
|
||||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
type StorageConfig struct {
|
||||
Driver string
|
||||
}
|
||||
|
||||
var config *StorageConfig = new(StorageConfig)
|
||||
|
||||
func GetConfig() *StorageConfig {
|
||||
return config
|
||||
}
|
||||
|
||||
func init() {
|
||||
//读取配置文件
|
||||
dirPath, _ := os.Getwd()
|
||||
filePath := path.Join(dirPath, "config/storage.toml")
|
||||
fmt.Println("aaa")
|
||||
fmt.Println(filePath)
|
||||
fmt.Println("bbb")
|
||||
_, err := toml.DecodeFile(filePath, config)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
}
|
|
@ -1,16 +0,0 @@
|
|||
package conf_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"storage/pkg/conf"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
exitCode := m.Run()
|
||||
fmt.Println("haha")
|
||||
config := conf.GetConfig()
|
||||
fmt.Println(config.Driver)
|
||||
os.Exit(exitCode)
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path"
|
||||
"os"
|
||||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
type StorageConfig struct {
|
||||
Driver string
|
||||
}
|
||||
|
||||
var config *StorageConfig = new(StorageConfig)
|
||||
|
||||
func GetConfig() *StorageConfig {
|
||||
return config
|
||||
}
|
||||
|
||||
func init() {
|
||||
//读取配置文件
|
||||
dirPath, _ := os.Getwd()
|
||||
filePath := path.Join(dirPath, "config/storage.toml")
|
||||
fmt.Println("aaa")
|
||||
fmt.Println(filePath)
|
||||
fmt.Println("bbb")
|
||||
_, err := toml.DecodeFile(filePath, config)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
}
|
|
@ -1,16 +0,0 @@
|
|||
package conf_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"storage/pkg/conf"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
exitCode := m.Run()
|
||||
fmt.Println("haha")
|
||||
config := conf.GetConfig()
|
||||
fmt.Println(config.Driver)
|
||||
os.Exit(exitCode)
|
||||
}
|
Loading…
Reference in New Issue