Add verifier

Signed-off-by: FluorineDog <guilin.gou@zilliz.com>
pull/4973/head^2
FluorineDog 2021-01-05 20:13:54 +08:00 committed by yefu.chen
parent bccf8d2098
commit e16a7478da
40 changed files with 417 additions and 2898 deletions

View File

@ -1,315 +0,0 @@
package main
import (
"context"
"crypto/md5"
"flag"
"fmt"
"log"
"math/rand"
"os"
"sync"
"sync/atomic"
"time"
"github.com/pivotal-golang/bytefmt"
"github.com/zilliztech/milvus-distributed/internal/storage"
storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type"
)
// Global variables
var durationSecs, threads, loops, numVersion, batchOpSize int
var valueSize uint64
var valueData []byte
var batchValueData [][]byte
var counter, totalKeyCount, keyNum int32
var endTime, setFinish, getFinish, deleteFinish time.Time
var totalKeys [][]byte
var logFileName = "benchmark.log"
var logFile *os.File
var store storagetype.Store
var wg sync.WaitGroup
func runSet() {
for time.Now().Before(endTime) {
num := atomic.AddInt32(&keyNum, 1)
key := []byte(fmt.Sprint("key", num))
for ver := 1; ver <= numVersion; ver++ {
atomic.AddInt32(&counter, 1)
err := store.PutRow(context.Background(), key, valueData, "empty", uint64(ver))
if err != nil {
log.Fatalf("Error setting key %s, %s", key, err.Error())
//atomic.AddInt32(&setCount, -1)
}
}
}
// Remember last done time
setFinish = time.Now()
wg.Done()
}
func runBatchSet() {
for time.Now().Before(endTime) {
num := atomic.AddInt32(&keyNum, int32(batchOpSize))
keys := make([][]byte, batchOpSize)
versions := make([]uint64, batchOpSize)
batchSuffix := make([]string, batchOpSize)
for n := batchOpSize; n > 0; n-- {
keys[n-1] = []byte(fmt.Sprint("key", num-int32(n)))
}
for ver := 1; ver <= numVersion; ver++ {
atomic.AddInt32(&counter, 1)
err := store.PutRows(context.Background(), keys, batchValueData, batchSuffix, versions)
if err != nil {
log.Fatalf("Error setting batch keys %s %s", keys, err.Error())
//atomic.AddInt32(&batchSetCount, -1)
}
}
}
setFinish = time.Now()
wg.Done()
}
func runGet() {
for time.Now().Before(endTime) {
num := atomic.AddInt32(&counter, 1)
//num := atomic.AddInt32(&keyNum, 1)
//key := []byte(fmt.Sprint("key", num))
num = num % totalKeyCount
key := totalKeys[num]
_, err := store.GetRow(context.Background(), key, uint64(numVersion))
if err != nil {
log.Fatalf("Error getting key %s, %s", key, err.Error())
//atomic.AddInt32(&getCount, -1)
}
}
// Remember last done time
getFinish = time.Now()
wg.Done()
}
func runBatchGet() {
for time.Now().Before(endTime) {
num := atomic.AddInt32(&keyNum, int32(batchOpSize))
//keys := make([][]byte, batchOpSize)
//for n := batchOpSize; n > 0; n-- {
// keys[n-1] = []byte(fmt.Sprint("key", num-int32(n)))
//}
end := num % totalKeyCount
if end < int32(batchOpSize) {
end = int32(batchOpSize)
}
start := end - int32(batchOpSize)
keys := totalKeys[start:end]
versions := make([]uint64, batchOpSize)
for i := range versions {
versions[i] = uint64(numVersion)
}
atomic.AddInt32(&counter, 1)
_, err := store.GetRows(context.Background(), keys, versions)
if err != nil {
log.Fatalf("Error getting key %s, %s", keys, err.Error())
//atomic.AddInt32(&batchGetCount, -1)
}
}
// Remember last done time
getFinish = time.Now()
wg.Done()
}
func runDelete() {
for time.Now().Before(endTime) {
num := atomic.AddInt32(&counter, 1)
//num := atomic.AddInt32(&keyNum, 1)
//key := []byte(fmt.Sprint("key", num))
num = num % totalKeyCount
key := totalKeys[num]
err := store.DeleteRow(context.Background(), key, uint64(numVersion))
if err != nil {
log.Fatalf("Error getting key %s, %s", key, err.Error())
//atomic.AddInt32(&deleteCount, -1)
}
}
// Remember last done time
deleteFinish = time.Now()
wg.Done()
}
func runBatchDelete() {
for time.Now().Before(endTime) {
num := atomic.AddInt32(&keyNum, int32(batchOpSize))
//keys := make([][]byte, batchOpSize)
//for n := batchOpSize; n > 0; n-- {
// keys[n-1] = []byte(fmt.Sprint("key", num-int32(n)))
//}
end := num % totalKeyCount
if end < int32(batchOpSize) {
end = int32(batchOpSize)
}
start := end - int32(batchOpSize)
keys := totalKeys[start:end]
atomic.AddInt32(&counter, 1)
versions := make([]uint64, batchOpSize)
for i := range versions {
versions[i] = uint64(numVersion)
}
err := store.DeleteRows(context.Background(), keys, versions)
if err != nil {
log.Fatalf("Error getting key %s, %s", keys, err.Error())
//atomic.AddInt32(&batchDeleteCount, -1)
}
}
// Remember last done time
getFinish = time.Now()
wg.Done()
}
func main() {
// Parse command line
myflag := flag.NewFlagSet("myflag", flag.ExitOnError)
myflag.IntVar(&durationSecs, "d", 5, "Duration of each test in seconds")
myflag.IntVar(&threads, "t", 1, "Number of threads to run")
myflag.IntVar(&loops, "l", 1, "Number of times to repeat test")
var sizeArg string
var storeType string
myflag.StringVar(&sizeArg, "z", "1k", "Size of objects in bytes with postfix K, M, and G")
myflag.StringVar(&storeType, "s", "s3", "Storage type, tikv or minio or s3")
myflag.IntVar(&numVersion, "v", 1, "Max versions for each key")
myflag.IntVar(&batchOpSize, "b", 100, "Batch operation kv pair number")
if err := myflag.Parse(os.Args[1:]); err != nil {
os.Exit(1)
}
// Check the arguments
var err error
if valueSize, err = bytefmt.ToBytes(sizeArg); err != nil {
log.Fatalf("Invalid -z argument for object size: %v", err)
}
var option = storagetype.Option{TikvAddress: "localhost:2379", Type: storeType, BucketName: "zilliz-hz"}
store, err = storage.NewStore(context.Background(), option)
if err != nil {
log.Fatalf("Error when creating storage " + err.Error())
}
logFile, err = os.OpenFile(logFileName, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0777)
if err != nil {
log.Fatalf("Prepare log file error, " + err.Error())
}
// Echo the parameters
log.Printf("Benchmark log will write to file %s\n", logFile.Name())
fmt.Fprintf(logFile, "Parameters: duration=%d, threads=%d, loops=%d, valueSize=%s, batchSize=%d, versions=%d\n", durationSecs, threads, loops, sizeArg, batchOpSize, numVersion)
// Init test data
valueData = make([]byte, valueSize)
rand.Read(valueData)
hasher := md5.New()
hasher.Write(valueData)
batchValueData = make([][]byte, batchOpSize)
for i := range batchValueData {
batchValueData[i] = make([]byte, valueSize)
rand.Read(batchValueData[i])
hasher := md5.New()
hasher.Write(batchValueData[i])
}
// Loop running the tests
for loop := 1; loop <= loops; loop++ {
// reset counters
counter = 0
keyNum = 0
totalKeyCount = 0
totalKeys = nil
// Run the batchSet case
// key seq start from setCount
counter = 0
startTime := time.Now()
endTime = startTime.Add(time.Second * time.Duration(durationSecs))
for n := 1; n <= threads; n++ {
wg.Add(1)
go runBatchSet()
}
wg.Wait()
setTime := setFinish.Sub(startTime).Seconds()
bps := float64(uint64(counter)*valueSize*uint64(batchOpSize)) / setTime
fmt.Fprintf(logFile, "Loop %d: BATCH PUT time %.1f secs, batchs = %d, kv pairs = %d, speed = %sB/sec, %.1f operations/sec, %.1f kv/sec.\n",
loop, setTime, counter, counter*int32(batchOpSize), bytefmt.ByteSize(uint64(bps)), float64(counter)/setTime, float64(counter*int32(batchOpSize))/setTime)
// Record all test keys
//totalKeyCount = keyNum
//totalKeys = make([][]byte, totalKeyCount)
//for i := int32(0); i < totalKeyCount; i++ {
// totalKeys[i] = []byte(fmt.Sprint("key", i))
//}
//
//// Run the get case
//counter = 0
//startTime = time.Now()
//endTime = startTime.Add(time.Second * time.Duration(durationSecs))
//for n := 1; n <= threads; n++ {
// wg.Add(1)
// go runGet()
//}
//wg.Wait()
//
//getTime := getFinish.Sub(startTime).Seconds()
//bps = float64(uint64(counter)*valueSize) / getTime
//fmt.Fprint(logFile, fmt.Sprintf("Loop %d: GET time %.1f secs, kv pairs = %d, speed = %sB/sec, %.1f operations/sec, %.1f kv/sec.\n",
// loop, getTime, counter, bytefmt.ByteSize(uint64(bps)), float64(counter)/getTime, float64(counter)/getTime))
// Run the batchGet case
//counter = 0
//startTime = time.Now()
//endTime = startTime.Add(time.Second * time.Duration(durationSecs))
//for n := 1; n <= threads; n++ {
// wg.Add(1)
// go runBatchGet()
//}
//wg.Wait()
//
//getTime = getFinish.Sub(startTime).Seconds()
//bps = float64(uint64(counter)*valueSize*uint64(batchOpSize)) / getTime
//fmt.Fprint(logFile, fmt.Sprintf("Loop %d: BATCH GET time %.1f secs, batchs = %d, kv pairs = %d, speed = %sB/sec, %.1f operations/sec, %.1f kv/sec.\n",
// loop, getTime, counter, counter*int32(batchOpSize), bytefmt.ByteSize(uint64(bps)), float64(counter)/getTime, float64(counter * int32(batchOpSize))/getTime))
//
//// Run the delete case
//counter = 0
//startTime = time.Now()
//endTime = startTime.Add(time.Second * time.Duration(durationSecs))
//for n := 1; n <= threads; n++ {
// wg.Add(1)
// go runDelete()
//}
//wg.Wait()
//
//deleteTime := deleteFinish.Sub(startTime).Seconds()
//bps = float64(uint64(counter)*valueSize) / deleteTime
//fmt.Fprint(logFile, fmt.Sprintf("Loop %d: Delete time %.1f secs, kv pairs = %d, %.1f operations/sec, %.1f kv/sec.\n",
// loop, deleteTime, counter, float64(counter)/deleteTime, float64(counter)/deleteTime))
//
//// Run the batchDelete case
//counter = 0
//startTime = time.Now()
//endTime = startTime.Add(time.Second * time.Duration(durationSecs))
//for n := 1; n <= threads; n++ {
// wg.Add(1)
// go runBatchDelete()
//}
//wg.Wait()
//
//deleteTime = setFinish.Sub(startTime).Seconds()
//bps = float64(uint64(counter)*valueSize*uint64(batchOpSize)) / setTime
//fmt.Fprint(logFile, fmt.Sprintf("Loop %d: BATCH DELETE time %.1f secs, batchs = %d, kv pairs = %d, %.1f operations/sec, %.1f kv/sec.\n",
// loop, setTime, counter, counter*int32(batchOpSize), float64(counter)/setTime, float64(counter * int32(batchOpSize))/setTime))
// Print line mark
lineMark := "\n"
fmt.Fprint(logFile, lineMark)
}
log.Print("Benchmark test done.")
}

View File

@ -55,7 +55,6 @@ IndexWrapper::parse_impl(const std::string& serialized_params_str, knowhere::Con
}
auto stoi_closure = [](const std::string& s) -> int { return std::stoi(s); };
auto stof_closure = [](const std::string& s) -> int { return std::stof(s); };
/***************************** meta *******************************/
check_parameter<int>(conf, milvus::knowhere::meta::DIM, stoi_closure, std::nullopt);
@ -89,7 +88,7 @@ IndexWrapper::parse_impl(const std::string& serialized_params_str, knowhere::Con
check_parameter<int>(conf, milvus::knowhere::IndexParams::edge_size, stoi_closure, std::nullopt);
/************************** NGT Search Params *****************************/
check_parameter<float>(conf, milvus::knowhere::IndexParams::epsilon, stof_closure, std::nullopt);
check_parameter<int>(conf, milvus::knowhere::IndexParams::epsilon, stoi_closure, std::nullopt);
check_parameter<int>(conf, milvus::knowhere::IndexParams::max_search_edges, stoi_closure, std::nullopt);
/************************** NGT_PANNG Params *****************************/
@ -275,12 +274,6 @@ IndexWrapper::QueryWithParam(const knowhere::DatasetPtr& dataset, const char* se
std::unique_ptr<IndexWrapper::QueryResult>
IndexWrapper::QueryImpl(const knowhere::DatasetPtr& dataset, const knowhere::Config& conf) {
auto load_raw_data_closure = [&]() { LoadRawData(); }; // hide this pointer
auto index_type = get_index_type();
if (is_in_nm_list(index_type)) {
std::call_once(raw_data_loaded_, load_raw_data_closure);
}
auto res = index_->Query(dataset, conf, nullptr);
auto ids = res->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto distances = res->Get<float*>(milvus::knowhere::meta::DISTANCE);
@ -298,19 +291,5 @@ IndexWrapper::QueryImpl(const knowhere::DatasetPtr& dataset, const knowhere::Con
return std::move(query_res);
}
void
IndexWrapper::LoadRawData() {
auto index_type = get_index_type();
if (is_in_nm_list(index_type)) {
auto bs = index_->Serialize(config_);
auto bptr = std::make_shared<milvus::knowhere::Binary>();
auto deleter = [&](uint8_t*) {}; // avoid repeated deconstruction
bptr->data = std::shared_ptr<uint8_t[]>(static_cast<uint8_t*>(raw_data_.data()), deleter);
bptr->size = raw_data_.size();
bs.Append(RAW_DATA, bptr);
index_->Load(bs);
}
}
} // namespace indexbuilder
} // namespace milvus

View File

@ -66,9 +66,6 @@ class IndexWrapper {
void
StoreRawData(const knowhere::DatasetPtr& dataset);
void
LoadRawData();
template <typename T>
void
check_parameter(knowhere::Config& conf,
@ -95,7 +92,6 @@ class IndexWrapper {
milvus::json index_config_;
knowhere::Config config_;
std::vector<uint8_t> raw_data_;
std::once_flag raw_data_loaded_;
};
} // namespace indexbuilder

View File

@ -4,13 +4,15 @@ set(MILVUS_QUERY_SRCS
generated/PlanNode.cpp
generated/Expr.cpp
visitors/ShowPlanNodeVisitor.cpp
visitors/ExecPlanNodeVisitor.cpp
visitors/ShowExprVisitor.cpp
visitors/ExecPlanNodeVisitor.cpp
visitors/ExecExprVisitor.cpp
visitors/VerifyPlanNodeVisitor.cpp
visitors/VerifyExprVisitor.cpp
Plan.cpp
Search.cpp
SearchOnSealed.cpp
BruteForceSearch.cpp
)
add_library(milvus_query ${MILVUS_QUERY_SRCS})
target_link_libraries(milvus_query milvus_proto milvus_utils)
target_link_libraries(milvus_query milvus_proto milvus_utils knowhere)

View File

@ -21,6 +21,7 @@
#include <boost/align/aligned_allocator.hpp>
#include <boost/algorithm/string.hpp>
#include <algorithm>
#include "query/generated/VerifyPlanNodeVisitor.h"
namespace milvus::query {
@ -138,6 +139,8 @@ Parser::CreatePlanImpl(const std::string& dsl_str) {
if (predicate != nullptr) {
vec_node->predicate_ = std::move(predicate);
}
VerifyPlanNodeVisitor verifier;
vec_node->accept(verifier);
auto plan = std::make_unique<Plan>(schema);
plan->tag2field_ = std::move(tag2field_);

View File

@ -21,7 +21,7 @@
#include "ExprVisitor.h"
namespace milvus::query {
class ExecExprVisitor : ExprVisitor {
class ExecExprVisitor : public ExprVisitor {
public:
void
visit(BoolUnaryExpr& expr) override;

View File

@ -19,7 +19,7 @@
#include "PlanNodeVisitor.h"
namespace milvus::query {
class ExecPlanNodeVisitor : PlanNodeVisitor {
class ExecPlanNodeVisitor : public PlanNodeVisitor {
public:
void
visit(FloatVectorANNS& node) override;

View File

@ -19,7 +19,7 @@
#include "ExprVisitor.h"
namespace milvus::query {
class ShowExprVisitor : ExprVisitor {
class ShowExprVisitor : public ExprVisitor {
public:
void
visit(BoolUnaryExpr& expr) override;

View File

@ -20,7 +20,7 @@
#include "PlanNodeVisitor.h"
namespace milvus::query {
class ShowPlanNodeVisitor : PlanNodeVisitor {
class ShowPlanNodeVisitor : public PlanNodeVisitor {
public:
void
visit(FloatVectorANNS& node) override;

View File

@ -0,0 +1,36 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#error TODO: copy this file out, and modify the content.
#include "query/generated/VerifyExprVisitor.h"
namespace milvus::query {
void
VerifyExprVisitor::visit(BoolUnaryExpr& expr) {
// TODO
}
void
VerifyExprVisitor::visit(BoolBinaryExpr& expr) {
// TODO
}
void
VerifyExprVisitor::visit(TermExpr& expr) {
// TODO
}
void
VerifyExprVisitor::visit(RangeExpr& expr) {
// TODO
}
} // namespace milvus::query

View File

@ -0,0 +1,40 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
// Generated File
// DO NOT EDIT
#include <optional>
#include <boost/dynamic_bitset.hpp>
#include <utility>
#include <deque>
#include "segcore/SegmentSmallIndex.h"
#include "query/ExprImpl.h"
#include "ExprVisitor.h"
namespace milvus::query {
class VerifyExprVisitor : public ExprVisitor {
public:
void
visit(BoolUnaryExpr& expr) override;
void
visit(BoolBinaryExpr& expr) override;
void
visit(TermExpr& expr) override;
void
visit(RangeExpr& expr) override;
public:
};
} // namespace milvus::query

View File

@ -0,0 +1,26 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#error TODO: copy this file out, and modify the content.
#include "query/generated/VerifyPlanNodeVisitor.h"
namespace milvus::query {
void
VerifyPlanNodeVisitor::visit(FloatVectorANNS& node) {
// TODO
}
void
VerifyPlanNodeVisitor::visit(BinaryVectorANNS& node) {
// TODO
}
} // namespace milvus::query

View File

@ -0,0 +1,37 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
// Generated File
// DO NOT EDIT
#include "utils/Json.h"
#include "query/PlanImpl.h"
#include "segcore/SegmentBase.h"
#include <utility>
#include "PlanNodeVisitor.h"
namespace milvus::query {
class VerifyPlanNodeVisitor : public PlanNodeVisitor {
public:
void
visit(FloatVectorANNS& node) override;
void
visit(BinaryVectorANNS& node) override;
public:
using RetType = QueryResult;
VerifyPlanNodeVisitor() = default;
private:
std::optional<RetType> ret_;
};
} // namespace milvus::query

View File

@ -0,0 +1,35 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "query/generated/VerifyExprVisitor.h"
namespace milvus::query {
void
VerifyExprVisitor::visit(BoolUnaryExpr& expr) {
// TODO
}
void
VerifyExprVisitor::visit(BoolBinaryExpr& expr) {
// TODO
}
void
VerifyExprVisitor::visit(TermExpr& expr) {
// TODO
}
void
VerifyExprVisitor::visit(RangeExpr& expr) {
// TODO
}
} // namespace milvus::query

View File

@ -0,0 +1,85 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "query/generated/VerifyPlanNodeVisitor.h"
#include "knowhere/index/vector_index/ConfAdapterMgr.h"
#include "segcore/SegmentSmallIndex.h"
#include "knowhere/index/vector_index/ConfAdapter.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
namespace milvus::query {
#if 1
namespace impl {
// THIS CONTAINS EXTRA BODY FOR VISITOR
// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/
class VerifyPlanNodeVisitor : PlanNodeVisitor {
public:
using RetType = QueryResult;
VerifyPlanNodeVisitor() = default;
private:
std::optional<RetType> ret_;
};
} // namespace impl
#endif
static knowhere::IndexType
InferIndexType(const Json& search_params) {
// ivf -> nprobe
// nsg -> search_length
// hnsw/rhnsw/*pq/*sq -> ef
// annoy -> search_k
// ngtpanng / ngtonng -> max_search_edges / epsilon
static const std::map<std::string, knowhere::IndexType> key_list = [] {
std::map<std::string, knowhere::IndexType> list;
namespace ip = knowhere::IndexParams;
namespace ie = knowhere::IndexEnum;
list.emplace(ip::nprobe, ie::INDEX_FAISS_IVFFLAT);
list.emplace(ip::search_length, ie::INDEX_NSG);
list.emplace(ip::ef, ie::INDEX_HNSW);
list.emplace(ip::search_k, ie::INDEX_ANNOY);
list.emplace(ip::max_search_edges, ie::INDEX_NGTONNG);
list.emplace(ip::epsilon, ie::INDEX_NGTONNG);
return list;
}();
auto dbg_str = search_params.dump();
for (auto& kv : search_params.items()) {
std::string key = kv.key();
if (key_list.count(key)) {
return key_list.at(key);
}
}
PanicInfo("failed to infer index type");
}
void
VerifyPlanNodeVisitor::visit(FloatVectorANNS& node) {
auto& search_params = node.query_info_.search_params_;
auto inferred_type = InferIndexType(search_params);
auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(inferred_type);
auto index_mode = knowhere::IndexMode::MODE_CPU;
// mock the api, topk will be passed from placeholder
auto params_copy = search_params;
params_copy[knowhere::meta::TOPK] = 10;
// NOTE: the second parameter is not checked in knowhere, may be redundant
auto passed = adapter->CheckSearch(params_copy, inferred_type, index_mode);
AssertInfo(passed, "invalid search params");
}
void
VerifyPlanNodeVisitor::visit(BinaryVectorANNS& node) {
// TODO
}
} // namespace milvus::query

View File

@ -24,5 +24,6 @@ target_link_libraries(milvus_segcore
dl backtrace
milvus_common
milvus_query
milvus_utils
)

View File

@ -24,10 +24,8 @@ target_link_libraries(all_tests
gtest_main
milvus_segcore
milvus_indexbuilder
knowhere
log
pthread
milvus_utils
)
install (TARGETS all_tests DESTINATION unittest)

View File

@ -137,7 +137,7 @@ TEST(CApiTest, SearchTest) {
auto offset = PreInsert(segment, N);
auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N);
assert(ins_res.error_code == Success);
ASSERT_EQ(ins_res.error_code, Success);
const char* dsl_string = R"(
{
@ -176,11 +176,11 @@ TEST(CApiTest, SearchTest) {
void* plan = nullptr;
auto status = CreatePlan(collection, dsl_string, &plan);
assert(status.error_code == Success);
ASSERT_EQ(status.error_code, Success);
void* placeholderGroup = nullptr;
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
assert(status.error_code == Success);
ASSERT_EQ(status.error_code, Success);
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
@ -189,7 +189,7 @@ TEST(CApiTest, SearchTest) {
CQueryResult search_result;
auto res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, &search_result);
assert(res.error_code == Success);
ASSERT_EQ(res.error_code, Success);
DeletePlan(plan);
DeletePlaceholderGroup(placeholderGroup);

View File

@ -11,8 +11,6 @@
#include <tuple>
#include <map>
#include <limits>
#include <math.h>
#include <gtest/gtest.h>
#include <google/protobuf/text_format.h>
@ -43,16 +41,16 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh
if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_IDMAP) {
return milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
// {milvus::knowhere::meta::TOPK, K},
{milvus::knowhere::Metric::TYPE, metric_type},
{milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE, 4},
};
} else if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ) {
return milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
// {milvus::knowhere::meta::TOPK, K},
{milvus::knowhere::IndexParams::nlist, 100},
{milvus::knowhere::IndexParams::nprobe, 4},
// {milvus::knowhere::IndexParams::nprobe, 4},
{milvus::knowhere::IndexParams::m, 4},
{milvus::knowhere::IndexParams::nbits, 8},
{milvus::knowhere::Metric::TYPE, metric_type},
@ -61,9 +59,9 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh
} else if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) {
return milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
// {milvus::knowhere::meta::TOPK, K},
{milvus::knowhere::IndexParams::nlist, 100},
{milvus::knowhere::IndexParams::nprobe, 4},
// {milvus::knowhere::IndexParams::nprobe, 4},
{milvus::knowhere::Metric::TYPE, metric_type},
{milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE, 4},
#ifdef MILVUS_GPU_VERSION
@ -73,9 +71,9 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh
} else if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8) {
return milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
// {milvus::knowhere::meta::TOPK, K},
{milvus::knowhere::IndexParams::nlist, 100},
{milvus::knowhere::IndexParams::nprobe, 4},
// {milvus::knowhere::IndexParams::nprobe, 4},
{milvus::knowhere::IndexParams::nbits, 8},
{milvus::knowhere::Metric::TYPE, metric_type},
{milvus::knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE, 4},
@ -86,9 +84,9 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh
} else if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT) {
return milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
// {milvus::knowhere::meta::TOPK, K},
{milvus::knowhere::IndexParams::nlist, 100},
{milvus::knowhere::IndexParams::nprobe, 4},
// {milvus::knowhere::IndexParams::nprobe, 4},
{milvus::knowhere::IndexParams::m, 4},
{milvus::knowhere::IndexParams::nbits, 8},
{milvus::knowhere::Metric::TYPE, metric_type},
@ -97,14 +95,13 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh
} else if (index_type == milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP) {
return milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
// {milvus::knowhere::meta::TOPK, K},
{milvus::knowhere::Metric::TYPE, metric_type},
};
} else if (index_type == milvus::knowhere::IndexEnum::INDEX_NSG) {
return milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::IndexParams::nlist, 163},
{milvus::knowhere::meta::TOPK, K},
{milvus::knowhere::IndexParams::nprobe, 8},
{milvus::knowhere::IndexParams::knng, 20},
{milvus::knowhere::IndexParams::search_length, 40},
@ -130,14 +127,17 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh
#endif
} else if (index_type == milvus::knowhere::IndexEnum::INDEX_HNSW) {
return milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM}, {milvus::knowhere::meta::TOPK, K},
{milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200},
{milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, metric_type},
{milvus::knowhere::meta::DIM, DIM},
// {milvus::knowhere::meta::TOPK, 10},
{milvus::knowhere::IndexParams::M, 16},
{milvus::knowhere::IndexParams::efConstruction, 200},
{milvus::knowhere::IndexParams::ef, 200},
{milvus::knowhere::Metric::TYPE, metric_type},
};
} else if (index_type == milvus::knowhere::IndexEnum::INDEX_ANNOY) {
return milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
// {milvus::knowhere::meta::TOPK, 10},
{milvus::knowhere::IndexParams::n_trees, 4},
{milvus::knowhere::IndexParams::search_k, 100},
{milvus::knowhere::Metric::TYPE, metric_type},
@ -146,7 +146,7 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh
} else if (index_type == milvus::knowhere::IndexEnum::INDEX_RHNSWFlat) {
return milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
// {milvus::knowhere::meta::TOPK, 10},
{milvus::knowhere::IndexParams::M, 16},
{milvus::knowhere::IndexParams::efConstruction, 200},
{milvus::knowhere::IndexParams::ef, 200},
@ -156,7 +156,7 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh
} else if (index_type == milvus::knowhere::IndexEnum::INDEX_RHNSWPQ) {
return milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
// {milvus::knowhere::meta::TOPK, 10},
{milvus::knowhere::IndexParams::M, 16},
{milvus::knowhere::IndexParams::efConstruction, 200},
{milvus::knowhere::IndexParams::ef, 200},
@ -167,7 +167,7 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh
} else if (index_type == milvus::knowhere::IndexEnum::INDEX_RHNSWSQ) {
return milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
// {milvus::knowhere::meta::TOPK, 10},
{milvus::knowhere::IndexParams::M, 16},
{milvus::knowhere::IndexParams::efConstruction, 200},
{milvus::knowhere::IndexParams::ef, 200},
@ -177,7 +177,7 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh
} else if (index_type == milvus::knowhere::IndexEnum::INDEX_NGTPANNG) {
return milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
// {milvus::knowhere::meta::TOPK, 10},
{milvus::knowhere::Metric::TYPE, metric_type},
{milvus::knowhere::IndexParams::edge_size, 10},
{milvus::knowhere::IndexParams::epsilon, 0.1},
@ -189,7 +189,7 @@ generate_conf(const milvus::knowhere::IndexType& index_type, const milvus::knowh
} else if (index_type == milvus::knowhere::IndexEnum::INDEX_NGTONNG) {
return milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
// {milvus::knowhere::meta::TOPK, 10},
{milvus::knowhere::Metric::TYPE, metric_type},
{milvus::knowhere::IndexParams::edge_size, 20},
{milvus::knowhere::IndexParams::epsilon, 0.1},
@ -234,99 +234,6 @@ GenDataset(int64_t N, const milvus::knowhere::MetricType& metric_type, bool is_b
return milvus::segcore::DataGen(schema, N);
}
}
using QueryResultPtr = std::unique_ptr<milvus::indexbuilder::IndexWrapper::QueryResult>;
void
PrintQueryResult(const QueryResultPtr& result) {
auto nq = result->nq;
auto k = result->topk;
std::stringstream ss_id;
std::stringstream ss_dist;
for (auto i = 0; i < nq; i++) {
for (auto j = 0; j < k; ++j) {
ss_id << result->ids[i * k + j] << " ";
ss_dist << result->distances[i * k + j] << " ";
}
ss_id << std::endl;
ss_dist << std::endl;
}
std::cout << "id\n" << ss_id.str() << std::endl;
std::cout << "dist\n" << ss_dist.str() << std::endl;
}
float
L2(const float* point_a, const float* point_b, int dim) {
float dis = 0;
for (auto i = 0; i < dim; i++) {
auto c_a = point_a[i];
auto c_b = point_b[i];
dis += pow(c_b - c_a, 2);
}
return dis;
}
int hamming_weight(uint8_t n) {
int count=0;
while(n != 0){
count += n&1;
n >>= 1;
}
return count;
}
float
Jaccard(const uint8_t* point_a, const uint8_t* point_b, int dim) {
float dis;
int len = dim / 8;
float intersection = 0;
float union_num = 0;
for (int i = 0; i < len; i++) {
intersection += hamming_weight(point_a[i] & point_b[i]);
union_num += hamming_weight(point_a[i] | point_b[i]);
}
dis = 1 - (intersection / union_num);
return dis;
}
float
CountDistance(const void* point_a,
const void* point_b,
int dim,
const milvus::knowhere::MetricType& metric,
bool is_binary = false) {
if (point_a == nullptr || point_b == nullptr) {
return std::numeric_limits<float>::max();
}
if (metric == milvus::knowhere::Metric::L2) {
return L2(static_cast<const float*>(point_a), static_cast<const float*>(point_b), dim);
} else if (metric == milvus::knowhere::Metric::JACCARD) {
return Jaccard(static_cast<const uint8_t*>(point_a), static_cast<const uint8_t*>(point_b), dim);
} else {
return std::numeric_limits<float>::max();
}
}
void
CheckDistances(const QueryResultPtr& result,
const milvus::knowhere::DatasetPtr& base_dataset,
const milvus::knowhere::DatasetPtr& query_dataset,
const milvus::knowhere::MetricType& metric,
const float threshold = 1.0e-5) {
auto base_vecs = base_dataset->Get<float*>(milvus::knowhere::meta::TENSOR);
auto query_vecs = query_dataset->Get<float*>(milvus::knowhere::meta::TENSOR);
auto dim = base_dataset->Get<int64_t>(milvus::knowhere::meta::DIM);
auto nq = result->nq;
auto k = result->topk;
for (auto i = 0; i < nq; i++) {
for (auto j = 0; j < k; ++j) {
auto dis = result->distances[i * k + j];
auto id = result->ids[i * k + j];
auto count_dis = CountDistance(query_vecs + i * dim, base_vecs + id * dim, dim, metric);
// assert(std::abs(dis - count_dis) < threshold);
}
}
}
} // namespace
using Param = std::pair<milvus::knowhere::IndexType, milvus::knowhere::MetricType>;
@ -340,26 +247,8 @@ class IndexWrapperTest : public ::testing::TestWithParam<Param> {
metric_type = param.second;
std::tie(type_params, index_params) = generate_params(index_type, metric_type);
std::map<std::string, bool> is_binary_map = {
{milvus::knowhere::IndexEnum::INDEX_FAISS_IDMAP, false},
{milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ, false},
{milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, false},
{milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, false},
{milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, true},
{milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, true},
#ifdef MILVUS_SUPPORT_SPTAG
{milvus::knowhere::IndexEnum::INDEX_SPTAG_KDT_RNT, false},
{milvus::knowhere::IndexEnum::INDEX_SPTAG_BKT_RNT, false},
#endif
{milvus::knowhere::IndexEnum::INDEX_HNSW, false},
{milvus::knowhere::IndexEnum::INDEX_ANNOY, false},
{milvus::knowhere::IndexEnum::INDEX_RHNSWFlat, false},
{milvus::knowhere::IndexEnum::INDEX_RHNSWPQ, false},
{milvus::knowhere::IndexEnum::INDEX_RHNSWSQ, false},
{milvus::knowhere::IndexEnum::INDEX_NGTPANNG, false},
{milvus::knowhere::IndexEnum::INDEX_NGTONNG, false},
{milvus::knowhere::IndexEnum::INDEX_NSG, false},
};
std::map<std::string, bool> is_binary_map = {{milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ, false},
{milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, true}};
is_binary = is_binary_map[index_type];
@ -373,13 +262,9 @@ class IndexWrapperTest : public ::testing::TestWithParam<Param> {
if (!is_binary) {
xb_data = dataset.get_col<float>(0);
xb_dataset = milvus::knowhere::GenDataset(NB, DIM, xb_data.data());
xq_data = dataset.get_col<float>(0);
xq_dataset = milvus::knowhere::GenDataset(NQ, DIM, xq_data.data());
} else {
xb_bin_data = dataset.get_col<uint8_t>(0);
xb_dataset = milvus::knowhere::GenDataset(NB, DIM, xb_bin_data.data());
xq_bin_data = dataset.get_col<uint8_t>(0);
xq_dataset = milvus::knowhere::GenDataset(NQ, DIM, xq_bin_data.data());
}
}
@ -397,9 +282,6 @@ class IndexWrapperTest : public ::testing::TestWithParam<Param> {
std::vector<float> xb_data;
std::vector<uint8_t> xb_bin_data;
std::vector<milvus::knowhere::IDType> ids;
milvus::knowhere::DatasetPtr xq_dataset;
std::vector<float> xq_data;
std::vector<uint8_t> xq_bin_data;
};
TEST(PQ, Build) {
@ -426,47 +308,6 @@ TEST(IVFFLATNM, Build) {
ASSERT_NO_THROW(index->AddWithoutIds(xb_dataset, conf));
}
TEST(IVFFLATNM, Query) {
auto index_type = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT;
auto metric_type = milvus::knowhere::Metric::L2;
auto conf = generate_conf(index_type, metric_type);
auto index = milvus::knowhere::VecIndexFactory::GetInstance().CreateVecIndex(index_type);
auto dataset = GenDataset(NB, metric_type, false);
auto xb_data = dataset.get_col<float>(0);
auto xb_dataset = milvus::knowhere::GenDataset(NB, DIM, xb_data.data());
ASSERT_NO_THROW(index->Train(xb_dataset, conf));
ASSERT_NO_THROW(index->AddWithoutIds(xb_dataset, conf));
auto bs = index->Serialize(conf);
auto bptr = std::make_shared<milvus::knowhere::Binary>();
bptr->data = std::shared_ptr<uint8_t[]>((uint8_t*)xb_data.data(), [&](uint8_t*) {});
bptr->size = DIM * NB * sizeof(float);
bs.Append(RAW_DATA, bptr);
index->Load(bs);
auto xq_data = dataset.get_col<float>(0);
auto xq_dataset = milvus::knowhere::GenDataset(NQ, DIM, xq_data.data());
auto result = index->Query(xq_dataset, conf, nullptr);
}
TEST(NSG, Query) {
auto index_type = milvus::knowhere::IndexEnum::INDEX_NSG;
auto metric_type = milvus::knowhere::Metric::L2;
auto conf = generate_conf(index_type, metric_type);
auto index = milvus::knowhere::VecIndexFactory::GetInstance().CreateVecIndex(index_type);
auto dataset = GenDataset(NB, metric_type, false);
auto xb_data = dataset.get_col<float>(0);
auto xb_dataset = milvus::knowhere::GenDataset(NB, DIM, xb_data.data());
index->BuildAll(xb_dataset, conf);
auto bs = index->Serialize(conf);
auto bptr = std::make_shared<milvus::knowhere::Binary>();
bptr->data = std::shared_ptr<uint8_t[]>((uint8_t*)xb_data.data(), [&](uint8_t*) {});
bptr->size = DIM * NB * sizeof(float);
bs.Append(RAW_DATA, bptr);
index->Load(bs);
auto xq_data = dataset.get_col<float>(0);
auto xq_dataset = milvus::knowhere::GenDataset(NQ, DIM, xq_data.data());
auto result = index->Query(xq_dataset, conf, nullptr);
}
TEST(BINFLAT, Build) {
auto index_type = milvus::knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT;
auto metric_type = milvus::knowhere::Metric::JACCARD;
@ -644,7 +485,12 @@ TEST_P(IndexWrapperTest, Dim) {
TEST_P(IndexWrapperTest, BuildWithoutIds) {
auto index =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
ASSERT_NO_THROW(index->BuildWithoutIds(xb_dataset));
if (milvus::indexbuilder::is_in_need_id_list(index_type)) {
ASSERT_ANY_THROW(index->BuildWithoutIds(xb_dataset));
} else {
ASSERT_NO_THROW(index->BuildWithoutIds(xb_dataset));
}
}
TEST_P(IndexWrapperTest, Codec) {
@ -665,16 +511,3 @@ TEST_P(IndexWrapperTest, Codec) {
ASSERT_EQ(strcmp(binary.data, copy_binary.data), 0);
}
}
TEST_P(IndexWrapperTest, Query) {
auto index_wrapper =
std::make_unique<milvus::indexbuilder::IndexWrapper>(type_params_str.c_str(), index_params_str.c_str());
index_wrapper->BuildWithoutIds(xb_dataset);
std::unique_ptr<milvus::indexbuilder::IndexWrapper::QueryResult> query_result = index_wrapper->Query(xq_dataset);
ASSERT_EQ(query_result->topk, K);
ASSERT_EQ(query_result->nq, NQ);
ASSERT_EQ(query_result->distances.size(), query_result->topk * query_result->nq);
ASSERT_EQ(query_result->ids.size(), query_result->topk * query_result->nq);
}

View File

@ -11,9 +11,6 @@ import (
miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"go.etcd.io/etcd/clientv3"
"github.com/zilliztech/milvus-distributed/internal/allocator"
@ -71,19 +68,16 @@ func CreateBuilder(ctx context.Context) (*Builder, error) {
idAllocator, err := allocator.NewIDAllocator(b.loopCtx, Params.MasterAddress)
minIOEndPoint := Params.MinIOAddress
minIOAccessKeyID := Params.MinIOAccessKeyID
minIOSecretAccessKey := Params.MinIOSecretAccessKey
minIOUseSSL := Params.MinIOUseSSL
minIOClient, err := minio.New(minIOEndPoint, &minio.Options{
Creds: credentials.NewStaticV4(minIOAccessKeyID, minIOSecretAccessKey, ""),
Secure: minIOUseSSL,
})
if err != nil {
return nil, err
option := &miniokv.Option{
Address: Params.MinIOAddress,
AccessKeyID: Params.MinIOAccessKeyID,
SecretAccessKeyID: Params.MinIOSecretAccessKey,
UseSSL: Params.MinIOUseSSL,
BucketName: Params.MinioBucketName,
CreateBucket: true,
}
b.kv, err = miniokv.NewMinIOKV(b.loopCtx, minIOClient, Params.MinioBucketName)
b.kv, err = miniokv.NewMinIOKV(b.loopCtx, option)
if err != nil {
return nil, err
}

View File

@ -2,11 +2,15 @@ package miniokv
import (
"context"
"fmt"
"io"
"log"
"strings"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/zilliztech/milvus-distributed/internal/errors"
)
type MinIOKV struct {
@ -15,24 +19,46 @@ type MinIOKV struct {
bucketName string
}
// NewMinIOKV creates a new MinIO kv.
func NewMinIOKV(ctx context.Context, client *minio.Client, bucketName string) (*MinIOKV, error) {
type Option struct {
Address string
AccessKeyID string
BucketName string
SecretAccessKeyID string
UseSSL bool
CreateBucket bool // when bucket not existed, create it
}
bucketExists, err := client.BucketExists(ctx, bucketName)
func NewMinIOKV(ctx context.Context, option *Option) (*MinIOKV, error) {
minIOClient, err := minio.New(option.Address, &minio.Options{
Creds: credentials.NewStaticV4(option.AccessKeyID, option.SecretAccessKeyID, ""),
Secure: option.UseSSL,
})
if err != nil {
return nil, err
}
if !bucketExists {
err = client.MakeBucket(ctx, bucketName, minio.MakeBucketOptions{})
if err != nil {
return nil, err
bucketExists, err := minIOClient.BucketExists(ctx, option.BucketName)
if err != nil {
return nil, err
}
if option.CreateBucket {
if !bucketExists {
err = minIOClient.MakeBucket(ctx, option.BucketName, minio.MakeBucketOptions{})
if err != nil {
return nil, err
}
}
} else {
if !bucketExists {
return nil, errors.New(fmt.Sprintf("Bucket %s not Existed.", option.BucketName))
}
}
return &MinIOKV{
ctx: ctx,
minioClient: client,
bucketName: bucketName,
minioClient: minIOClient,
bucketName: option.BucketName,
}, nil
}

View File

@ -5,8 +5,6 @@ import (
"strconv"
"testing"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio"
"github.com/zilliztech/milvus-distributed/internal/util/paramtable"
@ -15,24 +13,31 @@ import (
var Params paramtable.BaseTable
func TestMinIOKV_Load(t *testing.T) {
Params.Init()
func newMinIOKVClient(ctx context.Context, bucketName string) (*miniokv.MinIOKV, error) {
endPoint, _ := Params.Load("_MinioAddress")
accessKeyID, _ := Params.Load("minio.accessKeyID")
secretAccessKey, _ := Params.Load("minio.secretAccessKey")
useSSLStr, _ := Params.Load("minio.useSSL")
useSSL, _ := strconv.ParseBool(useSSLStr)
option := &miniokv.Option{
Address: endPoint,
AccessKeyID: accessKeyID,
SecretAccessKeyID: secretAccessKey,
UseSSL: useSSL,
BucketName: bucketName,
CreateBucket: true,
}
client, err := miniokv.NewMinIOKV(ctx, option)
return client, err
}
func TestMinIOKV_Load(t *testing.T) {
Params.Init()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
useSSL, _ := strconv.ParseBool(useSSLStr)
minioClient, err := minio.New(endPoint, &minio.Options{
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
Secure: useSSL,
})
assert.Nil(t, err)
bucketName := "fantastic-tech-test"
MinIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, bucketName)
MinIOKV, err := newMinIOKVClient(ctx, bucketName)
assert.Nil(t, err)
defer MinIOKV.RemoveWithPrefix("")
@ -79,25 +84,14 @@ func TestMinIOKV_Load(t *testing.T) {
}
func TestMinIOKV_MultiSave(t *testing.T) {
Params.Init()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Params.Init()
endPoint, _ := Params.Load("_MinioAddress")
accessKeyID, _ := Params.Load("minio.accessKeyID")
secretAccessKey, _ := Params.Load("minio.secretAccessKey")
useSSLStr, _ := Params.Load("minio.useSSL")
useSSL, _ := strconv.ParseBool(useSSLStr)
minioClient, err := minio.New(endPoint, &minio.Options{
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
Secure: useSSL,
})
assert.Nil(t, err)
bucketName := "fantastic-tech-test"
MinIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, bucketName)
MinIOKV, err := newMinIOKVClient(ctx, bucketName)
assert.Nil(t, err)
defer MinIOKV.RemoveWithPrefix("")
err = MinIOKV.Save("key_1", "111")
@ -117,25 +111,13 @@ func TestMinIOKV_MultiSave(t *testing.T) {
}
func TestMinIOKV_Remove(t *testing.T) {
Params.Init()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Params.Init()
endPoint, _ := Params.Load("_MinioAddress")
accessKeyID, _ := Params.Load("minio.accessKeyID")
secretAccessKey, _ := Params.Load("minio.secretAccessKey")
useSSLStr, _ := Params.Load("minio.useSSL")
useSSL, _ := strconv.ParseBool(useSSLStr)
minioClient, err := minio.New(endPoint, &minio.Options{
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
Secure: useSSL,
})
assert.Nil(t, err)
bucketName := "fantastic-tech-test"
MinIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, bucketName)
MinIOKV, err := newMinIOKVClient(ctx, bucketName)
assert.Nil(t, err)
defer MinIOKV.RemoveWithPrefix("")

View File

@ -11,9 +11,6 @@ import (
"strings"
"time"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
minioKV "github.com/zilliztech/milvus-distributed/internal/kv/minio"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
@ -38,16 +35,17 @@ type loadIndexService struct {
func newLoadIndexService(ctx context.Context, replica collectionReplica) *loadIndexService {
ctx1, cancel := context.WithCancel(ctx)
// init minio
minioClient, err := minio.New(Params.MinioEndPoint, &minio.Options{
Creds: credentials.NewStaticV4(Params.MinioAccessKeyID, Params.MinioSecretAccessKey, ""),
Secure: Params.MinioUseSSLStr,
})
if err != nil {
panic(err)
option := &minioKV.Option{
Address: Params.MinioEndPoint,
AccessKeyID: Params.MinioAccessKeyID,
SecretAccessKeyID: Params.MinioSecretAccessKey,
UseSSL: Params.MinioUseSSLStr,
CreateBucket: true,
BucketName: Params.MinioBucketName,
}
MinioKV, err := minioKV.NewMinIOKV(ctx1, minioClient, Params.MinioBucketName)
// TODO: load bucketName from config
MinioKV, err := minioKV.NewMinIOKV(ctx1, option)
if err != nil {
panic(err)
}

View File

@ -5,8 +5,6 @@ import (
"sort"
"testing"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/indexbuilder"
@ -68,13 +66,16 @@ func TestLoadIndexService(t *testing.T) {
binarySet, err := index.Serialize()
assert.Equal(t, err, nil)
//save index to minio
minioClient, err := minio.New(Params.MinioEndPoint, &minio.Options{
Creds: credentials.NewStaticV4(Params.MinioAccessKeyID, Params.MinioSecretAccessKey, ""),
Secure: Params.MinioUseSSLStr,
})
assert.Equal(t, err, nil)
minioKV, err := minioKV.NewMinIOKV(node.queryNodeLoopCtx, minioClient, Params.MinioBucketName)
option := &minioKV.Option{
Address: Params.MinioEndPoint,
AccessKeyID: Params.MinioAccessKeyID,
SecretAccessKeyID: Params.MinioSecretAccessKey,
UseSSL: Params.MinioUseSSLStr,
BucketName: Params.MinioBucketName,
CreateBucket: true,
}
minioKV, err := minioKV.NewMinIOKV(node.queryNodeLoopCtx, option)
assert.Equal(t, err, nil)
indexPaths := make([]string, 0)
for _, index := range binarySet {

View File

@ -1,134 +0,0 @@
package s3driver
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type"
)
var option = storagetype.Option{BucketName: "zilliz-hz"}
var ctx = context.Background()
var client, err = NewS3Driver(ctx, option)
func TestS3Driver_PutRowAndGetRow(t *testing.T) {
err = client.PutRow(ctx, []byte("bar"), []byte("abcdefghijklmnoopqrstuvwxyz"), "SegmentA", 1)
assert.Nil(t, err)
err = client.PutRow(ctx, []byte("bar"), []byte("djhfkjsbdfbsdughorsgsdjhgoisdgh"), "SegmentA", 2)
assert.Nil(t, err)
err = client.PutRow(ctx, []byte("bar"), []byte("123854676ershdgfsgdfk,sdhfg;sdi8"), "SegmentB", 3)
assert.Nil(t, err)
err = client.PutRow(ctx, []byte("bar1"), []byte("testkeybarorbar_1"), "SegmentC", 3)
assert.Nil(t, err)
object, _ := client.GetRow(ctx, []byte("bar"), 1)
assert.Equal(t, "abcdefghijklmnoopqrstuvwxyz", string(object))
object, _ = client.GetRow(ctx, []byte("bar"), 2)
assert.Equal(t, "djhfkjsbdfbsdughorsgsdjhgoisdgh", string(object))
object, _ = client.GetRow(ctx, []byte("bar"), 5)
assert.Equal(t, "123854676ershdgfsgdfk,sdhfg;sdi8", string(object))
object, _ = client.GetRow(ctx, []byte("bar1"), 5)
assert.Equal(t, "testkeybarorbar_1", string(object))
}
func TestS3Driver_DeleteRow(t *testing.T) {
err = client.DeleteRow(ctx, []byte("bar"), 5)
assert.Nil(t, err)
object, _ := client.GetRow(ctx, []byte("bar"), 6)
assert.Nil(t, object)
err = client.DeleteRow(ctx, []byte("bar1"), 5)
assert.Nil(t, err)
object2, _ := client.GetRow(ctx, []byte("bar1"), 6)
assert.Nil(t, object2)
}
func TestS3Driver_GetSegments(t *testing.T) {
err = client.PutRow(ctx, []byte("seg"), []byte("abcdefghijklmnoopqrstuvwxyz"), "SegmentA", 1)
assert.Nil(t, err)
err = client.PutRow(ctx, []byte("seg"), []byte("djhfkjsbdfbsdughorsgsdjhgoisdgh"), "SegmentA", 2)
assert.Nil(t, err)
err = client.PutRow(ctx, []byte("seg"), []byte("123854676ershdgfsgdfk,sdhfg;sdi8"), "SegmentB", 3)
assert.Nil(t, err)
err = client.PutRow(ctx, []byte("seg2"), []byte("testkeybarorbar_1"), "SegmentC", 1)
assert.Nil(t, err)
segements, err := client.GetSegments(ctx, []byte("seg"), 4)
assert.Nil(t, err)
assert.Equal(t, 2, len(segements))
if segements[0] == "SegmentA" {
assert.Equal(t, "SegmentA", segements[0])
assert.Equal(t, "SegmentB", segements[1])
} else {
assert.Equal(t, "SegmentB", segements[0])
assert.Equal(t, "SegmentA", segements[1])
}
}
func TestS3Driver_PutRowsAndGetRows(t *testing.T) {
keys := [][]byte{[]byte("foo"), []byte("bar")}
values := [][]byte{[]byte("The key is foo!"), []byte("The key is bar!")}
segments := []string{"segmentA", "segmentB"}
timestamps := []uint64{1, 2}
err = client.PutRows(ctx, keys, values, segments, timestamps)
assert.Nil(t, err)
objects, err := client.GetRows(ctx, keys, timestamps)
assert.Nil(t, err)
assert.Equal(t, "The key is foo!", string(objects[0]))
assert.Equal(t, "The key is bar!", string(objects[1]))
}
func TestS3Driver_DeleteRows(t *testing.T) {
keys := [][]byte{[]byte("foo"), []byte("bar")}
timestamps := []uint64{3, 3}
err := client.DeleteRows(ctx, keys, timestamps)
assert.Nil(t, err)
objects, err := client.GetRows(ctx, keys, timestamps)
assert.Nil(t, err)
assert.Nil(t, objects[0])
assert.Nil(t, objects[1])
}
func TestS3Driver_PutLogAndGetLog(t *testing.T) {
err = client.PutLog(ctx, []byte("insert"), []byte("This is insert log!"), 1, 11)
assert.Nil(t, err)
err = client.PutLog(ctx, []byte("delete"), []byte("This is delete log!"), 2, 10)
assert.Nil(t, err)
err = client.PutLog(ctx, []byte("update"), []byte("This is update log!"), 3, 9)
assert.Nil(t, err)
err = client.PutLog(ctx, []byte("select"), []byte("This is select log!"), 4, 8)
assert.Nil(t, err)
channels := []int{5, 8, 9, 10, 11, 12, 13}
logValues, err := client.GetLog(ctx, 0, 5, channels)
assert.Nil(t, err)
assert.Equal(t, "This is select log!", string(logValues[0]))
assert.Equal(t, "This is update log!", string(logValues[1]))
assert.Equal(t, "This is delete log!", string(logValues[2]))
assert.Equal(t, "This is insert log!", string(logValues[3]))
}
func TestS3Driver_Segment(t *testing.T) {
err := client.PutSegmentIndex(ctx, "segmentA", []byte("This is segmentA's index!"))
assert.Nil(t, err)
segmentIndex, err := client.GetSegmentIndex(ctx, "segmentA")
assert.Equal(t, "This is segmentA's index!", string(segmentIndex))
assert.Nil(t, err)
err = client.DeleteSegmentIndex(ctx, "segmentA")
assert.Nil(t, err)
}
func TestS3Driver_SegmentDL(t *testing.T) {
err := client.PutSegmentDL(ctx, "segmentB", []byte("This is segmentB's delete log!"))
assert.Nil(t, err)
segmentDL, err := client.GetSegmentDL(ctx, "segmentB")
assert.Nil(t, err)
assert.Equal(t, "This is segmentB's delete log!", string(segmentDL))
err = client.DeleteSegmentDL(ctx, "segmentB")
assert.Nil(t, err)
}

View File

@ -1,173 +0,0 @@
package s3driver
import (
"bytes"
"context"
"io"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
. "github.com/zilliztech/milvus-distributed/internal/storage/type"
)
type S3Store struct {
client *s3.S3
}
func NewS3Store(config aws.Config) (*S3Store, error) {
sess := session.Must(session.NewSession(&config))
service := s3.New(sess)
return &S3Store{
client: service,
}, nil
}
func (s *S3Store) Put(ctx context.Context, key Key, value Value) error {
_, err := s.client.PutObjectWithContext(ctx, &s3.PutObjectInput{
Bucket: aws.String(bucketName),
Key: aws.String(string(key)),
Body: bytes.NewReader(value),
})
//sess := session.Must(session.NewSessionWithOptions(session.Options{
// SharedConfigState: session.SharedConfigEnable,
//}))
//uploader := s3manager.NewUploader(sess)
//
//_, err := uploader.Upload(&s3manager.UploadInput{
// Bucket: aws.String(bucketName),
// Key: aws.String(string(key)),
// Body: bytes.NewReader(value),
//})
return err
}
func (s *S3Store) Get(ctx context.Context, key Key) (Value, error) {
object, err := s.client.GetObjectWithContext(ctx, &s3.GetObjectInput{
Bucket: aws.String(bucketName),
Key: aws.String(string(key)),
})
if err != nil {
return nil, err
}
//TODO: get size
size := 256 * 1024
buf := make([]byte, size)
n, err := object.Body.Read(buf)
if err != nil && err != io.EOF {
return nil, err
}
return buf[:n], nil
}
func (s *S3Store) GetByPrefix(ctx context.Context, prefix Key, keyOnly bool) ([]Key, []Value, error) {
objectsOutput, err := s.client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{
Bucket: aws.String(bucketName),
Prefix: aws.String(string(prefix)),
})
var objectsKeys []Key
var objectsValues []Value
if objectsOutput != nil && err == nil {
for _, object := range objectsOutput.Contents {
objectsKeys = append(objectsKeys, []byte(*object.Key))
if !keyOnly {
value, err := s.Get(ctx, []byte(*object.Key))
if err != nil {
return nil, nil, err
}
objectsValues = append(objectsValues, value)
}
}
} else {
return nil, nil, err
}
return objectsKeys, objectsValues, nil
}
func (s *S3Store) Scan(ctx context.Context, keyStart Key, keyEnd Key, limit int, keyOnly bool) ([]Key, []Value, error) {
var keys []Key
var values []Value
limitCount := uint(limit)
objects, err := s.client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{
Bucket: aws.String(bucketName),
Prefix: aws.String(string(keyStart)),
})
if err == nil && objects != nil {
for _, object := range objects.Contents {
if *object.Key >= string(keyEnd) {
keys = append(keys, []byte(*object.Key))
if !keyOnly {
value, err := s.Get(ctx, []byte(*object.Key))
if err != nil {
return nil, nil, err
}
values = append(values, value)
}
limitCount--
if limitCount <= 0 {
break
}
}
}
}
return keys, values, err
}
func (s *S3Store) Delete(ctx context.Context, key Key) error {
_, err := s.client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(bucketName),
Key: aws.String(string(key)),
})
return err
}
func (s *S3Store) DeleteByPrefix(ctx context.Context, prefix Key) error {
objects, err := s.client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{
Bucket: aws.String(bucketName),
Prefix: aws.String(string(prefix)),
})
if objects != nil && err == nil {
for _, object := range objects.Contents {
_, err := s.client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(bucketName),
Key: object.Key,
})
return err
}
}
return nil
}
func (s *S3Store) DeleteRange(ctx context.Context, keyStart Key, keyEnd Key) error {
objects, err := s.client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{
Bucket: aws.String(bucketName),
Prefix: aws.String(string(keyStart)),
})
if objects != nil && err == nil {
for _, object := range objects.Contents {
if *object.Key > string(keyEnd) {
_, err := s.client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(bucketName),
Key: object.Key,
})
return err
}
}
}
return nil
}

View File

@ -1,339 +0,0 @@
package s3driver
import (
"context"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/zilliztech/milvus-distributed/internal/storage/internal/minio/codec"
. "github.com/zilliztech/milvus-distributed/internal/storage/type"
)
type S3Driver struct {
driver *S3Store
}
var bucketName string
func NewS3Driver(ctx context.Context, option Option) (*S3Driver, error) {
// to-do read conf
bucketName = option.BucketName
S3Client, err := NewS3Store(aws.Config{
Region: aws.String(endpoints.CnNorthwest1RegionID)})
if err != nil {
return nil, err
}
return &S3Driver{
S3Client,
}, nil
}
func (s *S3Driver) put(ctx context.Context, key Key, value Value, timestamp Timestamp, suffix string) error {
minioKey, err := codec.MvccEncode(key, timestamp, suffix)
if err != nil {
return err
}
err = s.driver.Put(ctx, minioKey, value)
return err
}
func (s *S3Driver) scanLE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) {
keyEnd, err := codec.MvccEncode(key, timestamp, "")
if err != nil {
return nil, nil, nil, err
}
keys, values, err := s.driver.Scan(ctx, key, []byte(keyEnd), -1, keyOnly)
if err != nil {
return nil, nil, nil, err
}
var timestamps []Timestamp
for _, key := range keys {
_, timestamp, _, _ := codec.MvccDecode(key)
timestamps = append(timestamps, timestamp)
}
return timestamps, keys, values, nil
}
func (s *S3Driver) scanGE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) {
keyStart, err := codec.MvccEncode(key, timestamp, "")
if err != nil {
return nil, nil, nil, err
}
keys, values, err := s.driver.Scan(ctx, key, keyStart, -1, keyOnly)
if err != nil {
return nil, nil, nil, err
}
var timestamps []Timestamp
for _, key := range keys {
_, timestamp, _, _ := codec.MvccDecode(key)
timestamps = append(timestamps, timestamp)
}
return timestamps, keys, values, nil
}
//scan(ctx context.Context, key Key, start Timestamp, end Timestamp, withValue bool) ([]Timestamp, []Key, []Value, error)
func (s *S3Driver) deleteLE(ctx context.Context, key Key, timestamp Timestamp) error {
keyEnd, err := codec.MvccEncode(key, timestamp, "delete")
if err != nil {
return err
}
err = s.driver.DeleteRange(ctx, key, keyEnd)
return err
}
func (s *S3Driver) deleteGE(ctx context.Context, key Key, timestamp Timestamp) error {
keys, _, err := s.driver.GetByPrefix(ctx, key, true)
if err != nil {
return err
}
keyStart, err := codec.MvccEncode(key, timestamp, "")
if err != nil {
panic(err)
}
err = s.driver.DeleteRange(ctx, []byte(keyStart), keys[len(keys)-1])
return err
}
func (s *S3Driver) deleteRange(ctx context.Context, key Key, start Timestamp, end Timestamp) error {
keyStart, err := codec.MvccEncode(key, start, "")
if err != nil {
return err
}
keyEnd, err := codec.MvccEncode(key, end, "")
if err != nil {
return err
}
err = s.driver.DeleteRange(ctx, keyStart, keyEnd)
return err
}
func (s *S3Driver) GetRow(ctx context.Context, key Key, timestamp Timestamp) (Value, error) {
minioKey, err := codec.MvccEncode(key, timestamp, "")
if err != nil {
return nil, err
}
keys, values, err := s.driver.Scan(ctx, append(key, byte('_')), minioKey, 1, false)
if values == nil || keys == nil {
return nil, err
}
_, _, suffix, err := codec.MvccDecode(keys[0])
if err != nil {
return nil, err
}
if suffix == "delete" {
return nil, nil
}
return values[0], err
}
func (s *S3Driver) GetRows(ctx context.Context, keys []Key, timestamps []Timestamp) ([]Value, error) {
var values []Value
for i, key := range keys {
value, err := s.GetRow(ctx, key, timestamps[i])
if err != nil {
return nil, err
}
values = append(values, value)
}
return values, nil
}
func (s *S3Driver) PutRow(ctx context.Context, key Key, value Value, segment string, timestamp Timestamp) error {
minioKey, err := codec.MvccEncode(key, timestamp, segment)
if err != nil {
return err
}
err = s.driver.Put(ctx, minioKey, value)
return err
}
func (s *S3Driver) PutRows(ctx context.Context, keys []Key, values []Value, segments []string, timestamps []Timestamp) error {
maxThread := 100
batchSize := 1
keysLength := len(keys)
if keysLength/batchSize > maxThread {
batchSize = keysLength / maxThread
}
batchNums := keysLength / batchSize
if keysLength%batchSize != 0 {
batchNums = keysLength/batchSize + 1
}
errCh := make(chan error)
f := func(ctx2 context.Context, keys2 []Key, values2 []Value, segments2 []string, timestamps2 []Timestamp) {
for i := 0; i < len(keys2); i++ {
err := s.PutRow(ctx2, keys2[i], values2[i], segments2[i], timestamps2[i])
errCh <- err
}
}
for i := 0; i < batchNums; i++ {
j := i
go func() {
start, end := j*batchSize, (j+1)*batchSize
if len(keys) < end {
end = len(keys)
}
f(ctx, keys[start:end], values[start:end], segments[start:end], timestamps[start:end])
}()
}
for i := 0; i < len(keys); i++ {
if err := <-errCh; err != nil {
return err
}
}
return nil
}
func (s *S3Driver) GetSegments(ctx context.Context, key Key, timestamp Timestamp) ([]string, error) {
keyEnd, err := codec.MvccEncode(key, timestamp, "")
if err != nil {
return nil, err
}
keys, _, err := s.driver.Scan(ctx, append(key, byte('_')), keyEnd, -1, true)
if err != nil {
return nil, err
}
segmentsSet := map[string]bool{}
for _, key := range keys {
_, _, segment, err := codec.MvccDecode(key)
if err != nil {
panic("must no error")
}
if segment != "delete" {
segmentsSet[segment] = true
}
}
var segments []string
for k, v := range segmentsSet {
if v {
segments = append(segments, k)
}
}
return segments, err
}
func (s *S3Driver) DeleteRow(ctx context.Context, key Key, timestamp Timestamp) error {
minioKey, err := codec.MvccEncode(key, timestamp, "delete")
if err != nil {
return err
}
value := []byte("0")
err = s.driver.Put(ctx, minioKey, value)
return err
}
func (s *S3Driver) DeleteRows(ctx context.Context, keys []Key, timestamps []Timestamp) error {
maxThread := 100
batchSize := 1
keysLength := len(keys)
if keysLength/batchSize > maxThread {
batchSize = keysLength / maxThread
}
batchNums := keysLength / batchSize
if keysLength%batchSize != 0 {
batchNums = keysLength/batchSize + 1
}
errCh := make(chan error)
f := func(ctx2 context.Context, keys2 []Key, timestamps2 []Timestamp) {
for i := 0; i < len(keys2); i++ {
err := s.DeleteRow(ctx2, keys2[i], timestamps2[i])
errCh <- err
}
}
for i := 0; i < batchNums; i++ {
j := i
go func() {
start, end := j*batchSize, (j+1)*batchSize
if len(keys) < end {
end = len(keys)
}
f(ctx, keys[start:end], timestamps[start:end])
}()
}
for i := 0; i < len(keys); i++ {
if err := <-errCh; err != nil {
return err
}
}
return nil
}
func (s *S3Driver) PutLog(ctx context.Context, key Key, value Value, timestamp Timestamp, channel int) error {
logKey := codec.LogEncode(key, timestamp, channel)
err := s.driver.Put(ctx, logKey, value)
return err
}
func (s *S3Driver) GetLog(ctx context.Context, start Timestamp, end Timestamp, channels []int) ([]Value, error) {
keys, values, err := s.driver.GetByPrefix(ctx, []byte("log_"), false)
if err != nil {
return nil, err
}
var resultValues []Value
for i, key := range keys {
_, ts, channel, err := codec.LogDecode(string(key))
if err != nil {
return nil, err
}
if ts >= start && ts <= end {
for j := 0; j < len(channels); j++ {
if channel == channels[j] {
resultValues = append(resultValues, values[i])
}
}
}
}
return resultValues, nil
}
func (s *S3Driver) GetSegmentIndex(ctx context.Context, segment string) (SegmentIndex, error) {
return s.driver.Get(ctx, codec.SegmentEncode(segment, "index"))
}
func (s *S3Driver) PutSegmentIndex(ctx context.Context, segment string, index SegmentIndex) error {
return s.driver.Put(ctx, codec.SegmentEncode(segment, "index"), index)
}
func (s *S3Driver) DeleteSegmentIndex(ctx context.Context, segment string) error {
return s.driver.Delete(ctx, codec.SegmentEncode(segment, "index"))
}
func (s *S3Driver) GetSegmentDL(ctx context.Context, segment string) (SegmentDL, error) {
return s.driver.Get(ctx, codec.SegmentEncode(segment, "DL"))
}
func (s *S3Driver) PutSegmentDL(ctx context.Context, segment string, log SegmentDL) error {
return s.driver.Put(ctx, codec.SegmentEncode(segment, "DL"), log)
}
func (s *S3Driver) DeleteSegmentDL(ctx context.Context, segment string) error {
return s.driver.Delete(ctx, codec.SegmentEncode(segment, "DL"))
}

View File

@ -1,101 +0,0 @@
package codec
import (
"errors"
"fmt"
)
func MvccEncode(key []byte, ts uint64, suffix string) ([]byte, error) {
return []byte(string(key) + "_" + fmt.Sprintf("%016x", ^ts) + "_" + suffix), nil
}
func MvccDecode(key []byte) (string, uint64, string, error) {
if len(key) < 16 {
return "", 0, "", errors.New("insufficient bytes to decode value")
}
suffixIndex := 0
TSIndex := 0
undersCount := 0
for i := len(key) - 1; i > 0; i-- {
if key[i] == byte('_') {
undersCount++
if undersCount == 1 {
suffixIndex = i + 1
}
if undersCount == 2 {
TSIndex = i + 1
break
}
}
}
if suffixIndex == 0 || TSIndex == 0 {
return "", 0, "", errors.New("key is wrong formatted")
}
var TS uint64
_, err := fmt.Sscanf(string(key[TSIndex:suffixIndex-1]), "%x", &TS)
TS = ^TS
if err != nil {
return "", 0, "", err
}
return string(key[0 : TSIndex-1]), TS, string(key[suffixIndex:]), nil
}
func LogEncode(key []byte, ts uint64, channel int) []byte {
suffix := string(key) + "_" + fmt.Sprintf("%d", channel)
logKey, err := MvccEncode([]byte("log"), ts, suffix)
if err != nil {
return nil
}
return logKey
}
func LogDecode(logKey string) (string, uint64, int, error) {
if len(logKey) < 16 {
return "", 0, 0, errors.New("insufficient bytes to decode value")
}
channelIndex := 0
keyIndex := 0
TSIndex := 0
undersCount := 0
for i := len(logKey) - 1; i > 0; i-- {
if logKey[i] == '_' {
undersCount++
if undersCount == 1 {
channelIndex = i + 1
}
if undersCount == 2 {
keyIndex = i + 1
}
if undersCount == 3 {
TSIndex = i + 1
break
}
}
}
if channelIndex == 0 || TSIndex == 0 || keyIndex == 0 || logKey[:TSIndex-1] != "log" {
return "", 0, 0, errors.New("key is wrong formatted")
}
var TS uint64
var channel int
_, err := fmt.Sscanf(logKey[TSIndex:keyIndex-1], "%x", &TS)
if err != nil {
return "", 0, 0, err
}
TS = ^TS
_, err = fmt.Sscanf(logKey[channelIndex:], "%d", &channel)
if err != nil {
return "", 0, 0, err
}
return logKey[keyIndex : channelIndex-1], TS, channel, nil
}
func SegmentEncode(segment string, suffix string) []byte {
return []byte(segment + "_" + suffix)
}

View File

@ -1,361 +0,0 @@
package miniodriver
import (
"context"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/zilliztech/milvus-distributed/internal/storage/internal/minio/codec"
storageType "github.com/zilliztech/milvus-distributed/internal/storage/type"
)
type MinioDriver struct {
driver *minioStore
}
var bucketName string
func NewMinioDriver(ctx context.Context, option storageType.Option) (*MinioDriver, error) {
// to-do read conf
var endPoint = "localhost:9000"
var accessKeyID = "testminio"
var secretAccessKey = "testminio"
var useSSL = false
bucketName := option.BucketName
minioClient, err := minio.New(endPoint, &minio.Options{
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
Secure: useSSL,
})
if err != nil {
return nil, err
}
bucketExists, err := minioClient.BucketExists(ctx, bucketName)
if err != nil {
return nil, err
}
if !bucketExists {
err = minioClient.MakeBucket(ctx, bucketName, minio.MakeBucketOptions{})
if err != nil {
return nil, err
}
}
return &MinioDriver{
&minioStore{
client: minioClient,
},
}, nil
}
func (s *MinioDriver) put(ctx context.Context, key storageType.Key, value storageType.Value, timestamp storageType.Timestamp, suffix string) error {
minioKey, err := codec.MvccEncode(key, timestamp, suffix)
if err != nil {
return err
}
err = s.driver.Put(ctx, minioKey, value)
return err
}
func (s *MinioDriver) scanLE(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp, keyOnly bool) ([]storageType.Timestamp, []storageType.Key, []storageType.Value, error) {
keyEnd, err := codec.MvccEncode(key, timestamp, "")
if err != nil {
return nil, nil, nil, err
}
keys, values, err := s.driver.Scan(ctx, key, []byte(keyEnd), -1, keyOnly)
if err != nil {
return nil, nil, nil, err
}
var timestamps []storageType.Timestamp
for _, key := range keys {
_, timestamp, _, _ := codec.MvccDecode(key)
timestamps = append(timestamps, timestamp)
}
return timestamps, keys, values, nil
}
func (s *MinioDriver) scanGE(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp, keyOnly bool) ([]storageType.Timestamp, []storageType.Key, []storageType.Value, error) {
keyStart, err := codec.MvccEncode(key, timestamp, "")
if err != nil {
return nil, nil, nil, err
}
keys, values, err := s.driver.Scan(ctx, key, keyStart, -1, keyOnly)
if err != nil {
return nil, nil, nil, err
}
var timestamps []storageType.Timestamp
for _, key := range keys {
_, timestamp, _, _ := codec.MvccDecode(key)
timestamps = append(timestamps, timestamp)
}
return timestamps, keys, values, nil
}
//scan(ctx context.Context, key storageType.Key, start storageType.Timestamp, end storageType.Timestamp, withValue bool) ([]storageType.Timestamp, []storageType.Key, []storageType.Value, error)
func (s *MinioDriver) deleteLE(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp) error {
keyEnd, err := codec.MvccEncode(key, timestamp, "delete")
if err != nil {
return err
}
err = s.driver.DeleteRange(ctx, key, keyEnd)
return err
}
func (s *MinioDriver) deleteGE(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp) error {
keys, _, err := s.driver.GetByPrefix(ctx, key, true)
if err != nil {
return err
}
keyStart, err := codec.MvccEncode(key, timestamp, "")
if err != nil {
panic(err)
}
err = s.driver.DeleteRange(ctx, keyStart, keys[len(keys)-1])
if err != nil {
panic(err)
}
return nil
}
func (s *MinioDriver) deleteRange(ctx context.Context, key storageType.Key, start storageType.Timestamp, end storageType.Timestamp) error {
keyStart, err := codec.MvccEncode(key, start, "")
if err != nil {
return err
}
keyEnd, err := codec.MvccEncode(key, end, "")
if err != nil {
return err
}
err = s.driver.DeleteRange(ctx, keyStart, keyEnd)
return err
}
func (s *MinioDriver) GetRow(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp) (storageType.Value, error) {
minioKey, err := codec.MvccEncode(key, timestamp, "")
if err != nil {
return nil, err
}
keys, values, err := s.driver.Scan(ctx, append(key, byte('_')), minioKey, 1, false)
if values == nil || keys == nil {
return nil, err
}
_, _, suffix, err := codec.MvccDecode(keys[0])
if err != nil {
return nil, err
}
if suffix == "delete" {
return nil, nil
}
return values[0], err
}
func (s *MinioDriver) GetRows(ctx context.Context, keys []storageType.Key, timestamps []storageType.Timestamp) ([]storageType.Value, error) {
var values []storageType.Value
for i, key := range keys {
value, err := s.GetRow(ctx, key, timestamps[i])
if err != nil {
return nil, err
}
values = append(values, value)
}
return values, nil
}
func (s *MinioDriver) PutRow(ctx context.Context, key storageType.Key, value storageType.Value, segment string, timestamp storageType.Timestamp) error {
minioKey, err := codec.MvccEncode(key, timestamp, segment)
if err != nil {
return err
}
err = s.driver.Put(ctx, minioKey, value)
return err
}
func (s *MinioDriver) PutRows(ctx context.Context, keys []storageType.Key, values []storageType.Value, segments []string, timestamps []storageType.Timestamp) error {
maxThread := 100
batchSize := 1
keysLength := len(keys)
if keysLength/batchSize > maxThread {
batchSize = keysLength / maxThread
}
batchNums := keysLength / batchSize
if keysLength%batchSize != 0 {
batchNums = keysLength/batchSize + 1
}
errCh := make(chan error)
f := func(ctx2 context.Context, keys2 []storageType.Key, values2 []storageType.Value, segments2 []string, timestamps2 []storageType.Timestamp) {
for i := 0; i < len(keys2); i++ {
err := s.PutRow(ctx2, keys2[i], values2[i], segments2[i], timestamps2[i])
errCh <- err
}
}
for i := 0; i < batchNums; i++ {
j := i
go func() {
start, end := j*batchSize, (j+1)*batchSize
if len(keys) < end {
end = len(keys)
}
f(ctx, keys[start:end], values[start:end], segments[start:end], timestamps[start:end])
}()
}
for i := 0; i < len(keys); i++ {
if err := <-errCh; err != nil {
return err
}
}
return nil
}
func (s *MinioDriver) GetSegments(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp) ([]string, error) {
keyEnd, err := codec.MvccEncode(key, timestamp, "")
if err != nil {
return nil, err
}
keys, _, err := s.driver.Scan(ctx, append(key, byte('_')), keyEnd, -1, true)
if err != nil {
return nil, err
}
segmentsSet := map[string]bool{}
for _, key := range keys {
_, _, segment, err := codec.MvccDecode(key)
if err != nil {
panic("must no error")
}
if segment != "delete" {
segmentsSet[segment] = true
}
}
var segments []string
for k, v := range segmentsSet {
if v {
segments = append(segments, k)
}
}
return segments, err
}
func (s *MinioDriver) DeleteRow(ctx context.Context, key storageType.Key, timestamp storageType.Timestamp) error {
minioKey, err := codec.MvccEncode(key, timestamp, "delete")
if err != nil {
return err
}
value := []byte("0")
err = s.driver.Put(ctx, minioKey, value)
return err
}
func (s *MinioDriver) DeleteRows(ctx context.Context, keys []storageType.Key, timestamps []storageType.Timestamp) error {
maxThread := 100
batchSize := 1
keysLength := len(keys)
if keysLength/batchSize > maxThread {
batchSize = keysLength / maxThread
}
batchNums := keysLength / batchSize
if keysLength%batchSize != 0 {
batchNums = keysLength/batchSize + 1
}
errCh := make(chan error)
f := func(ctx2 context.Context, keys2 []storageType.Key, timestamps2 []storageType.Timestamp) {
for i := 0; i < len(keys2); i++ {
err := s.DeleteRow(ctx2, keys2[i], timestamps2[i])
errCh <- err
}
}
for i := 0; i < batchNums; i++ {
j := i
go func() {
start, end := j*batchSize, (j+1)*batchSize
if len(keys) < end {
end = len(keys)
}
f(ctx, keys[start:end], timestamps[start:end])
}()
}
for i := 0; i < len(keys); i++ {
if err := <-errCh; err != nil {
return err
}
}
return nil
}
func (s *MinioDriver) PutLog(ctx context.Context, key storageType.Key, value storageType.Value, timestamp storageType.Timestamp, channel int) error {
logKey := codec.LogEncode(key, timestamp, channel)
err := s.driver.Put(ctx, logKey, value)
return err
}
func (s *MinioDriver) GetLog(ctx context.Context, start storageType.Timestamp, end storageType.Timestamp, channels []int) ([]storageType.Value, error) {
keys, values, err := s.driver.GetByPrefix(ctx, []byte("log_"), false)
if err != nil {
return nil, err
}
var resultValues []storageType.Value
for i, key := range keys {
_, ts, channel, err := codec.LogDecode(string(key))
if err != nil {
return nil, err
}
if ts >= start && ts <= end {
for j := 0; j < len(channels); j++ {
if channel == channels[j] {
resultValues = append(resultValues, values[i])
}
}
}
}
return resultValues, nil
}
func (s *MinioDriver) GetSegmentIndex(ctx context.Context, segment string) (storageType.SegmentIndex, error) {
return s.driver.Get(ctx, codec.SegmentEncode(segment, "index"))
}
func (s *MinioDriver) PutSegmentIndex(ctx context.Context, segment string, index storageType.SegmentIndex) error {
return s.driver.Put(ctx, codec.SegmentEncode(segment, "index"), index)
}
func (s *MinioDriver) DeleteSegmentIndex(ctx context.Context, segment string) error {
return s.driver.Delete(ctx, codec.SegmentEncode(segment, "index"))
}
func (s *MinioDriver) GetSegmentDL(ctx context.Context, segment string) (storageType.SegmentDL, error) {
return s.driver.Get(ctx, codec.SegmentEncode(segment, "DL"))
}
func (s *MinioDriver) PutSegmentDL(ctx context.Context, segment string, log storageType.SegmentDL) error {
return s.driver.Put(ctx, codec.SegmentEncode(segment, "DL"), log)
}
func (s *MinioDriver) DeleteSegmentDL(ctx context.Context, segment string) error {
return s.driver.Delete(ctx, codec.SegmentEncode(segment, "DL"))
}

View File

@ -1,130 +0,0 @@
package miniodriver
import (
"bytes"
"context"
"io"
"github.com/minio/minio-go/v7"
. "github.com/zilliztech/milvus-distributed/internal/storage/type"
)
type minioStore struct {
client *minio.Client
}
func (s *minioStore) Put(ctx context.Context, key Key, value Value) error {
reader := bytes.NewReader(value)
_, err := s.client.PutObject(ctx, bucketName, string(key), reader, int64(len(value)), minio.PutObjectOptions{})
if err != nil {
return err
}
return err
}
func (s *minioStore) Get(ctx context.Context, key Key) (Value, error) {
object, err := s.client.GetObject(ctx, bucketName, string(key), minio.GetObjectOptions{})
if err != nil {
return nil, err
}
size := 256 * 1024
buf := make([]byte, size)
n, err := object.Read(buf)
if err != nil && err != io.EOF {
return nil, err
}
return buf[:n], nil
}
func (s *minioStore) GetByPrefix(ctx context.Context, prefix Key, keyOnly bool) ([]Key, []Value, error) {
objects := s.client.ListObjects(ctx, bucketName, minio.ListObjectsOptions{Prefix: string(prefix)})
var objectsKeys []Key
var objectsValues []Value
for object := range objects {
objectsKeys = append(objectsKeys, []byte(object.Key))
if !keyOnly {
value, err := s.Get(ctx, []byte(object.Key))
if err != nil {
return nil, nil, err
}
objectsValues = append(objectsValues, value)
}
}
return objectsKeys, objectsValues, nil
}
func (s *minioStore) Scan(ctx context.Context, keyStart Key, keyEnd Key, limit int, keyOnly bool) ([]Key, []Value, error) {
var keys []Key
var values []Value
limitCount := uint(limit)
for object := range s.client.ListObjects(ctx, bucketName, minio.ListObjectsOptions{Prefix: string(keyStart)}) {
if object.Key >= string(keyEnd) {
keys = append(keys, []byte(object.Key))
if !keyOnly {
value, err := s.Get(ctx, []byte(object.Key))
if err != nil {
return nil, nil, err
}
values = append(values, value)
}
limitCount--
if limitCount <= 0 {
break
}
}
}
return keys, values, nil
}
func (s *minioStore) Delete(ctx context.Context, key Key) error {
err := s.client.RemoveObject(ctx, bucketName, string(key), minio.RemoveObjectOptions{})
return err
}
func (s *minioStore) DeleteByPrefix(ctx context.Context, prefix Key) error {
objectsCh := make(chan minio.ObjectInfo)
go func() {
defer close(objectsCh)
for object := range s.client.ListObjects(ctx, bucketName, minio.ListObjectsOptions{Prefix: string(prefix)}) {
objectsCh <- object
}
}()
for rErr := range s.client.RemoveObjects(ctx, bucketName, objectsCh, minio.RemoveObjectsOptions{GovernanceBypass: true}) {
if rErr.Err != nil {
return rErr.Err
}
}
return nil
}
func (s *minioStore) DeleteRange(ctx context.Context, keyStart Key, keyEnd Key) error {
objectsCh := make(chan minio.ObjectInfo)
go func() {
defer close(objectsCh)
for object := range s.client.ListObjects(ctx, bucketName, minio.ListObjectsOptions{Prefix: string(keyStart)}) {
if object.Key <= string(keyEnd) {
objectsCh <- object
}
}
}()
for rErr := range s.client.RemoveObjects(ctx, bucketName, objectsCh, minio.RemoveObjectsOptions{GovernanceBypass: true}) {
if rErr.Err != nil {
return rErr.Err
}
}
return nil
}

View File

@ -1,134 +0,0 @@
package miniodriver
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type"
)
var option = storagetype.Option{BucketName: "zilliz-hz"}
var ctx = context.Background()
var client, err = NewMinioDriver(ctx, option)
func TestMinioDriver_PutRowAndGetRow(t *testing.T) {
err = client.PutRow(ctx, []byte("bar"), []byte("abcdefghijklmnoopqrstuvwxyz"), "SegmentA", 1)
assert.Nil(t, err)
err = client.PutRow(ctx, []byte("bar"), []byte("djhfkjsbdfbsdughorsgsdjhgoisdgh"), "SegmentA", 2)
assert.Nil(t, err)
err = client.PutRow(ctx, []byte("bar"), []byte("123854676ershdgfsgdfk,sdhfg;sdi8"), "SegmentB", 3)
assert.Nil(t, err)
err = client.PutRow(ctx, []byte("bar1"), []byte("testkeybarorbar_1"), "SegmentC", 3)
assert.Nil(t, err)
object, _ := client.GetRow(ctx, []byte("bar"), 5)
assert.Equal(t, "abcdefghijklmnoopqrstuvwxyz", string(object))
object, _ = client.GetRow(ctx, []byte("bar"), 2)
assert.Equal(t, "djhfkjsbdfbsdughorsgsdjhgoisdgh", string(object))
object, _ = client.GetRow(ctx, []byte("bar"), 5)
assert.Equal(t, "123854676ershdgfsgdfk,sdhfg;sdi8", string(object))
object, _ = client.GetRow(ctx, []byte("bar1"), 5)
assert.Equal(t, "testkeybarorbar_1", string(object))
}
func TestMinioDriver_DeleteRow(t *testing.T) {
err = client.DeleteRow(ctx, []byte("bar"), 5)
assert.Nil(t, err)
object, _ := client.GetRow(ctx, []byte("bar"), 6)
assert.Nil(t, object)
err = client.DeleteRow(ctx, []byte("bar1"), 5)
assert.Nil(t, err)
object2, _ := client.GetRow(ctx, []byte("bar1"), 6)
assert.Nil(t, object2)
}
func TestMinioDriver_GetSegments(t *testing.T) {
err = client.PutRow(ctx, []byte("seg"), []byte("abcdefghijklmnoopqrstuvwxyz"), "SegmentA", 1)
assert.Nil(t, err)
err = client.PutRow(ctx, []byte("seg"), []byte("djhfkjsbdfbsdughorsgsdjhgoisdgh"), "SegmentA", 2)
assert.Nil(t, err)
err = client.PutRow(ctx, []byte("seg"), []byte("123854676ershdgfsgdfk,sdhfg;sdi8"), "SegmentB", 3)
assert.Nil(t, err)
err = client.PutRow(ctx, []byte("seg2"), []byte("testkeybarorbar_1"), "SegmentC", 1)
assert.Nil(t, err)
segements, err := client.GetSegments(ctx, []byte("seg"), 4)
assert.Nil(t, err)
assert.Equal(t, 2, len(segements))
if segements[0] == "SegmentA" {
assert.Equal(t, "SegmentA", segements[0])
assert.Equal(t, "SegmentB", segements[1])
} else {
assert.Equal(t, "SegmentB", segements[0])
assert.Equal(t, "SegmentA", segements[1])
}
}
func TestMinioDriver_PutRowsAndGetRows(t *testing.T) {
keys := [][]byte{[]byte("foo"), []byte("bar")}
values := [][]byte{[]byte("The key is foo!"), []byte("The key is bar!")}
segments := []string{"segmentA", "segmentB"}
timestamps := []uint64{1, 2}
err = client.PutRows(ctx, keys, values, segments, timestamps)
assert.Nil(t, err)
objects, err := client.GetRows(ctx, keys, timestamps)
assert.Nil(t, err)
assert.Equal(t, "The key is foo!", string(objects[0]))
assert.Equal(t, "The key is bar!", string(objects[1]))
}
func TestMinioDriver_DeleteRows(t *testing.T) {
keys := [][]byte{[]byte("foo"), []byte("bar")}
timestamps := []uint64{3, 3}
err := client.DeleteRows(ctx, keys, timestamps)
assert.Nil(t, err)
objects, err := client.GetRows(ctx, keys, timestamps)
assert.Nil(t, err)
assert.Nil(t, objects[0])
assert.Nil(t, objects[1])
}
func TestMinioDriver_PutLogAndGetLog(t *testing.T) {
err = client.PutLog(ctx, []byte("insert"), []byte("This is insert log!"), 1, 11)
assert.Nil(t, err)
err = client.PutLog(ctx, []byte("delete"), []byte("This is delete log!"), 2, 10)
assert.Nil(t, err)
err = client.PutLog(ctx, []byte("update"), []byte("This is update log!"), 3, 9)
assert.Nil(t, err)
err = client.PutLog(ctx, []byte("select"), []byte("This is select log!"), 4, 8)
assert.Nil(t, err)
channels := []int{5, 8, 9, 10, 11, 12, 13}
logValues, err := client.GetLog(ctx, 0, 5, channels)
assert.Nil(t, err)
assert.Equal(t, "This is select log!", string(logValues[0]))
assert.Equal(t, "This is update log!", string(logValues[1]))
assert.Equal(t, "This is delete log!", string(logValues[2]))
assert.Equal(t, "This is insert log!", string(logValues[3]))
}
func TestMinioDriver_Segment(t *testing.T) {
err := client.PutSegmentIndex(ctx, "segmentA", []byte("This is segmentA's index!"))
assert.Nil(t, err)
segmentIndex, err := client.GetSegmentIndex(ctx, "segmentA")
assert.Equal(t, "This is segmentA's index!", string(segmentIndex))
assert.Nil(t, err)
err = client.DeleteSegmentIndex(ctx, "segmentA")
assert.Nil(t, err)
}
func TestMinioDriver_SegmentDL(t *testing.T) {
err := client.PutSegmentDL(ctx, "segmentB", []byte("This is segmentB's delete log!"))
assert.Nil(t, err)
segmentDL, err := client.GetSegmentDL(ctx, "segmentB")
assert.Nil(t, err)
assert.Equal(t, "This is segmentB's delete log!", string(segmentDL))
err = client.DeleteSegmentDL(ctx, "segmentB")
assert.Nil(t, err)
}

View File

@ -1,62 +0,0 @@
package codec
import (
"encoding/binary"
"errors"
"github.com/tikv/client-go/codec"
)
var (
Delimiter = byte('_')
DelimiterPlusOne = Delimiter + 0x01
DeleteMark = byte('d')
SegmentIndexMark = byte('i')
SegmentDLMark = byte('d')
)
// EncodeKey append timestamp, delimiter, and suffix string
// to one slice key.
// Note: suffix string should not contains Delimiter
func EncodeKey(key []byte, timestamp uint64, suffix string) []byte {
//TODO: should we encode key to memory comparable
ret := EncodeDelimiter(key, Delimiter)
ret = codec.EncodeUintDesc(ret, timestamp)
return append(ret, suffix...)
}
func DecodeKey(key []byte) ([]byte, uint64, string, error) {
if len(key) < 8 {
return nil, 0, "", errors.New("insufficient bytes to decode value")
}
lenDeKey := 0
for i := len(key) - 1; i > 0; i-- {
if key[i] == Delimiter {
lenDeKey = i
break
}
}
if lenDeKey == 0 || lenDeKey+8 > len(key) {
return nil, 0, "", errors.New("insufficient bytes to decode value")
}
tsBytes := key[lenDeKey+1 : lenDeKey+9]
ts := binary.BigEndian.Uint64(tsBytes)
suffix := string(key[lenDeKey+9:])
key = key[:lenDeKey-1]
return key, ^ts, suffix, nil
}
// EncodeDelimiter append a delimiter byte to slice b, and return the appended slice.
func EncodeDelimiter(b []byte, delimiter byte) []byte {
return append(b, delimiter)
}
func EncodeSegment(segName []byte, segType byte) []byte {
segmentKey := []byte("segment")
segmentKey = append(segmentKey, Delimiter)
segmentKey = append(segmentKey, segName...)
return append(segmentKey, Delimiter, segType)
}

View File

@ -1,389 +0,0 @@
package tikvdriver
import (
"context"
"errors"
"strconv"
"strings"
"github.com/tikv/client-go/config"
"github.com/tikv/client-go/rawkv"
. "github.com/zilliztech/milvus-distributed/internal/storage/internal/tikv/codec"
. "github.com/zilliztech/milvus-distributed/internal/storage/type"
storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type"
)
func keyAddOne(key Key) Key {
if key == nil {
return nil
}
lenKey := len(key)
ret := make(Key, lenKey)
copy(ret, key)
ret[lenKey-1] += 0x01
return ret
}
type tikvEngine struct {
client *rawkv.Client
conf config.Config
}
func (e tikvEngine) Put(ctx context.Context, key Key, value Value) error {
return e.client.Put(ctx, key, value)
}
func (e tikvEngine) BatchPut(ctx context.Context, keys []Key, values []Value) error {
return e.client.BatchPut(ctx, keys, values)
}
func (e tikvEngine) Get(ctx context.Context, key Key) (Value, error) {
return e.client.Get(ctx, key)
}
func (e tikvEngine) GetByPrefix(ctx context.Context, prefix Key, keyOnly bool) (keys []Key, values []Value, err error) {
startKey := prefix
endKey := keyAddOne(prefix)
limit := e.conf.Raw.MaxScanLimit
for {
ks, vs, err := e.Scan(ctx, startKey, endKey, limit, keyOnly)
if err != nil {
return keys, values, err
}
keys = append(keys, ks...)
values = append(values, vs...)
if len(ks) < limit {
break
}
// update the start key, and exclude the start key
startKey = append(ks[len(ks)-1], '\000')
}
return
}
func (e tikvEngine) Scan(ctx context.Context, startKey Key, endKey Key, limit int, keyOnly bool) ([]Key, []Value, error) {
return e.client.Scan(ctx, startKey, endKey, limit, rawkv.ScanOption{KeyOnly: keyOnly})
}
func (e tikvEngine) Delete(ctx context.Context, key Key) error {
return e.client.Delete(ctx, key)
}
func (e tikvEngine) DeleteByPrefix(ctx context.Context, prefix Key) error {
startKey := prefix
endKey := keyAddOne(prefix)
return e.client.DeleteRange(ctx, startKey, endKey)
}
func (e tikvEngine) DeleteRange(ctx context.Context, startKey Key, endKey Key) error {
return e.client.DeleteRange(ctx, startKey, endKey)
}
func (e tikvEngine) Close() error {
return e.client.Close()
}
type TikvStore struct {
engine *tikvEngine
}
func NewTikvStore(ctx context.Context, option storagetype.Option) (*TikvStore, error) {
conf := config.Default()
client, err := rawkv.NewClient(ctx, []string{option.TikvAddress}, conf)
if err != nil {
return nil, err
}
return &TikvStore{
&tikvEngine{
client: client,
conf: conf,
},
}, nil
}
func (s *TikvStore) Name() string {
return "TiKV storage"
}
func (s *TikvStore) put(ctx context.Context, key Key, value Value, timestamp Timestamp, suffix string) error {
return s.engine.Put(ctx, EncodeKey(key, timestamp, suffix), value)
}
func (s *TikvStore) scanLE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) {
panic("implement me")
}
func (s *TikvStore) scanGE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) {
panic("implement me")
}
func (s *TikvStore) scan(ctx context.Context, key Key, start Timestamp, end Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error) {
//startKey := EncodeKey(key, start, "")
//endKey := EncodeKey(EncodeDelimiter(key, DelimiterPlusOne), end, "")
//return s.engine.Scan(ctx, startKey, endKey, -1, keyOnly)
panic("implement me")
}
func (s *TikvStore) deleteLE(ctx context.Context, key Key, timestamp Timestamp) error {
panic("implement me")
}
func (s *TikvStore) deleteGE(ctx context.Context, key Key, timestamp Timestamp) error {
panic("implement me")
}
func (s *TikvStore) deleteRange(ctx context.Context, key Key, start Timestamp, end Timestamp) error {
panic("implement me")
}
func (s *TikvStore) GetRow(ctx context.Context, key Key, timestamp Timestamp) (Value, error) {
startKey := EncodeKey(key, timestamp, "")
endKey := EncodeDelimiter(key, DelimiterPlusOne)
keys, values, err := s.engine.Scan(ctx, startKey, endKey, 1, false)
if err != nil || keys == nil {
return nil, err
}
_, _, suffix, err := DecodeKey(keys[0])
if err != nil {
return nil, err
}
// key is marked deleted
if suffix == string(DeleteMark) {
return nil, nil
}
return values[0], nil
}
// TODO: how to spilt keys to some batches
var batchSize = 100
type kvPair struct {
key Key
value Value
err error
}
func batchKeys(keys []Key) [][]Key {
keysLen := len(keys)
numBatch := (keysLen-1)/batchSize + 1
batches := make([][]Key, numBatch)
for i := 0; i < numBatch; i++ {
batchStart := i * batchSize
batchEnd := batchStart + batchSize
// the last batch
if i == numBatch-1 {
batchEnd = keysLen
}
batches[i] = keys[batchStart:batchEnd]
}
return batches
}
func (s *TikvStore) GetRows(ctx context.Context, keys []Key, timestamps []Timestamp) ([]Value, error) {
if len(keys) != len(timestamps) {
return nil, errors.New("the len of keys is not equal to the len of timestamps")
}
batches := batchKeys(keys)
ch := make(chan kvPair, len(keys))
ctx, cancel := context.WithCancel(ctx)
defer cancel()
for n, b := range batches {
batch := b
numBatch := n
go func() {
for i, key := range batch {
select {
case <-ctx.Done():
return
default:
v, err := s.GetRow(ctx, key, timestamps[numBatch*batchSize+i])
ch <- kvPair{
key: key,
value: v,
err: err,
}
}
}
}()
}
var err error
var values []Value
kvMap := make(map[string]Value)
for i := 0; i < len(keys); i++ {
kv := <-ch
if kv.err != nil {
cancel()
if err == nil {
err = kv.err
}
}
kvMap[string(kv.key)] = kv.value
}
for _, key := range keys {
values = append(values, kvMap[string(key)])
}
return values, err
}
func (s *TikvStore) PutRow(ctx context.Context, key Key, value Value, segment string, timestamp Timestamp) error {
return s.put(ctx, key, value, timestamp, segment)
}
func (s *TikvStore) PutRows(ctx context.Context, keys []Key, values []Value, segments []string, timestamps []Timestamp) error {
if len(keys) != len(values) {
return errors.New("the len of keys is not equal to the len of values")
}
if len(keys) != len(timestamps) {
return errors.New("the len of keys is not equal to the len of timestamps")
}
encodedKeys := make([]Key, len(keys))
for i, key := range keys {
encodedKeys[i] = EncodeKey(key, timestamps[i], segments[i])
}
return s.engine.BatchPut(ctx, encodedKeys, values)
}
func (s *TikvStore) DeleteRow(ctx context.Context, key Key, timestamp Timestamp) error {
return s.put(ctx, key, Value{0x00}, timestamp, string(DeleteMark))
}
func (s *TikvStore) DeleteRows(ctx context.Context, keys []Key, timestamps []Timestamp) error {
encodeKeys := make([]Key, len(keys))
values := make([]Value, len(keys))
for i, key := range keys {
encodeKeys[i] = EncodeKey(key, timestamps[i], string(DeleteMark))
values[i] = Value{0x00}
}
return s.engine.BatchPut(ctx, encodeKeys, values)
}
//func (s *TikvStore) DeleteRows(ctx context.Context, keys []Key, timestamp Timestamp) error {
// batches := batchKeys(keys)
// ch := make(chan error, len(batches))
// ctx, cancel := context.WithCancel(ctx)
//
// for _, b := range batches {
// batch := b
// go func() {
// for _, key := range batch {
// select {
// case <-ctx.Done():
// return
// default:
// ch <- s.DeleteRow(ctx, key, timestamp)
// }
// }
// }()
// }
//
// var err error
// for i := 0; i < len(keys); i++ {
// if e := <-ch; e != nil {
// cancel()
// if err == nil {
// err = e
// }
// }
// }
// return err
//}
func (s *TikvStore) PutLog(ctx context.Context, key Key, value Value, timestamp Timestamp, channel int) error {
suffix := string(EncodeDelimiter(key, DelimiterPlusOne)) + strconv.Itoa(channel)
return s.put(ctx, Key("log"), value, timestamp, suffix)
}
func (s *TikvStore) GetLog(ctx context.Context, start Timestamp, end Timestamp, channels []int) (logs []Value, err error) {
key := Key("log")
startKey := EncodeKey(key, end, "")
endKey := EncodeKey(key, start, "")
// TODO: use for loop to ensure get all keys
keys, values, err := s.engine.Scan(ctx, startKey, endKey, s.engine.conf.Raw.MaxScanLimit, false)
if err != nil || keys == nil {
return nil, err
}
for i, key := range keys {
_, _, suffix, err := DecodeKey(key)
log := values[i]
if err != nil {
return logs, err
}
// no channels filter
if len(channels) == 0 {
logs = append(logs, log)
}
slice := strings.Split(suffix, string(DelimiterPlusOne))
channel, err := strconv.Atoi(slice[len(slice)-1])
if err != nil {
panic(err)
}
for _, item := range channels {
if item == channel {
logs = append(logs, log)
break
}
}
}
return
}
func (s *TikvStore) GetSegmentIndex(ctx context.Context, segment string) (SegmentIndex, error) {
return s.engine.Get(ctx, EncodeSegment([]byte(segment), SegmentIndexMark))
}
func (s *TikvStore) PutSegmentIndex(ctx context.Context, segment string, index SegmentIndex) error {
return s.engine.Put(ctx, EncodeSegment([]byte(segment), SegmentIndexMark), index)
}
func (s *TikvStore) DeleteSegmentIndex(ctx context.Context, segment string) error {
return s.engine.Delete(ctx, EncodeSegment([]byte(segment), SegmentIndexMark))
}
func (s *TikvStore) GetSegmentDL(ctx context.Context, segment string) (SegmentDL, error) {
return s.engine.Get(ctx, EncodeSegment([]byte(segment), SegmentDLMark))
}
func (s *TikvStore) PutSegmentDL(ctx context.Context, segment string, log SegmentDL) error {
return s.engine.Put(ctx, EncodeSegment([]byte(segment), SegmentDLMark), log)
}
func (s *TikvStore) DeleteSegmentDL(ctx context.Context, segment string) error {
return s.engine.Delete(ctx, EncodeSegment([]byte(segment), SegmentDLMark))
}
func (s *TikvStore) GetSegments(ctx context.Context, key Key, timestamp Timestamp) ([]string, error) {
keys, _, err := s.engine.GetByPrefix(ctx, EncodeDelimiter(key, Delimiter), true)
if err != nil {
return nil, err
}
segmentsSet := map[string]bool{}
for _, key := range keys {
_, ts, segment, err := DecodeKey(key)
if err != nil {
panic("must no error")
}
if ts <= timestamp && segment != string(DeleteMark) {
segmentsSet[segment] = true
}
}
var segments []string
for k, v := range segmentsSet {
if v {
segments = append(segments, k)
}
}
return segments, err
}
func (s *TikvStore) Close() error {
return s.engine.Close()
}

View File

@ -1,293 +0,0 @@
package tikvdriver
import (
"bytes"
"context"
"fmt"
"math"
"os"
"sort"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
. "github.com/zilliztech/milvus-distributed/internal/storage/internal/tikv/codec"
. "github.com/zilliztech/milvus-distributed/internal/storage/type"
)
//var store TikvStore
var store *TikvStore
var option = Option{TikvAddress: "localhost:2379"}
func TestMain(m *testing.M) {
store, _ = NewTikvStore(context.Background(), option)
exitCode := m.Run()
_ = store.Close()
os.Exit(exitCode)
}
func TestTikvEngine_Prefix(t *testing.T) {
ctx := context.Background()
prefix := Key("key")
engine := store.engine
value := Value("value")
// Put some key with same prefix
key := prefix
err := engine.Put(ctx, key, value)
require.Nil(t, err)
key = EncodeKey(prefix, 0, "")
err = engine.Put(ctx, key, value)
assert.Nil(t, err)
// Get by prefix
ks, _, err := engine.GetByPrefix(ctx, prefix, true)
assert.Equal(t, 2, len(ks))
assert.Nil(t, err)
// Delete by prefix
err = engine.DeleteByPrefix(ctx, prefix)
assert.Nil(t, err)
ks, _, err = engine.GetByPrefix(ctx, prefix, true)
assert.Equal(t, 0, len(ks))
assert.Nil(t, err)
//Test large amount keys
num := engine.conf.Raw.MaxScanLimit + 1
keys := make([]Key, num)
values := make([]Value, num)
for i := 0; i < num; i++ {
key = EncodeKey(prefix, uint64(i), "")
keys[i] = key
values[i] = value
}
err = engine.BatchPut(ctx, keys, values)
assert.Nil(t, err)
ks, _, err = engine.GetByPrefix(ctx, prefix, true)
assert.Nil(t, err)
assert.Equal(t, num, len(ks))
err = engine.DeleteByPrefix(ctx, prefix)
assert.Nil(t, err)
}
func TestTikvStore_Row(t *testing.T) {
ctx := context.Background()
key := Key("key")
// Add same row with different timestamp
err := store.PutRow(ctx, key, Value("value0"), "segment0", 0)
assert.Nil(t, err)
err = store.PutRow(ctx, key, Value("value1"), "segment0", 2)
assert.Nil(t, err)
// Get most recent row using key and timestamp
v, err := store.GetRow(ctx, key, 3)
assert.Nil(t, err)
assert.Equal(t, Value("value1"), v)
v, err = store.GetRow(ctx, key, 2)
assert.Nil(t, err)
assert.Equal(t, Value("value1"), v)
v, err = store.GetRow(ctx, key, 1)
assert.Nil(t, err)
assert.Equal(t, Value("value0"), v)
// Add a different row, but with same prefix
key1 := Key("key_y")
err = store.PutRow(ctx, key1, Value("valuey"), "segment0", 2)
assert.Nil(t, err)
// Get most recent row using key and timestamp
v, err = store.GetRow(ctx, key, 3)
assert.Nil(t, err)
assert.Equal(t, Value("value1"), v)
v, err = store.GetRow(ctx, key1, 3)
assert.Nil(t, err)
assert.Equal(t, Value("valuey"), v)
// Delete a row
err = store.DeleteRow(ctx, key, 4)
assert.Nil(t, err)
v, err = store.GetRow(ctx, key, 5)
assert.Nil(t, err)
assert.Nil(t, v)
// Clear test data
err = store.engine.DeleteByPrefix(ctx, key)
assert.Nil(t, err)
k, va, err := store.engine.GetByPrefix(ctx, key, false)
assert.Nil(t, err)
assert.Nil(t, k)
assert.Nil(t, va)
}
func TestTikvStore_BatchRow(t *testing.T) {
ctx := context.Background()
// Prepare test data
size := 0
var testKeys []Key
var testValues []Value
var segments []string
var timestamps []Timestamp
for i := 0; size/store.engine.conf.Raw.MaxBatchPutSize < 1; i++ {
key := fmt.Sprint("key", i)
size += len(key)
testKeys = append(testKeys, []byte(key))
value := fmt.Sprint("value", i)
size += len(value)
testValues = append(testValues, []byte(value))
segments = append(segments, "test")
v, err := store.GetRow(ctx, Key(key), math.MaxUint64)
assert.Nil(t, v)
assert.Nil(t, err)
}
// Batch put rows
for range testKeys {
timestamps = append(timestamps, 1)
}
err := store.PutRows(ctx, testKeys, testValues, segments, timestamps)
assert.Nil(t, err)
// Batch get rows
for i := range timestamps {
timestamps[i] = 2
}
checkValues, err := store.GetRows(ctx, testKeys, timestamps)
assert.NotNil(t, checkValues)
assert.Nil(t, err)
assert.Equal(t, len(checkValues), len(testValues))
for i := range testKeys {
assert.Equal(t, testValues[i], checkValues[i])
}
// Delete all test rows
for i := range timestamps {
timestamps[i] = math.MaxUint64
}
err = store.DeleteRows(ctx, testKeys, timestamps)
assert.Nil(t, err)
// Ensure all test row is deleted
for i := range timestamps {
timestamps[i] = math.MaxUint64
}
checkValues, err = store.GetRows(ctx, testKeys, timestamps)
assert.Nil(t, err)
for _, value := range checkValues {
assert.Nil(t, value)
}
// Clean test data
err = store.engine.DeleteByPrefix(ctx, Key("key"))
assert.Nil(t, err)
}
func TestTikvStore_GetSegments(t *testing.T) {
ctx := context.Background()
key := Key("key")
// Put rows
err := store.PutRow(ctx, key, Value{0}, "a", 1)
assert.Nil(t, err)
err = store.PutRow(ctx, key, Value{0}, "a", 2)
assert.Nil(t, err)
err = store.PutRow(ctx, key, Value{0}, "c", 3)
assert.Nil(t, err)
// Get segments
segs, err := store.GetSegments(ctx, key, 2)
assert.Nil(t, err)
assert.Equal(t, 1, len(segs))
assert.Equal(t, "a", segs[0])
segs, err = store.GetSegments(ctx, key, 3)
assert.Nil(t, err)
assert.Equal(t, 2, len(segs))
// Clean test data
err = store.engine.DeleteByPrefix(ctx, key)
assert.Nil(t, err)
}
func TestTikvStore_Log(t *testing.T) {
ctx := context.Background()
// Put some log
err := store.PutLog(ctx, Key("key1"), Value("value1"), 1, 1)
assert.Nil(t, err)
err = store.PutLog(ctx, Key("key1"), Value("value1_1"), 1, 2)
assert.Nil(t, err)
err = store.PutLog(ctx, Key("key2"), Value("value2"), 2, 1)
assert.Nil(t, err)
// Check log
log, err := store.GetLog(ctx, 0, 2, []int{1, 2})
if err != nil {
panic(err)
}
sort.Slice(log, func(i, j int) bool {
return bytes.Compare(log[i], log[j]) == -1
})
assert.Equal(t, log[0], Value("value1"))
assert.Equal(t, log[1], Value("value1_1"))
assert.Equal(t, log[2], Value("value2"))
// Delete test data
err = store.engine.DeleteByPrefix(ctx, Key("log"))
assert.Nil(t, err)
}
func TestTikvStore_SegmentIndex(t *testing.T) {
ctx := context.Background()
// Put segment index
err := store.PutSegmentIndex(ctx, "segment0", []byte("index0"))
assert.Nil(t, err)
err = store.PutSegmentIndex(ctx, "segment1", []byte("index1"))
assert.Nil(t, err)
// Get segment index
index, err := store.GetSegmentIndex(ctx, "segment0")
assert.Nil(t, err)
assert.Equal(t, []byte("index0"), index)
index, err = store.GetSegmentIndex(ctx, "segment1")
assert.Nil(t, err)
assert.Equal(t, []byte("index1"), index)
// Delete segment index
err = store.DeleteSegmentIndex(ctx, "segment0")
assert.Nil(t, err)
err = store.DeleteSegmentIndex(ctx, "segment1")
assert.Nil(t, err)
index, err = store.GetSegmentIndex(ctx, "segment0")
assert.Nil(t, err)
assert.Nil(t, index)
}
func TestTikvStore_DeleteSegmentDL(t *testing.T) {
ctx := context.Background()
// Put segment delete log
err := store.PutSegmentDL(ctx, "segment0", []byte("index0"))
assert.Nil(t, err)
err = store.PutSegmentDL(ctx, "segment1", []byte("index1"))
assert.Nil(t, err)
// Get segment delete log
index, err := store.GetSegmentDL(ctx, "segment0")
assert.Nil(t, err)
assert.Equal(t, []byte("index0"), index)
index, err = store.GetSegmentDL(ctx, "segment1")
assert.Nil(t, err)
assert.Equal(t, []byte("index1"), index)
// Delete segment delete log
err = store.DeleteSegmentDL(ctx, "segment0")
assert.Nil(t, err)
err = store.DeleteSegmentDL(ctx, "segment1")
assert.Nil(t, err)
index, err = store.GetSegmentDL(ctx, "segment0")
assert.Nil(t, err)
assert.Nil(t, index)
}

View File

@ -1,39 +0,0 @@
package storage
import (
"context"
"errors"
S3Driver "github.com/zilliztech/milvus-distributed/internal/storage/internal/S3"
minIODriver "github.com/zilliztech/milvus-distributed/internal/storage/internal/minio"
tikvDriver "github.com/zilliztech/milvus-distributed/internal/storage/internal/tikv"
storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type"
)
func NewStore(ctx context.Context, option storagetype.Option) (storagetype.Store, error) {
var err error
var store storagetype.Store
switch option.Type {
case storagetype.TIKVDriver:
store, err = tikvDriver.NewTikvStore(ctx, option)
if err != nil {
panic(err.Error())
}
return store, nil
case storagetype.MinIODriver:
store, err = minIODriver.NewMinioDriver(ctx, option)
if err != nil {
//panic(err.Error())
return nil, err
}
return store, nil
case storagetype.S3DRIVER:
store, err = S3Driver.NewS3Driver(ctx, option)
if err != nil {
//panic(err.Error())
return nil, err
}
return store, nil
}
return nil, errors.New("unsupported driver")
}

View File

@ -1,79 +0,0 @@
package storagetype
import (
"context"
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
)
type Key = []byte
type Value = []byte
type Timestamp = typeutil.Timestamp
type DriverType = string
type SegmentIndex = []byte
type SegmentDL = []byte
type Option struct {
Type DriverType
TikvAddress string
BucketName string
}
const (
MinIODriver DriverType = "MinIO"
TIKVDriver DriverType = "TIKV"
S3DRIVER DriverType = "S3"
)
/*
type Store interface {
Get(ctx context.Context, key Key, timestamp Timestamp) (Value, error)
BatchGet(ctx context.Context, keys [] Key, timestamp Timestamp) ([]Value, error)
Set(ctx context.Context, key Key, v Value, timestamp Timestamp) error
BatchSet(ctx context.Context, keys []Key, v []Value, timestamp Timestamp) error
Delete(ctx context.Context, key Key, timestamp Timestamp) error
BatchDelete(ctx context.Context, keys []Key, timestamp Timestamp) error
Close() error
}
*/
type storeEngine interface {
Put(ctx context.Context, key Key, value Value) error
Get(ctx context.Context, key Key) (Value, error)
GetByPrefix(ctx context.Context, prefix Key, keyOnly bool) ([]Key, []Value, error)
Scan(ctx context.Context, startKey Key, endKey Key, limit int, keyOnly bool) ([]Key, []Value, error)
Delete(ctx context.Context, key Key) error
DeleteByPrefix(ctx context.Context, prefix Key) error
DeleteRange(ctx context.Context, keyStart Key, keyEnd Key) error
}
type Store interface {
//put(ctx context.Context, key Key, value Value, timestamp Timestamp, suffix string) error
//scanLE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error)
//scanGE(ctx context.Context, key Key, timestamp Timestamp, keyOnly bool) ([]Timestamp, []Key, []Value, error)
//deleteLE(ctx context.Context, key Key, timestamp Timestamp) error
//deleteGE(ctx context.Context, key Key, timestamp Timestamp) error
//deleteRange(ctx context.Context, key Key, start Timestamp, end Timestamp) error
GetRow(ctx context.Context, key Key, timestamp Timestamp) (Value, error)
GetRows(ctx context.Context, keys []Key, timestamps []Timestamp) ([]Value, error)
PutRow(ctx context.Context, key Key, value Value, segment string, timestamp Timestamp) error
PutRows(ctx context.Context, keys []Key, values []Value, segments []string, timestamps []Timestamp) error
GetSegments(ctx context.Context, key Key, timestamp Timestamp) ([]string, error)
DeleteRow(ctx context.Context, key Key, timestamp Timestamp) error
DeleteRows(ctx context.Context, keys []Key, timestamps []Timestamp) error
PutLog(ctx context.Context, key Key, value Value, timestamp Timestamp, channel int) error
GetLog(ctx context.Context, start Timestamp, end Timestamp, channels []int) ([]Value, error)
GetSegmentIndex(ctx context.Context, segment string) (SegmentIndex, error)
PutSegmentIndex(ctx context.Context, segment string, index SegmentIndex) error
DeleteSegmentIndex(ctx context.Context, segment string) error
GetSegmentDL(ctx context.Context, segment string) (SegmentDL, error)
PutSegmentDL(ctx context.Context, segment string, log SegmentDL) error
DeleteSegmentDL(ctx context.Context, segment string) error
}

View File

@ -9,9 +9,6 @@ import (
"strconv"
"github.com/golang/protobuf/proto"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/zilliztech/milvus-distributed/internal/allocator"
"github.com/zilliztech/milvus-distributed/internal/kv"
miniokv "github.com/zilliztech/milvus-distributed/internal/kv/minio"
@ -360,19 +357,16 @@ func newDDNode(ctx context.Context, outCh chan *ddlFlushSyncMsg) *ddNode {
partitionRecords: make(map[UniqueID]interface{}),
}
minIOEndPoint := Params.MinioAddress
minIOAccessKeyID := Params.MinioAccessKeyID
minIOSecretAccessKey := Params.MinioSecretAccessKey
minIOUseSSL := Params.MinioUseSSL
minIOClient, err := minio.New(minIOEndPoint, &minio.Options{
Creds: credentials.NewStaticV4(minIOAccessKeyID, minIOSecretAccessKey, ""),
Secure: minIOUseSSL,
})
if err != nil {
panic(err)
}
bucketName := Params.MinioBucketName
minioKV, err := miniokv.NewMinIOKV(ctx, minIOClient, bucketName)
option := &miniokv.Option{
Address: Params.MinioAddress,
AccessKeyID: Params.MinioAccessKeyID,
SecretAccessKeyID: Params.MinioSecretAccessKey,
UseSSL: Params.MinioUseSSL,
BucketName: bucketName,
CreateBucket: true,
}
minioKV, err := miniokv.NewMinIOKV(ctx, option)
if err != nil {
panic(err)
}

View File

@ -11,8 +11,6 @@ import (
"unsafe"
"github.com/golang/protobuf/proto"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/zilliztech/milvus-distributed/internal/allocator"
"github.com/zilliztech/milvus-distributed/internal/kv"
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
@ -610,20 +608,17 @@ func newInsertBufferNode(ctx context.Context, outCh chan *insertFlushSyncMsg) *i
kvClient := etcdkv.NewEtcdKV(cli, MetaRootPath)
// MinIO
minioendPoint := Params.MinioAddress
miniioAccessKeyID := Params.MinioAccessKeyID
miniioSecretAccessKey := Params.MinioSecretAccessKey
minioUseSSL := Params.MinioUseSSL
minioBucketName := Params.MinioBucketName
minioClient, err := minio.New(minioendPoint, &minio.Options{
Creds: credentials.NewStaticV4(miniioAccessKeyID, miniioSecretAccessKey, ""),
Secure: minioUseSSL,
})
if err != nil {
panic(err)
option := &miniokv.Option{
Address: Params.MinioAddress,
AccessKeyID: Params.MinioAccessKeyID,
SecretAccessKeyID: Params.MinioSecretAccessKey,
UseSSL: Params.MinioUseSSL,
CreateBucket: true,
BucketName: Params.MinioBucketName,
}
minIOKV, err := miniokv.NewMinIOKV(ctx, minioClient, minioBucketName)
minIOKV, err := miniokv.NewMinIOKV(ctx, option)
if err != nil {
panic(err)
}

View File

@ -58,6 +58,10 @@ if __name__ == "__main__":
'visitor_name': "ExecExprVisitor",
"parameter_name": 'expr',
},
{
'visitor_name': "VerifyExprVisitor",
"parameter_name": 'expr',
},
],
'PlanNode': [
{
@ -68,7 +72,10 @@ if __name__ == "__main__":
'visitor_name': "ExecPlanNodeVisitor",
"parameter_name": 'node',
},
{
'visitor_name': "VerifyPlanNodeVisitor",
"parameter_name": 'node',
},
]
}
extract_extra_body(visitor_info, query_path)

View File

@ -13,7 +13,7 @@
#include "@@base_visitor@@.h"
namespace @@namespace@@ {
class @@visitor_name@@ : @@base_visitor@@ {
class @@visitor_name@@ : public @@base_visitor@@ {
public:
@@body@@