Enable Query of Segment Naive

Signed-off-by: FluorineDog <guilin.gou@zilliz.com>
pull/4973/head^2
FluorineDog 2020-09-10 09:42:19 +08:00 committed by yefu.chen
parent f319803eab
commit 78992a98b0
11 changed files with 89 additions and 117 deletions

View File

@ -115,7 +115,7 @@ class ConcurrentVector : public VectorBase {
void void
set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) override { set_data_raw(ssize_t element_offset, void* source, ssize_t element_count) override {
set_data(element_count, static_cast<const Type*>(source), element_count); set_data(element_offset, static_cast<const Type*>(source), element_count);
} }
void void

View File

@ -45,7 +45,7 @@ class SegmentBase {
// query contains metadata of // query contains metadata of
virtual Status virtual Status
Query(const query::QueryPtr& query, Timestamp timestamp, QueryResult& results) = 0; Query(query::QueryPtr query, Timestamp timestamp, QueryResult& results) = 0;
// // THIS FUNCTION IS REMOVED // // THIS FUNCTION IS REMOVED
// virtual Status // virtual Status

View File

@ -2,6 +2,7 @@
#include <vector> #include <vector>
#include <assert.h> #include <assert.h>
#include <stdexcept>
#include "utils/Types.h" #include "utils/Types.h"
// #include "knowhere/index/Index.h" // #include "knowhere/index/Index.h"

View File

@ -1,8 +1,9 @@
#include <dog_segment/SegmentNaive.h> #include <dog_segment/SegmentNaive.h>
#include <random>
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
#include <thread> #include <thread>
#include <queue>
namespace milvus::dog_segment { namespace milvus::dog_segment {
int int
@ -171,14 +172,32 @@ SegmentNaive::QueryImpl(const query::QueryPtr& query, Timestamp timestamp, Query
} }
Status Status
SegmentNaive::Query(const query::QueryPtr& query, Timestamp timestamp, QueryResult& result) { SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& result) {
// TODO: enable delete // TODO: enable delete
// TODO: enable index // TODO: enable index
auto& field = schema_->operator[](query->field_name);
if(query_info == nullptr) {
query_info = std::make_shared<query::Query>();
query_info->field_name = "fakevec";
query_info->topK = 10;
query_info->num_queries = 1;
auto dim = schema_->operator[]("fakevec").get_dim();
std::default_random_engine e(42);
std::uniform_real_distribution<> dis(0.0, 1.0);
query_info->query_raw_data.resize(query_info->num_queries * dim);
for(auto& x: query_info->query_raw_data) {
x = dis(e);
}
}
auto& field = schema_->operator[](query_info->field_name);
assert(field.get_data_type() == DataType::VECTOR_FLOAT); assert(field.get_data_type() == DataType::VECTOR_FLOAT);
auto dim = field.get_dim(); auto dim = field.get_dim();
auto topK = query->topK; auto topK = query_info->topK;
auto num_queries = query_info->num_queries;
int64_t barrier = [&] int64_t barrier = [&]
{ {
@ -197,24 +216,68 @@ SegmentNaive::Query(const query::QueryPtr& query, Timestamp timestamp, QueryResu
return beg; return beg;
}(); }();
if(topK > barrier) {
topK = barrier;
}
auto get_L2_distance = [dim](const float* a, const float* b) {
float L2_distance = 0;
for(auto i = 0; i < dim; ++i) {
auto d = a[i] - b[i];
L2_distance += d * d;
}
return L2_distance;
};
std::vector<std::priority_queue<std::pair<float, int>>> records(num_queries);
// TODO: optimize // TODO: optimize
auto vec_ptr = std::static_pointer_cast<ConcurrentVector<float>>(record_.entity_vec_[0]); auto vec_ptr = std::static_pointer_cast<ConcurrentVector<float>>(record_.entity_vec_[0]);
for(int64_t i = 0; i < barrier; ++i) { for(int64_t i = 0; i < barrier; ++i) {
auto element = vec_ptr->get_element(i); auto element = vec_ptr->get_element(i);
for(auto query_id = 0; query_id < num_queries; ++query_id) {
auto query_blob = query_info->query_raw_data.data() + query_id * dim;
auto dis = get_L2_distance(query_blob, element);
auto& record = records[query_id];
if(record.size() < topK) {
record.emplace(dis, i);
} else if(record.top().first > dis) {
record.emplace(dis, i);
record.pop();
}
}
}
throw std::runtime_error("unimplemented");
result.num_queries_ = num_queries;
result.topK_ = topK;
auto row_num = topK * num_queries;
result.row_num_ = topK * num_queries;
result.result_ids_.resize(row_num);
result.result_distances_.resize(row_num);
for(int q_id = 0; q_id < num_queries; ++q_id) {
// reverse
for(int i = 0; i < topK; ++i) {
auto dst_id = topK - 1 - i + q_id * topK;
auto [dis, offset] = records[q_id].top();
records[q_id].pop();
result.result_ids_[dst_id] = record_.uids_[offset];
result.result_distances_[dst_id] = dis;
}
} }
return Status::OK(); return Status::OK();
// find end of binary // find end of binary
// throw std::runtime_error("unimplemented"); // throw std::runtime_error("unimplemented");
// auto record_ptr = GetMutableRecord(); // auto record_ptr = GetMutableRecord();
// if (record_ptr) { // if (record_ptr) {
// return QueryImpl(*record_ptr, query, timestamp, result); // return QueryImpl(*record_ptr, query, timestamp, result);
// } else { // } else {
// assert(ready_immutable_); // assert(ready_immutable_);
// return QueryImpl(*record_immutable_, query, timestamp, result); // return QueryImpl(*record_immutable_, query, timestamp, result);
// } // }
} }
Status Status

View File

@ -58,7 +58,7 @@ class SegmentNaive : public SegmentBase {
// query contains metadata of // query contains metadata of
Status Status
Query(const query::QueryPtr& query, Timestamp timestamp, QueryResult& results) override; Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult& results) override;
// stop receive insert requests // stop receive insert requests
// will move data to immutable vector or something // will move data to immutable vector or something

View File

@ -137,8 +137,10 @@ struct AttrsData {
/////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////
struct QueryResult { struct QueryResult {
uint64_t row_num_; uint64_t row_num_; // row_num_ = topK * num_queries_
engine::ResultIds result_ids_; uint64_t topK_;
uint64_t num_queries_; // currently must be 1
engine::ResultIds result_ids_; // top1, top2, ..;
engine::ResultDistances result_distances_; engine::ResultDistances result_distances_;
// engine::DataChunkPtr data_chunk_; // engine::DataChunkPtr data_chunk_;
}; };

View File

@ -134,11 +134,11 @@ TEST(CApiTest, SearchTest) {
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
assert(ins_res == 0); assert(ins_res == 0);
long result_ids; long result_ids[10];
float result_distances; float result_distances[10];
auto sea_res = Search(segment, nullptr, 0, &result_ids, &result_distances); auto sea_res = Search(segment, nullptr, 0, result_ids, result_distances);
assert(sea_res == 0); assert(sea_res == 0);
assert(result_ids == 104490); assert(result_ids[0] == 100911);
DeleteCollection(collection); DeleteCollection(collection);
DeletePartition(partition); DeletePartition(partition);

View File

@ -1,31 +0,0 @@
package conf
import (
"fmt"
"path"
"os"
"github.com/BurntSushi/toml"
)
type StorageConfig struct {
Driver string
}
var config *StorageConfig = new(StorageConfig)
func GetConfig() *StorageConfig {
return config
}
func init() {
//读取配置文件
dirPath, _ := os.Getwd()
filePath := path.Join(dirPath, "config/storage.toml")
fmt.Println("aaa")
fmt.Println(filePath)
fmt.Println("bbb")
_, err := toml.DecodeFile(filePath, config)
if err != nil {
fmt.Println(err)
}
}

View File

@ -1,16 +0,0 @@
package conf_test
import (
"fmt"
"os"
"storage/pkg/conf"
"testing"
)
func TestMain(m *testing.M) {
exitCode := m.Run()
fmt.Println("haha")
config := conf.GetConfig()
fmt.Println(config.Driver)
os.Exit(exitCode)
}

View File

@ -1,31 +0,0 @@
package conf
import (
"fmt"
"path"
"os"
"github.com/BurntSushi/toml"
)
type StorageConfig struct {
Driver string
}
var config *StorageConfig = new(StorageConfig)
func GetConfig() *StorageConfig {
return config
}
func init() {
//读取配置文件
dirPath, _ := os.Getwd()
filePath := path.Join(dirPath, "config/storage.toml")
fmt.Println("aaa")
fmt.Println(filePath)
fmt.Println("bbb")
_, err := toml.DecodeFile(filePath, config)
if err != nil {
fmt.Println(err)
}
}

View File

@ -1,16 +0,0 @@
package conf_test
import (
"fmt"
"os"
"storage/pkg/conf"
"testing"
)
func TestMain(m *testing.M) {
exitCode := m.Run()
fmt.Println("haha")
config := conf.GetConfig()
fmt.Println(config.Driver)
os.Exit(exitCode)
}