Fix BruteForce, add restriction for indexing builder

Signed-off-by: FluorineDog <guilin.gou@zilliz.com>
pull/4973/head^2
FluorineDog 2020-09-18 11:03:28 +08:00 committed by yefu.chen
parent 3acd55f675
commit fe9040f326
26 changed files with 114 additions and 371 deletions

View File

@ -4,8 +4,6 @@ import (
"github.com/czs007/suvlim/storage/pkg/types"
yaml "gopkg.in/yaml.v2"
"io/ioutil"
"path"
"runtime"
)
// yaml.MapSlice
@ -60,14 +58,10 @@ func init() {
load_config()
}
func getCurrentFileDir() string {
_, fpath, _, _ := runtime.Caller(0)
return path.Dir(fpath)
}
func load_config() {
filePath := path.Join(getCurrentFileDir(), "config.yaml")
source, err := ioutil.ReadFile(filePath)
//var config ServerConfig
filename := "../conf/config.yaml"
source, err := ioutil.ReadFile(filename)
if err != nil {
panic(err)
}

View File

@ -15,8 +15,8 @@ master:
etcd:
address: localhost
port: 2379
rootpath: suvlim
port: 0
rootpath: a
segthreshold: 10000
timesync:

View File

@ -8,6 +8,7 @@
#include <knowhere/index/vector_index/adapter/VectorAdapter.h>
#include <knowhere/index/vector_index/VecIndexFactory.h>
#include <faiss/utils/distances.h>
#include <tbb/iterators.h>
namespace milvus::dog_segment {
@ -330,6 +331,28 @@ SegmentNaive::QueryImpl(query::QueryPtr query_info, Timestamp timestamp, QueryRe
return Status::OK();
}
void
merge_into(int64_t queries, int64_t topk, float *distances, int64_t *uids, const float *new_distances, const int64_t *new_uids) {
for(int64_t qn = 0; qn < queries; ++qn) {
auto base = qn * topk;
auto dst_dis = distances + base;
auto dst_uids = uids + base;
auto src_dis = new_distances + base;
auto src_uids = new_uids + base;
std::vector<float> buf_dis(2*topk);
std::vector<int64_t> buf_uids(2*topk);
auto zip_src = tbb::make_zip_iterator(src_dis, src_uids);
auto zip_dst = tbb::make_zip_iterator(dst_dis, dst_uids);
auto zip_buf = tbb::make_zip_iterator(buf_dis.data(), buf_uids.data());
auto fuck = zip_src + 1;
std::merge(zip_dst, zip_dst + topk, zip_src, zip_src + topk, zip_buf);
std::copy_n(zip_buf, topk, zip_dst);
}
}
Status
SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestamp, QueryResult &results) {
auto ins_barrier = get_barrier(record_, timestamp);
@ -343,12 +366,50 @@ SegmentNaive::QueryBruteForceImpl(query::QueryPtr query_info, Timestamp timestam
auto bitmap = bitmap_holder->bitmap_ptr;
auto topK = query_info->topK;
auto num_queries = query_info->num_queries;
auto total_count = topK * num_queries;
// TODO: optimize
auto the_offset_opt = schema_->get_offset(query_info->field_name);
assert(the_offset_opt.has_value());
auto vec_ptr = std::static_pointer_cast<ConcurrentVector<float>>(record_.entity_vec_.at(the_offset_opt.value()));
throw std::runtime_error("unimplemented");
std::vector<int64_t> final_uids(total_count);
std::vector<float> final_dis(total_count, std::numeric_limits<float>::max());
auto max_chunk = (ins_barrier + DefaultElementPerChunk - 1) / DefaultElementPerChunk;
for (int chunk_id = 0; chunk_id < max_chunk; ++chunk_id) {
std::vector<int64_t> buf_uids(total_count, -1);
std::vector<float> buf_dis(total_count, std::numeric_limits<float>::max());
faiss::float_maxheap_array_t buf = {
(size_t)num_queries, (size_t)topK, buf_uids.data(), buf_dis.data()};
auto src_data = vec_ptr->get_chunk(chunk_id).data();
auto nsize = chunk_id != max_chunk - 1? DefaultElementPerChunk: ins_barrier - chunk_id * DefaultElementPerChunk;
auto offset = chunk_id * DefaultElementPerChunk;
faiss::knn_L2sqr(query_info->query_raw_data.data(), src_data, dim, num_queries, nsize, &buf, bitmap, offset);
if(chunk_id == 0) {
final_uids = buf_uids;
final_dis = buf_dis;
} else {
merge_into(num_queries, topK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data());
}
}
for(auto& id: final_uids) {
id = record_.uids_[id];
}
results.result_ids_ = std::move(final_uids);
results.result_distances_ = std::move(final_dis);
results.topK_ = topK;
results.num_queries_ = num_queries;
results.row_num_ = total_count;
// throw std::runtime_error("unimplemented");
return Status::OK();
}
@ -445,7 +506,7 @@ SegmentNaive::Query(query::QueryPtr query_info, Timestamp timestamp, QueryResult
if (index_ready_) {
return QueryImpl(query_info, timestamp, result);
} else {
return QuerySlowImpl(query_info, timestamp, result);
return QueryBruteForceImpl(query_info, timestamp, result);
}
}
@ -493,6 +554,9 @@ knowhere::IndexPtr SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry &entry
Status
SegmentNaive::BuildIndex() {
if(record_.ack_responder_.GetAck() < 1024 * 4) {
return Status(SERVER_BUILD_INDEX_ERROR, "too few elements");
}
for (auto&[index_name, entry]: index_meta_->get_entries()) {
assert(entry.index_name == index_name);
const auto &field = (*schema_)[entry.field_name];

View File

@ -1,10 +1,2 @@
#include <algorithm>
#include <tbb/iterators.h>
namespace milvus::dog_segment {
class SmartBruteForce {
SmartBruteForce(int64_t queries, int64_t topK) {
}
};
}

View File

@ -279,7 +279,8 @@ static void knn_L2sqr_sse (
#pragma omp parallel for schedule(static)
for (size_t j = 0; j < ny; j++) {
if(!bitset_base || !bitset_base->test(j + offset)) {
auto test_bit = bitset_base && j + offset < bitset_base->capacity() && bitset_base->test(j + offset);
if(!test_bit) {
size_t thread_no = omp_get_thread_num();
const float *y_j = y + j * d;
const float *x_i = x + x_from * d;
@ -343,7 +344,8 @@ static void knn_L2sqr_sse (
}
for (size_t j = 0; j < ny; j++) {
if (!bitset_base || !bitset_base->test(j + offset)) {
auto test_bit = bitset_base && j + offset < bitset_base->capacity() && bitset_base->test(j + offset);
if (!test_bit) {
float disij = fvec_L2sqr (x_i, y_j, d);
if (disij < val_[0]) {
maxheap_swap_top (k, val_, ids_, disij, j);
@ -473,7 +475,8 @@ static void knn_L2sqr_blas (const float * x,
const float *ip_line = ip_block + (i - i0) * (j1 - j0);
for (size_t j = j0; j < j1; j++) {
if(!bitset_base || !bitset_base->test(j + offset)){
auto test_bit = bitset_base && j + offset < bitset_base->capacity() && bitset_base->test(j + offset);
if(!test_bit){
float ip = *ip_line;
float dis = x_norms[i] + y_norms[j] - 2 * ip;

View File

@ -21,43 +21,20 @@ var (
errTxnFailed = errors.New("failed to commit transaction")
)
type EtcdKVBase struct {
type etcdKVBase struct {
client *clientv3.Client
rootPath string
}
// NewEtcdKVBase creates a new etcd kv.
func NewEtcdKVBase(client *clientv3.Client, rootPath string) *EtcdKVBase {
return &EtcdKVBase{
func NewEtcdKVBase(client *clientv3.Client, rootPath string) *etcdKVBase {
return &etcdKVBase{
client: client,
rootPath: rootPath,
}
}
func (kv *EtcdKVBase) LoadWithPrefix(key string) ( []string, []string) {
key = path.Join(kv.rootPath, key)
println("in loadWithPrefix,", key)
resp, err := etcdutil.EtcdKVGet(kv.client, key,clientv3.WithPrefix())
if err != nil {
return [] string {}, [] string {}
}
var keys []string
var values []string
for _,kvs := range resp.Kvs{
//println(len(kvs.))
if len(kvs.Key) <= 0{
println("KKK")
continue
}
keys = append(keys, string(kvs.Key))
values = append(values, string(kvs.Value))
}
//println(keys)
//println(values)
return keys, values
}
func (kv *EtcdKVBase) Load(key string) (string, error) {
func (kv *etcdKVBase) Load(key string) (string, error) {
key = path.Join(kv.rootPath, key)
resp, err := etcdutil.EtcdKVGet(kv.client, key)
@ -72,7 +49,7 @@ func (kv *EtcdKVBase) Load(key string) (string, error) {
return string(resp.Kvs[0].Value), nil
}
func (kv *EtcdKVBase) Save(key, value string) error {
func (kv *etcdKVBase) Save(key, value string) error {
key = path.Join(kv.rootPath, key)
txn := NewSlowLogTxn(kv.client)
@ -87,7 +64,7 @@ func (kv *EtcdKVBase) Save(key, value string) error {
return nil
}
func (kv *EtcdKVBase) Remove(key string) error {
func (kv *etcdKVBase) Remove(key string) error {
key = path.Join(kv.rootPath, key)
txn := NewSlowLogTxn(kv.client)
@ -102,18 +79,12 @@ func (kv *EtcdKVBase) Remove(key string) error {
return nil
}
func (kv *EtcdKVBase) Watch(key string) clientv3.WatchChan {
func (kv *etcdKVBase) Watch(key string) clientv3.WatchChan {
key = path.Join(kv.rootPath, key)
rch := kv.client.Watch(context.Background(), key)
return rch
}
func (kv *EtcdKVBase) WatchWithPrefix(key string) clientv3.WatchChan {
key = path.Join(kv.rootPath, key)
rch := kv.client.Watch(context.Background(), key, clientv3.WithPrefix())
return rch
}
// SlowLogTxn wraps etcd transaction and log slow one.
type SlowLogTxn struct {
clientv3.Txn

View File

@ -7,6 +7,4 @@ type Base interface {
Save(key, value string) error
Remove(key string) error
Watch(key string) clientv3.WatchChan
WatchWithPrefix(key string) clientv3.WatchChan
LoadWithPrefix(key string) ( []string, []string)
}

View File

@ -2,9 +2,9 @@ package reader
/*
#cgo CFLAGS: -I${SRCDIR}/../../core/include
#cgo CFLAGS: -I../core/include
#cgo LDFLAGS: -L${SRCDIR}/../../core/lib -lmilvus_dog_segment -Wl,-rpath=${SRCDIR}/../../core/lib
#cgo LDFLAGS: -L../core/lib -lmilvus_dog_segment -Wl,-rpath=../core/lib
#include "collection_c.h"
#include "partition_c.h"
@ -16,7 +16,6 @@ import "C"
type Collection struct {
CollectionPtr C.CCollection
CollectionName string
CollectionID uint64
Partitions []*Partition
}

View File

@ -1,26 +0,0 @@
package main
import (
reader "github.com/czs007/suvlim/reader/read_node"
"sync"
)
func main() {
pulsarURL := "pulsar://localhost:6650"
numOfQueryNode := 2
go reader.StartQueryNode(pulsarURL, numOfQueryNode, 0)
reader.StartQueryNode(pulsarURL, numOfQueryNode, 1)
}
func main2() {
wg := sync.WaitGroup{}
//ctx, cancel := context.WithCancel(context.Background())
//defer cancel()
wg.Add(1)
reader.StartQueryNode2()
wg.Wait()
}

View File

@ -2,9 +2,9 @@ package reader
/*
#cgo CFLAGS: -I${SRCDIR}/../../core/include
#cgo CFLAGS: -I../core/include
#cgo LDFLAGS: -L${SRCDIR}/../../core/lib -lmilvus_dog_segment -Wl,-rpath=${SRCDIR}/../../core/lib
#cgo LDFLAGS: -L../core/lib -lmilvus_dog_segment -Wl,-rpath=../core/lib
#include "collection_c.h"
#include "partition_c.h"

View File

@ -2,9 +2,9 @@ package reader
/*
#cgo CFLAGS: -I${SRCDIR}/../../core/include
#cgo CFLAGS: -I../core/include
#cgo LDFLAGS: -L${SRCDIR}/../../core/lib -lmilvus_dog_segment -Wl,-rpath=${SRCDIR}/../../core/lib
#cgo LDFLAGS: -L../core/lib -lmilvus_dog_segment -Wl,-rpath=../core/lib
#include "collection_c.h"
#include "partition_c.h"
@ -15,14 +15,11 @@ import "C"
import (
"fmt"
msgPb "github.com/czs007/suvlim/pkg/master/grpc/message"
"github.com/czs007/suvlim/reader/message_client"
"sort"
"sync"
"sync/atomic"
msgPb "github.com/czs007/suvlim/pkg/master/grpc/message"
"github.com/czs007/suvlim/pkg/master/kv"
"github.com/czs007/suvlim/reader/message_client"
//"github.com/stretchr/testify/assert"
)
type InsertData struct {
@ -57,17 +54,16 @@ type QueryNodeDataBuffer struct {
}
type QueryNode struct {
QueryNodeId uint64
Collections []*Collection
SegmentsMap map[int64]*Segment
messageClient *message_client.MessageClient
QueryNodeId uint64
Collections []*Collection
SegmentsMap map[int64]*Segment
messageClient *message_client.MessageClient
//mc *message_client.MessageClient
queryNodeTimeSync *QueryNodeTime
buffer QueryNodeDataBuffer
deletePreprocessData DeletePreprocessData
deleteData DeleteData
insertData InsertData
kvBase *kv.EtcdKVBase
}
func NewQueryNode(queryNodeId uint64, timeSync uint64) *QueryNode {
@ -91,12 +87,12 @@ func NewQueryNode(queryNodeId uint64, timeSync uint64) *QueryNode {
}
return &QueryNode{
QueryNodeId: queryNodeId,
Collections: nil,
SegmentsMap: segmentsMap,
messageClient: &mc,
queryNodeTimeSync: queryNodeTimeSync,
buffer: buffer,
QueryNodeId: queryNodeId,
Collections: nil,
SegmentsMap: segmentsMap,
messageClient: &mc,
queryNodeTimeSync: queryNodeTimeSync,
buffer: buffer,
}
}
@ -123,12 +119,12 @@ func CreateQueryNode(queryNodeId uint64, timeSync uint64, mc *message_client.Mes
}
return &QueryNode{
QueryNodeId: queryNodeId,
Collections: nil,
SegmentsMap: segmentsMap,
messageClient: mc,
queryNodeTimeSync: queryNodeTimeSync,
buffer: buffer,
QueryNodeId: queryNodeId,
Collections: nil,
SegmentsMap: segmentsMap,
messageClient: mc,
queryNodeTimeSync: queryNodeTimeSync,
buffer: buffer,
}
}
@ -177,7 +173,7 @@ func (node *QueryNode) DeleteCollection(collection *Collection) {
////////////////////////////////////////////////////////////////////////////////////////////////////
func (node *QueryNode) PrepareBatchMsg() []int {
var msgLen = node.messageClient.PrepareBatchMsg()
var msgLen= node.messageClient.PrepareBatchMsg()
return msgLen
}
@ -193,7 +189,7 @@ func (node *QueryNode) InitQueryNodeCollection() {
////////////////////////////////////////////////////////////////////////////////////////////////////
func (node *QueryNode) RunInsertDelete(wg *sync.WaitGroup) {
func (node *QueryNode) RunInsertDelete(wg * sync.WaitGroup) {
for {
// TODO: get timeRange from message client
var msgLen = node.PrepareBatchMsg()
@ -275,7 +271,7 @@ func (node *QueryNode) MessagesPreprocess(insertDeleteMessages []*msgPb.InsertOr
}
// 2. Remove invalid messages from buffer.
tmpInsertOrDeleteBuffer := make([]*msgPb.InsertOrDeleteMsg, 0)
tmpInsertOrDeleteBuffer := make([]*msgPb.InsertOrDeleteMsg ,0)
for i, isValid := range node.buffer.validInsertDeleteBuffer {
if isValid {
tmpInsertOrDeleteBuffer = append(tmpInsertOrDeleteBuffer, node.buffer.InsertDeleteBuffer[i])

View File

@ -1,233 +0,0 @@
package reader
import (
"context"
"fmt"
"github.com/czs007/suvlim/pkg/master/mock"
"reflect"
"strconv"
"strings"
"sync"
"time"
"github.com/czs007/suvlim/conf"
"github.com/czs007/suvlim/pkg/master/kv"
"go.etcd.io/etcd/clientv3"
"go.etcd.io/etcd/mvcc/mvccpb"
)
const (
CollectonPrefix = "/collection/"
SegmentPrefix = "/segment/"
)
func GetCollectionObjId(key string) string {
prefix := conf.Config.Etcd.Rootpath + CollectonPrefix
return strings.TrimPrefix(key, prefix)
}
func GetSegmentObjId(key string) string {
prefix := conf.Config.Etcd.Rootpath + SegmentPrefix
return strings.TrimPrefix(key, prefix)
}
func isCollectionObj(key string) bool {
prefix := conf.Config.Etcd.Rootpath + CollectonPrefix
prefix = strings.TrimSpace(prefix)
println("prefix is :$", prefix)
index := strings.Index(key, prefix)
println("index is :", index)
return index == 0
}
func isSegmentObj(key string) bool {
prefix := conf.Config.Etcd.Rootpath + SegmentPrefix
prefix = strings.TrimSpace(prefix)
index := strings.Index(key, prefix)
return index == 0
}
func printCollectionStruct(obj *mock.Collection){
v := reflect.ValueOf(obj)
v = reflect.Indirect(v)
typeOfS := v.Type()
for i := 0; i< v.NumField(); i++ {
if typeOfS.Field(i).Name == "GrpcMarshalString"{
continue
}
fmt.Printf("Field: %s\tValue: %v\n", typeOfS.Field(i).Name, v.Field(i).Interface())
}
}
func printSegmentStruct(obj *mock.Segment){
v := reflect.ValueOf(obj)
v = reflect.Indirect(v)
typeOfS := v.Type()
for i := 0; i< v.NumField(); i++ {
fmt.Printf("Field: %s\tValue: %v\n", typeOfS.Field(i).Name, v.Field(i).Interface())
}
}
func (node *QueryNode) processCollectionCreate(id string, value string) {
println(fmt.Sprintf("Create Collection:$%s$", id))
collection, err := mock.JSON2Collection(value)
if err != nil {
println("error of json 2 collection")
println(err.Error())
}
printCollectionStruct(collection)
}
func (node *QueryNode) processSegmentCreate(id string, value string) {
println("Create Segment: ", id)
segment, err := mock.JSON2Segment(value)
if err != nil {
println("error of json 2 segment")
println(err.Error())
}
printSegmentStruct(segment)
}
func (node *QueryNode) processCreate(key string, msg string) {
println("process create", key, ":", msg)
if isCollectionObj(key){
objID := GetCollectionObjId(key)
node.processCollectionCreate(objID, msg)
}else if isSegmentObj(key){
objID := GetSegmentObjId(key)
node.processSegmentCreate(objID, msg)
}else {
println("can not process create msg:", key)
}
}
func (node *QueryNode) processSegmentModify(id string, value string) {
println("Modify Segment: ", id)
segment, err := mock.JSON2Segment(value)
if err != nil {
println("error of json 2 segment")
println(err.Error())
}
printSegmentStruct(segment)
}
func (node *QueryNode) processCollectionModify(id string, value string) {
println("Modify Collection: ", id)
collection, err := mock.JSON2Collection(value)
if err != nil {
println("error of json 2 collection")
println(err.Error())
}
printCollectionStruct(collection)
}
func (node *QueryNode) processModify(key string, msg string){
println("process modify")
if isCollectionObj(key){
objID := GetCollectionObjId(key)
node.processCollectionModify(objID, msg)
}else if isSegmentObj(key){
objID := GetSegmentObjId(key)
node.processSegmentModify(objID, msg)
}else {
println("can not process modify msg:", key)
}
}
func (node *QueryNode) processSegmentDelete(id string){
println("Delete segment: ", id)
}
func (node *QueryNode) processCollectionDelete(id string){
println("Delete collection: ", id)
}
func (node *QueryNode) processDelete(key string){
println("process delete")
if isCollectionObj(key){
objID := GetCollectionObjId(key)
node.processCollectionDelete(objID)
}else if isSegmentObj(key){
objID := GetSegmentObjId(key)
node.processSegmentDelete(objID)
}else {
println("can not process delete msg:", key)
}
}
func (node *QueryNode) processResp(resp clientv3.WatchResponse) error {
err := resp.Err()
if err != nil {
return err
}
for _, ev := range resp.Events {
if ev.IsCreate() {
key := string(ev.Kv.Key)
msg := string(ev.Kv.Value)
node.processCreate(key, msg)
} else if ev.IsModify() {
key := string(ev.Kv.Key)
msg := string(ev.Kv.Value)
node.processModify(key, msg)
} else if ev.Type == mvccpb.DELETE {
key := string(ev.Kv.Key)
node.processDelete(key)
} else {
println("Unrecognized etcd msg!")
}
}
return nil
}
func (node *QueryNode) loadCollections() error {
keys, values := node.kvBase.LoadWithPrefix(CollectonPrefix)
for i:= range keys{
objID := GetCollectionObjId(keys[i])
node.processCollectionCreate(objID, values[i])
}
return nil
}
func (node *QueryNode) loadSegments() error {
keys, values := node.kvBase.LoadWithPrefix(SegmentPrefix)
for i:= range keys{
objID := GetSegmentObjId(keys[i])
node.processSegmentCreate(objID, values[i])
}
return nil
}
func (node *QueryNode) InitFromMeta() error {
//pass
etcdAddr := "http://"
etcdAddr += conf.Config.Etcd.Address
etcdPort := conf.Config.Etcd.Port
etcdAddr = etcdAddr + ":" + strconv.FormatInt(int64(etcdPort), 10)
cli, _ := clientv3.New(clientv3.Config{
Endpoints: []string{etcdAddr},
DialTimeout: 5 * time.Second,
})
defer cli.Close()
node.kvBase = kv.NewEtcdKVBase(cli, conf.Config.Etcd.Rootpath)
node.loadCollections()
node.loadSegments()
return nil
}
func (node *QueryNode) RunMetaService(ctx context.Context, wg *sync.WaitGroup) {
node.InitFromMeta()
metaChan := node.kvBase.WatchWithPrefix("")
for {
select {
case <-ctx.Done():
wg.Done()
println("DONE!!!!!!")
return
case resp := <-metaChan:
node.processResp(resp)
}
}
}

View File

@ -1,11 +1,9 @@
package reader
import (
"context"
"github.com/czs007/suvlim/reader/message_client"
"log"
"sync"
"github.com/czs007/suvlim/reader/message_client"
)
func StartQueryNode(pulsarURL string, numOfQueryNode int, messageClientID int) {
@ -34,16 +32,3 @@ func StartQueryNode(pulsarURL string, numOfQueryNode int, messageClientID int) {
wg.Wait()
qn.Close()
}
func StartQueryNode2() {
ctx := context.Background()
qn := CreateQueryNode(0, 0, nil)
//qn.InitQueryNodeCollection()
wg := sync.WaitGroup{}
wg.Add(1)
//go qn.RunInsertDelete(&wg)
//go qn.RunSearch(&wg)
go qn.RunMetaService(ctx, &wg)
wg.Wait()
qn.Close()
}

View File

@ -2,9 +2,9 @@ package reader
/*
#cgo CFLAGS: -I${SRCDIR}/../../core/include
#cgo CFLAGS: -I../core/include
#cgo LDFLAGS: -L${SRCDIR}/../../core/lib -lmilvus_dog_segment -Wl,-rpath=${SRCDIR}/../../core/lib
#cgo LDFLAGS: -L../core/lib -lmilvus_dog_segment -Wl,-rpath=../core/lib
#include "collection_c.h"
#include "partition_c.h"