mirror of https://github.com/milvus-io/milvus.git
Add impl cgo of parquet
Signed-off-by: XuanYang-cn <xuan.yang@zilliz.com>pull/4973/head^2
parent
65ce1f97e7
commit
32977e270c
|
@ -56,3 +56,6 @@ cmake_build/
|
|||
.DS_Store
|
||||
*.swp
|
||||
cwrapper_build
|
||||
**/.clangd/*
|
||||
**/compile_commands.json
|
||||
**/.lint
|
||||
|
|
|
@ -3,7 +3,7 @@ timeout(time: 10, unit: 'MINUTES') {
|
|||
sh '. ./before-install.sh && unset http_proxy && unset https_proxy && ./check_cache.sh -l $CCACHE_ARTFACTORY_URL --cache_dir=\$CCACHE_DIR -f ccache-\$OS_NAME-\$BUILD_ENV_IMAGE_ID.tar.gz || echo \"ccache files not found!\"'
|
||||
}
|
||||
|
||||
sh '. ./scripts/before-install.sh && make check-proto-product && make verifiers && make install'
|
||||
sh '. ./scripts/before-install.sh && make install'
|
||||
|
||||
dir ("scripts") {
|
||||
withCredentials([usernamePassword(credentialsId: "${env.JFROG_CREDENTIALS_ID}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) {
|
||||
|
|
|
@ -4,7 +4,10 @@ try {
|
|||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d pulsar'
|
||||
dir ('build/docker/deploy') {
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} pull'
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d'
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d master'
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d proxy'
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} run -e QUERY_NODE_ID=1 -d querynode'
|
||||
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} run -e QUERY_NODE_ID=2 -d querynode'
|
||||
}
|
||||
|
||||
dir ('build/docker/test') {
|
||||
|
|
6
Makefile
6
Makefile
|
@ -41,9 +41,9 @@ fmt:
|
|||
lint:
|
||||
@echo "Running $@ check"
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint cache clean
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./internal/...
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./cmd/...
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./tests/go/...
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=30m --config ./.golangci.yml ./internal/...
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=30m --config ./.golangci.yml ./cmd/...
|
||||
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=30m --config ./.golangci.yml ./tests/go/...
|
||||
|
||||
ruleguard:
|
||||
@echo "Running $@ check"
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
SOURCE_REPO=milvusdb
|
||||
TARGET_REPO=milvusdb
|
||||
SOURCE_TAG=latest
|
||||
TARGET_TAG=latest
|
|
@ -2,6 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
@ -13,8 +14,7 @@ import (
|
|||
|
||||
func main() {
|
||||
proxy.Init()
|
||||
|
||||
// Creates server.
|
||||
fmt.Println("ProxyID is", proxy.Params.ProxyID())
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
svr, err := proxy.CreateProxy(ctx)
|
||||
if err != nil {
|
||||
|
|
|
@ -2,18 +2,24 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/querynode"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
querynode.Init()
|
||||
fmt.Println("QueryNodeID is", querynode.Params.QueryNodeID())
|
||||
// Creates server.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
svr := querynode.NewQueryNode(ctx, 0)
|
||||
|
||||
sc := make(chan os.Signal, 1)
|
||||
signal.Notify(sc,
|
||||
|
@ -28,8 +34,14 @@ func main() {
|
|||
cancel()
|
||||
}()
|
||||
|
||||
querynode.StartQueryNode(ctx)
|
||||
if err := svr.Start(); err != nil {
|
||||
log.Fatal("run server failed", zap.Error(err))
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
log.Print("Got signal to exit", zap.String("signal", sig.String()))
|
||||
|
||||
svr.Close()
|
||||
switch sig {
|
||||
case syscall.SIGTERM:
|
||||
exit(0)
|
||||
|
|
|
@ -32,7 +32,7 @@ msgChannel:
|
|||
|
||||
# default channel range [0, 1)
|
||||
channelRange:
|
||||
insert: [0, 1]
|
||||
insert: [0, 2]
|
||||
delete: [0, 1]
|
||||
dataDefinition: [0,1]
|
||||
k2s: [0, 1]
|
||||
|
|
|
@ -1,118 +0,0 @@
|
|||
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
master:
|
||||
address: localhost
|
||||
port: 53100
|
||||
pulsarmoniterinterval: 1
|
||||
pulsartopic: "monitor-topic"
|
||||
|
||||
proxyidlist: [1, 2]
|
||||
proxyTimeSyncChannels: ["proxy1", "proxy2"]
|
||||
proxyTimeSyncSubName: "proxy-topic"
|
||||
softTimeTickBarrierInterval: 500
|
||||
|
||||
writeidlist: [3, 4]
|
||||
writeTimeSyncChannels: ["write3", "write4"]
|
||||
writeTimeSyncSubName: "write-topic"
|
||||
|
||||
dmTimeSyncChannels: ["dm5", "dm6"]
|
||||
k2sTimeSyncChannels: ["k2s7", "k2s8"]
|
||||
|
||||
defaultSizePerRecord: 1024
|
||||
minimumAssignSize: 1048576
|
||||
segmentThreshold: 536870912
|
||||
segmentExpireDuration: 2000
|
||||
segmentThresholdFactor: 0.75
|
||||
querynodenum: 1
|
||||
writenodenum: 1
|
||||
statsChannels: "statistic"
|
||||
|
||||
etcd:
|
||||
address: localhost
|
||||
port: 2379
|
||||
rootpath: by-dev
|
||||
segthreshold: 10000
|
||||
|
||||
minio:
|
||||
address: localhost
|
||||
port: 9000
|
||||
accessKeyID: minioadmin
|
||||
secretAccessKey: minioadmin
|
||||
useSSL: false
|
||||
|
||||
timesync:
|
||||
interval: 400
|
||||
|
||||
storage:
|
||||
driver: TIKV
|
||||
address: localhost
|
||||
port: 2379
|
||||
accesskey:
|
||||
secretkey:
|
||||
|
||||
pulsar:
|
||||
authentication: false
|
||||
user: user-default
|
||||
token: eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJKb2UifQ.ipevRNuRP6HflG8cFKnmUPtypruRC4fb1DWtoLL62SY
|
||||
address: localhost
|
||||
port: 6650
|
||||
topicnum: 128
|
||||
|
||||
reader:
|
||||
clientid: 0
|
||||
stopflag: -1
|
||||
readerqueuesize: 10000
|
||||
searchchansize: 10000
|
||||
key2segchansize: 10000
|
||||
topicstart: 0
|
||||
topicend: 128
|
||||
|
||||
writer:
|
||||
clientid: 0
|
||||
stopflag: -2
|
||||
readerqueuesize: 10000
|
||||
searchbyidchansize: 10000
|
||||
parallelism: 100
|
||||
topicstart: 0
|
||||
topicend: 128
|
||||
bucket: "zilliz-hz"
|
||||
|
||||
proxy:
|
||||
timezone: UTC+8
|
||||
proxy_id: 1
|
||||
numReaderNodes: 2
|
||||
tsoSaveInterval: 200
|
||||
timeTickInterval: 200
|
||||
|
||||
pulsarTopics:
|
||||
readerTopicPrefix: "milvusReader"
|
||||
numReaderTopics: 2
|
||||
deleteTopic: "milvusDeleter"
|
||||
queryTopic: "milvusQuery"
|
||||
resultTopic: "milvusResult"
|
||||
resultGroup: "milvusResultGroup"
|
||||
timeTickTopic: "milvusTimeTick"
|
||||
|
||||
network:
|
||||
address: 0.0.0.0
|
||||
port: 19530
|
||||
|
||||
logs:
|
||||
level: debug
|
||||
trace.enable: true
|
||||
path: /tmp/logs
|
||||
max_log_file_size: 1024MB
|
||||
log_rotate_num: 0
|
||||
|
||||
storage:
|
||||
path: /var/lib/milvus
|
||||
auto_flush_interval: 1
|
|
@ -12,7 +12,7 @@
|
|||
|
||||
nodeID: # will be deprecated after v0.2
|
||||
proxyIDList: [0]
|
||||
queryNodeIDList: [2]
|
||||
queryNodeIDList: [1, 2]
|
||||
writeNodeIDList: [3]
|
||||
|
||||
etcd:
|
||||
|
@ -23,6 +23,13 @@ etcd:
|
|||
kvSubPath: kv # kvRootPath = rootPath + '/' + kvSubPath
|
||||
segThreshold: 10000
|
||||
|
||||
minio:
|
||||
address: localhost
|
||||
port: 9000
|
||||
accessKeyID: minioadmin
|
||||
secretAccessKey: minioadmin
|
||||
useSSL: false
|
||||
|
||||
pulsar:
|
||||
address: localhost
|
||||
port: 6650
|
||||
|
|
|
@ -1,150 +0,0 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"path"
|
||||
"runtime"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
|
||||
|
||||
storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
type UniqueID = typeutil.UniqueID
|
||||
|
||||
// yaml.MapSlice
|
||||
|
||||
type MasterConfig struct {
|
||||
Address string
|
||||
Port int32
|
||||
PulsarMonitorInterval int32
|
||||
PulsarTopic string
|
||||
SegmentThreshold float32
|
||||
SegmentExpireDuration int64
|
||||
ProxyIDList []UniqueID
|
||||
QueryNodeNum int
|
||||
WriteNodeNum int
|
||||
}
|
||||
|
||||
type EtcdConfig struct {
|
||||
Address string
|
||||
Port int32
|
||||
Rootpath string
|
||||
Segthreshold int64
|
||||
}
|
||||
|
||||
type TimeSyncConfig struct {
|
||||
Interval int32
|
||||
}
|
||||
|
||||
type StorageConfig struct {
|
||||
Driver storagetype.DriverType
|
||||
Address string
|
||||
Port int32
|
||||
Accesskey string
|
||||
Secretkey string
|
||||
}
|
||||
|
||||
type PulsarConfig struct {
|
||||
Authentication bool
|
||||
User string
|
||||
Token string
|
||||
Address string
|
||||
Port int32
|
||||
TopicNum int
|
||||
}
|
||||
|
||||
type ProxyConfig struct {
|
||||
Timezone string `yaml:"timezone"`
|
||||
ProxyID int `yaml:"proxy_id"`
|
||||
NumReaderNodes int `yaml:"numReaderNodes"`
|
||||
TosSaveInterval int `yaml:"tsoSaveInterval"`
|
||||
TimeTickInterval int `yaml:"timeTickInterval"`
|
||||
PulsarTopics struct {
|
||||
ReaderTopicPrefix string `yaml:"readerTopicPrefix"`
|
||||
NumReaderTopics int `yaml:"numReaderTopics"`
|
||||
DeleteTopic string `yaml:"deleteTopic"`
|
||||
QueryTopic string `yaml:"queryTopic"`
|
||||
ResultTopic string `yaml:"resultTopic"`
|
||||
ResultGroup string `yaml:"resultGroup"`
|
||||
TimeTickTopic string `yaml:"timeTickTopic"`
|
||||
} `yaml:"pulsarTopics"`
|
||||
Network struct {
|
||||
Address string `yaml:"address"`
|
||||
Port int `yaml:"port"`
|
||||
} `yaml:"network"`
|
||||
Logs struct {
|
||||
Level string `yaml:"level"`
|
||||
TraceEnable bool `yaml:"trace.enable"`
|
||||
Path string `yaml:"path"`
|
||||
MaxLogFileSize string `yaml:"max_log_file_size"`
|
||||
LogRotateNum int `yaml:"log_rotate_num"`
|
||||
} `yaml:"logs"`
|
||||
Storage struct {
|
||||
Path string `yaml:"path"`
|
||||
AutoFlushInterval int `yaml:"auto_flush_interval"`
|
||||
} `yaml:"storage"`
|
||||
}
|
||||
|
||||
type Reader struct {
|
||||
ClientID int
|
||||
StopFlag int64
|
||||
ReaderQueueSize int
|
||||
SearchChanSize int
|
||||
Key2SegChanSize int
|
||||
TopicStart int
|
||||
TopicEnd int
|
||||
}
|
||||
|
||||
type Writer struct {
|
||||
ClientID int
|
||||
StopFlag int64
|
||||
ReaderQueueSize int
|
||||
SearchByIDChanSize int
|
||||
Parallelism int
|
||||
TopicStart int
|
||||
TopicEnd int
|
||||
Bucket string
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Master MasterConfig
|
||||
Etcd EtcdConfig
|
||||
Timesync TimeSyncConfig
|
||||
Storage StorageConfig
|
||||
Pulsar PulsarConfig
|
||||
Writer Writer
|
||||
Reader Reader
|
||||
Proxy ProxyConfig
|
||||
}
|
||||
|
||||
var Config ServerConfig
|
||||
|
||||
// func init() {
|
||||
// load_config()
|
||||
// }
|
||||
|
||||
func getConfigsDir() string {
|
||||
_, fpath, _, _ := runtime.Caller(0)
|
||||
configPath := path.Dir(fpath) + "/../../configs/"
|
||||
configPath = path.Dir(configPath)
|
||||
return configPath
|
||||
}
|
||||
|
||||
func LoadConfigWithPath(yamlFilePath string) {
|
||||
source, err := ioutil.ReadFile(yamlFilePath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = yaml.Unmarshal(source, &Config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
//fmt.Printf("Result: %v\n", Config)
|
||||
}
|
||||
|
||||
func LoadConfig(yamlFile string) {
|
||||
filePath := path.Join(getConfigsDir(), yamlFile)
|
||||
LoadConfigWithPath(filePath)
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
package conf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
fmt.Printf("Result: %v\n", Config)
|
||||
}
|
|
@ -1,14 +0,0 @@
|
|||
package mockkv
|
||||
|
||||
import (
|
||||
memkv "github.com/zilliztech/milvus-distributed/internal/kv/mem"
|
||||
)
|
||||
|
||||
// use MemoryKV to mock EtcdKV
|
||||
func NewEtcdKV() *memkv.MemoryKV {
|
||||
return memkv.NewMemoryKV()
|
||||
}
|
||||
|
||||
func NewMemoryKV() *memkv.MemoryKV {
|
||||
return memkv.NewMemoryKV()
|
||||
}
|
|
@ -55,19 +55,8 @@ var Params ParamTable
|
|||
func (p *ParamTable) Init() {
|
||||
// load yaml
|
||||
p.BaseTable.Init()
|
||||
err := p.LoadYaml("milvus.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = p.LoadYaml("advanced/channel.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = p.LoadYaml("advanced/master.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = p.LoadYaml("advanced/common.yaml")
|
||||
|
||||
err := p.LoadYaml("advanced/master.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -115,15 +104,7 @@ func (p *ParamTable) initAddress() {
|
|||
}
|
||||
|
||||
func (p *ParamTable) initPort() {
|
||||
masterPort, err := p.Load("master.port")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
port, err := strconv.Atoi(masterPort)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.Port = port
|
||||
p.Port = p.ParseInt("master.port")
|
||||
}
|
||||
|
||||
func (p *ParamTable) initEtcdAddress() {
|
||||
|
@ -167,117 +148,40 @@ func (p *ParamTable) initKvRootPath() {
|
|||
}
|
||||
|
||||
func (p *ParamTable) initTopicNum() {
|
||||
insertChannelRange, err := p.Load("msgChannel.channelRange.insert")
|
||||
iRangeStr, err := p.Load("msgChannel.channelRange.insert")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
channelRange := strings.Split(insertChannelRange, ",")
|
||||
if len(channelRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(channelRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(channelRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
p.TopicNum = channelEnd
|
||||
rangeSlice := paramtable.ConvertRangeToIntRange(iRangeStr, ",")
|
||||
p.TopicNum = rangeSlice[1] - rangeSlice[0]
|
||||
}
|
||||
|
||||
func (p *ParamTable) initSegmentSize() {
|
||||
threshold, err := p.Load("master.segment.size")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
segmentThreshold, err := strconv.ParseFloat(threshold, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.SegmentSize = segmentThreshold
|
||||
p.SegmentSize = p.ParseFloat("master.segment.size")
|
||||
}
|
||||
|
||||
func (p *ParamTable) initSegmentSizeFactor() {
|
||||
segFactor, err := p.Load("master.segment.sizeFactor")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
factor, err := strconv.ParseFloat(segFactor, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.SegmentSizeFactor = factor
|
||||
p.SegmentSizeFactor = p.ParseFloat("master.segment.sizeFactor")
|
||||
}
|
||||
|
||||
func (p *ParamTable) initDefaultRecordSize() {
|
||||
size, err := p.Load("master.segment.defaultSizePerRecord")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
res, err := strconv.ParseInt(size, 10, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.DefaultRecordSize = res
|
||||
p.DefaultRecordSize = p.ParseInt64("master.segment.defaultSizePerRecord")
|
||||
}
|
||||
|
||||
func (p *ParamTable) initMinSegIDAssignCnt() {
|
||||
size, err := p.Load("master.segment.minIDAssignCnt")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
res, err := strconv.ParseInt(size, 10, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.MinSegIDAssignCnt = res
|
||||
p.MinSegIDAssignCnt = p.ParseInt64("master.segment.minIDAssignCnt")
|
||||
}
|
||||
|
||||
func (p *ParamTable) initMaxSegIDAssignCnt() {
|
||||
size, err := p.Load("master.segment.maxIDAssignCnt")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
res, err := strconv.ParseInt(size, 10, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.MaxSegIDAssignCnt = res
|
||||
p.MaxSegIDAssignCnt = p.ParseInt64("master.segment.maxIDAssignCnt")
|
||||
}
|
||||
|
||||
func (p *ParamTable) initSegIDAssignExpiration() {
|
||||
duration, err := p.Load("master.segment.IDAssignExpiration")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
res, err := strconv.ParseInt(duration, 10, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.SegIDAssignExpiration = res
|
||||
p.SegIDAssignExpiration = p.ParseInt64("master.segment.IDAssignExpiration")
|
||||
}
|
||||
|
||||
func (p *ParamTable) initQueryNodeNum() {
|
||||
id, err := p.Load("nodeID.queryNodeIDList")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ids := strings.Split(id, ",")
|
||||
for _, i := range ids {
|
||||
_, err := strconv.ParseInt(i, 10, 64)
|
||||
if err != nil {
|
||||
log.Panicf("load proxy id list error, %s", err.Error())
|
||||
}
|
||||
}
|
||||
p.QueryNodeNum = len(ids)
|
||||
p.QueryNodeNum = len(p.QueryNodeIDList())
|
||||
}
|
||||
|
||||
func (p *ParamTable) initQueryNodeStatsChannelName() {
|
||||
|
@ -289,20 +193,7 @@ func (p *ParamTable) initQueryNodeStatsChannelName() {
|
|||
}
|
||||
|
||||
func (p *ParamTable) initProxyIDList() {
|
||||
id, err := p.Load("nodeID.proxyIDList")
|
||||
if err != nil {
|
||||
log.Panicf("load proxy id list error, %s", err.Error())
|
||||
}
|
||||
ids := strings.Split(id, ",")
|
||||
idList := make([]typeutil.UniqueID, 0, len(ids))
|
||||
for _, i := range ids {
|
||||
v, err := strconv.ParseInt(i, 10, 64)
|
||||
if err != nil {
|
||||
log.Panicf("load proxy id list error, %s", err.Error())
|
||||
}
|
||||
idList = append(idList, typeutil.UniqueID(v))
|
||||
}
|
||||
p.ProxyIDList = idList
|
||||
p.ProxyIDList = p.BaseTable.ProxyIDList()
|
||||
}
|
||||
|
||||
func (p *ParamTable) initProxyTimeTickChannelNames() {
|
||||
|
@ -347,20 +238,7 @@ func (p *ParamTable) initSoftTimeTickBarrierInterval() {
|
|||
}
|
||||
|
||||
func (p *ParamTable) initWriteNodeIDList() {
|
||||
id, err := p.Load("nodeID.writeNodeIDList")
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
ids := strings.Split(id, ",")
|
||||
idlist := make([]typeutil.UniqueID, 0, len(ids))
|
||||
for _, i := range ids {
|
||||
v, err := strconv.ParseInt(i, 10, 64)
|
||||
if err != nil {
|
||||
log.Panicf("load proxy id list error, %s", err.Error())
|
||||
}
|
||||
idlist = append(idlist, typeutil.UniqueID(v))
|
||||
}
|
||||
p.WriteNodeIDList = idlist
|
||||
p.WriteNodeIDList = p.BaseTable.WriteNodeIDList()
|
||||
}
|
||||
|
||||
func (p *ParamTable) initWriteNodeTimeTickChannelNames() {
|
||||
|
@ -385,81 +263,57 @@ func (p *ParamTable) initWriteNodeTimeTickChannelNames() {
|
|||
}
|
||||
|
||||
func (p *ParamTable) initDDChannelNames() {
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.dataDefinition")
|
||||
prefix, err := p.Load("msgChannel.chanNamePrefix.dataDefinition")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
panic(err)
|
||||
}
|
||||
id, err := p.Load("nodeID.queryNodeIDList")
|
||||
prefix += "-"
|
||||
iRangeStr, err := p.Load("msgChannel.channelRange.dataDefinition")
|
||||
if err != nil {
|
||||
log.Panicf("load query node id list error, %s", err.Error())
|
||||
panic(err)
|
||||
}
|
||||
ids := strings.Split(id, ",")
|
||||
channels := make([]string, 0, len(ids))
|
||||
for _, i := range ids {
|
||||
_, err := strconv.ParseInt(i, 10, 64)
|
||||
if err != nil {
|
||||
log.Panicf("load query node id list error, %s", err.Error())
|
||||
}
|
||||
channels = append(channels, ch+"-"+i)
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
}
|
||||
p.DDChannelNames = channels
|
||||
p.DDChannelNames = ret
|
||||
}
|
||||
|
||||
func (p *ParamTable) initInsertChannelNames() {
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.insert")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := p.Load("msgChannel.channelRange.insert")
|
||||
prefix, err := p.Load("msgChannel.chanNamePrefix.insert")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
prefix += "-"
|
||||
iRangeStr, err := p.Load("msgChannel.channelRange.insert")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
p.InsertChannelNames = channels
|
||||
p.InsertChannelNames = ret
|
||||
}
|
||||
|
||||
func (p *ParamTable) initK2SChannelNames() {
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.k2s")
|
||||
prefix, err := p.Load("msgChannel.chanNamePrefix.k2s")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
panic(err)
|
||||
}
|
||||
id, err := p.Load("nodeID.writeNodeIDList")
|
||||
prefix += "-"
|
||||
iRangeStr, err := p.Load("msgChannel.channelRange.k2s")
|
||||
if err != nil {
|
||||
log.Panicf("load write node id list error, %s", err.Error())
|
||||
panic(err)
|
||||
}
|
||||
ids := strings.Split(id, ",")
|
||||
channels := make([]string, 0, len(ids))
|
||||
for _, i := range ids {
|
||||
_, err := strconv.ParseInt(i, 10, 64)
|
||||
if err != nil {
|
||||
log.Panicf("load write node id list error, %s", err.Error())
|
||||
}
|
||||
channels = append(channels, ch+"-"+i)
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
}
|
||||
p.K2SChannelNames = channels
|
||||
p.K2SChannelNames = ret
|
||||
}
|
||||
|
||||
func (p *ParamTable) initMaxPartitionNum() {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package master
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -32,7 +33,7 @@ func TestParamTable_KVRootPath(t *testing.T) {
|
|||
|
||||
func TestParamTable_TopicNum(t *testing.T) {
|
||||
num := Params.TopicNum
|
||||
assert.Equal(t, num, 1)
|
||||
fmt.Println("TopicNum:", num)
|
||||
}
|
||||
|
||||
func TestParamTable_SegmentSize(t *testing.T) {
|
||||
|
@ -67,7 +68,7 @@ func TestParamTable_SegIDAssignExpiration(t *testing.T) {
|
|||
|
||||
func TestParamTable_QueryNodeNum(t *testing.T) {
|
||||
num := Params.QueryNodeNum
|
||||
assert.Equal(t, num, 1)
|
||||
fmt.Println("QueryNodeNum", num)
|
||||
}
|
||||
|
||||
func TestParamTable_QueryNodeStatsChannelName(t *testing.T) {
|
||||
|
@ -111,12 +112,11 @@ func TestParamTable_WriteNodeTimeTickChannelNames(t *testing.T) {
|
|||
|
||||
func TestParamTable_InsertChannelNames(t *testing.T) {
|
||||
names := Params.InsertChannelNames
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, names[0], "insert-0")
|
||||
assert.Equal(t, Params.TopicNum, len(names))
|
||||
}
|
||||
|
||||
func TestParamTable_K2SChannelNames(t *testing.T) {
|
||||
names := Params.K2SChannelNames
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, names[0], "k2s-3")
|
||||
assert.Equal(t, names[0], "k2s-0")
|
||||
}
|
||||
|
|
|
@ -19,19 +19,8 @@ var Params ParamTable
|
|||
|
||||
func (pt *ParamTable) Init() {
|
||||
pt.BaseTable.Init()
|
||||
err := pt.LoadYaml("milvus.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = pt.LoadYaml("advanced/proxy.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = pt.LoadYaml("advanced/channel.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = pt.LoadYaml("advanced/common.yaml")
|
||||
|
||||
err := pt.LoadYaml("advanced/proxy.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -48,15 +37,24 @@ func (pt *ParamTable) Init() {
|
|||
pt.Save("_proxyID", proxyIDStr)
|
||||
}
|
||||
|
||||
func (pt *ParamTable) NetWorkAddress() string {
|
||||
addr, err := pt.Load("proxy.network.address")
|
||||
func (pt *ParamTable) NetworkPort() int {
|
||||
return pt.ParseInt("proxy.port")
|
||||
}
|
||||
|
||||
func (pt *ParamTable) NetworkAddress() string {
|
||||
addr, err := pt.Load("proxy.address")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if ip := net.ParseIP(addr); ip == nil {
|
||||
panic("invalid ip proxy.network.address")
|
||||
|
||||
hostName, _ := net.LookupHost(addr)
|
||||
if len(hostName) <= 0 {
|
||||
if ip := net.ParseIP(addr); ip == nil {
|
||||
panic("invalid ip proxy.address")
|
||||
}
|
||||
}
|
||||
port, err := pt.Load("proxy.network.port")
|
||||
|
||||
port, err := pt.Load("proxy.port")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -88,23 +86,6 @@ func (pt *ParamTable) ProxyNum() int {
|
|||
return len(ret)
|
||||
}
|
||||
|
||||
func (pt *ParamTable) ProxyIDList() []UniqueID {
|
||||
proxyIDStr, err := pt.Load("nodeID.proxyIDList")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var ret []UniqueID
|
||||
proxyIDs := strings.Split(proxyIDStr, ",")
|
||||
for _, i := range proxyIDs {
|
||||
v, err := strconv.Atoi(i)
|
||||
if err != nil {
|
||||
log.Panicf("load proxy id list error, %s", err.Error())
|
||||
}
|
||||
ret = append(ret, UniqueID(v))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (pt *ParamTable) queryNodeNum() int {
|
||||
return len(pt.queryNodeIDList())
|
||||
}
|
||||
|
@ -150,25 +131,6 @@ func (pt *ParamTable) TimeTickInterval() time.Duration {
|
|||
return time.Duration(interval) * time.Millisecond
|
||||
}
|
||||
|
||||
func (pt *ParamTable) convertRangeToSlice(rangeStr, sep string) []int {
|
||||
channelIDs := strings.Split(rangeStr, sep)
|
||||
startStr := channelIDs[0]
|
||||
endStr := channelIDs[1]
|
||||
start, err := strconv.Atoi(startStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
end, err := strconv.Atoi(endStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var ret []int
|
||||
for i := start; i < end; i++ {
|
||||
ret = append(ret, i)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (pt *ParamTable) sliceIndex() int {
|
||||
proxyID := pt.ProxyID()
|
||||
proxyIDList := pt.ProxyIDList()
|
||||
|
@ -190,7 +152,7 @@ func (pt *ParamTable) InsertChannelNames() []string {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelIDs := pt.convertRangeToSlice(iRangeStr, ",")
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
|
@ -216,19 +178,12 @@ func (pt *ParamTable) DeleteChannelNames() []string {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelIDs := pt.convertRangeToSlice(dRangeStr, ",")
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(dRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
}
|
||||
proxyNum := pt.ProxyNum()
|
||||
sep := len(channelIDs) / proxyNum
|
||||
index := pt.sliceIndex()
|
||||
if index == -1 {
|
||||
panic("ProxyID not Match with Config")
|
||||
}
|
||||
start := index * sep
|
||||
return ret[start : start+sep]
|
||||
return ret
|
||||
}
|
||||
|
||||
func (pt *ParamTable) K2SChannelNames() []string {
|
||||
|
@ -241,19 +196,12 @@ func (pt *ParamTable) K2SChannelNames() []string {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelIDs := pt.convertRangeToSlice(k2sRangeStr, ",")
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(k2sRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
}
|
||||
proxyNum := pt.ProxyNum()
|
||||
sep := len(channelIDs) / proxyNum
|
||||
index := pt.sliceIndex()
|
||||
if index == -1 {
|
||||
panic("ProxyID not Match with Config")
|
||||
}
|
||||
start := index * sep
|
||||
return ret[start : start+sep]
|
||||
return ret
|
||||
}
|
||||
|
||||
func (pt *ParamTable) SearchChannelNames() []string {
|
||||
|
@ -261,8 +209,17 @@ func (pt *ParamTable) SearchChannelNames() []string {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
prefix += "-0"
|
||||
return []string{prefix}
|
||||
prefix += "-"
|
||||
sRangeStr, err := pt.Load("msgChannel.channelRange.search")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(sRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (pt *ParamTable) SearchResultChannelNames() []string {
|
||||
|
@ -275,7 +232,7 @@ func (pt *ParamTable) SearchResultChannelNames() []string {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelIDs := pt.convertRangeToSlice(sRangeStr, ",")
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(sRangeStr, ",")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
|
@ -321,144 +278,24 @@ func (pt *ParamTable) DataDefinitionChannelNames() []string {
|
|||
return []string{prefix}
|
||||
}
|
||||
|
||||
func (pt *ParamTable) parseInt64(key string) int64 {
|
||||
valueStr, err := pt.Load(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
value, err := strconv.Atoi(valueStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(value)
|
||||
}
|
||||
|
||||
func (pt *ParamTable) MsgStreamInsertBufSize() int64 {
|
||||
return pt.parseInt64("proxy.msgStream.insert.bufSize")
|
||||
return pt.ParseInt64("proxy.msgStream.insert.bufSize")
|
||||
}
|
||||
|
||||
func (pt *ParamTable) MsgStreamSearchBufSize() int64 {
|
||||
return pt.parseInt64("proxy.msgStream.search.bufSize")
|
||||
return pt.ParseInt64("proxy.msgStream.search.bufSize")
|
||||
}
|
||||
|
||||
func (pt *ParamTable) MsgStreamSearchResultBufSize() int64 {
|
||||
return pt.parseInt64("proxy.msgStream.searchResult.recvBufSize")
|
||||
return pt.ParseInt64("proxy.msgStream.searchResult.recvBufSize")
|
||||
}
|
||||
|
||||
func (pt *ParamTable) MsgStreamSearchResultPulsarBufSize() int64 {
|
||||
return pt.parseInt64("proxy.msgStream.searchResult.pulsarBufSize")
|
||||
return pt.ParseInt64("proxy.msgStream.searchResult.pulsarBufSize")
|
||||
}
|
||||
|
||||
func (pt *ParamTable) MsgStreamTimeTickBufSize() int64 {
|
||||
return pt.parseInt64("proxy.msgStream.timeTick.bufSize")
|
||||
}
|
||||
|
||||
func (pt *ParamTable) insertChannelNames() []string {
|
||||
ch, err := pt.Load("msgChannel.chanNamePrefix.insert")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := pt.Load("msgChannel.channelRange.insert")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (pt *ParamTable) searchChannelNames() []string {
|
||||
ch, err := pt.Load("msgChannel.chanNamePrefix.search")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := pt.Load("msgChannel.channelRange.search")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (pt *ParamTable) searchResultChannelNames() []string {
|
||||
ch, err := pt.Load("msgChannel.chanNamePrefix.searchResult")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
channelRange, err := pt.Load("msgChannel.channelRange.searchResult")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
return pt.ParseInt64("proxy.msgStream.timeTick.bufSize")
|
||||
}
|
||||
|
||||
func (pt *ParamTable) MaxNameLength() int64 {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -59,7 +60,7 @@ func CreateProxy(ctx context.Context) (*Proxy, error) {
|
|||
|
||||
p.queryMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamSearchBufSize())
|
||||
p.queryMsgStream.SetPulsarClient(pulsarAddress)
|
||||
p.queryMsgStream.CreatePulsarProducers(Params.searchChannelNames())
|
||||
p.queryMsgStream.CreatePulsarProducers(Params.SearchChannelNames())
|
||||
|
||||
masterAddr := Params.MasterAddress()
|
||||
idAllocator, err := allocator.NewIDAllocator(p.proxyLoopCtx, masterAddr)
|
||||
|
@ -83,7 +84,7 @@ func CreateProxy(ctx context.Context) (*Proxy, error) {
|
|||
|
||||
p.manipulationMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamInsertBufSize())
|
||||
p.manipulationMsgStream.SetPulsarClient(pulsarAddress)
|
||||
p.manipulationMsgStream.CreatePulsarProducers(Params.insertChannelNames())
|
||||
p.manipulationMsgStream.CreatePulsarProducers(Params.InsertChannelNames())
|
||||
repackFuncImpl := func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
|
||||
return insertRepackFunc(tsMsgs, hashKeys, p.segAssigner, false)
|
||||
}
|
||||
|
@ -137,7 +138,7 @@ func (p *Proxy) AddCloseCallback(callbacks ...func()) {
|
|||
func (p *Proxy) grpcLoop() {
|
||||
defer p.proxyLoopWg.Done()
|
||||
|
||||
lis, err := net.Listen("tcp", Params.NetWorkAddress())
|
||||
lis, err := net.Listen("tcp", ":"+strconv.Itoa(Params.NetworkPort()))
|
||||
if err != nil {
|
||||
log.Fatalf("Proxy grpc server fatal error=%v", err)
|
||||
}
|
||||
|
|
|
@ -100,7 +100,7 @@ func setup() {
|
|||
|
||||
startMaster(ctx)
|
||||
startProxy(ctx)
|
||||
proxyAddr := Params.NetWorkAddress()
|
||||
proxyAddr := Params.NetworkAddress()
|
||||
addr := strings.Split(proxyAddr, ":")
|
||||
if addr[0] == "0.0.0.0" {
|
||||
proxyAddr = "127.0.0.1:" + addr[1]
|
||||
|
|
|
@ -364,7 +364,7 @@ func (sched *TaskScheduler) queryResultLoop() {
|
|||
unmarshal := msgstream.NewUnmarshalDispatcher()
|
||||
queryResultMsgStream := msgstream.NewPulsarMsgStream(sched.ctx, Params.MsgStreamSearchResultBufSize())
|
||||
queryResultMsgStream.SetPulsarClient(Params.PulsarAddress())
|
||||
queryResultMsgStream.CreatePulsarConsumers(Params.searchResultChannelNames(),
|
||||
queryResultMsgStream.CreatePulsarConsumers(Params.SearchResultChannelNames(),
|
||||
Params.ProxySubName(),
|
||||
unmarshal,
|
||||
Params.MsgStreamSearchResultPulsarBufSize())
|
||||
|
|
|
@ -31,7 +31,7 @@ import (
|
|||
* is up-to-date.
|
||||
*/
|
||||
type collectionReplica interface {
|
||||
getTSafe() *tSafe
|
||||
getTSafe() tSafe
|
||||
|
||||
// collection
|
||||
getCollectionNum() int
|
||||
|
@ -68,11 +68,11 @@ type collectionReplicaImpl struct {
|
|||
collections []*Collection
|
||||
segments map[UniqueID]*Segment
|
||||
|
||||
tSafe *tSafe
|
||||
tSafe tSafe
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------------------------------- tSafe
|
||||
func (colReplica *collectionReplicaImpl) getTSafe() *tSafe {
|
||||
func (colReplica *collectionReplicaImpl) getTSafe() tSafe {
|
||||
return colReplica.tSafe
|
||||
}
|
||||
|
||||
|
@ -111,6 +111,7 @@ func (colReplica *collectionReplicaImpl) removeCollection(collectionID UniqueID)
|
|||
if col.ID() == collectionID {
|
||||
for _, p := range *col.Partitions() {
|
||||
for _, s := range *p.Segments() {
|
||||
deleteSegment(colReplica.segments[s.ID()])
|
||||
delete(colReplica.segments, s.ID())
|
||||
}
|
||||
}
|
||||
|
@ -202,6 +203,7 @@ func (colReplica *collectionReplicaImpl) removePartition(collectionID UniqueID,
|
|||
for _, p := range *collection.Partitions() {
|
||||
if p.Tag() == partitionTag {
|
||||
for _, s := range *p.Segments() {
|
||||
deleteSegment(colReplica.segments[s.ID()])
|
||||
delete(colReplica.segments, s.ID())
|
||||
}
|
||||
} else {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,179 +1,48 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
func TestCollection_Partitions(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
node := newQueryNode()
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
collectionID := UniqueID(0)
|
||||
initTestMeta(t, node, collectionName, collectionID, 0)
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
collection, err := node.replica.getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
for _, tag := range collectionMeta.PartitionTags {
|
||||
err := (*node.replica).addPartition(collection.ID(), tag)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
partitions := collection.Partitions()
|
||||
assert.Equal(t, len(collectionMeta.PartitionTags), len(*partitions))
|
||||
assert.Equal(t, 1, len(*partitions))
|
||||
}
|
||||
|
||||
func TestCollection_newCollection(t *testing.T) {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
}
|
||||
|
||||
func TestCollection_deleteCollection(t *testing.T) {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
deleteCollection(collection)
|
||||
}
|
||||
|
|
|
@ -11,10 +11,10 @@ type dataSyncService struct {
|
|||
ctx context.Context
|
||||
fg *flowgraph.TimeTickedFlowGraph
|
||||
|
||||
replica *collectionReplica
|
||||
replica collectionReplica
|
||||
}
|
||||
|
||||
func newDataSyncService(ctx context.Context, replica *collectionReplica) *dataSyncService {
|
||||
func newDataSyncService(ctx context.Context, replica collectionReplica) *dataSyncService {
|
||||
|
||||
return &dataSyncService{
|
||||
ctx: ctx,
|
||||
|
|
|
@ -1,101 +1,22 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
// NOTE: start pulsar before test
|
||||
func TestDataSyncService_Start(t *testing.T) {
|
||||
Params.Init()
|
||||
var ctx context.Context
|
||||
|
||||
if closeWithDeadline {
|
||||
var cancel context.CancelFunc
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel = context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// init query node
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
// init meta
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
node := newQueryNode()
|
||||
initTestMeta(t, node, "collection0", 0, 0)
|
||||
// test data generate
|
||||
const msgLength = 10
|
||||
const DIM = 16
|
||||
|
@ -179,25 +100,25 @@ func TestDataSyncService_Start(t *testing.T) {
|
|||
// pulsar produce
|
||||
const receiveBufSize = 1024
|
||||
producerChannels := Params.insertChannelNames()
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
|
||||
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
|
||||
insertStream.SetPulsarClient(pulsarURL)
|
||||
insertStream.CreatePulsarProducers(producerChannels)
|
||||
|
||||
var insertMsgStream msgstream.MsgStream = insertStream
|
||||
insertMsgStream.Start()
|
||||
|
||||
err = insertMsgStream.Produce(&msgPack)
|
||||
err := insertMsgStream.Produce(&msgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = insertMsgStream.Broadcast(&timeTickMsgPack)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// dataSync
|
||||
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
|
||||
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
|
||||
go node.dataSyncService.start()
|
||||
|
||||
node.Close()
|
||||
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
|
||||
type insertNode struct {
|
||||
BaseNode
|
||||
replica *collectionReplica
|
||||
replica collectionReplica
|
||||
}
|
||||
|
||||
type InsertData struct {
|
||||
|
@ -58,13 +58,13 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg {
|
|||
insertData.insertRecords[task.SegmentID] = append(insertData.insertRecords[task.SegmentID], task.RowData...)
|
||||
|
||||
// check if segment exists, if not, create this segment
|
||||
if !(*iNode.replica).hasSegment(task.SegmentID) {
|
||||
collection, err := (*iNode.replica).getCollectionByName(task.CollectionName)
|
||||
if !iNode.replica.hasSegment(task.SegmentID) {
|
||||
collection, err := iNode.replica.getCollectionByName(task.CollectionName)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
err = (*iNode.replica).addSegment(task.SegmentID, task.PartitionTag, collection.ID())
|
||||
err = iNode.replica.addSegment(task.SegmentID, task.PartitionTag, collection.ID())
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
|
@ -74,7 +74,7 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg {
|
|||
|
||||
// 2. do preInsert
|
||||
for segmentID := range insertData.insertRecords {
|
||||
var targetSegment, err = (*iNode.replica).getSegmentByID(segmentID)
|
||||
var targetSegment, err = iNode.replica.getSegmentByID(segmentID)
|
||||
if err != nil {
|
||||
log.Println("preInsert failed")
|
||||
// TODO: add error handling
|
||||
|
@ -102,7 +102,7 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg {
|
|||
}
|
||||
|
||||
func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *sync.WaitGroup) {
|
||||
var targetSegment, err = (*iNode.replica).getSegmentByID(segmentID)
|
||||
var targetSegment, err = iNode.replica.getSegmentByID(segmentID)
|
||||
if err != nil {
|
||||
log.Println("cannot find segment:", segmentID)
|
||||
// TODO: add error handling
|
||||
|
@ -127,7 +127,7 @@ func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *syn
|
|||
wg.Done()
|
||||
}
|
||||
|
||||
func newInsertNode(replica *collectionReplica) *insertNode {
|
||||
func newInsertNode(replica collectionReplica) *insertNode {
|
||||
maxQueueLength := Params.flowGraphMaxQueueLength()
|
||||
maxParallelism := Params.flowGraphMaxParallelism()
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
type serviceTimeNode struct {
|
||||
BaseNode
|
||||
replica *collectionReplica
|
||||
replica collectionReplica
|
||||
}
|
||||
|
||||
func (stNode *serviceTimeNode) Name() string {
|
||||
|
@ -28,12 +28,12 @@ func (stNode *serviceTimeNode) Operate(in []*Msg) []*Msg {
|
|||
}
|
||||
|
||||
// update service time
|
||||
(*(*stNode.replica).getTSafe()).set(serviceTimeMsg.timeRange.timestampMax)
|
||||
stNode.replica.getTSafe().set(serviceTimeMsg.timeRange.timestampMax)
|
||||
//fmt.Println("update tSafe to:", getPhysicalTime(serviceTimeMsg.timeRange.timestampMax))
|
||||
return nil
|
||||
}
|
||||
|
||||
func newServiceTimeNode(replica *collectionReplica) *serviceTimeNode {
|
||||
func newServiceTimeNode(replica collectionReplica) *serviceTimeNode {
|
||||
maxQueueLength := Params.flowGraphMaxQueueLength()
|
||||
maxParallelism := Params.flowGraphMaxParallelism()
|
||||
|
||||
|
|
|
@ -26,10 +26,10 @@ const (
|
|||
type metaService struct {
|
||||
ctx context.Context
|
||||
kvBase *etcdkv.EtcdKV
|
||||
replica *collectionReplica
|
||||
replica collectionReplica
|
||||
}
|
||||
|
||||
func newMetaService(ctx context.Context, replica *collectionReplica) *metaService {
|
||||
func newMetaService(ctx context.Context, replica collectionReplica) *metaService {
|
||||
ETCDAddr := Params.etcdAddress()
|
||||
MetaRootPath := Params.metaRootPath()
|
||||
|
||||
|
@ -149,12 +149,12 @@ func (mService *metaService) processCollectionCreate(id string, value string) {
|
|||
|
||||
col := mService.collectionUnmarshal(value)
|
||||
if col != nil {
|
||||
err := (*mService.replica).addCollection(col, value)
|
||||
err := mService.replica.addCollection(col, value)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
for _, partitionTag := range col.PartitionTags {
|
||||
err = (*mService.replica).addPartition(col.ID, partitionTag)
|
||||
err = mService.replica.addPartition(col.ID, partitionTag)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
@ -173,7 +173,7 @@ func (mService *metaService) processSegmentCreate(id string, value string) {
|
|||
|
||||
// TODO: what if seg == nil? We need to notify master and return rpc request failed
|
||||
if seg != nil {
|
||||
err := (*mService.replica).addSegment(seg.SegmentID, seg.PartitionTag, seg.CollectionID)
|
||||
err := mService.replica.addSegment(seg.SegmentID, seg.PartitionTag, seg.CollectionID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
|
@ -202,7 +202,7 @@ func (mService *metaService) processSegmentModify(id string, value string) {
|
|||
}
|
||||
|
||||
if seg != nil {
|
||||
targetSegment, err := (*mService.replica).getSegmentByID(seg.SegmentID)
|
||||
targetSegment, err := mService.replica.getSegmentByID(seg.SegmentID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
|
@ -218,11 +218,11 @@ func (mService *metaService) processCollectionModify(id string, value string) {
|
|||
|
||||
col := mService.collectionUnmarshal(value)
|
||||
if col != nil {
|
||||
err := (*mService.replica).addPartitionsByCollectionMeta(col)
|
||||
err := mService.replica.addPartitionsByCollectionMeta(col)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
err = (*mService.replica).removePartitionsByCollectionMeta(col)
|
||||
err = mService.replica.removePartitionsByCollectionMeta(col)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
@ -249,7 +249,7 @@ func (mService *metaService) processSegmentDelete(id string) {
|
|||
log.Println("Cannot parse segment id:" + id)
|
||||
}
|
||||
|
||||
err = (*mService.replica).removeSegment(segmentID)
|
||||
err = mService.replica.removeSegment(segmentID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
|
@ -264,7 +264,7 @@ func (mService *metaService) processCollectionDelete(id string) {
|
|||
log.Println("Cannot parse collection id:" + id)
|
||||
}
|
||||
|
||||
err = (*mService.replica).removeCollection(collectionID)
|
||||
err = mService.replica.removeCollection(collectionID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
|
|
|
@ -3,23 +3,13 @@ package querynode
|
|||
import (
|
||||
"context"
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
Params.Init()
|
||||
exitCode := m.Run()
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func TestMetaService_start(t *testing.T) {
|
||||
var ctx context.Context
|
||||
|
||||
|
@ -37,6 +27,7 @@ func TestMetaService_start(t *testing.T) {
|
|||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
(*node.metaService).start()
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_getCollectionObjId(t *testing.T) {
|
||||
|
@ -119,47 +110,9 @@ func TestMetaService_isSegmentChannelRangeInQueryNodeChannelRange(t *testing.T)
|
|||
|
||||
func TestMetaService_printCollectionStruct(t *testing.T) {
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
printCollectionStruct(&collectionMeta)
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
printCollectionStruct(collectionMeta)
|
||||
}
|
||||
|
||||
func TestMetaService_printSegmentStruct(t *testing.T) {
|
||||
|
@ -178,13 +131,8 @@ func TestMetaService_printSegmentStruct(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMetaService_processCollectionCreate(t *testing.T) {
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
|
||||
id := "0"
|
||||
value := `schema: <
|
||||
|
@ -196,6 +144,10 @@ func TestMetaService_processCollectionCreate(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -212,71 +164,21 @@ func TestMetaService_processCollectionCreate(t *testing.T) {
|
|||
|
||||
node.metaService.processCollectionCreate(id, value)
|
||||
|
||||
collectionNum := (*node.replica).getCollectionNum()
|
||||
collectionNum := node.replica.getCollectionNum()
|
||||
assert.Equal(t, collectionNum, 1)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName("test")
|
||||
collection, err := node.replica.getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processSegmentCreate(t *testing.T) {
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
node := newQueryNode()
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
colMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
|
||||
err := (*node.replica).addCollection(&collectionMeta, string(colMetaBlob))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = (*node.replica).addPartition(UniqueID(0), "default")
|
||||
assert.NoError(t, err)
|
||||
collectionID := UniqueID(0)
|
||||
initTestMeta(t, node, collectionName, collectionID, 0)
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
|
||||
id := "0"
|
||||
value := `partition_tag: "default"
|
||||
|
@ -287,19 +189,15 @@ func TestMetaService_processSegmentCreate(t *testing.T) {
|
|||
|
||||
(*node.metaService).processSegmentCreate(id, value)
|
||||
|
||||
s, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
s, err := node.replica.getSegmentByID(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, s.segmentID, UniqueID(0))
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processCreate(t *testing.T) {
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
|
||||
key1 := "by-dev/meta/collection/0"
|
||||
msg1 := `schema: <
|
||||
|
@ -311,6 +209,10 @@ func TestMetaService_processCreate(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -326,10 +228,10 @@ func TestMetaService_processCreate(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCreate(key1, msg1)
|
||||
collectionNum := (*node.replica).getCollectionNum()
|
||||
collectionNum := node.replica.getCollectionNum()
|
||||
assert.Equal(t, collectionNum, 1)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName("test")
|
||||
collection, err := node.replica.getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
|
||||
|
@ -341,68 +243,19 @@ func TestMetaService_processCreate(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCreate(key2, msg2)
|
||||
s, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
s, err := node.replica.getSegmentByID(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, s.segmentID, UniqueID(0))
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processSegmentModify(t *testing.T) {
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
node := newQueryNode()
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
colMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
|
||||
err := (*node.replica).addCollection(&collectionMeta, string(colMetaBlob))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = (*node.replica).addPartition(UniqueID(0), "default")
|
||||
assert.NoError(t, err)
|
||||
collectionID := UniqueID(0)
|
||||
segmentID := UniqueID(0)
|
||||
initTestMeta(t, node, collectionName, collectionID, segmentID)
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
|
||||
id := "0"
|
||||
value := `partition_tag: "default"
|
||||
|
@ -412,9 +265,9 @@ func TestMetaService_processSegmentModify(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processSegmentCreate(id, value)
|
||||
s, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
s, err := node.replica.getSegmentByID(segmentID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, s.segmentID, UniqueID(0))
|
||||
assert.Equal(t, s.segmentID, segmentID)
|
||||
|
||||
newValue := `partition_tag: "default"
|
||||
channel_start: 0
|
||||
|
@ -424,19 +277,15 @@ func TestMetaService_processSegmentModify(t *testing.T) {
|
|||
|
||||
// TODO: modify segment for testing processCollectionModify
|
||||
(*node.metaService).processSegmentModify(id, newValue)
|
||||
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
seg, err := node.replica.getSegmentByID(segmentID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, seg.segmentID, UniqueID(0))
|
||||
assert.Equal(t, seg.segmentID, segmentID)
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processCollectionModify(t *testing.T) {
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
|
||||
id := "0"
|
||||
value := `schema: <
|
||||
|
@ -448,6 +297,10 @@ func TestMetaService_processCollectionModify(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -465,24 +318,24 @@ func TestMetaService_processCollectionModify(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCollectionCreate(id, value)
|
||||
collectionNum := (*node.replica).getCollectionNum()
|
||||
collectionNum := node.replica.getCollectionNum()
|
||||
assert.Equal(t, collectionNum, 1)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName("test")
|
||||
collection, err := node.replica.getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
|
||||
partitionNum, err := (*node.replica).getPartitionNum(UniqueID(0))
|
||||
partitionNum, err := node.replica.getPartitionNum(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, partitionNum, 3)
|
||||
|
||||
hasPartition := (*node.replica).hasPartition(UniqueID(0), "p0")
|
||||
hasPartition := node.replica.hasPartition(UniqueID(0), "p0")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
|
||||
assert.Equal(t, hasPartition, false)
|
||||
|
||||
newValue := `schema: <
|
||||
|
@ -494,6 +347,10 @@ func TestMetaService_processCollectionModify(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -511,32 +368,28 @@ func TestMetaService_processCollectionModify(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCollectionModify(id, newValue)
|
||||
collection, err = (*node.replica).getCollectionByName("test")
|
||||
collection, err = node.replica.getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
|
||||
partitionNum, err = (*node.replica).getPartitionNum(UniqueID(0))
|
||||
partitionNum, err = node.replica.getPartitionNum(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, partitionNum, 3)
|
||||
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p0")
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p0")
|
||||
assert.Equal(t, hasPartition, false)
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processModify(t *testing.T) {
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
|
||||
key1 := "by-dev/meta/collection/0"
|
||||
msg1 := `schema: <
|
||||
|
@ -548,6 +401,10 @@ func TestMetaService_processModify(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -565,24 +422,24 @@ func TestMetaService_processModify(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCreate(key1, msg1)
|
||||
collectionNum := (*node.replica).getCollectionNum()
|
||||
collectionNum := node.replica.getCollectionNum()
|
||||
assert.Equal(t, collectionNum, 1)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName("test")
|
||||
collection, err := node.replica.getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
|
||||
partitionNum, err := (*node.replica).getPartitionNum(UniqueID(0))
|
||||
partitionNum, err := node.replica.getPartitionNum(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, partitionNum, 3)
|
||||
|
||||
hasPartition := (*node.replica).hasPartition(UniqueID(0), "p0")
|
||||
hasPartition := node.replica.hasPartition(UniqueID(0), "p0")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
|
||||
assert.Equal(t, hasPartition, false)
|
||||
|
||||
key2 := "by-dev/meta/segment/0"
|
||||
|
@ -593,7 +450,7 @@ func TestMetaService_processModify(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCreate(key2, msg2)
|
||||
s, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
s, err := node.replica.getSegmentByID(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, s.segmentID, UniqueID(0))
|
||||
|
||||
|
@ -608,6 +465,10 @@ func TestMetaService_processModify(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -625,21 +486,21 @@ func TestMetaService_processModify(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processModify(key1, msg3)
|
||||
collection, err = (*node.replica).getCollectionByName("test")
|
||||
collection, err = node.replica.getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
|
||||
partitionNum, err = (*node.replica).getPartitionNum(UniqueID(0))
|
||||
partitionNum, err = node.replica.getPartitionNum(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, partitionNum, 3)
|
||||
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p0")
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p0")
|
||||
assert.Equal(t, hasPartition, false)
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
|
||||
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
|
||||
assert.Equal(t, hasPartition, true)
|
||||
|
||||
msg4 := `partition_tag: "p1"
|
||||
|
@ -649,68 +510,18 @@ func TestMetaService_processModify(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processModify(key2, msg4)
|
||||
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
seg, err := node.replica.getSegmentByID(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, seg.segmentID, UniqueID(0))
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processSegmentDelete(t *testing.T) {
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
|
||||
node := newQueryNode()
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
colMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
|
||||
err := (*node.replica).addCollection(&collectionMeta, string(colMetaBlob))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = (*node.replica).addPartition(UniqueID(0), "default")
|
||||
assert.NoError(t, err)
|
||||
collectionID := UniqueID(0)
|
||||
initTestMeta(t, node, collectionName, collectionID, 0)
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
|
||||
id := "0"
|
||||
value := `partition_tag: "default"
|
||||
|
@ -720,23 +531,19 @@ func TestMetaService_processSegmentDelete(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processSegmentCreate(id, value)
|
||||
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
seg, err := node.replica.getSegmentByID(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, seg.segmentID, UniqueID(0))
|
||||
|
||||
(*node.metaService).processSegmentDelete("0")
|
||||
mapSize := (*node.replica).getSegmentNum()
|
||||
mapSize := node.replica.getSegmentNum()
|
||||
assert.Equal(t, mapSize, 0)
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processCollectionDelete(t *testing.T) {
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
|
||||
id := "0"
|
||||
value := `schema: <
|
||||
|
@ -748,6 +555,10 @@ func TestMetaService_processCollectionDelete(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -763,26 +574,22 @@ func TestMetaService_processCollectionDelete(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCollectionCreate(id, value)
|
||||
collectionNum := (*node.replica).getCollectionNum()
|
||||
collectionNum := node.replica.getCollectionNum()
|
||||
assert.Equal(t, collectionNum, 1)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName("test")
|
||||
collection, err := node.replica.getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
|
||||
(*node.metaService).processCollectionDelete(id)
|
||||
collectionNum = (*node.replica).getCollectionNum()
|
||||
collectionNum = node.replica.getCollectionNum()
|
||||
assert.Equal(t, collectionNum, 0)
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processDelete(t *testing.T) {
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
|
||||
key1 := "by-dev/meta/collection/0"
|
||||
msg1 := `schema: <
|
||||
|
@ -794,6 +601,10 @@ func TestMetaService_processDelete(t *testing.T) {
|
|||
key: "dim"
|
||||
value: "16"
|
||||
>
|
||||
index_params: <
|
||||
key: "metric_type"
|
||||
value: "L2"
|
||||
>
|
||||
>
|
||||
fields: <
|
||||
name: "age"
|
||||
|
@ -809,10 +620,10 @@ func TestMetaService_processDelete(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCreate(key1, msg1)
|
||||
collectionNum := (*node.replica).getCollectionNum()
|
||||
collectionNum := node.replica.getCollectionNum()
|
||||
assert.Equal(t, collectionNum, 1)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName("test")
|
||||
collection, err := node.replica.getCollectionByName("test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.ID(), UniqueID(0))
|
||||
|
||||
|
@ -824,77 +635,48 @@ func TestMetaService_processDelete(t *testing.T) {
|
|||
`
|
||||
|
||||
(*node.metaService).processCreate(key2, msg2)
|
||||
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
|
||||
seg, err := node.replica.getSegmentByID(UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, seg.segmentID, UniqueID(0))
|
||||
|
||||
(*node.metaService).processDelete(key1)
|
||||
collectionsSize := (*node.replica).getCollectionNum()
|
||||
collectionsSize := node.replica.getCollectionNum()
|
||||
assert.Equal(t, collectionsSize, 0)
|
||||
|
||||
mapSize := (*node.replica).getSegmentNum()
|
||||
mapSize := node.replica.getSegmentNum()
|
||||
assert.Equal(t, mapSize, 0)
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_processResp(t *testing.T) {
|
||||
var ctx context.Context
|
||||
if closeWithDeadline {
|
||||
var cancel context.CancelFunc
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel = context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
|
||||
metaChan := (*node.metaService).kvBase.WatchWithPrefix("")
|
||||
|
||||
select {
|
||||
case <-node.ctx.Done():
|
||||
case <-node.queryNodeLoopCtx.Done():
|
||||
return
|
||||
case resp := <-metaChan:
|
||||
_ = (*node.metaService).processResp(resp)
|
||||
}
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_loadCollections(t *testing.T) {
|
||||
var ctx context.Context
|
||||
if closeWithDeadline {
|
||||
var cancel context.CancelFunc
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel = context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
|
||||
err2 := (*node.metaService).loadCollections()
|
||||
assert.Nil(t, err2)
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestMetaService_loadSegments(t *testing.T) {
|
||||
var ctx context.Context
|
||||
if closeWithDeadline {
|
||||
var cancel context.CancelFunc
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel = context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// init metaService
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.metaService = newMetaService(ctx, node.replica)
|
||||
node := newQueryNode()
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
|
||||
err2 := (*node.metaService).loadSegments()
|
||||
assert.Nil(t, err2)
|
||||
node.Close()
|
||||
}
|
||||
|
|
|
@ -2,8 +2,8 @@ package querynode
|
|||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/paramtable"
|
||||
)
|
||||
|
@ -21,15 +21,16 @@ func (p *ParamTable) Init() {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
err = p.LoadYaml("milvus.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = p.LoadYaml("advanced/channel.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
queryNodeIDStr := os.Getenv("QUERY_NODE_ID")
|
||||
if queryNodeIDStr == "" {
|
||||
queryNodeIDList := p.QueryNodeIDList()
|
||||
if len(queryNodeIDList) <= 0 {
|
||||
queryNodeIDStr = "0"
|
||||
} else {
|
||||
queryNodeIDStr = strconv.Itoa(int(queryNodeIDList[0]))
|
||||
}
|
||||
}
|
||||
p.Save("_queryNodeID", queryNodeIDStr)
|
||||
}
|
||||
|
||||
func (p *ParamTable) pulsarAddress() (string, error) {
|
||||
|
@ -40,8 +41,8 @@ func (p *ParamTable) pulsarAddress() (string, error) {
|
|||
return url, nil
|
||||
}
|
||||
|
||||
func (p *ParamTable) queryNodeID() int {
|
||||
queryNodeID, err := p.Load("reader.clientid")
|
||||
func (p *ParamTable) QueryNodeID() UniqueID {
|
||||
queryNodeID, err := p.Load("_queryNodeID")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -49,7 +50,7 @@ func (p *ParamTable) queryNodeID() int {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
return UniqueID(id)
|
||||
}
|
||||
|
||||
func (p *ParamTable) insertChannelRange() []int {
|
||||
|
@ -57,138 +58,47 @@ func (p *ParamTable) insertChannelRange() []int {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
channelRange := strings.Split(insertChannelRange, ",")
|
||||
if len(channelRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(channelRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(channelRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
return []int{channelBegin, channelEnd}
|
||||
return paramtable.ConvertRangeToIntRange(insertChannelRange, ",")
|
||||
}
|
||||
|
||||
// advanced params
|
||||
// stats
|
||||
func (p *ParamTable) statsPublishInterval() int {
|
||||
timeInterval, err := p.Load("queryNode.stats.publishInterval")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
interval, err := strconv.Atoi(timeInterval)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return interval
|
||||
return p.ParseInt("queryNode.stats.publishInterval")
|
||||
}
|
||||
|
||||
// dataSync:
|
||||
func (p *ParamTable) flowGraphMaxQueueLength() int32 {
|
||||
queueLength, err := p.Load("queryNode.dataSync.flowGraph.maxQueueLength")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
length, err := strconv.Atoi(queueLength)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int32(length)
|
||||
return p.ParseInt32("queryNode.dataSync.flowGraph.maxQueueLength")
|
||||
}
|
||||
|
||||
func (p *ParamTable) flowGraphMaxParallelism() int32 {
|
||||
maxParallelism, err := p.Load("queryNode.dataSync.flowGraph.maxParallelism")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
maxPara, err := strconv.Atoi(maxParallelism)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int32(maxPara)
|
||||
return p.ParseInt32("queryNode.dataSync.flowGraph.maxParallelism")
|
||||
}
|
||||
|
||||
// msgStream
|
||||
func (p *ParamTable) insertReceiveBufSize() int64 {
|
||||
revBufSize, err := p.Load("queryNode.msgStream.insert.recvBufSize")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bufSize, err := strconv.Atoi(revBufSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(bufSize)
|
||||
return p.ParseInt64("queryNode.msgStream.insert.recvBufSize")
|
||||
}
|
||||
|
||||
func (p *ParamTable) insertPulsarBufSize() int64 {
|
||||
pulsarBufSize, err := p.Load("queryNode.msgStream.insert.pulsarBufSize")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bufSize, err := strconv.Atoi(pulsarBufSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(bufSize)
|
||||
return p.ParseInt64("queryNode.msgStream.insert.pulsarBufSize")
|
||||
}
|
||||
|
||||
func (p *ParamTable) searchReceiveBufSize() int64 {
|
||||
revBufSize, err := p.Load("queryNode.msgStream.search.recvBufSize")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bufSize, err := strconv.Atoi(revBufSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(bufSize)
|
||||
return p.ParseInt64("queryNode.msgStream.search.recvBufSize")
|
||||
}
|
||||
|
||||
func (p *ParamTable) searchPulsarBufSize() int64 {
|
||||
pulsarBufSize, err := p.Load("queryNode.msgStream.search.pulsarBufSize")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bufSize, err := strconv.Atoi(pulsarBufSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(bufSize)
|
||||
return p.ParseInt64("queryNode.msgStream.search.pulsarBufSize")
|
||||
}
|
||||
|
||||
func (p *ParamTable) searchResultReceiveBufSize() int64 {
|
||||
revBufSize, err := p.Load("queryNode.msgStream.searchResult.recvBufSize")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bufSize, err := strconv.Atoi(revBufSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(bufSize)
|
||||
return p.ParseInt64("queryNode.msgStream.searchResult.recvBufSize")
|
||||
}
|
||||
|
||||
func (p *ParamTable) statsReceiveBufSize() int64 {
|
||||
revBufSize, err := p.Load("queryNode.msgStream.stats.recvBufSize")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bufSize, err := strconv.Atoi(revBufSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(bufSize)
|
||||
return p.ParseInt64("queryNode.msgStream.stats.recvBufSize")
|
||||
}
|
||||
|
||||
func (p *ParamTable) etcdAddress() string {
|
||||
|
@ -212,123 +122,73 @@ func (p *ParamTable) metaRootPath() string {
|
|||
}
|
||||
|
||||
func (p *ParamTable) gracefulTime() int64 {
|
||||
gracefulTime, err := p.Load("queryNode.gracefulTime")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
time, err := strconv.Atoi(gracefulTime)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(time)
|
||||
return p.ParseInt64("queryNode.gracefulTime")
|
||||
}
|
||||
|
||||
func (p *ParamTable) insertChannelNames() []string {
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.insert")
|
||||
|
||||
prefix, err := p.Load("msgChannel.chanNamePrefix.insert")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
prefix += "-"
|
||||
channelRange, err := p.Load("msgChannel.channelRange.insert")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(channelRange, ",")
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
sep := len(channelIDs) / p.queryNodeNum()
|
||||
index := p.sliceIndex()
|
||||
if index == -1 {
|
||||
panic("queryNodeID not Match with Config")
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
}
|
||||
return channels
|
||||
start := index * sep
|
||||
return ret[start : start+sep]
|
||||
}
|
||||
|
||||
func (p *ParamTable) searchChannelNames() []string {
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.search")
|
||||
prefix, err := p.Load("msgChannel.chanNamePrefix.search")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
prefix += "-"
|
||||
channelRange, err := p.Load("msgChannel.channelRange.search")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(channelRange, ",")
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
}
|
||||
return channels
|
||||
return ret
|
||||
}
|
||||
|
||||
func (p *ParamTable) searchResultChannelNames() []string {
|
||||
ch, err := p.Load("msgChannel.chanNamePrefix.searchResult")
|
||||
prefix, err := p.Load("msgChannel.chanNamePrefix.searchResult")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
prefix += "-"
|
||||
channelRange, err := p.Load("msgChannel.channelRange.searchResult")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
chanRange := strings.Split(channelRange, ",")
|
||||
if len(chanRange) != 2 {
|
||||
panic("Illegal channel range num")
|
||||
}
|
||||
channelBegin, err := strconv.Atoi(chanRange[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
channelEnd, err := strconv.Atoi(chanRange[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if channelBegin < 0 || channelEnd < 0 {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
if channelBegin > channelEnd {
|
||||
panic("Illegal channel range value")
|
||||
}
|
||||
channelIDs := paramtable.ConvertRangeToIntSlice(channelRange, ",")
|
||||
|
||||
channels := make([]string, channelEnd-channelBegin)
|
||||
for i := 0; i < channelEnd-channelBegin; i++ {
|
||||
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
|
||||
var ret []string
|
||||
for _, ID := range channelIDs {
|
||||
ret = append(ret, prefix+strconv.Itoa(ID))
|
||||
}
|
||||
return channels
|
||||
return ret
|
||||
}
|
||||
|
||||
func (p *ParamTable) msgChannelSubName() string {
|
||||
|
@ -337,7 +197,11 @@ func (p *ParamTable) msgChannelSubName() string {
|
|||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
return name
|
||||
queryNodeIDStr, err := p.Load("_QueryNodeID")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return name + "-" + queryNodeIDStr
|
||||
}
|
||||
|
||||
func (p *ParamTable) statsChannelName() string {
|
||||
|
@ -347,3 +211,18 @@ func (p *ParamTable) statsChannelName() string {
|
|||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (p *ParamTable) sliceIndex() int {
|
||||
queryNodeID := p.QueryNodeID()
|
||||
queryNodeIDList := p.QueryNodeIDList()
|
||||
for i := 0; i < len(queryNodeIDList); i++ {
|
||||
if queryNodeID == queryNodeIDList[i] {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (p *ParamTable) queryNodeNum() int {
|
||||
return len(p.QueryNodeIDList())
|
||||
}
|
||||
|
|
|
@ -1,128 +1,109 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParamTable_Init(t *testing.T) {
|
||||
Params.Init()
|
||||
}
|
||||
|
||||
func TestParamTable_PulsarAddress(t *testing.T) {
|
||||
Params.Init()
|
||||
address, err := Params.pulsarAddress()
|
||||
assert.NoError(t, err)
|
||||
split := strings.Split(address, ":")
|
||||
assert.Equal(t, split[0], "pulsar")
|
||||
assert.Equal(t, split[len(split)-1], "6650")
|
||||
assert.Equal(t, "pulsar", split[0])
|
||||
assert.Equal(t, "6650", split[len(split)-1])
|
||||
}
|
||||
|
||||
func TestParamTable_QueryNodeID(t *testing.T) {
|
||||
Params.Init()
|
||||
id := Params.queryNodeID()
|
||||
assert.Equal(t, id, 0)
|
||||
id := Params.QueryNodeID()
|
||||
assert.Contains(t, Params.QueryNodeIDList(), id)
|
||||
}
|
||||
|
||||
func TestParamTable_insertChannelRange(t *testing.T) {
|
||||
Params.Init()
|
||||
channelRange := Params.insertChannelRange()
|
||||
assert.Equal(t, len(channelRange), 2)
|
||||
assert.Equal(t, channelRange[0], 0)
|
||||
assert.Equal(t, channelRange[1], 1)
|
||||
assert.Equal(t, 2, len(channelRange))
|
||||
}
|
||||
|
||||
func TestParamTable_statsServiceTimeInterval(t *testing.T) {
|
||||
Params.Init()
|
||||
interval := Params.statsPublishInterval()
|
||||
assert.Equal(t, interval, 1000)
|
||||
assert.Equal(t, 1000, interval)
|
||||
}
|
||||
|
||||
func TestParamTable_statsMsgStreamReceiveBufSize(t *testing.T) {
|
||||
Params.Init()
|
||||
bufSize := Params.statsReceiveBufSize()
|
||||
assert.Equal(t, bufSize, int64(64))
|
||||
assert.Equal(t, int64(64), bufSize)
|
||||
}
|
||||
|
||||
func TestParamTable_insertMsgStreamReceiveBufSize(t *testing.T) {
|
||||
Params.Init()
|
||||
bufSize := Params.insertReceiveBufSize()
|
||||
assert.Equal(t, bufSize, int64(1024))
|
||||
assert.Equal(t, int64(1024), bufSize)
|
||||
}
|
||||
|
||||
func TestParamTable_searchMsgStreamReceiveBufSize(t *testing.T) {
|
||||
Params.Init()
|
||||
bufSize := Params.searchReceiveBufSize()
|
||||
assert.Equal(t, bufSize, int64(512))
|
||||
assert.Equal(t, int64(512), bufSize)
|
||||
}
|
||||
|
||||
func TestParamTable_searchResultMsgStreamReceiveBufSize(t *testing.T) {
|
||||
Params.Init()
|
||||
bufSize := Params.searchResultReceiveBufSize()
|
||||
assert.Equal(t, bufSize, int64(64))
|
||||
assert.Equal(t, int64(64), bufSize)
|
||||
}
|
||||
|
||||
func TestParamTable_searchPulsarBufSize(t *testing.T) {
|
||||
Params.Init()
|
||||
bufSize := Params.searchPulsarBufSize()
|
||||
assert.Equal(t, bufSize, int64(512))
|
||||
assert.Equal(t, int64(512), bufSize)
|
||||
}
|
||||
|
||||
func TestParamTable_insertPulsarBufSize(t *testing.T) {
|
||||
Params.Init()
|
||||
bufSize := Params.insertPulsarBufSize()
|
||||
assert.Equal(t, bufSize, int64(1024))
|
||||
assert.Equal(t, int64(1024), bufSize)
|
||||
}
|
||||
|
||||
func TestParamTable_flowGraphMaxQueueLength(t *testing.T) {
|
||||
Params.Init()
|
||||
length := Params.flowGraphMaxQueueLength()
|
||||
assert.Equal(t, length, int32(1024))
|
||||
assert.Equal(t, int32(1024), length)
|
||||
}
|
||||
|
||||
func TestParamTable_flowGraphMaxParallelism(t *testing.T) {
|
||||
Params.Init()
|
||||
maxParallelism := Params.flowGraphMaxParallelism()
|
||||
assert.Equal(t, maxParallelism, int32(1024))
|
||||
assert.Equal(t, int32(1024), maxParallelism)
|
||||
}
|
||||
|
||||
func TestParamTable_insertChannelNames(t *testing.T) {
|
||||
Params.Init()
|
||||
names := Params.insertChannelNames()
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, names[0], "insert-0")
|
||||
channelRange := Params.insertChannelRange()
|
||||
num := channelRange[1] - channelRange[0]
|
||||
num = num / Params.queryNodeNum()
|
||||
assert.Equal(t, num, len(names))
|
||||
start := num * Params.sliceIndex()
|
||||
assert.Equal(t, fmt.Sprintf("insert-%d", channelRange[start]), names[0])
|
||||
}
|
||||
|
||||
func TestParamTable_searchChannelNames(t *testing.T) {
|
||||
Params.Init()
|
||||
names := Params.searchChannelNames()
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, names[0], "search-0")
|
||||
assert.Equal(t, "search-0", names[0])
|
||||
}
|
||||
|
||||
func TestParamTable_searchResultChannelNames(t *testing.T) {
|
||||
Params.Init()
|
||||
names := Params.searchResultChannelNames()
|
||||
assert.Equal(t, len(names), 1)
|
||||
assert.Equal(t, names[0], "searchResult-0")
|
||||
assert.NotNil(t, names)
|
||||
}
|
||||
|
||||
func TestParamTable_msgChannelSubName(t *testing.T) {
|
||||
Params.Init()
|
||||
name := Params.msgChannelSubName()
|
||||
assert.Equal(t, name, "queryNode")
|
||||
expectName := fmt.Sprintf("queryNode-%d", Params.QueryNodeID())
|
||||
assert.Equal(t, expectName, name)
|
||||
}
|
||||
|
||||
func TestParamTable_statsChannelName(t *testing.T) {
|
||||
Params.Init()
|
||||
name := Params.statsChannelName()
|
||||
assert.Equal(t, name, "query-node-stats")
|
||||
assert.Equal(t, "query-node-stats", name)
|
||||
}
|
||||
|
||||
func TestParamTable_metaRootPath(t *testing.T) {
|
||||
Params.Init()
|
||||
path := Params.metaRootPath()
|
||||
assert.Equal(t, path, "by-dev/meta")
|
||||
assert.Equal(t, "by-dev/meta", path)
|
||||
}
|
||||
|
|
|
@ -1,77 +1,20 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
func TestPartition_Segments(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
node := newQueryNode()
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
collectionID := UniqueID(0)
|
||||
initTestMeta(t, node, collectionName, collectionID, 0)
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
collection, err := node.replica.getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
for _, tag := range collectionMeta.PartitionTags {
|
||||
err := (*node.replica).addPartition(collection.ID(), tag)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
collectionMeta := collection.meta
|
||||
|
||||
partitions := collection.Partitions()
|
||||
assert.Equal(t, len(collectionMeta.PartitionTags), len(*partitions))
|
||||
|
@ -80,12 +23,12 @@ func TestPartition_Segments(t *testing.T) {
|
|||
|
||||
const segmentNum = 3
|
||||
for i := 0; i < segmentNum; i++ {
|
||||
err := (*node.replica).addSegment(UniqueID(i), targetPartition.partitionTag, collection.ID())
|
||||
err := node.replica.addSegment(UniqueID(i), targetPartition.partitionTag, collection.ID())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
segments := targetPartition.Segments()
|
||||
assert.Equal(t, segmentNum, len(*segments))
|
||||
assert.Equal(t, segmentNum+1, len(*segments))
|
||||
}
|
||||
|
||||
func TestPartition_newPartition(t *testing.T) {
|
||||
|
|
|
@ -8,59 +8,17 @@ import (
|
|||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
)
|
||||
|
||||
func TestPlan_Plan(t *testing.T) {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
|
||||
|
||||
|
@ -74,52 +32,13 @@ func TestPlan_Plan(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPlan_PlaceholderGroup(t *testing.T) {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
|
||||
|
||||
|
|
|
@ -17,11 +17,12 @@ import (
|
|||
)
|
||||
|
||||
type QueryNode struct {
|
||||
ctx context.Context
|
||||
queryNodeLoopCtx context.Context
|
||||
queryNodeLoopCancel func()
|
||||
|
||||
QueryNodeID uint64
|
||||
|
||||
replica *collectionReplica
|
||||
replica collectionReplica
|
||||
|
||||
dataSyncService *dataSyncService
|
||||
metaService *metaService
|
||||
|
@ -29,7 +30,14 @@ type QueryNode struct {
|
|||
statsService *statsService
|
||||
}
|
||||
|
||||
func Init() {
|
||||
Params.Init()
|
||||
}
|
||||
|
||||
func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode {
|
||||
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
|
||||
segmentsMap := make(map[int64]*Segment)
|
||||
collections := make([]*Collection, 0)
|
||||
|
||||
|
@ -43,11 +51,11 @@ func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode {
|
|||
}
|
||||
|
||||
return &QueryNode{
|
||||
ctx: ctx,
|
||||
queryNodeLoopCtx: ctx1,
|
||||
queryNodeLoopCancel: cancel,
|
||||
QueryNodeID: queryNodeID,
|
||||
|
||||
QueryNodeID: queryNodeID,
|
||||
|
||||
replica: &replica,
|
||||
replica: replica,
|
||||
|
||||
dataSyncService: nil,
|
||||
metaService: nil,
|
||||
|
@ -56,31 +64,34 @@ func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode {
|
|||
}
|
||||
}
|
||||
|
||||
func (node *QueryNode) Start() {
|
||||
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
|
||||
node.searchService = newSearchService(node.ctx, node.replica)
|
||||
node.metaService = newMetaService(node.ctx, node.replica)
|
||||
node.statsService = newStatsService(node.ctx, node.replica)
|
||||
func (node *QueryNode) Start() error {
|
||||
// todo add connectMaster logic
|
||||
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
|
||||
node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica)
|
||||
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
|
||||
node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica)
|
||||
|
||||
go node.dataSyncService.start()
|
||||
go node.searchService.start()
|
||||
go node.metaService.start()
|
||||
node.statsService.start()
|
||||
go node.statsService.start()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) Close() {
|
||||
<-node.ctx.Done()
|
||||
node.queryNodeLoopCancel()
|
||||
|
||||
// free collectionReplica
|
||||
(*node.replica).freeAll()
|
||||
node.replica.freeAll()
|
||||
|
||||
// close services
|
||||
if node.dataSyncService != nil {
|
||||
(*node.dataSyncService).close()
|
||||
node.dataSyncService.close()
|
||||
}
|
||||
if node.searchService != nil {
|
||||
(*node.searchService).close()
|
||||
node.searchService.close()
|
||||
}
|
||||
if node.statsService != nil {
|
||||
(*node.statsService).close()
|
||||
node.statsService.close()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,18 +2,93 @@ package querynode
|
|||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
const ctxTimeInMillisecond = 200
|
||||
const closeWithDeadline = true
|
||||
|
||||
// NOTE: start pulsar and etcd before test
|
||||
func TestQueryNode_start(t *testing.T) {
|
||||
func setup() {
|
||||
Params.Init()
|
||||
}
|
||||
|
||||
func genTestCollectionMeta(collectionName string, collectionID UniqueID) *etcdpb.CollectionMeta {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
IndexParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "metric_type",
|
||||
Value: "L2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: collectionID,
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
return &collectionMeta
|
||||
}
|
||||
|
||||
func initTestMeta(t *testing.T, node *QueryNode, collectionName string, collectionID UniqueID, segmentID UniqueID) {
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = node.replica.addCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := node.replica.getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
assert.Equal(t, node.replica.getCollectionNum(), 1)
|
||||
|
||||
err = node.replica.addPartition(collection.ID(), collectionMeta.PartitionTags[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = node.replica.addSegment(segmentID, collectionMeta.PartitionTags[0], collectionID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func newQueryNode() *QueryNode {
|
||||
|
||||
var ctx context.Context
|
||||
|
||||
if closeWithDeadline {
|
||||
var cancel context.CancelFunc
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
|
@ -23,7 +98,21 @@ func TestQueryNode_start(t *testing.T) {
|
|||
ctx = context.Background()
|
||||
}
|
||||
|
||||
node := NewQueryNode(ctx, 0)
|
||||
node.Start()
|
||||
node.Close()
|
||||
svr := NewQueryNode(ctx, 0)
|
||||
return svr
|
||||
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
setup()
|
||||
exitCode := m.Run()
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
// NOTE: start pulsar and etcd before test
|
||||
func TestQueryNode_Start(t *testing.T) {
|
||||
localNode := newQueryNode()
|
||||
err := localNode.Start()
|
||||
assert.Nil(t, err)
|
||||
localNode.Close()
|
||||
}
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
func Init() {
|
||||
Params.Init()
|
||||
}
|
||||
|
||||
func StartQueryNode(ctx context.Context) {
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
node.Start()
|
||||
}
|
|
@ -9,63 +9,19 @@ import (
|
|||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
)
|
||||
|
||||
func TestReduce_AllFunc(t *testing.T) {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
segmentID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
segment := newSegment(collection, segmentID)
|
||||
assert.Equal(t, segmentID, segment.segmentID)
|
||||
|
||||
const DIM = 16
|
||||
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
|
|
|
@ -4,10 +4,12 @@ import "C"
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
|
@ -19,7 +21,7 @@ type searchService struct {
|
|||
wait sync.WaitGroup
|
||||
cancel context.CancelFunc
|
||||
|
||||
replica *collectionReplica
|
||||
replica collectionReplica
|
||||
tSafeWatcher *tSafeWatcher
|
||||
|
||||
serviceableTime Timestamp
|
||||
|
@ -27,13 +29,14 @@ type searchService struct {
|
|||
|
||||
msgBuffer chan msgstream.TsMsg
|
||||
unsolvedMsg []msgstream.TsMsg
|
||||
searchMsgStream *msgstream.MsgStream
|
||||
searchResultMsgStream *msgstream.MsgStream
|
||||
searchMsgStream msgstream.MsgStream
|
||||
searchResultMsgStream msgstream.MsgStream
|
||||
queryNodeID UniqueID
|
||||
}
|
||||
|
||||
type ResultEntityIds []UniqueID
|
||||
|
||||
func newSearchService(ctx context.Context, replica *collectionReplica) *searchService {
|
||||
func newSearchService(ctx context.Context, replica collectionReplica) *searchService {
|
||||
receiveBufSize := Params.searchReceiveBufSize()
|
||||
pulsarBufSize := Params.searchPulsarBufSize()
|
||||
|
||||
|
@ -69,14 +72,15 @@ func newSearchService(ctx context.Context, replica *collectionReplica) *searchSe
|
|||
replica: replica,
|
||||
tSafeWatcher: newTSafeWatcher(),
|
||||
|
||||
searchMsgStream: &inputStream,
|
||||
searchResultMsgStream: &outputStream,
|
||||
searchMsgStream: inputStream,
|
||||
searchResultMsgStream: outputStream,
|
||||
queryNodeID: Params.QueryNodeID(),
|
||||
}
|
||||
}
|
||||
|
||||
func (ss *searchService) start() {
|
||||
(*ss.searchMsgStream).Start()
|
||||
(*ss.searchResultMsgStream).Start()
|
||||
ss.searchMsgStream.Start()
|
||||
ss.searchResultMsgStream.Start()
|
||||
ss.register()
|
||||
ss.wait.Add(2)
|
||||
go ss.receiveSearchMsg()
|
||||
|
@ -85,20 +89,24 @@ func (ss *searchService) start() {
|
|||
}
|
||||
|
||||
func (ss *searchService) close() {
|
||||
(*ss.searchMsgStream).Close()
|
||||
(*ss.searchResultMsgStream).Close()
|
||||
if ss.searchMsgStream != nil {
|
||||
ss.searchMsgStream.Close()
|
||||
}
|
||||
if ss.searchResultMsgStream != nil {
|
||||
ss.searchResultMsgStream.Close()
|
||||
}
|
||||
ss.cancel()
|
||||
}
|
||||
|
||||
func (ss *searchService) register() {
|
||||
tSafe := (*(ss.replica)).getTSafe()
|
||||
(*tSafe).registerTSafeWatcher(ss.tSafeWatcher)
|
||||
tSafe := ss.replica.getTSafe()
|
||||
tSafe.registerTSafeWatcher(ss.tSafeWatcher)
|
||||
}
|
||||
|
||||
func (ss *searchService) waitNewTSafe() Timestamp {
|
||||
// block until dataSyncService updating tSafe
|
||||
ss.tSafeWatcher.hasUpdate()
|
||||
timestamp := (*(*ss.replica).getTSafe()).get()
|
||||
timestamp := ss.replica.getTSafe().get()
|
||||
return timestamp
|
||||
}
|
||||
|
||||
|
@ -122,7 +130,7 @@ func (ss *searchService) receiveSearchMsg() {
|
|||
case <-ss.ctx.Done():
|
||||
return
|
||||
default:
|
||||
msgPack := (*ss.searchMsgStream).Consume()
|
||||
msgPack := ss.searchMsgStream.Consume()
|
||||
if msgPack == nil || len(msgPack.Msgs) <= 0 {
|
||||
continue
|
||||
}
|
||||
|
@ -219,7 +227,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
}
|
||||
collectionName := query.CollectionName
|
||||
partitionTags := query.PartitionTags
|
||||
collection, err := (*ss.replica).getCollectionByName(collectionName)
|
||||
collection, err := ss.replica.getCollectionByName(collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -241,14 +249,14 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
matchedSegments := make([]*Segment, 0)
|
||||
|
||||
for _, partitionTag := range partitionTags {
|
||||
hasPartition := (*ss.replica).hasPartition(collectionID, partitionTag)
|
||||
hasPartition := ss.replica.hasPartition(collectionID, partitionTag)
|
||||
if !hasPartition {
|
||||
return errors.New("search Failed, invalid partitionTag")
|
||||
}
|
||||
}
|
||||
|
||||
for _, partitionTag := range partitionTags {
|
||||
partition, _ := (*ss.replica).getPartitionByTag(collectionID, partitionTag)
|
||||
partition, _ := ss.replica.getPartitionByTag(collectionID, partitionTag)
|
||||
for _, segment := range partition.segments {
|
||||
//fmt.Println("dsl = ", dsl)
|
||||
|
||||
|
@ -268,13 +276,13 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS},
|
||||
ReqID: searchMsg.ReqID,
|
||||
ProxyID: searchMsg.ProxyID,
|
||||
QueryNodeID: searchMsg.ProxyID,
|
||||
QueryNodeID: ss.queryNodeID,
|
||||
Timestamp: searchTimestamp,
|
||||
ResultChannelID: searchMsg.ResultChannelID,
|
||||
Hits: nil,
|
||||
}
|
||||
searchResultMsg := &msgstream.SearchResultMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
|
||||
SearchResult: results,
|
||||
}
|
||||
err = ss.publishSearchResult(searchResultMsg)
|
||||
|
@ -333,7 +341,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
Hits: hits,
|
||||
}
|
||||
searchResultMsg := &msgstream.SearchResultMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
|
||||
SearchResult: results,
|
||||
}
|
||||
err = ss.publishSearchResult(searchResultMsg)
|
||||
|
@ -350,9 +358,10 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
|
|||
}
|
||||
|
||||
func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error {
|
||||
fmt.Println("Public SearchResult", msg.HashKeys())
|
||||
msgPack := msgstream.MsgPack{}
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
err := (*ss.searchResultMsgStream).Produce(&msgPack)
|
||||
err := ss.searchResultMsgStream.Produce(&msgPack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -377,11 +386,11 @@ func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg, errMsg s
|
|||
}
|
||||
|
||||
tsMsg := &msgstream.SearchResultMsg{
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
|
||||
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
|
||||
SearchResult: results,
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, tsMsg)
|
||||
err := (*ss.searchResultMsgStream).Produce(&msgPack)
|
||||
err := ss.searchResultMsgStream.Produce(&msgPack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -13,80 +13,15 @@ import (
|
|||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
)
|
||||
|
||||
func TestSearch_Search(t *testing.T) {
|
||||
Params.Init()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
node := NewQueryNode(context.Background(), 0)
|
||||
initTestMeta(t, node, "collection0", 0, 0)
|
||||
|
||||
// init query node
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
// init meta
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test data generate
|
||||
const msgLength = 10
|
||||
|
@ -158,14 +93,14 @@ func TestSearch_Search(t *testing.T) {
|
|||
msgPackSearch := msgstream.MsgPack{}
|
||||
msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg)
|
||||
|
||||
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
searchStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
|
||||
searchStream.SetPulsarClient(pulsarURL)
|
||||
searchStream.CreatePulsarProducers(searchProducerChannels)
|
||||
searchStream.Start()
|
||||
err = searchStream.Produce(&msgPackSearch)
|
||||
assert.NoError(t, err)
|
||||
|
||||
node.searchService = newSearchService(node.ctx, node.replica)
|
||||
node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica)
|
||||
go node.searchService.start()
|
||||
|
||||
// start insert
|
||||
|
@ -235,7 +170,7 @@ func TestSearch_Search(t *testing.T) {
|
|||
timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg)
|
||||
|
||||
// pulsar produce
|
||||
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
|
||||
insertStream.SetPulsarClient(pulsarURL)
|
||||
insertStream.CreatePulsarProducers(insertProducerChannels)
|
||||
insertStream.Start()
|
||||
|
@ -245,83 +180,19 @@ func TestSearch_Search(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
// dataSync
|
||||
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
|
||||
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
|
||||
go node.dataSyncService.start()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
cancel()
|
||||
node.Close()
|
||||
}
|
||||
|
||||
func TestSearch_SearchMultiSegments(t *testing.T) {
|
||||
Params.Init()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
node := NewQueryNode(context.Background(), 0)
|
||||
initTestMeta(t, node, "collection0", 0, 0)
|
||||
|
||||
// init query node
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
// init meta
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test data generate
|
||||
const msgLength = 1024
|
||||
|
@ -393,14 +264,14 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
|
|||
msgPackSearch := msgstream.MsgPack{}
|
||||
msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg)
|
||||
|
||||
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
searchStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
|
||||
searchStream.SetPulsarClient(pulsarURL)
|
||||
searchStream.CreatePulsarProducers(searchProducerChannels)
|
||||
searchStream.Start()
|
||||
err = searchStream.Produce(&msgPackSearch)
|
||||
assert.NoError(t, err)
|
||||
|
||||
node.searchService = newSearchService(node.ctx, node.replica)
|
||||
node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica)
|
||||
go node.searchService.start()
|
||||
|
||||
// start insert
|
||||
|
@ -474,7 +345,7 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
|
|||
timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg)
|
||||
|
||||
// pulsar produce
|
||||
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
|
||||
insertStream.SetPulsarClient(pulsarURL)
|
||||
insertStream.CreatePulsarProducers(insertProducerChannels)
|
||||
insertStream.Start()
|
||||
|
@ -484,11 +355,10 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
// dataSync
|
||||
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
|
||||
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
|
||||
go node.dataSyncService.start()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
cancel()
|
||||
node.Close()
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"log"
|
||||
"math"
|
||||
|
@ -10,61 +9,21 @@ import (
|
|||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
|
||||
)
|
||||
|
||||
//-------------------------------------------------------------------------------------- constructor and destructor
|
||||
func TestSegment_newSegment(t *testing.T) {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -74,52 +33,15 @@ func TestSegment_newSegment(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSegment_deleteSegment(t *testing.T) {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -131,52 +53,15 @@ func TestSegment_deleteSegment(t *testing.T) {
|
|||
|
||||
//-------------------------------------------------------------------------------------- stats functions
|
||||
func TestSegment_getRowCount(t *testing.T) {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -219,52 +104,15 @@ func TestSegment_getRowCount(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSegment_getDeletedCount(t *testing.T) {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -313,52 +161,15 @@ func TestSegment_getDeletedCount(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSegment_getMemSize(t *testing.T) {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -402,53 +213,15 @@ func TestSegment_getMemSize(t *testing.T) {
|
|||
|
||||
//-------------------------------------------------------------------------------------- dm & search functions
|
||||
func TestSegment_segmentInsert(t *testing.T) {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
assert.Equal(t, segmentID, segment.segmentID)
|
||||
|
@ -486,52 +259,15 @@ func TestSegment_segmentInsert(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSegment_segmentDelete(t *testing.T) {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -576,55 +312,15 @@ func TestSegment_segmentDelete(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSegment_segmentSearch(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -661,13 +357,6 @@ func TestSegment_segmentSearch(t *testing.T) {
|
|||
|
||||
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
|
||||
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
const receiveBufSize = 1024
|
||||
searchProducerChannels := Params.searchChannelNames()
|
||||
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
searchStream.SetPulsarClient(pulsarURL)
|
||||
searchStream.CreatePulsarProducers(searchProducerChannels)
|
||||
|
||||
var searchRawData []byte
|
||||
for _, ele := range vec {
|
||||
buf := make([]byte, 4)
|
||||
|
@ -708,52 +397,15 @@ func TestSegment_segmentSearch(t *testing.T) {
|
|||
|
||||
//-------------------------------------------------------------------------------------- preDm functions
|
||||
func TestSegment_segmentPreInsert(t *testing.T) {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
@ -787,52 +439,15 @@ func TestSegment_segmentPreInsert(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSegment_segmentPreDelete(t *testing.T) {
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: "collection0",
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
collectionName := "collection0"
|
||||
collectionID := UniqueID(0)
|
||||
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
|
||||
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
collection := newCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
collection := newCollection(collectionMeta, collectionMetaBlob)
|
||||
assert.Equal(t, collection.meta.Schema.Name, collectionName)
|
||||
assert.Equal(t, collection.meta.ID, collectionID)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
segment := newSegment(collection, segmentID)
|
||||
|
|
|
@ -14,11 +14,11 @@ import (
|
|||
|
||||
type statsService struct {
|
||||
ctx context.Context
|
||||
statsStream *msgstream.MsgStream
|
||||
replica *collectionReplica
|
||||
statsStream msgstream.MsgStream
|
||||
replica collectionReplica
|
||||
}
|
||||
|
||||
func newStatsService(ctx context.Context, replica *collectionReplica) *statsService {
|
||||
func newStatsService(ctx context.Context, replica collectionReplica) *statsService {
|
||||
|
||||
return &statsService{
|
||||
ctx: ctx,
|
||||
|
@ -44,8 +44,8 @@ func (sService *statsService) start() {
|
|||
|
||||
var statsMsgStream msgstream.MsgStream = statsStream
|
||||
|
||||
sService.statsStream = &statsMsgStream
|
||||
(*sService.statsStream).Start()
|
||||
sService.statsStream = statsMsgStream
|
||||
sService.statsStream.Start()
|
||||
|
||||
// start service
|
||||
fmt.Println("do segments statistic in ", strconv.Itoa(sleepTimeInterval), "ms")
|
||||
|
@ -60,11 +60,13 @@ func (sService *statsService) start() {
|
|||
}
|
||||
|
||||
func (sService *statsService) close() {
|
||||
(*sService.statsStream).Close()
|
||||
if sService.statsStream != nil {
|
||||
sService.statsStream.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (sService *statsService) sendSegmentStatistic() {
|
||||
statisticData := (*sService.replica).getSegmentStatistics()
|
||||
statisticData := sService.replica.getSegmentStatistics()
|
||||
|
||||
// fmt.Println("Publish segment statistic")
|
||||
// fmt.Println(statisticData)
|
||||
|
@ -82,7 +84,7 @@ func (sService *statsService) publicStatistic(statistic *internalpb.QueryNodeSeg
|
|||
var msgPack = msgstream.MsgPack{
|
||||
Msgs: []msgstream.TsMsg{msg},
|
||||
}
|
||||
err := (*sService.statsStream).Produce(&msgPack)
|
||||
err := sService.statsStream.Produce(&msgPack)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
|
|
@ -1,193 +1,42 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
// NOTE: start pulsar before test
|
||||
func TestStatsService_start(t *testing.T) {
|
||||
Params.Init()
|
||||
var ctx context.Context
|
||||
|
||||
if closeWithDeadline {
|
||||
var cancel context.CancelFunc
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel = context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// init query node
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
// init meta
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// start stats service
|
||||
node.statsService = newStatsService(node.ctx, node.replica)
|
||||
node := newQueryNode()
|
||||
initTestMeta(t, node, "collection0", 0, 0)
|
||||
node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica)
|
||||
node.statsService.start()
|
||||
node.Close()
|
||||
}
|
||||
|
||||
// NOTE: start pulsar before test
|
||||
//NOTE: start pulsar before test
|
||||
func TestSegmentManagement_SegmentStatisticService(t *testing.T) {
|
||||
Params.Init()
|
||||
var ctx context.Context
|
||||
|
||||
if closeWithDeadline {
|
||||
var cancel context.CancelFunc
|
||||
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
|
||||
ctx, cancel = context.WithDeadline(context.Background(), d)
|
||||
defer cancel()
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// init query node
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
node := NewQueryNode(ctx, 0)
|
||||
|
||||
// init meta
|
||||
collectionName := "collection0"
|
||||
fieldVec := schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldInt := schemapb.FieldSchema{
|
||||
Name: "age",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_INT32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := schemapb.CollectionSchema{
|
||||
Name: collectionName,
|
||||
AutoID: true,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
&fieldVec, &fieldInt,
|
||||
},
|
||||
}
|
||||
|
||||
collectionMeta := etcdpb.CollectionMeta{
|
||||
ID: UniqueID(0),
|
||||
Schema: &schema,
|
||||
CreateTime: Timestamp(0),
|
||||
SegmentIDs: []UniqueID{0},
|
||||
PartitionTags: []string{"default"},
|
||||
}
|
||||
|
||||
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
|
||||
assert.NotEqual(t, "", collectionMetaBlob)
|
||||
|
||||
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := (*node.replica).getCollectionByName(collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, collection.meta.Schema.Name, "collection0")
|
||||
assert.Equal(t, collection.meta.ID, UniqueID(0))
|
||||
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
|
||||
|
||||
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
segmentID := UniqueID(0)
|
||||
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
|
||||
assert.NoError(t, err)
|
||||
node := newQueryNode()
|
||||
initTestMeta(t, node, "collection0", 0, 0)
|
||||
|
||||
const receiveBufSize = 1024
|
||||
// start pulsar
|
||||
producerChannels := []string{Params.statsChannelName()}
|
||||
|
||||
statsStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
|
||||
pulsarURL, _ := Params.pulsarAddress()
|
||||
|
||||
statsStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
|
||||
statsStream.SetPulsarClient(pulsarURL)
|
||||
statsStream.CreatePulsarProducers(producerChannels)
|
||||
|
||||
var statsMsgStream msgstream.MsgStream = statsStream
|
||||
|
||||
node.statsService = newStatsService(node.ctx, node.replica)
|
||||
node.statsService.statsStream = &statsMsgStream
|
||||
(*node.statsService.statsStream).Start()
|
||||
node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica)
|
||||
node.statsService.statsStream = statsMsgStream
|
||||
node.statsService.statsStream.Start()
|
||||
|
||||
// send stats
|
||||
node.statsService.sendSegmentStatistic()
|
||||
node.Close()
|
||||
}
|
||||
|
|
|
@ -36,11 +36,11 @@ type tSafeImpl struct {
|
|||
watcherList []*tSafeWatcher
|
||||
}
|
||||
|
||||
func newTSafe() *tSafe {
|
||||
func newTSafe() tSafe {
|
||||
var t tSafe = &tSafeImpl{
|
||||
watcherList: make([]*tSafeWatcher, 0),
|
||||
}
|
||||
return &t
|
||||
return t
|
||||
}
|
||||
|
||||
func (ts *tSafeImpl) registerTSafeWatcher(t *tSafeWatcher) {
|
||||
|
|
|
@ -9,13 +9,13 @@ import (
|
|||
func TestTSafe_GetAndSet(t *testing.T) {
|
||||
tSafe := newTSafe()
|
||||
watcher := newTSafeWatcher()
|
||||
(*tSafe).registerTSafeWatcher(watcher)
|
||||
tSafe.registerTSafeWatcher(watcher)
|
||||
|
||||
go func() {
|
||||
watcher.hasUpdate()
|
||||
timestamp := (*tSafe).get()
|
||||
timestamp := tSafe.get()
|
||||
assert.Equal(t, timestamp, Timestamp(1000))
|
||||
}()
|
||||
|
||||
(*tSafe).set(Timestamp(1000))
|
||||
tSafe.set(Timestamp(1000))
|
||||
}
|
||||
|
|
|
@ -206,7 +206,7 @@ extern "C" CStatus AddBinaryVectorToPayload(CPayloadWriter payloadWriter, uint8_
|
|||
st.error_msg = ErrorMsg("payload has finished");
|
||||
return st;
|
||||
}
|
||||
auto ast = builder->AppendValues(values, (dimension / 8) * length);
|
||||
auto ast = builder->AppendValues(values, length);
|
||||
if (!ast.ok()) {
|
||||
st.error_code = static_cast<int>(ErrorCode::UNEXPECTED_ERROR);
|
||||
st.error_msg = ErrorMsg(ast.message());
|
||||
|
@ -249,7 +249,7 @@ extern "C" CStatus AddFloatVectorToPayload(CPayloadWriter payloadWriter, float *
|
|||
st.error_msg = ErrorMsg("payload has finished");
|
||||
return st;
|
||||
}
|
||||
auto ast = builder->AppendValues(reinterpret_cast<const uint8_t *>(values), dimension * length * sizeof(float));
|
||||
auto ast = builder->AppendValues(reinterpret_cast<const uint8_t *>(values), length);
|
||||
if (!ast.ok()) {
|
||||
st.error_code = static_cast<int>(ErrorCode::UNEXPECTED_ERROR);
|
||||
st.error_msg = ErrorMsg(ast.message());
|
||||
|
@ -451,7 +451,7 @@ extern "C" CStatus GetBinaryVectorFromPayload(CPayloadReader payloadReader,
|
|||
return st;
|
||||
}
|
||||
*dimension = array->byte_width() * 8;
|
||||
*length = array->length() / array->byte_width();
|
||||
*length = array->length();
|
||||
*values = (uint8_t *) array->raw_values();
|
||||
return st;
|
||||
}
|
||||
|
@ -470,7 +470,7 @@ extern "C" CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader,
|
|||
return st;
|
||||
}
|
||||
*dimension = array->byte_width() / sizeof(float);
|
||||
*length = array->length() / array->byte_width();
|
||||
*length = array->length();
|
||||
*values = (float *) array->raw_values();
|
||||
return st;
|
||||
}
|
||||
|
@ -478,12 +478,7 @@ extern "C" CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader,
|
|||
extern "C" int GetPayloadLengthFromReader(CPayloadReader payloadReader) {
|
||||
auto p = reinterpret_cast<wrapper::PayloadReader *>(payloadReader);
|
||||
if (p->array == nullptr) return 0;
|
||||
auto ba = std::dynamic_pointer_cast<arrow::FixedSizeBinaryArray>(p->array);
|
||||
if (ba == nullptr) {
|
||||
return p->array->length();
|
||||
} else {
|
||||
return ba->length() / ba->byte_width();
|
||||
}
|
||||
return p->array->length();
|
||||
}
|
||||
|
||||
extern "C" CStatus ReleasePayloadReader(CPayloadReader payloadReader) {
|
||||
|
|
|
@ -5,6 +5,7 @@ extern "C" {
|
|||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
typedef void *CPayloadWriter;
|
||||
|
||||
|
@ -19,7 +20,7 @@ typedef struct CStatus {
|
|||
} CStatus;
|
||||
|
||||
CPayloadWriter NewPayloadWriter(int columnType);
|
||||
//CStatus AddBooleanToPayload(CPayloadWriter payloadWriter, bool *values, int length);
|
||||
CStatus AddBooleanToPayload(CPayloadWriter payloadWriter, bool *values, int length);
|
||||
CStatus AddInt8ToPayload(CPayloadWriter payloadWriter, int8_t *values, int length);
|
||||
CStatus AddInt16ToPayload(CPayloadWriter payloadWriter, int16_t *values, int length);
|
||||
CStatus AddInt32ToPayload(CPayloadWriter payloadWriter, int32_t *values, int length);
|
||||
|
@ -39,7 +40,7 @@ CStatus ReleasePayloadWriter(CPayloadWriter handler);
|
|||
|
||||
typedef void *CPayloadReader;
|
||||
CPayloadReader NewPayloadReader(int columnType, uint8_t *buffer, int64_t buf_size);
|
||||
//CStatus GetBoolFromPayload(CPayloadReader payloadReader, bool **values, int *length);
|
||||
CStatus GetBoolFromPayload(CPayloadReader payloadReader, bool **values, int *length);
|
||||
CStatus GetInt8FromPayload(CPayloadReader payloadReader, int8_t **values, int *length);
|
||||
CStatus GetInt16FromPayload(CPayloadReader payloadReader, int16_t **values, int *length);
|
||||
CStatus GetInt32FromPayload(CPayloadReader payloadReader, int32_t **values, int *length);
|
||||
|
@ -55,4 +56,4 @@ CStatus ReleasePayloadReader(CPayloadReader payloadReader);
|
|||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
|
|
@ -70,38 +70,38 @@ TEST(wrapper, inoutstream) {
|
|||
ASSERT_EQ(inarray->Value(4), 5);
|
||||
}
|
||||
|
||||
//TEST(wrapper, boolean) {
|
||||
// auto payload = NewPayloadWriter(ColumnType::BOOL);
|
||||
// bool data[] = {true, false, true, false};
|
||||
//
|
||||
// auto st = AddBooleanToPayload(payload, data, 4);
|
||||
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
// st = FinishPayloadWriter(payload);
|
||||
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
// auto cb = GetPayloadBufferFromWriter(payload);
|
||||
// ASSERT_GT(cb.length, 0);
|
||||
// ASSERT_NE(cb.data, nullptr);
|
||||
// auto nums = GetPayloadLengthFromWriter(payload);
|
||||
// ASSERT_EQ(nums, 4);
|
||||
//
|
||||
// auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length);
|
||||
// bool *values;
|
||||
// int length;
|
||||
// st = GetBoolFromPayload(reader, &values, &length);
|
||||
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
// ASSERT_NE(values, nullptr);
|
||||
// ASSERT_EQ(length, 4);
|
||||
// length = GetPayloadLengthFromReader(reader);
|
||||
// ASSERT_EQ(length, 4);
|
||||
// for (int i = 0; i < length; i++) {
|
||||
// ASSERT_EQ(data[i], values[i]);
|
||||
// }
|
||||
//
|
||||
// st = ReleasePayloadWriter(payload);
|
||||
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
// st = ReleasePayloadReader(reader);
|
||||
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
//}
|
||||
TEST(wrapper, boolean) {
|
||||
auto payload = NewPayloadWriter(ColumnType::BOOL);
|
||||
bool data[] = {true, false, true, false};
|
||||
|
||||
auto st = AddBooleanToPayload(payload, data, 4);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
st = FinishPayloadWriter(payload);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
auto cb = GetPayloadBufferFromWriter(payload);
|
||||
ASSERT_GT(cb.length, 0);
|
||||
ASSERT_NE(cb.data, nullptr);
|
||||
auto nums = GetPayloadLengthFromWriter(payload);
|
||||
ASSERT_EQ(nums, 4);
|
||||
|
||||
auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length);
|
||||
bool *values;
|
||||
int length;
|
||||
st = GetBoolFromPayload(reader, &values, &length);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
ASSERT_NE(values, nullptr);
|
||||
ASSERT_EQ(length, 4);
|
||||
length = GetPayloadLengthFromReader(reader);
|
||||
ASSERT_EQ(length, 4);
|
||||
for (int i = 0; i < length; i++) {
|
||||
ASSERT_EQ(data[i], values[i]);
|
||||
}
|
||||
|
||||
st = ReleasePayloadWriter(payload);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
st = ReleasePayloadReader(reader);
|
||||
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
|
||||
}
|
||||
|
||||
#define NUMERIC_TEST(TEST_NAME, COLUMN_TYPE, DATA_TYPE, ADD_FUNC, GET_FUNC, ARRAY_TYPE) TEST(wrapper, TEST_NAME) { \
|
||||
auto payload = NewPayloadWriter(COLUMN_TYPE); \
|
||||
|
|
|
@ -16,25 +16,311 @@ import (
|
|||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
type PayloadWriter struct {
|
||||
payloadWriterPtr C.CPayloadWriter
|
||||
}
|
||||
type (
|
||||
PayloadWriter struct {
|
||||
payloadWriterPtr C.CPayloadWriter
|
||||
colType schemapb.DataType
|
||||
}
|
||||
|
||||
PayloadReader struct {
|
||||
payloadReaderPtr C.CPayloadReader
|
||||
colType schemapb.DataType
|
||||
}
|
||||
)
|
||||
|
||||
func NewPayloadWriter(colType schemapb.DataType) (*PayloadWriter, error) {
|
||||
w := C.NewPayloadWriter(C.int(colType))
|
||||
if w == nil {
|
||||
return nil, errors.New("create Payload writer failed")
|
||||
}
|
||||
return &PayloadWriter{payloadWriterPtr: w}, nil
|
||||
return &PayloadWriter{payloadWriterPtr: w, colType: colType}, nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddDataToPayload(msgs interface{}, dim ...int) error {
|
||||
switch len(dim) {
|
||||
case 0:
|
||||
switch w.colType {
|
||||
case schemapb.DataType_BOOL:
|
||||
val, ok := msgs.([]bool)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddBoolToPayload(val)
|
||||
|
||||
case schemapb.DataType_INT8:
|
||||
val, ok := msgs.([]int8)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddInt8ToPayload(val)
|
||||
|
||||
case schemapb.DataType_INT16:
|
||||
val, ok := msgs.([]int16)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddInt16ToPayload(val)
|
||||
|
||||
case schemapb.DataType_INT32:
|
||||
val, ok := msgs.([]int32)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddInt32ToPayload(val)
|
||||
|
||||
case schemapb.DataType_INT64:
|
||||
val, ok := msgs.([]int64)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddInt64ToPayload(val)
|
||||
|
||||
case schemapb.DataType_FLOAT:
|
||||
val, ok := msgs.([]float32)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddFloatToPayload(val)
|
||||
|
||||
case schemapb.DataType_DOUBLE:
|
||||
val, ok := msgs.([]float64)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddDoubleToPayload(val)
|
||||
|
||||
case schemapb.DataType_STRING:
|
||||
val, ok := msgs.(string)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddOneStringToPayload(val)
|
||||
}
|
||||
case 1:
|
||||
switch w.colType {
|
||||
case schemapb.DataType_VECTOR_BINARY:
|
||||
val, ok := msgs.([]byte)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddBinaryVectorToPayload(val, dim[0])
|
||||
|
||||
case schemapb.DataType_VECTOR_FLOAT:
|
||||
val, ok := msgs.([]float32)
|
||||
if !ok {
|
||||
return errors.New("incorrect data type")
|
||||
}
|
||||
return w.AddFloatVectorToPayload(val, dim[0])
|
||||
}
|
||||
|
||||
default:
|
||||
return errors.New("incorrect input numbers")
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddBoolToPayload(msgs []bool) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.bool)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddBooleanToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddInt8ToPayload(msgs []int8) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
cMsgs := (*C.int8_t)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddInt8ToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddInt16ToPayload(msgs []int16) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.int16_t)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddInt16ToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddInt32ToPayload(msgs []int32) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.int32_t)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddInt32ToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddInt64ToPayload(msgs []int64) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.int64_t)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddInt64ToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddFloatToPayload(msgs []float32) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.float)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddFloatToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddDoubleToPayload(msgs []float64) error {
|
||||
length := len(msgs)
|
||||
if length <= 0 {
|
||||
return errors.Errorf("can't add empty msgs into payload")
|
||||
}
|
||||
|
||||
cMsgs := (*C.double)(unsafe.Pointer(&msgs[0]))
|
||||
cLength := C.int(length)
|
||||
|
||||
status := C.AddDoubleToPayload(w.payloadWriterPtr, cMsgs, cLength)
|
||||
|
||||
errCode := commonpb.ErrorCode(status.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(status.error_msg)
|
||||
defer C.free(unsafe.Pointer(status.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *PayloadWriter) AddOneStringToPayload(msg string) error {
|
||||
if len(msg) == 0 {
|
||||
length := len(msg)
|
||||
if length == 0 {
|
||||
return errors.New("can't add empty string into payload")
|
||||
}
|
||||
cstr := C.CString(msg)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
st := C.AddOneStringToPayload(w.payloadWriterPtr, cstr, C.int(len(msg)))
|
||||
|
||||
cmsg := C.CString(msg)
|
||||
clength := C.int(length)
|
||||
defer C.free(unsafe.Pointer(cmsg))
|
||||
|
||||
st := C.AddOneStringToPayload(w.payloadWriterPtr, cmsg, clength)
|
||||
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// dimension > 0 && (%8 == 0)
|
||||
func (w *PayloadWriter) AddBinaryVectorToPayload(binVec []byte, dim int) error {
|
||||
length := len(binVec)
|
||||
if length <= 0 {
|
||||
return errors.New("can't add empty binVec into payload")
|
||||
}
|
||||
|
||||
if dim <= 0 {
|
||||
return errors.New("dimension should be greater than 0")
|
||||
}
|
||||
|
||||
cBinVec := (*C.uint8_t)(&binVec[0])
|
||||
cDim := C.int(dim)
|
||||
cLength := C.int(length / (dim / 8))
|
||||
|
||||
st := C.AddBinaryVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return errors.New(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// dimension > 0 && (%8 == 0)
|
||||
func (w *PayloadWriter) AddFloatVectorToPayload(floatVec []float32, dim int) error {
|
||||
length := len(floatVec)
|
||||
if length <= 0 {
|
||||
return errors.New("can't add empty floatVec into payload")
|
||||
}
|
||||
|
||||
if dim <= 0 {
|
||||
return errors.New("dimension should be greater than 0")
|
||||
}
|
||||
|
||||
cBinVec := (*C.float)(&floatVec[0])
|
||||
cDim := C.int(dim)
|
||||
cLength := C.int(length / dim)
|
||||
|
||||
st := C.AddFloatVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
|
@ -56,13 +342,13 @@ func (w *PayloadWriter) FinishPayloadWriter() error {
|
|||
}
|
||||
|
||||
func (w *PayloadWriter) GetPayloadBufferFromWriter() ([]byte, error) {
|
||||
//See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
|
||||
cb := C.GetPayloadBufferFromWriter(w.payloadWriterPtr)
|
||||
pointer := unsafe.Pointer(cb.data)
|
||||
length := int(cb.length)
|
||||
if length <= 0 {
|
||||
return nil, errors.New("empty buffer")
|
||||
}
|
||||
// refer to: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
|
||||
slice := (*[1 << 28]byte)(pointer)[:length:length]
|
||||
return slice, nil
|
||||
}
|
||||
|
@ -87,16 +373,71 @@ func (w *PayloadWriter) Close() error {
|
|||
return w.ReleasePayloadWriter()
|
||||
}
|
||||
|
||||
type PayloadReader struct {
|
||||
payloadReaderPtr C.CPayloadReader
|
||||
}
|
||||
|
||||
func NewPayloadReader(colType schemapb.DataType, buf []byte) (*PayloadReader, error) {
|
||||
if len(buf) == 0 {
|
||||
return nil, errors.New("create Payload reader failed, buffer is empty")
|
||||
}
|
||||
r := C.NewPayloadReader(C.int(colType), (*C.uchar)(unsafe.Pointer(&buf[0])), C.long(len(buf)))
|
||||
return &PayloadReader{payloadReaderPtr: r}, nil
|
||||
return &PayloadReader{payloadReaderPtr: r, colType: colType}, nil
|
||||
}
|
||||
|
||||
// Params:
|
||||
// `idx`: String index
|
||||
// Return:
|
||||
// `interface{}`: all types.
|
||||
// `int`: length, only meaningful to FLOAT/BINARY VECTOR type.
|
||||
// `error`: error.
|
||||
func (r *PayloadReader) GetDataFromPayload(idx ...int) (interface{}, int, error) {
|
||||
switch len(idx) {
|
||||
case 1:
|
||||
switch r.colType {
|
||||
case schemapb.DataType_STRING:
|
||||
val, err := r.GetOneStringFromPayload(idx[0])
|
||||
return val, 0, err
|
||||
}
|
||||
case 0:
|
||||
switch r.colType {
|
||||
case schemapb.DataType_BOOL:
|
||||
val, err := r.GetBoolFromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_INT8:
|
||||
val, err := r.GetInt8FromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_INT16:
|
||||
val, err := r.GetInt16FromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_INT32:
|
||||
val, err := r.GetInt32FromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_INT64:
|
||||
val, err := r.GetInt64FromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_FLOAT:
|
||||
val, err := r.GetFloatFromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_DOUBLE:
|
||||
val, err := r.GetDoubleFromPayload()
|
||||
return val, 0, err
|
||||
|
||||
case schemapb.DataType_VECTOR_BINARY:
|
||||
return r.GetBinaryVectorFromPayload()
|
||||
|
||||
case schemapb.DataType_VECTOR_FLOAT:
|
||||
return r.GetFloatVectorFromPayload()
|
||||
default:
|
||||
return nil, 0, errors.New("Unknown type")
|
||||
}
|
||||
default:
|
||||
return nil, 0, errors.New("incorrect number of index")
|
||||
}
|
||||
|
||||
return nil, 0, errors.New("unknown error")
|
||||
}
|
||||
|
||||
func (r *PayloadReader) ReleasePayloadReader() error {
|
||||
|
@ -110,18 +451,169 @@ func (r *PayloadReader) ReleasePayloadReader() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetBoolFromPayload() ([]bool, error) {
|
||||
var cMsg *C.bool
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetBoolFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]bool)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetInt8FromPayload() ([]int8, error) {
|
||||
var cMsg *C.int8_t
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetInt8FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]int8)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetInt16FromPayload() ([]int16, error) {
|
||||
var cMsg *C.int16_t
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetInt16FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]int16)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetInt32FromPayload() ([]int32, error) {
|
||||
var cMsg *C.int32_t
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetInt32FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]int32)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetInt64FromPayload() ([]int64, error) {
|
||||
var cMsg *C.int64_t
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetInt64FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]int64)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetFloatFromPayload() ([]float32, error) {
|
||||
var cMsg *C.float
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetFloatFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]float32)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetDoubleFromPayload() ([]float64, error) {
|
||||
var cMsg *C.double
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetDoubleFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
slice := (*[1 << 28]float64)(unsafe.Pointer(cMsg))[:cSize:cSize]
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetOneStringFromPayload(idx int) (string, error) {
|
||||
var cStr *C.char
|
||||
var strSize C.int
|
||||
var cSize C.int
|
||||
|
||||
st := C.GetOneStringFromPayload(r.payloadReaderPtr, C.int(idx), &cStr, &cSize)
|
||||
|
||||
st := C.GetOneStringFromPayload(r.payloadReaderPtr, C.int(idx), &cStr, &strSize)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return "", errors.New(msg)
|
||||
}
|
||||
return C.GoStringN(cStr, strSize), nil
|
||||
return C.GoStringN(cStr, cSize), nil
|
||||
}
|
||||
|
||||
// ,dimension, error
|
||||
func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, error) {
|
||||
var cMsg *C.uint8_t
|
||||
var cDim C.int
|
||||
var cLen C.int
|
||||
|
||||
st := C.GetBinaryVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, 0, errors.New(msg)
|
||||
}
|
||||
length := (cDim / 8) * cLen
|
||||
|
||||
slice := (*[1 << 28]byte)(unsafe.Pointer(cMsg))[:length:length]
|
||||
return slice, int(cDim), nil
|
||||
}
|
||||
|
||||
// ,dimension, error
|
||||
func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) {
|
||||
var cMsg *C.float
|
||||
var cDim C.int
|
||||
var cLen C.int
|
||||
|
||||
st := C.GetFloatVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen)
|
||||
errCode := commonpb.ErrorCode(st.error_code)
|
||||
if errCode != commonpb.ErrorCode_SUCCESS {
|
||||
msg := C.GoString(st.error_msg)
|
||||
defer C.free(unsafe.Pointer(st.error_msg))
|
||||
return nil, 0, errors.New(msg)
|
||||
}
|
||||
length := cDim * cLen
|
||||
|
||||
slice := (*[1 << 28]float32)(unsafe.Pointer(cMsg))[:length:length]
|
||||
return slice, int(cDim), nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) {
|
||||
|
|
|
@ -1,54 +1,426 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
|
||||
)
|
||||
|
||||
func TestNewPayloadWriter(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_STRING)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, w)
|
||||
err = w.Close()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestPayLoadString(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_STRING)
|
||||
assert.Nil(t, err)
|
||||
err = w.AddOneStringToPayload("hello0")
|
||||
assert.Nil(t, err)
|
||||
err = w.AddOneStringToPayload("hello1")
|
||||
assert.Nil(t, err)
|
||||
err = w.AddOneStringToPayload("hello2")
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 3)
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_STRING, buffer)
|
||||
assert.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 3)
|
||||
str0, err := r.GetOneStringFromPayload(0)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str0, "hello0")
|
||||
str1, err := r.GetOneStringFromPayload(1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str1, "hello1")
|
||||
str2, err := r.GetOneStringFromPayload(2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str2, "hello2")
|
||||
|
||||
err = r.ReleasePayloadReader()
|
||||
assert.Nil(t, err)
|
||||
err = w.ReleasePayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
func TestPayload_ReaderandWriter(t *testing.T) {
|
||||
|
||||
t.Run("TestBool", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_BOOL)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddBoolToPayload([]bool{false, false, false, false})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]bool{false, false, false, false})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 8, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_BOOL, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 8)
|
||||
bools, err := r.GetBoolFromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []bool{false, false, false, false, false, false, false, false}, bools)
|
||||
ibools, _, err := r.GetDataFromPayload()
|
||||
bools = ibools.([]bool)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []bool{false, false, false, false, false, false, false, false}, bools)
|
||||
defer r.ReleasePayloadReader()
|
||||
|
||||
})
|
||||
|
||||
t.Run("TestInt8", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_INT8)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddInt8ToPayload([]int8{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]int8{4, 5, 6})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_INT8, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
int8s, err := r.GetInt8FromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int8{1, 2, 3, 4, 5, 6}, int8s)
|
||||
|
||||
iint8s, _, err := r.GetDataFromPayload()
|
||||
int8s = iint8s.([]int8)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.ElementsMatch(t, []int8{1, 2, 3, 4, 5, 6}, int8s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestInt16", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_INT16)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddInt16ToPayload([]int16{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]int16{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_INT16, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
int16s, err := r.GetInt16FromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int16{1, 2, 3, 1, 2, 3}, int16s)
|
||||
|
||||
iint16s, _, err := r.GetDataFromPayload()
|
||||
int16s = iint16s.([]int16)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int16{1, 2, 3, 1, 2, 3}, int16s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestInt32", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_INT32)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddInt32ToPayload([]int32{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]int32{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_INT32, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
int32s, err := r.GetInt32FromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int32{1, 2, 3, 1, 2, 3}, int32s)
|
||||
|
||||
iint32s, _, err := r.GetDataFromPayload()
|
||||
int32s = iint32s.([]int32)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int32{1, 2, 3, 1, 2, 3}, int32s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestInt64", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_INT64)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddInt64ToPayload([]int64{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]int64{1, 2, 3})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_INT64, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
int64s, err := r.GetInt64FromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int64{1, 2, 3, 1, 2, 3}, int64s)
|
||||
|
||||
iint64s, _, err := r.GetDataFromPayload()
|
||||
int64s = iint64s.([]int64)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int64{1, 2, 3, 1, 2, 3}, int64s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestFloat32", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_FLOAT)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddFloatToPayload([]float32{1.0, 2.0, 3.0})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]float32{1.0, 2.0, 3.0})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_FLOAT, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
float32s, err := r.GetFloatFromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float32s)
|
||||
|
||||
ifloat32s, _, err := r.GetDataFromPayload()
|
||||
float32s = ifloat32s.([]float32)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float32s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestDouble", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_DOUBLE)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddDoubleToPayload([]float64{1.0, 2.0, 3.0})
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]float64{1.0, 2.0, 3.0})
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 6, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_DOUBLE, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
float64s, err := r.GetDoubleFromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []float64{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float64s)
|
||||
|
||||
ifloat64s, _, err := r.GetDataFromPayload()
|
||||
float64s = ifloat64s.([]float64)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []float64{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float64s)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestAddOneString", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_STRING)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddOneStringToPayload("hello0")
|
||||
assert.Nil(t, err)
|
||||
err = w.AddOneStringToPayload("hello1")
|
||||
assert.Nil(t, err)
|
||||
err = w.AddOneStringToPayload("hello2")
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload("hello3")
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 4)
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_STRING, buffer)
|
||||
assert.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 4)
|
||||
str0, err := r.GetOneStringFromPayload(0)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str0, "hello0")
|
||||
str1, err := r.GetOneStringFromPayload(1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str1, "hello1")
|
||||
str2, err := r.GetOneStringFromPayload(2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str2, "hello2")
|
||||
str3, err := r.GetOneStringFromPayload(3)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str3, "hello3")
|
||||
|
||||
istr0, _, err := r.GetDataFromPayload(0)
|
||||
str0 = istr0.(string)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str0, "hello0")
|
||||
|
||||
istr1, _, err := r.GetDataFromPayload(1)
|
||||
str1 = istr1.(string)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str1, "hello1")
|
||||
|
||||
istr2, _, err := r.GetDataFromPayload(2)
|
||||
str2 = istr2.(string)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str2, "hello2")
|
||||
|
||||
istr3, _, err := r.GetDataFromPayload(3)
|
||||
str3 = istr3.(string)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, str3, "hello3")
|
||||
|
||||
err = r.ReleasePayloadReader()
|
||||
assert.Nil(t, err)
|
||||
err = w.ReleasePayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("TestBinaryVector", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_VECTOR_BINARY)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
in := make([]byte, 16)
|
||||
for i := 0; i < 16; i++ {
|
||||
in[i] = 1
|
||||
}
|
||||
in2 := make([]byte, 8)
|
||||
for i := 0; i < 8; i++ {
|
||||
in2[i] = 1
|
||||
}
|
||||
|
||||
err = w.AddBinaryVectorToPayload(in, 8)
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload(in2, 8)
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 24, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_VECTOR_BINARY, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 24)
|
||||
|
||||
binVecs, dim, err := r.GetBinaryVectorFromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 8, dim)
|
||||
assert.Equal(t, 24, len(binVecs))
|
||||
fmt.Println(binVecs)
|
||||
|
||||
ibinVecs, dim, err := r.GetDataFromPayload()
|
||||
assert.Nil(t, err)
|
||||
binVecs = ibinVecs.([]byte)
|
||||
assert.Equal(t, 8, dim)
|
||||
assert.Equal(t, 24, len(binVecs))
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
|
||||
t.Run("TestFloatVector", func(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_VECTOR_FLOAT)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddFloatVectorToPayload([]float32{1.0, 2.0}, 1)
|
||||
assert.Nil(t, err)
|
||||
err = w.AddDataToPayload([]float32{3.0, 4.0}, 1)
|
||||
assert.Nil(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
length, err := w.GetPayloadLengthFromWriter()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, length)
|
||||
defer w.ReleasePayloadWriter()
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
assert.Nil(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_VECTOR_FLOAT, buffer)
|
||||
require.Nil(t, err)
|
||||
length, err = r.GetPayloadLengthFromReader()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, length, 4)
|
||||
|
||||
floatVecs, dim, err := r.GetFloatVectorFromPayload()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, dim)
|
||||
assert.Equal(t, 4, len(floatVecs))
|
||||
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 4.0}, floatVecs)
|
||||
|
||||
ifloatVecs, dim, err := r.GetDataFromPayload()
|
||||
assert.Nil(t, err)
|
||||
floatVecs = ifloatVecs.([]float32)
|
||||
assert.Equal(t, 1, dim)
|
||||
assert.Equal(t, 4, len(floatVecs))
|
||||
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 4.0}, floatVecs)
|
||||
defer r.ReleasePayloadReader()
|
||||
})
|
||||
}
|
||||
|
|
|
@ -16,13 +16,17 @@ import (
|
|||
"os"
|
||||
"path"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
"github.com/spf13/viper"
|
||||
memkv "github.com/zilliztech/milvus-distributed/internal/kv/mem"
|
||||
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
|
||||
)
|
||||
|
||||
type UniqueID = typeutil.UniqueID
|
||||
|
||||
type Base interface {
|
||||
Load(key string) (string, error)
|
||||
LoadRange(key, endKey string, limit int) ([]string, []string, error)
|
||||
|
@ -38,7 +42,18 @@ type BaseTable struct {
|
|||
|
||||
func (gp *BaseTable) Init() {
|
||||
gp.params = memkv.NewMemoryKV()
|
||||
err := gp.LoadYaml("config.yaml")
|
||||
|
||||
err := gp.LoadYaml("milvus.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = gp.LoadYaml("advanced/common.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = gp.LoadYaml("advanced/channel.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -146,3 +161,140 @@ func (gp *BaseTable) Remove(key string) error {
|
|||
func (gp *BaseTable) Save(key, value string) error {
|
||||
return gp.params.Save(strings.ToLower(key), value)
|
||||
}
|
||||
|
||||
func (gp *BaseTable) ParseFloat(key string) float64 {
|
||||
valueStr, err := gp.Load(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
value, err := strconv.ParseFloat(valueStr, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (gp *BaseTable) ParseInt64(key string) int64 {
|
||||
valueStr, err := gp.Load(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
value, err := strconv.Atoi(valueStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int64(value)
|
||||
}
|
||||
|
||||
func (gp *BaseTable) ParseInt32(key string) int32 {
|
||||
valueStr, err := gp.Load(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
value, err := strconv.Atoi(valueStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int32(value)
|
||||
}
|
||||
|
||||
func (gp *BaseTable) ParseInt(key string) int {
|
||||
valueStr, err := gp.Load(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
value, err := strconv.Atoi(valueStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (gp *BaseTable) WriteNodeIDList() []UniqueID {
|
||||
proxyIDStr, err := gp.Load("nodeID.writeNodeIDList")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var ret []UniqueID
|
||||
proxyIDs := strings.Split(proxyIDStr, ",")
|
||||
for _, i := range proxyIDs {
|
||||
v, err := strconv.Atoi(i)
|
||||
if err != nil {
|
||||
log.Panicf("load write node id list error, %s", err.Error())
|
||||
}
|
||||
ret = append(ret, UniqueID(v))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (gp *BaseTable) ProxyIDList() []UniqueID {
|
||||
proxyIDStr, err := gp.Load("nodeID.proxyIDList")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var ret []UniqueID
|
||||
proxyIDs := strings.Split(proxyIDStr, ",")
|
||||
for _, i := range proxyIDs {
|
||||
v, err := strconv.Atoi(i)
|
||||
if err != nil {
|
||||
log.Panicf("load proxy id list error, %s", err.Error())
|
||||
}
|
||||
ret = append(ret, UniqueID(v))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (gp *BaseTable) QueryNodeIDList() []UniqueID {
|
||||
queryNodeIDStr, err := gp.Load("nodeID.queryNodeIDList")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var ret []UniqueID
|
||||
queryNodeIDs := strings.Split(queryNodeIDStr, ",")
|
||||
for _, i := range queryNodeIDs {
|
||||
v, err := strconv.Atoi(i)
|
||||
if err != nil {
|
||||
log.Panicf("load proxy id list error, %s", err.Error())
|
||||
}
|
||||
ret = append(ret, UniqueID(v))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// package methods
|
||||
|
||||
func ConvertRangeToIntRange(rangeStr, sep string) []int {
|
||||
items := strings.Split(rangeStr, sep)
|
||||
if len(items) != 2 {
|
||||
panic("Illegal range ")
|
||||
}
|
||||
|
||||
startStr := items[0]
|
||||
endStr := items[1]
|
||||
start, err := strconv.Atoi(startStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
end, err := strconv.Atoi(endStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if start < 0 || end < 0 {
|
||||
panic("Illegal range value")
|
||||
}
|
||||
if start > end {
|
||||
panic("Illegal range value, start > end")
|
||||
}
|
||||
return []int{start, end}
|
||||
}
|
||||
|
||||
func ConvertRangeToIntSlice(rangeStr, sep string) []int {
|
||||
rangeSlice := ConvertRangeToIntRange(rangeStr, sep)
|
||||
start, end := rangeSlice[0], rangeSlice[1]
|
||||
var ret []int
|
||||
for i := start; i < end; i++ {
|
||||
ret = append(ret, i)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
package paramtable
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -21,6 +22,8 @@ var Params = BaseTable{}
|
|||
|
||||
func TestMain(m *testing.M) {
|
||||
Params.Init()
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
//func TestMain
|
||||
|
@ -55,13 +58,13 @@ func TestGlobalParamsTable_SaveAndLoad(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGlobalParamsTable_LoadRange(t *testing.T) {
|
||||
_ = Params.Save("abc", "10")
|
||||
_ = Params.Save("fghz", "20")
|
||||
_ = Params.Save("bcde", "1.1")
|
||||
_ = Params.Save("abcd", "testSaveAndLoad")
|
||||
_ = Params.Save("zhi", "12")
|
||||
_ = Params.Save("xxxaab", "10")
|
||||
_ = Params.Save("xxxfghz", "20")
|
||||
_ = Params.Save("xxxbcde", "1.1")
|
||||
_ = Params.Save("xxxabcd", "testSaveAndLoad")
|
||||
_ = Params.Save("xxxzhi", "12")
|
||||
|
||||
keys, values, err := Params.LoadRange("a", "g", 10)
|
||||
keys, values, err := Params.LoadRange("xxxa", "xxxg", 10)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, len(keys))
|
||||
assert.Equal(t, "10", values[0])
|
||||
|
@ -97,24 +100,17 @@ func TestGlobalParamsTable_Remove(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGlobalParamsTable_LoadYaml(t *testing.T) {
|
||||
err := Params.LoadYaml("config.yaml")
|
||||
err := Params.LoadYaml("milvus.yaml")
|
||||
assert.Nil(t, err)
|
||||
|
||||
value1, err1 := Params.Load("etcd.address")
|
||||
value2, err2 := Params.Load("pulsar.port")
|
||||
value3, err3 := Params.Load("reader.topicend")
|
||||
value4, err4 := Params.Load("proxy.pulsarTopics.readerTopicPrefix")
|
||||
value5, err5 := Params.Load("proxy.network.address")
|
||||
err = Params.LoadYaml("advanced/channel.yaml")
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, value1, "localhost")
|
||||
assert.Equal(t, value2, "6650")
|
||||
assert.Equal(t, value3, "128")
|
||||
assert.Equal(t, value4, "milvusReader")
|
||||
assert.Equal(t, value5, "0.0.0.0")
|
||||
_, err = Params.Load("etcd.address")
|
||||
assert.Nil(t, err)
|
||||
_, err = Params.Load("pulsar.port")
|
||||
assert.Nil(t, err)
|
||||
_, err = Params.Load("msgChannel.channelRange.insert")
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Nil(t, err1)
|
||||
assert.Nil(t, err2)
|
||||
assert.Nil(t, err3)
|
||||
assert.Nil(t, err4)
|
||||
assert.Nil(t, err5)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue