Add parquet payload

Signed-off-by: neza2017 <yefu.chen@zilliz.com>
pull/4973/head^2
neza2017 2020-12-05 16:11:03 +08:00 committed by yefu.chen
parent 63c8f60c6e
commit 70710dee47
104 changed files with 4045 additions and 3531 deletions

4
.gitignore vendored
View File

@ -55,7 +55,3 @@ cmake_build/
.DS_Store
*.swp
cwrapper_build
**/.clangd/*
**/compile_commands.json
**/.lint

View File

@ -1,9 +1,9 @@
timeout(time: 10, unit: 'MINUTES') {
timeout(time: 5, unit: 'MINUTES') {
dir ("scripts") {
sh '. ./before-install.sh && unset http_proxy && unset https_proxy && ./check_cache.sh -l $CCACHE_ARTFACTORY_URL --cache_dir=\$CCACHE_DIR -f ccache-\$OS_NAME-\$BUILD_ENV_IMAGE_ID.tar.gz || echo \"ccache files not found!\"'
}
sh '. ./scripts/before-install.sh && make install'
sh '. ./scripts/before-install.sh && make check-proto-product && make verifiers && make install'
dir ("scripts") {
withCredentials([usernamePassword(credentialsId: "${env.JFROG_CREDENTIALS_ID}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) {

View File

@ -4,10 +4,7 @@ try {
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d pulsar'
dir ('build/docker/deploy') {
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} pull'
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d master'
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d proxy'
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} run -e QUERY_NODE_ID=1 -d querynode'
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} run -e QUERY_NODE_ID=2 -d querynode'
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d'
}
dir ('build/docker/test') {

View File

@ -41,9 +41,9 @@ fmt:
lint:
@echo "Running $@ check"
@GO111MODULE=on ${GOPATH}/bin/golangci-lint cache clean
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=30m --config ./.golangci.yml ./internal/...
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=30m --config ./.golangci.yml ./cmd/...
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=30m --config ./.golangci.yml ./tests/go/...
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./internal/...
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./cmd/...
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./tests/go/...
ruleguard:
@echo "Running $@ check"
@ -65,11 +65,9 @@ build-go:
build-cpp:
@(env bash $(PWD)/scripts/core_build.sh)
@(env bash $(PWD)/scripts/cwrapper_build.sh -t Release)
build-cpp-with-unittest:
@(env bash $(PWD)/scripts/core_build.sh -u)
@(env bash $(PWD)/scripts/cwrapper_build.sh -t Release)
# Runs the tests.
unittest: test-cpp test-go

View File

@ -1,4 +0,0 @@
SOURCE_REPO=milvusdb
TARGET_REPO=milvusdb
SOURCE_TAG=latest
TARGET_TAG=latest

View File

@ -2,7 +2,6 @@ package main
import (
"context"
"fmt"
"log"
"os"
"os/signal"
@ -14,7 +13,8 @@ import (
func main() {
proxy.Init()
fmt.Println("ProxyID is", proxy.Params.ProxyID())
// Creates server.
ctx, cancel := context.WithCancel(context.Background())
svr, err := proxy.CreateProxy(ctx)
if err != nil {

View File

@ -2,24 +2,18 @@ package main
import (
"context"
"fmt"
"log"
"os"
"os/signal"
"syscall"
"go.uber.org/zap"
"github.com/zilliztech/milvus-distributed/internal/querynode"
)
func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
querynode.Init()
fmt.Println("QueryNodeID is", querynode.Params.QueryNodeID())
// Creates server.
ctx, cancel := context.WithCancel(context.Background())
svr := querynode.NewQueryNode(ctx, 0)
sc := make(chan os.Signal, 1)
signal.Notify(sc,
@ -34,14 +28,8 @@ func main() {
cancel()
}()
if err := svr.Start(); err != nil {
log.Fatal("run server failed", zap.Error(err))
}
querynode.StartQueryNode(ctx)
<-ctx.Done()
log.Print("Got signal to exit", zap.String("signal", sig.String()))
svr.Close()
switch sig {
case syscall.SIGTERM:
exit(0)

View File

@ -32,7 +32,7 @@ msgChannel:
# default channel range [0, 1)
channelRange:
insert: [0, 2]
insert: [0, 1]
delete: [0, 1]
dataDefinition: [0,1]
k2s: [0, 1]

118
configs/config.yaml Normal file
View File

@ -0,0 +1,118 @@
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
master:
address: localhost
port: 53100
pulsarmoniterinterval: 1
pulsartopic: "monitor-topic"
proxyidlist: [1, 2]
proxyTimeSyncChannels: ["proxy1", "proxy2"]
proxyTimeSyncSubName: "proxy-topic"
softTimeTickBarrierInterval: 500
writeidlist: [3, 4]
writeTimeSyncChannels: ["write3", "write4"]
writeTimeSyncSubName: "write-topic"
dmTimeSyncChannels: ["dm5", "dm6"]
k2sTimeSyncChannels: ["k2s7", "k2s8"]
defaultSizePerRecord: 1024
minimumAssignSize: 1048576
segmentThreshold: 536870912
segmentExpireDuration: 2000
segmentThresholdFactor: 0.75
querynodenum: 1
writenodenum: 1
statsChannels: "statistic"
etcd:
address: localhost
port: 2379
rootpath: by-dev
segthreshold: 10000
minio:
address: localhost
port: 9000
accessKeyID: minioadmin
secretAccessKey: minioadmin
useSSL: false
timesync:
interval: 400
storage:
driver: TIKV
address: localhost
port: 2379
accesskey:
secretkey:
pulsar:
authentication: false
user: user-default
token: eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJKb2UifQ.ipevRNuRP6HflG8cFKnmUPtypruRC4fb1DWtoLL62SY
address: localhost
port: 6650
topicnum: 128
reader:
clientid: 0
stopflag: -1
readerqueuesize: 10000
searchchansize: 10000
key2segchansize: 10000
topicstart: 0
topicend: 128
writer:
clientid: 0
stopflag: -2
readerqueuesize: 10000
searchbyidchansize: 10000
parallelism: 100
topicstart: 0
topicend: 128
bucket: "zilliz-hz"
proxy:
timezone: UTC+8
proxy_id: 1
numReaderNodes: 2
tsoSaveInterval: 200
timeTickInterval: 200
pulsarTopics:
readerTopicPrefix: "milvusReader"
numReaderTopics: 2
deleteTopic: "milvusDeleter"
queryTopic: "milvusQuery"
resultTopic: "milvusResult"
resultGroup: "milvusResultGroup"
timeTickTopic: "milvusTimeTick"
network:
address: 0.0.0.0
port: 19530
logs:
level: debug
trace.enable: true
path: /tmp/logs
max_log_file_size: 1024MB
log_rotate_num: 0
storage:
path: /var/lib/milvus
auto_flush_interval: 1

View File

@ -12,7 +12,7 @@
nodeID: # will be deprecated after v0.2
proxyIDList: [0]
queryNodeIDList: [1, 2]
queryNodeIDList: [2]
writeNodeIDList: [3]
etcd:
@ -23,13 +23,6 @@ etcd:
kvSubPath: kv # kvRootPath = rootPath + '/' + kvSubPath
segThreshold: 10000
minio:
address: localhost
port: 9000
accessKeyID: minioadmin
secretAccessKey: minioadmin
useSSL: false
pulsar:
address: localhost
port: 6650

View File

@ -1,223 +0,0 @@
## Binlog
InsertBinlog、DeleteBinlog、DDLBinlog
Binlog is stored in a columnar storage format, every column in schema should be stored in a individual file. Timestamp, schema, row id and primary key allocated by system are four special columns. Schema column records the DDL of the collection.
## Event format
Binlog file consists of 4 bytes magic number and a series of events. The first event must be descriptor event.
### Event format
```
+=====================================+
| event | timestamp 0 : 8 | create timestamp
| header +----------------------------+
| | type_code 8 : 1 | event type code
| +----------------------------+
| | server_id 9 : 4 | write node id
| +----------------------------+
| | event_length 13 : 4 | length of event, including header and data
| +----------------------------+
| | next_position 17 : 4 | offset of next event from the start of file
| +----------------------------+
| | extra_headers 21 : x-21 | reserved part
+=====================================+
| event | fixed part x : y |
| data +----------------------------+
| | variable part |
+=====================================+
```
### Descriptor Event format
```
+=====================================+
| event | timestamp 0 : 8 | create timestamp
| header +----------------------------+
| | type_code 8 : 1 | event type code
| +----------------------------+
| | server_id 9 : 4 | write node id
| +----------------------------+
| | event_length 13 : 4 | length of event, including header and data
| +----------------------------+
| | next_position 17 : 4 | offset of next event from the start of file
+=====================================+
| event | binlog_version 21 : 2 | binlog version
| data +----------------------------+
| | server_version 23 : 8 | write node version
| +----------------------------+
| | commit_id 31 : 8 | commit id of the programe in git
| +----------------------------+
| | header_length 39 : 1 | header length of other event
| +----------------------------+
| | collection_id 40 : 8 | collection id
| +----------------------------+
| | partition_id 48 : 8 | partition id (schema column does not need)
| +----------------------------+
| | segment_id 56 : 8 | segment id (schema column does not need)
| +----------------------------+
| | start_timestamp 64 : 1 | minimum timestamp allocated by master of all events in this file
| +----------------------------+
| | end_timestamp 65 : 1 | maximum timestamp allocated by master of all events in this file
| +----------------------------+
| | post-header 66 : n | array of n bytes, one byte per event type that the server knows about
| | lengths for all |
| | event types |
+=====================================+
```
### Type code
```
DESCRIPTOR_EVENT
INSERT_EVENT
DELETE_EVENT
CREATE_COLLECTION_EVENT
DROP_COLLECTION_EVENT
CREATE_PARTITION_EVENT
DROP_PARTITION_EVENT
```
DESCRIPTOR_EVENT must appear in all column files and always be the first event.
INSERT_EVENT 可以出现在除DDL binlog文件外的其他列的binlog
DELETE_EVENT 只能用于primary key 的binlog文件目前只有按照primary key删除
CREATE_COLLECTION_EVENT、DROP_COLLECTION_EVENT、CREATE_PARTITION_EVENT、DROP_PARTITION_EVENT 只出现在DDL binlog文件
### Event data part
```
event data part
INSERT_EVENT:
+================================================+
| event | fixed | start_timestamp x : 8 | min timestamp in this event
| data | part +------------------------------+
| | | end_timestamp x+8 : 8 | max timestamp in this event
| | +------------------------------+
| | | reserved x+16 : y-x-16 | reserved part
| +--------+------------------------------+
| |variable| parquet payloI ad | payload in parquet format
| |part | |
+================================================+
other events is similar with INSERT_EVENT
```
### Example
Schema
string | int | float(optional) | vector(512)
Request:
InsertRequest rows(1W)
DeleteRequest pk=1
DropPartition partitionTag="abc"
insert binlogs:
rowid, pk, ts, string, int, float, vector 6 files
all events are INSERT_EVENT
float column file contains some NULL value
delete binlogs:
pk, ts 2 files
pk's events are DELETE_EVENT, ts's events are INSERT_EVENT
DDL binlogs:
ddl, ts
ddl's event is DROP_PARTITION_EVENT, ts's event is INSERT_EVENT
C++ interface
```c++
typedef void* CPayloadWriter
typedef struct CBuffer {
char* data;
int length;
} CBuffer
typedef struct CStatus {
int error_code;
const char* error_msg;
} CStatus
// C++ interface
// writer
CPayloadWriter NewPayloadWriter(int columnType);
CStatus AddBooleanToPayload(CPayloadWriter payloadWriter, bool *values, int length);
CStatus AddInt8ToPayload(CPayloadWriter payloadWriter, int8_t *values, int length);
CStatus AddInt16ToPayload(CPayloadWriter payloadWriter, int16_t *values, int length);
CStatus AddInt32ToPayload(CPayloadWriter payloadWriter, int32_t *values, int length);
CStatus AddInt64ToPayload(CPayloadWriter payloadWriter, int64_t *values, int length);
CStatus AddFloatToPayload(CPayloadWriter payloadWriter, float *values, int length);
CStatus AddDoubleToPayload(CPayloadWriter payloadWriter, double *values, int length);
CStatus AddOneStringToPayload(CPayloadWriter payloadWriter, char *cstr, int str_size);
CStatus AddBinaryVectorToPayload(CPayloadWriter payloadWriter, uint8_t *values, int dimension, int length);
CStatus AddFloatVectorToPayload(CPayloadWriter payloadWriter, float *values, int dimension, int length);
CStatus FinishPayloadWriter(CPayloadWriter payloadWriter);
CBuffer GetPayloadBufferFromWriter(CPayloadWriter payloadWriter);
int GetPayloadLengthFromWriter(CPayloadWriter payloadWriter);
CStatus ReleasePayloadWriter(CPayloadWriter handler);
// reader
CPayloadReader NewPayloadReader(int columnType, uint8_t *buffer, int64_t buf_size);
CStatus GetBoolFromPayload(CPayloadReader payloadReader, bool **values, int *length);
CStatus GetInt8FromPayload(CPayloadReader payloadReader, int8_t **values, int *length);
CStatus GetInt16FromPayload(CPayloadReader payloadReader, int16_t **values, int *length);
CStatus GetInt32FromPayload(CPayloadReader payloadReader, int32_t **values, int *length);
CStatus GetInt64FromPayload(CPayloadReader payloadReader, int64_t **values, int *length);
CStatus GetFloatFromPayload(CPayloadReader payloadReader, float **values, int *length);
CStatus GetDoubleFromPayload(CPayloadReader payloadReader, double **values, int *length);
CStatus GetOneStringFromPayload(CPayloadReader payloadReader, int idx, char **cstr, int *str_size);
CStatus GetBinaryVectorFromPayload(CPayloadReader payloadReader, uint8_t **values, int *dimension, int *length);
CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader, float **values, int *dimension, int *length);
int GetPayloadLengthFromReader(CPayloadReader payloadReader);
CStatus ReleasePayloadReader(CPayloadReader payloadReader);
```

150
internal/conf/conf.go Normal file
View File

@ -0,0 +1,150 @@
package conf
import (
"io/ioutil"
"path"
"runtime"
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type"
yaml "gopkg.in/yaml.v2"
)
type UniqueID = typeutil.UniqueID
// yaml.MapSlice
type MasterConfig struct {
Address string
Port int32
PulsarMonitorInterval int32
PulsarTopic string
SegmentThreshold float32
SegmentExpireDuration int64
ProxyIDList []UniqueID
QueryNodeNum int
WriteNodeNum int
}
type EtcdConfig struct {
Address string
Port int32
Rootpath string
Segthreshold int64
}
type TimeSyncConfig struct {
Interval int32
}
type StorageConfig struct {
Driver storagetype.DriverType
Address string
Port int32
Accesskey string
Secretkey string
}
type PulsarConfig struct {
Authentication bool
User string
Token string
Address string
Port int32
TopicNum int
}
type ProxyConfig struct {
Timezone string `yaml:"timezone"`
ProxyID int `yaml:"proxy_id"`
NumReaderNodes int `yaml:"numReaderNodes"`
TosSaveInterval int `yaml:"tsoSaveInterval"`
TimeTickInterval int `yaml:"timeTickInterval"`
PulsarTopics struct {
ReaderTopicPrefix string `yaml:"readerTopicPrefix"`
NumReaderTopics int `yaml:"numReaderTopics"`
DeleteTopic string `yaml:"deleteTopic"`
QueryTopic string `yaml:"queryTopic"`
ResultTopic string `yaml:"resultTopic"`
ResultGroup string `yaml:"resultGroup"`
TimeTickTopic string `yaml:"timeTickTopic"`
} `yaml:"pulsarTopics"`
Network struct {
Address string `yaml:"address"`
Port int `yaml:"port"`
} `yaml:"network"`
Logs struct {
Level string `yaml:"level"`
TraceEnable bool `yaml:"trace.enable"`
Path string `yaml:"path"`
MaxLogFileSize string `yaml:"max_log_file_size"`
LogRotateNum int `yaml:"log_rotate_num"`
} `yaml:"logs"`
Storage struct {
Path string `yaml:"path"`
AutoFlushInterval int `yaml:"auto_flush_interval"`
} `yaml:"storage"`
}
type Reader struct {
ClientID int
StopFlag int64
ReaderQueueSize int
SearchChanSize int
Key2SegChanSize int
TopicStart int
TopicEnd int
}
type Writer struct {
ClientID int
StopFlag int64
ReaderQueueSize int
SearchByIDChanSize int
Parallelism int
TopicStart int
TopicEnd int
Bucket string
}
type ServerConfig struct {
Master MasterConfig
Etcd EtcdConfig
Timesync TimeSyncConfig
Storage StorageConfig
Pulsar PulsarConfig
Writer Writer
Reader Reader
Proxy ProxyConfig
}
var Config ServerConfig
// func init() {
// load_config()
// }
func getConfigsDir() string {
_, fpath, _, _ := runtime.Caller(0)
configPath := path.Dir(fpath) + "/../../configs/"
configPath = path.Dir(configPath)
return configPath
}
func LoadConfigWithPath(yamlFilePath string) {
source, err := ioutil.ReadFile(yamlFilePath)
if err != nil {
panic(err)
}
err = yaml.Unmarshal(source, &Config)
if err != nil {
panic(err)
}
//fmt.Printf("Result: %v\n", Config)
}
func LoadConfig(yamlFile string) {
filePath := path.Join(getConfigsDir(), yamlFile)
LoadConfigWithPath(filePath)
}

View File

@ -0,0 +1,10 @@
package conf
import (
"fmt"
"testing"
)
func TestMain(m *testing.M) {
fmt.Printf("Result: %v\n", Config)
}

View File

@ -1,9 +1,8 @@
set(COMMON_SRC
Schema.cpp
Types.cpp
)
set(COMMON_SRC
Schema.cpp
)
add_library(milvus_common
${COMMON_SRC}
)
add_library(milvus_common
${COMMON_SRC}
)
target_link_libraries(milvus_common milvus_proto)

View File

@ -10,13 +10,18 @@
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "common/Types.h"
#include "utils/Types.h"
#include "utils/Status.h"
#include "utils/EasyAssert.h"
#include <string>
#include <stdexcept>
namespace milvus {
using Timestamp = uint64_t; // TODO: use TiKV-like timestamp
using engine::DataType;
using engine::FieldElementType;
inline int
field_sizeof(DataType data_type, int dim = 1) {
switch (data_type) {
@ -84,13 +89,7 @@ field_is_vector(DataType datatype) {
struct FieldMeta {
public:
FieldMeta(std::string_view name, DataType type) : name_(name), type_(type) {
Assert(!is_vector());
}
FieldMeta(std::string_view name, DataType type, int64_t dim, MetricType metric_type)
: name_(name), type_(type), vector_info_(VectorInfo{dim, metric_type}) {
Assert(is_vector());
FieldMeta(std::string_view name, DataType type, int dim = 1) : name_(name), type_(type), dim_(dim) {
}
bool
@ -99,11 +98,14 @@ struct FieldMeta {
return type_ == DataType::VECTOR_BINARY || type_ == DataType::VECTOR_FLOAT;
}
int64_t
void
set_dim(int dim) {
dim_ = dim;
}
int
get_dim() const {
Assert(is_vector());
Assert(vector_info_.has_value());
return vector_info_->dim_;
return dim_;
}
const std::string&
@ -118,20 +120,12 @@ struct FieldMeta {
int
get_sizeof() const {
if (is_vector()) {
return field_sizeof(type_, get_dim());
} else {
return field_sizeof(type_, 1);
}
return field_sizeof(type_, dim_);
}
private:
struct VectorInfo {
int64_t dim_;
MetricType metric_type_;
};
std::string name_;
DataType type_ = DataType::NONE;
std::optional<VectorInfo> vector_info_;
int dim_ = 1;
};
} // namespace milvus

View File

@ -11,50 +11,35 @@
#include "common/Schema.h"
#include <google/protobuf/text_format.h>
#include <boost/lexical_cast.hpp>
namespace milvus {
using std::string;
static std::map<string, string>
RepeatedKeyValToMap(const google::protobuf::RepeatedPtrField<proto::common::KeyValuePair>& kvs) {
std::map<string, string> mapping;
for (auto& kv : kvs) {
AssertInfo(!mapping.count(kv.key()), "repeat key(" + kv.key() + ") in protobuf");
mapping.emplace(kv.key(), kv.value());
}
return mapping;
}
std::shared_ptr<Schema>
Schema::ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto) {
auto schema = std::make_shared<Schema>();
schema->set_auto_id(schema_proto.autoid());
for (const milvus::proto::schema::FieldSchema& child : schema_proto.fields()) {
const auto& type_params = child.type_params();
int64_t dim = -1;
auto data_type = DataType(child.data_type());
for (const auto& type_param : type_params) {
if (type_param.key() == "dim") {
dim = strtoll(type_param.value().c_str(), nullptr, 10);
}
}
if (field_is_vector(data_type)) {
AssertInfo(dim != -1, "dim not found");
} else {
AssertInfo(dim == 1 || dim == -1, "Invalid dim field. Should be 1 or not exists");
dim = 1;
}
if (child.is_primary_key()) {
AssertInfo(!schema->primary_key_offset_opt_.has_value(), "repetitive primary key");
schema->primary_key_offset_opt_ = schema->size();
}
if (field_is_vector(data_type)) {
auto type_map = RepeatedKeyValToMap(child.type_params());
auto index_map = RepeatedKeyValToMap(child.index_params());
if (!index_map.count("metric_type")) {
auto default_metric_type =
data_type == DataType::VECTOR_FLOAT ? MetricType::METRIC_L2 : MetricType::METRIC_Jaccard;
index_map["metric_type"] = default_metric_type;
}
AssertInfo(type_map.count("dim"), "dim not found");
auto dim = boost::lexical_cast<int64_t>(type_map.at("dim"));
AssertInfo(index_map.count("metric_type"), "index not found");
auto metric_type = GetMetricType(index_map.at("metric_type"));
schema->AddField(child.name(), data_type, dim, metric_type);
} else {
schema->AddField(child.name(), data_type);
}
schema->AddField(child.name(), data_type, dim);
}
return schema;
}

View File

@ -24,15 +24,19 @@ namespace milvus {
class Schema {
public:
void
AddField(std::string_view field_name, DataType data_type) {
auto field_meta = FieldMeta(field_name, data_type);
AddField(std::string_view field_name, DataType data_type, int dim = 1) {
auto field_meta = FieldMeta(field_name, data_type, dim);
this->AddField(std::move(field_meta));
}
void
AddField(std::string_view field_name, DataType data_type, int64_t dim, MetricType metric_type) {
auto field_meta = FieldMeta(field_name, data_type, dim, metric_type);
this->AddField(std::move(field_meta));
AddField(FieldMeta field_meta) {
auto offset = fields_.size();
fields_.emplace_back(field_meta);
offsets_.emplace(field_meta.get_name(), offset);
auto field_sizeof = field_meta.get_sizeof();
sizeof_infos_.push_back(field_sizeof);
total_sizeof_ += field_sizeof;
}
void
@ -40,6 +44,17 @@ class Schema {
is_auto_id_ = is_auto_id;
}
auto
begin() {
return fields_.begin();
}
auto
end() {
return fields_.end();
}
public:
bool
get_is_auto_id() const {
return is_auto_id_;
@ -108,20 +123,11 @@ class Schema {
static std::shared_ptr<Schema>
ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto);
void
AddField(FieldMeta&& field_meta) {
auto offset = fields_.size();
fields_.emplace_back(field_meta);
offsets_.emplace(field_meta.get_name(), offset);
auto field_sizeof = field_meta.get_sizeof();
sizeof_infos_.push_back(std::move(field_sizeof));
total_sizeof_ += field_sizeof;
}
private:
// this is where data holds
std::vector<FieldMeta> fields_;
private:
// a mapping for random access
std::unordered_map<std::string, int> offsets_;
std::vector<int> sizeof_infos_;

View File

@ -1,45 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
//
// Created by mike on 12/3/20.
//
#include "common/Types.h"
#include <knowhere/index/vector_index/helpers/IndexParameter.h>
#include "utils/EasyAssert.h"
#include <boost/bimap.hpp>
#include <boost/algorithm/string/case_conv.hpp>
namespace milvus {
using boost::algorithm::to_lower_copy;
namespace Metric = knowhere::Metric;
static auto map = [] {
boost::bimap<std::string, MetricType> mapping;
using pos = boost::bimap<std::string, MetricType>::value_type;
mapping.insert(pos(to_lower_copy(std::string(Metric::L2)), MetricType::METRIC_L2));
mapping.insert(pos(to_lower_copy(std::string(Metric::IP)), MetricType::METRIC_INNER_PRODUCT));
mapping.insert(pos(to_lower_copy(std::string(Metric::JACCARD)), MetricType::METRIC_Jaccard));
mapping.insert(pos(to_lower_copy(std::string(Metric::TANIMOTO)), MetricType::METRIC_Tanimoto));
mapping.insert(pos(to_lower_copy(std::string(Metric::HAMMING)), MetricType::METRIC_Hamming));
mapping.insert(pos(to_lower_copy(std::string(Metric::SUBSTRUCTURE)), MetricType::METRIC_Substructure));
mapping.insert(pos(to_lower_copy(std::string(Metric::SUPERSTRUCTURE)), MetricType::METRIC_Superstructure));
return mapping;
}();
MetricType
GetMetricType(const std::string& type_name) {
auto real_name = to_lower_copy(type_name);
AssertInfo(map.left.count(real_name), "metric type not found: (" + type_name + ")");
return map.left.at(real_name);
}
} // namespace milvus

View File

@ -1,40 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "utils/Types.h"
#include <faiss/MetricType.h>
#include <string>
#include <boost/align/aligned_allocator.hpp>
#include <vector>
namespace milvus {
using Timestamp = uint64_t; // TODO: use TiKV-like timestamp
using engine::DataType;
using engine::FieldElementType;
using engine::QueryResult;
using MetricType = faiss::MetricType;
faiss::MetricType
GetMetricType(const std::string& type);
// NOTE: dependent type
// used at meta-template programming
template <class...>
constexpr std::true_type always_true{};
template <class...>
constexpr std::false_type always_false{};
template <typename T>
using aligned_vector = std::vector<T, boost::alignment::aligned_allocator<T, 512>>;
} // namespace milvus

View File

@ -11,66 +11,20 @@
#include "BruteForceSearch.h"
#include <vector>
#include <common/Types.h>
#include <boost/dynamic_bitset.hpp>
#include <queue>
namespace milvus::query {
void
BinarySearchBruteForceNaive(MetricType metric_type,
int64_t code_size,
const uint8_t* binary_chunk,
int64_t chunk_size,
int64_t topk,
int64_t num_queries,
const uint8_t* query_data,
float* result_distances,
idx_t* result_labels,
faiss::ConcurrentBitsetPtr bitset) {
// THIS IS A NAIVE IMPLEMENTATION, ready for optimize
Assert(metric_type == faiss::METRIC_Jaccard);
Assert(code_size % 4 == 0);
using T = std::tuple<float, int>;
for (int64_t q = 0; q < num_queries; ++q) {
auto query_ptr = query_data + code_size * q;
auto query = boost::dynamic_bitset(query_ptr, query_ptr + code_size);
std::vector<T> max_heap(topk + 1, std::make_tuple(std::numeric_limits<float>::max(), -1));
for (int64_t i = 0; i < chunk_size; ++i) {
auto element_ptr = binary_chunk + code_size * i;
auto element = boost::dynamic_bitset(element_ptr, element_ptr + code_size);
auto the_and = (query & element).count();
auto the_or = (query | element).count();
auto distance = the_or ? (float)(the_or - the_and) / the_or : 0;
if (distance < std::get<0>(max_heap[0])) {
max_heap[topk] = std::make_tuple(distance, i);
std::push_heap(max_heap.begin(), max_heap.end());
std::pop_heap(max_heap.begin(), max_heap.end());
}
}
std::sort(max_heap.begin(), max_heap.end());
for (int k = 0; k < topk; ++k) {
auto info = max_heap[k];
result_distances[k + q * topk] = std::get<0>(info);
result_labels[k + q * topk] = std::get<1>(info);
}
}
}
void
BinarySearchBruteForceFast(MetricType metric_type,
int64_t code_size,
const uint8_t* binary_chunk,
int64_t chunk_size,
int64_t topk,
int64_t num_queries,
const uint8_t* query_data,
float* result_distances,
idx_t* result_labels,
faiss::ConcurrentBitsetPtr bitset) {
BinarySearchBruteForce(faiss::MetricType metric_type,
int64_t code_size,
const uint8_t* binary_chunk,
int64_t chunk_size,
int64_t topk,
int64_t num_queries,
const uint8_t* query_data,
float* result_distances,
idx_t* result_labels,
faiss::ConcurrentBitsetPtr bitset) {
const idx_t block_size = segcore::DefaultElementPerChunk;
bool use_heap = true;
@ -129,21 +83,6 @@ BinarySearchBruteForceFast(MetricType metric_type,
for (int i = 0; i < num_queries; ++i) {
result_distances[i] = static_cast<float>(int_distances[i]);
}
} else {
PanicInfo("Unsupported metric type");
}
}
void
BinarySearchBruteForce(const dataset::BinaryQueryDataset& query_dataset,
const uint8_t* binary_chunk,
int64_t chunk_size,
float* result_distances,
idx_t* result_labels,
faiss::ConcurrentBitsetPtr bitset) {
// TODO: refactor the internal function
BinarySearchBruteForceFast(query_dataset.metric_type, query_dataset.code_size, binary_chunk, chunk_size,
query_dataset.topk, query_dataset.num_queries, query_dataset.query_data,
result_distances, result_labels, bitset);
}
} // namespace milvus::query

View File

@ -15,25 +15,15 @@
#include "common/Schema.h"
namespace milvus::query {
using MetricType = faiss::MetricType;
namespace dataset {
struct BinaryQueryDataset {
MetricType metric_type;
int64_t num_queries;
int64_t topk;
int64_t code_size;
const uint8_t* query_data;
};
} // namespace dataset
void
BinarySearchBruteForce(const dataset::BinaryQueryDataset& query_dataset,
BinarySearchBruteForce(faiss::MetricType metric_type,
int64_t code_size,
const uint8_t* binary_chunk,
int64_t chunk_size,
int64_t topk,
int64_t num_queries,
const uint8_t* query_data,
float* result_distances,
idx_t* result_labels,
faiss::ConcurrentBitsetPtr bitset = nullptr);
} // namespace milvus::query

View File

@ -26,25 +26,15 @@ static std::unique_ptr<VectorPlanNode>
ParseVecNode(Plan* plan, const Json& out_body) {
Assert(out_body.is_object());
// TODO add binary info
auto vec_node = std::make_unique<FloatVectorANNS>();
Assert(out_body.size() == 1);
auto iter = out_body.begin();
std::string field_name = iter.key();
auto& vec_info = iter.value();
Assert(vec_info.is_object());
auto topK = vec_info["topk"];
AssertInfo(topK > 0, "topK must greater than 0");
AssertInfo(topK < 16384, "topK is too large");
auto field_meta = plan->schema_.operator[](field_name);
auto vec_node = [&]() -> std::unique_ptr<VectorPlanNode> {
auto data_type = field_meta.get_data_type();
if (data_type == DataType::VECTOR_FLOAT) {
return std::make_unique<FloatVectorANNS>();
} else {
return std::make_unique<BinaryVectorANNS>();
}
}();
vec_node->query_info_.topK_ = topK;
vec_node->query_info_.metric_type_ = vec_info.at("metric_type");
vec_node->query_info_.search_params_ = vec_info.at("params");
@ -70,6 +60,8 @@ to_lower(const std::string& raw) {
return data;
}
template <class...>
constexpr std::false_type always_false{};
template <typename T>
std::unique_ptr<Expr>
ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) {
@ -83,62 +75,31 @@ ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Js
AssertInfo(RangeExpr::mapping_.count(op_name), "op(" + op_name + ") not found");
auto op = RangeExpr::mapping_.at(op_name);
if constexpr (std::is_same_v<T, bool>) {
Assert(item.value().is_boolean());
} else if constexpr (std::is_integral_v<T>) {
if constexpr (std::is_integral_v<T>) {
Assert(item.value().is_number_integer());
} else if constexpr (std::is_floating_point_v<T>) {
Assert(item.value().is_number());
} else {
static_assert(always_false<T>, "unsupported type");
__builtin_unreachable();
}
T value = item.value();
expr->conditions_.emplace_back(op, value);
}
std::sort(expr->conditions_.begin(), expr->conditions_.end());
return expr;
}
template <typename T>
std::unique_ptr<Expr>
ParseTermNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) {
auto expr = std::make_unique<TermExprImpl<T>>();
auto data_type = schema[field_name].get_data_type();
Assert(body.is_array());
expr->field_id_ = field_name;
expr->data_type_ = data_type;
for (auto& value : body) {
if constexpr (std::is_same_v<T, bool>) {
Assert(value.is_boolean());
} else if constexpr (std::is_integral_v<T>) {
Assert(value.is_number_integer());
} else if constexpr (std::is_floating_point_v<T>) {
Assert(value.is_number());
} else {
static_assert(always_false<T>, "unsupported type");
__builtin_unreachable();
}
T real_value = value;
expr->terms_.push_back(real_value);
}
std::sort(expr->terms_.begin(), expr->terms_.end());
return expr;
}
std::unique_ptr<Expr>
ParseRangeNode(const Schema& schema, const Json& out_body) {
Assert(out_body.is_object());
Assert(out_body.size() == 1);
auto out_iter = out_body.begin();
auto field_name = out_iter.key();
auto body = out_iter.value();
auto data_type = schema[field_name].get_data_type();
Assert(!field_is_vector(data_type));
switch (data_type) {
case DataType::BOOL: {
return ParseRangeNodeImpl<bool>(schema, field_name, body);
PanicInfo("bool is not supported in Range node");
// return ParseRangeNodeImpl<bool>(schema, field_name, body);
}
case DataType::INT8:
return ParseRangeNodeImpl<int8_t>(schema, field_name, body);
@ -157,42 +118,6 @@ ParseRangeNode(const Schema& schema, const Json& out_body) {
}
}
static std::unique_ptr<Expr>
ParseTermNode(const Schema& schema, const Json& out_body) {
Assert(out_body.size() == 1);
auto out_iter = out_body.begin();
auto field_name = out_iter.key();
auto body = out_iter.value();
auto data_type = schema[field_name].get_data_type();
Assert(!field_is_vector(data_type));
switch (data_type) {
case DataType::BOOL: {
return ParseTermNodeImpl<bool>(schema, field_name, body);
}
case DataType::INT8: {
return ParseTermNodeImpl<int8_t>(schema, field_name, body);
}
case DataType::INT16: {
return ParseTermNodeImpl<int16_t>(schema, field_name, body);
}
case DataType::INT32: {
return ParseTermNodeImpl<int32_t>(schema, field_name, body);
}
case DataType::INT64: {
return ParseTermNodeImpl<int64_t>(schema, field_name, body);
}
case DataType::FLOAT: {
return ParseTermNodeImpl<float>(schema, field_name, body);
}
case DataType::DOUBLE: {
return ParseTermNodeImpl<double>(schema, field_name, body);
}
default: {
PanicInfo("unsupported data_type");
}
}
}
static std::unique_ptr<Plan>
CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) {
auto plan = std::make_unique<Plan>(schema);
@ -208,10 +133,6 @@ CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) {
if (pack.contains("vector")) {
auto& out_body = pack.at("vector");
plan->plan_node_ = ParseVecNode(plan.get(), out_body);
} else if (pack.contains("term")) {
AssertInfo(!predicate, "unsupported complex DSL");
auto& out_body = pack.at("term");
predicate = ParseTermNode(schema, out_body);
} else if (pack.contains("range")) {
AssertInfo(!predicate, "unsupported complex DSL");
auto& out_body = pack.at("range");

View File

@ -20,6 +20,7 @@
#include <map>
#include <string>
#include <vector>
#include <boost/align/aligned_allocator.hpp>
namespace milvus::query {
using Json = nlohmann::json;
@ -38,6 +39,9 @@ struct Plan {
// TODO: add move extra info
};
template <typename T>
using aligned_vector = std::vector<T, boost::alignment::aligned_allocator<T, 512>>;
struct Placeholder {
// milvus::proto::service::PlaceholderGroup group_;
std::string tag_;

View File

@ -16,7 +16,6 @@
#include <faiss/utils/distances.h>
#include "utils/tools.h"
#include "query/BruteForceSearch.h"
namespace milvus::query {
using segcore::DefaultElementPerChunk;
@ -27,7 +26,7 @@ create_bitmap_view(std::optional<const BitmapSimple*> bitmaps_opt, int64_t chunk
return nullptr;
}
auto& bitmaps = *bitmaps_opt.value();
auto src_vec = ~bitmaps.at(chunk_id);
auto& src_vec = bitmaps.at(chunk_id);
auto dst = std::make_shared<faiss::ConcurrentBitset>(src_vec.size());
auto iter = reinterpret_cast<BitmapChunk::block_type*>(dst->mutable_data());
@ -42,7 +41,7 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
int64_t num_queries,
Timestamp timestamp,
std::optional<const BitmapSimple*> bitmaps_opt,
QueryResult& results) {
segcore::QueryResult& results) {
auto& schema = segment.get_schema();
auto& indexing_record = segment.get_indexing_record();
auto& record = segment.get_insert_record();
@ -132,92 +131,7 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
}
results.result_ids_ = std::move(final_uids);
// TODO: deprecated code end
return Status::OK();
}
Status
BinaryQueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
const query::QueryInfo& info,
const uint8_t* query_data,
int64_t num_queries,
Timestamp timestamp,
std::optional<const BitmapSimple*> bitmaps_opt,
QueryResult& results) {
auto& schema = segment.get_schema();
auto& indexing_record = segment.get_indexing_record();
auto& record = segment.get_insert_record();
// step 1: binary search to find the barrier of the snapshot
auto ins_barrier = get_barrier(record, timestamp);
auto max_chunk = upper_div(ins_barrier, DefaultElementPerChunk);
auto metric_type = GetMetricType(info.metric_type_);
// auto del_barrier = get_barrier(deleted_record_, timestamp);
#if 0
auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier);
Assert(bitmap_holder);
auto bitmap = bitmap_holder->bitmap_ptr;
#endif
// step 2.1: get meta
// step 2.2: get which vector field to search
auto vecfield_offset_opt = schema.get_offset(info.field_id_);
Assert(vecfield_offset_opt.has_value());
auto vecfield_offset = vecfield_offset_opt.value();
auto& field = schema[vecfield_offset];
Assert(field.get_data_type() == DataType::VECTOR_BINARY);
auto dim = field.get_dim();
auto code_size = dim / 8;
auto topK = info.topK_;
auto total_count = topK * num_queries;
// step 3: small indexing search
std::vector<int64_t> final_uids(total_count, -1);
std::vector<float> final_dis(total_count, std::numeric_limits<float>::max());
query::dataset::BinaryQueryDataset query_dataset{metric_type, num_queries, topK, code_size, query_data};
using segcore::BinaryVector;
auto vec_ptr = record.get_entity<BinaryVector>(vecfield_offset);
auto max_indexed_id = 0;
// step 4: brute force search where small indexing is unavailable
for (int chunk_id = max_indexed_id; chunk_id < max_chunk; ++chunk_id) {
std::vector<int64_t> buf_uids(total_count, -1);
std::vector<float> buf_dis(total_count, std::numeric_limits<float>::max());
auto& chunk = vec_ptr->get_chunk(chunk_id);
auto nsize =
chunk_id != max_chunk - 1 ? DefaultElementPerChunk : ins_barrier - chunk_id * DefaultElementPerChunk;
auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id);
BinarySearchBruteForce(query_dataset, chunk.data(), nsize, buf_dis.data(), buf_uids.data(), bitmap_view);
// convert chunk uid to segment uid
for (auto& x : buf_uids) {
if (x != -1) {
x += chunk_id * DefaultElementPerChunk;
}
}
segcore::merge_into(num_queries, topK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data());
}
results.result_distances_ = std::move(final_dis);
results.internal_seg_offsets_ = std::move(final_uids);
results.topK_ = topK;
results.num_queries_ = num_queries;
// TODO: deprecated code begin
final_uids = results.internal_seg_offsets_;
for (auto& id : final_uids) {
if (id == -1) {
continue;
}
id = record.uids_[id];
}
results.result_ids_ = std::move(final_uids);
// TODO: deprecated code end
return Status::OK();
}
} // namespace milvus::query

View File

@ -27,14 +27,5 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
int64_t num_queries,
Timestamp timestamp,
std::optional<const BitmapSimple*> bitmap_opt,
QueryResult& results);
Status
BinaryQueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
const query::QueryInfo& info,
const uint8_t* query_data,
int64_t num_queries,
Timestamp timestamp,
std::optional<const BitmapSimple*> bitmaps_opt,
QueryResult& results);
segcore::QueryResult& results);
} // namespace milvus::query

View File

@ -18,7 +18,7 @@
#include <unordered_map>
#include <vector>
#include "common/Types.h"
#include "utils/Types.h"
#include "utils/Json.h"
namespace milvus {

View File

@ -58,10 +58,6 @@ class ExecExprVisitor : ExprVisitor {
auto
ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType;
template <typename T>
auto
ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType;
private:
segcore::SegmentSmallIndex& segment_;
std::optional<RetType> ret_;

View File

@ -28,7 +28,7 @@ class ExecPlanNodeVisitor : PlanNodeVisitor {
visit(BinaryVectorANNS& node) override;
public:
using RetType = QueryResult;
using RetType = segcore::QueryResult;
ExecPlanNodeVisitor(segcore::SegmentBase& segment, Timestamp timestamp, const PlaceholderGroup& placeholder_group)
: segment_(segment), timestamp_(timestamp), placeholder_group_(placeholder_group) {
}

View File

@ -46,10 +46,6 @@ class ExecExprVisitor : ExprVisitor {
auto
ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType;
template <typename T>
auto
ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType;
private:
segcore::SegmentSmallIndex& segment_;
std::optional<RetType> ret_;
@ -67,6 +63,11 @@ ExecExprVisitor::visit(BoolBinaryExpr& expr) {
PanicInfo("unimplemented");
}
void
ExecExprVisitor::visit(TermExpr& expr) {
PanicInfo("unimplemented");
}
template <typename T, typename IndexFunc, typename ElementFunc>
auto
ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl<T>& expr, IndexFunc index_func, ElementFunc element_func)
@ -83,17 +84,17 @@ ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl<T>& expr, IndexFunc index_fu
auto& indexing_record = segment_.get_indexing_record();
const segcore::ScalarIndexingEntry<T>& entry = indexing_record.get_scalar_entry<T>(field_offset);
RetType results(vec.num_chunk());
RetType results(vec.chunk_size());
auto indexing_barrier = indexing_record.get_finished_ack();
for (auto chunk_id = 0; chunk_id < indexing_barrier; ++chunk_id) {
auto& result = results[chunk_id];
auto indexing = entry.get_indexing(chunk_id);
auto data = index_func(indexing);
result = std::move(*data);
result = ~std::move(*data);
Assert(result.size() == segcore::DefaultElementPerChunk);
}
for (auto chunk_id = indexing_barrier; chunk_id < vec.num_chunk(); ++chunk_id) {
for (auto chunk_id = indexing_barrier; chunk_id < vec.chunk_size(); ++chunk_id) {
auto& result = results[chunk_id];
result.resize(segcore::DefaultElementPerChunk);
auto chunk = vec.get_chunk(chunk_id);
@ -125,32 +126,32 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType {
switch (op) {
case OpType::Equal: {
auto index_func = [val](Index* index) { return index->In(1, &val); };
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x == val); });
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x == val); });
}
case OpType::NotEqual: {
auto index_func = [val](Index* index) { return index->NotIn(1, &val); };
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x != val); });
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x != val); });
}
case OpType::GreaterEqual: {
auto index_func = [val](Index* index) { return index->Range(val, Operator::GE); };
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x >= val); });
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x >= val); });
}
case OpType::GreaterThan: {
auto index_func = [val](Index* index) { return index->Range(val, Operator::GT); };
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x > val); });
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x > val); });
}
case OpType::LessEqual: {
auto index_func = [val](Index* index) { return index->Range(val, Operator::LE); };
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x <= val); });
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x <= val); });
}
case OpType::LessThan: {
auto index_func = [val](Index* index) { return index->Range(val, Operator::LT); };
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x < val); });
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x < val); });
}
default: {
PanicInfo("unsupported range node");
@ -166,16 +167,16 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType {
if (false) {
} else if (ops == std::make_tuple(OpType::GreaterThan, OpType::LessThan)) {
auto index_func = [val1, val2](Index* index) { return index->Range(val1, false, val2, false); };
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 < x && x < val2); });
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 < x && x < val2); });
} else if (ops == std::make_tuple(OpType::GreaterThan, OpType::LessEqual)) {
auto index_func = [val1, val2](Index* index) { return index->Range(val1, false, val2, true); };
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 < x && x <= val2); });
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 < x && x <= val2); });
} else if (ops == std::make_tuple(OpType::GreaterEqual, OpType::LessThan)) {
auto index_func = [val1, val2](Index* index) { return index->Range(val1, true, val2, false); };
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 <= x && x < val2); });
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 <= x && x < val2); });
} else if (ops == std::make_tuple(OpType::GreaterEqual, OpType::LessEqual)) {
auto index_func = [val1, val2](Index* index) { return index->Range(val1, true, val2, true); };
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 <= x && x <= val2); });
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 <= x && x <= val2); });
} else {
PanicInfo("unsupported range node");
}
@ -225,79 +226,4 @@ ExecExprVisitor::visit(RangeExpr& expr) {
ret_ = std::move(ret);
}
template <typename T>
auto
ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType {
auto& expr = static_cast<TermExprImpl<T>&>(expr_raw);
auto& records = segment_.get_insert_record();
auto data_type = expr.data_type_;
auto& schema = segment_.get_schema();
auto field_offset_opt = schema.get_offset(expr.field_id_);
Assert(field_offset_opt);
auto field_offset = field_offset_opt.value();
auto& field_meta = schema[field_offset];
auto vec_ptr = records.get_entity<T>(field_offset);
auto& vec = *vec_ptr;
auto num_chunk = vec.num_chunk();
RetType bitsets;
auto N = records.ack_responder_.GetAck();
// small batch
for (int64_t chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
auto& chunk = vec.get_chunk(chunk_id);
auto size = chunk_id == num_chunk - 1 ? N - chunk_id * segcore::DefaultElementPerChunk
: segcore::DefaultElementPerChunk;
boost::dynamic_bitset<> bitset(segcore::DefaultElementPerChunk);
for (int i = 0; i < size; ++i) {
auto value = chunk[i];
bool is_in = std::binary_search(expr.terms_.begin(), expr.terms_.end(), value);
bitset[i] = is_in;
}
bitsets.emplace_back(std::move(bitset));
}
return bitsets;
}
void
ExecExprVisitor::visit(TermExpr& expr) {
auto& field_meta = segment_.get_schema()[expr.field_id_];
Assert(expr.data_type_ == field_meta.get_data_type());
RetType ret;
switch (expr.data_type_) {
case DataType::BOOL: {
ret = ExecTermVisitorImpl<bool>(expr);
break;
}
case DataType::INT8: {
ret = ExecTermVisitorImpl<int8_t>(expr);
break;
}
case DataType::INT16: {
ret = ExecTermVisitorImpl<int16_t>(expr);
break;
}
case DataType::INT32: {
ret = ExecTermVisitorImpl<int32_t>(expr);
break;
}
case DataType::INT64: {
ret = ExecTermVisitorImpl<int64_t>(expr);
break;
}
case DataType::FLOAT: {
ret = ExecTermVisitorImpl<float>(expr);
break;
}
case DataType::DOUBLE: {
ret = ExecTermVisitorImpl<double>(expr);
break;
}
default:
PanicInfo("unsupported");
}
ret_ = std::move(ret);
}
} // namespace milvus::query

View File

@ -26,7 +26,7 @@ namespace impl {
// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/
class ExecPlanNodeVisitor : PlanNodeVisitor {
public:
using RetType = QueryResult;
using RetType = segcore::QueryResult;
ExecPlanNodeVisitor(segcore::SegmentBase& segment, Timestamp timestamp, const PlaceholderGroup& placeholder_group)
: segment_(segment), timestamp_(timestamp), placeholder_group_(placeholder_group) {
}
@ -75,22 +75,7 @@ ExecPlanNodeVisitor::visit(FloatVectorANNS& node) {
void
ExecPlanNodeVisitor::visit(BinaryVectorANNS& node) {
// TODO: optimize here, remove the dynamic cast
assert(!ret_.has_value());
auto segment = dynamic_cast<segcore::SegmentSmallIndex*>(&segment_);
AssertInfo(segment, "support SegmentSmallIndex Only");
RetType ret;
auto& ph = placeholder_group_.at(0);
auto src_data = ph.get_blob<uint8_t>();
auto num_queries = ph.num_of_queries_;
if (node.predicate_.has_value()) {
auto bitmap = ExecExprVisitor(*segment).call_child(*node.predicate_.value());
auto ptr = &bitmap;
BinaryQueryBruteForceImpl(*segment, node.query_info_, src_data, num_queries, timestamp_, ptr, ret);
} else {
BinaryQueryBruteForceImpl(*segment, node.query_info_, src_data, num_queries, timestamp_, std::nullopt, ret);
}
ret_ = ret;
// TODO
}
} // namespace milvus::query

View File

@ -73,24 +73,7 @@ ShowPlanNodeVisitor::visit(FloatVectorANNS& node) {
void
ShowPlanNodeVisitor::visit(BinaryVectorANNS& node) {
assert(!ret_);
auto& info = node.query_info_;
Json json_body{
{"node_type", "BinaryVectorANNS"}, //
{"metric_type", info.metric_type_}, //
{"field_id_", info.field_id_}, //
{"topK", info.topK_}, //
{"search_params", info.search_params_}, //
{"placeholder_tag", node.placeholder_tag_}, //
};
if (node.predicate_.has_value()) {
ShowExprVisitor expr_show;
Assert(node.predicate_.value());
json_body["predicate"] = expr_show.call_child(node.predicate_->operator*());
} else {
json_body["predicate"] = "None";
}
ret_ = json_body;
// TODO
}
} // namespace milvus::query

View File

@ -123,10 +123,9 @@ Collection::CreateIndex(std::string& index_config) {
void
Collection::parse() {
if (collection_proto_.empty()) {
// TODO: remove hard code use unittests are ready
std::cout << "WARN: Use default schema" << std::endl;
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
schema->AddField("age", DataType::INT32);
schema_ = schema;
return;

View File

@ -196,7 +196,7 @@ class ConcurrentVectorImpl : public VectorBase {
}
ssize_t
num_chunk() const {
chunk_size() const {
return chunks_.size();
}
@ -226,14 +226,8 @@ class ConcurrentVector : public ConcurrentVectorImpl<Type, true> {
using ConcurrentVectorImpl<Type, true>::ConcurrentVectorImpl;
};
class VectorTrait {};
class FloatVector : public VectorTrait {
using embedded_type = float;
};
class BinaryVector : public VectorTrait {
using embedded_type = uint8_t;
};
class FloatVector {};
class BinaryVector {};
template <>
class ConcurrentVector<FloatVector> : public ConcurrentVectorImpl<float, false> {

View File

@ -24,7 +24,7 @@ VecIndexingEntry::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const Vector
auto source = dynamic_cast<const ConcurrentVector<FloatVector>*>(vec_base);
Assert(source);
auto chunk_size = source->num_chunk();
auto chunk_size = source->chunk_size();
assert(ack_end <= chunk_size);
auto conf = get_build_conf();
data_.grow_to_at_least(ack_end);
@ -85,9 +85,11 @@ IndexingRecord::UpdateResourceAck(int64_t chunk_ack, const InsertRecord& record)
template <typename T>
void
ScalarIndexingEntry<T>::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) {
auto dim = field_meta_.get_dim();
auto source = dynamic_cast<const ConcurrentVector<T>*>(vec_base);
Assert(source);
auto chunk_size = source->num_chunk();
auto chunk_size = source->chunk_size();
assert(ack_end <= chunk_size);
data_.grow_to_at_least(ack_end);
for (int chunk_id = ack_beg; chunk_id < ack_end; chunk_id++) {

View File

@ -24,7 +24,7 @@ namespace milvus {
namespace segcore {
// using engine::DataChunk;
// using engine::DataChunkPtr;
using QueryResult = milvus::QueryResult;
using engine::QueryResult;
struct RowBasedRawData {
void* raw_data; // schema
int sizeof_per_row; // alignment

View File

@ -467,16 +467,16 @@ SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry& entry) {
auto dim = field.get_dim();
auto indexing = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(entry.type, entry.mode);
auto chunk_size = record_.uids_.num_chunk();
auto chunk_size = record_.uids_.chunk_size();
auto& uids = record_.uids_;
auto entities = record_.get_entity<FloatVector>(offset);
std::vector<knowhere::DatasetPtr> datasets;
for (int chunk_id = 0; chunk_id < uids.num_chunk(); ++chunk_id) {
for (int chunk_id = 0; chunk_id < uids.chunk_size(); ++chunk_id) {
auto entities_chunk = entities->get_chunk(chunk_id).data();
int64_t count = chunk_id == uids.num_chunk() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
: DefaultElementPerChunk;
int64_t count = chunk_id == uids.chunk_size() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
: DefaultElementPerChunk;
datasets.push_back(knowhere::GenDataset(count, dim, entities_chunk));
}
for (auto& ds : datasets) {

View File

@ -241,10 +241,10 @@ SegmentSmallIndex::BuildVecIndexImpl(const IndexMeta::Entry& entry) {
auto entities = record_.get_entity<FloatVector>(offset);
std::vector<knowhere::DatasetPtr> datasets;
for (int chunk_id = 0; chunk_id < uids.num_chunk(); ++chunk_id) {
for (int chunk_id = 0; chunk_id < uids.chunk_size(); ++chunk_id) {
auto entities_chunk = entities->get_chunk(chunk_id).data();
int64_t count = chunk_id == uids.num_chunk() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
: DefaultElementPerChunk;
int64_t count = chunk_id == uids.chunk_size() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
: DefaultElementPerChunk;
datasets.push_back(knowhere::GenDataset(count, dim, entities_chunk));
}
for (auto& ds : datasets) {

View File

@ -42,7 +42,7 @@ DeleteSegment(CSegmentBase segment) {
void
DeleteQueryResult(CQueryResult query_result) {
auto res = (milvus::QueryResult*)query_result;
auto res = (milvus::segcore::QueryResult*)query_result;
delete res;
}
@ -134,7 +134,7 @@ Search(CSegmentBase c_segment,
placeholder_groups.push_back((const milvus::query::PlaceholderGroup*)c_placeholder_groups[i]);
}
auto query_result = std::make_unique<milvus::QueryResult>();
auto query_result = std::make_unique<milvus::segcore::QueryResult>();
auto status = CStatus();
try {

View File

@ -42,11 +42,8 @@ EasyAssertInfo(
[[noreturn]] void
ThrowWithTrace(const std::exception& exception) {
if (typeid(exception) == typeid(WrappedRuntimError)) {
throw exception;
}
auto err_msg = exception.what() + std::string("\n") + EasyStackTrace();
throw WrappedRuntimError(err_msg);
throw std::runtime_error(err_msg);
}
} // namespace milvus::impl

View File

@ -11,7 +11,6 @@
#pragma once
#include <string_view>
#include <stdexcept>
#include <exception>
#include <stdio.h>
#include <stdlib.h>
@ -23,10 +22,6 @@ void
EasyAssertInfo(
bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info);
class WrappedRuntimError : public std::runtime_error {
using std::runtime_error::runtime_error;
};
[[noreturn]] void
ThrowWithTrace(const std::exception& exception);

View File

@ -26,5 +26,4 @@ target_link_libraries(all_tests
pthread
milvus_utils
)
install (TARGETS all_tests DESTINATION unittest)

View File

@ -21,7 +21,7 @@ TEST(Binary, Insert) {
int64_t num_queries = 10;
int64_t topK = 5;
auto schema = std::make_shared<Schema>();
schema->AddField("vecbin", DataType::VECTOR_BINARY, 128, MetricType::METRIC_Jaccard);
schema->AddField("vecbin", DataType::VECTOR_BINARY, 128);
schema->AddField("age", DataType::INT64);
auto dataset = DataGen(schema, N, 10);
auto segment = CreateSegment(schema);

View File

@ -1,12 +0,0 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>

View File

@ -52,7 +52,7 @@ TEST(ConcurrentVector, TestSingle) {
c_vec.set_data(total_count, vec.data(), insert_size);
total_count += insert_size;
}
ASSERT_EQ(c_vec.num_chunk(), (total_count + 31) / 32);
ASSERT_EQ(c_vec.chunk_size(), (total_count + 31) / 32);
for (int i = 0; i < total_count; ++i) {
for (int d = 0; d < dim; ++d) {
auto std_data = d + i * dim;

View File

@ -98,49 +98,7 @@ TEST(Expr, Range) {
}
})";
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("age", DataType::INT32);
auto plan = CreatePlan(*schema, dsl_string);
ShowPlanNodeVisitor shower;
Assert(plan->tag2field_.at("$0") == "fakevec");
auto out = shower.call_child(*plan->plan_node_);
std::cout << out.dump(4);
}
TEST(Expr, RangeBinary) {
SUCCEED();
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
std::string dsl_string = R"(
{
"bool": {
"must": [
{
"range": {
"age": {
"GT": 1,
"LT": 100
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "Jaccard",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
]
}
})";
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_BINARY, 512, MetricType::METRIC_Jaccard);
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
schema->AddField("age", DataType::INT32);
auto plan = CreatePlan(*schema, dsl_string);
ShowPlanNodeVisitor shower;
@ -182,7 +140,7 @@ TEST(Expr, InvalidRange) {
}
})";
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
schema->AddField("age", DataType::INT32);
ASSERT_ANY_THROW(CreatePlan(*schema, dsl_string));
}
@ -221,7 +179,7 @@ TEST(Expr, InvalidDSL) {
})";
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
schema->AddField("age", DataType::INT32);
ASSERT_ANY_THROW(CreatePlan(*schema, dsl_string));
}
@ -231,7 +189,7 @@ TEST(Expr, ShowExecutor) {
using namespace milvus::segcore;
auto node = std::make_unique<FloatVectorANNS>();
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
int64_t num_queries = 100L;
auto raw_data = DataGen(schema, num_queries);
auto& info = node->query_info_;
@ -290,7 +248,7 @@ TEST(Expr, TestRange) {
}
})";
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
schema->AddField("age", DataType::INT32);
auto seg = CreateSegment(schema);
@ -321,88 +279,7 @@ TEST(Expr, TestRange) {
auto ans = final[vec_id][offset];
auto val = age_col[i];
auto ref = ref_func(val);
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
}
}
}
TEST(Expr, TestTerm) {
using namespace milvus::query;
using namespace milvus::segcore;
auto vec_2k_3k = [] {
std::string buf = "[";
for (int i = 2000; i < 3000 - 1; ++i) {
buf += std::to_string(i) + ", ";
}
buf += std::to_string(2999) + "]";
return buf;
}();
std::vector<std::tuple<std::string, std::function<bool(int)>>> testcases = {
{R"([2000, 3000])", [](int v) { return v == 2000 || v == 3000; }},
{R"([2000])", [](int v) { return v == 2000; }},
{R"([3000])", [](int v) { return v == 3000; }},
{vec_2k_3k, [](int v) { return 2000 <= v && v < 3000; }},
};
std::string dsl_string_tmp = R"(
{
"bool": {
"must": [
{
"term": {
"age": @@@@
}
},
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
]
}
})";
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("age", DataType::INT32);
auto seg = CreateSegment(schema);
int N = 10000;
std::vector<int> age_col;
int num_iters = 100;
for (int iter = 0; iter < num_iters; ++iter) {
auto raw_data = DataGen(schema, N, iter);
auto new_age_col = raw_data.get_col<int>(1);
age_col.insert(age_col.end(), new_age_col.begin(), new_age_col.end());
seg->PreInsert(N);
seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_);
}
auto seg_promote = dynamic_cast<SegmentSmallIndex*>(seg.get());
ExecExprVisitor visitor(*seg_promote);
for (auto [clause, ref_func] : testcases) {
auto loc = dsl_string_tmp.find("@@@@");
auto dsl_string = dsl_string_tmp;
dsl_string.replace(loc, 4, clause);
auto plan = CreatePlan(*schema, dsl_string);
auto final = visitor.call_child(*plan->plan_node_->predicate_.value());
EXPECT_EQ(final.size(), upper_div(N * num_iters, DefaultElementPerChunk));
for (int i = 0; i < N * num_iters; ++i) {
auto vec_id = i / DefaultElementPerChunk;
auto offset = i % DefaultElementPerChunk;
auto ans = final[vec_id][offset];
auto val = age_col[i];
auto ref = ref_func(val);
auto ref = !ref_func(val);
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
}
}

View File

@ -235,14 +235,14 @@ TEST(Indexing, IVFFlatNM) {
}
}
TEST(Indexing, BinaryBruteForce) {
TEST(Indexing, DISABLED_BinaryBruteForce) {
int64_t N = 100000;
int64_t num_queries = 10;
int64_t topk = 5;
int64_t dim = 512;
int64_t dim = 64;
auto result_count = topk * num_queries;
auto schema = std::make_shared<Schema>();
schema->AddField("vecbin", DataType::VECTOR_BINARY, dim, MetricType::METRIC_Jaccard);
schema->AddField("vecbin", DataType::VECTOR_BINARY, dim);
schema->AddField("age", DataType::INT64);
auto dataset = DataGen(schema, N, 10);
vector<float> distances(result_count);
@ -250,16 +250,8 @@ TEST(Indexing, BinaryBruteForce) {
auto bin_vec = dataset.get_col<uint8_t>(0);
auto line_sizeof = schema->operator[](0).get_sizeof();
auto query_data = 1024 * line_sizeof + bin_vec.data();
query::dataset::BinaryQueryDataset query_dataset{
faiss::MetricType::METRIC_Jaccard, //
num_queries, //
topk, //
line_sizeof, //
query_data //
};
query::BinarySearchBruteForce(query_dataset, bin_vec.data(), N, distances.data(), ids.data());
query::BinarySearchBruteForce(faiss::MetricType::METRIC_Jaccard, line_sizeof, bin_vec.data(), N, topk, num_queries,
query_data, distances.data(), ids.data());
QueryResult qr;
qr.num_queries_ = num_queries;
qr.topK_ = topk;
@ -272,78 +264,76 @@ TEST(Indexing, BinaryBruteForce) {
[
[
"1024->0.000000",
"43190->0.578804",
"5255->0.586207",
"23247->0.586486",
"4936->0.588889"
"86966->0.395349",
"24843->0.404762",
"13806->0.416667",
"44313->0.421053"
],
[
"1025->0.000000",
"15147->0.562162",
"49910->0.564304",
"67435->0.567867",
"38292->0.569921"
"14226->0.348837",
"1488->0.365854",
"47337->0.377778",
"20913->0.377778"
],
[
"1026->0.000000",
"15332->0.569061",
"56391->0.572559",
"17187->0.572603",
"26988->0.573771"
"81882->0.386364",
"9215->0.409091",
"95024->0.409091",
"54987->0.414634"
],
[
"1027->0.000000",
"4502->0.559585",
"25879->0.566234",
"66937->0.566489",
"21228->0.566845"
"68981->0.394737",
"75528->0.404762",
"68794->0.405405",
"21975->0.425000"
],
[
"1028->0.000000",
"38490->0.578804",
"12946->0.581717",
"31677->0.582173",
"94474->0.583569"
"90290->0.375000",
"34309->0.394737",
"58559->0.400000",
"33865->0.400000"
],
[
"1029->0.000000",
"59011->0.551630",
"82575->0.555263",
"42914->0.561828",
"23705->0.564171"
"62722->0.388889",
"89070->0.394737",
"18528->0.414634",
"94971->0.421053"
],
[
"1030->0.000000",
"39782->0.579946",
"65553->0.589947",
"82154->0.590028",
"13374->0.590164"
"67402->0.333333",
"3988->0.347826",
"86376->0.354167",
"84381->0.361702"
],
[
"1031->0.000000",
"47826->0.582873",
"72669->0.587432",
"334->0.588076",
"80652->0.589333"
"81569->0.325581",
"12715->0.347826",
"40332->0.363636",
"21037->0.372093"
],
[
"1032->0.000000",
"31968->0.573034",
"63545->0.575758",
"76913->0.575916",
"6286->0.576000"
"60536->0.428571",
"93293->0.432432",
"70969->0.435897",
"64048->0.450000"
],
[
"1033->0.000000",
"95635->0.570248",
"93439->0.574866",
"6709->0.578534",
"6367->0.579634"
"99022->0.394737",
"11763->0.405405",
"50073->0.428571",
"97118->0.428571"
]
]
]
)");
auto json_str = json.dump(2);
auto ref_str = ref.dump(2);
ASSERT_EQ(json_str, ref_str);
ASSERT_EQ(json, ref);
}

View File

@ -72,7 +72,7 @@ TEST(Query, ShowExecutor) {
using namespace milvus;
auto node = std::make_unique<FloatVectorANNS>();
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
int64_t num_queries = 100L;
auto raw_data = DataGen(schema, num_queries);
auto& info = node->query_info_;
@ -98,7 +98,7 @@ TEST(Query, DSL) {
"must": [
{
"vector": {
"fakevec": {
"Vec": {
"metric_type": "L2",
"params": {
"nprobe": 10
@ -113,7 +113,7 @@ TEST(Query, DSL) {
})";
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
auto plan = CreatePlan(*schema, dsl_string);
auto res = shower.call_child(*plan->plan_node_);
@ -123,7 +123,7 @@ TEST(Query, DSL) {
{
"bool": {
"vector": {
"fakevec": {
"Vec": {
"metric_type": "L2",
"params": {
"nprobe": 10
@ -159,7 +159,7 @@ TEST(Query, ParsePlaceholderGroup) {
})";
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
auto plan = CreatePlan(*schema, dsl_string);
int64_t num_queries = 100000;
int dim = 16;
@ -172,7 +172,7 @@ TEST(Query, ExecWithPredicate) {
using namespace milvus::query;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
schema->AddField("age", DataType::FLOAT);
std::string dsl = R"({
"bool": {
@ -217,8 +217,8 @@ TEST(Query, ExecWithPredicate) {
int topk = 5;
Json json = QueryResultToJson(qr);
auto ref = Json::parse(R"(
[
auto ref = Json::parse(R"([
[
[
"980486->3.149221",
@ -257,14 +257,15 @@ TEST(Query, ExecWithPredicate) {
]
]
])");
ASSERT_EQ(json.dump(2), ref.dump(2));
ASSERT_EQ(json, ref);
}
TEST(Query, ExecWithoutPredicate) {
TEST(Query, ExecWihtoutPredicate) {
using namespace milvus::query;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
schema->AddField("age", DataType::FLOAT);
std::string dsl = R"({
"bool": {
@ -300,49 +301,18 @@ TEST(Query, ExecWithoutPredicate) {
segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr);
std::vector<std::vector<std::string>> results;
int topk = 5;
auto json = QueryResultToJson(qr);
auto ref = Json::parse(R"(
[
[
[
"980486->3.149221",
"318367->3.661235",
"302798->4.553688",
"321424->4.757450",
"565529->5.083780"
],
[
"233390->7.931535",
"238958->8.109344",
"230645->8.439169",
"901939->8.658772",
"380328->8.731251"
],
[
"749862->3.398494",
"701321->3.632437",
"897246->3.749835",
"750683->3.897577",
"105995->4.073595"
],
[
"138274->3.454446",
"124548->3.783290",
"840855->4.782170",
"936719->5.026924",
"709627->5.063170"
],
[
"810401->3.926393",
"46575->4.054171",
"201740->4.274491",
"669040->4.399628",
"231500->4.831223"
]
]
]
)");
ASSERT_EQ(json.dump(2), ref.dump(2));
for (int q = 0; q < num_queries; ++q) {
std::vector<std::string> result;
for (int k = 0; k < topk; ++k) {
int index = q * topk + k;
result.emplace_back(std::to_string(qr.result_ids_[index]) + "->" +
std::to_string(qr.result_distances_[index]));
}
results.emplace_back(std::move(result));
}
Json json{results};
std::cout << json.dump(2);
}
TEST(Query, FillSegment) {
@ -361,9 +331,6 @@ TEST(Query, FillSegment) {
auto param = field->add_type_params();
param->set_key("dim");
param->set_value("16");
auto iparam = field->add_index_params();
iparam->set_key("metric_type");
iparam->set_value("L2");
}
{
@ -425,57 +392,3 @@ TEST(Query, FillSegment) {
++std_index;
}
}
TEST(Query, ExecWithPredicateBinary) {
using namespace milvus::query;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_BINARY, 512, MetricType::METRIC_Jaccard);
schema->AddField("age", DataType::FLOAT);
std::string dsl = R"({
"bool": {
"must": [
{
"range": {
"age": {
"GE": -1,
"LT": 1
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "Jaccard",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5
}
}
}
]
}
})";
int64_t N = 1000 * 1000;
auto dataset = DataGen(schema, N);
auto segment = std::make_unique<SegmentSmallIndex>(schema);
segment->PreInsert(N);
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
auto vec_ptr = dataset.get_col<uint8_t>(0);
auto plan = CreatePlan(*schema, dsl);
auto num_queries = 5;
auto ph_group_raw = CreateBinaryPlaceholderGroupFromBlob(num_queries, 512, vec_ptr.data() + 1024 * 512 / 8);
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
QueryResult qr;
Timestamp time = 1000000;
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr);
int topk = 5;
Json json = QueryResultToJson(qr);
std::cout << json.dump(2);
// ASSERT_EQ(json.dump(2), ref.dump(2));
}

View File

@ -63,7 +63,7 @@ TEST(SegmentCoreTest, NormalDistributionTest) {
using namespace milvus::segcore;
using namespace milvus::engine;
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
schema->AddField("age", DataType::INT32);
int N = 1000 * 1000;
auto [raw_data, timestamps, uids] = generate_data(N);
@ -76,7 +76,7 @@ TEST(SegmentCoreTest, MockTest) {
using namespace milvus::segcore;
using namespace milvus::engine;
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
schema->AddField("age", DataType::INT32);
std::vector<char> raw_data;
std::vector<Timestamp> timestamps;
@ -116,7 +116,7 @@ TEST(SegmentCoreTest, SmallIndex) {
using namespace milvus::segcore;
using namespace milvus::engine;
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
schema->AddField("age", DataType::INT32);
int N = 1024 * 1024;
auto data = DataGen(schema, N);

View File

@ -31,14 +31,6 @@ struct GeneratedData {
memcpy(ret.data(), target.data(), target.size());
return ret;
}
template <typename T>
auto
get_mutable_col(int index) {
auto& target = cols_.at(index);
assert(target.size() == row_ids_.size() * sizeof(T));
auto ptr = reinterpret_cast<T*>(target.data());
return ptr;
}
private:
GeneratedData() = default;
@ -66,9 +58,6 @@ GeneratedData::generate_rows(int N, SchemaPtr schema) {
}
}
rows_ = std::move(result);
raw_.raw_data = rows_.data();
raw_.sizeof_per_row = schema->get_total_sizeof();
raw_.count = N;
}
inline GeneratedData
@ -140,12 +129,14 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42) {
}
GeneratedData res;
res.cols_ = std::move(cols);
res.generate_rows(N, schema);
for (int i = 0; i < N; ++i) {
res.row_ids_.push_back(i);
res.timestamps_.push_back(i);
}
res.generate_rows(N, schema);
res.raw_.raw_data = res.rows_.data();
res.raw_.sizeof_per_row = schema->get_total_sizeof();
res.raw_.count = N;
return std::move(res);
}
@ -176,7 +167,7 @@ CreateBinaryPlaceholderGroup(int64_t num_queries, int64_t dim, int64_t seed = 42
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::VECTOR_BINARY);
value->set_type(ser::PlaceholderType::VECTOR_FLOAT);
std::default_random_engine e(seed);
for (int i = 0; i < num_queries; ++i) {
std::vector<uint8_t> vec;
@ -184,27 +175,7 @@ CreateBinaryPlaceholderGroup(int64_t num_queries, int64_t dim, int64_t seed = 42
vec.push_back(e());
}
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
value->add_values(vec.data(), vec.size());
}
return raw_group;
}
inline auto
CreateBinaryPlaceholderGroupFromBlob(int64_t num_queries, int64_t dim, const uint8_t* ptr) {
assert(dim % 8 == 0);
namespace ser = milvus::proto::service;
ser::PlaceholderGroup raw_group;
auto value = raw_group.add_placeholders();
value->set_tag("$0");
value->set_type(ser::PlaceholderType::VECTOR_BINARY);
for (int i = 0; i < num_queries; ++i) {
std::vector<uint8_t> vec;
for (int d = 0; d < dim / 8; ++d) {
vec.push_back(*ptr);
++ptr;
}
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
value->add_values(vec.data(), vec.size());
value->add_values(vec.data(), vec.size() * sizeof(float));
}
return raw_group;
}

View File

@ -0,0 +1,14 @@
package mockkv
import (
memkv "github.com/zilliztech/milvus-distributed/internal/kv/mem"
)
// use MemoryKV to mock EtcdKV
func NewEtcdKV() *memkv.MemoryKV {
return memkv.NewMemoryKV()
}
func NewMemoryKV() *memkv.MemoryKV {
return memkv.NewMemoryKV()
}

View File

@ -19,7 +19,6 @@ import (
func TestMaster_CollectionTask(t *testing.T) {
Init()
refreshMasterAddress()
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
@ -65,10 +64,10 @@ func TestMaster_CollectionTask(t *testing.T) {
svr, err := CreateServer(ctx)
assert.Nil(t, err)
err = svr.Run(int64(Params.Port))
err = svr.Run(10002)
assert.Nil(t, err)
conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock())
conn, err := grpc.DialContext(ctx, "127.0.0.1:10002", grpc.WithInsecure(), grpc.WithBlock())
assert.Nil(t, err)
defer conn.Close()

View File

@ -16,7 +16,6 @@ import (
func TestMaster_ConfigTask(t *testing.T) {
Init()
refreshMasterAddress()
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
@ -60,11 +59,11 @@ func TestMaster_ConfigTask(t *testing.T) {
svr, err := CreateServer(ctx)
require.Nil(t, err)
err = svr.Run(int64(Params.Port))
err = svr.Run(10002)
defer svr.Close()
require.Nil(t, err)
conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock())
conn, err := grpc.DialContext(ctx, "127.0.0.1:10002", grpc.WithInsecure(), grpc.WithBlock())
require.Nil(t, err)
defer conn.Close()

View File

@ -1,6 +1,7 @@
package master
import (
"os"
"testing"
"time"
@ -11,6 +12,16 @@ import (
var gTestTsoAllocator Allocator
var gTestIDAllocator *GlobalIDAllocator
func TestMain(m *testing.M) {
Params.Init()
etcdAddr := Params.EtcdAddress
gTestTsoAllocator = NewGlobalTSOAllocator("timestamp", tsoutil.NewTSOKVBase([]string{etcdAddr}, "/test/root/kv", "tso"))
gTestIDAllocator = NewGlobalIDAllocator("idTimestamp", tsoutil.NewTSOKVBase([]string{etcdAddr}, "/test/root/kv", "gid"))
exitCode := m.Run()
os.Exit(exitCode)
}
func TestGlobalTSOAllocator_Initialize(t *testing.T) {
err := gTestTsoAllocator.Initialize()
assert.Nil(t, err)

View File

@ -17,7 +17,6 @@ import (
func TestMaster_CreateCollection(t *testing.T) {
Init()
refreshMasterAddress()
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
@ -63,10 +62,10 @@ func TestMaster_CreateCollection(t *testing.T) {
svr, err := CreateServer(ctx)
assert.Nil(t, err)
err = svr.Run(int64(Params.Port))
err = svr.Run(10001)
assert.Nil(t, err)
conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock())
conn, err := grpc.DialContext(ctx, "127.0.0.1:10001", grpc.WithInsecure(), grpc.WithBlock())
assert.Nil(t, err)
defer conn.Close()

View File

@ -4,12 +4,9 @@ import (
"context"
"log"
"math/rand"
"os"
"strconv"
"testing"
"github.com/zilliztech/milvus-distributed/internal/util/tsoutil"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
ms "github.com/zilliztech/milvus-distributed/internal/msgstream"
@ -23,42 +20,6 @@ import (
"google.golang.org/grpc"
)
var testPORT = 53200
func genMasterTestPort() int64 {
testPORT++
return int64(testPORT)
}
func refreshMasterAddress() {
masterPort := genMasterTestPort()
Params.Port = int(masterPort)
masterAddr := makeMasterAddress(masterPort)
Params.Address = masterAddr
}
func makeMasterAddress(port int64) string {
masterAddr := "127.0.0.1:" + strconv.FormatInt(port, 10)
return masterAddr
}
func makeNewChannalNames(names []string, suffix string) []string {
var ret []string
for _, name := range names {
ret = append(ret, name+suffix)
}
return ret
}
func refreshChannelNames() {
suffix := "_test" + strconv.FormatInt(rand.Int63n(100), 10)
Params.DDChannelNames = makeNewChannalNames(Params.DDChannelNames, suffix)
Params.WriteNodeTimeTickChannelNames = makeNewChannalNames(Params.WriteNodeTimeTickChannelNames, suffix)
Params.InsertChannelNames = makeNewChannalNames(Params.InsertChannelNames, suffix)
Params.K2SChannelNames = makeNewChannalNames(Params.K2SChannelNames, suffix)
Params.ProxyTimeTickChannelNames = makeNewChannalNames(Params.ProxyTimeTickChannelNames, suffix)
}
func receiveTimeTickMsg(stream *ms.MsgStream) bool {
for {
result := (*stream).Consume()
@ -76,25 +37,13 @@ func getTimeTickMsgPack(ttmsgs [][2]uint64) *ms.MsgPack {
return &msgPack
}
func TestMain(m *testing.M) {
Init()
refreshMasterAddress()
refreshChannelNames()
etcdAddr := Params.EtcdAddress
gTestTsoAllocator = NewGlobalTSOAllocator("timestamp", tsoutil.NewTSOKVBase([]string{etcdAddr}, "/test/root/kv", "tso"))
gTestIDAllocator = NewGlobalIDAllocator("idTimestamp", tsoutil.NewTSOKVBase([]string{etcdAddr}, "/test/root/kv", "gid"))
exitCode := m.Run()
os.Exit(exitCode)
}
func TestMaster(t *testing.T) {
Init()
refreshMasterAddress()
pulsarAddr := Params.PulsarAddress
Params.ProxyIDList = []UniqueID{0}
//Param
// Creates server.
ctx, cancel := context.WithCancel(context.Background())
svr, err := CreateServer(ctx)
if err != nil {
log.Print("create server failed", zap.Error(err))
@ -149,7 +98,7 @@ func TestMaster(t *testing.T) {
var k2sMsgstream ms.MsgStream = k2sMs
assert.True(t, receiveTimeTickMsg(&k2sMsgstream))
conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock())
conn, err := grpc.DialContext(ctx, "127.0.0.1:53100", grpc.WithInsecure(), grpc.WithBlock())
assert.Nil(t, err)
defer conn.Close()

View File

@ -55,8 +55,19 @@ var Params ParamTable
func (p *ParamTable) Init() {
// load yaml
p.BaseTable.Init()
err := p.LoadYaml("advanced/master.yaml")
err := p.LoadYaml("milvus.yaml")
if err != nil {
panic(err)
}
err = p.LoadYaml("advanced/channel.yaml")
if err != nil {
panic(err)
}
err = p.LoadYaml("advanced/master.yaml")
if err != nil {
panic(err)
}
err = p.LoadYaml("advanced/common.yaml")
if err != nil {
panic(err)
}
@ -104,7 +115,15 @@ func (p *ParamTable) initAddress() {
}
func (p *ParamTable) initPort() {
p.Port = p.ParseInt("master.port")
masterPort, err := p.Load("master.port")
if err != nil {
panic(err)
}
port, err := strconv.Atoi(masterPort)
if err != nil {
panic(err)
}
p.Port = port
}
func (p *ParamTable) initEtcdAddress() {
@ -148,40 +167,117 @@ func (p *ParamTable) initKvRootPath() {
}
func (p *ParamTable) initTopicNum() {
iRangeStr, err := p.Load("msgChannel.channelRange.insert")
insertChannelRange, err := p.Load("msgChannel.channelRange.insert")
if err != nil {
panic(err)
}
rangeSlice := paramtable.ConvertRangeToIntRange(iRangeStr, ",")
p.TopicNum = rangeSlice[1] - rangeSlice[0]
channelRange := strings.Split(insertChannelRange, ",")
if len(channelRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(channelRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(channelRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
p.TopicNum = channelEnd
}
func (p *ParamTable) initSegmentSize() {
p.SegmentSize = p.ParseFloat("master.segment.size")
threshold, err := p.Load("master.segment.size")
if err != nil {
panic(err)
}
segmentThreshold, err := strconv.ParseFloat(threshold, 64)
if err != nil {
panic(err)
}
p.SegmentSize = segmentThreshold
}
func (p *ParamTable) initSegmentSizeFactor() {
p.SegmentSizeFactor = p.ParseFloat("master.segment.sizeFactor")
segFactor, err := p.Load("master.segment.sizeFactor")
if err != nil {
panic(err)
}
factor, err := strconv.ParseFloat(segFactor, 64)
if err != nil {
panic(err)
}
p.SegmentSizeFactor = factor
}
func (p *ParamTable) initDefaultRecordSize() {
p.DefaultRecordSize = p.ParseInt64("master.segment.defaultSizePerRecord")
size, err := p.Load("master.segment.defaultSizePerRecord")
if err != nil {
panic(err)
}
res, err := strconv.ParseInt(size, 10, 64)
if err != nil {
panic(err)
}
p.DefaultRecordSize = res
}
func (p *ParamTable) initMinSegIDAssignCnt() {
p.MinSegIDAssignCnt = p.ParseInt64("master.segment.minIDAssignCnt")
size, err := p.Load("master.segment.minIDAssignCnt")
if err != nil {
panic(err)
}
res, err := strconv.ParseInt(size, 10, 64)
if err != nil {
panic(err)
}
p.MinSegIDAssignCnt = res
}
func (p *ParamTable) initMaxSegIDAssignCnt() {
p.MaxSegIDAssignCnt = p.ParseInt64("master.segment.maxIDAssignCnt")
size, err := p.Load("master.segment.maxIDAssignCnt")
if err != nil {
panic(err)
}
res, err := strconv.ParseInt(size, 10, 64)
if err != nil {
panic(err)
}
p.MaxSegIDAssignCnt = res
}
func (p *ParamTable) initSegIDAssignExpiration() {
p.SegIDAssignExpiration = p.ParseInt64("master.segment.IDAssignExpiration")
duration, err := p.Load("master.segment.IDAssignExpiration")
if err != nil {
panic(err)
}
res, err := strconv.ParseInt(duration, 10, 64)
if err != nil {
panic(err)
}
p.SegIDAssignExpiration = res
}
func (p *ParamTable) initQueryNodeNum() {
p.QueryNodeNum = len(p.QueryNodeIDList())
id, err := p.Load("nodeID.queryNodeIDList")
if err != nil {
panic(err)
}
ids := strings.Split(id, ",")
for _, i := range ids {
_, err := strconv.ParseInt(i, 10, 64)
if err != nil {
log.Panicf("load proxy id list error, %s", err.Error())
}
}
p.QueryNodeNum = len(ids)
}
func (p *ParamTable) initQueryNodeStatsChannelName() {
@ -193,7 +289,20 @@ func (p *ParamTable) initQueryNodeStatsChannelName() {
}
func (p *ParamTable) initProxyIDList() {
p.ProxyIDList = p.BaseTable.ProxyIDList()
id, err := p.Load("nodeID.proxyIDList")
if err != nil {
log.Panicf("load proxy id list error, %s", err.Error())
}
ids := strings.Split(id, ",")
idList := make([]typeutil.UniqueID, 0, len(ids))
for _, i := range ids {
v, err := strconv.ParseInt(i, 10, 64)
if err != nil {
log.Panicf("load proxy id list error, %s", err.Error())
}
idList = append(idList, typeutil.UniqueID(v))
}
p.ProxyIDList = idList
}
func (p *ParamTable) initProxyTimeTickChannelNames() {
@ -238,7 +347,20 @@ func (p *ParamTable) initSoftTimeTickBarrierInterval() {
}
func (p *ParamTable) initWriteNodeIDList() {
p.WriteNodeIDList = p.BaseTable.WriteNodeIDList()
id, err := p.Load("nodeID.writeNodeIDList")
if err != nil {
log.Panic(err)
}
ids := strings.Split(id, ",")
idlist := make([]typeutil.UniqueID, 0, len(ids))
for _, i := range ids {
v, err := strconv.ParseInt(i, 10, 64)
if err != nil {
log.Panicf("load proxy id list error, %s", err.Error())
}
idlist = append(idlist, typeutil.UniqueID(v))
}
p.WriteNodeIDList = idlist
}
func (p *ParamTable) initWriteNodeTimeTickChannelNames() {
@ -263,57 +385,81 @@ func (p *ParamTable) initWriteNodeTimeTickChannelNames() {
}
func (p *ParamTable) initDDChannelNames() {
prefix, err := p.Load("msgChannel.chanNamePrefix.dataDefinition")
ch, err := p.Load("msgChannel.chanNamePrefix.dataDefinition")
if err != nil {
panic(err)
log.Fatal(err)
}
prefix += "-"
iRangeStr, err := p.Load("msgChannel.channelRange.dataDefinition")
id, err := p.Load("nodeID.queryNodeIDList")
if err != nil {
panic(err)
log.Panicf("load query node id list error, %s", err.Error())
}
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
ids := strings.Split(id, ",")
channels := make([]string, 0, len(ids))
for _, i := range ids {
_, err := strconv.ParseInt(i, 10, 64)
if err != nil {
log.Panicf("load query node id list error, %s", err.Error())
}
channels = append(channels, ch+"-"+i)
}
p.DDChannelNames = ret
p.DDChannelNames = channels
}
func (p *ParamTable) initInsertChannelNames() {
prefix, err := p.Load("msgChannel.chanNamePrefix.insert")
ch, err := p.Load("msgChannel.chanNamePrefix.insert")
if err != nil {
log.Fatal(err)
}
channelRange, err := p.Load("msgChannel.channelRange.insert")
if err != nil {
panic(err)
}
prefix += "-"
iRangeStr, err := p.Load("msgChannel.channelRange.insert")
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
p.InsertChannelNames = ret
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
p.InsertChannelNames = channels
}
func (p *ParamTable) initK2SChannelNames() {
prefix, err := p.Load("msgChannel.chanNamePrefix.k2s")
ch, err := p.Load("msgChannel.chanNamePrefix.k2s")
if err != nil {
panic(err)
log.Fatal(err)
}
prefix += "-"
iRangeStr, err := p.Load("msgChannel.channelRange.k2s")
id, err := p.Load("nodeID.writeNodeIDList")
if err != nil {
panic(err)
log.Panicf("load write node id list error, %s", err.Error())
}
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
ids := strings.Split(id, ",")
channels := make([]string, 0, len(ids))
for _, i := range ids {
_, err := strconv.ParseInt(i, 10, 64)
if err != nil {
log.Panicf("load write node id list error, %s", err.Error())
}
channels = append(channels, ch+"-"+i)
}
p.K2SChannelNames = ret
p.K2SChannelNames = channels
}
func (p *ParamTable) initMaxPartitionNum() {

View File

@ -1,7 +1,6 @@
package master
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
@ -12,111 +11,133 @@ func TestParamTable_Init(t *testing.T) {
}
func TestParamTable_Address(t *testing.T) {
Params.Init()
address := Params.Address
assert.Equal(t, address, "localhost")
}
func TestParamTable_Port(t *testing.T) {
Params.Init()
port := Params.Port
assert.Equal(t, port, 53100)
}
func TestParamTable_MetaRootPath(t *testing.T) {
Params.Init()
path := Params.MetaRootPath
assert.Equal(t, path, "by-dev/meta")
}
func TestParamTable_KVRootPath(t *testing.T) {
Params.Init()
path := Params.KvRootPath
assert.Equal(t, path, "by-dev/kv")
}
func TestParamTable_TopicNum(t *testing.T) {
Params.Init()
num := Params.TopicNum
fmt.Println("TopicNum:", num)
assert.Equal(t, num, 1)
}
func TestParamTable_SegmentSize(t *testing.T) {
Params.Init()
size := Params.SegmentSize
assert.Equal(t, size, float64(512))
}
func TestParamTable_SegmentSizeFactor(t *testing.T) {
Params.Init()
factor := Params.SegmentSizeFactor
assert.Equal(t, factor, 0.75)
}
func TestParamTable_DefaultRecordSize(t *testing.T) {
Params.Init()
size := Params.DefaultRecordSize
assert.Equal(t, size, int64(1024))
}
func TestParamTable_MinSegIDAssignCnt(t *testing.T) {
Params.Init()
cnt := Params.MinSegIDAssignCnt
assert.Equal(t, cnt, int64(1024))
}
func TestParamTable_MaxSegIDAssignCnt(t *testing.T) {
Params.Init()
cnt := Params.MaxSegIDAssignCnt
assert.Equal(t, cnt, int64(16384))
}
func TestParamTable_SegIDAssignExpiration(t *testing.T) {
Params.Init()
expiration := Params.SegIDAssignExpiration
assert.Equal(t, expiration, int64(2000))
}
func TestParamTable_QueryNodeNum(t *testing.T) {
Params.Init()
num := Params.QueryNodeNum
fmt.Println("QueryNodeNum", num)
assert.Equal(t, num, 1)
}
func TestParamTable_QueryNodeStatsChannelName(t *testing.T) {
Params.Init()
name := Params.QueryNodeStatsChannelName
assert.Equal(t, name, "query-node-stats")
}
func TestParamTable_ProxyIDList(t *testing.T) {
Params.Init()
ids := Params.ProxyIDList
assert.Equal(t, len(ids), 1)
assert.Equal(t, ids[0], int64(0))
}
func TestParamTable_ProxyTimeTickChannelNames(t *testing.T) {
Params.Init()
names := Params.ProxyTimeTickChannelNames
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "proxyTimeTick-0")
}
func TestParamTable_MsgChannelSubName(t *testing.T) {
Params.Init()
name := Params.MsgChannelSubName
assert.Equal(t, name, "master")
}
func TestParamTable_SoftTimeTickBarrierInterval(t *testing.T) {
Params.Init()
interval := Params.SoftTimeTickBarrierInterval
assert.Equal(t, interval, Timestamp(0x7d00000))
}
func TestParamTable_WriteNodeIDList(t *testing.T) {
Params.Init()
ids := Params.WriteNodeIDList
assert.Equal(t, len(ids), 1)
assert.Equal(t, ids[0], int64(3))
}
func TestParamTable_WriteNodeTimeTickChannelNames(t *testing.T) {
Params.Init()
names := Params.WriteNodeTimeTickChannelNames
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "writeNodeTimeTick-3")
}
func TestParamTable_InsertChannelNames(t *testing.T) {
Params.Init()
names := Params.InsertChannelNames
assert.Equal(t, Params.TopicNum, len(names))
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "insert-0")
}
func TestParamTable_K2SChannelNames(t *testing.T) {
Params.Init()
names := Params.K2SChannelNames
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "k2s-0")
assert.Equal(t, names[0], "k2s-3")
}

View File

@ -2,6 +2,8 @@ package master
import (
"context"
"math/rand"
"strconv"
"testing"
"github.com/golang/protobuf/proto"
@ -18,7 +20,6 @@ import (
func TestMaster_Partition(t *testing.T) {
Init()
refreshMasterAddress()
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
@ -65,12 +66,13 @@ func TestMaster_Partition(t *testing.T) {
DefaultPartitionTag: "_default",
}
port := 10000 + rand.Intn(1000)
svr, err := CreateServer(ctx)
assert.Nil(t, err)
err = svr.Run(int64(Params.Port))
err = svr.Run(int64(port))
assert.Nil(t, err)
conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock())
conn, err := grpc.DialContext(ctx, "127.0.0.1:"+strconv.Itoa(port), grpc.WithInsecure(), grpc.WithBlock())
assert.Nil(t, err)
defer conn.Close()

View File

@ -35,7 +35,7 @@ var master *Master
var masterCancelFunc context.CancelFunc
func setup() {
Init()
Params.Init()
etcdAddress := Params.EtcdAddress
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}})
@ -218,8 +218,7 @@ func TestSegmentManager_SegmentStats(t *testing.T) {
}
func startupMaster() {
Init()
refreshMasterAddress()
Params.Init()
etcdAddress := Params.EtcdAddress
rootPath := "/test/root"
ctx, cancel := context.WithCancel(context.TODO())
@ -232,6 +231,7 @@ func startupMaster() {
if err != nil {
panic(err)
}
Params = ParamTable{
Address: Params.Address,
Port: Params.Port,
@ -272,7 +272,7 @@ func startupMaster() {
if err != nil {
panic(err)
}
err = master.Run(int64(Params.Port))
err = master.Run(10013)
if err != nil {
panic(err)
@ -289,7 +289,7 @@ func TestSegmentManager_RPC(t *testing.T) {
defer shutdownMaster()
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
dialContext, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock())
dialContext, err := grpc.DialContext(ctx, "127.0.0.1:10013", grpc.WithInsecure(), grpc.WithBlock())
assert.Nil(t, err)
defer dialContext.Close()
client := masterpb.NewMasterClient(dialContext)

View File

@ -19,8 +19,19 @@ var Params ParamTable
func (pt *ParamTable) Init() {
pt.BaseTable.Init()
err := pt.LoadYaml("advanced/proxy.yaml")
err := pt.LoadYaml("milvus.yaml")
if err != nil {
panic(err)
}
err = pt.LoadYaml("advanced/proxy.yaml")
if err != nil {
panic(err)
}
err = pt.LoadYaml("advanced/channel.yaml")
if err != nil {
panic(err)
}
err = pt.LoadYaml("advanced/common.yaml")
if err != nil {
panic(err)
}
@ -37,24 +48,15 @@ func (pt *ParamTable) Init() {
pt.Save("_proxyID", proxyIDStr)
}
func (pt *ParamTable) NetworkPort() int {
return pt.ParseInt("proxy.port")
}
func (pt *ParamTable) NetworkAddress() string {
addr, err := pt.Load("proxy.address")
func (pt *ParamTable) NetWorkAddress() string {
addr, err := pt.Load("proxy.network.address")
if err != nil {
panic(err)
}
hostName, _ := net.LookupHost(addr)
if len(hostName) <= 0 {
if ip := net.ParseIP(addr); ip == nil {
panic("invalid ip proxy.address")
}
if ip := net.ParseIP(addr); ip == nil {
panic("invalid ip proxy.network.address")
}
port, err := pt.Load("proxy.port")
port, err := pt.Load("proxy.network.port")
if err != nil {
panic(err)
}
@ -86,6 +88,23 @@ func (pt *ParamTable) ProxyNum() int {
return len(ret)
}
func (pt *ParamTable) ProxyIDList() []UniqueID {
proxyIDStr, err := pt.Load("nodeID.proxyIDList")
if err != nil {
panic(err)
}
var ret []UniqueID
proxyIDs := strings.Split(proxyIDStr, ",")
for _, i := range proxyIDs {
v, err := strconv.Atoi(i)
if err != nil {
log.Panicf("load proxy id list error, %s", err.Error())
}
ret = append(ret, UniqueID(v))
}
return ret
}
func (pt *ParamTable) queryNodeNum() int {
return len(pt.queryNodeIDList())
}
@ -131,6 +150,25 @@ func (pt *ParamTable) TimeTickInterval() time.Duration {
return time.Duration(interval) * time.Millisecond
}
func (pt *ParamTable) convertRangeToSlice(rangeStr, sep string) []int {
channelIDs := strings.Split(rangeStr, sep)
startStr := channelIDs[0]
endStr := channelIDs[1]
start, err := strconv.Atoi(startStr)
if err != nil {
panic(err)
}
end, err := strconv.Atoi(endStr)
if err != nil {
panic(err)
}
var ret []int
for i := start; i < end; i++ {
ret = append(ret, i)
}
return ret
}
func (pt *ParamTable) sliceIndex() int {
proxyID := pt.ProxyID()
proxyIDList := pt.ProxyIDList()
@ -152,7 +190,7 @@ func (pt *ParamTable) InsertChannelNames() []string {
if err != nil {
panic(err)
}
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
channelIDs := pt.convertRangeToSlice(iRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
@ -178,12 +216,19 @@ func (pt *ParamTable) DeleteChannelNames() []string {
if err != nil {
panic(err)
}
channelIDs := paramtable.ConvertRangeToIntSlice(dRangeStr, ",")
channelIDs := pt.convertRangeToSlice(dRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
return ret
proxyNum := pt.ProxyNum()
sep := len(channelIDs) / proxyNum
index := pt.sliceIndex()
if index == -1 {
panic("ProxyID not Match with Config")
}
start := index * sep
return ret[start : start+sep]
}
func (pt *ParamTable) K2SChannelNames() []string {
@ -196,12 +241,19 @@ func (pt *ParamTable) K2SChannelNames() []string {
if err != nil {
panic(err)
}
channelIDs := paramtable.ConvertRangeToIntSlice(k2sRangeStr, ",")
channelIDs := pt.convertRangeToSlice(k2sRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
return ret
proxyNum := pt.ProxyNum()
sep := len(channelIDs) / proxyNum
index := pt.sliceIndex()
if index == -1 {
panic("ProxyID not Match with Config")
}
start := index * sep
return ret[start : start+sep]
}
func (pt *ParamTable) SearchChannelNames() []string {
@ -209,17 +261,8 @@ func (pt *ParamTable) SearchChannelNames() []string {
if err != nil {
panic(err)
}
prefix += "-"
sRangeStr, err := pt.Load("msgChannel.channelRange.search")
if err != nil {
panic(err)
}
channelIDs := paramtable.ConvertRangeToIntSlice(sRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
return ret
prefix += "-0"
return []string{prefix}
}
func (pt *ParamTable) SearchResultChannelNames() []string {
@ -232,7 +275,7 @@ func (pt *ParamTable) SearchResultChannelNames() []string {
if err != nil {
panic(err)
}
channelIDs := paramtable.ConvertRangeToIntSlice(sRangeStr, ",")
channelIDs := pt.convertRangeToSlice(sRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
@ -278,24 +321,144 @@ func (pt *ParamTable) DataDefinitionChannelNames() []string {
return []string{prefix}
}
func (pt *ParamTable) parseInt64(key string) int64 {
valueStr, err := pt.Load(key)
if err != nil {
panic(err)
}
value, err := strconv.Atoi(valueStr)
if err != nil {
panic(err)
}
return int64(value)
}
func (pt *ParamTable) MsgStreamInsertBufSize() int64 {
return pt.ParseInt64("proxy.msgStream.insert.bufSize")
return pt.parseInt64("proxy.msgStream.insert.bufSize")
}
func (pt *ParamTable) MsgStreamSearchBufSize() int64 {
return pt.ParseInt64("proxy.msgStream.search.bufSize")
return pt.parseInt64("proxy.msgStream.search.bufSize")
}
func (pt *ParamTable) MsgStreamSearchResultBufSize() int64 {
return pt.ParseInt64("proxy.msgStream.searchResult.recvBufSize")
return pt.parseInt64("proxy.msgStream.searchResult.recvBufSize")
}
func (pt *ParamTable) MsgStreamSearchResultPulsarBufSize() int64 {
return pt.ParseInt64("proxy.msgStream.searchResult.pulsarBufSize")
return pt.parseInt64("proxy.msgStream.searchResult.pulsarBufSize")
}
func (pt *ParamTable) MsgStreamTimeTickBufSize() int64 {
return pt.ParseInt64("proxy.msgStream.timeTick.bufSize")
return pt.parseInt64("proxy.msgStream.timeTick.bufSize")
}
func (pt *ParamTable) insertChannelNames() []string {
ch, err := pt.Load("msgChannel.chanNamePrefix.insert")
if err != nil {
log.Fatal(err)
}
channelRange, err := pt.Load("msgChannel.channelRange.insert")
if err != nil {
panic(err)
}
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
}
func (pt *ParamTable) searchChannelNames() []string {
ch, err := pt.Load("msgChannel.chanNamePrefix.search")
if err != nil {
log.Fatal(err)
}
channelRange, err := pt.Load("msgChannel.channelRange.search")
if err != nil {
panic(err)
}
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
}
func (pt *ParamTable) searchResultChannelNames() []string {
ch, err := pt.Load("msgChannel.chanNamePrefix.searchResult")
if err != nil {
log.Fatal(err)
}
channelRange, err := pt.Load("msgChannel.channelRange.searchResult")
if err != nil {
panic(err)
}
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
}
func (pt *ParamTable) MaxNameLength() int64 {

View File

@ -5,7 +5,6 @@ import (
"log"
"math/rand"
"net"
"strconv"
"sync"
"time"
@ -60,7 +59,7 @@ func CreateProxy(ctx context.Context) (*Proxy, error) {
p.queryMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamSearchBufSize())
p.queryMsgStream.SetPulsarClient(pulsarAddress)
p.queryMsgStream.CreatePulsarProducers(Params.SearchChannelNames())
p.queryMsgStream.CreatePulsarProducers(Params.searchChannelNames())
masterAddr := Params.MasterAddress()
idAllocator, err := allocator.NewIDAllocator(p.proxyLoopCtx, masterAddr)
@ -84,7 +83,7 @@ func CreateProxy(ctx context.Context) (*Proxy, error) {
p.manipulationMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamInsertBufSize())
p.manipulationMsgStream.SetPulsarClient(pulsarAddress)
p.manipulationMsgStream.CreatePulsarProducers(Params.InsertChannelNames())
p.manipulationMsgStream.CreatePulsarProducers(Params.insertChannelNames())
repackFuncImpl := func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
return insertRepackFunc(tsMsgs, hashKeys, p.segAssigner, false)
}
@ -138,7 +137,7 @@ func (p *Proxy) AddCloseCallback(callbacks ...func()) {
func (p *Proxy) grpcLoop() {
defer p.proxyLoopWg.Done()
lis, err := net.Listen("tcp", ":"+strconv.Itoa(Params.NetworkPort()))
lis, err := net.Listen("tcp", Params.NetWorkAddress())
if err != nil {
log.Fatalf("Proxy grpc server fatal error=%v", err)
}

View File

@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log"
"math/rand"
"os"
"strconv"
"strings"
@ -35,26 +34,8 @@ var masterServer *master.Master
var testNum = 10
func makeNewChannalNames(names []string, suffix string) []string {
var ret []string
for _, name := range names {
ret = append(ret, name+suffix)
}
return ret
}
func refreshChannelNames() {
suffix := "_test" + strconv.FormatInt(rand.Int63n(100), 10)
master.Params.DDChannelNames = makeNewChannalNames(master.Params.DDChannelNames, suffix)
master.Params.WriteNodeTimeTickChannelNames = makeNewChannalNames(master.Params.WriteNodeTimeTickChannelNames, suffix)
master.Params.InsertChannelNames = makeNewChannalNames(master.Params.InsertChannelNames, suffix)
master.Params.K2SChannelNames = makeNewChannalNames(master.Params.K2SChannelNames, suffix)
master.Params.ProxyTimeTickChannelNames = makeNewChannalNames(master.Params.ProxyTimeTickChannelNames, suffix)
}
func startMaster(ctx context.Context) {
master.Init()
refreshChannelNames()
etcdAddr := master.Params.EtcdAddress
metaRootPath := master.Params.MetaRootPath
@ -100,7 +81,7 @@ func setup() {
startMaster(ctx)
startProxy(ctx)
proxyAddr := Params.NetworkAddress()
proxyAddr := Params.NetWorkAddress()
addr := strings.Split(proxyAddr, ":")
if addr[0] == "0.0.0.0" {
proxyAddr = "127.0.0.1:" + addr[1]

View File

@ -364,7 +364,7 @@ func (sched *TaskScheduler) queryResultLoop() {
unmarshal := msgstream.NewUnmarshalDispatcher()
queryResultMsgStream := msgstream.NewPulsarMsgStream(sched.ctx, Params.MsgStreamSearchResultBufSize())
queryResultMsgStream.SetPulsarClient(Params.PulsarAddress())
queryResultMsgStream.CreatePulsarConsumers(Params.SearchResultChannelNames(),
queryResultMsgStream.CreatePulsarConsumers(Params.searchResultChannelNames(),
Params.ProxySubName(),
unmarshal,
Params.MsgStreamSearchResultPulsarBufSize())

View File

@ -31,7 +31,7 @@ import (
* is up-to-date.
*/
type collectionReplica interface {
getTSafe() tSafe
getTSafe() *tSafe
// collection
getCollectionNum() int
@ -68,11 +68,11 @@ type collectionReplicaImpl struct {
collections []*Collection
segments map[UniqueID]*Segment
tSafe tSafe
tSafe *tSafe
}
//----------------------------------------------------------------------------------------------------- tSafe
func (colReplica *collectionReplicaImpl) getTSafe() tSafe {
func (colReplica *collectionReplicaImpl) getTSafe() *tSafe {
return colReplica.tSafe
}
@ -111,7 +111,6 @@ func (colReplica *collectionReplicaImpl) removeCollection(collectionID UniqueID)
if col.ID() == collectionID {
for _, p := range *col.Partitions() {
for _, s := range *p.Segments() {
deleteSegment(colReplica.segments[s.ID()])
delete(colReplica.segments, s.ID())
}
}
@ -203,7 +202,6 @@ func (colReplica *collectionReplicaImpl) removePartition(collectionID UniqueID,
for _, p := range *collection.Partitions() {
if p.Tag() == partitionTag {
for _, s := range *p.Segments() {
deleteSegment(colReplica.segments[s.ID()])
delete(colReplica.segments, s.ID())
}
} else {

File diff suppressed because it is too large Load Diff

View File

@ -1,48 +1,179 @@
package querynode
import (
"context"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
func TestCollection_Partitions(t *testing.T) {
node := newQueryNode()
collectionName := "collection0"
collectionID := UniqueID(0)
initTestMeta(t, node, collectionName, collectionID, 0)
ctx := context.Background()
node := NewQueryNode(ctx, 0)
collection, err := node.replica.getCollectionByName(collectionName)
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
for _, tag := range collectionMeta.PartitionTags {
err := (*node.replica).addPartition(collection.ID(), tag)
assert.NoError(t, err)
}
partitions := collection.Partitions()
assert.Equal(t, 1, len(*partitions))
assert.Equal(t, len(collectionMeta.PartitionTags), len(*partitions))
}
func TestCollection_newCollection(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
}
func TestCollection_deleteCollection(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
deleteCollection(collection)
}

View File

@ -11,10 +11,10 @@ type dataSyncService struct {
ctx context.Context
fg *flowgraph.TimeTickedFlowGraph
replica collectionReplica
replica *collectionReplica
}
func newDataSyncService(ctx context.Context, replica collectionReplica) *dataSyncService {
func newDataSyncService(ctx context.Context, replica *collectionReplica) *dataSyncService {
return &dataSyncService{
ctx: ctx,

View File

@ -1,22 +1,101 @@
package querynode
import (
"context"
"encoding/binary"
"math"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
// NOTE: start pulsar before test
func TestDataSyncService_Start(t *testing.T) {
Params.Init()
var ctx context.Context
if closeWithDeadline {
var cancel context.CancelFunc
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel = context.WithDeadline(context.Background(), d)
defer cancel()
} else {
ctx = context.Background()
}
// init query node
pulsarURL, _ := Params.pulsarAddress()
node := NewQueryNode(ctx, 0)
// init meta
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
assert.NoError(t, err)
segmentID := UniqueID(0)
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
assert.NoError(t, err)
node := newQueryNode()
initTestMeta(t, node, "collection0", 0, 0)
// test data generate
const msgLength = 10
const DIM = 16
@ -100,25 +179,25 @@ func TestDataSyncService_Start(t *testing.T) {
// pulsar produce
const receiveBufSize = 1024
producerChannels := Params.insertChannelNames()
pulsarURL, _ := Params.pulsarAddress()
insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
insertStream.SetPulsarClient(pulsarURL)
insertStream.CreatePulsarProducers(producerChannels)
var insertMsgStream msgstream.MsgStream = insertStream
insertMsgStream.Start()
err := insertMsgStream.Produce(&msgPack)
err = insertMsgStream.Produce(&msgPack)
assert.NoError(t, err)
err = insertMsgStream.Broadcast(&timeTickMsgPack)
assert.NoError(t, err)
// dataSync
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
go node.dataSyncService.start()
node.Close()
<-ctx.Done()
}

View File

@ -10,7 +10,7 @@ import (
type insertNode struct {
BaseNode
replica collectionReplica
replica *collectionReplica
}
type InsertData struct {
@ -58,13 +58,13 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg {
insertData.insertRecords[task.SegmentID] = append(insertData.insertRecords[task.SegmentID], task.RowData...)
// check if segment exists, if not, create this segment
if !iNode.replica.hasSegment(task.SegmentID) {
collection, err := iNode.replica.getCollectionByName(task.CollectionName)
if !(*iNode.replica).hasSegment(task.SegmentID) {
collection, err := (*iNode.replica).getCollectionByName(task.CollectionName)
if err != nil {
log.Println(err)
continue
}
err = iNode.replica.addSegment(task.SegmentID, task.PartitionTag, collection.ID())
err = (*iNode.replica).addSegment(task.SegmentID, task.PartitionTag, collection.ID())
if err != nil {
log.Println(err)
continue
@ -74,7 +74,7 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg {
// 2. do preInsert
for segmentID := range insertData.insertRecords {
var targetSegment, err = iNode.replica.getSegmentByID(segmentID)
var targetSegment, err = (*iNode.replica).getSegmentByID(segmentID)
if err != nil {
log.Println("preInsert failed")
// TODO: add error handling
@ -102,7 +102,7 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg {
}
func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *sync.WaitGroup) {
var targetSegment, err = iNode.replica.getSegmentByID(segmentID)
var targetSegment, err = (*iNode.replica).getSegmentByID(segmentID)
if err != nil {
log.Println("cannot find segment:", segmentID)
// TODO: add error handling
@ -127,7 +127,7 @@ func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *syn
wg.Done()
}
func newInsertNode(replica collectionReplica) *insertNode {
func newInsertNode(replica *collectionReplica) *insertNode {
maxQueueLength := Params.flowGraphMaxQueueLength()
maxParallelism := Params.flowGraphMaxParallelism()

View File

@ -6,7 +6,7 @@ import (
type serviceTimeNode struct {
BaseNode
replica collectionReplica
replica *collectionReplica
}
func (stNode *serviceTimeNode) Name() string {
@ -28,12 +28,12 @@ func (stNode *serviceTimeNode) Operate(in []*Msg) []*Msg {
}
// update service time
stNode.replica.getTSafe().set(serviceTimeMsg.timeRange.timestampMax)
(*(*stNode.replica).getTSafe()).set(serviceTimeMsg.timeRange.timestampMax)
//fmt.Println("update tSafe to:", getPhysicalTime(serviceTimeMsg.timeRange.timestampMax))
return nil
}
func newServiceTimeNode(replica collectionReplica) *serviceTimeNode {
func newServiceTimeNode(replica *collectionReplica) *serviceTimeNode {
maxQueueLength := Params.flowGraphMaxQueueLength()
maxParallelism := Params.flowGraphMaxParallelism()

View File

@ -26,10 +26,10 @@ const (
type metaService struct {
ctx context.Context
kvBase *etcdkv.EtcdKV
replica collectionReplica
replica *collectionReplica
}
func newMetaService(ctx context.Context, replica collectionReplica) *metaService {
func newMetaService(ctx context.Context, replica *collectionReplica) *metaService {
ETCDAddr := Params.etcdAddress()
MetaRootPath := Params.metaRootPath()
@ -149,12 +149,12 @@ func (mService *metaService) processCollectionCreate(id string, value string) {
col := mService.collectionUnmarshal(value)
if col != nil {
err := mService.replica.addCollection(col, value)
err := (*mService.replica).addCollection(col, value)
if err != nil {
log.Println(err)
}
for _, partitionTag := range col.PartitionTags {
err = mService.replica.addPartition(col.ID, partitionTag)
err = (*mService.replica).addPartition(col.ID, partitionTag)
if err != nil {
log.Println(err)
}
@ -173,7 +173,7 @@ func (mService *metaService) processSegmentCreate(id string, value string) {
// TODO: what if seg == nil? We need to notify master and return rpc request failed
if seg != nil {
err := mService.replica.addSegment(seg.SegmentID, seg.PartitionTag, seg.CollectionID)
err := (*mService.replica).addSegment(seg.SegmentID, seg.PartitionTag, seg.CollectionID)
if err != nil {
log.Println(err)
return
@ -202,7 +202,7 @@ func (mService *metaService) processSegmentModify(id string, value string) {
}
if seg != nil {
targetSegment, err := mService.replica.getSegmentByID(seg.SegmentID)
targetSegment, err := (*mService.replica).getSegmentByID(seg.SegmentID)
if err != nil {
log.Println(err)
return
@ -218,11 +218,11 @@ func (mService *metaService) processCollectionModify(id string, value string) {
col := mService.collectionUnmarshal(value)
if col != nil {
err := mService.replica.addPartitionsByCollectionMeta(col)
err := (*mService.replica).addPartitionsByCollectionMeta(col)
if err != nil {
log.Println(err)
}
err = mService.replica.removePartitionsByCollectionMeta(col)
err = (*mService.replica).removePartitionsByCollectionMeta(col)
if err != nil {
log.Println(err)
}
@ -249,7 +249,7 @@ func (mService *metaService) processSegmentDelete(id string) {
log.Println("Cannot parse segment id:" + id)
}
err = mService.replica.removeSegment(segmentID)
err = (*mService.replica).removeSegment(segmentID)
if err != nil {
log.Println(err)
return
@ -264,7 +264,7 @@ func (mService *metaService) processCollectionDelete(id string) {
log.Println("Cannot parse collection id:" + id)
}
err = mService.replica.removeCollection(collectionID)
err = (*mService.replica).removeCollection(collectionID)
if err != nil {
log.Println(err)
return

View File

@ -3,13 +3,23 @@ package querynode
import (
"context"
"math"
"os"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
func TestMain(m *testing.M) {
Params.Init()
exitCode := m.Run()
os.Exit(exitCode)
}
func TestMetaService_start(t *testing.T) {
var ctx context.Context
@ -27,7 +37,6 @@ func TestMetaService_start(t *testing.T) {
node.metaService = newMetaService(ctx, node.replica)
(*node.metaService).start()
node.Close()
}
func TestMetaService_getCollectionObjId(t *testing.T) {
@ -110,9 +119,47 @@ func TestMetaService_isSegmentChannelRangeInQueryNodeChannelRange(t *testing.T)
func TestMetaService_printCollectionStruct(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
printCollectionStruct(collectionMeta)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
printCollectionStruct(&collectionMeta)
}
func TestMetaService_printSegmentStruct(t *testing.T) {
@ -131,8 +178,13 @@ func TestMetaService_printSegmentStruct(t *testing.T) {
}
func TestMetaService_processCollectionCreate(t *testing.T) {
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
id := "0"
value := `schema: <
@ -144,10 +196,6 @@ func TestMetaService_processCollectionCreate(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -164,21 +212,71 @@ func TestMetaService_processCollectionCreate(t *testing.T) {
node.metaService.processCollectionCreate(id, value)
collectionNum := node.replica.getCollectionNum()
collectionNum := (*node.replica).getCollectionNum()
assert.Equal(t, collectionNum, 1)
collection, err := node.replica.getCollectionByName("test")
collection, err := (*node.replica).getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
node.Close()
}
func TestMetaService_processSegmentCreate(t *testing.T) {
node := newQueryNode()
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
collectionName := "collection0"
collectionID := UniqueID(0)
initTestMeta(t, node, collectionName, collectionID, 0)
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
colMetaBlob := proto.MarshalTextString(&collectionMeta)
err := (*node.replica).addCollection(&collectionMeta, string(colMetaBlob))
assert.NoError(t, err)
err = (*node.replica).addPartition(UniqueID(0), "default")
assert.NoError(t, err)
id := "0"
value := `partition_tag: "default"
@ -189,15 +287,19 @@ func TestMetaService_processSegmentCreate(t *testing.T) {
(*node.metaService).processSegmentCreate(id, value)
s, err := node.replica.getSegmentByID(UniqueID(0))
s, err := (*node.replica).getSegmentByID(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, s.segmentID, UniqueID(0))
node.Close()
}
func TestMetaService_processCreate(t *testing.T) {
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
key1 := "by-dev/meta/collection/0"
msg1 := `schema: <
@ -209,10 +311,6 @@ func TestMetaService_processCreate(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -228,10 +326,10 @@ func TestMetaService_processCreate(t *testing.T) {
`
(*node.metaService).processCreate(key1, msg1)
collectionNum := node.replica.getCollectionNum()
collectionNum := (*node.replica).getCollectionNum()
assert.Equal(t, collectionNum, 1)
collection, err := node.replica.getCollectionByName("test")
collection, err := (*node.replica).getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
@ -243,19 +341,68 @@ func TestMetaService_processCreate(t *testing.T) {
`
(*node.metaService).processCreate(key2, msg2)
s, err := node.replica.getSegmentByID(UniqueID(0))
s, err := (*node.replica).getSegmentByID(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, s.segmentID, UniqueID(0))
node.Close()
}
func TestMetaService_processSegmentModify(t *testing.T) {
node := newQueryNode()
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
collectionName := "collection0"
collectionID := UniqueID(0)
segmentID := UniqueID(0)
initTestMeta(t, node, collectionName, collectionID, segmentID)
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
colMetaBlob := proto.MarshalTextString(&collectionMeta)
err := (*node.replica).addCollection(&collectionMeta, string(colMetaBlob))
assert.NoError(t, err)
err = (*node.replica).addPartition(UniqueID(0), "default")
assert.NoError(t, err)
id := "0"
value := `partition_tag: "default"
@ -265,9 +412,9 @@ func TestMetaService_processSegmentModify(t *testing.T) {
`
(*node.metaService).processSegmentCreate(id, value)
s, err := node.replica.getSegmentByID(segmentID)
s, err := (*node.replica).getSegmentByID(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, s.segmentID, segmentID)
assert.Equal(t, s.segmentID, UniqueID(0))
newValue := `partition_tag: "default"
channel_start: 0
@ -277,15 +424,19 @@ func TestMetaService_processSegmentModify(t *testing.T) {
// TODO: modify segment for testing processCollectionModify
(*node.metaService).processSegmentModify(id, newValue)
seg, err := node.replica.getSegmentByID(segmentID)
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, seg.segmentID, segmentID)
node.Close()
assert.Equal(t, seg.segmentID, UniqueID(0))
}
func TestMetaService_processCollectionModify(t *testing.T) {
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
id := "0"
value := `schema: <
@ -297,10 +448,6 @@ func TestMetaService_processCollectionModify(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -318,24 +465,24 @@ func TestMetaService_processCollectionModify(t *testing.T) {
`
(*node.metaService).processCollectionCreate(id, value)
collectionNum := node.replica.getCollectionNum()
collectionNum := (*node.replica).getCollectionNum()
assert.Equal(t, collectionNum, 1)
collection, err := node.replica.getCollectionByName("test")
collection, err := (*node.replica).getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
partitionNum, err := node.replica.getPartitionNum(UniqueID(0))
partitionNum, err := (*node.replica).getPartitionNum(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, partitionNum, 3)
hasPartition := node.replica.hasPartition(UniqueID(0), "p0")
hasPartition := (*node.replica).hasPartition(UniqueID(0), "p0")
assert.Equal(t, hasPartition, true)
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
assert.Equal(t, hasPartition, true)
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
assert.Equal(t, hasPartition, true)
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
assert.Equal(t, hasPartition, false)
newValue := `schema: <
@ -347,10 +494,6 @@ func TestMetaService_processCollectionModify(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -368,28 +511,32 @@ func TestMetaService_processCollectionModify(t *testing.T) {
`
(*node.metaService).processCollectionModify(id, newValue)
collection, err = node.replica.getCollectionByName("test")
collection, err = (*node.replica).getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
partitionNum, err = node.replica.getPartitionNum(UniqueID(0))
partitionNum, err = (*node.replica).getPartitionNum(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, partitionNum, 3)
hasPartition = node.replica.hasPartition(UniqueID(0), "p0")
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p0")
assert.Equal(t, hasPartition, false)
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
assert.Equal(t, hasPartition, true)
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
assert.Equal(t, hasPartition, true)
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
assert.Equal(t, hasPartition, true)
node.Close()
}
func TestMetaService_processModify(t *testing.T) {
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
key1 := "by-dev/meta/collection/0"
msg1 := `schema: <
@ -401,10 +548,6 @@ func TestMetaService_processModify(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -422,24 +565,24 @@ func TestMetaService_processModify(t *testing.T) {
`
(*node.metaService).processCreate(key1, msg1)
collectionNum := node.replica.getCollectionNum()
collectionNum := (*node.replica).getCollectionNum()
assert.Equal(t, collectionNum, 1)
collection, err := node.replica.getCollectionByName("test")
collection, err := (*node.replica).getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
partitionNum, err := node.replica.getPartitionNum(UniqueID(0))
partitionNum, err := (*node.replica).getPartitionNum(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, partitionNum, 3)
hasPartition := node.replica.hasPartition(UniqueID(0), "p0")
hasPartition := (*node.replica).hasPartition(UniqueID(0), "p0")
assert.Equal(t, hasPartition, true)
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
assert.Equal(t, hasPartition, true)
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
assert.Equal(t, hasPartition, true)
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
assert.Equal(t, hasPartition, false)
key2 := "by-dev/meta/segment/0"
@ -450,7 +593,7 @@ func TestMetaService_processModify(t *testing.T) {
`
(*node.metaService).processCreate(key2, msg2)
s, err := node.replica.getSegmentByID(UniqueID(0))
s, err := (*node.replica).getSegmentByID(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, s.segmentID, UniqueID(0))
@ -465,10 +608,6 @@ func TestMetaService_processModify(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -486,21 +625,21 @@ func TestMetaService_processModify(t *testing.T) {
`
(*node.metaService).processModify(key1, msg3)
collection, err = node.replica.getCollectionByName("test")
collection, err = (*node.replica).getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
partitionNum, err = node.replica.getPartitionNum(UniqueID(0))
partitionNum, err = (*node.replica).getPartitionNum(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, partitionNum, 3)
hasPartition = node.replica.hasPartition(UniqueID(0), "p0")
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p0")
assert.Equal(t, hasPartition, false)
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
assert.Equal(t, hasPartition, true)
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
assert.Equal(t, hasPartition, true)
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
assert.Equal(t, hasPartition, true)
msg4 := `partition_tag: "p1"
@ -510,18 +649,68 @@ func TestMetaService_processModify(t *testing.T) {
`
(*node.metaService).processModify(key2, msg4)
seg, err := node.replica.getSegmentByID(UniqueID(0))
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, seg.segmentID, UniqueID(0))
node.Close()
}
func TestMetaService_processSegmentDelete(t *testing.T) {
node := newQueryNode()
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
collectionName := "collection0"
collectionID := UniqueID(0)
initTestMeta(t, node, collectionName, collectionID, 0)
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
colMetaBlob := proto.MarshalTextString(&collectionMeta)
err := (*node.replica).addCollection(&collectionMeta, string(colMetaBlob))
assert.NoError(t, err)
err = (*node.replica).addPartition(UniqueID(0), "default")
assert.NoError(t, err)
id := "0"
value := `partition_tag: "default"
@ -531,19 +720,23 @@ func TestMetaService_processSegmentDelete(t *testing.T) {
`
(*node.metaService).processSegmentCreate(id, value)
seg, err := node.replica.getSegmentByID(UniqueID(0))
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, seg.segmentID, UniqueID(0))
(*node.metaService).processSegmentDelete("0")
mapSize := node.replica.getSegmentNum()
mapSize := (*node.replica).getSegmentNum()
assert.Equal(t, mapSize, 0)
node.Close()
}
func TestMetaService_processCollectionDelete(t *testing.T) {
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
id := "0"
value := `schema: <
@ -555,10 +748,6 @@ func TestMetaService_processCollectionDelete(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -574,22 +763,26 @@ func TestMetaService_processCollectionDelete(t *testing.T) {
`
(*node.metaService).processCollectionCreate(id, value)
collectionNum := node.replica.getCollectionNum()
collectionNum := (*node.replica).getCollectionNum()
assert.Equal(t, collectionNum, 1)
collection, err := node.replica.getCollectionByName("test")
collection, err := (*node.replica).getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
(*node.metaService).processCollectionDelete(id)
collectionNum = node.replica.getCollectionNum()
collectionNum = (*node.replica).getCollectionNum()
assert.Equal(t, collectionNum, 0)
node.Close()
}
func TestMetaService_processDelete(t *testing.T) {
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
key1 := "by-dev/meta/collection/0"
msg1 := `schema: <
@ -601,10 +794,6 @@ func TestMetaService_processDelete(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -620,10 +809,10 @@ func TestMetaService_processDelete(t *testing.T) {
`
(*node.metaService).processCreate(key1, msg1)
collectionNum := node.replica.getCollectionNum()
collectionNum := (*node.replica).getCollectionNum()
assert.Equal(t, collectionNum, 1)
collection, err := node.replica.getCollectionByName("test")
collection, err := (*node.replica).getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
@ -635,48 +824,77 @@ func TestMetaService_processDelete(t *testing.T) {
`
(*node.metaService).processCreate(key2, msg2)
seg, err := node.replica.getSegmentByID(UniqueID(0))
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, seg.segmentID, UniqueID(0))
(*node.metaService).processDelete(key1)
collectionsSize := node.replica.getCollectionNum()
collectionsSize := (*node.replica).getCollectionNum()
assert.Equal(t, collectionsSize, 0)
mapSize := node.replica.getSegmentNum()
mapSize := (*node.replica).getSegmentNum()
assert.Equal(t, mapSize, 0)
node.Close()
}
func TestMetaService_processResp(t *testing.T) {
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
var ctx context.Context
if closeWithDeadline {
var cancel context.CancelFunc
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel = context.WithDeadline(context.Background(), d)
defer cancel()
} else {
ctx = context.Background()
}
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
metaChan := (*node.metaService).kvBase.WatchWithPrefix("")
select {
case <-node.queryNodeLoopCtx.Done():
case <-node.ctx.Done():
return
case resp := <-metaChan:
_ = (*node.metaService).processResp(resp)
}
node.Close()
}
func TestMetaService_loadCollections(t *testing.T) {
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
var ctx context.Context
if closeWithDeadline {
var cancel context.CancelFunc
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel = context.WithDeadline(context.Background(), d)
defer cancel()
} else {
ctx = context.Background()
}
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
err2 := (*node.metaService).loadCollections()
assert.Nil(t, err2)
node.Close()
}
func TestMetaService_loadSegments(t *testing.T) {
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
var ctx context.Context
if closeWithDeadline {
var cancel context.CancelFunc
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel = context.WithDeadline(context.Background(), d)
defer cancel()
} else {
ctx = context.Background()
}
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
err2 := (*node.metaService).loadSegments()
assert.Nil(t, err2)
node.Close()
}

View File

@ -2,8 +2,8 @@ package querynode
import (
"log"
"os"
"strconv"
"strings"
"github.com/zilliztech/milvus-distributed/internal/util/paramtable"
)
@ -21,16 +21,15 @@ func (p *ParamTable) Init() {
panic(err)
}
queryNodeIDStr := os.Getenv("QUERY_NODE_ID")
if queryNodeIDStr == "" {
queryNodeIDList := p.QueryNodeIDList()
if len(queryNodeIDList) <= 0 {
queryNodeIDStr = "0"
} else {
queryNodeIDStr = strconv.Itoa(int(queryNodeIDList[0]))
}
err = p.LoadYaml("milvus.yaml")
if err != nil {
panic(err)
}
err = p.LoadYaml("advanced/channel.yaml")
if err != nil {
panic(err)
}
p.Save("_queryNodeID", queryNodeIDStr)
}
func (p *ParamTable) pulsarAddress() (string, error) {
@ -41,8 +40,8 @@ func (p *ParamTable) pulsarAddress() (string, error) {
return url, nil
}
func (p *ParamTable) QueryNodeID() UniqueID {
queryNodeID, err := p.Load("_queryNodeID")
func (p *ParamTable) queryNodeID() int {
queryNodeID, err := p.Load("reader.clientid")
if err != nil {
panic(err)
}
@ -50,7 +49,7 @@ func (p *ParamTable) QueryNodeID() UniqueID {
if err != nil {
panic(err)
}
return UniqueID(id)
return id
}
func (p *ParamTable) insertChannelRange() []int {
@ -58,47 +57,138 @@ func (p *ParamTable) insertChannelRange() []int {
if err != nil {
panic(err)
}
return paramtable.ConvertRangeToIntRange(insertChannelRange, ",")
channelRange := strings.Split(insertChannelRange, ",")
if len(channelRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(channelRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(channelRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
return []int{channelBegin, channelEnd}
}
// advanced params
// stats
func (p *ParamTable) statsPublishInterval() int {
return p.ParseInt("queryNode.stats.publishInterval")
timeInterval, err := p.Load("queryNode.stats.publishInterval")
if err != nil {
panic(err)
}
interval, err := strconv.Atoi(timeInterval)
if err != nil {
panic(err)
}
return interval
}
// dataSync:
func (p *ParamTable) flowGraphMaxQueueLength() int32 {
return p.ParseInt32("queryNode.dataSync.flowGraph.maxQueueLength")
queueLength, err := p.Load("queryNode.dataSync.flowGraph.maxQueueLength")
if err != nil {
panic(err)
}
length, err := strconv.Atoi(queueLength)
if err != nil {
panic(err)
}
return int32(length)
}
func (p *ParamTable) flowGraphMaxParallelism() int32 {
return p.ParseInt32("queryNode.dataSync.flowGraph.maxParallelism")
maxParallelism, err := p.Load("queryNode.dataSync.flowGraph.maxParallelism")
if err != nil {
panic(err)
}
maxPara, err := strconv.Atoi(maxParallelism)
if err != nil {
panic(err)
}
return int32(maxPara)
}
// msgStream
func (p *ParamTable) insertReceiveBufSize() int64 {
return p.ParseInt64("queryNode.msgStream.insert.recvBufSize")
revBufSize, err := p.Load("queryNode.msgStream.insert.recvBufSize")
if err != nil {
panic(err)
}
bufSize, err := strconv.Atoi(revBufSize)
if err != nil {
panic(err)
}
return int64(bufSize)
}
func (p *ParamTable) insertPulsarBufSize() int64 {
return p.ParseInt64("queryNode.msgStream.insert.pulsarBufSize")
pulsarBufSize, err := p.Load("queryNode.msgStream.insert.pulsarBufSize")
if err != nil {
panic(err)
}
bufSize, err := strconv.Atoi(pulsarBufSize)
if err != nil {
panic(err)
}
return int64(bufSize)
}
func (p *ParamTable) searchReceiveBufSize() int64 {
return p.ParseInt64("queryNode.msgStream.search.recvBufSize")
revBufSize, err := p.Load("queryNode.msgStream.search.recvBufSize")
if err != nil {
panic(err)
}
bufSize, err := strconv.Atoi(revBufSize)
if err != nil {
panic(err)
}
return int64(bufSize)
}
func (p *ParamTable) searchPulsarBufSize() int64 {
return p.ParseInt64("queryNode.msgStream.search.pulsarBufSize")
pulsarBufSize, err := p.Load("queryNode.msgStream.search.pulsarBufSize")
if err != nil {
panic(err)
}
bufSize, err := strconv.Atoi(pulsarBufSize)
if err != nil {
panic(err)
}
return int64(bufSize)
}
func (p *ParamTable) searchResultReceiveBufSize() int64 {
return p.ParseInt64("queryNode.msgStream.searchResult.recvBufSize")
revBufSize, err := p.Load("queryNode.msgStream.searchResult.recvBufSize")
if err != nil {
panic(err)
}
bufSize, err := strconv.Atoi(revBufSize)
if err != nil {
panic(err)
}
return int64(bufSize)
}
func (p *ParamTable) statsReceiveBufSize() int64 {
return p.ParseInt64("queryNode.msgStream.stats.recvBufSize")
revBufSize, err := p.Load("queryNode.msgStream.stats.recvBufSize")
if err != nil {
panic(err)
}
bufSize, err := strconv.Atoi(revBufSize)
if err != nil {
panic(err)
}
return int64(bufSize)
}
func (p *ParamTable) etcdAddress() string {
@ -122,73 +212,123 @@ func (p *ParamTable) metaRootPath() string {
}
func (p *ParamTable) gracefulTime() int64 {
return p.ParseInt64("queryNode.gracefulTime")
gracefulTime, err := p.Load("queryNode.gracefulTime")
if err != nil {
panic(err)
}
time, err := strconv.Atoi(gracefulTime)
if err != nil {
panic(err)
}
return int64(time)
}
func (p *ParamTable) insertChannelNames() []string {
prefix, err := p.Load("msgChannel.chanNamePrefix.insert")
ch, err := p.Load("msgChannel.chanNamePrefix.insert")
if err != nil {
log.Fatal(err)
}
prefix += "-"
channelRange, err := p.Load("msgChannel.channelRange.insert")
if err != nil {
panic(err)
}
channelIDs := paramtable.ConvertRangeToIntSlice(channelRange, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
sep := len(channelIDs) / p.queryNodeNum()
index := p.sliceIndex()
if index == -1 {
panic("queryNodeID not Match with Config")
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
start := index * sep
return ret[start : start+sep]
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
}
func (p *ParamTable) searchChannelNames() []string {
prefix, err := p.Load("msgChannel.chanNamePrefix.search")
ch, err := p.Load("msgChannel.chanNamePrefix.search")
if err != nil {
log.Fatal(err)
}
prefix += "-"
channelRange, err := p.Load("msgChannel.channelRange.search")
if err != nil {
panic(err)
}
channelIDs := paramtable.ConvertRangeToIntSlice(channelRange, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
return ret
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
}
func (p *ParamTable) searchResultChannelNames() []string {
prefix, err := p.Load("msgChannel.chanNamePrefix.searchResult")
ch, err := p.Load("msgChannel.chanNamePrefix.searchResult")
if err != nil {
log.Fatal(err)
}
prefix += "-"
channelRange, err := p.Load("msgChannel.channelRange.searchResult")
if err != nil {
panic(err)
}
channelIDs := paramtable.ConvertRangeToIntSlice(channelRange, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
return ret
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
}
func (p *ParamTable) msgChannelSubName() string {
@ -197,11 +337,7 @@ func (p *ParamTable) msgChannelSubName() string {
if err != nil {
log.Panic(err)
}
queryNodeIDStr, err := p.Load("_QueryNodeID")
if err != nil {
panic(err)
}
return name + "-" + queryNodeIDStr
return name
}
func (p *ParamTable) statsChannelName() string {
@ -211,18 +347,3 @@ func (p *ParamTable) statsChannelName() string {
}
return channels
}
func (p *ParamTable) sliceIndex() int {
queryNodeID := p.QueryNodeID()
queryNodeIDList := p.QueryNodeIDList()
for i := 0; i < len(queryNodeIDList); i++ {
if queryNodeID == queryNodeIDList[i] {
return i
}
}
return -1
}
func (p *ParamTable) queryNodeNum() int {
return len(p.QueryNodeIDList())
}

View File

@ -1,109 +1,128 @@
package querynode
import (
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParamTable_Init(t *testing.T) {
Params.Init()
}
func TestParamTable_PulsarAddress(t *testing.T) {
Params.Init()
address, err := Params.pulsarAddress()
assert.NoError(t, err)
split := strings.Split(address, ":")
assert.Equal(t, "pulsar", split[0])
assert.Equal(t, "6650", split[len(split)-1])
assert.Equal(t, split[0], "pulsar")
assert.Equal(t, split[len(split)-1], "6650")
}
func TestParamTable_QueryNodeID(t *testing.T) {
id := Params.QueryNodeID()
assert.Contains(t, Params.QueryNodeIDList(), id)
Params.Init()
id := Params.queryNodeID()
assert.Equal(t, id, 0)
}
func TestParamTable_insertChannelRange(t *testing.T) {
Params.Init()
channelRange := Params.insertChannelRange()
assert.Equal(t, 2, len(channelRange))
assert.Equal(t, len(channelRange), 2)
assert.Equal(t, channelRange[0], 0)
assert.Equal(t, channelRange[1], 1)
}
func TestParamTable_statsServiceTimeInterval(t *testing.T) {
Params.Init()
interval := Params.statsPublishInterval()
assert.Equal(t, 1000, interval)
assert.Equal(t, interval, 1000)
}
func TestParamTable_statsMsgStreamReceiveBufSize(t *testing.T) {
Params.Init()
bufSize := Params.statsReceiveBufSize()
assert.Equal(t, int64(64), bufSize)
assert.Equal(t, bufSize, int64(64))
}
func TestParamTable_insertMsgStreamReceiveBufSize(t *testing.T) {
Params.Init()
bufSize := Params.insertReceiveBufSize()
assert.Equal(t, int64(1024), bufSize)
assert.Equal(t, bufSize, int64(1024))
}
func TestParamTable_searchMsgStreamReceiveBufSize(t *testing.T) {
Params.Init()
bufSize := Params.searchReceiveBufSize()
assert.Equal(t, int64(512), bufSize)
assert.Equal(t, bufSize, int64(512))
}
func TestParamTable_searchResultMsgStreamReceiveBufSize(t *testing.T) {
Params.Init()
bufSize := Params.searchResultReceiveBufSize()
assert.Equal(t, int64(64), bufSize)
assert.Equal(t, bufSize, int64(64))
}
func TestParamTable_searchPulsarBufSize(t *testing.T) {
Params.Init()
bufSize := Params.searchPulsarBufSize()
assert.Equal(t, int64(512), bufSize)
assert.Equal(t, bufSize, int64(512))
}
func TestParamTable_insertPulsarBufSize(t *testing.T) {
Params.Init()
bufSize := Params.insertPulsarBufSize()
assert.Equal(t, int64(1024), bufSize)
assert.Equal(t, bufSize, int64(1024))
}
func TestParamTable_flowGraphMaxQueueLength(t *testing.T) {
Params.Init()
length := Params.flowGraphMaxQueueLength()
assert.Equal(t, int32(1024), length)
assert.Equal(t, length, int32(1024))
}
func TestParamTable_flowGraphMaxParallelism(t *testing.T) {
Params.Init()
maxParallelism := Params.flowGraphMaxParallelism()
assert.Equal(t, int32(1024), maxParallelism)
assert.Equal(t, maxParallelism, int32(1024))
}
func TestParamTable_insertChannelNames(t *testing.T) {
Params.Init()
names := Params.insertChannelNames()
channelRange := Params.insertChannelRange()
num := channelRange[1] - channelRange[0]
num = num / Params.queryNodeNum()
assert.Equal(t, num, len(names))
start := num * Params.sliceIndex()
assert.Equal(t, fmt.Sprintf("insert-%d", channelRange[start]), names[0])
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "insert-0")
}
func TestParamTable_searchChannelNames(t *testing.T) {
Params.Init()
names := Params.searchChannelNames()
assert.Equal(t, len(names), 1)
assert.Equal(t, "search-0", names[0])
assert.Equal(t, names[0], "search-0")
}
func TestParamTable_searchResultChannelNames(t *testing.T) {
Params.Init()
names := Params.searchResultChannelNames()
assert.NotNil(t, names)
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "searchResult-0")
}
func TestParamTable_msgChannelSubName(t *testing.T) {
Params.Init()
name := Params.msgChannelSubName()
expectName := fmt.Sprintf("queryNode-%d", Params.QueryNodeID())
assert.Equal(t, expectName, name)
assert.Equal(t, name, "queryNode")
}
func TestParamTable_statsChannelName(t *testing.T) {
Params.Init()
name := Params.statsChannelName()
assert.Equal(t, "query-node-stats", name)
assert.Equal(t, name, "query-node-stats")
}
func TestParamTable_metaRootPath(t *testing.T) {
Params.Init()
path := Params.metaRootPath()
assert.Equal(t, "by-dev/meta", path)
assert.Equal(t, path, "by-dev/meta")
}

View File

@ -1,20 +1,77 @@
package querynode
import (
"context"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
func TestPartition_Segments(t *testing.T) {
node := newQueryNode()
collectionName := "collection0"
collectionID := UniqueID(0)
initTestMeta(t, node, collectionName, collectionID, 0)
ctx := context.Background()
node := NewQueryNode(ctx, 0)
collection, err := node.replica.getCollectionByName(collectionName)
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collectionMeta := collection.meta
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
for _, tag := range collectionMeta.PartitionTags {
err := (*node.replica).addPartition(collection.ID(), tag)
assert.NoError(t, err)
}
partitions := collection.Partitions()
assert.Equal(t, len(collectionMeta.PartitionTags), len(*partitions))
@ -23,12 +80,12 @@ func TestPartition_Segments(t *testing.T) {
const segmentNum = 3
for i := 0; i < segmentNum; i++ {
err := node.replica.addSegment(UniqueID(i), targetPartition.partitionTag, collection.ID())
err := (*node.replica).addSegment(UniqueID(i), targetPartition.partitionTag, collection.ID())
assert.NoError(t, err)
}
segments := targetPartition.Segments()
assert.Equal(t, segmentNum+1, len(*segments))
assert.Equal(t, segmentNum, len(*segments))
}
func TestPartition_newPartition(t *testing.T) {

View File

@ -8,17 +8,59 @@ import (
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
func TestPlan_Plan(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(collectionMeta, collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
@ -32,13 +74,52 @@ func TestPlan_Plan(t *testing.T) {
}
func TestPlan_PlaceholderGroup(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(collectionMeta, collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"

View File

@ -17,12 +17,11 @@ import (
)
type QueryNode struct {
queryNodeLoopCtx context.Context
queryNodeLoopCancel func()
ctx context.Context
QueryNodeID uint64
replica collectionReplica
replica *collectionReplica
dataSyncService *dataSyncService
metaService *metaService
@ -30,14 +29,7 @@ type QueryNode struct {
statsService *statsService
}
func Init() {
Params.Init()
}
func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode {
ctx1, cancel := context.WithCancel(ctx)
segmentsMap := make(map[int64]*Segment)
collections := make([]*Collection, 0)
@ -51,11 +43,11 @@ func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode {
}
return &QueryNode{
queryNodeLoopCtx: ctx1,
queryNodeLoopCancel: cancel,
QueryNodeID: queryNodeID,
ctx: ctx,
replica: replica,
QueryNodeID: queryNodeID,
replica: &replica,
dataSyncService: nil,
metaService: nil,
@ -64,34 +56,31 @@ func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode {
}
}
func (node *QueryNode) Start() error {
// todo add connectMaster logic
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica)
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica)
func (node *QueryNode) Start() {
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
node.searchService = newSearchService(node.ctx, node.replica)
node.metaService = newMetaService(node.ctx, node.replica)
node.statsService = newStatsService(node.ctx, node.replica)
go node.dataSyncService.start()
go node.searchService.start()
go node.metaService.start()
go node.statsService.start()
return nil
node.statsService.start()
}
func (node *QueryNode) Close() {
node.queryNodeLoopCancel()
<-node.ctx.Done()
// free collectionReplica
node.replica.freeAll()
(*node.replica).freeAll()
// close services
if node.dataSyncService != nil {
node.dataSyncService.close()
(*node.dataSyncService).close()
}
if node.searchService != nil {
node.searchService.close()
(*node.searchService).close()
}
if node.statsService != nil {
node.statsService.close()
(*node.statsService).close()
}
}

View File

@ -2,93 +2,18 @@ package querynode
import (
"context"
"os"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
const ctxTimeInMillisecond = 200
const closeWithDeadline = true
func setup() {
// NOTE: start pulsar and etcd before test
func TestQueryNode_start(t *testing.T) {
Params.Init()
}
func genTestCollectionMeta(collectionName string, collectionID UniqueID) *etcdpb.CollectionMeta {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
IndexParams: []*commonpb.KeyValuePair{
{
Key: "metric_type",
Value: "L2",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: collectionID,
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
return &collectionMeta
}
func initTestMeta(t *testing.T, node *QueryNode, collectionName string, collectionID UniqueID, segmentID UniqueID) {
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = node.replica.addCollection(collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := node.replica.getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
assert.Equal(t, node.replica.getCollectionNum(), 1)
err = node.replica.addPartition(collection.ID(), collectionMeta.PartitionTags[0])
assert.NoError(t, err)
err = node.replica.addSegment(segmentID, collectionMeta.PartitionTags[0], collectionID)
assert.NoError(t, err)
}
func newQueryNode() *QueryNode {
var ctx context.Context
if closeWithDeadline {
var cancel context.CancelFunc
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
@ -98,21 +23,7 @@ func newQueryNode() *QueryNode {
ctx = context.Background()
}
svr := NewQueryNode(ctx, 0)
return svr
}
func TestMain(m *testing.M) {
setup()
exitCode := m.Run()
os.Exit(exitCode)
}
// NOTE: start pulsar and etcd before test
func TestQueryNode_Start(t *testing.T) {
localNode := newQueryNode()
err := localNode.Start()
assert.Nil(t, err)
localNode.Close()
node := NewQueryNode(ctx, 0)
node.Start()
node.Close()
}

View File

@ -0,0 +1,15 @@
package querynode
import (
"context"
)
func Init() {
Params.Init()
}
func StartQueryNode(ctx context.Context) {
node := NewQueryNode(ctx, 0)
node.Start()
}

View File

@ -9,19 +9,63 @@ import (
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
func TestReduce_AllFunc(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
segmentID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(collectionMeta, collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
assert.Equal(t, segmentID, segment.segmentID)
const DIM = 16
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}

View File

@ -4,12 +4,10 @@ import "C"
import (
"context"
"errors"
"fmt"
"github.com/golang/protobuf/proto"
"log"
"sync"
"github.com/golang/protobuf/proto"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
@ -21,7 +19,7 @@ type searchService struct {
wait sync.WaitGroup
cancel context.CancelFunc
replica collectionReplica
replica *collectionReplica
tSafeWatcher *tSafeWatcher
serviceableTime Timestamp
@ -29,14 +27,13 @@ type searchService struct {
msgBuffer chan msgstream.TsMsg
unsolvedMsg []msgstream.TsMsg
searchMsgStream msgstream.MsgStream
searchResultMsgStream msgstream.MsgStream
queryNodeID UniqueID
searchMsgStream *msgstream.MsgStream
searchResultMsgStream *msgstream.MsgStream
}
type ResultEntityIds []UniqueID
func newSearchService(ctx context.Context, replica collectionReplica) *searchService {
func newSearchService(ctx context.Context, replica *collectionReplica) *searchService {
receiveBufSize := Params.searchReceiveBufSize()
pulsarBufSize := Params.searchPulsarBufSize()
@ -72,15 +69,14 @@ func newSearchService(ctx context.Context, replica collectionReplica) *searchSer
replica: replica,
tSafeWatcher: newTSafeWatcher(),
searchMsgStream: inputStream,
searchResultMsgStream: outputStream,
queryNodeID: Params.QueryNodeID(),
searchMsgStream: &inputStream,
searchResultMsgStream: &outputStream,
}
}
func (ss *searchService) start() {
ss.searchMsgStream.Start()
ss.searchResultMsgStream.Start()
(*ss.searchMsgStream).Start()
(*ss.searchResultMsgStream).Start()
ss.register()
ss.wait.Add(2)
go ss.receiveSearchMsg()
@ -89,24 +85,20 @@ func (ss *searchService) start() {
}
func (ss *searchService) close() {
if ss.searchMsgStream != nil {
ss.searchMsgStream.Close()
}
if ss.searchResultMsgStream != nil {
ss.searchResultMsgStream.Close()
}
(*ss.searchMsgStream).Close()
(*ss.searchResultMsgStream).Close()
ss.cancel()
}
func (ss *searchService) register() {
tSafe := ss.replica.getTSafe()
tSafe.registerTSafeWatcher(ss.tSafeWatcher)
tSafe := (*(ss.replica)).getTSafe()
(*tSafe).registerTSafeWatcher(ss.tSafeWatcher)
}
func (ss *searchService) waitNewTSafe() Timestamp {
// block until dataSyncService updating tSafe
ss.tSafeWatcher.hasUpdate()
timestamp := ss.replica.getTSafe().get()
timestamp := (*(*ss.replica).getTSafe()).get()
return timestamp
}
@ -130,7 +122,7 @@ func (ss *searchService) receiveSearchMsg() {
case <-ss.ctx.Done():
return
default:
msgPack := ss.searchMsgStream.Consume()
msgPack := (*ss.searchMsgStream).Consume()
if msgPack == nil || len(msgPack.Msgs) <= 0 {
continue
}
@ -227,7 +219,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
}
collectionName := query.CollectionName
partitionTags := query.PartitionTags
collection, err := ss.replica.getCollectionByName(collectionName)
collection, err := (*ss.replica).getCollectionByName(collectionName)
if err != nil {
return err
}
@ -249,14 +241,14 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
matchedSegments := make([]*Segment, 0)
for _, partitionTag := range partitionTags {
hasPartition := ss.replica.hasPartition(collectionID, partitionTag)
hasPartition := (*ss.replica).hasPartition(collectionID, partitionTag)
if !hasPartition {
return errors.New("search Failed, invalid partitionTag")
}
}
for _, partitionTag := range partitionTags {
partition, _ := ss.replica.getPartitionByTag(collectionID, partitionTag)
partition, _ := (*ss.replica).getPartitionByTag(collectionID, partitionTag)
for _, segment := range partition.segments {
//fmt.Println("dsl = ", dsl)
@ -276,13 +268,13 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS},
ReqID: searchMsg.ReqID,
ProxyID: searchMsg.ProxyID,
QueryNodeID: ss.queryNodeID,
QueryNodeID: searchMsg.ProxyID,
Timestamp: searchTimestamp,
ResultChannelID: searchMsg.ResultChannelID,
Hits: nil,
}
searchResultMsg := &msgstream.SearchResultMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
SearchResult: results,
}
err = ss.publishSearchResult(searchResultMsg)
@ -341,7 +333,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
Hits: hits,
}
searchResultMsg := &msgstream.SearchResultMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
SearchResult: results,
}
err = ss.publishSearchResult(searchResultMsg)
@ -358,10 +350,9 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
}
func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error {
fmt.Println("Public SearchResult", msg.HashKeys())
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, msg)
err := ss.searchResultMsgStream.Produce(&msgPack)
err := (*ss.searchResultMsgStream).Produce(&msgPack)
if err != nil {
return err
}
@ -386,11 +377,11 @@ func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg, errMsg s
}
tsMsg := &msgstream.SearchResultMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
SearchResult: results,
}
msgPack.Msgs = append(msgPack.Msgs, tsMsg)
err := ss.searchResultMsgStream.Produce(&msgPack)
err := (*ss.searchResultMsgStream).Produce(&msgPack)
if err != nil {
return err
}

View File

@ -13,15 +13,80 @@ import (
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
func TestSearch_Search(t *testing.T) {
node := NewQueryNode(context.Background(), 0)
initTestMeta(t, node, "collection0", 0, 0)
Params.Init()
ctx, cancel := context.WithCancel(context.Background())
// init query node
pulsarURL, _ := Params.pulsarAddress()
node := NewQueryNode(ctx, 0)
// init meta
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
assert.NoError(t, err)
segmentID := UniqueID(0)
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
assert.NoError(t, err)
// test data generate
const msgLength = 10
@ -93,14 +158,14 @@ func TestSearch_Search(t *testing.T) {
msgPackSearch := msgstream.MsgPack{}
msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg)
searchStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchStream.SetPulsarClient(pulsarURL)
searchStream.CreatePulsarProducers(searchProducerChannels)
searchStream.Start()
err = searchStream.Produce(&msgPackSearch)
assert.NoError(t, err)
node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica)
node.searchService = newSearchService(node.ctx, node.replica)
go node.searchService.start()
// start insert
@ -170,7 +235,7 @@ func TestSearch_Search(t *testing.T) {
timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg)
// pulsar produce
insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
insertStream.SetPulsarClient(pulsarURL)
insertStream.CreatePulsarProducers(insertProducerChannels)
insertStream.Start()
@ -180,19 +245,83 @@ func TestSearch_Search(t *testing.T) {
assert.NoError(t, err)
// dataSync
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
go node.dataSyncService.start()
time.Sleep(1 * time.Second)
cancel()
node.Close()
}
func TestSearch_SearchMultiSegments(t *testing.T) {
node := NewQueryNode(context.Background(), 0)
initTestMeta(t, node, "collection0", 0, 0)
Params.Init()
ctx, cancel := context.WithCancel(context.Background())
// init query node
pulsarURL, _ := Params.pulsarAddress()
node := NewQueryNode(ctx, 0)
// init meta
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
assert.NoError(t, err)
segmentID := UniqueID(0)
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
assert.NoError(t, err)
// test data generate
const msgLength = 1024
@ -264,14 +393,14 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
msgPackSearch := msgstream.MsgPack{}
msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg)
searchStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchStream.SetPulsarClient(pulsarURL)
searchStream.CreatePulsarProducers(searchProducerChannels)
searchStream.Start()
err = searchStream.Produce(&msgPackSearch)
assert.NoError(t, err)
node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica)
node.searchService = newSearchService(node.ctx, node.replica)
go node.searchService.start()
// start insert
@ -345,7 +474,7 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg)
// pulsar produce
insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
insertStream.SetPulsarClient(pulsarURL)
insertStream.CreatePulsarProducers(insertProducerChannels)
insertStream.Start()
@ -355,10 +484,11 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
assert.NoError(t, err)
// dataSync
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
go node.dataSyncService.start()
time.Sleep(1 * time.Second)
cancel()
node.Close()
}

View File

@ -1,6 +1,7 @@
package querynode
import (
"context"
"encoding/binary"
"log"
"math"
@ -9,21 +10,61 @@ import (
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
//-------------------------------------------------------------------------------------- constructor and destructor
func TestSegment_newSegment(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -33,15 +74,52 @@ func TestSegment_newSegment(t *testing.T) {
}
func TestSegment_deleteSegment(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -53,15 +131,52 @@ func TestSegment_deleteSegment(t *testing.T) {
//-------------------------------------------------------------------------------------- stats functions
func TestSegment_getRowCount(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -104,15 +219,52 @@ func TestSegment_getRowCount(t *testing.T) {
}
func TestSegment_getDeletedCount(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -161,15 +313,52 @@ func TestSegment_getDeletedCount(t *testing.T) {
}
func TestSegment_getMemSize(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -213,15 +402,53 @@ func TestSegment_getMemSize(t *testing.T) {
//-------------------------------------------------------------------------------------- dm & search functions
func TestSegment_segmentInsert(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
assert.Equal(t, segmentID, segment.segmentID)
@ -259,15 +486,52 @@ func TestSegment_segmentInsert(t *testing.T) {
}
func TestSegment_segmentDelete(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -312,15 +576,55 @@ func TestSegment_segmentDelete(t *testing.T) {
}
func TestSegment_segmentSearch(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -357,6 +661,13 @@ func TestSegment_segmentSearch(t *testing.T) {
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
pulsarURL, _ := Params.pulsarAddress()
const receiveBufSize = 1024
searchProducerChannels := Params.searchChannelNames()
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchStream.SetPulsarClient(pulsarURL)
searchStream.CreatePulsarProducers(searchProducerChannels)
var searchRawData []byte
for _, ele := range vec {
buf := make([]byte, 4)
@ -397,15 +708,52 @@ func TestSegment_segmentSearch(t *testing.T) {
//-------------------------------------------------------------------------------------- preDm functions
func TestSegment_segmentPreInsert(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -439,15 +787,52 @@ func TestSegment_segmentPreInsert(t *testing.T) {
}
func TestSegment_segmentPreDelete(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)

View File

@ -14,11 +14,11 @@ import (
type statsService struct {
ctx context.Context
statsStream msgstream.MsgStream
replica collectionReplica
statsStream *msgstream.MsgStream
replica *collectionReplica
}
func newStatsService(ctx context.Context, replica collectionReplica) *statsService {
func newStatsService(ctx context.Context, replica *collectionReplica) *statsService {
return &statsService{
ctx: ctx,
@ -44,8 +44,8 @@ func (sService *statsService) start() {
var statsMsgStream msgstream.MsgStream = statsStream
sService.statsStream = statsMsgStream
sService.statsStream.Start()
sService.statsStream = &statsMsgStream
(*sService.statsStream).Start()
// start service
fmt.Println("do segments statistic in ", strconv.Itoa(sleepTimeInterval), "ms")
@ -60,13 +60,11 @@ func (sService *statsService) start() {
}
func (sService *statsService) close() {
if sService.statsStream != nil {
sService.statsStream.Close()
}
(*sService.statsStream).Close()
}
func (sService *statsService) sendSegmentStatistic() {
statisticData := sService.replica.getSegmentStatistics()
statisticData := (*sService.replica).getSegmentStatistics()
// fmt.Println("Publish segment statistic")
// fmt.Println(statisticData)
@ -84,7 +82,7 @@ func (sService *statsService) publicStatistic(statistic *internalpb.QueryNodeSeg
var msgPack = msgstream.MsgPack{
Msgs: []msgstream.TsMsg{msg},
}
err := sService.statsStream.Produce(&msgPack)
err := (*sService.statsStream).Produce(&msgPack)
if err != nil {
log.Println(err)
}

View File

@ -1,42 +1,193 @@
package querynode
import (
"context"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
// NOTE: start pulsar before test
func TestStatsService_start(t *testing.T) {
node := newQueryNode()
initTestMeta(t, node, "collection0", 0, 0)
node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica)
Params.Init()
var ctx context.Context
if closeWithDeadline {
var cancel context.CancelFunc
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel = context.WithDeadline(context.Background(), d)
defer cancel()
} else {
ctx = context.Background()
}
// init query node
node := NewQueryNode(ctx, 0)
// init meta
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
assert.NoError(t, err)
segmentID := UniqueID(0)
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
assert.NoError(t, err)
// start stats service
node.statsService = newStatsService(node.ctx, node.replica)
node.statsService.start()
node.Close()
}
//NOTE: start pulsar before test
// NOTE: start pulsar before test
func TestSegmentManagement_SegmentStatisticService(t *testing.T) {
node := newQueryNode()
initTestMeta(t, node, "collection0", 0, 0)
Params.Init()
var ctx context.Context
if closeWithDeadline {
var cancel context.CancelFunc
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel = context.WithDeadline(context.Background(), d)
defer cancel()
} else {
ctx = context.Background()
}
// init query node
pulsarURL, _ := Params.pulsarAddress()
node := NewQueryNode(ctx, 0)
// init meta
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
assert.NoError(t, err)
segmentID := UniqueID(0)
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
assert.NoError(t, err)
const receiveBufSize = 1024
// start pulsar
producerChannels := []string{Params.statsChannelName()}
pulsarURL, _ := Params.pulsarAddress()
statsStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
statsStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
statsStream.SetPulsarClient(pulsarURL)
statsStream.CreatePulsarProducers(producerChannels)
var statsMsgStream msgstream.MsgStream = statsStream
node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica)
node.statsService.statsStream = statsMsgStream
node.statsService.statsStream.Start()
node.statsService = newStatsService(node.ctx, node.replica)
node.statsService.statsStream = &statsMsgStream
(*node.statsService.statsStream).Start()
// send stats
node.statsService.sendSegmentStatistic()
node.Close()
}

View File

@ -36,11 +36,11 @@ type tSafeImpl struct {
watcherList []*tSafeWatcher
}
func newTSafe() tSafe {
func newTSafe() *tSafe {
var t tSafe = &tSafeImpl{
watcherList: make([]*tSafeWatcher, 0),
}
return t
return &t
}
func (ts *tSafeImpl) registerTSafeWatcher(t *tSafeWatcher) {

View File

@ -9,13 +9,13 @@ import (
func TestTSafe_GetAndSet(t *testing.T) {
tSafe := newTSafe()
watcher := newTSafeWatcher()
tSafe.registerTSafeWatcher(watcher)
(*tSafe).registerTSafeWatcher(watcher)
go func() {
watcher.hasUpdate()
timestamp := tSafe.get()
timestamp := (*tSafe).get()
assert.Equal(t, timestamp, Timestamp(1000))
}()
tSafe.set(Timestamp(1000))
(*tSafe).set(Timestamp(1000))
}

View File

@ -1,4 +1,3 @@
output
cmake-build-debug
.idea
cmake_build

View File

@ -2,17 +2,6 @@ cmake_minimum_required(VERSION 3.14...3.17 FATAL_ERROR)
project(wrapper)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
if (NOT GIT_ARROW_REPO)
set(GIT_ARROW_REPO "https://github.com/apache/arrow.git")
endif ()
message(STATUS "Arrow Repo:" ${GIT_ARROW_REPO})
if (NOT GIT_ARROW_TAG)
set(GIT_ARROW_TAG "apache-arrow-2.0.0")
endif ()
message(STATUS "Arrow Tag:" ${GIT_ARROW_TAG})
###################################################################################################
# - cmake modules ---------------------------------------------------------------------------------
@ -25,39 +14,29 @@ set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/Modules/" ${CMAKE_MODUL
message(STATUS "BUILDING ARROW")
include(ConfigureArrow)
if (ARROW_FOUND)
if(ARROW_FOUND)
message(STATUS "Apache Arrow found in ${ARROW_INCLUDE_DIR}")
else ()
else()
message(FATAL_ERROR "Apache Arrow not found, please check your settings.")
endif (ARROW_FOUND)
endif(ARROW_FOUND)
add_library(arrow STATIC IMPORTED ${ARROW_LIB})
add_library(parquet STATIC IMPORTED ${PARQUET_LIB})
add_library(thrift STATIC IMPORTED ${THRIFT_LIB})
add_library(utf8proc STATIC IMPORTED ${UTF8PROC_LIB})
if (ARROW_FOUND)
if(ARROW_FOUND)
set_target_properties(arrow PROPERTIES IMPORTED_LOCATION ${ARROW_LIB})
set_target_properties(parquet PROPERTIES IMPORTED_LOCATION ${PARQUET_LIB})
set_target_properties(thrift PROPERTIES IMPORTED_LOCATION ${THRIFT_LIB})
set_target_properties(utf8proc PROPERTIES IMPORTED_LOCATION ${UTF8PROC_LIB})
endif (ARROW_FOUND)
endif(ARROW_FOUND)
###################################################################################################
include_directories(${ARROW_INCLUDE_DIR})
include_directories(${PROJECT_SOURCE_DIR})
add_library(wrapper STATIC)
target_sources(wrapper PUBLIC ParquetWrapper.cpp
PayloadStream.cpp)
target_link_libraries(wrapper PUBLIC parquet arrow thrift utf8proc pthread)
if(NOT CMAKE_INSTALL_PREFIX)
set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR})
endif()
install(TARGETS wrapper DESTINATION ${CMAKE_INSTALL_PREFIX})
install(FILES ${ARROW_LIB} ${PARQUET_LIB} ${THRIFT_LIB} ${UTF8PROC_LIB} DESTINATION ${CMAKE_INSTALL_PREFIX})
add_library(wrapper ParquetWrapper.cpp ParquetWrapper.h ColumnType.h PayloadStream.h PayloadStream.cpp)
add_subdirectory(test)

View File

@ -206,7 +206,7 @@ extern "C" CStatus AddBinaryVectorToPayload(CPayloadWriter payloadWriter, uint8_
st.error_msg = ErrorMsg("payload has finished");
return st;
}
auto ast = builder->AppendValues(values, length);
auto ast = builder->AppendValues(values, (dimension / 8) * length);
if (!ast.ok()) {
st.error_code = static_cast<int>(ErrorCode::UNEXPECTED_ERROR);
st.error_msg = ErrorMsg(ast.message());
@ -249,7 +249,7 @@ extern "C" CStatus AddFloatVectorToPayload(CPayloadWriter payloadWriter, float *
st.error_msg = ErrorMsg("payload has finished");
return st;
}
auto ast = builder->AppendValues(reinterpret_cast<const uint8_t *>(values), length);
auto ast = builder->AppendValues(reinterpret_cast<const uint8_t *>(values), dimension * length * sizeof(float));
if (!ast.ok()) {
st.error_code = static_cast<int>(ErrorCode::UNEXPECTED_ERROR);
st.error_msg = ErrorMsg(ast.message());
@ -451,7 +451,7 @@ extern "C" CStatus GetBinaryVectorFromPayload(CPayloadReader payloadReader,
return st;
}
*dimension = array->byte_width() * 8;
*length = array->length();
*length = array->length() / array->byte_width();
*values = (uint8_t *) array->raw_values();
return st;
}
@ -470,7 +470,7 @@ extern "C" CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader,
return st;
}
*dimension = array->byte_width() / sizeof(float);
*length = array->length();
*length = array->length() / array->byte_width();
*values = (float *) array->raw_values();
return st;
}
@ -478,7 +478,12 @@ extern "C" CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader,
extern "C" int GetPayloadLengthFromReader(CPayloadReader payloadReader) {
auto p = reinterpret_cast<wrapper::PayloadReader *>(payloadReader);
if (p->array == nullptr) return 0;
return p->array->length();
auto ba = std::dynamic_pointer_cast<arrow::FixedSizeBinaryArray>(p->array);
if (ba == nullptr) {
return p->array->length();
} else {
return ba->length() / ba->byte_width();
}
}
extern "C" CStatus ReleasePayloadReader(CPayloadReader payloadReader) {

View File

@ -5,7 +5,6 @@ extern "C" {
#endif
#include <stdint.h>
#include <stdbool.h>
typedef void *CPayloadWriter;
@ -56,4 +55,4 @@ CStatus ReleasePayloadReader(CPayloadReader payloadReader);
#ifdef __cplusplus
}
#endif
#endif

View File

@ -1,58 +0,0 @@
#!/bin/bash
SOURCE=${BASH_SOURCE[0]}
while [ -h $SOURCE ]; do # resolve $SOURCE until the file is no longer a symlink
DIR=$( cd -P $( dirname $SOURCE ) && pwd )
SOURCE=$(readlink $SOURCE)
[[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
done
DIR=$( cd -P $( dirname $SOURCE ) && pwd )
# echo $DIR
CMAKE_BUILD=${DIR}/cmake_build
OUTPUT_LIB=${DIR}/output
if [ ! -d ${CMAKE_BUILD} ];then
mkdir ${CMAKE_BUILD}
fi
if [ -d ${OUTPUT_LIB} ];then
rm -rf ${OUTPUT_LIB}
fi
mkdir ${OUTPUT_LIB}
BUILD_TYPE="Debug"
GIT_ARROW_REPO="https://github.com/apache/arrow.git"
GIT_ARROW_TAG="apache-arrow-2.0.0"
while getopts "a:b:t:h" arg; do
case $arg in
t)
BUILD_TYPE=$OPTARG # BUILD_TYPE
;;
a)
GIT_ARROW_REPO=$OPTARG
;;
b)
GIT_ARROW_TAG=$OPTARG
;;
h) # help
echo "-t: build type(default: Debug)
-a: arrow repo(default: https://github.com/apache/arrow.git)
-b: arrow tag(default: apache-arrow-2.0.0)
-h: help
"
exit 0
;;
?)
echo "ERROR! unknown argument"
exit 1
;;
esac
done
echo "BUILD_TYPE: " $BUILD_TYPE
echo "GIT_ARROW_REPO: " $GIT_ARROW_REPO
echo "GIT_ARROW_TAG: " $GIT_ARROW_TAG
pushd ${CMAKE_BUILD}
cmake -DCMAKE_INSTALL_PREFIX=${OUTPUT_LIB} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DGIT_ARROW_REPO=${GIT_ARROW_REPO} -DGIT_ARROW_TAG=${GIT_ARROW_TAG} .. && make && make install

View File

@ -35,14 +35,11 @@ if(ARROW_CONFIG)
message(FATAL_ERROR "Configuring Arrow failed: " ${ARROW_CONFIG})
endif(ARROW_CONFIG)
#set(PARALLEL_BUILD -j)
#if($ENV{PARALLEL_LEVEL})
# set(NUM_JOBS $ENV{PARALLEL_LEVEL})
# set(PARALLEL_BUILD "${PARALLEL_BUILD}${NUM_JOBS}")
#endif($ENV{PARALLEL_LEVEL})
set(NUM_JOBS 4)
set(PARALLEL_BUILD "-j${NUM_JOBS}")
set(PARALLEL_BUILD -j)
if($ENV{PARALLEL_LEVEL})
set(NUM_JOBS $ENV{PARALLEL_LEVEL})
set(PARALLEL_BUILD "${PARALLEL_BUILD}${NUM_JOBS}")
endif($ENV{PARALLEL_LEVEL})
if(${NUM_JOBS})
if(${NUM_JOBS} EQUAL 1)
@ -91,8 +88,8 @@ if(ARROW_LIB AND PARQUET_LIB AND THRIFT_LIB AND UTF8PROC_LIB)
set(ARROW_FOUND TRUE)
endif(ARROW_LIB AND PARQUET_LIB AND THRIFT_LIB AND UTF8PROC_LIB)
# message(STATUS "FlatBuffers installed here: " ${FLATBUFFERS_ROOT})
# set(FLATBUFFERS_INCLUDE_DIR "${FLATBUFFERS_ROOT}/include")
# set(FLATBUFFERS_LIBRARY_DIR "${FLATBUFFERS_ROOT}/lib")
message(STATUS "FlatBuffers installed here: " ${FLATBUFFERS_ROOT})
set(FLATBUFFERS_INCLUDE_DIR "${FLATBUFFERS_ROOT}/include")
set(FLATBUFFERS_LIBRARY_DIR "${FLATBUFFERS_ROOT}/lib")
add_definitions(-DARROW_METADATA_V4)

View File

@ -20,8 +20,8 @@ project(wrapper-Arrow)
include(ExternalProject)
ExternalProject_Add(Arrow
GIT_REPOSITORY ${GIT_ARROW_REPO}
GIT_TAG ${GIT_ARROW_TAG}
GIT_REPOSITORY https://github.com/apache/arrow.git
GIT_TAG apache-arrow-2.0.0
GIT_SHALLOW true
SOURCE_DIR "${ARROW_ROOT}/arrow"
SOURCE_SUBDIR "cpp"

View File

@ -14,8 +14,6 @@ target_link_libraries(wrapper_test
parquet arrow thrift utf8proc pthread
)
install(TARGETS wrapper_test DESTINATION ${CMAKE_INSTALL_PREFIX})
# Defines `gtest_discover_tests()`.
#include(GoogleTest)
#gtest_discover_tests(milvusd_test)

View File

@ -71,36 +71,36 @@ TEST(wrapper, inoutstream) {
}
TEST(wrapper, boolean) {
auto payload = NewPayloadWriter(ColumnType::BOOL);
bool data[] = {true, false, true, false};
auto payload = NewPayloadWriter(ColumnType::BOOL);
bool data[] = {true, false, true, false};
auto st = AddBooleanToPayload(payload, data, 4);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
st = FinishPayloadWriter(payload);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
auto cb = GetPayloadBufferFromWriter(payload);
ASSERT_GT(cb.length, 0);
ASSERT_NE(cb.data, nullptr);
auto nums = GetPayloadLengthFromWriter(payload);
ASSERT_EQ(nums, 4);
auto st = AddBooleanToPayload(payload, data, 4);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
st = FinishPayloadWriter(payload);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
auto cb = GetPayloadBufferFromWriter(payload);
ASSERT_GT(cb.length, 0);
ASSERT_NE(cb.data, nullptr);
auto nums = GetPayloadLengthFromWriter(payload);
ASSERT_EQ(nums, 4);
auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length);
bool *values;
int length;
st = GetBoolFromPayload(reader, &values, &length);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
ASSERT_NE(values, nullptr);
ASSERT_EQ(length, 4);
length = GetPayloadLengthFromReader(reader);
ASSERT_EQ(length, 4);
for (int i = 0; i < length; i++) {
ASSERT_EQ(data[i], values[i]);
}
auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length);
bool *values;
int length;
st = GetBoolFromPayload(reader, &values, &length);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
ASSERT_NE(values, nullptr);
ASSERT_EQ(length, 4);
length = GetPayloadLengthFromReader(reader);
ASSERT_EQ(length, 4);
for (int i = 0; i < length; i++) {
ASSERT_EQ(data[i], values[i]);
}
st = ReleasePayloadWriter(payload);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
st = ReleasePayloadReader(reader);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
st = ReleasePayloadWriter(payload);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
st = ReleasePayloadReader(reader);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
}
#define NUMERIC_TEST(TEST_NAME, COLUMN_TYPE, DATA_TYPE, ADD_FUNC, GET_FUNC, ARRAY_TYPE) TEST(wrapper, TEST_NAME) { \

View File

@ -1,626 +0,0 @@
package storage
/*
#cgo CFLAGS: -I${SRCDIR}/cwrapper
#cgo LDFLAGS: -L${SRCDIR}/cwrapper/output -l:libwrapper.a -l:libparquet.a -l:libarrow.a -l:libthrift.a -l:libutf8proc.a -lstdc++ -lm
#include <stdlib.h>
#include "ParquetWrapper.h"
*/
import "C"
import (
"unsafe"
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
type (
PayloadWriter struct {
payloadWriterPtr C.CPayloadWriter
colType schemapb.DataType
}
PayloadReader struct {
payloadReaderPtr C.CPayloadReader
colType schemapb.DataType
}
)
func NewPayloadWriter(colType schemapb.DataType) (*PayloadWriter, error) {
w := C.NewPayloadWriter(C.int(colType))
if w == nil {
return nil, errors.New("create Payload writer failed")
}
return &PayloadWriter{payloadWriterPtr: w, colType: colType}, nil
}
func (w *PayloadWriter) AddDataToPayload(msgs interface{}, dim ...int) error {
switch len(dim) {
case 0:
switch w.colType {
case schemapb.DataType_BOOL:
val, ok := msgs.([]bool)
if !ok {
return errors.New("incorrect data type")
}
return w.AddBoolToPayload(val)
case schemapb.DataType_INT8:
val, ok := msgs.([]int8)
if !ok {
return errors.New("incorrect data type")
}
return w.AddInt8ToPayload(val)
case schemapb.DataType_INT16:
val, ok := msgs.([]int16)
if !ok {
return errors.New("incorrect data type")
}
return w.AddInt16ToPayload(val)
case schemapb.DataType_INT32:
val, ok := msgs.([]int32)
if !ok {
return errors.New("incorrect data type")
}
return w.AddInt32ToPayload(val)
case schemapb.DataType_INT64:
val, ok := msgs.([]int64)
if !ok {
return errors.New("incorrect data type")
}
return w.AddInt64ToPayload(val)
case schemapb.DataType_FLOAT:
val, ok := msgs.([]float32)
if !ok {
return errors.New("incorrect data type")
}
return w.AddFloatToPayload(val)
case schemapb.DataType_DOUBLE:
val, ok := msgs.([]float64)
if !ok {
return errors.New("incorrect data type")
}
return w.AddDoubleToPayload(val)
case schemapb.DataType_STRING:
val, ok := msgs.(string)
if !ok {
return errors.New("incorrect data type")
}
return w.AddOneStringToPayload(val)
}
case 1:
switch w.colType {
case schemapb.DataType_VECTOR_BINARY:
val, ok := msgs.([]byte)
if !ok {
return errors.New("incorrect data type")
}
return w.AddBinaryVectorToPayload(val, dim[0])
case schemapb.DataType_VECTOR_FLOAT:
val, ok := msgs.([]float32)
if !ok {
return errors.New("incorrect data type")
}
return w.AddFloatVectorToPayload(val, dim[0])
}
default:
return errors.New("incorrect input numbers")
}
return nil
}
func (w *PayloadWriter) AddBoolToPayload(msgs []bool) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.bool)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddBooleanToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddInt8ToPayload(msgs []int8) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.int8_t)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddInt8ToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddInt16ToPayload(msgs []int16) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.int16_t)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddInt16ToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddInt32ToPayload(msgs []int32) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.int32_t)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddInt32ToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddInt64ToPayload(msgs []int64) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.int64_t)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddInt64ToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddFloatToPayload(msgs []float32) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.float)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddFloatToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddDoubleToPayload(msgs []float64) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.double)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddDoubleToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddOneStringToPayload(msg string) error {
length := len(msg)
if length == 0 {
return errors.New("can't add empty string into payload")
}
cmsg := C.CString(msg)
clength := C.int(length)
defer C.free(unsafe.Pointer(cmsg))
st := C.AddOneStringToPayload(w.payloadWriterPtr, cmsg, clength)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return errors.New(msg)
}
return nil
}
// dimension > 0 && (%8 == 0)
func (w *PayloadWriter) AddBinaryVectorToPayload(binVec []byte, dim int) error {
length := len(binVec)
if length <= 0 {
return errors.New("can't add empty binVec into payload")
}
if dim <= 0 {
return errors.New("dimension should be greater than 0")
}
cBinVec := (*C.uint8_t)(&binVec[0])
cDim := C.int(dim)
cLength := C.int(length / (dim / 8))
st := C.AddBinaryVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return errors.New(msg)
}
return nil
}
// dimension > 0 && (%8 == 0)
func (w *PayloadWriter) AddFloatVectorToPayload(floatVec []float32, dim int) error {
length := len(floatVec)
if length <= 0 {
return errors.New("can't add empty floatVec into payload")
}
if dim <= 0 {
return errors.New("dimension should be greater than 0")
}
cBinVec := (*C.float)(&floatVec[0])
cDim := C.int(dim)
cLength := C.int(length / dim)
st := C.AddFloatVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) FinishPayloadWriter() error {
st := C.FinishPayloadWriter(w.payloadWriterPtr)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) GetPayloadBufferFromWriter() ([]byte, error) {
cb := C.GetPayloadBufferFromWriter(w.payloadWriterPtr)
pointer := unsafe.Pointer(cb.data)
length := int(cb.length)
if length <= 0 {
return nil, errors.New("empty buffer")
}
// refer to: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
slice := (*[1 << 28]byte)(pointer)[:length:length]
return slice, nil
}
func (w *PayloadWriter) GetPayloadLengthFromWriter() (int, error) {
length := C.GetPayloadLengthFromWriter(w.payloadWriterPtr)
return int(length), nil
}
func (w *PayloadWriter) ReleasePayloadWriter() error {
st := C.ReleasePayloadWriter(w.payloadWriterPtr)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) Close() error {
return w.ReleasePayloadWriter()
}
func NewPayloadReader(colType schemapb.DataType, buf []byte) (*PayloadReader, error) {
if len(buf) == 0 {
return nil, errors.New("create Payload reader failed, buffer is empty")
}
r := C.NewPayloadReader(C.int(colType), (*C.uchar)(unsafe.Pointer(&buf[0])), C.long(len(buf)))
return &PayloadReader{payloadReaderPtr: r, colType: colType}, nil
}
// Params:
// `idx`: String index
// Return:
// `interface{}`: all types.
// `int`: length, only meaningful to FLOAT/BINARY VECTOR type.
// `error`: error.
func (r *PayloadReader) GetDataFromPayload(idx ...int) (interface{}, int, error) {
switch len(idx) {
case 1:
switch r.colType {
case schemapb.DataType_STRING:
val, err := r.GetOneStringFromPayload(idx[0])
return val, 0, err
}
case 0:
switch r.colType {
case schemapb.DataType_BOOL:
val, err := r.GetBoolFromPayload()
return val, 0, err
case schemapb.DataType_INT8:
val, err := r.GetInt8FromPayload()
return val, 0, err
case schemapb.DataType_INT16:
val, err := r.GetInt16FromPayload()
return val, 0, err
case schemapb.DataType_INT32:
val, err := r.GetInt32FromPayload()
return val, 0, err
case schemapb.DataType_INT64:
val, err := r.GetInt64FromPayload()
return val, 0, err
case schemapb.DataType_FLOAT:
val, err := r.GetFloatFromPayload()
return val, 0, err
case schemapb.DataType_DOUBLE:
val, err := r.GetDoubleFromPayload()
return val, 0, err
case schemapb.DataType_VECTOR_BINARY:
return r.GetBinaryVectorFromPayload()
case schemapb.DataType_VECTOR_FLOAT:
return r.GetFloatVectorFromPayload()
default:
return nil, 0, errors.New("Unknown type")
}
default:
return nil, 0, errors.New("incorrect number of index")
}
return nil, 0, errors.New("unknown error")
}
func (r *PayloadReader) ReleasePayloadReader() error {
st := C.ReleasePayloadReader(r.payloadReaderPtr)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return errors.New(msg)
}
return nil
}
func (r *PayloadReader) GetBoolFromPayload() ([]bool, error) {
var cMsg *C.bool
var cSize C.int
st := C.GetBoolFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]bool)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetInt8FromPayload() ([]int8, error) {
var cMsg *C.int8_t
var cSize C.int
st := C.GetInt8FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]int8)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetInt16FromPayload() ([]int16, error) {
var cMsg *C.int16_t
var cSize C.int
st := C.GetInt16FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]int16)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetInt32FromPayload() ([]int32, error) {
var cMsg *C.int32_t
var cSize C.int
st := C.GetInt32FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]int32)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetInt64FromPayload() ([]int64, error) {
var cMsg *C.int64_t
var cSize C.int
st := C.GetInt64FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]int64)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetFloatFromPayload() ([]float32, error) {
var cMsg *C.float
var cSize C.int
st := C.GetFloatFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]float32)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetDoubleFromPayload() ([]float64, error) {
var cMsg *C.double
var cSize C.int
st := C.GetDoubleFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]float64)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetOneStringFromPayload(idx int) (string, error) {
var cStr *C.char
var cSize C.int
st := C.GetOneStringFromPayload(r.payloadReaderPtr, C.int(idx), &cStr, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return "", errors.New(msg)
}
return C.GoStringN(cStr, cSize), nil
}
// ,dimension, error
func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, error) {
var cMsg *C.uint8_t
var cDim C.int
var cLen C.int
st := C.GetBinaryVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, 0, errors.New(msg)
}
length := (cDim / 8) * cLen
slice := (*[1 << 28]byte)(unsafe.Pointer(cMsg))[:length:length]
return slice, int(cDim), nil
}
// ,dimension, error
func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) {
var cMsg *C.float
var cDim C.int
var cLen C.int
st := C.GetFloatVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, 0, errors.New(msg)
}
length := cDim * cLen
slice := (*[1 << 28]float32)(unsafe.Pointer(cMsg))[:length:length]
return slice, int(cDim), nil
}
func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) {
length := C.GetPayloadLengthFromReader(r.payloadReaderPtr)
return int(length), nil
}
func (r *PayloadReader) Close() error {
return r.ReleasePayloadReader()
}

View File

@ -1,426 +0,0 @@
package storage
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
func TestPayload_ReaderandWriter(t *testing.T) {
t.Run("TestBool", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_BOOL)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddBoolToPayload([]bool{false, false, false, false})
assert.Nil(t, err)
err = w.AddDataToPayload([]bool{false, false, false, false})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 8, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_BOOL, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 8)
bools, err := r.GetBoolFromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []bool{false, false, false, false, false, false, false, false}, bools)
ibools, _, err := r.GetDataFromPayload()
bools = ibools.([]bool)
assert.Nil(t, err)
assert.ElementsMatch(t, []bool{false, false, false, false, false, false, false, false}, bools)
defer r.ReleasePayloadReader()
})
t.Run("TestInt8", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_INT8)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddInt8ToPayload([]int8{1, 2, 3})
assert.Nil(t, err)
err = w.AddDataToPayload([]int8{4, 5, 6})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_INT8, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
int8s, err := r.GetInt8FromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []int8{1, 2, 3, 4, 5, 6}, int8s)
iint8s, _, err := r.GetDataFromPayload()
int8s = iint8s.([]int8)
assert.Nil(t, err)
assert.ElementsMatch(t, []int8{1, 2, 3, 4, 5, 6}, int8s)
defer r.ReleasePayloadReader()
})
t.Run("TestInt16", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_INT16)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddInt16ToPayload([]int16{1, 2, 3})
assert.Nil(t, err)
err = w.AddDataToPayload([]int16{1, 2, 3})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_INT16, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
int16s, err := r.GetInt16FromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []int16{1, 2, 3, 1, 2, 3}, int16s)
iint16s, _, err := r.GetDataFromPayload()
int16s = iint16s.([]int16)
assert.Nil(t, err)
assert.ElementsMatch(t, []int16{1, 2, 3, 1, 2, 3}, int16s)
defer r.ReleasePayloadReader()
})
t.Run("TestInt32", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_INT32)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddInt32ToPayload([]int32{1, 2, 3})
assert.Nil(t, err)
err = w.AddDataToPayload([]int32{1, 2, 3})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_INT32, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
int32s, err := r.GetInt32FromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []int32{1, 2, 3, 1, 2, 3}, int32s)
iint32s, _, err := r.GetDataFromPayload()
int32s = iint32s.([]int32)
assert.Nil(t, err)
assert.ElementsMatch(t, []int32{1, 2, 3, 1, 2, 3}, int32s)
defer r.ReleasePayloadReader()
})
t.Run("TestInt64", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_INT64)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddInt64ToPayload([]int64{1, 2, 3})
assert.Nil(t, err)
err = w.AddDataToPayload([]int64{1, 2, 3})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_INT64, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
int64s, err := r.GetInt64FromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []int64{1, 2, 3, 1, 2, 3}, int64s)
iint64s, _, err := r.GetDataFromPayload()
int64s = iint64s.([]int64)
assert.Nil(t, err)
assert.ElementsMatch(t, []int64{1, 2, 3, 1, 2, 3}, int64s)
defer r.ReleasePayloadReader()
})
t.Run("TestFloat32", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_FLOAT)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddFloatToPayload([]float32{1.0, 2.0, 3.0})
assert.Nil(t, err)
err = w.AddDataToPayload([]float32{1.0, 2.0, 3.0})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_FLOAT, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
float32s, err := r.GetFloatFromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float32s)
ifloat32s, _, err := r.GetDataFromPayload()
float32s = ifloat32s.([]float32)
assert.Nil(t, err)
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float32s)
defer r.ReleasePayloadReader()
})
t.Run("TestDouble", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_DOUBLE)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddDoubleToPayload([]float64{1.0, 2.0, 3.0})
assert.Nil(t, err)
err = w.AddDataToPayload([]float64{1.0, 2.0, 3.0})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_DOUBLE, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
float64s, err := r.GetDoubleFromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []float64{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float64s)
ifloat64s, _, err := r.GetDataFromPayload()
float64s = ifloat64s.([]float64)
assert.Nil(t, err)
assert.ElementsMatch(t, []float64{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float64s)
defer r.ReleasePayloadReader()
})
t.Run("TestAddOneString", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_STRING)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddOneStringToPayload("hello0")
assert.Nil(t, err)
err = w.AddOneStringToPayload("hello1")
assert.Nil(t, err)
err = w.AddOneStringToPayload("hello2")
assert.Nil(t, err)
err = w.AddDataToPayload("hello3")
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, length, 4)
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_STRING, buffer)
assert.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 4)
str0, err := r.GetOneStringFromPayload(0)
assert.Nil(t, err)
assert.Equal(t, str0, "hello0")
str1, err := r.GetOneStringFromPayload(1)
assert.Nil(t, err)
assert.Equal(t, str1, "hello1")
str2, err := r.GetOneStringFromPayload(2)
assert.Nil(t, err)
assert.Equal(t, str2, "hello2")
str3, err := r.GetOneStringFromPayload(3)
assert.Nil(t, err)
assert.Equal(t, str3, "hello3")
istr0, _, err := r.GetDataFromPayload(0)
str0 = istr0.(string)
assert.Nil(t, err)
assert.Equal(t, str0, "hello0")
istr1, _, err := r.GetDataFromPayload(1)
str1 = istr1.(string)
assert.Nil(t, err)
assert.Equal(t, str1, "hello1")
istr2, _, err := r.GetDataFromPayload(2)
str2 = istr2.(string)
assert.Nil(t, err)
assert.Equal(t, str2, "hello2")
istr3, _, err := r.GetDataFromPayload(3)
str3 = istr3.(string)
assert.Nil(t, err)
assert.Equal(t, str3, "hello3")
err = r.ReleasePayloadReader()
assert.Nil(t, err)
err = w.ReleasePayloadWriter()
assert.Nil(t, err)
})
t.Run("TestBinaryVector", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_VECTOR_BINARY)
require.Nil(t, err)
require.NotNil(t, w)
in := make([]byte, 16)
for i := 0; i < 16; i++ {
in[i] = 1
}
in2 := make([]byte, 8)
for i := 0; i < 8; i++ {
in2[i] = 1
}
err = w.AddBinaryVectorToPayload(in, 8)
assert.Nil(t, err)
err = w.AddDataToPayload(in2, 8)
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 24, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_VECTOR_BINARY, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 24)
binVecs, dim, err := r.GetBinaryVectorFromPayload()
assert.Nil(t, err)
assert.Equal(t, 8, dim)
assert.Equal(t, 24, len(binVecs))
fmt.Println(binVecs)
ibinVecs, dim, err := r.GetDataFromPayload()
assert.Nil(t, err)
binVecs = ibinVecs.([]byte)
assert.Equal(t, 8, dim)
assert.Equal(t, 24, len(binVecs))
defer r.ReleasePayloadReader()
})
t.Run("TestFloatVector", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_VECTOR_FLOAT)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddFloatVectorToPayload([]float32{1.0, 2.0}, 1)
assert.Nil(t, err)
err = w.AddDataToPayload([]float32{3.0, 4.0}, 1)
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 4, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_VECTOR_FLOAT, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 4)
floatVecs, dim, err := r.GetFloatVectorFromPayload()
assert.Nil(t, err)
assert.Equal(t, 1, dim)
assert.Equal(t, 4, len(floatVecs))
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 4.0}, floatVecs)
ifloatVecs, dim, err := r.GetDataFromPayload()
assert.Nil(t, err)
floatVecs = ifloatVecs.([]float32)
assert.Equal(t, 1, dim)
assert.Equal(t, 4, len(floatVecs))
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 4.0}, floatVecs)
defer r.ReleasePayloadReader()
})
}

View File

@ -16,17 +16,13 @@ import (
"os"
"path"
"runtime"
"strconv"
"strings"
"github.com/spf13/cast"
"github.com/spf13/viper"
memkv "github.com/zilliztech/milvus-distributed/internal/kv/mem"
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
)
type UniqueID = typeutil.UniqueID
type Base interface {
Load(key string) (string, error)
LoadRange(key, endKey string, limit int) ([]string, []string, error)
@ -42,18 +38,7 @@ type BaseTable struct {
func (gp *BaseTable) Init() {
gp.params = memkv.NewMemoryKV()
err := gp.LoadYaml("milvus.yaml")
if err != nil {
panic(err)
}
err = gp.LoadYaml("advanced/common.yaml")
if err != nil {
panic(err)
}
err = gp.LoadYaml("advanced/channel.yaml")
err := gp.LoadYaml("config.yaml")
if err != nil {
panic(err)
}
@ -161,140 +146,3 @@ func (gp *BaseTable) Remove(key string) error {
func (gp *BaseTable) Save(key, value string) error {
return gp.params.Save(strings.ToLower(key), value)
}
func (gp *BaseTable) ParseFloat(key string) float64 {
valueStr, err := gp.Load(key)
if err != nil {
panic(err)
}
value, err := strconv.ParseFloat(valueStr, 64)
if err != nil {
panic(err)
}
return value
}
func (gp *BaseTable) ParseInt64(key string) int64 {
valueStr, err := gp.Load(key)
if err != nil {
panic(err)
}
value, err := strconv.Atoi(valueStr)
if err != nil {
panic(err)
}
return int64(value)
}
func (gp *BaseTable) ParseInt32(key string) int32 {
valueStr, err := gp.Load(key)
if err != nil {
panic(err)
}
value, err := strconv.Atoi(valueStr)
if err != nil {
panic(err)
}
return int32(value)
}
func (gp *BaseTable) ParseInt(key string) int {
valueStr, err := gp.Load(key)
if err != nil {
panic(err)
}
value, err := strconv.Atoi(valueStr)
if err != nil {
panic(err)
}
return value
}
func (gp *BaseTable) WriteNodeIDList() []UniqueID {
proxyIDStr, err := gp.Load("nodeID.writeNodeIDList")
if err != nil {
panic(err)
}
var ret []UniqueID
proxyIDs := strings.Split(proxyIDStr, ",")
for _, i := range proxyIDs {
v, err := strconv.Atoi(i)
if err != nil {
log.Panicf("load write node id list error, %s", err.Error())
}
ret = append(ret, UniqueID(v))
}
return ret
}
func (gp *BaseTable) ProxyIDList() []UniqueID {
proxyIDStr, err := gp.Load("nodeID.proxyIDList")
if err != nil {
panic(err)
}
var ret []UniqueID
proxyIDs := strings.Split(proxyIDStr, ",")
for _, i := range proxyIDs {
v, err := strconv.Atoi(i)
if err != nil {
log.Panicf("load proxy id list error, %s", err.Error())
}
ret = append(ret, UniqueID(v))
}
return ret
}
func (gp *BaseTable) QueryNodeIDList() []UniqueID {
queryNodeIDStr, err := gp.Load("nodeID.queryNodeIDList")
if err != nil {
panic(err)
}
var ret []UniqueID
queryNodeIDs := strings.Split(queryNodeIDStr, ",")
for _, i := range queryNodeIDs {
v, err := strconv.Atoi(i)
if err != nil {
log.Panicf("load proxy id list error, %s", err.Error())
}
ret = append(ret, UniqueID(v))
}
return ret
}
// package methods
func ConvertRangeToIntRange(rangeStr, sep string) []int {
items := strings.Split(rangeStr, sep)
if len(items) != 2 {
panic("Illegal range ")
}
startStr := items[0]
endStr := items[1]
start, err := strconv.Atoi(startStr)
if err != nil {
panic(err)
}
end, err := strconv.Atoi(endStr)
if err != nil {
panic(err)
}
if start < 0 || end < 0 {
panic("Illegal range value")
}
if start > end {
panic("Illegal range value, start > end")
}
return []int{start, end}
}
func ConvertRangeToIntSlice(rangeStr, sep string) []int {
rangeSlice := ConvertRangeToIntRange(rangeStr, sep)
start, end := rangeSlice[0], rangeSlice[1]
var ret []int
for i := start; i < end; i++ {
ret = append(ret, i)
}
return ret
}

View File

@ -12,7 +12,6 @@
package paramtable
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
@ -22,8 +21,6 @@ var Params = BaseTable{}
func TestMain(m *testing.M) {
Params.Init()
code := m.Run()
os.Exit(code)
}
//func TestMain
@ -58,13 +55,13 @@ func TestGlobalParamsTable_SaveAndLoad(t *testing.T) {
}
func TestGlobalParamsTable_LoadRange(t *testing.T) {
_ = Params.Save("xxxaab", "10")
_ = Params.Save("xxxfghz", "20")
_ = Params.Save("xxxbcde", "1.1")
_ = Params.Save("xxxabcd", "testSaveAndLoad")
_ = Params.Save("xxxzhi", "12")
_ = Params.Save("abc", "10")
_ = Params.Save("fghz", "20")
_ = Params.Save("bcde", "1.1")
_ = Params.Save("abcd", "testSaveAndLoad")
_ = Params.Save("zhi", "12")
keys, values, err := Params.LoadRange("xxxa", "xxxg", 10)
keys, values, err := Params.LoadRange("a", "g", 10)
assert.Nil(t, err)
assert.Equal(t, 4, len(keys))
assert.Equal(t, "10", values[0])
@ -100,17 +97,24 @@ func TestGlobalParamsTable_Remove(t *testing.T) {
}
func TestGlobalParamsTable_LoadYaml(t *testing.T) {
err := Params.LoadYaml("milvus.yaml")
err := Params.LoadYaml("config.yaml")
assert.Nil(t, err)
err = Params.LoadYaml("advanced/channel.yaml")
assert.Nil(t, err)
value1, err1 := Params.Load("etcd.address")
value2, err2 := Params.Load("pulsar.port")
value3, err3 := Params.Load("reader.topicend")
value4, err4 := Params.Load("proxy.pulsarTopics.readerTopicPrefix")
value5, err5 := Params.Load("proxy.network.address")
_, err = Params.Load("etcd.address")
assert.Nil(t, err)
_, err = Params.Load("pulsar.port")
assert.Nil(t, err)
_, err = Params.Load("msgChannel.channelRange.insert")
assert.Nil(t, err)
assert.Equal(t, value1, "localhost")
assert.Equal(t, value2, "6650")
assert.Equal(t, value3, "128")
assert.Equal(t, value4, "milvusReader")
assert.Equal(t, value5, "0.0.0.0")
assert.Nil(t, err1)
assert.Nil(t, err2)
assert.Nil(t, err3)
assert.Nil(t, err4)
assert.Nil(t, err5)
}

Some files were not shown because too many files have changed in this diff Show More