mirror of https://github.com/milvus-io/milvus.git
parent
63c8f60c6e
commit
70710dee47
|
@ -55,7 +55,3 @@ cmake_build/
|
|||
|
||||
.DS_Store
|
||||
*.swp
|
||||
cwrapper_build
|
||||
**/.clangd/*
|
||||
**/compile_commands.json
|
||||
**/.lint
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
timeout(time: 10, unit: 'MINUTES') {
|
||||
timeout(time: 5, unit: 'MINUTES') {
|
||||
dir ("scripts") {
|
||||
sh '. ./before-install.sh && unset http_proxy && unset https_proxy && ./check_cache.sh -l $CCACHE_ARTFACTORY_URL --cache_dir=\$CCACHE_DIR -f ccache-\$OS_NAME-\$BUILD_ENV_IMAGE_ID.tar.gz || echo \"ccache files not found!\"'
|
||||
}
|
||||
|
||||
sh '. ./scripts/before-install.sh && make install'
|
||||
sh '. ./scripts/before-install.sh && make check-proto-product && make verifiers && make install'
|
||||
|
||||
dir ("scripts") {
|
||||
withCredentials([usernamePassword(credentialsId: "${env.JFROG_CREDENTIALS_ID}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) {
|
||||
|
|
|
@ -4,10 +4,7 @@ try {
|
|||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d pulsar'
|
||||
dir ('build/docker/deploy') {
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} pull'
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d master'
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d proxy'
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} run -e QUERY_NODE_ID=1 -d querynode'
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} run -e QUERY_NODE_ID=2 -d querynode'
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d'
|
||||
}
|
||||
|
||||
dir ('build/docker/test') {
|
||||
|
|
8
Makefile
8
Makefile
|
@ -41,9 +41,9 @@ fmt:
|
|||
lint:
|
||||
@echo "Running $@ check"
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint cache clean
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=30m --config ./.golangci.yml ./internal/...
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=30m --config ./.golangci.yml ./cmd/...
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=30m --config ./.golangci.yml ./tests/go/...
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./internal/...
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./cmd/...
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./tests/go/...
|
||||
|
||||
ruleguard:
|
||||
@echo "Running $@ check"
|
||||
|
@ -65,11 +65,9 @@ build-go:
|
|||
|
||||
build-cpp:
|
||||
@(env bash $(PWD)/scripts/core_build.sh)
|
||||
@(env bash $(PWD)/scripts/cwrapper_build.sh -t Release)
|
||||
|
||||
build-cpp-with-unittest:
|
||||
@(env bash $(PWD)/scripts/core_build.sh -u)
|
||||
@(env bash $(PWD)/scripts/cwrapper_build.sh -t Release)
|
||||
|
||||
# Runs the tests.
|
||||
unittest: test-cpp test-go
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
SOURCE_REPO=milvusdb
|
||||
TARGET_REPO=milvusdb
|
||||
SOURCE_TAG=latest
|
||||
TARGET_TAG=latest
|
|
@ -2,7 +2,6 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
@ -14,7 +13,8 @@ import (
|
|||
|
||||
func main() {
|
||||
proxy.Init()
|
||||
fmt.Println("ProxyID is", proxy.Params.ProxyID())
|
||||
|
||||
// Creates server.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
svr, err := proxy.CreateProxy(ctx)
|
||||
if err != nil {
|
||||
|
|
|
@ -2,24 +2,18 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/querynode"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
querynode.Init()
|
||||
fmt.Println("QueryNodeID is", querynode.Params.QueryNodeID())
|
||||
// Creates server.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
svr := querynode.NewQueryNode(ctx, 0)
|
||||
|
||||
sc := make(chan os.Signal, 1)
|
||||
signal.Notify(sc,
|
||||
|
@ -34,14 +28,8 @@ func main() {
|
|||
cancel()
|
||||
}()
|
||||
|
||||
if err := svr.Start(); err != nil {
|
||||
log.Fatal("run server failed", zap.Error(err))
|
||||
}
|
||||
querynode.StartQueryNode(ctx)
|
||||
|
||||
<-ctx.Done()
|
||||
log.Print("Got signal to exit", zap.String("signal", sig.String()))
|
||||
|
||||
svr.Close()
|
||||
switch sig {
|
||||
case syscall.SIGTERM:
|
||||
exit(0)
|
||||
|
|
|
@ -32,7 +32,7 @@ msgChannel:
|
|||
|
||||
# default channel range [0, 1)
|
||||
channelRange:
|
||||
insert: [0, 2]
|
||||
insert: [0, 1]
|
||||
delete: [0, 1]
|
||||
dataDefinition: [0,1]
|
||||
k2s: [0, 1]
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
master:
|
||||
address: localhost
|
||||
port: 53100
|
||||
pulsarmoniterinterval: 1
|
||||
pulsartopic: "monitor-topic"
|
||||
|
||||
proxyidlist: [1, 2]
|
||||
proxyTimeSyncChannels: ["proxy1", "proxy2"]
|
||||
proxyTimeSyncSubName: "proxy-topic"
|
||||
softTimeTickBarrierInterval: 500
|
||||
|
||||
writeidlist: [3, 4]
|
||||
writeTimeSyncChannels: ["write3", "write4"]
|
||||
writeTimeSyncSubName: "write-topic"
|
||||
|
||||
dmTimeSyncChannels: ["dm5", "dm6"]
|
||||
k2sTimeSyncChannels: ["k2s7", "k2s8"]
|
||||
|
||||
defaultSizePerRecord: 1024
|
||||
minimumAssignSize: 1048576
|
||||
segmentThreshold: 536870912
|
||||
segmentExpireDuration: 2000
|
||||
segmentThresholdFactor: 0.75
|
||||
querynodenum: 1
|
||||
writenodenum: 1
|
||||
statsChannels: "statistic"
|
||||
|
||||
etcd:
|
||||
address: localhost
|
||||
port: 2379
|
||||
rootpath: by-dev
|
||||
segthreshold: 10000
|
||||
|
||||
minio:
|
||||
address: localhost
|
||||
port: 9000
|
||||
accessKeyID: minioadmin
|
||||
secretAccessKey: minioadmin
|
||||
useSSL: false
|
||||
|
||||
timesync:
|
||||
interval: 400
|
||||
|
||||
storage:
|
||||
driver: TIKV
|
||||
address: localhost
|
||||
port: 2379
|
||||
accesskey:
|
||||
secretkey:
|
||||
|
||||
pulsar:
|
||||
authentication: false
|
||||
user: user-default
|
||||
token: eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJKb2UifQ.ipevRNuRP6HflG8cFKnmUPtypruRC4fb1DWtoLL62SY
|
||||
address: localhost
|
||||
port: 6650
|
||||
topicnum: 128
|
||||
|
||||
reader:
|
||||
clientid: 0
|
||||
stopflag: -1
|
||||
readerqueuesize: 10000
|
||||
searchchansize: 10000
|
||||
key2segchansize: 10000
|
||||
topicstart: 0
|
||||
topicend: 128
|
||||
|
||||
writer:
|
||||
clientid: 0
|
||||
stopflag: -2
|
||||
readerqueuesize: 10000
|
||||
searchbyidchansize: 10000
|
||||
parallelism: 100
|
||||
topicstart: 0
|
||||
topicend: 128
|
||||
bucket: "zilliz-hz"
|
||||
|
||||
proxy:
|
||||
timezone: UTC+8
|
||||
proxy_id: 1
|
||||
numReaderNodes: 2
|
||||
tsoSaveInterval: 200
|
||||
timeTickInterval: 200
|
||||
|
||||
pulsarTopics:
|
||||
readerTopicPrefix: "milvusReader"
|
||||
numReaderTopics: 2
|
||||
deleteTopic: "milvusDeleter"
|
||||
queryTopic: "milvusQuery"
|
||||
resultTopic: "milvusResult"
|
||||
resultGroup: "milvusResultGroup"
|
||||
timeTickTopic: "milvusTimeTick"
|
||||
|
||||
network:
|
||||
address: 0.0.0.0
|
||||
port: 19530
|
||||
|
||||
logs:
|
||||
level: debug
|
||||
trace.enable: true
|
||||
path: /tmp/logs
|
||||
max_log_file_size: 1024MB
|
||||
log_rotate_num: 0
|
||||
|
||||
storage:
|
||||
path: /var/lib/milvus
|
||||
auto_flush_interval: 1
|
|
@ -12,7 +12,7 @@
|
|||
|
||||
nodeID: # will be deprecated after v0.2
|
||||
proxyIDList: [0]
|
||||
queryNodeIDList: [1, 2]
|
||||
queryNodeIDList: [2]
|
||||
writeNodeIDList: [3]
|
||||
|
||||
etcd:
|
||||
|
@ -23,13 +23,6 @@ etcd:
|
|||
kvSubPath: kv # kvRootPath = rootPath + '/' + kvSubPath
|
||||
segThreshold: 10000
|
||||
|
||||
minio:
|
||||
address: localhost
|
||||
port: 9000
|
||||
accessKeyID: minioadmin
|
||||
secretAccessKey: minioadmin
|
||||
useSSL: false
|
||||
|
||||
pulsar:
|
||||
address: localhost
|
||||
port: 6650
|
||||
|
|
|
@ -1,223 +0,0 @@
|
|||
## Binlog
|
||||
|
||||
InsertBinlog、DeleteBinlog、DDLBinlog
|
||||
|
||||
Binlog is stored in a columnar storage format, every column in schema should be stored in a individual file. Timestamp, schema, row id and primary key allocated by system are four special columns. Schema column records the DDL of the collection.
|
||||
|
||||
|
||||
|
||||
## Event format
|
||||
|
||||
Binlog file consists of 4 bytes magic number and a series of events. The first event must be descriptor event.
|
||||
|
||||
### Event format
|
||||
|
||||
```
|
||||
+=====================================+
|
||||
| event | timestamp 0 : 8 | create timestamp
|
||||
| header +----------------------------+
|
||||
| | type_code 8 : 1 | event type code
|
||||
| +----------------------------+
|
||||
| | server_id 9 : 4 | write node id
|
||||
| +----------------------------+
|
||||
| | event_length 13 : 4 | length of event, including header and data
|
||||
| +----------------------------+
|
||||
| | next_position 17 : 4 | offset of next event from the start of file
|
||||
| +----------------------------+
|
||||
| | extra_headers 21 : x-21 | reserved part
|
||||
+=====================================+
|
||||
| event | fixed part x : y |
|
||||
| data +----------------------------+
|
||||
| | variable part |
|
||||
+=====================================+
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Descriptor Event format
|
||||
|
||||
```
|
||||
+=====================================+
|
||||
| event | timestamp 0 : 8 | create timestamp
|
||||
| header +----------------------------+
|
||||
| | type_code 8 : 1 | event type code
|
||||
| +----------------------------+
|
||||
| | server_id 9 : 4 | write node id
|
||||
| +----------------------------+
|
||||
| | event_length 13 : 4 | length of event, including header and data
|
||||
| +----------------------------+
|
||||
| | next_position 17 : 4 | offset of next event from the start of file
|
||||
+=====================================+
|
||||
| event | binlog_version 21 : 2 | binlog version
|
||||
| data +----------------------------+
|
||||
| | server_version 23 : 8 | write node version
|
||||
| +----------------------------+
|
||||
| | commit_id 31 : 8 | commit id of the programe in git
|
||||
| +----------------------------+
|
||||
| | header_length 39 : 1 | header length of other event
|
||||
| +----------------------------+
|
||||
| | collection_id 40 : 8 | collection id
|
||||
| +----------------------------+
|
||||
| | partition_id 48 : 8 | partition id (schema column does not need)
|
||||
| +----------------------------+
|
||||
| | segment_id 56 : 8 | segment id (schema column does not need)
|
||||
| +----------------------------+
|
||||
| | start_timestamp 64 : 1 | minimum timestamp allocated by master of all events in this file
|
||||
| +----------------------------+
|
||||
| | end_timestamp 65 : 1 | maximum timestamp allocated by master of all events in this file
|
||||
| +----------------------------+
|
||||
| | post-header 66 : n | array of n bytes, one byte per event type that the server knows about
|
||||
| | lengths for all |
|
||||
| | event types |
|
||||
+=====================================+
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Type code
|
||||
|
||||
```
|
||||
DESCRIPTOR_EVENT
|
||||
INSERT_EVENT
|
||||
DELETE_EVENT
|
||||
CREATE_COLLECTION_EVENT
|
||||
DROP_COLLECTION_EVENT
|
||||
CREATE_PARTITION_EVENT
|
||||
DROP_PARTITION_EVENT
|
||||
```
|
||||
|
||||
DESCRIPTOR_EVENT must appear in all column files and always be the first event.
|
||||
|
||||
INSERT_EVENT 可以出现在除DDL binlog文件外的其他列的binlog
|
||||
|
||||
DELETE_EVENT 只能用于primary key 的binlog文件(目前只有按照primary key删除)
|
||||
|
||||
CREATE_COLLECTION_EVENT、DROP_COLLECTION_EVENT、CREATE_PARTITION_EVENT、DROP_PARTITION_EVENT 只出现在DDL binlog文件
|
||||
|
||||
|
||||
|
||||
### Event data part
|
||||
|
||||
```
|
||||
event data part
|
||||
|
||||
INSERT_EVENT:
|
||||
+================================================+
|
||||
| event | fixed | start_timestamp x : 8 | min timestamp in this event
|
||||
| data | part +------------------------------+
|
||||
| | | end_timestamp x+8 : 8 | max timestamp in this event
|
||||
| | +------------------------------+
|
||||
| | | reserved x+16 : y-x-16 | reserved part
|
||||
| +--------+------------------------------+
|
||||
| |variable| parquet payloI ad | payload in parquet format
|
||||
| |part | |
|
||||
+================================================+
|
||||
|
||||
other events is similar with INSERT_EVENT
|
||||
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
### Example
|
||||
|
||||
Schema
|
||||
|
||||
string | int | float(optional) | vector(512)
|
||||
|
||||
|
||||
|
||||
Request:
|
||||
|
||||
InsertRequest rows(1W)
|
||||
|
||||
DeleteRequest pk=1
|
||||
|
||||
DropPartition partitionTag="abc"
|
||||
|
||||
|
||||
|
||||
insert binlogs:
|
||||
|
||||
rowid, pk, ts, string, int, float, vector 6 files
|
||||
|
||||
all events are INSERT_EVENT
|
||||
float column file contains some NULL value
|
||||
|
||||
delete binlogs:
|
||||
|
||||
pk, ts 2 files
|
||||
|
||||
pk's events are DELETE_EVENT, ts's events are INSERT_EVENT
|
||||
|
||||
DDL binlogs:
|
||||
|
||||
ddl, ts
|
||||
|
||||
ddl's event is DROP_PARTITION_EVENT, ts's event is INSERT_EVENT
|
||||
|
||||
|
||||
|
||||
C++ interface
|
||||
|
||||
```c++
|
||||
typedef void* CPayloadWriter
|
||||
typedef struct CBuffer {
|
||||
char* data;
|
||||
int length;
|
||||
} CBuffer
|
||||
|
||||
typedef struct CStatus {
|
||||
int error_code;
|
||||
const char* error_msg;
|
||||
} CStatus
|
||||
|
||||
|
||||
// C++ interface
|
||||
// writer
|
||||
CPayloadWriter NewPayloadWriter(int columnType);
|
||||
CStatus AddBooleanToPayload(CPayloadWriter payloadWriter, bool *values, int length);
|
||||
CStatus AddInt8ToPayload(CPayloadWriter payloadWriter, int8_t *values, int length);
|
||||
CStatus AddInt16ToPayload(CPayloadWriter payloadWriter, int16_t *values, int length);
|
||||
CStatus AddInt32ToPayload(CPayloadWriter payloadWriter, int32_t *values, int length);
|
||||
CStatus AddInt64ToPayload(CPayloadWriter payloadWriter, int64_t *values, int length);
|
||||
CStatus AddFloatToPayload(CPayloadWriter payloadWriter, float *values, int length);
|
||||
CStatus AddDoubleToPayload(CPayloadWriter payloadWriter, double *values, int length);
|
||||
CStatus AddOneStringToPayload(CPayloadWriter payloadWriter, char *cstr, int str_size);
|
||||
CStatus AddBinaryVectorToPayload(CPayloadWriter payloadWriter, uint8_t *values, int dimension, int length);
|
||||
CStatus AddFloatVectorToPayload(CPayloadWriter payloadWriter, float *values, int dimension, int length);
|
||||
|
||||
CStatus FinishPayloadWriter(CPayloadWriter payloadWriter);
|
||||
CBuffer GetPayloadBufferFromWriter(CPayloadWriter payloadWriter);
|
||||
int GetPayloadLengthFromWriter(CPayloadWriter payloadWriter);
|
||||
CStatus ReleasePayloadWriter(CPayloadWriter handler);
|
||||
|
||||
// reader
|
||||
CPayloadReader NewPayloadReader(int columnType, uint8_t *buffer, int64_t buf_size);
|
||||
CStatus GetBoolFromPayload(CPayloadReader payloadReader, bool **values, int *length);
|
||||
CStatus GetInt8FromPayload(CPayloadReader payloadReader, int8_t **values, int *length);
|
||||
CStatus GetInt16FromPayload(CPayloadReader payloadReader, int16_t **values, int *length);
|
||||
CStatus GetInt32FromPayload(CPayloadReader payloadReader, int32_t **values, int *length);
|
||||
CStatus GetInt64FromPayload(CPayloadReader payloadReader, int64_t **values, int *length);
|
||||
CStatus GetFloatFromPayload(CPayloadReader payloadReader, float **values, int *length);
|
||||
CStatus GetDoubleFromPayload(CPayloadReader payloadReader, double **values, int *length);
|
||||
CStatus GetOneStringFromPayload(CPayloadReader payloadReader, int idx, char **cstr, int *str_size);
|
||||
CStatus GetBinaryVectorFromPayload(CPayloadReader payloadReader, uint8_t **values, int *dimension, int *length);
|
||||
CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader, float **values, int *dimension, int *length);
|
||||
|
||||
int GetPayloadLengthFromReader(CPayloadReader payloadReader);
|
||||
CStatus ReleasePayloadReader(CPayloadReader payloadReader);
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,150 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"path"
|
||||
"runtime"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
|
||||
|
||||
storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
type UniqueID = typeutil.UniqueID
|
||||
|
||||
// yaml.MapSlice
|
||||
|
||||
type MasterConfig struct {
|
||||
Address string
|
||||
Port int32
|
||||
PulsarMonitorInterval int32
|
||||
PulsarTopic string
|
||||
SegmentThreshold float32
|
||||
SegmentExpireDuration int64
|
||||
ProxyIDList []UniqueID
|
||||
QueryNodeNum int
|
||||
WriteNodeNum int
|
||||
}
|
||||
|
||||
type EtcdConfig struct {
|
||||
Address string
|
||||
Port int32
|
||||
Rootpath string
|
||||
Segthreshold int64
|
||||
}
|
||||
|
||||
type TimeSyncConfig struct {
|
||||
Interval int32
|
||||
}
|
||||
|
||||
type StorageConfig struct {
|
||||
Driver storagetype.DriverType
|
||||
Address string
|
||||
Port int32
|
||||
Accesskey string
|
||||
Secretkey string
|
||||
}
|
||||
|
||||
type PulsarConfig struct {
|
||||
Authentication bool
|
||||
User string
|
||||
Token string
|
||||
Address string
|
||||
Port int32
|
||||
TopicNum int
|
||||
}
|
||||
|
||||
type ProxyConfig struct {
|
||||
Timezone string `yaml:"timezone"`
|
||||
ProxyID int `yaml:"proxy_id"`
|
||||
NumReaderNodes int `yaml:"numReaderNodes"`
|
||||
TosSaveInterval int `yaml:"tsoSaveInterval"`
|
||||
TimeTickInterval int `yaml:"timeTickInterval"`
|
||||
PulsarTopics struct {
|
||||
ReaderTopicPrefix string `yaml:"readerTopicPrefix"`
|
||||
NumReaderTopics int `yaml:"numReaderTopics"`
|
||||
DeleteTopic string `yaml:"deleteTopic"`
|
||||
QueryTopic string `yaml:"queryTopic"`
|
||||
ResultTopic string `yaml:"resultTopic"`
|
||||
ResultGroup string `yaml:"resultGroup"`
|
||||
TimeTickTopic string `yaml:"timeTickTopic"`
|
||||
} `yaml:"pulsarTopics"`
|
||||
Network struct {
|
||||
Address string `yaml:"address"`
|
||||
Port int `yaml:"port"`
|
||||
} `yaml:"network"`
|
||||
Logs struct {
|
||||
Level string `yaml:"level"`
|
||||
TraceEnable bool `yaml:"trace.enable"`
|
||||
Path string `yaml:"path"`
|
||||
MaxLogFileSize string `yaml:"max_log_file_size"`
|
||||
LogRotateNum int `yaml:"log_rotate_num"`
|
||||
} `yaml:"logs"`
|
||||
Storage struct {
|
||||
Path string `yaml:"path"`
|
||||
AutoFlushInterval int `yaml:"auto_flush_interval"`
|
||||
} `yaml:"storage"`
|
||||
}
|
||||
|
||||
type Reader struct {
|
||||
ClientID int
|
||||
StopFlag int64
|
||||
ReaderQueueSize int
|
||||
SearchChanSize int
|
||||
Key2SegChanSize int
|
||||
TopicStart int
|
||||
TopicEnd int
|
||||
}
|
||||
|
||||
type Writer struct {
|
||||
ClientID int
|
||||
StopFlag int64
|
||||
ReaderQueueSize int
|
||||
SearchByIDChanSize int
|
||||
Parallelism int
|
||||
TopicStart int
|
||||
TopicEnd int
|
||||
Bucket string
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Master MasterConfig
|
||||
Etcd EtcdConfig
|
||||
Timesync TimeSyncConfig
|
||||
Storage StorageConfig
|
||||
Pulsar PulsarConfig
|
||||
Writer Writer
|
||||
Reader Reader
|
||||
Proxy ProxyConfig
|
||||
}
|
||||
|
||||
var Config ServerConfig
|
||||
|
||||
// func init() {
|
||||
// load_config()
|
||||
// }
|
||||
|
||||
func getConfigsDir() string {
|
||||
_, fpath, _, _ := runtime.Caller(0)
|
||||
configPath := path.Dir(fpath) + "/../../configs/"
|
||||
configPath = path.Dir(configPath)
|
||||
return configPath
|
||||
}
|
||||
|
||||
func LoadConfigWithPath(yamlFilePath string) {
|
||||
source, err := ioutil.ReadFile(yamlFilePath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = yaml.Unmarshal(source, &Config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
//fmt.Printf("Result: %v\n", Config)
|
||||
}
|
||||
|
||||
func LoadConfig(yamlFile string) {
|
||||
filePath := path.Join(getConfigsDir(), yamlFile)
|
||||
LoadConfigWithPath(filePath)
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
fmt.Printf("Result: %v\n", Config)
|
||||
}
|
|
@ -1,9 +1,8 @@
|
|||
set(COMMON_SRC
|
||||
Schema.cpp
|
||||
Types.cpp
|
||||
)
|
||||
set(COMMON_SRC
|
||||
Schema.cpp
|
||||
)
|
||||
|
||||
add_library(milvus_common
|
||||
${COMMON_SRC}
|
||||
)
|
||||
add_library(milvus_common
|
||||
${COMMON_SRC}
|
||||
)
|
||||
target_link_libraries(milvus_common milvus_proto)
|
||||
|
|
|
@ -10,13 +10,18 @@
|
|||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
#include "common/Types.h"
|
||||
#include "utils/Types.h"
|
||||
#include "utils/Status.h"
|
||||
#include "utils/EasyAssert.h"
|
||||
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace milvus {
|
||||
|
||||
using Timestamp = uint64_t; // TODO: use TiKV-like timestamp
|
||||
using engine::DataType;
|
||||
using engine::FieldElementType;
|
||||
|
||||
inline int
|
||||
field_sizeof(DataType data_type, int dim = 1) {
|
||||
switch (data_type) {
|
||||
|
@ -84,13 +89,7 @@ field_is_vector(DataType datatype) {
|
|||
|
||||
struct FieldMeta {
|
||||
public:
|
||||
FieldMeta(std::string_view name, DataType type) : name_(name), type_(type) {
|
||||
Assert(!is_vector());
|
||||
}
|
||||
|
||||
FieldMeta(std::string_view name, DataType type, int64_t dim, MetricType metric_type)
|
||||
: name_(name), type_(type), vector_info_(VectorInfo{dim, metric_type}) {
|
||||
Assert(is_vector());
|
||||
FieldMeta(std::string_view name, DataType type, int dim = 1) : name_(name), type_(type), dim_(dim) {
|
||||
}
|
||||
|
||||
bool
|
||||
|
@ -99,11 +98,14 @@ struct FieldMeta {
|
|||
return type_ == DataType::VECTOR_BINARY || type_ == DataType::VECTOR_FLOAT;
|
||||
}
|
||||
|
||||
int64_t
|
||||
void
|
||||
set_dim(int dim) {
|
||||
dim_ = dim;
|
||||
}
|
||||
|
||||
int
|
||||
get_dim() const {
|
||||
Assert(is_vector());
|
||||
Assert(vector_info_.has_value());
|
||||
return vector_info_->dim_;
|
||||
return dim_;
|
||||
}
|
||||
|
||||
const std::string&
|
||||
|
@ -118,20 +120,12 @@ struct FieldMeta {
|
|||
|
||||
int
|
||||
get_sizeof() const {
|
||||
if (is_vector()) {
|
||||
return field_sizeof(type_, get_dim());
|
||||
} else {
|
||||
return field_sizeof(type_, 1);
|
||||
}
|
||||
return field_sizeof(type_, dim_);
|
||||
}
|
||||
|
||||
private:
|
||||
struct VectorInfo {
|
||||
int64_t dim_;
|
||||
MetricType metric_type_;
|
||||
};
|
||||
std::string name_;
|
||||
DataType type_ = DataType::NONE;
|
||||
std::optional<VectorInfo> vector_info_;
|
||||
int dim_ = 1;
|
||||
};
|
||||
} // namespace milvus
|
||||
|
|
|
@ -11,50 +11,35 @@
|
|||
|
||||
#include "common/Schema.h"
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <boost/lexical_cast.hpp>
|
||||
|
||||
namespace milvus {
|
||||
|
||||
using std::string;
|
||||
static std::map<string, string>
|
||||
RepeatedKeyValToMap(const google::protobuf::RepeatedPtrField<proto::common::KeyValuePair>& kvs) {
|
||||
std::map<string, string> mapping;
|
||||
for (auto& kv : kvs) {
|
||||
AssertInfo(!mapping.count(kv.key()), "repeat key(" + kv.key() + ") in protobuf");
|
||||
mapping.emplace(kv.key(), kv.value());
|
||||
}
|
||||
return mapping;
|
||||
}
|
||||
|
||||
std::shared_ptr<Schema>
|
||||
Schema::ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto) {
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->set_auto_id(schema_proto.autoid());
|
||||
for (const milvus::proto::schema::FieldSchema& child : schema_proto.fields()) {
|
||||
const auto& type_params = child.type_params();
|
||||
int64_t dim = -1;
|
||||
auto data_type = DataType(child.data_type());
|
||||
for (const auto& type_param : type_params) {
|
||||
if (type_param.key() == "dim") {
|
||||
dim = strtoll(type_param.value().c_str(), nullptr, 10);
|
||||
}
|
||||
}
|
||||
|
||||
if (field_is_vector(data_type)) {
|
||||
AssertInfo(dim != -1, "dim not found");
|
||||
} else {
|
||||
AssertInfo(dim == 1 || dim == -1, "Invalid dim field. Should be 1 or not exists");
|
||||
dim = 1;
|
||||
}
|
||||
|
||||
if (child.is_primary_key()) {
|
||||
AssertInfo(!schema->primary_key_offset_opt_.has_value(), "repetitive primary key");
|
||||
schema->primary_key_offset_opt_ = schema->size();
|
||||
}
|
||||
|
||||
if (field_is_vector(data_type)) {
|
||||
auto type_map = RepeatedKeyValToMap(child.type_params());
|
||||
auto index_map = RepeatedKeyValToMap(child.index_params());
|
||||
if (!index_map.count("metric_type")) {
|
||||
auto default_metric_type =
|
||||
data_type == DataType::VECTOR_FLOAT ? MetricType::METRIC_L2 : MetricType::METRIC_Jaccard;
|
||||
index_map["metric_type"] = default_metric_type;
|
||||
}
|
||||
|
||||
AssertInfo(type_map.count("dim"), "dim not found");
|
||||
auto dim = boost::lexical_cast<int64_t>(type_map.at("dim"));
|
||||
AssertInfo(index_map.count("metric_type"), "index not found");
|
||||
auto metric_type = GetMetricType(index_map.at("metric_type"));
|
||||
schema->AddField(child.name(), data_type, dim, metric_type);
|
||||
} else {
|
||||
schema->AddField(child.name(), data_type);
|
||||
}
|
||||
schema->AddField(child.name(), data_type, dim);
|
||||
}
|
||||
return schema;
|
||||
}
|
||||
|
|
|
@ -24,15 +24,19 @@ namespace milvus {
|
|||
class Schema {
|
||||
public:
|
||||
void
|
||||
AddField(std::string_view field_name, DataType data_type) {
|
||||
auto field_meta = FieldMeta(field_name, data_type);
|
||||
AddField(std::string_view field_name, DataType data_type, int dim = 1) {
|
||||
auto field_meta = FieldMeta(field_name, data_type, dim);
|
||||
this->AddField(std::move(field_meta));
|
||||
}
|
||||
|
||||
void
|
||||
AddField(std::string_view field_name, DataType data_type, int64_t dim, MetricType metric_type) {
|
||||
auto field_meta = FieldMeta(field_name, data_type, dim, metric_type);
|
||||
this->AddField(std::move(field_meta));
|
||||
AddField(FieldMeta field_meta) {
|
||||
auto offset = fields_.size();
|
||||
fields_.emplace_back(field_meta);
|
||||
offsets_.emplace(field_meta.get_name(), offset);
|
||||
auto field_sizeof = field_meta.get_sizeof();
|
||||
sizeof_infos_.push_back(field_sizeof);
|
||||
total_sizeof_ += field_sizeof;
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -40,6 +44,17 @@ class Schema {
|
|||
is_auto_id_ = is_auto_id;
|
||||
}
|
||||
|
||||
auto
|
||||
begin() {
|
||||
return fields_.begin();
|
||||
}
|
||||
|
||||
auto
|
||||
end() {
|
||||
return fields_.end();
|
||||
}
|
||||
|
||||
public:
|
||||
bool
|
||||
get_is_auto_id() const {
|
||||
return is_auto_id_;
|
||||
|
@ -108,20 +123,11 @@ class Schema {
|
|||
static std::shared_ptr<Schema>
|
||||
ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto);
|
||||
|
||||
void
|
||||
AddField(FieldMeta&& field_meta) {
|
||||
auto offset = fields_.size();
|
||||
fields_.emplace_back(field_meta);
|
||||
offsets_.emplace(field_meta.get_name(), offset);
|
||||
auto field_sizeof = field_meta.get_sizeof();
|
||||
sizeof_infos_.push_back(std::move(field_sizeof));
|
||||
total_sizeof_ += field_sizeof;
|
||||
}
|
||||
|
||||
private:
|
||||
// this is where data holds
|
||||
std::vector<FieldMeta> fields_;
|
||||
|
||||
private:
|
||||
// a mapping for random access
|
||||
std::unordered_map<std::string, int> offsets_;
|
||||
std::vector<int> sizeof_infos_;
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
//
|
||||
// Created by mike on 12/3/20.
|
||||
//
|
||||
#include "common/Types.h"
|
||||
#include <knowhere/index/vector_index/helpers/IndexParameter.h>
|
||||
#include "utils/EasyAssert.h"
|
||||
#include <boost/bimap.hpp>
|
||||
#include <boost/algorithm/string/case_conv.hpp>
|
||||
|
||||
namespace milvus {
|
||||
|
||||
using boost::algorithm::to_lower_copy;
|
||||
namespace Metric = knowhere::Metric;
|
||||
static auto map = [] {
|
||||
boost::bimap<std::string, MetricType> mapping;
|
||||
using pos = boost::bimap<std::string, MetricType>::value_type;
|
||||
mapping.insert(pos(to_lower_copy(std::string(Metric::L2)), MetricType::METRIC_L2));
|
||||
mapping.insert(pos(to_lower_copy(std::string(Metric::IP)), MetricType::METRIC_INNER_PRODUCT));
|
||||
mapping.insert(pos(to_lower_copy(std::string(Metric::JACCARD)), MetricType::METRIC_Jaccard));
|
||||
mapping.insert(pos(to_lower_copy(std::string(Metric::TANIMOTO)), MetricType::METRIC_Tanimoto));
|
||||
mapping.insert(pos(to_lower_copy(std::string(Metric::HAMMING)), MetricType::METRIC_Hamming));
|
||||
mapping.insert(pos(to_lower_copy(std::string(Metric::SUBSTRUCTURE)), MetricType::METRIC_Substructure));
|
||||
mapping.insert(pos(to_lower_copy(std::string(Metric::SUPERSTRUCTURE)), MetricType::METRIC_Superstructure));
|
||||
return mapping;
|
||||
}();
|
||||
|
||||
MetricType
|
||||
GetMetricType(const std::string& type_name) {
|
||||
auto real_name = to_lower_copy(type_name);
|
||||
AssertInfo(map.left.count(real_name), "metric type not found: (" + type_name + ")");
|
||||
return map.left.at(real_name);
|
||||
}
|
||||
|
||||
} // namespace milvus
|
|
@ -1,40 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#pragma once
|
||||
#include "utils/Types.h"
|
||||
#include <faiss/MetricType.h>
|
||||
#include <string>
|
||||
#include <boost/align/aligned_allocator.hpp>
|
||||
#include <vector>
|
||||
|
||||
namespace milvus {
|
||||
using Timestamp = uint64_t; // TODO: use TiKV-like timestamp
|
||||
using engine::DataType;
|
||||
using engine::FieldElementType;
|
||||
using engine::QueryResult;
|
||||
using MetricType = faiss::MetricType;
|
||||
|
||||
faiss::MetricType
|
||||
GetMetricType(const std::string& type);
|
||||
|
||||
// NOTE: dependent type
|
||||
// used at meta-template programming
|
||||
template <class...>
|
||||
constexpr std::true_type always_true{};
|
||||
|
||||
template <class...>
|
||||
constexpr std::false_type always_false{};
|
||||
|
||||
template <typename T>
|
||||
using aligned_vector = std::vector<T, boost::alignment::aligned_allocator<T, 512>>;
|
||||
|
||||
} // namespace milvus
|
|
@ -11,66 +11,20 @@
|
|||
|
||||
#include "BruteForceSearch.h"
|
||||
#include <vector>
|
||||
#include <common/Types.h>
|
||||
#include <boost/dynamic_bitset.hpp>
|
||||
#include <queue>
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
void
|
||||
BinarySearchBruteForceNaive(MetricType metric_type,
|
||||
int64_t code_size,
|
||||
const uint8_t* binary_chunk,
|
||||
int64_t chunk_size,
|
||||
int64_t topk,
|
||||
int64_t num_queries,
|
||||
const uint8_t* query_data,
|
||||
float* result_distances,
|
||||
idx_t* result_labels,
|
||||
faiss::ConcurrentBitsetPtr bitset) {
|
||||
// THIS IS A NAIVE IMPLEMENTATION, ready for optimize
|
||||
Assert(metric_type == faiss::METRIC_Jaccard);
|
||||
Assert(code_size % 4 == 0);
|
||||
|
||||
using T = std::tuple<float, int>;
|
||||
|
||||
for (int64_t q = 0; q < num_queries; ++q) {
|
||||
auto query_ptr = query_data + code_size * q;
|
||||
auto query = boost::dynamic_bitset(query_ptr, query_ptr + code_size);
|
||||
std::vector<T> max_heap(topk + 1, std::make_tuple(std::numeric_limits<float>::max(), -1));
|
||||
|
||||
for (int64_t i = 0; i < chunk_size; ++i) {
|
||||
auto element_ptr = binary_chunk + code_size * i;
|
||||
auto element = boost::dynamic_bitset(element_ptr, element_ptr + code_size);
|
||||
auto the_and = (query & element).count();
|
||||
auto the_or = (query | element).count();
|
||||
auto distance = the_or ? (float)(the_or - the_and) / the_or : 0;
|
||||
if (distance < std::get<0>(max_heap[0])) {
|
||||
max_heap[topk] = std::make_tuple(distance, i);
|
||||
std::push_heap(max_heap.begin(), max_heap.end());
|
||||
std::pop_heap(max_heap.begin(), max_heap.end());
|
||||
}
|
||||
}
|
||||
std::sort(max_heap.begin(), max_heap.end());
|
||||
for (int k = 0; k < topk; ++k) {
|
||||
auto info = max_heap[k];
|
||||
result_distances[k + q * topk] = std::get<0>(info);
|
||||
result_labels[k + q * topk] = std::get<1>(info);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
BinarySearchBruteForceFast(MetricType metric_type,
|
||||
int64_t code_size,
|
||||
const uint8_t* binary_chunk,
|
||||
int64_t chunk_size,
|
||||
int64_t topk,
|
||||
int64_t num_queries,
|
||||
const uint8_t* query_data,
|
||||
float* result_distances,
|
||||
idx_t* result_labels,
|
||||
faiss::ConcurrentBitsetPtr bitset) {
|
||||
BinarySearchBruteForce(faiss::MetricType metric_type,
|
||||
int64_t code_size,
|
||||
const uint8_t* binary_chunk,
|
||||
int64_t chunk_size,
|
||||
int64_t topk,
|
||||
int64_t num_queries,
|
||||
const uint8_t* query_data,
|
||||
float* result_distances,
|
||||
idx_t* result_labels,
|
||||
faiss::ConcurrentBitsetPtr bitset) {
|
||||
const idx_t block_size = segcore::DefaultElementPerChunk;
|
||||
bool use_heap = true;
|
||||
|
||||
|
@ -129,21 +83,6 @@ BinarySearchBruteForceFast(MetricType metric_type,
|
|||
for (int i = 0; i < num_queries; ++i) {
|
||||
result_distances[i] = static_cast<float>(int_distances[i]);
|
||||
}
|
||||
} else {
|
||||
PanicInfo("Unsupported metric type");
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
BinarySearchBruteForce(const dataset::BinaryQueryDataset& query_dataset,
|
||||
const uint8_t* binary_chunk,
|
||||
int64_t chunk_size,
|
||||
float* result_distances,
|
||||
idx_t* result_labels,
|
||||
faiss::ConcurrentBitsetPtr bitset) {
|
||||
// TODO: refactor the internal function
|
||||
BinarySearchBruteForceFast(query_dataset.metric_type, query_dataset.code_size, binary_chunk, chunk_size,
|
||||
query_dataset.topk, query_dataset.num_queries, query_dataset.query_data,
|
||||
result_distances, result_labels, bitset);
|
||||
}
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -15,25 +15,15 @@
|
|||
#include "common/Schema.h"
|
||||
|
||||
namespace milvus::query {
|
||||
using MetricType = faiss::MetricType;
|
||||
|
||||
namespace dataset {
|
||||
struct BinaryQueryDataset {
|
||||
MetricType metric_type;
|
||||
int64_t num_queries;
|
||||
int64_t topk;
|
||||
int64_t code_size;
|
||||
const uint8_t* query_data;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
||||
void
|
||||
BinarySearchBruteForce(const dataset::BinaryQueryDataset& query_dataset,
|
||||
BinarySearchBruteForce(faiss::MetricType metric_type,
|
||||
int64_t code_size,
|
||||
const uint8_t* binary_chunk,
|
||||
int64_t chunk_size,
|
||||
int64_t topk,
|
||||
int64_t num_queries,
|
||||
const uint8_t* query_data,
|
||||
float* result_distances,
|
||||
idx_t* result_labels,
|
||||
faiss::ConcurrentBitsetPtr bitset = nullptr);
|
||||
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -26,25 +26,15 @@ static std::unique_ptr<VectorPlanNode>
|
|||
ParseVecNode(Plan* plan, const Json& out_body) {
|
||||
Assert(out_body.is_object());
|
||||
// TODO add binary info
|
||||
auto vec_node = std::make_unique<FloatVectorANNS>();
|
||||
Assert(out_body.size() == 1);
|
||||
auto iter = out_body.begin();
|
||||
std::string field_name = iter.key();
|
||||
|
||||
auto& vec_info = iter.value();
|
||||
Assert(vec_info.is_object());
|
||||
auto topK = vec_info["topk"];
|
||||
AssertInfo(topK > 0, "topK must greater than 0");
|
||||
AssertInfo(topK < 16384, "topK is too large");
|
||||
auto field_meta = plan->schema_.operator[](field_name);
|
||||
|
||||
auto vec_node = [&]() -> std::unique_ptr<VectorPlanNode> {
|
||||
auto data_type = field_meta.get_data_type();
|
||||
if (data_type == DataType::VECTOR_FLOAT) {
|
||||
return std::make_unique<FloatVectorANNS>();
|
||||
} else {
|
||||
return std::make_unique<BinaryVectorANNS>();
|
||||
}
|
||||
}();
|
||||
vec_node->query_info_.topK_ = topK;
|
||||
vec_node->query_info_.metric_type_ = vec_info.at("metric_type");
|
||||
vec_node->query_info_.search_params_ = vec_info.at("params");
|
||||
|
@ -70,6 +60,8 @@ to_lower(const std::string& raw) {
|
|||
return data;
|
||||
}
|
||||
|
||||
template <class...>
|
||||
constexpr std::false_type always_false{};
|
||||
template <typename T>
|
||||
std::unique_ptr<Expr>
|
||||
ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) {
|
||||
|
@ -83,62 +75,31 @@ ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Js
|
|||
|
||||
AssertInfo(RangeExpr::mapping_.count(op_name), "op(" + op_name + ") not found");
|
||||
auto op = RangeExpr::mapping_.at(op_name);
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
Assert(item.value().is_boolean());
|
||||
} else if constexpr (std::is_integral_v<T>) {
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
Assert(item.value().is_number_integer());
|
||||
} else if constexpr (std::is_floating_point_v<T>) {
|
||||
Assert(item.value().is_number());
|
||||
} else {
|
||||
static_assert(always_false<T>, "unsupported type");
|
||||
__builtin_unreachable();
|
||||
}
|
||||
T value = item.value();
|
||||
expr->conditions_.emplace_back(op, value);
|
||||
}
|
||||
std::sort(expr->conditions_.begin(), expr->conditions_.end());
|
||||
return expr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::unique_ptr<Expr>
|
||||
ParseTermNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) {
|
||||
auto expr = std::make_unique<TermExprImpl<T>>();
|
||||
auto data_type = schema[field_name].get_data_type();
|
||||
Assert(body.is_array());
|
||||
expr->field_id_ = field_name;
|
||||
expr->data_type_ = data_type;
|
||||
for (auto& value : body) {
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
Assert(value.is_boolean());
|
||||
} else if constexpr (std::is_integral_v<T>) {
|
||||
Assert(value.is_number_integer());
|
||||
} else if constexpr (std::is_floating_point_v<T>) {
|
||||
Assert(value.is_number());
|
||||
} else {
|
||||
static_assert(always_false<T>, "unsupported type");
|
||||
__builtin_unreachable();
|
||||
}
|
||||
T real_value = value;
|
||||
expr->terms_.push_back(real_value);
|
||||
}
|
||||
std::sort(expr->terms_.begin(), expr->terms_.end());
|
||||
return expr;
|
||||
}
|
||||
|
||||
std::unique_ptr<Expr>
|
||||
ParseRangeNode(const Schema& schema, const Json& out_body) {
|
||||
Assert(out_body.is_object());
|
||||
Assert(out_body.size() == 1);
|
||||
auto out_iter = out_body.begin();
|
||||
auto field_name = out_iter.key();
|
||||
auto body = out_iter.value();
|
||||
auto data_type = schema[field_name].get_data_type();
|
||||
Assert(!field_is_vector(data_type));
|
||||
|
||||
switch (data_type) {
|
||||
case DataType::BOOL: {
|
||||
return ParseRangeNodeImpl<bool>(schema, field_name, body);
|
||||
PanicInfo("bool is not supported in Range node");
|
||||
// return ParseRangeNodeImpl<bool>(schema, field_name, body);
|
||||
}
|
||||
case DataType::INT8:
|
||||
return ParseRangeNodeImpl<int8_t>(schema, field_name, body);
|
||||
|
@ -157,42 +118,6 @@ ParseRangeNode(const Schema& schema, const Json& out_body) {
|
|||
}
|
||||
}
|
||||
|
||||
static std::unique_ptr<Expr>
|
||||
ParseTermNode(const Schema& schema, const Json& out_body) {
|
||||
Assert(out_body.size() == 1);
|
||||
auto out_iter = out_body.begin();
|
||||
auto field_name = out_iter.key();
|
||||
auto body = out_iter.value();
|
||||
auto data_type = schema[field_name].get_data_type();
|
||||
Assert(!field_is_vector(data_type));
|
||||
switch (data_type) {
|
||||
case DataType::BOOL: {
|
||||
return ParseTermNodeImpl<bool>(schema, field_name, body);
|
||||
}
|
||||
case DataType::INT8: {
|
||||
return ParseTermNodeImpl<int8_t>(schema, field_name, body);
|
||||
}
|
||||
case DataType::INT16: {
|
||||
return ParseTermNodeImpl<int16_t>(schema, field_name, body);
|
||||
}
|
||||
case DataType::INT32: {
|
||||
return ParseTermNodeImpl<int32_t>(schema, field_name, body);
|
||||
}
|
||||
case DataType::INT64: {
|
||||
return ParseTermNodeImpl<int64_t>(schema, field_name, body);
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
return ParseTermNodeImpl<float>(schema, field_name, body);
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
return ParseTermNodeImpl<double>(schema, field_name, body);
|
||||
}
|
||||
default: {
|
||||
PanicInfo("unsupported data_type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::unique_ptr<Plan>
|
||||
CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) {
|
||||
auto plan = std::make_unique<Plan>(schema);
|
||||
|
@ -208,10 +133,6 @@ CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) {
|
|||
if (pack.contains("vector")) {
|
||||
auto& out_body = pack.at("vector");
|
||||
plan->plan_node_ = ParseVecNode(plan.get(), out_body);
|
||||
} else if (pack.contains("term")) {
|
||||
AssertInfo(!predicate, "unsupported complex DSL");
|
||||
auto& out_body = pack.at("term");
|
||||
predicate = ParseTermNode(schema, out_body);
|
||||
} else if (pack.contains("range")) {
|
||||
AssertInfo(!predicate, "unsupported complex DSL");
|
||||
auto& out_body = pack.at("range");
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <boost/align/aligned_allocator.hpp>
|
||||
|
||||
namespace milvus::query {
|
||||
using Json = nlohmann::json;
|
||||
|
@ -38,6 +39,9 @@ struct Plan {
|
|||
// TODO: add move extra info
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using aligned_vector = std::vector<T, boost::alignment::aligned_allocator<T, 512>>;
|
||||
|
||||
struct Placeholder {
|
||||
// milvus::proto::service::PlaceholderGroup group_;
|
||||
std::string tag_;
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
#include <faiss/utils/distances.h>
|
||||
#include "utils/tools.h"
|
||||
#include "query/BruteForceSearch.h"
|
||||
|
||||
namespace milvus::query {
|
||||
using segcore::DefaultElementPerChunk;
|
||||
|
@ -27,7 +26,7 @@ create_bitmap_view(std::optional<const BitmapSimple*> bitmaps_opt, int64_t chunk
|
|||
return nullptr;
|
||||
}
|
||||
auto& bitmaps = *bitmaps_opt.value();
|
||||
auto src_vec = ~bitmaps.at(chunk_id);
|
||||
auto& src_vec = bitmaps.at(chunk_id);
|
||||
auto dst = std::make_shared<faiss::ConcurrentBitset>(src_vec.size());
|
||||
auto iter = reinterpret_cast<BitmapChunk::block_type*>(dst->mutable_data());
|
||||
|
||||
|
@ -42,7 +41,7 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
|
|||
int64_t num_queries,
|
||||
Timestamp timestamp,
|
||||
std::optional<const BitmapSimple*> bitmaps_opt,
|
||||
QueryResult& results) {
|
||||
segcore::QueryResult& results) {
|
||||
auto& schema = segment.get_schema();
|
||||
auto& indexing_record = segment.get_indexing_record();
|
||||
auto& record = segment.get_insert_record();
|
||||
|
@ -132,92 +131,7 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
|
|||
}
|
||||
results.result_ids_ = std::move(final_uids);
|
||||
// TODO: deprecated code end
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
BinaryQueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
|
||||
const query::QueryInfo& info,
|
||||
const uint8_t* query_data,
|
||||
int64_t num_queries,
|
||||
Timestamp timestamp,
|
||||
std::optional<const BitmapSimple*> bitmaps_opt,
|
||||
QueryResult& results) {
|
||||
auto& schema = segment.get_schema();
|
||||
auto& indexing_record = segment.get_indexing_record();
|
||||
auto& record = segment.get_insert_record();
|
||||
// step 1: binary search to find the barrier of the snapshot
|
||||
auto ins_barrier = get_barrier(record, timestamp);
|
||||
auto max_chunk = upper_div(ins_barrier, DefaultElementPerChunk);
|
||||
auto metric_type = GetMetricType(info.metric_type_);
|
||||
// auto del_barrier = get_barrier(deleted_record_, timestamp);
|
||||
|
||||
#if 0
|
||||
auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier);
|
||||
Assert(bitmap_holder);
|
||||
auto bitmap = bitmap_holder->bitmap_ptr;
|
||||
#endif
|
||||
|
||||
// step 2.1: get meta
|
||||
// step 2.2: get which vector field to search
|
||||
auto vecfield_offset_opt = schema.get_offset(info.field_id_);
|
||||
Assert(vecfield_offset_opt.has_value());
|
||||
auto vecfield_offset = vecfield_offset_opt.value();
|
||||
auto& field = schema[vecfield_offset];
|
||||
|
||||
Assert(field.get_data_type() == DataType::VECTOR_BINARY);
|
||||
auto dim = field.get_dim();
|
||||
auto code_size = dim / 8;
|
||||
auto topK = info.topK_;
|
||||
auto total_count = topK * num_queries;
|
||||
|
||||
// step 3: small indexing search
|
||||
std::vector<int64_t> final_uids(total_count, -1);
|
||||
std::vector<float> final_dis(total_count, std::numeric_limits<float>::max());
|
||||
query::dataset::BinaryQueryDataset query_dataset{metric_type, num_queries, topK, code_size, query_data};
|
||||
|
||||
using segcore::BinaryVector;
|
||||
auto vec_ptr = record.get_entity<BinaryVector>(vecfield_offset);
|
||||
|
||||
auto max_indexed_id = 0;
|
||||
// step 4: brute force search where small indexing is unavailable
|
||||
for (int chunk_id = max_indexed_id; chunk_id < max_chunk; ++chunk_id) {
|
||||
std::vector<int64_t> buf_uids(total_count, -1);
|
||||
std::vector<float> buf_dis(total_count, std::numeric_limits<float>::max());
|
||||
|
||||
auto& chunk = vec_ptr->get_chunk(chunk_id);
|
||||
auto nsize =
|
||||
chunk_id != max_chunk - 1 ? DefaultElementPerChunk : ins_barrier - chunk_id * DefaultElementPerChunk;
|
||||
|
||||
auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id);
|
||||
BinarySearchBruteForce(query_dataset, chunk.data(), nsize, buf_dis.data(), buf_uids.data(), bitmap_view);
|
||||
|
||||
// convert chunk uid to segment uid
|
||||
for (auto& x : buf_uids) {
|
||||
if (x != -1) {
|
||||
x += chunk_id * DefaultElementPerChunk;
|
||||
}
|
||||
}
|
||||
|
||||
segcore::merge_into(num_queries, topK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data());
|
||||
}
|
||||
|
||||
results.result_distances_ = std::move(final_dis);
|
||||
results.internal_seg_offsets_ = std::move(final_uids);
|
||||
results.topK_ = topK;
|
||||
results.num_queries_ = num_queries;
|
||||
|
||||
// TODO: deprecated code begin
|
||||
final_uids = results.internal_seg_offsets_;
|
||||
for (auto& id : final_uids) {
|
||||
if (id == -1) {
|
||||
continue;
|
||||
}
|
||||
id = record.uids_[id];
|
||||
}
|
||||
results.result_ids_ = std::move(final_uids);
|
||||
// TODO: deprecated code end
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -27,14 +27,5 @@ QueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
|
|||
int64_t num_queries,
|
||||
Timestamp timestamp,
|
||||
std::optional<const BitmapSimple*> bitmap_opt,
|
||||
QueryResult& results);
|
||||
|
||||
Status
|
||||
BinaryQueryBruteForceImpl(const segcore::SegmentSmallIndex& segment,
|
||||
const query::QueryInfo& info,
|
||||
const uint8_t* query_data,
|
||||
int64_t num_queries,
|
||||
Timestamp timestamp,
|
||||
std::optional<const BitmapSimple*> bitmaps_opt,
|
||||
QueryResult& results);
|
||||
segcore::QueryResult& results);
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "common/Types.h"
|
||||
#include "utils/Types.h"
|
||||
#include "utils/Json.h"
|
||||
|
||||
namespace milvus {
|
||||
|
|
|
@ -58,10 +58,6 @@ class ExecExprVisitor : ExprVisitor {
|
|||
auto
|
||||
ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType;
|
||||
|
||||
template <typename T>
|
||||
auto
|
||||
ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType;
|
||||
|
||||
private:
|
||||
segcore::SegmentSmallIndex& segment_;
|
||||
std::optional<RetType> ret_;
|
||||
|
|
|
@ -28,7 +28,7 @@ class ExecPlanNodeVisitor : PlanNodeVisitor {
|
|||
visit(BinaryVectorANNS& node) override;
|
||||
|
||||
public:
|
||||
using RetType = QueryResult;
|
||||
using RetType = segcore::QueryResult;
|
||||
ExecPlanNodeVisitor(segcore::SegmentBase& segment, Timestamp timestamp, const PlaceholderGroup& placeholder_group)
|
||||
: segment_(segment), timestamp_(timestamp), placeholder_group_(placeholder_group) {
|
||||
}
|
||||
|
|
|
@ -46,10 +46,6 @@ class ExecExprVisitor : ExprVisitor {
|
|||
auto
|
||||
ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType;
|
||||
|
||||
template <typename T>
|
||||
auto
|
||||
ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType;
|
||||
|
||||
private:
|
||||
segcore::SegmentSmallIndex& segment_;
|
||||
std::optional<RetType> ret_;
|
||||
|
@ -67,6 +63,11 @@ ExecExprVisitor::visit(BoolBinaryExpr& expr) {
|
|||
PanicInfo("unimplemented");
|
||||
}
|
||||
|
||||
void
|
||||
ExecExprVisitor::visit(TermExpr& expr) {
|
||||
PanicInfo("unimplemented");
|
||||
}
|
||||
|
||||
template <typename T, typename IndexFunc, typename ElementFunc>
|
||||
auto
|
||||
ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl<T>& expr, IndexFunc index_func, ElementFunc element_func)
|
||||
|
@ -83,17 +84,17 @@ ExecExprVisitor::ExecRangeVisitorImpl(RangeExprImpl<T>& expr, IndexFunc index_fu
|
|||
auto& indexing_record = segment_.get_indexing_record();
|
||||
const segcore::ScalarIndexingEntry<T>& entry = indexing_record.get_scalar_entry<T>(field_offset);
|
||||
|
||||
RetType results(vec.num_chunk());
|
||||
RetType results(vec.chunk_size());
|
||||
auto indexing_barrier = indexing_record.get_finished_ack();
|
||||
for (auto chunk_id = 0; chunk_id < indexing_barrier; ++chunk_id) {
|
||||
auto& result = results[chunk_id];
|
||||
auto indexing = entry.get_indexing(chunk_id);
|
||||
auto data = index_func(indexing);
|
||||
result = std::move(*data);
|
||||
result = ~std::move(*data);
|
||||
Assert(result.size() == segcore::DefaultElementPerChunk);
|
||||
}
|
||||
|
||||
for (auto chunk_id = indexing_barrier; chunk_id < vec.num_chunk(); ++chunk_id) {
|
||||
for (auto chunk_id = indexing_barrier; chunk_id < vec.chunk_size(); ++chunk_id) {
|
||||
auto& result = results[chunk_id];
|
||||
result.resize(segcore::DefaultElementPerChunk);
|
||||
auto chunk = vec.get_chunk(chunk_id);
|
||||
|
@ -125,32 +126,32 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType {
|
|||
switch (op) {
|
||||
case OpType::Equal: {
|
||||
auto index_func = [val](Index* index) { return index->In(1, &val); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x == val); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x == val); });
|
||||
}
|
||||
|
||||
case OpType::NotEqual: {
|
||||
auto index_func = [val](Index* index) { return index->NotIn(1, &val); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x != val); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x != val); });
|
||||
}
|
||||
|
||||
case OpType::GreaterEqual: {
|
||||
auto index_func = [val](Index* index) { return index->Range(val, Operator::GE); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x >= val); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x >= val); });
|
||||
}
|
||||
|
||||
case OpType::GreaterThan: {
|
||||
auto index_func = [val](Index* index) { return index->Range(val, Operator::GT); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x > val); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x > val); });
|
||||
}
|
||||
|
||||
case OpType::LessEqual: {
|
||||
auto index_func = [val](Index* index) { return index->Range(val, Operator::LE); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x <= val); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x <= val); });
|
||||
}
|
||||
|
||||
case OpType::LessThan: {
|
||||
auto index_func = [val](Index* index) { return index->Range(val, Operator::LT); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return (x < val); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val](T x) { return !(x < val); });
|
||||
}
|
||||
default: {
|
||||
PanicInfo("unsupported range node");
|
||||
|
@ -166,16 +167,16 @@ ExecExprVisitor::ExecRangeVisitorDispatcher(RangeExpr& expr_raw) -> RetType {
|
|||
if (false) {
|
||||
} else if (ops == std::make_tuple(OpType::GreaterThan, OpType::LessThan)) {
|
||||
auto index_func = [val1, val2](Index* index) { return index->Range(val1, false, val2, false); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 < x && x < val2); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 < x && x < val2); });
|
||||
} else if (ops == std::make_tuple(OpType::GreaterThan, OpType::LessEqual)) {
|
||||
auto index_func = [val1, val2](Index* index) { return index->Range(val1, false, val2, true); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 < x && x <= val2); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 < x && x <= val2); });
|
||||
} else if (ops == std::make_tuple(OpType::GreaterEqual, OpType::LessThan)) {
|
||||
auto index_func = [val1, val2](Index* index) { return index->Range(val1, true, val2, false); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 <= x && x < val2); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 <= x && x < val2); });
|
||||
} else if (ops == std::make_tuple(OpType::GreaterEqual, OpType::LessEqual)) {
|
||||
auto index_func = [val1, val2](Index* index) { return index->Range(val1, true, val2, true); };
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return (val1 <= x && x <= val2); });
|
||||
return ExecRangeVisitorImpl(expr, index_func, [val1, val2](T x) { return !(val1 <= x && x <= val2); });
|
||||
} else {
|
||||
PanicInfo("unsupported range node");
|
||||
}
|
||||
|
@ -225,79 +226,4 @@ ExecExprVisitor::visit(RangeExpr& expr) {
|
|||
ret_ = std::move(ret);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto
|
||||
ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> RetType {
|
||||
auto& expr = static_cast<TermExprImpl<T>&>(expr_raw);
|
||||
auto& records = segment_.get_insert_record();
|
||||
auto data_type = expr.data_type_;
|
||||
auto& schema = segment_.get_schema();
|
||||
auto field_offset_opt = schema.get_offset(expr.field_id_);
|
||||
Assert(field_offset_opt);
|
||||
auto field_offset = field_offset_opt.value();
|
||||
auto& field_meta = schema[field_offset];
|
||||
auto vec_ptr = records.get_entity<T>(field_offset);
|
||||
auto& vec = *vec_ptr;
|
||||
auto num_chunk = vec.num_chunk();
|
||||
RetType bitsets;
|
||||
|
||||
auto N = records.ack_responder_.GetAck();
|
||||
|
||||
// small batch
|
||||
for (int64_t chunk_id = 0; chunk_id < num_chunk; ++chunk_id) {
|
||||
auto& chunk = vec.get_chunk(chunk_id);
|
||||
|
||||
auto size = chunk_id == num_chunk - 1 ? N - chunk_id * segcore::DefaultElementPerChunk
|
||||
: segcore::DefaultElementPerChunk;
|
||||
|
||||
boost::dynamic_bitset<> bitset(segcore::DefaultElementPerChunk);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
auto value = chunk[i];
|
||||
bool is_in = std::binary_search(expr.terms_.begin(), expr.terms_.end(), value);
|
||||
bitset[i] = is_in;
|
||||
}
|
||||
bitsets.emplace_back(std::move(bitset));
|
||||
}
|
||||
return bitsets;
|
||||
}
|
||||
|
||||
void
|
||||
ExecExprVisitor::visit(TermExpr& expr) {
|
||||
auto& field_meta = segment_.get_schema()[expr.field_id_];
|
||||
Assert(expr.data_type_ == field_meta.get_data_type());
|
||||
RetType ret;
|
||||
switch (expr.data_type_) {
|
||||
case DataType::BOOL: {
|
||||
ret = ExecTermVisitorImpl<bool>(expr);
|
||||
break;
|
||||
}
|
||||
case DataType::INT8: {
|
||||
ret = ExecTermVisitorImpl<int8_t>(expr);
|
||||
break;
|
||||
}
|
||||
case DataType::INT16: {
|
||||
ret = ExecTermVisitorImpl<int16_t>(expr);
|
||||
break;
|
||||
}
|
||||
case DataType::INT32: {
|
||||
ret = ExecTermVisitorImpl<int32_t>(expr);
|
||||
break;
|
||||
}
|
||||
case DataType::INT64: {
|
||||
ret = ExecTermVisitorImpl<int64_t>(expr);
|
||||
break;
|
||||
}
|
||||
case DataType::FLOAT: {
|
||||
ret = ExecTermVisitorImpl<float>(expr);
|
||||
break;
|
||||
}
|
||||
case DataType::DOUBLE: {
|
||||
ret = ExecTermVisitorImpl<double>(expr);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PanicInfo("unsupported");
|
||||
}
|
||||
ret_ = std::move(ret);
|
||||
}
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -26,7 +26,7 @@ namespace impl {
|
|||
// WILL BE USED BY GENERATOR UNDER suvlim/core_gen/
|
||||
class ExecPlanNodeVisitor : PlanNodeVisitor {
|
||||
public:
|
||||
using RetType = QueryResult;
|
||||
using RetType = segcore::QueryResult;
|
||||
ExecPlanNodeVisitor(segcore::SegmentBase& segment, Timestamp timestamp, const PlaceholderGroup& placeholder_group)
|
||||
: segment_(segment), timestamp_(timestamp), placeholder_group_(placeholder_group) {
|
||||
}
|
||||
|
@ -75,22 +75,7 @@ ExecPlanNodeVisitor::visit(FloatVectorANNS& node) {
|
|||
|
||||
void
|
||||
ExecPlanNodeVisitor::visit(BinaryVectorANNS& node) {
|
||||
// TODO: optimize here, remove the dynamic cast
|
||||
assert(!ret_.has_value());
|
||||
auto segment = dynamic_cast<segcore::SegmentSmallIndex*>(&segment_);
|
||||
AssertInfo(segment, "support SegmentSmallIndex Only");
|
||||
RetType ret;
|
||||
auto& ph = placeholder_group_.at(0);
|
||||
auto src_data = ph.get_blob<uint8_t>();
|
||||
auto num_queries = ph.num_of_queries_;
|
||||
if (node.predicate_.has_value()) {
|
||||
auto bitmap = ExecExprVisitor(*segment).call_child(*node.predicate_.value());
|
||||
auto ptr = &bitmap;
|
||||
BinaryQueryBruteForceImpl(*segment, node.query_info_, src_data, num_queries, timestamp_, ptr, ret);
|
||||
} else {
|
||||
BinaryQueryBruteForceImpl(*segment, node.query_info_, src_data, num_queries, timestamp_, std::nullopt, ret);
|
||||
}
|
||||
ret_ = ret;
|
||||
// TODO
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -73,24 +73,7 @@ ShowPlanNodeVisitor::visit(FloatVectorANNS& node) {
|
|||
|
||||
void
|
||||
ShowPlanNodeVisitor::visit(BinaryVectorANNS& node) {
|
||||
assert(!ret_);
|
||||
auto& info = node.query_info_;
|
||||
Json json_body{
|
||||
{"node_type", "BinaryVectorANNS"}, //
|
||||
{"metric_type", info.metric_type_}, //
|
||||
{"field_id_", info.field_id_}, //
|
||||
{"topK", info.topK_}, //
|
||||
{"search_params", info.search_params_}, //
|
||||
{"placeholder_tag", node.placeholder_tag_}, //
|
||||
};
|
||||
if (node.predicate_.has_value()) {
|
||||
ShowExprVisitor expr_show;
|
||||
Assert(node.predicate_.value());
|
||||
json_body["predicate"] = expr_show.call_child(node.predicate_->operator*());
|
||||
} else {
|
||||
json_body["predicate"] = "None";
|
||||
}
|
||||
ret_ = json_body;
|
||||
// TODO
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
||||
|
|
|
@ -123,10 +123,9 @@ Collection::CreateIndex(std::string& index_config) {
|
|||
void
|
||||
Collection::parse() {
|
||||
if (collection_proto_.empty()) {
|
||||
// TODO: remove hard code use unittests are ready
|
||||
std::cout << "WARN: Use default schema" << std::endl;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
schema->AddField("age", DataType::INT32);
|
||||
schema_ = schema;
|
||||
return;
|
||||
|
|
|
@ -196,7 +196,7 @@ class ConcurrentVectorImpl : public VectorBase {
|
|||
}
|
||||
|
||||
ssize_t
|
||||
num_chunk() const {
|
||||
chunk_size() const {
|
||||
return chunks_.size();
|
||||
}
|
||||
|
||||
|
@ -226,14 +226,8 @@ class ConcurrentVector : public ConcurrentVectorImpl<Type, true> {
|
|||
using ConcurrentVectorImpl<Type, true>::ConcurrentVectorImpl;
|
||||
};
|
||||
|
||||
class VectorTrait {};
|
||||
|
||||
class FloatVector : public VectorTrait {
|
||||
using embedded_type = float;
|
||||
};
|
||||
class BinaryVector : public VectorTrait {
|
||||
using embedded_type = uint8_t;
|
||||
};
|
||||
class FloatVector {};
|
||||
class BinaryVector {};
|
||||
|
||||
template <>
|
||||
class ConcurrentVector<FloatVector> : public ConcurrentVectorImpl<float, false> {
|
||||
|
|
|
@ -24,7 +24,7 @@ VecIndexingEntry::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const Vector
|
|||
|
||||
auto source = dynamic_cast<const ConcurrentVector<FloatVector>*>(vec_base);
|
||||
Assert(source);
|
||||
auto chunk_size = source->num_chunk();
|
||||
auto chunk_size = source->chunk_size();
|
||||
assert(ack_end <= chunk_size);
|
||||
auto conf = get_build_conf();
|
||||
data_.grow_to_at_least(ack_end);
|
||||
|
@ -85,9 +85,11 @@ IndexingRecord::UpdateResourceAck(int64_t chunk_ack, const InsertRecord& record)
|
|||
template <typename T>
|
||||
void
|
||||
ScalarIndexingEntry<T>::BuildIndexRange(int64_t ack_beg, int64_t ack_end, const VectorBase* vec_base) {
|
||||
auto dim = field_meta_.get_dim();
|
||||
|
||||
auto source = dynamic_cast<const ConcurrentVector<T>*>(vec_base);
|
||||
Assert(source);
|
||||
auto chunk_size = source->num_chunk();
|
||||
auto chunk_size = source->chunk_size();
|
||||
assert(ack_end <= chunk_size);
|
||||
data_.grow_to_at_least(ack_end);
|
||||
for (int chunk_id = ack_beg; chunk_id < ack_end; chunk_id++) {
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace milvus {
|
|||
namespace segcore {
|
||||
// using engine::DataChunk;
|
||||
// using engine::DataChunkPtr;
|
||||
using QueryResult = milvus::QueryResult;
|
||||
using engine::QueryResult;
|
||||
struct RowBasedRawData {
|
||||
void* raw_data; // schema
|
||||
int sizeof_per_row; // alignment
|
||||
|
|
|
@ -467,16 +467,16 @@ SegmentNaive::BuildVecIndexImpl(const IndexMeta::Entry& entry) {
|
|||
auto dim = field.get_dim();
|
||||
|
||||
auto indexing = knowhere::VecIndexFactory::GetInstance().CreateVecIndex(entry.type, entry.mode);
|
||||
auto chunk_size = record_.uids_.num_chunk();
|
||||
auto chunk_size = record_.uids_.chunk_size();
|
||||
|
||||
auto& uids = record_.uids_;
|
||||
auto entities = record_.get_entity<FloatVector>(offset);
|
||||
|
||||
std::vector<knowhere::DatasetPtr> datasets;
|
||||
for (int chunk_id = 0; chunk_id < uids.num_chunk(); ++chunk_id) {
|
||||
for (int chunk_id = 0; chunk_id < uids.chunk_size(); ++chunk_id) {
|
||||
auto entities_chunk = entities->get_chunk(chunk_id).data();
|
||||
int64_t count = chunk_id == uids.num_chunk() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
|
||||
: DefaultElementPerChunk;
|
||||
int64_t count = chunk_id == uids.chunk_size() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
|
||||
: DefaultElementPerChunk;
|
||||
datasets.push_back(knowhere::GenDataset(count, dim, entities_chunk));
|
||||
}
|
||||
for (auto& ds : datasets) {
|
||||
|
|
|
@ -241,10 +241,10 @@ SegmentSmallIndex::BuildVecIndexImpl(const IndexMeta::Entry& entry) {
|
|||
auto entities = record_.get_entity<FloatVector>(offset);
|
||||
|
||||
std::vector<knowhere::DatasetPtr> datasets;
|
||||
for (int chunk_id = 0; chunk_id < uids.num_chunk(); ++chunk_id) {
|
||||
for (int chunk_id = 0; chunk_id < uids.chunk_size(); ++chunk_id) {
|
||||
auto entities_chunk = entities->get_chunk(chunk_id).data();
|
||||
int64_t count = chunk_id == uids.num_chunk() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
|
||||
: DefaultElementPerChunk;
|
||||
int64_t count = chunk_id == uids.chunk_size() - 1 ? record_.reserved - chunk_id * DefaultElementPerChunk
|
||||
: DefaultElementPerChunk;
|
||||
datasets.push_back(knowhere::GenDataset(count, dim, entities_chunk));
|
||||
}
|
||||
for (auto& ds : datasets) {
|
||||
|
|
|
@ -42,7 +42,7 @@ DeleteSegment(CSegmentBase segment) {
|
|||
|
||||
void
|
||||
DeleteQueryResult(CQueryResult query_result) {
|
||||
auto res = (milvus::QueryResult*)query_result;
|
||||
auto res = (milvus::segcore::QueryResult*)query_result;
|
||||
delete res;
|
||||
}
|
||||
|
||||
|
@ -134,7 +134,7 @@ Search(CSegmentBase c_segment,
|
|||
placeholder_groups.push_back((const milvus::query::PlaceholderGroup*)c_placeholder_groups[i]);
|
||||
}
|
||||
|
||||
auto query_result = std::make_unique<milvus::QueryResult>();
|
||||
auto query_result = std::make_unique<milvus::segcore::QueryResult>();
|
||||
|
||||
auto status = CStatus();
|
||||
try {
|
||||
|
|
|
@ -42,11 +42,8 @@ EasyAssertInfo(
|
|||
|
||||
[[noreturn]] void
|
||||
ThrowWithTrace(const std::exception& exception) {
|
||||
if (typeid(exception) == typeid(WrappedRuntimError)) {
|
||||
throw exception;
|
||||
}
|
||||
auto err_msg = exception.what() + std::string("\n") + EasyStackTrace();
|
||||
throw WrappedRuntimError(err_msg);
|
||||
throw std::runtime_error(err_msg);
|
||||
}
|
||||
|
||||
} // namespace milvus::impl
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
|
||||
#pragma once
|
||||
#include <string_view>
|
||||
#include <stdexcept>
|
||||
#include <exception>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
@ -23,10 +22,6 @@ void
|
|||
EasyAssertInfo(
|
||||
bool value, std::string_view expr_str, std::string_view filename, int lineno, std::string_view extra_info);
|
||||
|
||||
class WrappedRuntimError : public std::runtime_error {
|
||||
using std::runtime_error::runtime_error;
|
||||
};
|
||||
|
||||
[[noreturn]] void
|
||||
ThrowWithTrace(const std::exception& exception);
|
||||
|
||||
|
|
|
@ -26,5 +26,4 @@ target_link_libraries(all_tests
|
|||
pthread
|
||||
milvus_utils
|
||||
)
|
||||
|
||||
install (TARGETS all_tests DESTINATION unittest)
|
||||
|
|
|
@ -21,7 +21,7 @@ TEST(Binary, Insert) {
|
|||
int64_t num_queries = 10;
|
||||
int64_t topK = 5;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("vecbin", DataType::VECTOR_BINARY, 128, MetricType::METRIC_Jaccard);
|
||||
schema->AddField("vecbin", DataType::VECTOR_BINARY, 128);
|
||||
schema->AddField("age", DataType::INT64);
|
||||
auto dataset = DataGen(schema, N, 10);
|
||||
auto segment = CreateSegment(schema);
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <gtest/gtest.h>
|
|
@ -52,7 +52,7 @@ TEST(ConcurrentVector, TestSingle) {
|
|||
c_vec.set_data(total_count, vec.data(), insert_size);
|
||||
total_count += insert_size;
|
||||
}
|
||||
ASSERT_EQ(c_vec.num_chunk(), (total_count + 31) / 32);
|
||||
ASSERT_EQ(c_vec.chunk_size(), (total_count + 31) / 32);
|
||||
for (int i = 0; i < total_count; ++i) {
|
||||
for (int d = 0; d < dim; ++d) {
|
||||
auto std_data = d + i * dim;
|
||||
|
|
|
@ -98,49 +98,7 @@ TEST(Expr, Range) {
|
|||
}
|
||||
})";
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("age", DataType::INT32);
|
||||
auto plan = CreatePlan(*schema, dsl_string);
|
||||
ShowPlanNodeVisitor shower;
|
||||
Assert(plan->tag2field_.at("$0") == "fakevec");
|
||||
auto out = shower.call_child(*plan->plan_node_);
|
||||
std::cout << out.dump(4);
|
||||
}
|
||||
|
||||
TEST(Expr, RangeBinary) {
|
||||
SUCCEED();
|
||||
using namespace milvus;
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
std::string dsl_string = R"(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"age": {
|
||||
"GT": 1,
|
||||
"LT": 100
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "Jaccard",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_BINARY, 512, MetricType::METRIC_Jaccard);
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
schema->AddField("age", DataType::INT32);
|
||||
auto plan = CreatePlan(*schema, dsl_string);
|
||||
ShowPlanNodeVisitor shower;
|
||||
|
@ -182,7 +140,7 @@ TEST(Expr, InvalidRange) {
|
|||
}
|
||||
})";
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
schema->AddField("age", DataType::INT32);
|
||||
ASSERT_ANY_THROW(CreatePlan(*schema, dsl_string));
|
||||
}
|
||||
|
@ -221,7 +179,7 @@ TEST(Expr, InvalidDSL) {
|
|||
})";
|
||||
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
schema->AddField("age", DataType::INT32);
|
||||
ASSERT_ANY_THROW(CreatePlan(*schema, dsl_string));
|
||||
}
|
||||
|
@ -231,7 +189,7 @@ TEST(Expr, ShowExecutor) {
|
|||
using namespace milvus::segcore;
|
||||
auto node = std::make_unique<FloatVectorANNS>();
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
int64_t num_queries = 100L;
|
||||
auto raw_data = DataGen(schema, num_queries);
|
||||
auto& info = node->query_info_;
|
||||
|
@ -290,7 +248,7 @@ TEST(Expr, TestRange) {
|
|||
}
|
||||
})";
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
schema->AddField("age", DataType::INT32);
|
||||
|
||||
auto seg = CreateSegment(schema);
|
||||
|
@ -321,88 +279,7 @@ TEST(Expr, TestRange) {
|
|||
auto ans = final[vec_id][offset];
|
||||
|
||||
auto val = age_col[i];
|
||||
auto ref = ref_func(val);
|
||||
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Expr, TestTerm) {
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
auto vec_2k_3k = [] {
|
||||
std::string buf = "[";
|
||||
for (int i = 2000; i < 3000 - 1; ++i) {
|
||||
buf += std::to_string(i) + ", ";
|
||||
}
|
||||
buf += std::to_string(2999) + "]";
|
||||
return buf;
|
||||
}();
|
||||
|
||||
std::vector<std::tuple<std::string, std::function<bool(int)>>> testcases = {
|
||||
{R"([2000, 3000])", [](int v) { return v == 2000 || v == 3000; }},
|
||||
{R"([2000])", [](int v) { return v == 2000; }},
|
||||
{R"([3000])", [](int v) { return v == 3000; }},
|
||||
{vec_2k_3k, [](int v) { return 2000 <= v && v < 3000; }},
|
||||
};
|
||||
|
||||
std::string dsl_string_tmp = R"(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"term": {
|
||||
"age": @@@@
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("age", DataType::INT32);
|
||||
|
||||
auto seg = CreateSegment(schema);
|
||||
int N = 10000;
|
||||
std::vector<int> age_col;
|
||||
int num_iters = 100;
|
||||
for (int iter = 0; iter < num_iters; ++iter) {
|
||||
auto raw_data = DataGen(schema, N, iter);
|
||||
auto new_age_col = raw_data.get_col<int>(1);
|
||||
age_col.insert(age_col.end(), new_age_col.begin(), new_age_col.end());
|
||||
seg->PreInsert(N);
|
||||
seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_);
|
||||
}
|
||||
|
||||
auto seg_promote = dynamic_cast<SegmentSmallIndex*>(seg.get());
|
||||
ExecExprVisitor visitor(*seg_promote);
|
||||
for (auto [clause, ref_func] : testcases) {
|
||||
auto loc = dsl_string_tmp.find("@@@@");
|
||||
auto dsl_string = dsl_string_tmp;
|
||||
dsl_string.replace(loc, 4, clause);
|
||||
auto plan = CreatePlan(*schema, dsl_string);
|
||||
auto final = visitor.call_child(*plan->plan_node_->predicate_.value());
|
||||
EXPECT_EQ(final.size(), upper_div(N * num_iters, DefaultElementPerChunk));
|
||||
|
||||
for (int i = 0; i < N * num_iters; ++i) {
|
||||
auto vec_id = i / DefaultElementPerChunk;
|
||||
auto offset = i % DefaultElementPerChunk;
|
||||
auto ans = final[vec_id][offset];
|
||||
|
||||
auto val = age_col[i];
|
||||
auto ref = ref_func(val);
|
||||
auto ref = !ref_func(val);
|
||||
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -235,14 +235,14 @@ TEST(Indexing, IVFFlatNM) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(Indexing, BinaryBruteForce) {
|
||||
TEST(Indexing, DISABLED_BinaryBruteForce) {
|
||||
int64_t N = 100000;
|
||||
int64_t num_queries = 10;
|
||||
int64_t topk = 5;
|
||||
int64_t dim = 512;
|
||||
int64_t dim = 64;
|
||||
auto result_count = topk * num_queries;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("vecbin", DataType::VECTOR_BINARY, dim, MetricType::METRIC_Jaccard);
|
||||
schema->AddField("vecbin", DataType::VECTOR_BINARY, dim);
|
||||
schema->AddField("age", DataType::INT64);
|
||||
auto dataset = DataGen(schema, N, 10);
|
||||
vector<float> distances(result_count);
|
||||
|
@ -250,16 +250,8 @@ TEST(Indexing, BinaryBruteForce) {
|
|||
auto bin_vec = dataset.get_col<uint8_t>(0);
|
||||
auto line_sizeof = schema->operator[](0).get_sizeof();
|
||||
auto query_data = 1024 * line_sizeof + bin_vec.data();
|
||||
query::dataset::BinaryQueryDataset query_dataset{
|
||||
faiss::MetricType::METRIC_Jaccard, //
|
||||
num_queries, //
|
||||
topk, //
|
||||
line_sizeof, //
|
||||
query_data //
|
||||
};
|
||||
|
||||
query::BinarySearchBruteForce(query_dataset, bin_vec.data(), N, distances.data(), ids.data());
|
||||
|
||||
query::BinarySearchBruteForce(faiss::MetricType::METRIC_Jaccard, line_sizeof, bin_vec.data(), N, topk, num_queries,
|
||||
query_data, distances.data(), ids.data());
|
||||
QueryResult qr;
|
||||
qr.num_queries_ = num_queries;
|
||||
qr.topK_ = topk;
|
||||
|
@ -272,78 +264,76 @@ TEST(Indexing, BinaryBruteForce) {
|
|||
[
|
||||
[
|
||||
"1024->0.000000",
|
||||
"43190->0.578804",
|
||||
"5255->0.586207",
|
||||
"23247->0.586486",
|
||||
"4936->0.588889"
|
||||
"86966->0.395349",
|
||||
"24843->0.404762",
|
||||
"13806->0.416667",
|
||||
"44313->0.421053"
|
||||
],
|
||||
[
|
||||
"1025->0.000000",
|
||||
"15147->0.562162",
|
||||
"49910->0.564304",
|
||||
"67435->0.567867",
|
||||
"38292->0.569921"
|
||||
"14226->0.348837",
|
||||
"1488->0.365854",
|
||||
"47337->0.377778",
|
||||
"20913->0.377778"
|
||||
],
|
||||
[
|
||||
"1026->0.000000",
|
||||
"15332->0.569061",
|
||||
"56391->0.572559",
|
||||
"17187->0.572603",
|
||||
"26988->0.573771"
|
||||
"81882->0.386364",
|
||||
"9215->0.409091",
|
||||
"95024->0.409091",
|
||||
"54987->0.414634"
|
||||
],
|
||||
[
|
||||
"1027->0.000000",
|
||||
"4502->0.559585",
|
||||
"25879->0.566234",
|
||||
"66937->0.566489",
|
||||
"21228->0.566845"
|
||||
"68981->0.394737",
|
||||
"75528->0.404762",
|
||||
"68794->0.405405",
|
||||
"21975->0.425000"
|
||||
],
|
||||
[
|
||||
"1028->0.000000",
|
||||
"38490->0.578804",
|
||||
"12946->0.581717",
|
||||
"31677->0.582173",
|
||||
"94474->0.583569"
|
||||
"90290->0.375000",
|
||||
"34309->0.394737",
|
||||
"58559->0.400000",
|
||||
"33865->0.400000"
|
||||
],
|
||||
[
|
||||
"1029->0.000000",
|
||||
"59011->0.551630",
|
||||
"82575->0.555263",
|
||||
"42914->0.561828",
|
||||
"23705->0.564171"
|
||||
"62722->0.388889",
|
||||
"89070->0.394737",
|
||||
"18528->0.414634",
|
||||
"94971->0.421053"
|
||||
],
|
||||
[
|
||||
"1030->0.000000",
|
||||
"39782->0.579946",
|
||||
"65553->0.589947",
|
||||
"82154->0.590028",
|
||||
"13374->0.590164"
|
||||
"67402->0.333333",
|
||||
"3988->0.347826",
|
||||
"86376->0.354167",
|
||||
"84381->0.361702"
|
||||
],
|
||||
[
|
||||
"1031->0.000000",
|
||||
"47826->0.582873",
|
||||
"72669->0.587432",
|
||||
"334->0.588076",
|
||||
"80652->0.589333"
|
||||
"81569->0.325581",
|
||||
"12715->0.347826",
|
||||
"40332->0.363636",
|
||||
"21037->0.372093"
|
||||
],
|
||||
[
|
||||
"1032->0.000000",
|
||||
"31968->0.573034",
|
||||
"63545->0.575758",
|
||||
"76913->0.575916",
|
||||
"6286->0.576000"
|
||||
"60536->0.428571",
|
||||
"93293->0.432432",
|
||||
"70969->0.435897",
|
||||
"64048->0.450000"
|
||||
],
|
||||
[
|
||||
"1033->0.000000",
|
||||
"95635->0.570248",
|
||||
"93439->0.574866",
|
||||
"6709->0.578534",
|
||||
"6367->0.579634"
|
||||
"99022->0.394737",
|
||||
"11763->0.405405",
|
||||
"50073->0.428571",
|
||||
"97118->0.428571"
|
||||
]
|
||||
]
|
||||
]
|
||||
)");
|
||||
auto json_str = json.dump(2);
|
||||
auto ref_str = ref.dump(2);
|
||||
ASSERT_EQ(json_str, ref_str);
|
||||
ASSERT_EQ(json, ref);
|
||||
}
|
||||
|
|
|
@ -72,7 +72,7 @@ TEST(Query, ShowExecutor) {
|
|||
using namespace milvus;
|
||||
auto node = std::make_unique<FloatVectorANNS>();
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
int64_t num_queries = 100L;
|
||||
auto raw_data = DataGen(schema, num_queries);
|
||||
auto& info = node->query_info_;
|
||||
|
@ -98,7 +98,7 @@ TEST(Query, DSL) {
|
|||
"must": [
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"Vec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
|
@ -113,7 +113,7 @@ TEST(Query, DSL) {
|
|||
})";
|
||||
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
|
||||
auto plan = CreatePlan(*schema, dsl_string);
|
||||
auto res = shower.call_child(*plan->plan_node_);
|
||||
|
@ -123,7 +123,7 @@ TEST(Query, DSL) {
|
|||
{
|
||||
"bool": {
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"Vec": {
|
||||
"metric_type": "L2",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
|
@ -159,7 +159,7 @@ TEST(Query, ParsePlaceholderGroup) {
|
|||
})";
|
||||
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
auto plan = CreatePlan(*schema, dsl_string);
|
||||
int64_t num_queries = 100000;
|
||||
int dim = 16;
|
||||
|
@ -172,7 +172,7 @@ TEST(Query, ExecWithPredicate) {
|
|||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
schema->AddField("age", DataType::FLOAT);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
|
@ -217,8 +217,8 @@ TEST(Query, ExecWithPredicate) {
|
|||
int topk = 5;
|
||||
|
||||
Json json = QueryResultToJson(qr);
|
||||
auto ref = Json::parse(R"(
|
||||
[
|
||||
|
||||
auto ref = Json::parse(R"([
|
||||
[
|
||||
[
|
||||
"980486->3.149221",
|
||||
|
@ -257,14 +257,15 @@ TEST(Query, ExecWithPredicate) {
|
|||
]
|
||||
]
|
||||
])");
|
||||
ASSERT_EQ(json.dump(2), ref.dump(2));
|
||||
|
||||
ASSERT_EQ(json, ref);
|
||||
}
|
||||
|
||||
TEST(Query, ExecWithoutPredicate) {
|
||||
TEST(Query, ExecWihtoutPredicate) {
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
schema->AddField("age", DataType::FLOAT);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
|
@ -300,49 +301,18 @@ TEST(Query, ExecWithoutPredicate) {
|
|||
segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr);
|
||||
std::vector<std::vector<std::string>> results;
|
||||
int topk = 5;
|
||||
auto json = QueryResultToJson(qr);
|
||||
auto ref = Json::parse(R"(
|
||||
[
|
||||
[
|
||||
[
|
||||
"980486->3.149221",
|
||||
"318367->3.661235",
|
||||
"302798->4.553688",
|
||||
"321424->4.757450",
|
||||
"565529->5.083780"
|
||||
],
|
||||
[
|
||||
"233390->7.931535",
|
||||
"238958->8.109344",
|
||||
"230645->8.439169",
|
||||
"901939->8.658772",
|
||||
"380328->8.731251"
|
||||
],
|
||||
[
|
||||
"749862->3.398494",
|
||||
"701321->3.632437",
|
||||
"897246->3.749835",
|
||||
"750683->3.897577",
|
||||
"105995->4.073595"
|
||||
],
|
||||
[
|
||||
"138274->3.454446",
|
||||
"124548->3.783290",
|
||||
"840855->4.782170",
|
||||
"936719->5.026924",
|
||||
"709627->5.063170"
|
||||
],
|
||||
[
|
||||
"810401->3.926393",
|
||||
"46575->4.054171",
|
||||
"201740->4.274491",
|
||||
"669040->4.399628",
|
||||
"231500->4.831223"
|
||||
]
|
||||
]
|
||||
]
|
||||
)");
|
||||
ASSERT_EQ(json.dump(2), ref.dump(2));
|
||||
for (int q = 0; q < num_queries; ++q) {
|
||||
std::vector<std::string> result;
|
||||
for (int k = 0; k < topk; ++k) {
|
||||
int index = q * topk + k;
|
||||
result.emplace_back(std::to_string(qr.result_ids_[index]) + "->" +
|
||||
std::to_string(qr.result_distances_[index]));
|
||||
}
|
||||
results.emplace_back(std::move(result));
|
||||
}
|
||||
|
||||
Json json{results};
|
||||
std::cout << json.dump(2);
|
||||
}
|
||||
|
||||
TEST(Query, FillSegment) {
|
||||
|
@ -361,9 +331,6 @@ TEST(Query, FillSegment) {
|
|||
auto param = field->add_type_params();
|
||||
param->set_key("dim");
|
||||
param->set_value("16");
|
||||
auto iparam = field->add_index_params();
|
||||
iparam->set_key("metric_type");
|
||||
iparam->set_value("L2");
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -425,57 +392,3 @@ TEST(Query, FillSegment) {
|
|||
++std_index;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Query, ExecWithPredicateBinary) {
|
||||
using namespace milvus::query;
|
||||
using namespace milvus::segcore;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_BINARY, 512, MetricType::METRIC_Jaccard);
|
||||
schema->AddField("age", DataType::FLOAT);
|
||||
std::string dsl = R"({
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"range": {
|
||||
"age": {
|
||||
"GE": -1,
|
||||
"LT": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"vector": {
|
||||
"fakevec": {
|
||||
"metric_type": "Jaccard",
|
||||
"params": {
|
||||
"nprobe": 10
|
||||
},
|
||||
"query": "$0",
|
||||
"topk": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})";
|
||||
int64_t N = 1000 * 1000;
|
||||
auto dataset = DataGen(schema, N);
|
||||
auto segment = std::make_unique<SegmentSmallIndex>(schema);
|
||||
segment->PreInsert(N);
|
||||
segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_);
|
||||
auto vec_ptr = dataset.get_col<uint8_t>(0);
|
||||
|
||||
auto plan = CreatePlan(*schema, dsl);
|
||||
auto num_queries = 5;
|
||||
auto ph_group_raw = CreateBinaryPlaceholderGroupFromBlob(num_queries, 512, vec_ptr.data() + 1024 * 512 / 8);
|
||||
auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
QueryResult qr;
|
||||
Timestamp time = 1000000;
|
||||
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
|
||||
segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr);
|
||||
int topk = 5;
|
||||
|
||||
Json json = QueryResultToJson(qr);
|
||||
std::cout << json.dump(2);
|
||||
// ASSERT_EQ(json.dump(2), ref.dump(2));
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ TEST(SegmentCoreTest, NormalDistributionTest) {
|
|||
using namespace milvus::segcore;
|
||||
using namespace milvus::engine;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
schema->AddField("age", DataType::INT32);
|
||||
int N = 1000 * 1000;
|
||||
auto [raw_data, timestamps, uids] = generate_data(N);
|
||||
|
@ -76,7 +76,7 @@ TEST(SegmentCoreTest, MockTest) {
|
|||
using namespace milvus::segcore;
|
||||
using namespace milvus::engine;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
schema->AddField("age", DataType::INT32);
|
||||
std::vector<char> raw_data;
|
||||
std::vector<Timestamp> timestamps;
|
||||
|
@ -116,7 +116,7 @@ TEST(SegmentCoreTest, SmallIndex) {
|
|||
using namespace milvus::segcore;
|
||||
using namespace milvus::engine;
|
||||
auto schema = std::make_shared<Schema>();
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
|
||||
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||
schema->AddField("age", DataType::INT32);
|
||||
int N = 1024 * 1024;
|
||||
auto data = DataGen(schema, N);
|
||||
|
|
|
@ -31,14 +31,6 @@ struct GeneratedData {
|
|||
memcpy(ret.data(), target.data(), target.size());
|
||||
return ret;
|
||||
}
|
||||
template <typename T>
|
||||
auto
|
||||
get_mutable_col(int index) {
|
||||
auto& target = cols_.at(index);
|
||||
assert(target.size() == row_ids_.size() * sizeof(T));
|
||||
auto ptr = reinterpret_cast<T*>(target.data());
|
||||
return ptr;
|
||||
}
|
||||
|
||||
private:
|
||||
GeneratedData() = default;
|
||||
|
@ -66,9 +58,6 @@ GeneratedData::generate_rows(int N, SchemaPtr schema) {
|
|||
}
|
||||
}
|
||||
rows_ = std::move(result);
|
||||
raw_.raw_data = rows_.data();
|
||||
raw_.sizeof_per_row = schema->get_total_sizeof();
|
||||
raw_.count = N;
|
||||
}
|
||||
|
||||
inline GeneratedData
|
||||
|
@ -140,12 +129,14 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42) {
|
|||
}
|
||||
GeneratedData res;
|
||||
res.cols_ = std::move(cols);
|
||||
res.generate_rows(N, schema);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
res.row_ids_.push_back(i);
|
||||
res.timestamps_.push_back(i);
|
||||
}
|
||||
|
||||
res.generate_rows(N, schema);
|
||||
res.raw_.raw_data = res.rows_.data();
|
||||
res.raw_.sizeof_per_row = schema->get_total_sizeof();
|
||||
res.raw_.count = N;
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
|
@ -176,7 +167,7 @@ CreateBinaryPlaceholderGroup(int64_t num_queries, int64_t dim, int64_t seed = 42
|
|||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::VECTOR_BINARY);
|
||||
value->set_type(ser::PlaceholderType::VECTOR_FLOAT);
|
||||
std::default_random_engine e(seed);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::vector<uint8_t> vec;
|
||||
|
@ -184,27 +175,7 @@ CreateBinaryPlaceholderGroup(int64_t num_queries, int64_t dim, int64_t seed = 42
|
|||
vec.push_back(e());
|
||||
}
|
||||
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
|
||||
value->add_values(vec.data(), vec.size());
|
||||
}
|
||||
return raw_group;
|
||||
}
|
||||
|
||||
inline auto
|
||||
CreateBinaryPlaceholderGroupFromBlob(int64_t num_queries, int64_t dim, const uint8_t* ptr) {
|
||||
assert(dim % 8 == 0);
|
||||
namespace ser = milvus::proto::service;
|
||||
ser::PlaceholderGroup raw_group;
|
||||
auto value = raw_group.add_placeholders();
|
||||
value->set_tag("$0");
|
||||
value->set_type(ser::PlaceholderType::VECTOR_BINARY);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::vector<uint8_t> vec;
|
||||
for (int d = 0; d < dim / 8; ++d) {
|
||||
vec.push_back(*ptr);
|
||||
++ptr;
|
||||
}
|
||||
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
|
||||
value->add_values(vec.data(), vec.size());
|
||||
value->add_values(vec.data(), vec.size() * sizeof(float));
|
||||
}
|
||||
return raw_group;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
package mockkv
|
||||
|
||||
import (
|
||||
memkv "github.com/zilliztech/milvus-distributed/internal/kv/mem"
|
||||
)
|
||||
|
||||
// use MemoryKV to mock EtcdKV
|
||||
func NewEtcdKV() *memkv.MemoryKV {
|
||||
return memkv.NewMemoryKV()
|
||||
}
|
||||
|
||||
func NewMemoryKV() *memkv.MemoryKV {
|
||||
return memkv.NewMemoryKV()
|
||||
}
|
|
@ -19,7 +19,6 @@ import (
|
|||
|
||||
func TestMaster_CollectionTask(t *testing.T) {
|
||||
Init()
|
||||
refreshMasterAddress()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
defer cancel()
|
||||
|
@ -65,10 +64,10 @@ func TestMaster_CollectionTask(t *testing.T) {
|
|||
|
||||
svr, err := CreateServer(ctx)
|
||||
assert.Nil(t, err)
|
||||
err = svr.Run(int64(Params.Port))
|
||||
err = svr.Run(10002)
|
||||
assert.Nil(t, err)
|
||||
|
||||
conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock())
|
||||
conn, err := grpc.DialContext(ctx, "127.0.0.1:10002", grpc.WithInsecure(), grpc.WithBlock())
|
||||
assert.Nil(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@ import (
|
|||
|
||||
func TestMaster_ConfigTask(t *testing.T) {
|
||||
Init()
|
||||
refreshMasterAddress()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
defer cancel()
|
||||
|
@ -60,11 +59,11 @@ func TestMaster_ConfigTask(t *testing.T) {
|
|||
|
||||
svr, err := CreateServer(ctx)
|
||||
require.Nil(t, err)
|
||||
err = svr.Run(int64(Params.Port))
|
||||
err = svr.Run(10002)
|
||||
defer svr.Close()
|
||||
require.Nil(t, err)
|
||||
|
||||
conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock())
|
||||
conn, err := grpc.DialContext(ctx, "127.0.0.1:10002", grpc.WithInsecure(), grpc.WithBlock())
|
||||
require.Nil(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package master
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -11,6 +12,16 @@ import (
|
|||
var gTestTsoAllocator Allocator
|
||||
var gTestIDAllocator *GlobalIDAllocator
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
Params.Init()
|
||||
|
||||
etcdAddr := Params.EtcdAddress
|
||||
gTestTsoAllocator = NewGlobalTSOAllocator("timestamp", tsoutil.NewTSOKVBase([]string{etcdAddr}, "/test/root/kv", "tso"))
|
||||
gTestIDAllocator = NewGlobalIDAllocator("idTimestamp", tsoutil.NewTSOKVBase([]string{etcdAddr}, "/test/root/kv", "gid"))
|
||||
exitCode := m.Run()
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func TestGlobalTSOAllocator_Initialize(t *testing.T) {
|
||||
err := gTestTsoAllocator.Initialize()
|
||||
assert.Nil(t, err)
|
||||
|
|
|
@ -17,7 +17,6 @@ import (
|
|||
|
||||
func TestMaster_CreateCollection(t *testing.T) {
|
||||
Init()
|
||||
refreshMasterAddress()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
defer cancel()
|
||||
|
@ -63,10 +62,10 @@ func TestMaster_CreateCollection(t *testing.T) {
|
|||
|
||||
svr, err := CreateServer(ctx)
|
||||
assert.Nil(t, err)
|
||||
err = svr.Run(int64(Params.Port))
|
||||
err = svr.Run(10001)
|
||||
assert.Nil(t, err)
|
||||
|
||||
conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock())
|
||||
conn, err := grpc.DialContext(ctx, "127.0.0.1:10001", grpc.WithInsecure(), grpc.WithBlock())
|
||||
assert.Nil(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
|
|
|
@ -4,12 +4,9 @@ import (
|
|||
"context"
|
||||
"log"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/tsoutil"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
ms "github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
|
@ -23,42 +20,6 @@ import (
|
|||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
var testPORT = 53200
|
||||
|
||||
func genMasterTestPort() int64 {
|
||||
testPORT++
|
||||
return int64(testPORT)
|
||||
}
|
||||
|
||||
func refreshMasterAddress() {
|
||||
masterPort := genMasterTestPort()
|
||||
Params.Port = int(masterPort)
|
||||
masterAddr := makeMasterAddress(masterPort)
|
||||
Params.Address = masterAddr
|
||||
}
|
||||
|
||||
func makeMasterAddress(port int64) string {
|
||||
masterAddr := "127.0.0.1:" + strconv.FormatInt(port, 10)
|
||||
return masterAddr
|
||||
}
|
||||
|
||||
func makeNewChannalNames(names []string, suffix string) []string {
|
||||
var ret []string
|
||||
for _, name := range names {
|
||||
ret = append(ret, name+suffix)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func refreshChannelNames() {
|
||||
suffix := "_test" + strconv.FormatInt(rand.Int63n(100), 10)
|
||||
Params.DDChannelNames = makeNewChannalNames(Params.DDChannelNames, suffix)
|
||||
Params.WriteNodeTimeTickChannelNames = makeNewChannalNames(Params.WriteNodeTimeTickChannelNames, suffix)
|
||||
Params.InsertChannelNames = makeNewChannalNames(Params.InsertChannelNames, suffix)
|
||||
Params.K2SChannelNames = makeNewChannalNames(Params.K2SChannelNames, suffix)
|
||||
Params.ProxyTimeTickChannelNames = makeNewChannalNames(Params.ProxyTimeTickChannelNames, suffix)
|
||||
}
|
||||
|
||||
func receiveTimeTickMsg(stream *ms.MsgStream) bool {
|
||||
for {
|
||||
result := (*stream).Consume()
|
||||
|
@ -76,25 +37,13 @@ func getTimeTickMsgPack(ttmsgs [][2]uint64) *ms.MsgPack {
|
|||
return &msgPack
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
Init()
|
||||
refreshMasterAddress()
|
||||
refreshChannelNames()
|
||||
etcdAddr := Params.EtcdAddress
|
||||
gTestTsoAllocator = NewGlobalTSOAllocator("timestamp", tsoutil.NewTSOKVBase([]string{etcdAddr}, "/test/root/kv", "tso"))
|
||||
gTestIDAllocator = NewGlobalIDAllocator("idTimestamp", tsoutil.NewTSOKVBase([]string{etcdAddr}, "/test/root/kv", "gid"))
|
||||
exitCode := m.Run()
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func TestMaster(t *testing.T) {
|
||||
Init()
|
||||
refreshMasterAddress()
|
||||
pulsarAddr := Params.PulsarAddress
|
||||
Params.ProxyIDList = []UniqueID{0}
|
||||
//Param
|
||||
|
||||
// Creates server.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
svr, err := CreateServer(ctx)
|
||||
if err != nil {
|
||||
log.Print("create server failed", zap.Error(err))
|
||||
|
@ -149,7 +98,7 @@ func TestMaster(t *testing.T) {
|
|||
var k2sMsgstream ms.MsgStream = k2sMs
|
||||
assert.True(t, receiveTimeTickMsg(&k2sMsgstream))
|
||||
|
||||
conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock())
|
||||
conn, err := grpc.DialContext(ctx, "127.0.0.1:53100", grpc.WithInsecure(), grpc.WithBlock())
|
||||
assert.Nil(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
|
|
|
@ -55,8 +55,19 @@ var Params ParamTable
|
|||
func (p *ParamTable) Init() {
|
||||
// load yaml
|
||||
p.BaseTable.Init()
|
||||
|
||||
err := p.LoadYaml("advanced/master.yaml")
|
||||
err := p.LoadYaml("milvus.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = p.LoadYaml("advanced/channel.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = p.LoadYaml("advanced/master.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = p.LoadYaml("advanced/common.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -104,7 +115,15 @@ func (p *ParamTable) initAddress() {
|
|||
}
|
||||
|
||||
func (p *ParamTable) initPort() {
|
||||
p.Port = p.ParseInt("master.port")
|
||||
masterPort, err := p.Load("master.port")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
port, err := strconv.Atoi(masterPort)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.Port = port
|
||||
}
|
||||
|
||||
func (p *ParamTable) initEtcdAddress() {
|
||||
|
@ -148,40 +167,117 @@ func (p *ParamTable) initKvRootPath() {
|
|||
}
|
||||
|
||||
func (p *ParamTable) initTopicNum() {
|
||||
iRangeStr, err := p.Load("msgChannel.channelRange.insert")
|
||||
insertChannelRange, err := p.Load("msgChannel.channelRange.insert")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
rangeSlice := paramtable.ConvertRangeToIntRange(iRangeStr, ",")
|
||||
p.TopicNum = rangeSlice[1] - rangeSlice[0]
|
||||
|
||||
channelRange := strings.Split(insertChannelRange, ",")
|
||||
if len(channelRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(channelRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(channelRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
p.TopicNum = channelEnd
|
||||
}
|
||||
|
||||
func (p *ParamTable) initSegmentSize() {
|
||||
p.SegmentSize = p.ParseFloat("master.segment.size")
|
||||
threshold, err := p.Load("master.segment.size")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
segmentThreshold, err := strconv.ParseFloat(threshold, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.SegmentSize = segmentThreshold
|
||||
}
|
||||
|
||||
func (p *ParamTable) initSegmentSizeFactor() {
|
||||
p.SegmentSizeFactor = p.ParseFloat("master.segment.sizeFactor")
|
||||
segFactor, err := p.Load("master.segment.sizeFactor")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
factor, err := strconv.ParseFloat(segFactor, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.SegmentSizeFactor = factor
|
||||
}
|
||||
|
||||
func (p *ParamTable) initDefaultRecordSize() {
|
||||
p.DefaultRecordSize = p.ParseInt64("master.segment.defaultSizePerRecord")
|
||||
size, err := p.Load("master.segment.defaultSizePerRecord")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
res, err := strconv.ParseInt(size, 10, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.DefaultRecordSize = res
|
||||
}
|
||||
|
||||
func (p *ParamTable) initMinSegIDAssignCnt() {
|
||||
p.MinSegIDAssignCnt = p.ParseInt64("master.segment.minIDAssignCnt")
|
||||
size, err := p.Load("master.segment.minIDAssignCnt")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
res, err := strconv.ParseInt(size, 10, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.MinSegIDAssignCnt = res
|
||||
}
|
||||
|
||||
func (p *ParamTable) initMaxSegIDAssignCnt() {
|
||||
p.MaxSegIDAssignCnt = p.ParseInt64("master.segment.maxIDAssignCnt")
|
||||
size, err := p.Load("master.segment.maxIDAssignCnt")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
res, err := strconv.ParseInt(size, 10, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.MaxSegIDAssignCnt = res
|
||||
}
|
||||
|
||||
func (p *ParamTable) initSegIDAssignExpiration() {
|
||||
p.SegIDAssignExpiration = p.ParseInt64("master.segment.IDAssignExpiration")
|
||||
duration, err := p.Load("master.segment.IDAssignExpiration")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
res, err := strconv.ParseInt(duration, 10, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.SegIDAssignExpiration = res
|
||||
}
|
||||
|
||||
func (p *ParamTable) initQueryNodeNum() {
|
||||
p.QueryNodeNum = len(p.QueryNodeIDList())
|
||||
id, err := p.Load("nodeID.queryNodeIDList")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ids := strings.Split(id, ",")
|
||||
for _, i := range ids {
|
||||
_, err := strconv.ParseInt(i, 10, 64)
|
||||
if err != nil {
|
||||
log.Panicf("load proxy id list error, %s", err.Error())
|
||||
}
|
||||
}
|
||||
p.QueryNodeNum = len(ids)
|
||||
}
|
||||
|
||||
func (p *ParamTable) initQueryNodeStatsChannelName() {
|
||||
|
@ -193,7 +289,20 @@ func (p *ParamTable) initQueryNodeStatsChannelName() {
|
|||
}
|
||||
|
||||
func (p *ParamTable) initProxyIDList() {
|
||||
p.ProxyIDList = p.BaseTable.ProxyIDList()
|
||||
id, err := p.Load("nodeID.proxyIDList")
|
||||
if err != nil {
|
||||
log.Panicf("load proxy id list error, %s", err.Error())
|
||||
}
|
||||
ids := strings.Split(id, ",")
|
||||
idList := make([]typeutil.UniqueID, 0, len(ids))
|
||||
for _, i := range ids {
|
||||
v, err := strconv.ParseInt(i, 10, 64)
|
||||
if err != nil {
|
||||
log.Panicf("load proxy id list error, %s", err.Error())
|
||||
}
|
||||
idList = append(idList, typeutil.UniqueID(v))
|
||||
}
|
||||
p.ProxyIDList = idList
|
||||
}
|
||||
|
||||
func (p *ParamTable) initProxyTimeTickChannelNames() {
|
||||
|
@ -238,7 +347,20 @@ func (p *ParamTable) initSoftTimeTickBarrierInterval() {
|
|||
}
|
||||
|
||||
func (p *ParamTable) initWriteNodeIDList() {
|
||||
p.WriteNodeIDList = p.BaseTable.WriteNodeIDList()
|
||||
id, err := p.Load("nodeID.writeNodeIDList")
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
ids := strings.Split(id, ",")
|
||||
idlist := make([]typeutil.UniqueID, 0, len(ids))
|
||||
for _, i := range ids {
|
||||
v, err := strconv.ParseInt(i, 10, 64)
|
||||
if err != nil {
|
||||
log.Panicf("load proxy id list error, %s", err.Error())
|
||||
}
|
||||
idlist = append(idlist, typeutil.UniqueID(v))
|
||||
}
|
||||
p.WriteNodeIDList = idlist
|
||||
}
|
||||
|
||||
func (p *ParamTable) initWriteNodeTimeTickChannelNames() {
|
||||
|
@ -263,57 +385,81 @@ func (p *ParamTable) initWriteNodeTimeTickChannelNames() {
|
|||
}
|
||||
|
||||
func (p *ParamTable) initDDChannelNames() {
|
||||
prefix, err := p.Load("msgChannel.chanNamePrefix.dataDefinition")
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.dataDefinition")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
prefix += "-"
|
||||
iRangeStr, err := p.Load("msgChannel.channelRange.dataDefinition")
|
||||
id, err := p.Load("nodeID.queryNodeIDList")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Panicf("load query node id list error, %s", err.Error())
|
||||
}
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
ids := strings.Split(id, ",")
|
||||
channels := make([]string, 0, len(ids))
|
||||
for _, i := range ids {
|
||||
_, err := strconv.ParseInt(i, 10, 64)
|
||||
if err != nil {
|
||||
log.Panicf("load query node id list error, %s", err.Error())
|
||||
}
|
||||
channels = append(channels, ch+"-"+i)
|
||||
}
|
||||
p.DDChannelNames = ret
|
||||
p.DDChannelNames = channels
|
||||
}
|
||||
|
||||
func (p *ParamTable) initInsertChannelNames() {
|
||||
prefix, err := p.Load("msgChannel.chanNamePrefix.insert")
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.insert")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := p.Load("msgChannel.channelRange.insert")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
prefix += "-"
|
||||
iRangeStr, err := p.Load("msgChannel.channelRange.insert")
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.InsertChannelNames = ret
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
p.InsertChannelNames = channels
|
||||
}
|
||||
|
||||
func (p *ParamTable) initK2SChannelNames() {
|
||||
prefix, err := p.Load("msgChannel.chanNamePrefix.k2s")
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.k2s")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
prefix += "-"
|
||||
iRangeStr, err := p.Load("msgChannel.channelRange.k2s")
|
||||
id, err := p.Load("nodeID.writeNodeIDList")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Panicf("load write node id list error, %s", err.Error())
|
||||
}
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
ids := strings.Split(id, ",")
|
||||
channels := make([]string, 0, len(ids))
|
||||
for _, i := range ids {
|
||||
_, err := strconv.ParseInt(i, 10, 64)
|
||||
if err != nil {
|
||||
log.Panicf("load write node id list error, %s", err.Error())
|
||||
}
|
||||
channels = append(channels, ch+"-"+i)
|
||||
}
|
||||
p.K2SChannelNames = ret
|
||||
p.K2SChannelNames = channels
|
||||
}
|
||||
|
||||
func (p *ParamTable) initMaxPartitionNum() {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package master
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -12,111 +11,133 @@ func TestParamTable_Init(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestParamTable_Address(t *testing.T) {
|
||||
Params.Init()
|
||||
address := Params.Address
|
||||
assert.Equal(t, address, "localhost")
|
||||
}
|
||||
|
||||
func TestParamTable_Port(t *testing.T) {
|
||||
Params.Init()
|
||||
port := Params.Port
|
||||
assert.Equal(t, port, 53100)
|
||||
}
|
||||
|
||||
func TestParamTable_MetaRootPath(t *testing.T) {
|
||||
Params.Init()
|
||||
path := Params.MetaRootPath
|
||||
assert.Equal(t, path, "by-dev/meta")
|
||||
}
|
||||
|
||||
func TestParamTable_KVRootPath(t *testing.T) {
|
||||
Params.Init()
|
||||
path := Params.KvRootPath
|
||||
assert.Equal(t, path, "by-dev/kv")
|
||||
}
|
||||
|
||||
func TestParamTable_TopicNum(t *testing.T) {
|
||||
Params.Init()
|
||||
num := Params.TopicNum
|
||||
fmt.Println("TopicNum:", num)
|
||||
assert.Equal(t, num, 1)
|
||||
}
|
||||
|
||||
func TestParamTable_SegmentSize(t *testing.T) {
|
||||
Params.Init()
|
||||
size := Params.SegmentSize
|
||||
assert.Equal(t, size, float64(512))
|
||||
}
|
||||
|
||||
func TestParamTable_SegmentSizeFactor(t *testing.T) {
|
||||
Params.Init()
|
||||
factor := Params.SegmentSizeFactor
|
||||
assert.Equal(t, factor, 0.75)
|
||||
}
|
||||
|
||||
func TestParamTable_DefaultRecordSize(t *testing.T) {
|
||||
Params.Init()
|
||||
size := Params.DefaultRecordSize
|
||||
assert.Equal(t, size, int64(1024))
|
||||
}
|
||||
|
||||
func TestParamTable_MinSegIDAssignCnt(t *testing.T) {
|
||||
Params.Init()
|
||||
cnt := Params.MinSegIDAssignCnt
|
||||
assert.Equal(t, cnt, int64(1024))
|
||||
}
|
||||
|
||||
func TestParamTable_MaxSegIDAssignCnt(t *testing.T) {
|
||||
Params.Init()
|
||||
cnt := Params.MaxSegIDAssignCnt
|
||||
assert.Equal(t, cnt, int64(16384))
|
||||
}
|
||||
|
||||
func TestParamTable_SegIDAssignExpiration(t *testing.T) {
|
||||
Params.Init()
|
||||
expiration := Params.SegIDAssignExpiration
|
||||
assert.Equal(t, expiration, int64(2000))
|
||||
}
|
||||
|
||||
func TestParamTable_QueryNodeNum(t *testing.T) {
|
||||
Params.Init()
|
||||
num := Params.QueryNodeNum
|
||||
fmt.Println("QueryNodeNum", num)
|
||||
assert.Equal(t, num, 1)
|
||||
}
|
||||
|
||||
func TestParamTable_QueryNodeStatsChannelName(t *testing.T) {
|
||||
Params.Init()
|
||||
name := Params.QueryNodeStatsChannelName
|
||||
assert.Equal(t, name, "query-node-stats")
|
||||
}
|
||||
|
||||
func TestParamTable_ProxyIDList(t *testing.T) {
|
||||
Params.Init()
|
||||
ids := Params.ProxyIDList
|
||||
assert.Equal(t, len(ids), 1)
|
||||
assert.Equal(t, ids[0], int64(0))
|
||||
}
|
||||
|
||||
func TestParamTable_ProxyTimeTickChannelNames(t *testing.T) {
|
||||
Params.Init()
|
||||
names := Params.ProxyTimeTickChannelNames
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, names[0], "proxyTimeTick-0")
|
||||
}
|
||||
|
||||
func TestParamTable_MsgChannelSubName(t *testing.T) {
|
||||
Params.Init()
|
||||
name := Params.MsgChannelSubName
|
||||
assert.Equal(t, name, "master")
|
||||
}
|
||||
|
||||
func TestParamTable_SoftTimeTickBarrierInterval(t *testing.T) {
|
||||
Params.Init()
|
||||
interval := Params.SoftTimeTickBarrierInterval
|
||||
assert.Equal(t, interval, Timestamp(0x7d00000))
|
||||
}
|
||||
|
||||
func TestParamTable_WriteNodeIDList(t *testing.T) {
|
||||
Params.Init()
|
||||
ids := Params.WriteNodeIDList
|
||||
assert.Equal(t, len(ids), 1)
|
||||
assert.Equal(t, ids[0], int64(3))
|
||||
}
|
||||
|
||||
func TestParamTable_WriteNodeTimeTickChannelNames(t *testing.T) {
|
||||
Params.Init()
|
||||
names := Params.WriteNodeTimeTickChannelNames
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, names[0], "writeNodeTimeTick-3")
|
||||
}
|
||||
|
||||
func TestParamTable_InsertChannelNames(t *testing.T) {
|
||||
Params.Init()
|
||||
names := Params.InsertChannelNames
|
||||
assert.Equal(t, Params.TopicNum, len(names))
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, names[0], "insert-0")
|
||||
}
|
||||
|
||||
func TestParamTable_K2SChannelNames(t *testing.T) {
|
||||
Params.Init()
|
||||
names := Params.K2SChannelNames
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, names[0], "k2s-0")
|
||||
assert.Equal(t, names[0], "k2s-3")
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@ package master
|
|||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
@ -18,7 +20,6 @@ import (
|
|||
|
||||
func TestMaster_Partition(t *testing.T) {
|
||||
Init()
|
||||
refreshMasterAddress()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
defer cancel()
|
||||
|
@ -65,12 +66,13 @@ func TestMaster_Partition(t *testing.T) {
|
|||
DefaultPartitionTag: "_default",
|
||||
}
|
||||
|
||||
port := 10000 + rand.Intn(1000)
|
||||
svr, err := CreateServer(ctx)
|
||||
assert.Nil(t, err)
|
||||
err = svr.Run(int64(Params.Port))
|
||||
err = svr.Run(int64(port))
|
||||
assert.Nil(t, err)
|
||||
|
||||
conn, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock())
|
||||
conn, err := grpc.DialContext(ctx, "127.0.0.1:"+strconv.Itoa(port), grpc.WithInsecure(), grpc.WithBlock())
|
||||
assert.Nil(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ var master *Master
|
|||
var masterCancelFunc context.CancelFunc
|
||||
|
||||
func setup() {
|
||||
Init()
|
||||
Params.Init()
|
||||
etcdAddress := Params.EtcdAddress
|
||||
|
||||
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}})
|
||||
|
@ -218,8 +218,7 @@ func TestSegmentManager_SegmentStats(t *testing.T) {
|
|||
}
|
||||
|
||||
func startupMaster() {
|
||||
Init()
|
||||
refreshMasterAddress()
|
||||
Params.Init()
|
||||
etcdAddress := Params.EtcdAddress
|
||||
rootPath := "/test/root"
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
|
@ -232,6 +231,7 @@ func startupMaster() {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
Params = ParamTable{
|
||||
Address: Params.Address,
|
||||
Port: Params.Port,
|
||||
|
@ -272,7 +272,7 @@ func startupMaster() {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = master.Run(int64(Params.Port))
|
||||
err = master.Run(10013)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -289,7 +289,7 @@ func TestSegmentManager_RPC(t *testing.T) {
|
|||
defer shutdownMaster()
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
defer cancel()
|
||||
dialContext, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock())
|
||||
dialContext, err := grpc.DialContext(ctx, "127.0.0.1:10013", grpc.WithInsecure(), grpc.WithBlock())
|
||||
assert.Nil(t, err)
|
||||
defer dialContext.Close()
|
||||
client := masterpb.NewMasterClient(dialContext)
|
||||
|
|
|
@ -19,8 +19,19 @@ var Params ParamTable
|
|||
|
||||
func (pt *ParamTable) Init() {
|
||||
pt.BaseTable.Init()
|
||||
|
||||
err := pt.LoadYaml("advanced/proxy.yaml")
|
||||
err := pt.LoadYaml("milvus.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = pt.LoadYaml("advanced/proxy.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = pt.LoadYaml("advanced/channel.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = pt.LoadYaml("advanced/common.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -37,24 +48,15 @@ func (pt *ParamTable) Init() {
|
|||
pt.Save("_proxyID", proxyIDStr)
|
||||
}
|
||||
|
||||
func (pt *ParamTable) NetworkPort() int {
|
||||
return pt.ParseInt("proxy.port")
|
||||
}
|
||||
|
||||
func (pt *ParamTable) NetworkAddress() string {
|
||||
addr, err := pt.Load("proxy.address")
|
||||
func (pt *ParamTable) NetWorkAddress() string {
|
||||
addr, err := pt.Load("proxy.network.address")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
hostName, _ := net.LookupHost(addr)
|
||||
if len(hostName) <= 0 {
|
||||
if ip := net.ParseIP(addr); ip == nil {
|
||||
panic("invalid ip proxy.address")
|
||||
}
|
||||
if ip := net.ParseIP(addr); ip == nil {
|
||||
panic("invalid ip proxy.network.address")
|
||||
}
|
||||
|
||||
port, err := pt.Load("proxy.port")
|
||||
port, err := pt.Load("proxy.network.port")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -86,6 +88,23 @@ func (pt *ParamTable) ProxyNum() int {
|
|||
return len(ret)
|
||||
}
|
||||
|
||||
func (pt *ParamTable) ProxyIDList() []UniqueID {
|
||||
proxyIDStr, err := pt.Load("nodeID.proxyIDList")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var ret []UniqueID
|
||||
proxyIDs := strings.Split(proxyIDStr, ",")
|
||||
for _, i := range proxyIDs {
|
||||
v, err := strconv.Atoi(i)
|
||||
if err != nil {
|
||||
log.Panicf("load proxy id list error, %s", err.Error())
|
||||
}
|
||||
ret = append(ret, UniqueID(v))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (pt *ParamTable) queryNodeNum() int {
|
||||
return len(pt.queryNodeIDList())
|
||||
}
|
||||
|
@ -131,6 +150,25 @@ func (pt *ParamTable) TimeTickInterval() time.Duration {
|
|||
return time.Duration(interval) * time.Millisecond
|
||||
}
|
||||
|
||||
func (pt *ParamTable) convertRangeToSlice(rangeStr, sep string) []int {
|
||||
channelIDs := strings.Split(rangeStr, sep)
|
||||
startStr := channelIDs[0]
|
||||
endStr := channelIDs[1]
|
||||
start, err := strconv.Atoi(startStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
end, err := strconv.Atoi(endStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var ret []int
|
||||
for i := start; i < end; i++ {
|
||||
ret = append(ret, i)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (pt *ParamTable) sliceIndex() int {
|
||||
proxyID := pt.ProxyID()
|
||||
proxyIDList := pt.ProxyIDList()
|
||||
|
@ -152,7 +190,7 @@ func (pt *ParamTable) InsertChannelNames() []string {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
|
||||
channelIDs := pt.convertRangeToSlice(iRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
|
@ -178,12 +216,19 @@ func (pt *ParamTable) DeleteChannelNames() []string {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(dRangeStr, ",")
|
||||
channelIDs := pt.convertRangeToSlice(dRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
}
|
||||
return ret
|
||||
proxyNum := pt.ProxyNum()
|
||||
sep := len(channelIDs) / proxyNum
|
||||
index := pt.sliceIndex()
|
||||
if index == -1 {
|
||||
panic("ProxyID not Match with Config")
|
||||
}
|
||||
start := index * sep
|
||||
return ret[start : start+sep]
|
||||
}
|
||||
|
||||
func (pt *ParamTable) K2SChannelNames() []string {
|
||||
|
@ -196,12 +241,19 @@ func (pt *ParamTable) K2SChannelNames() []string {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(k2sRangeStr, ",")
|
||||
channelIDs := pt.convertRangeToSlice(k2sRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
}
|
||||
return ret
|
||||
proxyNum := pt.ProxyNum()
|
||||
sep := len(channelIDs) / proxyNum
|
||||
index := pt.sliceIndex()
|
||||
if index == -1 {
|
||||
panic("ProxyID not Match with Config")
|
||||
}
|
||||
start := index * sep
|
||||
return ret[start : start+sep]
|
||||
}
|
||||
|
||||
func (pt *ParamTable) SearchChannelNames() []string {
|
||||
|
@ -209,17 +261,8 @@ func (pt *ParamTable) SearchChannelNames() []string {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
prefix += "-"
|
||||
sRangeStr, err := pt.Load("msgChannel.channelRange.search")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(sRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
}
|
||||
return ret
|
||||
prefix += "-0"
|
||||
return []string{prefix}
|
||||
}
|
||||
|
||||
func (pt *ParamTable) SearchResultChannelNames() []string {
|
||||
|
@ -232,7 +275,7 @@ func (pt *ParamTable) SearchResultChannelNames() []string {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(sRangeStr, ",")
|
||||
channelIDs := pt.convertRangeToSlice(sRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
|
@ -278,24 +321,144 @@ func (pt *ParamTable) DataDefinitionChannelNames() []string {
|
|||
return []string{prefix}
|
||||
}
|
||||
|
||||
func (pt *ParamTable) parseInt64(key string) int64 {
|
||||
valueStr, err := pt.Load(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
value, err := strconv.Atoi(valueStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(value)
|
||||
}
|
||||
|
||||
func (pt *ParamTable) MsgStreamInsertBufSize() int64 {
|
||||
return pt.ParseInt64("proxy.msgStream.insert.bufSize")
|
||||
return pt.parseInt64("proxy.msgStream.insert.bufSize")
|
||||
}
|
||||
|
||||
func (pt *ParamTable) MsgStreamSearchBufSize() int64 {
|
||||
return pt.ParseInt64("proxy.msgStream.search.bufSize")
|
||||
return pt.parseInt64("proxy.msgStream.search.bufSize")
|
||||
}
|
||||
|
||||
func (pt *ParamTable) MsgStreamSearchResultBufSize() int64 {
|
||||
return pt.ParseInt64("proxy.msgStream.searchResult.recvBufSize")
|
||||
return pt.parseInt64("proxy.msgStream.searchResult.recvBufSize")
|
||||
}
|
||||
|
||||
func (pt *ParamTable) MsgStreamSearchResultPulsarBufSize() int64 {
|
||||
return pt.ParseInt64("proxy.msgStream.searchResult.pulsarBufSize")
|
||||
return pt.parseInt64("proxy.msgStream.searchResult.pulsarBufSize")
|
||||
}
|
||||
|
||||
func (pt *ParamTable) MsgStreamTimeTickBufSize() int64 {
|
||||
return pt.ParseInt64("proxy.msgStream.timeTick.bufSize")
|
||||
return pt.parseInt64("proxy.msgStream.timeTick.bufSize")
|
||||
}
|
||||
|
||||
func (pt *ParamTable) insertChannelNames() []string {
|
||||
ch, err := pt.Load("msgChannel.chanNamePrefix.insert")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := pt.Load("msgChannel.channelRange.insert")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (pt *ParamTable) searchChannelNames() []string {
|
||||
ch, err := pt.Load("msgChannel.chanNamePrefix.search")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := pt.Load("msgChannel.channelRange.search")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (pt *ParamTable) searchResultChannelNames() []string {
|
||||
ch, err := pt.Load("msgChannel.chanNamePrefix.searchResult")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := pt.Load("msgChannel.channelRange.searchResult")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (pt *ParamTable) MaxNameLength() int64 {
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -60,7 +59,7 @@ func CreateProxy(ctx context.Context) (*Proxy, error) {
|
|||
|
||||
p.queryMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamSearchBufSize())
|
||||
p.queryMsgStream.SetPulsarClient(pulsarAddress)
|
||||
p.queryMsgStream.CreatePulsarProducers(Params.SearchChannelNames())
|
||||
p.queryMsgStream.CreatePulsarProducers(Params.searchChannelNames())
|
||||
|
||||
masterAddr := Params.MasterAddress()
|
||||
idAllocator, err := allocator.NewIDAllocator(p.proxyLoopCtx, masterAddr)
|
||||
|
@ -84,7 +83,7 @@ func CreateProxy(ctx context.Context) (*Proxy, error) {
|
|||
|
||||
p.manipulationMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamInsertBufSize())
|
||||
p.manipulationMsgStream.SetPulsarClient(pulsarAddress)
|
||||
p.manipulationMsgStream.CreatePulsarProducers(Params.InsertChannelNames())
|
||||
p.manipulationMsgStream.CreatePulsarProducers(Params.insertChannelNames())
|
||||
repackFuncImpl := func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
|
||||
return insertRepackFunc(tsMsgs, hashKeys, p.segAssigner, false)
|
||||
}
|
||||
|
@ -138,7 +137,7 @@ func (p *Proxy) AddCloseCallback(callbacks ...func()) {
|
|||
func (p *Proxy) grpcLoop() {
|
||||
defer p.proxyLoopWg.Done()
|
||||
|
||||
lis, err := net.Listen("tcp", ":"+strconv.Itoa(Params.NetworkPort()))
|
||||
lis, err := net.Listen("tcp", Params.NetWorkAddress())
|
||||
if err != nil {
|
||||
log.Fatalf("Proxy grpc server fatal error=%v", err)
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -35,26 +34,8 @@ var masterServer *master.Master
|
|||
|
||||
var testNum = 10
|
||||
|
||||
func makeNewChannalNames(names []string, suffix string) []string {
|
||||
var ret []string
|
||||
for _, name := range names {
|
||||
ret = append(ret, name+suffix)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func refreshChannelNames() {
|
||||
suffix := "_test" + strconv.FormatInt(rand.Int63n(100), 10)
|
||||
master.Params.DDChannelNames = makeNewChannalNames(master.Params.DDChannelNames, suffix)
|
||||
master.Params.WriteNodeTimeTickChannelNames = makeNewChannalNames(master.Params.WriteNodeTimeTickChannelNames, suffix)
|
||||
master.Params.InsertChannelNames = makeNewChannalNames(master.Params.InsertChannelNames, suffix)
|
||||
master.Params.K2SChannelNames = makeNewChannalNames(master.Params.K2SChannelNames, suffix)
|
||||
master.Params.ProxyTimeTickChannelNames = makeNewChannalNames(master.Params.ProxyTimeTickChannelNames, suffix)
|
||||
}
|
||||
|
||||
func startMaster(ctx context.Context) {
|
||||
master.Init()
|
||||
refreshChannelNames()
|
||||
etcdAddr := master.Params.EtcdAddress
|
||||
metaRootPath := master.Params.MetaRootPath
|
||||
|
||||
|
@ -100,7 +81,7 @@ func setup() {
|
|||
|
||||
startMaster(ctx)
|
||||
startProxy(ctx)
|
||||
proxyAddr := Params.NetworkAddress()
|
||||
proxyAddr := Params.NetWorkAddress()
|
||||
addr := strings.Split(proxyAddr, ":")
|
||||
if addr[0] == "0.0.0.0" {
|
||||
proxyAddr = "127.0.0.1:" + addr[1]
|
||||
|
|
|
@ -364,7 +364,7 @@ func (sched *TaskScheduler) queryResultLoop() {
|
|||
unmarshal := msgstream.NewUnmarshalDispatcher()
|
||||
queryResultMsgStream := msgstream.NewPulsarMsgStream(sched.ctx, Params.MsgStreamSearchResultBufSize())
|
||||
queryResultMsgStream.SetPulsarClient(Params.PulsarAddress())
|
||||
queryResultMsgStream.CreatePulsarConsumers(Params.SearchResultChannelNames(),
|
||||
queryResultMsgStream.CreatePulsarConsumers(Params.searchResultChannelNames(),
|
||||
Params.ProxySubName(),
|
||||
unmarshal,
|
||||
Params.MsgStreamSearchResultPulsarBufSize())
|
||||
|
|
|
@ -31,7 +31,7 @@ import (
|
|||
* is up-to-date.
|
||||
*/
|
||||
type collectionReplica interface {
|
||||
getTSafe() tSafe
|
||||
getTSafe() *tSafe
|
||||
|
||||
// collection
|
||||
getCollectionNum() int
|
||||
|
@ -68,11 +68,11 @@ type collectionReplicaImpl struct {
|
|||
collections []*Collection
|
||||
segments map[UniqueID]*Segment
|
||||
|
||||
tSafe tSafe
|
||||
tSafe *tSafe
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------------------------------- tSafe
|
||||
func (colReplica *collectionReplicaImpl) getTSafe() tSafe {
|
||||
func (colReplica *collectionReplicaImpl) getTSafe() *tSafe {
|
||||
return colReplica.tSafe
|
||||
}
|
||||
|
||||
|
@ -111,7 +111,6 @@ func (colReplica *collectionReplicaImpl) removeCollection(collectionID UniqueID)
|
|||
if col.ID() == collectionID {
|
||||
for _, p := range *col.Partitions() {
|
||||
for _, s := range *p.Segments() {
|
||||
deleteSegment(colReplica.segments[s.ID()])
|
||||
delete(colReplica.segments, s.ID())
|
||||
}
|
||||
}
|
||||
|
@ -203,7 +202,6 @@ func (colReplica *collectionReplicaImpl) removePartition(collectionID UniqueID,
|
|||
for _, p := range *collection.Partitions() {
|
||||
if p.Tag() == partitionTag {
|
||||
for _, s := range *p.Segments() {
|
||||
deleteSegment(colReplica.segments[s.ID()])
|
||||
delete(colReplica.segments, s.ID())
|
||||
}
|
||||
} else {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,48 +1,179 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
func TestCollection_Partitions(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
initTestMeta(t, node, collectionName, collectionID, 0)
|
||||
ctx := context.Background()
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
collection, err := node.replica.getCollectionByName(collectionName)
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
for _, tag := range collectionMeta.PartitionTags {
|
||||
err := (*node.replica).addPartition(collection.ID(), tag)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
partitions := collection.Partitions()
|
||||
assert.Equal(t, 1, len(*partitions))
|
||||
assert.Equal(t, len(collectionMeta.PartitionTags), len(*partitions))
|
||||
}
|
||||
|
||||
func TestCollection_newCollection(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
}
|
||||
|
||||
func TestCollection_deleteCollection(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
deleteCollection(collection)
|
||||
}
|
||||
|
|
|
@ -11,10 +11,10 @@ type dataSyncService struct {
|
|||
ctx context.Context
|
||||
fg *flowgraph.TimeTickedFlowGraph
|
||||
|
||||
replica collectionReplica
|
||||
replica *collectionReplica
|
||||
}
|
||||
|
||||
func newDataSyncService(ctx context.Context, replica collectionReplica) *dataSyncService {
|
||||
func newDataSyncService(ctx context.Context, replica *collectionReplica) *dataSyncService {
|
||||
|
||||
return &dataSyncService{
|
||||
ctx: ctx,
|
||||
|
|
|
@ -1,22 +1,101 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
// NOTE: start pulsar before test
|
||||
func TestDataSyncService_Start(t *testing.T) {
|
||||
Params.Init()
|
||||
var ctx context.Context
|
||||
|
||||
if closeWithDeadline {
|
||||
var cancel context.CancelFunc
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel = context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// init query node
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
// init meta
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
node := newQueryNode()
|
||||
initTestMeta(t, node, "collection0", 0, 0)
|
||||
// test data generate
|
||||
const msgLength = 10
|
||||
const DIM = 16
|
||||
|
@ -100,25 +179,25 @@ func TestDataSyncService_Start(t *testing.T) {
|
|||
// pulsar produce
|
||||
const receiveBufSize = 1024
|
||||
producerChannels := Params.insertChannelNames()
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
|
||||
insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
|
||||
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
insertStream.SetPulsarClient(pulsarURL)
|
||||
insertStream.CreatePulsarProducers(producerChannels)
|
||||
|
||||
var insertMsgStream msgstream.MsgStream = insertStream
|
||||
insertMsgStream.Start()
|
||||
|
||||
err := insertMsgStream.Produce(&msgPack)
|
||||
err = insertMsgStream.Produce(&msgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = insertMsgStream.Broadcast(&timeTickMsgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// dataSync
|
||||
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
|
||||
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
|
||||
go node.dataSyncService.start()
|
||||
|
||||
node.Close()
|
||||
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
|
||||
type insertNode struct {
|
||||
BaseNode
|
||||
replica collectionReplica
|
||||
replica *collectionReplica
|
||||
}
|
||||
|
||||
type InsertData struct {
|
||||
|
@ -58,13 +58,13 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg {
|
|||
insertData.insertRecords[task.SegmentID] = append(insertData.insertRecords[task.SegmentID], task.RowData...)
|
||||
|
||||
// check if segment exists, if not, create this segment
|
||||
if !iNode.replica.hasSegment(task.SegmentID) {
|
||||
collection, err := iNode.replica.getCollectionByName(task.CollectionName)
|
||||
if !(*iNode.replica).hasSegment(task.SegmentID) {
|
||||
collection, err := (*iNode.replica).getCollectionByName(task.CollectionName)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
err = iNode.replica.addSegment(task.SegmentID, task.PartitionTag, collection.ID())
|
||||
err = (*iNode.replica).addSegment(task.SegmentID, task.PartitionTag, collection.ID())
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
|
@ -74,7 +74,7 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg {
|
|||
|
||||
// 2. do preInsert
|
||||
for segmentID := range insertData.insertRecords {
|
||||
var targetSegment, err = iNode.replica.getSegmentByID(segmentID)
|
||||
var targetSegment, err = (*iNode.replica).getSegmentByID(segmentID)
|
||||
if err != nil {
|
||||
log.Println("preInsert failed")
|
||||
// TODO: add error handling
|
||||
|
@ -102,7 +102,7 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg {
|
|||
}
|
||||
|
||||
func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *sync.WaitGroup) {
|
||||
var targetSegment, err = iNode.replica.getSegmentByID(segmentID)
|
||||
var targetSegment, err = (*iNode.replica).getSegmentByID(segmentID)
|
||||
if err != nil {
|
||||
log.Println("cannot find segment:", segmentID)
|
||||
// TODO: add error handling
|
||||
|
@ -127,7 +127,7 @@ func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *syn
|
|||
wg.Done()
|
||||
}
|
||||
|
||||
func newInsertNode(replica collectionReplica) *insertNode {
|
||||
func newInsertNode(replica *collectionReplica) *insertNode {
|
||||
maxQueueLength := Params.flowGraphMaxQueueLength()
|
||||
maxParallelism := Params.flowGraphMaxParallelism()
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
type serviceTimeNode struct {
|
||||
BaseNode
|
||||
replica collectionReplica
|
||||
replica *collectionReplica
|
||||
}
|
||||
|
||||
func (stNode *serviceTimeNode) Name() string {
|
||||
|
@ -28,12 +28,12 @@ func (stNode *serviceTimeNode) Operate(in []*Msg) []*Msg {
|
|||
}
|
||||
|
||||
// update service time
|
||||
stNode.replica.getTSafe().set(serviceTimeMsg.timeRange.timestampMax)
|
||||
(*(*stNode.replica).getTSafe()).set(serviceTimeMsg.timeRange.timestampMax)
|
||||
//fmt.Println("update tSafe to:", getPhysicalTime(serviceTimeMsg.timeRange.timestampMax))
|
||||
return nil
|
||||
}
|
||||
|
||||
func newServiceTimeNode(replica collectionReplica) *serviceTimeNode {
|
||||
func newServiceTimeNode(replica *collectionReplica) *serviceTimeNode {
|
||||
maxQueueLength := Params.flowGraphMaxQueueLength()
|
||||
maxParallelism := Params.flowGraphMaxParallelism()
|
||||
|
||||
|
|
|
@ -26,10 +26,10 @@ const (
|
|||
type metaService struct {
|
||||
ctx context.Context
|
||||
kvBase *etcdkv.EtcdKV
|
||||
replica collectionReplica
|
||||
replica *collectionReplica
|
||||
}
|
||||
|
||||
func newMetaService(ctx context.Context, replica collectionReplica) *metaService {
|
||||
func newMetaService(ctx context.Context, replica *collectionReplica) *metaService {
|
||||
ETCDAddr := Params.etcdAddress()
|
||||
MetaRootPath := Params.metaRootPath()
|
||||
|
||||
|
@ -149,12 +149,12 @@ func (mService *metaService) processCollectionCreate(id string, value string) {
|
|||
|
||||
col := mService.collectionUnmarshal(value)
|
||||
if col != nil {
|
||||
err := mService.replica.addCollection(col, value)
|
||||
err := (*mService.replica).addCollection(col, value)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
for _, partitionTag := range col.PartitionTags {
|
||||
err = mService.replica.addPartition(col.ID, partitionTag)
|
||||
err = (*mService.replica).addPartition(col.ID, partitionTag)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
@ -173,7 +173,7 @@ func (mService *metaService) processSegmentCreate(id string, value string) {
|
|||
|
||||
// TODO: what if seg == nil? We need to notify master and return rpc request failed
|
||||
if seg != nil {
|
||||
err := mService.replica.addSegment(seg.SegmentID, seg.PartitionTag, seg.CollectionID)
|
||||
err := (*mService.replica).addSegment(seg.SegmentID, seg.PartitionTag, seg.CollectionID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
|
@ -202,7 +202,7 @@ func (mService *metaService) processSegmentModify(id string, value string) {
|
|||
}
|
||||
|
||||
if seg != nil {
|
||||
targetSegment, err := mService.replica.getSegmentByID(seg.SegmentID)
|
||||
targetSegment, err := (*mService.replica).getSegmentByID(seg.SegmentID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
|
@ -218,11 +218,11 @@ func (mService *metaService) processCollectionModify(id string, value string) {
|
|||
|
||||
col := mService.collectionUnmarshal(value)
|
||||
if col != nil {
|
||||
err := mService.replica.addPartitionsByCollectionMeta(col)
|
||||
err := (*mService.replica).addPartitionsByCollectionMeta(col)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
err = mService.replica.removePartitionsByCollectionMeta(col)
|
||||
err = (*mService.replica).removePartitionsByCollectionMeta(col)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
@ -249,7 +249,7 @@ func (mService *metaService) processSegmentDelete(id string) {
|
|||
log.Println("Cannot parse segment id:" + id)
|
||||
}
|
||||
|
||||
err = mService.replica.removeSegment(segmentID)
|
||||
err = (*mService.replica).removeSegment(segmentID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
|
@ -264,7 +264,7 @@ func (mService *metaService) processCollectionDelete(id string) {
|
|||
log.Println("Cannot parse collection id:" + id)
|
||||
}
|
||||
|
||||
err = mService.replica.removeCollection(collectionID)
|
||||
err = (*mService.replica).removeCollection(collectionID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
|
|
|
@ -3,13 +3,23 @@ package querynode
|
|||
import (
|
||||
"context"
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
Params.Init()
|
||||
exitCode := m.Run()
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func TestMetaService_start(t *testing.T) {
|
||||
var ctx context.Context
|
||||
|
||||
|
@ -27,7 +37,6 @@ func TestMetaService_start(t *testing.T) {
|
|||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
(*node.metaService).start()
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_getCollectionObjId(t *testing.T) {
|
||||
|
@ -110,9 +119,47 @@ func TestMetaService_isSegmentChannelRangeInQueryNodeChannelRange(t *testing.T)
|
|||
|
||||
func TestMetaService_printCollectionStruct(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
printCollectionStruct(collectionMeta)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
printCollectionStruct(&collectionMeta)
|
||||
}
|
||||
|
||||
func TestMetaService_printSegmentStruct(t *testing.T) {
|
||||
|
@ -131,8 +178,13 @@ func TestMetaService_printSegmentStruct(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMetaService_processCollectionCreate(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
id := "0"
|
||||
value := `schema: <
|
||||
|
@ -144,10 +196,6 @@ func TestMetaService_processCollectionCreate(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -164,21 +212,71 @@ func TestMetaService_processCollectionCreate(t *testing.T) {
|
|||
|
||||
node.metaService.processCollectionCreate(id, value)
|
||||
|
||||
collectionNum := node.replica.getCollectionNum()
|
||||
collectionNum := (*node.replica).getCollectionNum()
|
||||
assert.Equal(t, collectionNum, 1)
|
||||
|
||||
collection, err := node.replica.getCollectionByName("test")
|
||||
collection, err := (*node.replica).getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processSegmentCreate(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
initTestMeta(t, node, collectionName, collectionID, 0)
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
colMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
|
||||
err := (*node.replica).addCollection(&collectionMeta, string(colMetaBlob))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = (*node.replica).addPartition(UniqueID(0), "default")
|
||||
assert.NoError(t, err)
|
||||
|
||||
id := "0"
|
||||
value := `partition_tag: "default"
|
||||
|
@ -189,15 +287,19 @@ func TestMetaService_processSegmentCreate(t *testing.T) {
|
|||
|
||||
(*node.metaService).processSegmentCreate(id, value)
|
||||
|
||||
s, err := node.replica.getSegmentByID(UniqueID(0))
|
||||
s, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, s.segmentID, UniqueID(0))
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processCreate(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
key1 := "by-dev/meta/collection/0"
|
||||
msg1 := `schema: <
|
||||
|
@ -209,10 +311,6 @@ func TestMetaService_processCreate(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -228,10 +326,10 @@ func TestMetaService_processCreate(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCreate(key1, msg1)
|
||||
collectionNum := node.replica.getCollectionNum()
|
||||
collectionNum := (*node.replica).getCollectionNum()
|
||||
assert.Equal(t, collectionNum, 1)
|
||||
|
||||
collection, err := node.replica.getCollectionByName("test")
|
||||
collection, err := (*node.replica).getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
|
||||
|
@ -243,19 +341,68 @@ func TestMetaService_processCreate(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCreate(key2, msg2)
|
||||
s, err := node.replica.getSegmentByID(UniqueID(0))
|
||||
s, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, s.segmentID, UniqueID(0))
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processSegmentModify(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
segmentID := UniqueID(0)
|
||||
initTestMeta(t, node, collectionName, collectionID, segmentID)
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
colMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
|
||||
err := (*node.replica).addCollection(&collectionMeta, string(colMetaBlob))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = (*node.replica).addPartition(UniqueID(0), "default")
|
||||
assert.NoError(t, err)
|
||||
|
||||
id := "0"
|
||||
value := `partition_tag: "default"
|
||||
|
@ -265,9 +412,9 @@ func TestMetaService_processSegmentModify(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processSegmentCreate(id, value)
|
||||
s, err := node.replica.getSegmentByID(segmentID)
|
||||
s, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, s.segmentID, segmentID)
|
||||
assert.Equal(t, s.segmentID, UniqueID(0))
|
||||
|
||||
newValue := `partition_tag: "default"
|
||||
channel_start: 0
|
||||
|
@ -277,15 +424,19 @@ func TestMetaService_processSegmentModify(t *testing.T) {
|
|||
|
||||
// TODO: modify segment for testing processCollectionModify
|
||||
(*node.metaService).processSegmentModify(id, newValue)
|
||||
seg, err := node.replica.getSegmentByID(segmentID)
|
||||
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, seg.segmentID, segmentID)
|
||||
node.Close()
|
||||
assert.Equal(t, seg.segmentID, UniqueID(0))
|
||||
}
|
||||
|
||||
func TestMetaService_processCollectionModify(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
id := "0"
|
||||
value := `schema: <
|
||||
|
@ -297,10 +448,6 @@ func TestMetaService_processCollectionModify(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -318,24 +465,24 @@ func TestMetaService_processCollectionModify(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCollectionCreate(id, value)
|
||||
collectionNum := node.replica.getCollectionNum()
|
||||
collectionNum := (*node.replica).getCollectionNum()
|
||||
assert.Equal(t, collectionNum, 1)
|
||||
|
||||
collection, err := node.replica.getCollectionByName("test")
|
||||
collection, err := (*node.replica).getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
|
||||
partitionNum, err := node.replica.getPartitionNum(UniqueID(0))
|
||||
partitionNum, err := (*node.replica).getPartitionNum(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, partitionNum, 3)
|
||||
|
||||
hasPartition := node.replica.hasPartition(UniqueID(0), "p0")
|
||||
hasPartition := (*node.replica).hasPartition(UniqueID(0), "p0")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
|
||||
assert.Equal(t, hasPartition, false)
|
||||
|
||||
newValue := `schema: <
|
||||
|
@ -347,10 +494,6 @@ func TestMetaService_processCollectionModify(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -368,28 +511,32 @@ func TestMetaService_processCollectionModify(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCollectionModify(id, newValue)
|
||||
collection, err = node.replica.getCollectionByName("test")
|
||||
collection, err = (*node.replica).getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
|
||||
partitionNum, err = node.replica.getPartitionNum(UniqueID(0))
|
||||
partitionNum, err = (*node.replica).getPartitionNum(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, partitionNum, 3)
|
||||
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p0")
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p0")
|
||||
assert.Equal(t, hasPartition, false)
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processModify(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
key1 := "by-dev/meta/collection/0"
|
||||
msg1 := `schema: <
|
||||
|
@ -401,10 +548,6 @@ func TestMetaService_processModify(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -422,24 +565,24 @@ func TestMetaService_processModify(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCreate(key1, msg1)
|
||||
collectionNum := node.replica.getCollectionNum()
|
||||
collectionNum := (*node.replica).getCollectionNum()
|
||||
assert.Equal(t, collectionNum, 1)
|
||||
|
||||
collection, err := node.replica.getCollectionByName("test")
|
||||
collection, err := (*node.replica).getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
|
||||
partitionNum, err := node.replica.getPartitionNum(UniqueID(0))
|
||||
partitionNum, err := (*node.replica).getPartitionNum(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, partitionNum, 3)
|
||||
|
||||
hasPartition := node.replica.hasPartition(UniqueID(0), "p0")
|
||||
hasPartition := (*node.replica).hasPartition(UniqueID(0), "p0")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
|
||||
assert.Equal(t, hasPartition, false)
|
||||
|
||||
key2 := "by-dev/meta/segment/0"
|
||||
|
@ -450,7 +593,7 @@ func TestMetaService_processModify(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCreate(key2, msg2)
|
||||
s, err := node.replica.getSegmentByID(UniqueID(0))
|
||||
s, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, s.segmentID, UniqueID(0))
|
||||
|
||||
|
@ -465,10 +608,6 @@ func TestMetaService_processModify(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -486,21 +625,21 @@ func TestMetaService_processModify(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processModify(key1, msg3)
|
||||
collection, err = node.replica.getCollectionByName("test")
|
||||
collection, err = (*node.replica).getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
|
||||
partitionNum, err = node.replica.getPartitionNum(UniqueID(0))
|
||||
partitionNum, err = (*node.replica).getPartitionNum(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, partitionNum, 3)
|
||||
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p0")
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p0")
|
||||
assert.Equal(t, hasPartition, false)
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
|
||||
msg4 := `partition_tag: "p1"
|
||||
|
@ -510,18 +649,68 @@ func TestMetaService_processModify(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processModify(key2, msg4)
|
||||
seg, err := node.replica.getSegmentByID(UniqueID(0))
|
||||
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, seg.segmentID, UniqueID(0))
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processSegmentDelete(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
initTestMeta(t, node, collectionName, collectionID, 0)
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
colMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
|
||||
err := (*node.replica).addCollection(&collectionMeta, string(colMetaBlob))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = (*node.replica).addPartition(UniqueID(0), "default")
|
||||
assert.NoError(t, err)
|
||||
|
||||
id := "0"
|
||||
value := `partition_tag: "default"
|
||||
|
@ -531,19 +720,23 @@ func TestMetaService_processSegmentDelete(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processSegmentCreate(id, value)
|
||||
seg, err := node.replica.getSegmentByID(UniqueID(0))
|
||||
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, seg.segmentID, UniqueID(0))
|
||||
|
||||
(*node.metaService).processSegmentDelete("0")
|
||||
mapSize := node.replica.getSegmentNum()
|
||||
mapSize := (*node.replica).getSegmentNum()
|
||||
assert.Equal(t, mapSize, 0)
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processCollectionDelete(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
id := "0"
|
||||
value := `schema: <
|
||||
|
@ -555,10 +748,6 @@ func TestMetaService_processCollectionDelete(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -574,22 +763,26 @@ func TestMetaService_processCollectionDelete(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCollectionCreate(id, value)
|
||||
collectionNum := node.replica.getCollectionNum()
|
||||
collectionNum := (*node.replica).getCollectionNum()
|
||||
assert.Equal(t, collectionNum, 1)
|
||||
|
||||
collection, err := node.replica.getCollectionByName("test")
|
||||
collection, err := (*node.replica).getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
|
||||
(*node.metaService).processCollectionDelete(id)
|
||||
collectionNum = node.replica.getCollectionNum()
|
||||
collectionNum = (*node.replica).getCollectionNum()
|
||||
assert.Equal(t, collectionNum, 0)
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processDelete(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
key1 := "by-dev/meta/collection/0"
|
||||
msg1 := `schema: <
|
||||
|
@ -601,10 +794,6 @@ func TestMetaService_processDelete(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -620,10 +809,10 @@ func TestMetaService_processDelete(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCreate(key1, msg1)
|
||||
collectionNum := node.replica.getCollectionNum()
|
||||
collectionNum := (*node.replica).getCollectionNum()
|
||||
assert.Equal(t, collectionNum, 1)
|
||||
|
||||
collection, err := node.replica.getCollectionByName("test")
|
||||
collection, err := (*node.replica).getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
|
||||
|
@ -635,48 +824,77 @@ func TestMetaService_processDelete(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCreate(key2, msg2)
|
||||
seg, err := node.replica.getSegmentByID(UniqueID(0))
|
||||
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, seg.segmentID, UniqueID(0))
|
||||
|
||||
(*node.metaService).processDelete(key1)
|
||||
collectionsSize := node.replica.getCollectionNum()
|
||||
collectionsSize := (*node.replica).getCollectionNum()
|
||||
assert.Equal(t, collectionsSize, 0)
|
||||
|
||||
mapSize := node.replica.getSegmentNum()
|
||||
mapSize := (*node.replica).getSegmentNum()
|
||||
assert.Equal(t, mapSize, 0)
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processResp(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
var ctx context.Context
|
||||
if closeWithDeadline {
|
||||
var cancel context.CancelFunc
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel = context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
metaChan := (*node.metaService).kvBase.WatchWithPrefix("")
|
||||
|
||||
select {
|
||||
case <-node.queryNodeLoopCtx.Done():
|
||||
case <-node.ctx.Done():
|
||||
return
|
||||
case resp := <-metaChan:
|
||||
_ = (*node.metaService).processResp(resp)
|
||||
}
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_loadCollections(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
var ctx context.Context
|
||||
if closeWithDeadline {
|
||||
var cancel context.CancelFunc
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel = context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
err2 := (*node.metaService).loadCollections()
|
||||
assert.Nil(t, err2)
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_loadSegments(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
var ctx context.Context
|
||||
if closeWithDeadline {
|
||||
var cancel context.CancelFunc
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel = context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
err2 := (*node.metaService).loadSegments()
|
||||
assert.Nil(t, err2)
|
||||
node.Close()
|
||||
}
|
||||
|
|
|
@ -2,8 +2,8 @@ package querynode
|
|||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/paramtable"
|
||||
)
|
||||
|
@ -21,16 +21,15 @@ func (p *ParamTable) Init() {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
queryNodeIDStr := os.Getenv("QUERY_NODE_ID")
|
||||
if queryNodeIDStr == "" {
|
||||
queryNodeIDList := p.QueryNodeIDList()
|
||||
if len(queryNodeIDList) <= 0 {
|
||||
queryNodeIDStr = "0"
|
||||
} else {
|
||||
queryNodeIDStr = strconv.Itoa(int(queryNodeIDList[0]))
|
||||
}
|
||||
err = p.LoadYaml("milvus.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = p.LoadYaml("advanced/channel.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.Save("_queryNodeID", queryNodeIDStr)
|
||||
}
|
||||
|
||||
func (p *ParamTable) pulsarAddress() (string, error) {
|
||||
|
@ -41,8 +40,8 @@ func (p *ParamTable) pulsarAddress() (string, error) {
|
|||
return url, nil
|
||||
}
|
||||
|
||||
func (p *ParamTable) QueryNodeID() UniqueID {
|
||||
queryNodeID, err := p.Load("_queryNodeID")
|
||||
func (p *ParamTable) queryNodeID() int {
|
||||
queryNodeID, err := p.Load("reader.clientid")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -50,7 +49,7 @@ func (p *ParamTable) QueryNodeID() UniqueID {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return UniqueID(id)
|
||||
return id
|
||||
}
|
||||
|
||||
func (p *ParamTable) insertChannelRange() []int {
|
||||
|
@ -58,47 +57,138 @@ func (p *ParamTable) insertChannelRange() []int {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return paramtable.ConvertRangeToIntRange(insertChannelRange, ",")
|
||||
|
||||
channelRange := strings.Split(insertChannelRange, ",")
|
||||
if len(channelRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(channelRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(channelRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
return []int{channelBegin, channelEnd}
|
||||
}
|
||||
|
||||
// advanced params
|
||||
// stats
|
||||
func (p *ParamTable) statsPublishInterval() int {
|
||||
return p.ParseInt("queryNode.stats.publishInterval")
|
||||
timeInterval, err := p.Load("queryNode.stats.publishInterval")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
interval, err := strconv.Atoi(timeInterval)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return interval
|
||||
}
|
||||
|
||||
// dataSync:
|
||||
func (p *ParamTable) flowGraphMaxQueueLength() int32 {
|
||||
return p.ParseInt32("queryNode.dataSync.flowGraph.maxQueueLength")
|
||||
queueLength, err := p.Load("queryNode.dataSync.flowGraph.maxQueueLength")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
length, err := strconv.Atoi(queueLength)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int32(length)
|
||||
}
|
||||
|
||||
func (p *ParamTable) flowGraphMaxParallelism() int32 {
|
||||
return p.ParseInt32("queryNode.dataSync.flowGraph.maxParallelism")
|
||||
maxParallelism, err := p.Load("queryNode.dataSync.flowGraph.maxParallelism")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
maxPara, err := strconv.Atoi(maxParallelism)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int32(maxPara)
|
||||
}
|
||||
|
||||
// msgStream
|
||||
func (p *ParamTable) insertReceiveBufSize() int64 {
|
||||
return p.ParseInt64("queryNode.msgStream.insert.recvBufSize")
|
||||
revBufSize, err := p.Load("queryNode.msgStream.insert.recvBufSize")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bufSize, err := strconv.Atoi(revBufSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(bufSize)
|
||||
}
|
||||
|
||||
func (p *ParamTable) insertPulsarBufSize() int64 {
|
||||
return p.ParseInt64("queryNode.msgStream.insert.pulsarBufSize")
|
||||
pulsarBufSize, err := p.Load("queryNode.msgStream.insert.pulsarBufSize")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bufSize, err := strconv.Atoi(pulsarBufSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(bufSize)
|
||||
}
|
||||
|
||||
func (p *ParamTable) searchReceiveBufSize() int64 {
|
||||
return p.ParseInt64("queryNode.msgStream.search.recvBufSize")
|
||||
revBufSize, err := p.Load("queryNode.msgStream.search.recvBufSize")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bufSize, err := strconv.Atoi(revBufSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(bufSize)
|
||||
}
|
||||
|
||||
func (p *ParamTable) searchPulsarBufSize() int64 {
|
||||
return p.ParseInt64("queryNode.msgStream.search.pulsarBufSize")
|
||||
pulsarBufSize, err := p.Load("queryNode.msgStream.search.pulsarBufSize")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bufSize, err := strconv.Atoi(pulsarBufSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(bufSize)
|
||||
}
|
||||
|
||||
func (p *ParamTable) searchResultReceiveBufSize() int64 {
|
||||
return p.ParseInt64("queryNode.msgStream.searchResult.recvBufSize")
|
||||
revBufSize, err := p.Load("queryNode.msgStream.searchResult.recvBufSize")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bufSize, err := strconv.Atoi(revBufSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(bufSize)
|
||||
}
|
||||
|
||||
func (p *ParamTable) statsReceiveBufSize() int64 {
|
||||
return p.ParseInt64("queryNode.msgStream.stats.recvBufSize")
|
||||
revBufSize, err := p.Load("queryNode.msgStream.stats.recvBufSize")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bufSize, err := strconv.Atoi(revBufSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(bufSize)
|
||||
}
|
||||
|
||||
func (p *ParamTable) etcdAddress() string {
|
||||
|
@ -122,73 +212,123 @@ func (p *ParamTable) metaRootPath() string {
|
|||
}
|
||||
|
||||
func (p *ParamTable) gracefulTime() int64 {
|
||||
return p.ParseInt64("queryNode.gracefulTime")
|
||||
gracefulTime, err := p.Load("queryNode.gracefulTime")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
time, err := strconv.Atoi(gracefulTime)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(time)
|
||||
}
|
||||
|
||||
func (p *ParamTable) insertChannelNames() []string {
|
||||
|
||||
prefix, err := p.Load("msgChannel.chanNamePrefix.insert")
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.insert")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
prefix += "-"
|
||||
channelRange, err := p.Load("msgChannel.channelRange.insert")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(channelRange, ",")
|
||||
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
sep := len(channelIDs) / p.queryNodeNum()
|
||||
index := p.sliceIndex()
|
||||
if index == -1 {
|
||||
panic("queryNodeID not Match with Config")
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
start := index * sep
|
||||
return ret[start : start+sep]
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (p *ParamTable) searchChannelNames() []string {
|
||||
prefix, err := p.Load("msgChannel.chanNamePrefix.search")
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.search")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
prefix += "-"
|
||||
channelRange, err := p.Load("msgChannel.channelRange.search")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(channelRange, ",")
|
||||
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
return ret
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (p *ParamTable) searchResultChannelNames() []string {
|
||||
prefix, err := p.Load("msgChannel.chanNamePrefix.searchResult")
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.searchResult")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
prefix += "-"
|
||||
channelRange, err := p.Load("msgChannel.channelRange.searchResult")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(channelRange, ",")
|
||||
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
return ret
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (p *ParamTable) msgChannelSubName() string {
|
||||
|
@ -197,11 +337,7 @@ func (p *ParamTable) msgChannelSubName() string {
|
|||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
queryNodeIDStr, err := p.Load("_QueryNodeID")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return name + "-" + queryNodeIDStr
|
||||
return name
|
||||
}
|
||||
|
||||
func (p *ParamTable) statsChannelName() string {
|
||||
|
@ -211,18 +347,3 @@ func (p *ParamTable) statsChannelName() string {
|
|||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (p *ParamTable) sliceIndex() int {
|
||||
queryNodeID := p.QueryNodeID()
|
||||
queryNodeIDList := p.QueryNodeIDList()
|
||||
for i := 0; i < len(queryNodeIDList); i++ {
|
||||
if queryNodeID == queryNodeIDList[i] {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (p *ParamTable) queryNodeNum() int {
|
||||
return len(p.QueryNodeIDList())
|
||||
}
|
||||
|
|
|
@ -1,109 +1,128 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParamTable_Init(t *testing.T) {
|
||||
Params.Init()
|
||||
}
|
||||
|
||||
func TestParamTable_PulsarAddress(t *testing.T) {
|
||||
Params.Init()
|
||||
address, err := Params.pulsarAddress()
|
||||
assert.NoError(t, err)
|
||||
split := strings.Split(address, ":")
|
||||
assert.Equal(t, "pulsar", split[0])
|
||||
assert.Equal(t, "6650", split[len(split)-1])
|
||||
assert.Equal(t, split[0], "pulsar")
|
||||
assert.Equal(t, split[len(split)-1], "6650")
|
||||
}
|
||||
|
||||
func TestParamTable_QueryNodeID(t *testing.T) {
|
||||
id := Params.QueryNodeID()
|
||||
assert.Contains(t, Params.QueryNodeIDList(), id)
|
||||
Params.Init()
|
||||
id := Params.queryNodeID()
|
||||
assert.Equal(t, id, 0)
|
||||
}
|
||||
|
||||
func TestParamTable_insertChannelRange(t *testing.T) {
|
||||
Params.Init()
|
||||
channelRange := Params.insertChannelRange()
|
||||
assert.Equal(t, 2, len(channelRange))
|
||||
assert.Equal(t, len(channelRange), 2)
|
||||
assert.Equal(t, channelRange[0], 0)
|
||||
assert.Equal(t, channelRange[1], 1)
|
||||
}
|
||||
|
||||
func TestParamTable_statsServiceTimeInterval(t *testing.T) {
|
||||
Params.Init()
|
||||
interval := Params.statsPublishInterval()
|
||||
assert.Equal(t, 1000, interval)
|
||||
assert.Equal(t, interval, 1000)
|
||||
}
|
||||
|
||||
func TestParamTable_statsMsgStreamReceiveBufSize(t *testing.T) {
|
||||
Params.Init()
|
||||
bufSize := Params.statsReceiveBufSize()
|
||||
assert.Equal(t, int64(64), bufSize)
|
||||
assert.Equal(t, bufSize, int64(64))
|
||||
}
|
||||
|
||||
func TestParamTable_insertMsgStreamReceiveBufSize(t *testing.T) {
|
||||
Params.Init()
|
||||
bufSize := Params.insertReceiveBufSize()
|
||||
assert.Equal(t, int64(1024), bufSize)
|
||||
assert.Equal(t, bufSize, int64(1024))
|
||||
}
|
||||
|
||||
func TestParamTable_searchMsgStreamReceiveBufSize(t *testing.T) {
|
||||
Params.Init()
|
||||
bufSize := Params.searchReceiveBufSize()
|
||||
assert.Equal(t, int64(512), bufSize)
|
||||
assert.Equal(t, bufSize, int64(512))
|
||||
}
|
||||
|
||||
func TestParamTable_searchResultMsgStreamReceiveBufSize(t *testing.T) {
|
||||
Params.Init()
|
||||
bufSize := Params.searchResultReceiveBufSize()
|
||||
assert.Equal(t, int64(64), bufSize)
|
||||
assert.Equal(t, bufSize, int64(64))
|
||||
}
|
||||
|
||||
func TestParamTable_searchPulsarBufSize(t *testing.T) {
|
||||
Params.Init()
|
||||
bufSize := Params.searchPulsarBufSize()
|
||||
assert.Equal(t, int64(512), bufSize)
|
||||
assert.Equal(t, bufSize, int64(512))
|
||||
}
|
||||
|
||||
func TestParamTable_insertPulsarBufSize(t *testing.T) {
|
||||
Params.Init()
|
||||
bufSize := Params.insertPulsarBufSize()
|
||||
assert.Equal(t, int64(1024), bufSize)
|
||||
assert.Equal(t, bufSize, int64(1024))
|
||||
}
|
||||
|
||||
func TestParamTable_flowGraphMaxQueueLength(t *testing.T) {
|
||||
Params.Init()
|
||||
length := Params.flowGraphMaxQueueLength()
|
||||
assert.Equal(t, int32(1024), length)
|
||||
assert.Equal(t, length, int32(1024))
|
||||
}
|
||||
|
||||
func TestParamTable_flowGraphMaxParallelism(t *testing.T) {
|
||||
Params.Init()
|
||||
maxParallelism := Params.flowGraphMaxParallelism()
|
||||
assert.Equal(t, int32(1024), maxParallelism)
|
||||
assert.Equal(t, maxParallelism, int32(1024))
|
||||
}
|
||||
|
||||
func TestParamTable_insertChannelNames(t *testing.T) {
|
||||
Params.Init()
|
||||
names := Params.insertChannelNames()
|
||||
channelRange := Params.insertChannelRange()
|
||||
num := channelRange[1] - channelRange[0]
|
||||
num = num / Params.queryNodeNum()
|
||||
assert.Equal(t, num, len(names))
|
||||
start := num * Params.sliceIndex()
|
||||
assert.Equal(t, fmt.Sprintf("insert-%d", channelRange[start]), names[0])
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, names[0], "insert-0")
|
||||
}
|
||||
|
||||
func TestParamTable_searchChannelNames(t *testing.T) {
|
||||
Params.Init()
|
||||
names := Params.searchChannelNames()
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, "search-0", names[0])
|
||||
assert.Equal(t, names[0], "search-0")
|
||||
}
|
||||
|
||||
func TestParamTable_searchResultChannelNames(t *testing.T) {
|
||||
Params.Init()
|
||||
names := Params.searchResultChannelNames()
|
||||
assert.NotNil(t, names)
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, names[0], "searchResult-0")
|
||||
}
|
||||
|
||||
func TestParamTable_msgChannelSubName(t *testing.T) {
|
||||
Params.Init()
|
||||
name := Params.msgChannelSubName()
|
||||
expectName := fmt.Sprintf("queryNode-%d", Params.QueryNodeID())
|
||||
assert.Equal(t, expectName, name)
|
||||
assert.Equal(t, name, "queryNode")
|
||||
}
|
||||
|
||||
func TestParamTable_statsChannelName(t *testing.T) {
|
||||
Params.Init()
|
||||
name := Params.statsChannelName()
|
||||
assert.Equal(t, "query-node-stats", name)
|
||||
assert.Equal(t, name, "query-node-stats")
|
||||
}
|
||||
|
||||
func TestParamTable_metaRootPath(t *testing.T) {
|
||||
Params.Init()
|
||||
path := Params.metaRootPath()
|
||||
assert.Equal(t, "by-dev/meta", path)
|
||||
assert.Equal(t, path, "by-dev/meta")
|
||||
}
|
||||
|
|
|
@ -1,20 +1,77 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
func TestPartition_Segments(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
initTestMeta(t, node, collectionName, collectionID, 0)
|
||||
ctx := context.Background()
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
collection, err := node.replica.getCollectionByName(collectionName)
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
collectionMeta := collection.meta
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
for _, tag := range collectionMeta.PartitionTags {
|
||||
err := (*node.replica).addPartition(collection.ID(), tag)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
partitions := collection.Partitions()
|
||||
assert.Equal(t, len(collectionMeta.PartitionTags), len(*partitions))
|
||||
|
@ -23,12 +80,12 @@ func TestPartition_Segments(t *testing.T) {
|
|||
|
||||
const segmentNum = 3
|
||||
for i := 0; i < segmentNum; i++ {
|
||||
err := node.replica.addSegment(UniqueID(i), targetPartition.partitionTag, collection.ID())
|
||||
err := (*node.replica).addSegment(UniqueID(i), targetPartition.partitionTag, collection.ID())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
segments := targetPartition.Segments()
|
||||
assert.Equal(t, segmentNum+1, len(*segments))
|
||||
assert.Equal(t, segmentNum, len(*segments))
|
||||
}
|
||||
|
||||
func TestPartition_newPartition(t *testing.T) {
|
||||
|
|
|
@ -8,17 +8,59 @@ import (
|
|||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
)
|
||||
|
||||
func TestPlan_Plan(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
|
||||
|
||||
|
@ -32,13 +74,52 @@ func TestPlan_Plan(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPlan_PlaceholderGroup(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
|
||||
|
||||
|
|
|
@ -17,12 +17,11 @@ import (
|
|||
)
|
||||
|
||||
type QueryNode struct {
|
||||
queryNodeLoopCtx context.Context
|
||||
queryNodeLoopCancel func()
|
||||
ctx context.Context
|
||||
|
||||
QueryNodeID uint64
|
||||
|
||||
replica collectionReplica
|
||||
replica *collectionReplica
|
||||
|
||||
dataSyncService *dataSyncService
|
||||
metaService *metaService
|
||||
|
@ -30,14 +29,7 @@ type QueryNode struct {
|
|||
statsService *statsService
|
||||
}
|
||||
|
||||
func Init() {
|
||||
Params.Init()
|
||||
}
|
||||
|
||||
func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode {
|
||||
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
|
||||
segmentsMap := make(map[int64]*Segment)
|
||||
collections := make([]*Collection, 0)
|
||||
|
||||
|
@ -51,11 +43,11 @@ func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode {
|
|||
}
|
||||
|
||||
return &QueryNode{
|
||||
queryNodeLoopCtx: ctx1,
|
||||
queryNodeLoopCancel: cancel,
|
||||
QueryNodeID: queryNodeID,
|
||||
ctx: ctx,
|
||||
|
||||
replica: replica,
|
||||
QueryNodeID: queryNodeID,
|
||||
|
||||
replica: &replica,
|
||||
|
||||
dataSyncService: nil,
|
||||
metaService: nil,
|
||||
|
@ -64,34 +56,31 @@ func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode {
|
|||
}
|
||||
}
|
||||
|
||||
func (node *QueryNode) Start() error {
|
||||
// todo add connectMaster logic
|
||||
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
|
||||
node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica)
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica)
|
||||
func (node *QueryNode) Start() {
|
||||
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
|
||||
node.searchService = newSearchService(node.ctx, node.replica)
|
||||
node.metaService = newMetaService(node.ctx, node.replica)
|
||||
node.statsService = newStatsService(node.ctx, node.replica)
|
||||
|
||||
go node.dataSyncService.start()
|
||||
go node.searchService.start()
|
||||
go node.metaService.start()
|
||||
go node.statsService.start()
|
||||
return nil
|
||||
node.statsService.start()
|
||||
}
|
||||
|
||||
func (node *QueryNode) Close() {
|
||||
node.queryNodeLoopCancel()
|
||||
|
||||
<-node.ctx.Done()
|
||||
// free collectionReplica
|
||||
node.replica.freeAll()
|
||||
(*node.replica).freeAll()
|
||||
|
||||
// close services
|
||||
if node.dataSyncService != nil {
|
||||
node.dataSyncService.close()
|
||||
(*node.dataSyncService).close()
|
||||
}
|
||||
if node.searchService != nil {
|
||||
node.searchService.close()
|
||||
(*node.searchService).close()
|
||||
}
|
||||
if node.statsService != nil {
|
||||
node.statsService.close()
|
||||
(*node.statsService).close()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,93 +2,18 @@ package querynode
|
|||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
const ctxTimeInMillisecond = 200
|
||||
const closeWithDeadline = true
|
||||
|
||||
func setup() {
|
||||
// NOTE: start pulsar and etcd before test
|
||||
func TestQueryNode_start(t *testing.T) {
|
||||
Params.Init()
|
||||
}
|
||||
|
||||
func genTestCollectionMeta(collectionName string, collectionID UniqueID) *etcdpb.CollectionMeta {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "metric_type",
|
||||
Value: "L2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: collectionID,
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
return &collectionMeta
|
||||
}
|
||||
|
||||
func initTestMeta(t *testing.T, node *QueryNode, collectionName string, collectionID UniqueID, segmentID UniqueID) {
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = node.replica.addCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := node.replica.getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
assert.Equal(t, node.replica.getCollectionNum(), 1)
|
||||
|
||||
err = node.replica.addPartition(collection.ID(), collectionMeta.PartitionTags[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = node.replica.addSegment(segmentID, collectionMeta.PartitionTags[0], collectionID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func newQueryNode() *QueryNode {
|
||||
|
||||
var ctx context.Context
|
||||
|
||||
if closeWithDeadline {
|
||||
var cancel context.CancelFunc
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
|
@ -98,21 +23,7 @@ func newQueryNode() *QueryNode {
|
|||
ctx = context.Background()
|
||||
}
|
||||
|
||||
svr := NewQueryNode(ctx, 0)
|
||||
return svr
|
||||
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
setup()
|
||||
exitCode := m.Run()
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
// NOTE: start pulsar and etcd before test
|
||||
func TestQueryNode_Start(t *testing.T) {
|
||||
localNode := newQueryNode()
|
||||
err := localNode.Start()
|
||||
assert.Nil(t, err)
|
||||
localNode.Close()
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.Start()
|
||||
node.Close()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
func Init() {
|
||||
Params.Init()
|
||||
}
|
||||
|
||||
func StartQueryNode(ctx context.Context) {
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
node.Start()
|
||||
}
|
|
@ -9,19 +9,63 @@ import (
|
|||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
)
|
||||
|
||||
func TestReduce_AllFunc(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
segmentID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
assert.Equal(t, segmentID, segment.segmentID)
|
||||
|
||||
const DIM = 16
|
||||
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
|
|
|
@ -4,12 +4,10 @@ import "C"
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
|
@ -21,7 +19,7 @@ type searchService struct {
|
|||
wait sync.WaitGroup
|
||||
cancel context.CancelFunc
|
||||
|
||||
replica collectionReplica
|
||||
replica *collectionReplica
|
||||
tSafeWatcher *tSafeWatcher
|
||||
|
||||
serviceableTime Timestamp
|
||||
|
@ -29,14 +27,13 @@ type searchService struct {
|
|||
|
||||
msgBuffer chan msgstream.TsMsg
|
||||
unsolvedMsg []msgstream.TsMsg
|
||||
searchMsgStream msgstream.MsgStream
|
||||
searchResultMsgStream msgstream.MsgStream
|
||||
queryNodeID UniqueID
|
||||
searchMsgStream *msgstream.MsgStream
|
||||
searchResultMsgStream *msgstream.MsgStream
|
||||
}
|
||||
|
||||
type ResultEntityIds []UniqueID
|
||||
|
||||
func newSearchService(ctx context.Context, replica collectionReplica) *searchService {
|
||||
func newSearchService(ctx context.Context, replica *collectionReplica) *searchService {
|
||||
receiveBufSize := Params.searchReceiveBufSize()
|
||||
pulsarBufSize := Params.searchPulsarBufSize()
|
||||
|
||||
|
@ -72,15 +69,14 @@ func newSearchService(ctx context.Context, replica collectionReplica) *searchSer
|
|||
replica: replica,
|
||||
tSafeWatcher: newTSafeWatcher(),
|
||||
|
||||
searchMsgStream: inputStream,
|
||||
searchResultMsgStream: outputStream,
|
||||
queryNodeID: Params.QueryNodeID(),
|
||||
searchMsgStream: &inputStream,
|
||||
searchResultMsgStream: &outputStream,
|
||||
}
|
||||
}
|
||||
|
||||
func (ss *searchService) start() {
|
||||
ss.searchMsgStream.Start()
|
||||
ss.searchResultMsgStream.Start()
|
||||
(*ss.searchMsgStream).Start()
|
||||
(*ss.searchResultMsgStream).Start()
|
||||
ss.register()
|
||||
ss.wait.Add(2)
|
||||
go ss.receiveSearchMsg()
|
||||
|
@ -89,24 +85,20 @@ func (ss *searchService) start() {
|
|||
}
|
||||
|
||||
func (ss *searchService) close() {
|
||||
if ss.searchMsgStream != nil {
|
||||
ss.searchMsgStream.Close()
|
||||
}
|
||||
if ss.searchResultMsgStream != nil {
|
||||
ss.searchResultMsgStream.Close()
|
||||
}
|
||||
(*ss.searchMsgStream).Close()
|
||||
(*ss.searchResultMsgStream).Close()
|
||||
ss.cancel()
|
||||
}
|
||||
|
||||
func (ss *searchService) register() {
|
||||
tSafe := ss.replica.getTSafe()
|
||||
tSafe.registerTSafeWatcher(ss.tSafeWatcher)
|
||||
tSafe := (*(ss.replica)).getTSafe()
|
||||
(*tSafe).registerTSafeWatcher(ss.tSafeWatcher)
|
||||
}
|
||||
|
||||
func (ss *searchService) waitNewTSafe() Timestamp {
|
||||
// block until dataSyncService updating tSafe
|
||||
ss.tSafeWatcher.hasUpdate()
|
||||
timestamp := ss.replica.getTSafe().get()
|
||||
timestamp := (*(*ss.replica).getTSafe()).get()
|
||||
return timestamp
|
||||
}
|
||||
|
||||
|
@ -130,7 +122,7 @@ func (ss *searchService) receiveSearchMsg() {
|
|||
case <-ss.ctx.Done():
|
||||
return
|
||||
default:
|
||||
msgPack := ss.searchMsgStream.Consume()
|
||||
msgPack := (*ss.searchMsgStream).Consume()
|
||||
if msgPack == nil || len(msgPack.Msgs) <= 0 {
|
||||
continue
|
||||
}
|
||||
|
@ -227,7 +219,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
}
|
||||
collectionName := query.CollectionName
|
||||
partitionTags := query.PartitionTags
|
||||
collection, err := ss.replica.getCollectionByName(collectionName)
|
||||
collection, err := (*ss.replica).getCollectionByName(collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -249,14 +241,14 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
matchedSegments := make([]*Segment, 0)
|
||||
|
||||
for _, partitionTag := range partitionTags {
|
||||
hasPartition := ss.replica.hasPartition(collectionID, partitionTag)
|
||||
hasPartition := (*ss.replica).hasPartition(collectionID, partitionTag)
|
||||
if !hasPartition {
|
||||
return errors.New("search Failed, invalid partitionTag")
|
||||
}
|
||||
}
|
||||
|
||||
for _, partitionTag := range partitionTags {
|
||||
partition, _ := ss.replica.getPartitionByTag(collectionID, partitionTag)
|
||||
partition, _ := (*ss.replica).getPartitionByTag(collectionID, partitionTag)
|
||||
for _, segment := range partition.segments {
|
||||
//fmt.Println("dsl = ", dsl)
|
||||
|
||||
|
@ -276,13 +268,13 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS},
|
||||
ReqID: searchMsg.ReqID,
|
||||
ProxyID: searchMsg.ProxyID,
|
||||
QueryNodeID: ss.queryNodeID,
|
||||
QueryNodeID: searchMsg.ProxyID,
|
||||
Timestamp: searchTimestamp,
|
||||
ResultChannelID: searchMsg.ResultChannelID,
|
||||
Hits: nil,
|
||||
}
|
||||
searchResultMsg := &msgstream.SearchResultMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
|
||||
SearchResult: results,
|
||||
}
|
||||
err = ss.publishSearchResult(searchResultMsg)
|
||||
|
@ -341,7 +333,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
Hits: hits,
|
||||
}
|
||||
searchResultMsg := &msgstream.SearchResultMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
|
||||
SearchResult: results,
|
||||
}
|
||||
err = ss.publishSearchResult(searchResultMsg)
|
||||
|
@ -358,10 +350,9 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
}
|
||||
|
||||
func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error {
|
||||
fmt.Println("Public SearchResult", msg.HashKeys())
|
||||
msgPack := msgstream.MsgPack{}
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
err := ss.searchResultMsgStream.Produce(&msgPack)
|
||||
err := (*ss.searchResultMsgStream).Produce(&msgPack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -386,11 +377,11 @@ func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg, errMsg s
|
|||
}
|
||||
|
||||
tsMsg := &msgstream.SearchResultMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
|
||||
SearchResult: results,
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, tsMsg)
|
||||
err := ss.searchResultMsgStream.Produce(&msgPack)
|
||||
err := (*ss.searchResultMsgStream).Produce(&msgPack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -13,15 +13,80 @@ import (
|
|||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
)
|
||||
|
||||
func TestSearch_Search(t *testing.T) {
|
||||
node := NewQueryNode(context.Background(), 0)
|
||||
initTestMeta(t, node, "collection0", 0, 0)
|
||||
Params.Init()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// init query node
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
// init meta
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test data generate
|
||||
const msgLength = 10
|
||||
|
@ -93,14 +158,14 @@ func TestSearch_Search(t *testing.T) {
|
|||
msgPackSearch := msgstream.MsgPack{}
|
||||
msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg)
|
||||
|
||||
searchStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
|
||||
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
searchStream.SetPulsarClient(pulsarURL)
|
||||
searchStream.CreatePulsarProducers(searchProducerChannels)
|
||||
searchStream.Start()
|
||||
err = searchStream.Produce(&msgPackSearch)
|
||||
assert.NoError(t, err)
|
||||
|
||||
node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica)
|
||||
node.searchService = newSearchService(node.ctx, node.replica)
|
||||
go node.searchService.start()
|
||||
|
||||
// start insert
|
||||
|
@ -170,7 +235,7 @@ func TestSearch_Search(t *testing.T) {
|
|||
timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg)
|
||||
|
||||
// pulsar produce
|
||||
insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
|
||||
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
insertStream.SetPulsarClient(pulsarURL)
|
||||
insertStream.CreatePulsarProducers(insertProducerChannels)
|
||||
insertStream.Start()
|
||||
|
@ -180,19 +245,83 @@ func TestSearch_Search(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
// dataSync
|
||||
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
|
||||
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
|
||||
go node.dataSyncService.start()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
cancel()
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestSearch_SearchMultiSegments(t *testing.T) {
|
||||
node := NewQueryNode(context.Background(), 0)
|
||||
initTestMeta(t, node, "collection0", 0, 0)
|
||||
Params.Init()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// init query node
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
// init meta
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test data generate
|
||||
const msgLength = 1024
|
||||
|
@ -264,14 +393,14 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
|
|||
msgPackSearch := msgstream.MsgPack{}
|
||||
msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg)
|
||||
|
||||
searchStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
|
||||
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
searchStream.SetPulsarClient(pulsarURL)
|
||||
searchStream.CreatePulsarProducers(searchProducerChannels)
|
||||
searchStream.Start()
|
||||
err = searchStream.Produce(&msgPackSearch)
|
||||
assert.NoError(t, err)
|
||||
|
||||
node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica)
|
||||
node.searchService = newSearchService(node.ctx, node.replica)
|
||||
go node.searchService.start()
|
||||
|
||||
// start insert
|
||||
|
@ -345,7 +474,7 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
|
|||
timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg)
|
||||
|
||||
// pulsar produce
|
||||
insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
|
||||
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
insertStream.SetPulsarClient(pulsarURL)
|
||||
insertStream.CreatePulsarProducers(insertProducerChannels)
|
||||
insertStream.Start()
|
||||
|
@ -355,10 +484,11 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
// dataSync
|
||||
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
|
||||
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
|
||||
go node.dataSyncService.start()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
cancel()
|
||||
node.Close()
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"log"
|
||||
"math"
|
||||
|
@ -9,21 +10,61 @@ import (
|
|||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
)
|
||||
|
||||
//-------------------------------------------------------------------------------------- constructor and destructor
|
||||
func TestSegment_newSegment(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -33,15 +74,52 @@ func TestSegment_newSegment(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSegment_deleteSegment(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -53,15 +131,52 @@ func TestSegment_deleteSegment(t *testing.T) {
|
|||
|
||||
//-------------------------------------------------------------------------------------- stats functions
|
||||
func TestSegment_getRowCount(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -104,15 +219,52 @@ func TestSegment_getRowCount(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSegment_getDeletedCount(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -161,15 +313,52 @@ func TestSegment_getDeletedCount(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSegment_getMemSize(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -213,15 +402,53 @@ func TestSegment_getMemSize(t *testing.T) {
|
|||
|
||||
//-------------------------------------------------------------------------------------- dm & search functions
|
||||
func TestSegment_segmentInsert(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
assert.Equal(t, segmentID, segment.segmentID)
|
||||
|
@ -259,15 +486,52 @@ func TestSegment_segmentInsert(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSegment_segmentDelete(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -312,15 +576,55 @@ func TestSegment_segmentDelete(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSegment_segmentSearch(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -357,6 +661,13 @@ func TestSegment_segmentSearch(t *testing.T) {
|
|||
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
|
||||
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
const receiveBufSize = 1024
|
||||
searchProducerChannels := Params.searchChannelNames()
|
||||
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
searchStream.SetPulsarClient(pulsarURL)
|
||||
searchStream.CreatePulsarProducers(searchProducerChannels)
|
||||
|
||||
var searchRawData []byte
|
||||
for _, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
|
@ -397,15 +708,52 @@ func TestSegment_segmentSearch(t *testing.T) {
|
|||
|
||||
//-------------------------------------------------------------------------------------- preDm functions
|
||||
func TestSegment_segmentPreInsert(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -439,15 +787,52 @@ func TestSegment_segmentPreInsert(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSegment_segmentPreDelete(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
|
|
@ -14,11 +14,11 @@ import (
|
|||
|
||||
type statsService struct {
|
||||
ctx context.Context
|
||||
statsStream msgstream.MsgStream
|
||||
replica collectionReplica
|
||||
statsStream *msgstream.MsgStream
|
||||
replica *collectionReplica
|
||||
}
|
||||
|
||||
func newStatsService(ctx context.Context, replica collectionReplica) *statsService {
|
||||
func newStatsService(ctx context.Context, replica *collectionReplica) *statsService {
|
||||
|
||||
return &statsService{
|
||||
ctx: ctx,
|
||||
|
@ -44,8 +44,8 @@ func (sService *statsService) start() {
|
|||
|
||||
var statsMsgStream msgstream.MsgStream = statsStream
|
||||
|
||||
sService.statsStream = statsMsgStream
|
||||
sService.statsStream.Start()
|
||||
sService.statsStream = &statsMsgStream
|
||||
(*sService.statsStream).Start()
|
||||
|
||||
// start service
|
||||
fmt.Println("do segments statistic in ", strconv.Itoa(sleepTimeInterval), "ms")
|
||||
|
@ -60,13 +60,11 @@ func (sService *statsService) start() {
|
|||
}
|
||||
|
||||
func (sService *statsService) close() {
|
||||
if sService.statsStream != nil {
|
||||
sService.statsStream.Close()
|
||||
}
|
||||
(*sService.statsStream).Close()
|
||||
}
|
||||
|
||||
func (sService *statsService) sendSegmentStatistic() {
|
||||
statisticData := sService.replica.getSegmentStatistics()
|
||||
statisticData := (*sService.replica).getSegmentStatistics()
|
||||
|
||||
// fmt.Println("Publish segment statistic")
|
||||
// fmt.Println(statisticData)
|
||||
|
@ -84,7 +82,7 @@ func (sService *statsService) publicStatistic(statistic *internalpb.QueryNodeSeg
|
|||
var msgPack = msgstream.MsgPack{
|
||||
Msgs: []msgstream.TsMsg{msg},
|
||||
}
|
||||
err := sService.statsStream.Produce(&msgPack)
|
||||
err := (*sService.statsStream).Produce(&msgPack)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
|
|
@ -1,42 +1,193 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
// NOTE: start pulsar before test
|
||||
func TestStatsService_start(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
initTestMeta(t, node, "collection0", 0, 0)
|
||||
node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica)
|
||||
Params.Init()
|
||||
var ctx context.Context
|
||||
|
||||
if closeWithDeadline {
|
||||
var cancel context.CancelFunc
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel = context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// init query node
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
// init meta
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// start stats service
|
||||
node.statsService = newStatsService(node.ctx, node.replica)
|
||||
node.statsService.start()
|
||||
node.Close()
|
||||
}
|
||||
|
||||
//NOTE: start pulsar before test
|
||||
// NOTE: start pulsar before test
|
||||
func TestSegmentManagement_SegmentStatisticService(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
initTestMeta(t, node, "collection0", 0, 0)
|
||||
Params.Init()
|
||||
var ctx context.Context
|
||||
|
||||
if closeWithDeadline {
|
||||
var cancel context.CancelFunc
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel = context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// init query node
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
// init meta
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
const receiveBufSize = 1024
|
||||
// start pulsar
|
||||
producerChannels := []string{Params.statsChannelName()}
|
||||
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
|
||||
statsStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
|
||||
statsStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
statsStream.SetPulsarClient(pulsarURL)
|
||||
statsStream.CreatePulsarProducers(producerChannels)
|
||||
|
||||
var statsMsgStream msgstream.MsgStream = statsStream
|
||||
|
||||
node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica)
|
||||
node.statsService.statsStream = statsMsgStream
|
||||
node.statsService.statsStream.Start()
|
||||
node.statsService = newStatsService(node.ctx, node.replica)
|
||||
node.statsService.statsStream = &statsMsgStream
|
||||
(*node.statsService.statsStream).Start()
|
||||
|
||||
// send stats
|
||||
node.statsService.sendSegmentStatistic()
|
||||
node.Close()
|
||||
}
|
||||
|
|
|
@ -36,11 +36,11 @@ type tSafeImpl struct {
|
|||
watcherList []*tSafeWatcher
|
||||
}
|
||||
|
||||
func newTSafe() tSafe {
|
||||
func newTSafe() *tSafe {
|
||||
var t tSafe = &tSafeImpl{
|
||||
watcherList: make([]*tSafeWatcher, 0),
|
||||
}
|
||||
return t
|
||||
return &t
|
||||
}
|
||||
|
||||
func (ts *tSafeImpl) registerTSafeWatcher(t *tSafeWatcher) {
|
||||
|
|
|
@ -9,13 +9,13 @@ import (
|
|||
func TestTSafe_GetAndSet(t *testing.T) {
|
||||
tSafe := newTSafe()
|
||||
watcher := newTSafeWatcher()
|
||||
tSafe.registerTSafeWatcher(watcher)
|
||||
(*tSafe).registerTSafeWatcher(watcher)
|
||||
|
||||
go func() {
|
||||
watcher.hasUpdate()
|
||||
timestamp := tSafe.get()
|
||||
timestamp := (*tSafe).get()
|
||||
assert.Equal(t, timestamp, Timestamp(1000))
|
||||
}()
|
||||
|
||||
tSafe.set(Timestamp(1000))
|
||||
(*tSafe).set(Timestamp(1000))
|
||||
}
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
output
|
||||
cmake-build-debug
|
||||
.idea
|
||||
cmake_build
|
||||
|
|
|
@ -2,17 +2,6 @@ cmake_minimum_required(VERSION 3.14...3.17 FATAL_ERROR)
|
|||
project(wrapper)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
if (NOT GIT_ARROW_REPO)
|
||||
set(GIT_ARROW_REPO "https://github.com/apache/arrow.git")
|
||||
endif ()
|
||||
message(STATUS "Arrow Repo:" ${GIT_ARROW_REPO})
|
||||
|
||||
if (NOT GIT_ARROW_TAG)
|
||||
set(GIT_ARROW_TAG "apache-arrow-2.0.0")
|
||||
endif ()
|
||||
message(STATUS "Arrow Tag:" ${GIT_ARROW_TAG})
|
||||
|
||||
###################################################################################################
|
||||
# - cmake modules ---------------------------------------------------------------------------------
|
||||
|
@ -25,39 +14,29 @@ set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/Modules/" ${CMAKE_MODUL
|
|||
message(STATUS "BUILDING ARROW")
|
||||
include(ConfigureArrow)
|
||||
|
||||
if (ARROW_FOUND)
|
||||
if(ARROW_FOUND)
|
||||
message(STATUS "Apache Arrow found in ${ARROW_INCLUDE_DIR}")
|
||||
else ()
|
||||
else()
|
||||
message(FATAL_ERROR "Apache Arrow not found, please check your settings.")
|
||||
endif (ARROW_FOUND)
|
||||
endif(ARROW_FOUND)
|
||||
|
||||
add_library(arrow STATIC IMPORTED ${ARROW_LIB})
|
||||
add_library(parquet STATIC IMPORTED ${PARQUET_LIB})
|
||||
add_library(thrift STATIC IMPORTED ${THRIFT_LIB})
|
||||
add_library(utf8proc STATIC IMPORTED ${UTF8PROC_LIB})
|
||||
|
||||
if (ARROW_FOUND)
|
||||
if(ARROW_FOUND)
|
||||
set_target_properties(arrow PROPERTIES IMPORTED_LOCATION ${ARROW_LIB})
|
||||
set_target_properties(parquet PROPERTIES IMPORTED_LOCATION ${PARQUET_LIB})
|
||||
set_target_properties(thrift PROPERTIES IMPORTED_LOCATION ${THRIFT_LIB})
|
||||
set_target_properties(utf8proc PROPERTIES IMPORTED_LOCATION ${UTF8PROC_LIB})
|
||||
endif (ARROW_FOUND)
|
||||
endif(ARROW_FOUND)
|
||||
|
||||
###################################################################################################
|
||||
|
||||
include_directories(${ARROW_INCLUDE_DIR})
|
||||
include_directories(${PROJECT_SOURCE_DIR})
|
||||
|
||||
add_library(wrapper STATIC)
|
||||
target_sources(wrapper PUBLIC ParquetWrapper.cpp
|
||||
PayloadStream.cpp)
|
||||
|
||||
target_link_libraries(wrapper PUBLIC parquet arrow thrift utf8proc pthread)
|
||||
|
||||
if(NOT CMAKE_INSTALL_PREFIX)
|
||||
set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR})
|
||||
endif()
|
||||
install(TARGETS wrapper DESTINATION ${CMAKE_INSTALL_PREFIX})
|
||||
install(FILES ${ARROW_LIB} ${PARQUET_LIB} ${THRIFT_LIB} ${UTF8PROC_LIB} DESTINATION ${CMAKE_INSTALL_PREFIX})
|
||||
add_library(wrapper ParquetWrapper.cpp ParquetWrapper.h ColumnType.h PayloadStream.h PayloadStream.cpp)
|
||||
|
||||
add_subdirectory(test)
|
|
@ -206,7 +206,7 @@ extern "C" CStatus AddBinaryVectorToPayload(CPayloadWriter payloadWriter, uint8_
|
|||
st.error_msg = ErrorMsg("payload has finished");
|
||||
return st;
|
||||
}
|
||||
auto ast = builder->AppendValues(values, length);
|
||||
auto ast = builder->AppendValues(values, (dimension / 8) * length);
|
||||
if (!ast.ok()) {
|
||||
st.error_code = static_cast<int>(ErrorCode::UNEXPECTED_ERROR);
|
||||
st.error_msg = ErrorMsg(ast.message());
|
||||
|
@ -249,7 +249,7 @@ extern "C" CStatus AddFloatVectorToPayload(CPayloadWriter payloadWriter, float *
|
|||
st.error_msg = ErrorMsg("payload has finished");
|
||||
return st;
|
||||
}
|
||||
auto ast = builder->AppendValues(reinterpret_cast<const uint8_t *>(values), length);
|
||||
auto ast = builder->AppendValues(reinterpret_cast<const uint8_t *>(values), dimension * length * sizeof(float));
|
||||
if (!ast.ok()) {
|
||||
st.error_code = static_cast<int>(ErrorCode::UNEXPECTED_ERROR);
|
||||
st.error_msg = ErrorMsg(ast.message());
|
||||
|
@ -451,7 +451,7 @@ extern "C" CStatus GetBinaryVectorFromPayload(CPayloadReader payloadReader,
|
|||
return st;
|
||||
}
|
||||
*dimension = array->byte_width() * 8;
|
||||
*length = array->length();
|
||||
*length = array->length() / array->byte_width();
|
||||
*values = (uint8_t *) array->raw_values();
|
||||
return st;
|
||||
}
|
||||
|
@ -470,7 +470,7 @@ extern "C" CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader,
|
|||
return st;
|
||||
}
|
||||
*dimension = array->byte_width() / sizeof(float);
|
||||
*length = array->length();
|
||||
*length = array->length() / array->byte_width();
|
||||
*values = (float *) array->raw_values();
|
||||
return st;
|
||||
}
|
||||
|
@ -478,7 +478,12 @@ extern "C" CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader,
|
|||
extern "C" int GetPayloadLengthFromReader(CPayloadReader payloadReader) {
|
||||
auto p = reinterpret_cast<wrapper::PayloadReader *>(payloadReader);
|
||||
if (p->array == nullptr) return 0;
|
||||
return p->array->length();
|
||||
auto ba = std::dynamic_pointer_cast<arrow::FixedSizeBinaryArray>(p->array);
|
||||
if (ba == nullptr) {
|
||||
return p->array->length();
|
||||
} else {
|
||||
return ba->length() / ba->byte_width();
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" CStatus ReleasePayloadReader(CPayloadReader payloadReader) {
|
||||
|
|
|
@ -5,7 +5,6 @@ extern "C" {
|
|||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
typedef void *CPayloadWriter;
|
||||
|
||||
|
@ -56,4 +55,4 @@ CStatus ReleasePayloadReader(CPayloadReader payloadReader);
|
|||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
|
@ -1,58 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
SOURCE=${BASH_SOURCE[0]}
|
||||
while [ -h $SOURCE ]; do # resolve $SOURCE until the file is no longer a symlink
|
||||
DIR=$( cd -P $( dirname $SOURCE ) && pwd )
|
||||
SOURCE=$(readlink $SOURCE)
|
||||
[[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
|
||||
done
|
||||
DIR=$( cd -P $( dirname $SOURCE ) && pwd )
|
||||
# echo $DIR
|
||||
|
||||
CMAKE_BUILD=${DIR}/cmake_build
|
||||
OUTPUT_LIB=${DIR}/output
|
||||
|
||||
if [ ! -d ${CMAKE_BUILD} ];then
|
||||
mkdir ${CMAKE_BUILD}
|
||||
fi
|
||||
|
||||
if [ -d ${OUTPUT_LIB} ];then
|
||||
rm -rf ${OUTPUT_LIB}
|
||||
fi
|
||||
mkdir ${OUTPUT_LIB}
|
||||
|
||||
BUILD_TYPE="Debug"
|
||||
GIT_ARROW_REPO="https://github.com/apache/arrow.git"
|
||||
GIT_ARROW_TAG="apache-arrow-2.0.0"
|
||||
|
||||
while getopts "a:b:t:h" arg; do
|
||||
case $arg in
|
||||
t)
|
||||
BUILD_TYPE=$OPTARG # BUILD_TYPE
|
||||
;;
|
||||
a)
|
||||
GIT_ARROW_REPO=$OPTARG
|
||||
;;
|
||||
b)
|
||||
GIT_ARROW_TAG=$OPTARG
|
||||
;;
|
||||
h) # help
|
||||
echo "-t: build type(default: Debug)
|
||||
-a: arrow repo(default: https://github.com/apache/arrow.git)
|
||||
-b: arrow tag(default: apache-arrow-2.0.0)
|
||||
-h: help
|
||||
"
|
||||
exit 0
|
||||
;;
|
||||
?)
|
||||
echo "ERROR! unknown argument"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
echo "BUILD_TYPE: " $BUILD_TYPE
|
||||
echo "GIT_ARROW_REPO: " $GIT_ARROW_REPO
|
||||
echo "GIT_ARROW_TAG: " $GIT_ARROW_TAG
|
||||
|
||||
pushd ${CMAKE_BUILD}
|
||||
cmake -DCMAKE_INSTALL_PREFIX=${OUTPUT_LIB} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DGIT_ARROW_REPO=${GIT_ARROW_REPO} -DGIT_ARROW_TAG=${GIT_ARROW_TAG} .. && make && make install
|
|
@ -35,14 +35,11 @@ if(ARROW_CONFIG)
|
|||
message(FATAL_ERROR "Configuring Arrow failed: " ${ARROW_CONFIG})
|
||||
endif(ARROW_CONFIG)
|
||||
|
||||
#set(PARALLEL_BUILD -j)
|
||||
#if($ENV{PARALLEL_LEVEL})
|
||||
# set(NUM_JOBS $ENV{PARALLEL_LEVEL})
|
||||
# set(PARALLEL_BUILD "${PARALLEL_BUILD}${NUM_JOBS}")
|
||||
#endif($ENV{PARALLEL_LEVEL})
|
||||
set(NUM_JOBS 4)
|
||||
set(PARALLEL_BUILD "-j${NUM_JOBS}")
|
||||
|
||||
set(PARALLEL_BUILD -j)
|
||||
if($ENV{PARALLEL_LEVEL})
|
||||
set(NUM_JOBS $ENV{PARALLEL_LEVEL})
|
||||
set(PARALLEL_BUILD "${PARALLEL_BUILD}${NUM_JOBS}")
|
||||
endif($ENV{PARALLEL_LEVEL})
|
||||
|
||||
if(${NUM_JOBS})
|
||||
if(${NUM_JOBS} EQUAL 1)
|
||||
|
@ -91,8 +88,8 @@ if(ARROW_LIB AND PARQUET_LIB AND THRIFT_LIB AND UTF8PROC_LIB)
|
|||
set(ARROW_FOUND TRUE)
|
||||
endif(ARROW_LIB AND PARQUET_LIB AND THRIFT_LIB AND UTF8PROC_LIB)
|
||||
|
||||
# message(STATUS "FlatBuffers installed here: " ${FLATBUFFERS_ROOT})
|
||||
# set(FLATBUFFERS_INCLUDE_DIR "${FLATBUFFERS_ROOT}/include")
|
||||
# set(FLATBUFFERS_LIBRARY_DIR "${FLATBUFFERS_ROOT}/lib")
|
||||
message(STATUS "FlatBuffers installed here: " ${FLATBUFFERS_ROOT})
|
||||
set(FLATBUFFERS_INCLUDE_DIR "${FLATBUFFERS_ROOT}/include")
|
||||
set(FLATBUFFERS_LIBRARY_DIR "${FLATBUFFERS_ROOT}/lib")
|
||||
|
||||
add_definitions(-DARROW_METADATA_V4)
|
||||
|
|
|
@ -20,8 +20,8 @@ project(wrapper-Arrow)
|
|||
include(ExternalProject)
|
||||
|
||||
ExternalProject_Add(Arrow
|
||||
GIT_REPOSITORY ${GIT_ARROW_REPO}
|
||||
GIT_TAG ${GIT_ARROW_TAG}
|
||||
GIT_REPOSITORY https://github.com/apache/arrow.git
|
||||
GIT_TAG apache-arrow-2.0.0
|
||||
GIT_SHALLOW true
|
||||
SOURCE_DIR "${ARROW_ROOT}/arrow"
|
||||
SOURCE_SUBDIR "cpp"
|
||||
|
|
|
@ -14,8 +14,6 @@ target_link_libraries(wrapper_test
|
|||
parquet arrow thrift utf8proc pthread
|
||||
)
|
||||
|
||||
install(TARGETS wrapper_test DESTINATION ${CMAKE_INSTALL_PREFIX})
|
||||
|
||||
# Defines `gtest_discover_tests()`.
|
||||
#include(GoogleTest)
|
||||
#gtest_discover_tests(milvusd_test)
|
|
@ -71,36 +71,36 @@ TEST(wrapper, inoutstream) {
|
|||
}
|
||||
|
||||
TEST(wrapper, boolean) {
|
||||
auto payload = NewPayloadWriter(ColumnType::BOOL);
|
||||
bool data[] = {true, false, true, false};
|
||||
auto payload = NewPayloadWriter(ColumnType::BOOL);
|
||||
bool data[] = {true, false, true, false};
|
||||
|
||||
auto st = AddBooleanToPayload(payload, data, 4);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
st = FinishPayloadWriter(payload);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
auto cb = GetPayloadBufferFromWriter(payload);
|
||||
ASSERT_GT(cb.length, 0);
|
||||
ASSERT_NE(cb.data, nullptr);
|
||||
auto nums = GetPayloadLengthFromWriter(payload);
|
||||
ASSERT_EQ(nums, 4);
|
||||
auto st = AddBooleanToPayload(payload, data, 4);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
st = FinishPayloadWriter(payload);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
auto cb = GetPayloadBufferFromWriter(payload);
|
||||
ASSERT_GT(cb.length, 0);
|
||||
ASSERT_NE(cb.data, nullptr);
|
||||
auto nums = GetPayloadLengthFromWriter(payload);
|
||||
ASSERT_EQ(nums, 4);
|
||||
|
||||
auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length);
|
||||
bool *values;
|
||||
int length;
|
||||
st = GetBoolFromPayload(reader, &values, &length);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
ASSERT_NE(values, nullptr);
|
||||
ASSERT_EQ(length, 4);
|
||||
length = GetPayloadLengthFromReader(reader);
|
||||
ASSERT_EQ(length, 4);
|
||||
for (int i = 0; i < length; i++) {
|
||||
ASSERT_EQ(data[i], values[i]);
|
||||
}
|
||||
auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length);
|
||||
bool *values;
|
||||
int length;
|
||||
st = GetBoolFromPayload(reader, &values, &length);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
ASSERT_NE(values, nullptr);
|
||||
ASSERT_EQ(length, 4);
|
||||
length = GetPayloadLengthFromReader(reader);
|
||||
ASSERT_EQ(length, 4);
|
||||
for (int i = 0; i < length; i++) {
|
||||
ASSERT_EQ(data[i], values[i]);
|
||||
}
|
||||
|
||||
st = ReleasePayloadWriter(payload);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
st = ReleasePayloadReader(reader);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
st = ReleasePayloadWriter(payload);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
st = ReleasePayloadReader(reader);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
}
|
||||
|
||||
#define NUMERIC_TEST(TEST_NAME, COLUMN_TYPE, DATA_TYPE, ADD_FUNC, GET_FUNC, ARRAY_TYPE) TEST(wrapper, TEST_NAME) { \
|
||||
|
|
|
@ -1,626 +0,0 @@
|
|||
package storage
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I${SRCDIR}/cwrapper
|
||||
|
||||
#cgo LDFLAGS: -L${SRCDIR}/cwrapper/output -l:libwrapper.a -l:libparquet.a -l:libarrow.a -l:libthrift.a -l:libutf8proc.a -lstdc++ -lm
|
||||
#include <stdlib.h>
|
||||
#include "ParquetWrapper.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/errors"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
type (
|
||||
PayloadWriter struct {
|
||||
payloadWriterPtr C.CPayloadWriter
|
||||
colType schemapb.DataType
|
||||
}
|
||||
|
||||
PayloadReader struct {
|
||||
payloadReaderPtr C.CPayloadReader
|
||||
colType schemapb.DataType
|
||||
}
|
||||
)
|
||||
|
||||
func NewPayloadWriter(colType schemapb.DataType) (*PayloadWriter, error) {
|
||||
w := C.NewPayloadWriter(C.int(colType))
|
||||
if w == nil {
|
||||
return nil, errors.New("create Payload writer failed")
|
||||
}
|
||||
return &PayloadWriter{payloadWriterPtr: w, colType: colType}, nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddDataToPayload(msgs interface{}, dim ...int) error {
|
||||
switch len(dim) {
|
||||
case 0:
|
||||
switch w.colType {
|
||||
case schemapb.DataType_BOOL:
|
||||
val, ok := msgs.([]bool)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddBoolToPayload(val)
|
||||
|
||||
case schemapb.DataType_INT8:
|
||||
val, ok := msgs.([]int8)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddInt8ToPayload(val)
|
||||
|
||||
case schemapb.DataType_INT16:
|
||||
val, ok := msgs.([]int16)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddInt16ToPayload(val)
|
||||
|
||||
case schemapb.DataType_INT32:
|
||||
val, ok := msgs.([]int32)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddInt32ToPayload(val)
|
||||
|
||||
case schemapb.DataType_INT64:
|
||||
val, ok := msgs.([]int64)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddInt64ToPayload(val)
|
||||
|
||||
case schemapb.DataType_FLOAT:
|
||||
val, ok := msgs.([]float32)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddFloatToPayload(val)
|
||||
|
||||
case schemapb.DataType_DOUBLE:
|
||||
val, ok := msgs.([]float64)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddDoubleToPayload(val)
|
||||
|
||||
case schemapb.DataType_STRING:
|
||||
val, ok := msgs.(string)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddOneStringToPayload(val)
|
||||
}
|
||||
case 1:
|
||||
switch w.colType {
|
||||
case schemapb.DataType_VECTOR_BINARY:
|
||||
val, ok := msgs.([]byte)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddBinaryVectorToPayload(val, dim[0])
|
||||
|
||||
case schemapb.DataType_VECTOR_FLOAT:
|
||||
val, ok := msgs.([]float32)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddFloatVectorToPayload(val, dim[0])
|
||||
}
|
||||
|
||||
default:
|
||||
return errors.New("incorrect input numbers")
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddBoolToPayload(msgs []bool) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.bool)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddBooleanToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddInt8ToPayload(msgs []int8) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
cMsgs := (*C.int8_t)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddInt8ToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddInt16ToPayload(msgs []int16) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.int16_t)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddInt16ToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddInt32ToPayload(msgs []int32) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.int32_t)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddInt32ToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddInt64ToPayload(msgs []int64) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.int64_t)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddInt64ToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddFloatToPayload(msgs []float32) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.float)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddFloatToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddDoubleToPayload(msgs []float64) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.double)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddDoubleToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddOneStringToPayload(msg string) error {
|
||||
length := len(msg)
|
||||
if length == 0 {
|
||||
return errors.New("can't add empty string into payload")
|
||||
}
|
||||
|
||||
cmsg := C.CString(msg)
|
||||
clength := C.int(length)
|
||||
defer C.free(unsafe.Pointer(cmsg))
|
||||
|
||||
st := C.AddOneStringToPayload(w.payloadWriterPtr, cmsg, clength)
|
||||
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// dimension > 0 && (%8 == 0)
|
||||
func (w *PayloadWriter) AddBinaryVectorToPayload(binVec []byte, dim int) error {
|
||||
length := len(binVec)
|
||||
if length <= 0 {
|
||||
return errors.New("can't add empty binVec into payload")
|
||||
}
|
||||
|
||||
if dim <= 0 {
|
||||
return errors.New("dimension should be greater than 0")
|
||||
}
|
||||
|
||||
cBinVec := (*C.uint8_t)(&binVec[0])
|
||||
cDim := C.int(dim)
|
||||
cLength := C.int(length / (dim / 8))
|
||||
|
||||
st := C.AddBinaryVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// dimension > 0 && (%8 == 0)
|
||||
func (w *PayloadWriter) AddFloatVectorToPayload(floatVec []float32, dim int) error {
|
||||
length := len(floatVec)
|
||||
if length <= 0 {
|
||||
return errors.New("can't add empty floatVec into payload")
|
||||
}
|
||||
|
||||
if dim <= 0 {
|
||||
return errors.New("dimension should be greater than 0")
|
||||
}
|
||||
|
||||
cBinVec := (*C.float)(&floatVec[0])
|
||||
cDim := C.int(dim)
|
||||
cLength := C.int(length / dim)
|
||||
|
||||
st := C.AddFloatVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) FinishPayloadWriter() error {
|
||||
st := C.FinishPayloadWriter(w.payloadWriterPtr)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) GetPayloadBufferFromWriter() ([]byte, error) {
|
||||
cb := C.GetPayloadBufferFromWriter(w.payloadWriterPtr)
|
||||
pointer := unsafe.Pointer(cb.data)
|
||||
length := int(cb.length)
|
||||
if length <= 0 {
|
||||
return nil, errors.New("empty buffer")
|
||||
}
|
||||
// refer to: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
|
||||
slice := (*[1 << 28]byte)(pointer)[:length:length]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) GetPayloadLengthFromWriter() (int, error) {
|
||||
length := C.GetPayloadLengthFromWriter(w.payloadWriterPtr)
|
||||
return int(length), nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) ReleasePayloadWriter() error {
|
||||
st := C.ReleasePayloadWriter(w.payloadWriterPtr)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) Close() error {
|
||||
return w.ReleasePayloadWriter()
|
||||
}
|
||||
|
||||
func NewPayloadReader(colType schemapb.DataType, buf []byte) (*PayloadReader, error) {
|
||||
if len(buf) == 0 {
|
||||
return nil, errors.New("create Payload reader failed, buffer is empty")
|
||||
}
|
||||
r := C.NewPayloadReader(C.int(colType), (*C.uchar)(unsafe.Pointer(&buf[0])), C.long(len(buf)))
|
||||
return &PayloadReader{payloadReaderPtr: r, colType: colType}, nil
|
||||
}
|
||||
|
||||
// Params:
|
||||
// `idx`: String index
|
||||
// Return:
|
||||
// `interface{}`: all types.
|
||||
// `int`: length, only meaningful to FLOAT/BINARY VECTOR type.
|
||||
// `error`: error.
|
||||
func (r *PayloadReader) GetDataFromPayload(idx ...int) (interface{}, int, error) {
|
||||
switch len(idx) {
|
||||
case 1:
|
||||
switch r.colType {
|
||||
case schemapb.DataType_STRING:
|
||||
val, err := r.GetOneStringFromPayload(idx[0])
|
||||
return val, 0, err
|
||||
}
|
||||
case 0:
|
||||
switch r.colType {
|
||||
case schemapb.DataType_BOOL:
|
||||
val, err := r.GetBoolFromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_INT8:
|
||||
val, err := r.GetInt8FromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_INT16:
|
||||
val, err := r.GetInt16FromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_INT32:
|
||||
val, err := r.GetInt32FromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_INT64:
|
||||
val, err := r.GetInt64FromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_FLOAT:
|
||||
val, err := r.GetFloatFromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_DOUBLE:
|
||||
val, err := r.GetDoubleFromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_VECTOR_BINARY:
|
||||
return r.GetBinaryVectorFromPayload()
|
||||
|
||||
case schemapb.DataType_VECTOR_FLOAT:
|
||||
return r.GetFloatVectorFromPayload()
|
||||
default:
|
||||
return nil, 0, errors.New("Unknown type")
|
||||
}
|
||||
default:
|
||||
return nil, 0, errors.New("incorrect number of index")
|
||||
}
|
||||
|
||||
return nil, 0, errors.New("unknown error")
|
||||
}
|
||||
|
||||
func (r *PayloadReader) ReleasePayloadReader() error {
|
||||
st := C.ReleasePayloadReader(r.payloadReaderPtr)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetBoolFromPayload() ([]bool, error) {
|
||||
var cMsg *C.bool
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetBoolFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]bool)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetInt8FromPayload() ([]int8, error) {
|
||||
var cMsg *C.int8_t
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetInt8FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]int8)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetInt16FromPayload() ([]int16, error) {
|
||||
var cMsg *C.int16_t
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetInt16FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]int16)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetInt32FromPayload() ([]int32, error) {
|
||||
var cMsg *C.int32_t
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetInt32FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]int32)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetInt64FromPayload() ([]int64, error) {
|
||||
var cMsg *C.int64_t
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetInt64FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]int64)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetFloatFromPayload() ([]float32, error) {
|
||||
var cMsg *C.float
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetFloatFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]float32)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetDoubleFromPayload() ([]float64, error) {
|
||||
var cMsg *C.double
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetDoubleFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]float64)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetOneStringFromPayload(idx int) (string, error) {
|
||||
var cStr *C.char
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetOneStringFromPayload(r.payloadReaderPtr, C.int(idx), &cStr, &cSize)
|
||||
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return "", errors.New(msg)
|
||||
}
|
||||
return C.GoStringN(cStr, cSize), nil
|
||||
}
|
||||
|
||||
// ,dimension, error
|
||||
func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, error) {
|
||||
var cMsg *C.uint8_t
|
||||
var cDim C.int
|
||||
var cLen C.int
|
||||
|
||||
st := C.GetBinaryVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, 0, errors.New(msg)
|
||||
}
|
||||
length := (cDim / 8) * cLen
|
||||
|
||||
slice := (*[1 << 28]byte)(unsafe.Pointer(cMsg))[:length:length]
|
||||
return slice, int(cDim), nil
|
||||
}
|
||||
|
||||
// ,dimension, error
|
||||
func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) {
|
||||
var cMsg *C.float
|
||||
var cDim C.int
|
||||
var cLen C.int
|
||||
|
||||
st := C.GetFloatVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, 0, errors.New(msg)
|
||||
}
|
||||
length := cDim * cLen
|
||||
|
||||
slice := (*[1 << 28]float32)(unsafe.Pointer(cMsg))[:length:length]
|
||||
return slice, int(cDim), nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) {
|
||||
length := C.GetPayloadLengthFromReader(r.payloadReaderPtr)
|
||||
return int(length), nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) Close() error {
|
||||
return r.ReleasePayloadReader()
|
||||
}
|
|
@ -1,426 +0,0 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
func TestPayload_ReaderandWriter(t *testing.T) {
|
||||
|
||||
t.Run("TestBool", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_BOOL)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddBoolToPayload([]bool{false, false, false, false})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]bool{false, false, false, false})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 8, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_BOOL, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 8)
|
||||
bools, err := r.GetBoolFromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []bool{false, false, false, false, false, false, false, false}, bools)
|
||||
ibools, _, err := r.GetDataFromPayload()
|
||||
bools = ibools.([]bool)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []bool{false, false, false, false, false, false, false, false}, bools)
|
||||
defer r.ReleasePayloadReader()
|
||||
|
||||
})
|
||||
|
||||
t.Run("TestInt8", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_INT8)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddInt8ToPayload([]int8{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]int8{4, 5, 6})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_INT8, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
int8s, err := r.GetInt8FromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int8{1, 2, 3, 4, 5, 6}, int8s)
|
||||
|
||||
iint8s, _, err := r.GetDataFromPayload()
|
||||
int8s = iint8s.([]int8)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.ElementsMatch(t, []int8{1, 2, 3, 4, 5, 6}, int8s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestInt16", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_INT16)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddInt16ToPayload([]int16{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]int16{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_INT16, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
int16s, err := r.GetInt16FromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int16{1, 2, 3, 1, 2, 3}, int16s)
|
||||
|
||||
iint16s, _, err := r.GetDataFromPayload()
|
||||
int16s = iint16s.([]int16)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int16{1, 2, 3, 1, 2, 3}, int16s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestInt32", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_INT32)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddInt32ToPayload([]int32{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]int32{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_INT32, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
int32s, err := r.GetInt32FromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int32{1, 2, 3, 1, 2, 3}, int32s)
|
||||
|
||||
iint32s, _, err := r.GetDataFromPayload()
|
||||
int32s = iint32s.([]int32)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int32{1, 2, 3, 1, 2, 3}, int32s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestInt64", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_INT64)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddInt64ToPayload([]int64{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]int64{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_INT64, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
int64s, err := r.GetInt64FromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int64{1, 2, 3, 1, 2, 3}, int64s)
|
||||
|
||||
iint64s, _, err := r.GetDataFromPayload()
|
||||
int64s = iint64s.([]int64)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int64{1, 2, 3, 1, 2, 3}, int64s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestFloat32", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_FLOAT)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddFloatToPayload([]float32{1.0, 2.0, 3.0})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]float32{1.0, 2.0, 3.0})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_FLOAT, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
float32s, err := r.GetFloatFromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float32s)
|
||||
|
||||
ifloat32s, _, err := r.GetDataFromPayload()
|
||||
float32s = ifloat32s.([]float32)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float32s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestDouble", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_DOUBLE)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddDoubleToPayload([]float64{1.0, 2.0, 3.0})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]float64{1.0, 2.0, 3.0})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_DOUBLE, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
float64s, err := r.GetDoubleFromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []float64{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float64s)
|
||||
|
||||
ifloat64s, _, err := r.GetDataFromPayload()
|
||||
float64s = ifloat64s.([]float64)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []float64{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float64s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestAddOneString", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_STRING)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddOneStringToPayload("hello0")
|
||||
assert.Nil(t, err)
|
||||
err = w.AddOneStringToPayload("hello1")
|
||||
assert.Nil(t, err)
|
||||
err = w.AddOneStringToPayload("hello2")
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload("hello3")
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 4)
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_STRING, buffer)
|
||||
assert.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 4)
|
||||
str0, err := r.GetOneStringFromPayload(0)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str0, "hello0")
|
||||
str1, err := r.GetOneStringFromPayload(1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str1, "hello1")
|
||||
str2, err := r.GetOneStringFromPayload(2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str2, "hello2")
|
||||
str3, err := r.GetOneStringFromPayload(3)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str3, "hello3")
|
||||
|
||||
istr0, _, err := r.GetDataFromPayload(0)
|
||||
str0 = istr0.(string)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str0, "hello0")
|
||||
|
||||
istr1, _, err := r.GetDataFromPayload(1)
|
||||
str1 = istr1.(string)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str1, "hello1")
|
||||
|
||||
istr2, _, err := r.GetDataFromPayload(2)
|
||||
str2 = istr2.(string)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str2, "hello2")
|
||||
|
||||
istr3, _, err := r.GetDataFromPayload(3)
|
||||
str3 = istr3.(string)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str3, "hello3")
|
||||
|
||||
err = r.ReleasePayloadReader()
|
||||
assert.Nil(t, err)
|
||||
err = w.ReleasePayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("TestBinaryVector", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_VECTOR_BINARY)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
in := make([]byte, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
in[i] = 1
|
||||
}
|
||||
in2 := make([]byte, 8)
|
||||
for i := 0; i < 8; i++ {
|
||||
in2[i] = 1
|
||||
}
|
||||
|
||||
err = w.AddBinaryVectorToPayload(in, 8)
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload(in2, 8)
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 24, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_VECTOR_BINARY, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 24)
|
||||
|
||||
binVecs, dim, err := r.GetBinaryVectorFromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 8, dim)
|
||||
assert.Equal(t, 24, len(binVecs))
|
||||
fmt.Println(binVecs)
|
||||
|
||||
ibinVecs, dim, err := r.GetDataFromPayload()
|
||||
assert.Nil(t, err)
|
||||
binVecs = ibinVecs.([]byte)
|
||||
assert.Equal(t, 8, dim)
|
||||
assert.Equal(t, 24, len(binVecs))
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestFloatVector", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_VECTOR_FLOAT)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddFloatVectorToPayload([]float32{1.0, 2.0}, 1)
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]float32{3.0, 4.0}, 1)
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_VECTOR_FLOAT, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 4)
|
||||
|
||||
floatVecs, dim, err := r.GetFloatVectorFromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, dim)
|
||||
assert.Equal(t, 4, len(floatVecs))
|
||||
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 4.0}, floatVecs)
|
||||
|
||||
ifloatVecs, dim, err := r.GetDataFromPayload()
|
||||
assert.Nil(t, err)
|
||||
floatVecs = ifloatVecs.([]float32)
|
||||
assert.Equal(t, 1, dim)
|
||||
assert.Equal(t, 4, len(floatVecs))
|
||||
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 4.0}, floatVecs)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
}
|
|
@ -16,17 +16,13 @@ import (
|
|||
"os"
|
||||
"path"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
"github.com/spf13/viper"
|
||||
memkv "github.com/zilliztech/milvus-distributed/internal/kv/mem"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
|
||||
)
|
||||
|
||||
type UniqueID = typeutil.UniqueID
|
||||
|
||||
type Base interface {
|
||||
Load(key string) (string, error)
|
||||
LoadRange(key, endKey string, limit int) ([]string, []string, error)
|
||||
|
@ -42,18 +38,7 @@ type BaseTable struct {
|
|||
|
||||
func (gp *BaseTable) Init() {
|
||||
gp.params = memkv.NewMemoryKV()
|
||||
|
||||
err := gp.LoadYaml("milvus.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = gp.LoadYaml("advanced/common.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = gp.LoadYaml("advanced/channel.yaml")
|
||||
err := gp.LoadYaml("config.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -161,140 +146,3 @@ func (gp *BaseTable) Remove(key string) error {
|
|||
func (gp *BaseTable) Save(key, value string) error {
|
||||
return gp.params.Save(strings.ToLower(key), value)
|
||||
}
|
||||
|
||||
func (gp *BaseTable) ParseFloat(key string) float64 {
|
||||
valueStr, err := gp.Load(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
value, err := strconv.ParseFloat(valueStr, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (gp *BaseTable) ParseInt64(key string) int64 {
|
||||
valueStr, err := gp.Load(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
value, err := strconv.Atoi(valueStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(value)
|
||||
}
|
||||
|
||||
func (gp *BaseTable) ParseInt32(key string) int32 {
|
||||
valueStr, err := gp.Load(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
value, err := strconv.Atoi(valueStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int32(value)
|
||||
}
|
||||
|
||||
func (gp *BaseTable) ParseInt(key string) int {
|
||||
valueStr, err := gp.Load(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
value, err := strconv.Atoi(valueStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (gp *BaseTable) WriteNodeIDList() []UniqueID {
|
||||
proxyIDStr, err := gp.Load("nodeID.writeNodeIDList")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var ret []UniqueID
|
||||
proxyIDs := strings.Split(proxyIDStr, ",")
|
||||
for _, i := range proxyIDs {
|
||||
v, err := strconv.Atoi(i)
|
||||
if err != nil {
|
||||
log.Panicf("load write node id list error, %s", err.Error())
|
||||
}
|
||||
ret = append(ret, UniqueID(v))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (gp *BaseTable) ProxyIDList() []UniqueID {
|
||||
proxyIDStr, err := gp.Load("nodeID.proxyIDList")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var ret []UniqueID
|
||||
proxyIDs := strings.Split(proxyIDStr, ",")
|
||||
for _, i := range proxyIDs {
|
||||
v, err := strconv.Atoi(i)
|
||||
if err != nil {
|
||||
log.Panicf("load proxy id list error, %s", err.Error())
|
||||
}
|
||||
ret = append(ret, UniqueID(v))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (gp *BaseTable) QueryNodeIDList() []UniqueID {
|
||||
queryNodeIDStr, err := gp.Load("nodeID.queryNodeIDList")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var ret []UniqueID
|
||||
queryNodeIDs := strings.Split(queryNodeIDStr, ",")
|
||||
for _, i := range queryNodeIDs {
|
||||
v, err := strconv.Atoi(i)
|
||||
if err != nil {
|
||||
log.Panicf("load proxy id list error, %s", err.Error())
|
||||
}
|
||||
ret = append(ret, UniqueID(v))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// package methods
|
||||
|
||||
func ConvertRangeToIntRange(rangeStr, sep string) []int {
|
||||
items := strings.Split(rangeStr, sep)
|
||||
if len(items) != 2 {
|
||||
panic("Illegal range ")
|
||||
}
|
||||
|
||||
startStr := items[0]
|
||||
endStr := items[1]
|
||||
start, err := strconv.Atoi(startStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
end, err := strconv.Atoi(endStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if start < 0 || end < 0 {
|
||||
panic("Illegal range value")
|
||||
}
|
||||
if start > end {
|
||||
panic("Illegal range value, start > end")
|
||||
}
|
||||
return []int{start, end}
|
||||
}
|
||||
|
||||
func ConvertRangeToIntSlice(rangeStr, sep string) []int {
|
||||
rangeSlice := ConvertRangeToIntRange(rangeStr, sep)
|
||||
start, end := rangeSlice[0], rangeSlice[1]
|
||||
var ret []int
|
||||
for i := start; i < end; i++ {
|
||||
ret = append(ret, i)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
package paramtable
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -22,8 +21,6 @@ var Params = BaseTable{}
|
|||
|
||||
func TestMain(m *testing.M) {
|
||||
Params.Init()
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
//func TestMain
|
||||
|
@ -58,13 +55,13 @@ func TestGlobalParamsTable_SaveAndLoad(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGlobalParamsTable_LoadRange(t *testing.T) {
|
||||
_ = Params.Save("xxxaab", "10")
|
||||
_ = Params.Save("xxxfghz", "20")
|
||||
_ = Params.Save("xxxbcde", "1.1")
|
||||
_ = Params.Save("xxxabcd", "testSaveAndLoad")
|
||||
_ = Params.Save("xxxzhi", "12")
|
||||
_ = Params.Save("abc", "10")
|
||||
_ = Params.Save("fghz", "20")
|
||||
_ = Params.Save("bcde", "1.1")
|
||||
_ = Params.Save("abcd", "testSaveAndLoad")
|
||||
_ = Params.Save("zhi", "12")
|
||||
|
||||
keys, values, err := Params.LoadRange("xxxa", "xxxg", 10)
|
||||
keys, values, err := Params.LoadRange("a", "g", 10)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, len(keys))
|
||||
assert.Equal(t, "10", values[0])
|
||||
|
@ -100,17 +97,24 @@ func TestGlobalParamsTable_Remove(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGlobalParamsTable_LoadYaml(t *testing.T) {
|
||||
err := Params.LoadYaml("milvus.yaml")
|
||||
err := Params.LoadYaml("config.yaml")
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = Params.LoadYaml("advanced/channel.yaml")
|
||||
assert.Nil(t, err)
|
||||
value1, err1 := Params.Load("etcd.address")
|
||||
value2, err2 := Params.Load("pulsar.port")
|
||||
value3, err3 := Params.Load("reader.topicend")
|
||||
value4, err4 := Params.Load("proxy.pulsarTopics.readerTopicPrefix")
|
||||
value5, err5 := Params.Load("proxy.network.address")
|
||||
|
||||
_, err = Params.Load("etcd.address")
|
||||
assert.Nil(t, err)
|
||||
_, err = Params.Load("pulsar.port")
|
||||
assert.Nil(t, err)
|
||||
_, err = Params.Load("msgChannel.channelRange.insert")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, value1, "localhost")
|
||||
assert.Equal(t, value2, "6650")
|
||||
assert.Equal(t, value3, "128")
|
||||
assert.Equal(t, value4, "milvusReader")
|
||||
assert.Equal(t, value5, "0.0.0.0")
|
||||
|
||||
assert.Nil(t, err1)
|
||||
assert.Nil(t, err2)
|
||||
assert.Nil(t, err3)
|
||||
assert.Nil(t, err4)
|
||||
assert.Nil(t, err5)
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue