Add impl cgo of parquet

Signed-off-by: XuanYang-cn <xuan.yang@zilliz.com>
pull/4973/head^2
XuanYang-cn 2020-12-08 14:41:04 +08:00 committed by yefu.chen
parent 65ce1f97e7
commit 32977e270c
50 changed files with 1965 additions and 3813 deletions

3
.gitignore vendored
View File

@ -56,3 +56,6 @@ cmake_build/
.DS_Store
*.swp
cwrapper_build
**/.clangd/*
**/compile_commands.json
**/.lint

View File

@ -3,7 +3,7 @@ timeout(time: 10, unit: 'MINUTES') {
sh '. ./before-install.sh && unset http_proxy && unset https_proxy && ./check_cache.sh -l $CCACHE_ARTFACTORY_URL --cache_dir=\$CCACHE_DIR -f ccache-\$OS_NAME-\$BUILD_ENV_IMAGE_ID.tar.gz || echo \"ccache files not found!\"'
}
sh '. ./scripts/before-install.sh && make check-proto-product && make verifiers && make install'
sh '. ./scripts/before-install.sh && make install'
dir ("scripts") {
withCredentials([usernamePassword(credentialsId: "${env.JFROG_CREDENTIALS_ID}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) {

View File

@ -4,7 +4,10 @@ try {
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d pulsar'
dir ('build/docker/deploy') {
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} pull'
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d'
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d master'
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} up -d proxy'
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} run -e QUERY_NODE_ID=1 -d querynode'
sh 'docker-compose -p ${DOCKER_COMPOSE_PROJECT_NAME} run -e QUERY_NODE_ID=2 -d querynode'
}
dir ('build/docker/test') {

View File

@ -41,9 +41,9 @@ fmt:
lint:
@echo "Running $@ check"
@GO111MODULE=on ${GOPATH}/bin/golangci-lint cache clean
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./internal/...
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./cmd/...
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=5m --config ./.golangci.yml ./tests/go/...
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=30m --config ./.golangci.yml ./internal/...
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=30m --config ./.golangci.yml ./cmd/...
@GO111MODULE=on ${GOPATH}/bin/golangci-lint run --timeout=30m --config ./.golangci.yml ./tests/go/...
ruleguard:
@echo "Running $@ check"

4
build/docker/test/.env Normal file
View File

@ -0,0 +1,4 @@
SOURCE_REPO=milvusdb
TARGET_REPO=milvusdb
SOURCE_TAG=latest
TARGET_TAG=latest

View File

@ -2,6 +2,7 @@ package main
import (
"context"
"fmt"
"log"
"os"
"os/signal"
@ -13,8 +14,7 @@ import (
func main() {
proxy.Init()
// Creates server.
fmt.Println("ProxyID is", proxy.Params.ProxyID())
ctx, cancel := context.WithCancel(context.Background())
svr, err := proxy.CreateProxy(ctx)
if err != nil {

View File

@ -2,18 +2,24 @@ package main
import (
"context"
"fmt"
"log"
"os"
"os/signal"
"syscall"
"go.uber.org/zap"
"github.com/zilliztech/milvus-distributed/internal/querynode"
)
func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
querynode.Init()
fmt.Println("QueryNodeID is", querynode.Params.QueryNodeID())
// Creates server.
ctx, cancel := context.WithCancel(context.Background())
svr := querynode.NewQueryNode(ctx, 0)
sc := make(chan os.Signal, 1)
signal.Notify(sc,
@ -28,8 +34,14 @@ func main() {
cancel()
}()
querynode.StartQueryNode(ctx)
if err := svr.Start(); err != nil {
log.Fatal("run server failed", zap.Error(err))
}
<-ctx.Done()
log.Print("Got signal to exit", zap.String("signal", sig.String()))
svr.Close()
switch sig {
case syscall.SIGTERM:
exit(0)

View File

@ -32,7 +32,7 @@ msgChannel:
# default channel range [0, 1)
channelRange:
insert: [0, 1]
insert: [0, 2]
delete: [0, 1]
dataDefinition: [0,1]
k2s: [0, 1]

View File

@ -1,118 +0,0 @@
# Copyright (C) 2019-2020 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under the License.
master:
address: localhost
port: 53100
pulsarmoniterinterval: 1
pulsartopic: "monitor-topic"
proxyidlist: [1, 2]
proxyTimeSyncChannels: ["proxy1", "proxy2"]
proxyTimeSyncSubName: "proxy-topic"
softTimeTickBarrierInterval: 500
writeidlist: [3, 4]
writeTimeSyncChannels: ["write3", "write4"]
writeTimeSyncSubName: "write-topic"
dmTimeSyncChannels: ["dm5", "dm6"]
k2sTimeSyncChannels: ["k2s7", "k2s8"]
defaultSizePerRecord: 1024
minimumAssignSize: 1048576
segmentThreshold: 536870912
segmentExpireDuration: 2000
segmentThresholdFactor: 0.75
querynodenum: 1
writenodenum: 1
statsChannels: "statistic"
etcd:
address: localhost
port: 2379
rootpath: by-dev
segthreshold: 10000
minio:
address: localhost
port: 9000
accessKeyID: minioadmin
secretAccessKey: minioadmin
useSSL: false
timesync:
interval: 400
storage:
driver: TIKV
address: localhost
port: 2379
accesskey:
secretkey:
pulsar:
authentication: false
user: user-default
token: eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJKb2UifQ.ipevRNuRP6HflG8cFKnmUPtypruRC4fb1DWtoLL62SY
address: localhost
port: 6650
topicnum: 128
reader:
clientid: 0
stopflag: -1
readerqueuesize: 10000
searchchansize: 10000
key2segchansize: 10000
topicstart: 0
topicend: 128
writer:
clientid: 0
stopflag: -2
readerqueuesize: 10000
searchbyidchansize: 10000
parallelism: 100
topicstart: 0
topicend: 128
bucket: "zilliz-hz"
proxy:
timezone: UTC+8
proxy_id: 1
numReaderNodes: 2
tsoSaveInterval: 200
timeTickInterval: 200
pulsarTopics:
readerTopicPrefix: "milvusReader"
numReaderTopics: 2
deleteTopic: "milvusDeleter"
queryTopic: "milvusQuery"
resultTopic: "milvusResult"
resultGroup: "milvusResultGroup"
timeTickTopic: "milvusTimeTick"
network:
address: 0.0.0.0
port: 19530
logs:
level: debug
trace.enable: true
path: /tmp/logs
max_log_file_size: 1024MB
log_rotate_num: 0
storage:
path: /var/lib/milvus
auto_flush_interval: 1

View File

@ -12,7 +12,7 @@
nodeID: # will be deprecated after v0.2
proxyIDList: [0]
queryNodeIDList: [2]
queryNodeIDList: [1, 2]
writeNodeIDList: [3]
etcd:
@ -23,6 +23,13 @@ etcd:
kvSubPath: kv # kvRootPath = rootPath + '/' + kvSubPath
segThreshold: 10000
minio:
address: localhost
port: 9000
accessKeyID: minioadmin
secretAccessKey: minioadmin
useSSL: false
pulsar:
address: localhost
port: 6650

View File

@ -1,150 +0,0 @@
package conf
import (
"io/ioutil"
"path"
"runtime"
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
storagetype "github.com/zilliztech/milvus-distributed/internal/storage/type"
yaml "gopkg.in/yaml.v2"
)
type UniqueID = typeutil.UniqueID
// yaml.MapSlice
type MasterConfig struct {
Address string
Port int32
PulsarMonitorInterval int32
PulsarTopic string
SegmentThreshold float32
SegmentExpireDuration int64
ProxyIDList []UniqueID
QueryNodeNum int
WriteNodeNum int
}
type EtcdConfig struct {
Address string
Port int32
Rootpath string
Segthreshold int64
}
type TimeSyncConfig struct {
Interval int32
}
type StorageConfig struct {
Driver storagetype.DriverType
Address string
Port int32
Accesskey string
Secretkey string
}
type PulsarConfig struct {
Authentication bool
User string
Token string
Address string
Port int32
TopicNum int
}
type ProxyConfig struct {
Timezone string `yaml:"timezone"`
ProxyID int `yaml:"proxy_id"`
NumReaderNodes int `yaml:"numReaderNodes"`
TosSaveInterval int `yaml:"tsoSaveInterval"`
TimeTickInterval int `yaml:"timeTickInterval"`
PulsarTopics struct {
ReaderTopicPrefix string `yaml:"readerTopicPrefix"`
NumReaderTopics int `yaml:"numReaderTopics"`
DeleteTopic string `yaml:"deleteTopic"`
QueryTopic string `yaml:"queryTopic"`
ResultTopic string `yaml:"resultTopic"`
ResultGroup string `yaml:"resultGroup"`
TimeTickTopic string `yaml:"timeTickTopic"`
} `yaml:"pulsarTopics"`
Network struct {
Address string `yaml:"address"`
Port int `yaml:"port"`
} `yaml:"network"`
Logs struct {
Level string `yaml:"level"`
TraceEnable bool `yaml:"trace.enable"`
Path string `yaml:"path"`
MaxLogFileSize string `yaml:"max_log_file_size"`
LogRotateNum int `yaml:"log_rotate_num"`
} `yaml:"logs"`
Storage struct {
Path string `yaml:"path"`
AutoFlushInterval int `yaml:"auto_flush_interval"`
} `yaml:"storage"`
}
type Reader struct {
ClientID int
StopFlag int64
ReaderQueueSize int
SearchChanSize int
Key2SegChanSize int
TopicStart int
TopicEnd int
}
type Writer struct {
ClientID int
StopFlag int64
ReaderQueueSize int
SearchByIDChanSize int
Parallelism int
TopicStart int
TopicEnd int
Bucket string
}
type ServerConfig struct {
Master MasterConfig
Etcd EtcdConfig
Timesync TimeSyncConfig
Storage StorageConfig
Pulsar PulsarConfig
Writer Writer
Reader Reader
Proxy ProxyConfig
}
var Config ServerConfig
// func init() {
// load_config()
// }
func getConfigsDir() string {
_, fpath, _, _ := runtime.Caller(0)
configPath := path.Dir(fpath) + "/../../configs/"
configPath = path.Dir(configPath)
return configPath
}
func LoadConfigWithPath(yamlFilePath string) {
source, err := ioutil.ReadFile(yamlFilePath)
if err != nil {
panic(err)
}
err = yaml.Unmarshal(source, &Config)
if err != nil {
panic(err)
}
//fmt.Printf("Result: %v\n", Config)
}
func LoadConfig(yamlFile string) {
filePath := path.Join(getConfigsDir(), yamlFile)
LoadConfigWithPath(filePath)
}

View File

@ -1,10 +0,0 @@
package conf
import (
"fmt"
"testing"
)
func TestMain(m *testing.M) {
fmt.Printf("Result: %v\n", Config)
}

View File

@ -1,14 +0,0 @@
package mockkv
import (
memkv "github.com/zilliztech/milvus-distributed/internal/kv/mem"
)
// use MemoryKV to mock EtcdKV
func NewEtcdKV() *memkv.MemoryKV {
return memkv.NewMemoryKV()
}
func NewMemoryKV() *memkv.MemoryKV {
return memkv.NewMemoryKV()
}

View File

@ -55,19 +55,8 @@ var Params ParamTable
func (p *ParamTable) Init() {
// load yaml
p.BaseTable.Init()
err := p.LoadYaml("milvus.yaml")
if err != nil {
panic(err)
}
err = p.LoadYaml("advanced/channel.yaml")
if err != nil {
panic(err)
}
err = p.LoadYaml("advanced/master.yaml")
if err != nil {
panic(err)
}
err = p.LoadYaml("advanced/common.yaml")
err := p.LoadYaml("advanced/master.yaml")
if err != nil {
panic(err)
}
@ -115,15 +104,7 @@ func (p *ParamTable) initAddress() {
}
func (p *ParamTable) initPort() {
masterPort, err := p.Load("master.port")
if err != nil {
panic(err)
}
port, err := strconv.Atoi(masterPort)
if err != nil {
panic(err)
}
p.Port = port
p.Port = p.ParseInt("master.port")
}
func (p *ParamTable) initEtcdAddress() {
@ -167,117 +148,40 @@ func (p *ParamTable) initKvRootPath() {
}
func (p *ParamTable) initTopicNum() {
insertChannelRange, err := p.Load("msgChannel.channelRange.insert")
iRangeStr, err := p.Load("msgChannel.channelRange.insert")
if err != nil {
panic(err)
}
channelRange := strings.Split(insertChannelRange, ",")
if len(channelRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(channelRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(channelRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
p.TopicNum = channelEnd
rangeSlice := paramtable.ConvertRangeToIntRange(iRangeStr, ",")
p.TopicNum = rangeSlice[1] - rangeSlice[0]
}
func (p *ParamTable) initSegmentSize() {
threshold, err := p.Load("master.segment.size")
if err != nil {
panic(err)
}
segmentThreshold, err := strconv.ParseFloat(threshold, 64)
if err != nil {
panic(err)
}
p.SegmentSize = segmentThreshold
p.SegmentSize = p.ParseFloat("master.segment.size")
}
func (p *ParamTable) initSegmentSizeFactor() {
segFactor, err := p.Load("master.segment.sizeFactor")
if err != nil {
panic(err)
}
factor, err := strconv.ParseFloat(segFactor, 64)
if err != nil {
panic(err)
}
p.SegmentSizeFactor = factor
p.SegmentSizeFactor = p.ParseFloat("master.segment.sizeFactor")
}
func (p *ParamTable) initDefaultRecordSize() {
size, err := p.Load("master.segment.defaultSizePerRecord")
if err != nil {
panic(err)
}
res, err := strconv.ParseInt(size, 10, 64)
if err != nil {
panic(err)
}
p.DefaultRecordSize = res
p.DefaultRecordSize = p.ParseInt64("master.segment.defaultSizePerRecord")
}
func (p *ParamTable) initMinSegIDAssignCnt() {
size, err := p.Load("master.segment.minIDAssignCnt")
if err != nil {
panic(err)
}
res, err := strconv.ParseInt(size, 10, 64)
if err != nil {
panic(err)
}
p.MinSegIDAssignCnt = res
p.MinSegIDAssignCnt = p.ParseInt64("master.segment.minIDAssignCnt")
}
func (p *ParamTable) initMaxSegIDAssignCnt() {
size, err := p.Load("master.segment.maxIDAssignCnt")
if err != nil {
panic(err)
}
res, err := strconv.ParseInt(size, 10, 64)
if err != nil {
panic(err)
}
p.MaxSegIDAssignCnt = res
p.MaxSegIDAssignCnt = p.ParseInt64("master.segment.maxIDAssignCnt")
}
func (p *ParamTable) initSegIDAssignExpiration() {
duration, err := p.Load("master.segment.IDAssignExpiration")
if err != nil {
panic(err)
}
res, err := strconv.ParseInt(duration, 10, 64)
if err != nil {
panic(err)
}
p.SegIDAssignExpiration = res
p.SegIDAssignExpiration = p.ParseInt64("master.segment.IDAssignExpiration")
}
func (p *ParamTable) initQueryNodeNum() {
id, err := p.Load("nodeID.queryNodeIDList")
if err != nil {
panic(err)
}
ids := strings.Split(id, ",")
for _, i := range ids {
_, err := strconv.ParseInt(i, 10, 64)
if err != nil {
log.Panicf("load proxy id list error, %s", err.Error())
}
}
p.QueryNodeNum = len(ids)
p.QueryNodeNum = len(p.QueryNodeIDList())
}
func (p *ParamTable) initQueryNodeStatsChannelName() {
@ -289,20 +193,7 @@ func (p *ParamTable) initQueryNodeStatsChannelName() {
}
func (p *ParamTable) initProxyIDList() {
id, err := p.Load("nodeID.proxyIDList")
if err != nil {
log.Panicf("load proxy id list error, %s", err.Error())
}
ids := strings.Split(id, ",")
idList := make([]typeutil.UniqueID, 0, len(ids))
for _, i := range ids {
v, err := strconv.ParseInt(i, 10, 64)
if err != nil {
log.Panicf("load proxy id list error, %s", err.Error())
}
idList = append(idList, typeutil.UniqueID(v))
}
p.ProxyIDList = idList
p.ProxyIDList = p.BaseTable.ProxyIDList()
}
func (p *ParamTable) initProxyTimeTickChannelNames() {
@ -347,20 +238,7 @@ func (p *ParamTable) initSoftTimeTickBarrierInterval() {
}
func (p *ParamTable) initWriteNodeIDList() {
id, err := p.Load("nodeID.writeNodeIDList")
if err != nil {
log.Panic(err)
}
ids := strings.Split(id, ",")
idlist := make([]typeutil.UniqueID, 0, len(ids))
for _, i := range ids {
v, err := strconv.ParseInt(i, 10, 64)
if err != nil {
log.Panicf("load proxy id list error, %s", err.Error())
}
idlist = append(idlist, typeutil.UniqueID(v))
}
p.WriteNodeIDList = idlist
p.WriteNodeIDList = p.BaseTable.WriteNodeIDList()
}
func (p *ParamTable) initWriteNodeTimeTickChannelNames() {
@ -385,81 +263,57 @@ func (p *ParamTable) initWriteNodeTimeTickChannelNames() {
}
func (p *ParamTable) initDDChannelNames() {
ch, err := p.Load("msgChannel.chanNamePrefix.dataDefinition")
prefix, err := p.Load("msgChannel.chanNamePrefix.dataDefinition")
if err != nil {
log.Fatal(err)
panic(err)
}
id, err := p.Load("nodeID.queryNodeIDList")
prefix += "-"
iRangeStr, err := p.Load("msgChannel.channelRange.dataDefinition")
if err != nil {
log.Panicf("load query node id list error, %s", err.Error())
panic(err)
}
ids := strings.Split(id, ",")
channels := make([]string, 0, len(ids))
for _, i := range ids {
_, err := strconv.ParseInt(i, 10, 64)
if err != nil {
log.Panicf("load query node id list error, %s", err.Error())
}
channels = append(channels, ch+"-"+i)
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
p.DDChannelNames = channels
p.DDChannelNames = ret
}
func (p *ParamTable) initInsertChannelNames() {
ch, err := p.Load("msgChannel.chanNamePrefix.insert")
if err != nil {
log.Fatal(err)
}
channelRange, err := p.Load("msgChannel.channelRange.insert")
prefix, err := p.Load("msgChannel.chanNamePrefix.insert")
if err != nil {
panic(err)
}
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
prefix += "-"
iRangeStr, err := p.Load("msgChannel.channelRange.insert")
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
p.InsertChannelNames = channels
p.InsertChannelNames = ret
}
func (p *ParamTable) initK2SChannelNames() {
ch, err := p.Load("msgChannel.chanNamePrefix.k2s")
prefix, err := p.Load("msgChannel.chanNamePrefix.k2s")
if err != nil {
log.Fatal(err)
panic(err)
}
id, err := p.Load("nodeID.writeNodeIDList")
prefix += "-"
iRangeStr, err := p.Load("msgChannel.channelRange.k2s")
if err != nil {
log.Panicf("load write node id list error, %s", err.Error())
panic(err)
}
ids := strings.Split(id, ",")
channels := make([]string, 0, len(ids))
for _, i := range ids {
_, err := strconv.ParseInt(i, 10, 64)
if err != nil {
log.Panicf("load write node id list error, %s", err.Error())
}
channels = append(channels, ch+"-"+i)
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
p.K2SChannelNames = channels
p.K2SChannelNames = ret
}
func (p *ParamTable) initMaxPartitionNum() {

View File

@ -1,6 +1,7 @@
package master
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
@ -32,7 +33,7 @@ func TestParamTable_KVRootPath(t *testing.T) {
func TestParamTable_TopicNum(t *testing.T) {
num := Params.TopicNum
assert.Equal(t, num, 1)
fmt.Println("TopicNum:", num)
}
func TestParamTable_SegmentSize(t *testing.T) {
@ -67,7 +68,7 @@ func TestParamTable_SegIDAssignExpiration(t *testing.T) {
func TestParamTable_QueryNodeNum(t *testing.T) {
num := Params.QueryNodeNum
assert.Equal(t, num, 1)
fmt.Println("QueryNodeNum", num)
}
func TestParamTable_QueryNodeStatsChannelName(t *testing.T) {
@ -111,12 +112,11 @@ func TestParamTable_WriteNodeTimeTickChannelNames(t *testing.T) {
func TestParamTable_InsertChannelNames(t *testing.T) {
names := Params.InsertChannelNames
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "insert-0")
assert.Equal(t, Params.TopicNum, len(names))
}
func TestParamTable_K2SChannelNames(t *testing.T) {
names := Params.K2SChannelNames
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "k2s-3")
assert.Equal(t, names[0], "k2s-0")
}

View File

@ -19,19 +19,8 @@ var Params ParamTable
func (pt *ParamTable) Init() {
pt.BaseTable.Init()
err := pt.LoadYaml("milvus.yaml")
if err != nil {
panic(err)
}
err = pt.LoadYaml("advanced/proxy.yaml")
if err != nil {
panic(err)
}
err = pt.LoadYaml("advanced/channel.yaml")
if err != nil {
panic(err)
}
err = pt.LoadYaml("advanced/common.yaml")
err := pt.LoadYaml("advanced/proxy.yaml")
if err != nil {
panic(err)
}
@ -48,15 +37,24 @@ func (pt *ParamTable) Init() {
pt.Save("_proxyID", proxyIDStr)
}
func (pt *ParamTable) NetWorkAddress() string {
addr, err := pt.Load("proxy.network.address")
func (pt *ParamTable) NetworkPort() int {
return pt.ParseInt("proxy.port")
}
func (pt *ParamTable) NetworkAddress() string {
addr, err := pt.Load("proxy.address")
if err != nil {
panic(err)
}
if ip := net.ParseIP(addr); ip == nil {
panic("invalid ip proxy.network.address")
hostName, _ := net.LookupHost(addr)
if len(hostName) <= 0 {
if ip := net.ParseIP(addr); ip == nil {
panic("invalid ip proxy.address")
}
}
port, err := pt.Load("proxy.network.port")
port, err := pt.Load("proxy.port")
if err != nil {
panic(err)
}
@ -88,23 +86,6 @@ func (pt *ParamTable) ProxyNum() int {
return len(ret)
}
func (pt *ParamTable) ProxyIDList() []UniqueID {
proxyIDStr, err := pt.Load("nodeID.proxyIDList")
if err != nil {
panic(err)
}
var ret []UniqueID
proxyIDs := strings.Split(proxyIDStr, ",")
for _, i := range proxyIDs {
v, err := strconv.Atoi(i)
if err != nil {
log.Panicf("load proxy id list error, %s", err.Error())
}
ret = append(ret, UniqueID(v))
}
return ret
}
func (pt *ParamTable) queryNodeNum() int {
return len(pt.queryNodeIDList())
}
@ -150,25 +131,6 @@ func (pt *ParamTable) TimeTickInterval() time.Duration {
return time.Duration(interval) * time.Millisecond
}
func (pt *ParamTable) convertRangeToSlice(rangeStr, sep string) []int {
channelIDs := strings.Split(rangeStr, sep)
startStr := channelIDs[0]
endStr := channelIDs[1]
start, err := strconv.Atoi(startStr)
if err != nil {
panic(err)
}
end, err := strconv.Atoi(endStr)
if err != nil {
panic(err)
}
var ret []int
for i := start; i < end; i++ {
ret = append(ret, i)
}
return ret
}
func (pt *ParamTable) sliceIndex() int {
proxyID := pt.ProxyID()
proxyIDList := pt.ProxyIDList()
@ -190,7 +152,7 @@ func (pt *ParamTable) InsertChannelNames() []string {
if err != nil {
panic(err)
}
channelIDs := pt.convertRangeToSlice(iRangeStr, ",")
channelIDs := paramtable.ConvertRangeToIntSlice(iRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
@ -216,19 +178,12 @@ func (pt *ParamTable) DeleteChannelNames() []string {
if err != nil {
panic(err)
}
channelIDs := pt.convertRangeToSlice(dRangeStr, ",")
channelIDs := paramtable.ConvertRangeToIntSlice(dRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
proxyNum := pt.ProxyNum()
sep := len(channelIDs) / proxyNum
index := pt.sliceIndex()
if index == -1 {
panic("ProxyID not Match with Config")
}
start := index * sep
return ret[start : start+sep]
return ret
}
func (pt *ParamTable) K2SChannelNames() []string {
@ -241,19 +196,12 @@ func (pt *ParamTable) K2SChannelNames() []string {
if err != nil {
panic(err)
}
channelIDs := pt.convertRangeToSlice(k2sRangeStr, ",")
channelIDs := paramtable.ConvertRangeToIntSlice(k2sRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
proxyNum := pt.ProxyNum()
sep := len(channelIDs) / proxyNum
index := pt.sliceIndex()
if index == -1 {
panic("ProxyID not Match with Config")
}
start := index * sep
return ret[start : start+sep]
return ret
}
func (pt *ParamTable) SearchChannelNames() []string {
@ -261,8 +209,17 @@ func (pt *ParamTable) SearchChannelNames() []string {
if err != nil {
panic(err)
}
prefix += "-0"
return []string{prefix}
prefix += "-"
sRangeStr, err := pt.Load("msgChannel.channelRange.search")
if err != nil {
panic(err)
}
channelIDs := paramtable.ConvertRangeToIntSlice(sRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
return ret
}
func (pt *ParamTable) SearchResultChannelNames() []string {
@ -275,7 +232,7 @@ func (pt *ParamTable) SearchResultChannelNames() []string {
if err != nil {
panic(err)
}
channelIDs := pt.convertRangeToSlice(sRangeStr, ",")
channelIDs := paramtable.ConvertRangeToIntSlice(sRangeStr, ",")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
@ -321,144 +278,24 @@ func (pt *ParamTable) DataDefinitionChannelNames() []string {
return []string{prefix}
}
func (pt *ParamTable) parseInt64(key string) int64 {
valueStr, err := pt.Load(key)
if err != nil {
panic(err)
}
value, err := strconv.Atoi(valueStr)
if err != nil {
panic(err)
}
return int64(value)
}
func (pt *ParamTable) MsgStreamInsertBufSize() int64 {
return pt.parseInt64("proxy.msgStream.insert.bufSize")
return pt.ParseInt64("proxy.msgStream.insert.bufSize")
}
func (pt *ParamTable) MsgStreamSearchBufSize() int64 {
return pt.parseInt64("proxy.msgStream.search.bufSize")
return pt.ParseInt64("proxy.msgStream.search.bufSize")
}
func (pt *ParamTable) MsgStreamSearchResultBufSize() int64 {
return pt.parseInt64("proxy.msgStream.searchResult.recvBufSize")
return pt.ParseInt64("proxy.msgStream.searchResult.recvBufSize")
}
func (pt *ParamTable) MsgStreamSearchResultPulsarBufSize() int64 {
return pt.parseInt64("proxy.msgStream.searchResult.pulsarBufSize")
return pt.ParseInt64("proxy.msgStream.searchResult.pulsarBufSize")
}
func (pt *ParamTable) MsgStreamTimeTickBufSize() int64 {
return pt.parseInt64("proxy.msgStream.timeTick.bufSize")
}
func (pt *ParamTable) insertChannelNames() []string {
ch, err := pt.Load("msgChannel.chanNamePrefix.insert")
if err != nil {
log.Fatal(err)
}
channelRange, err := pt.Load("msgChannel.channelRange.insert")
if err != nil {
panic(err)
}
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
}
func (pt *ParamTable) searchChannelNames() []string {
ch, err := pt.Load("msgChannel.chanNamePrefix.search")
if err != nil {
log.Fatal(err)
}
channelRange, err := pt.Load("msgChannel.channelRange.search")
if err != nil {
panic(err)
}
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
}
func (pt *ParamTable) searchResultChannelNames() []string {
ch, err := pt.Load("msgChannel.chanNamePrefix.searchResult")
if err != nil {
log.Fatal(err)
}
channelRange, err := pt.Load("msgChannel.channelRange.searchResult")
if err != nil {
panic(err)
}
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
return pt.ParseInt64("proxy.msgStream.timeTick.bufSize")
}
func (pt *ParamTable) MaxNameLength() int64 {

View File

@ -5,6 +5,7 @@ import (
"log"
"math/rand"
"net"
"strconv"
"sync"
"time"
@ -59,7 +60,7 @@ func CreateProxy(ctx context.Context) (*Proxy, error) {
p.queryMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamSearchBufSize())
p.queryMsgStream.SetPulsarClient(pulsarAddress)
p.queryMsgStream.CreatePulsarProducers(Params.searchChannelNames())
p.queryMsgStream.CreatePulsarProducers(Params.SearchChannelNames())
masterAddr := Params.MasterAddress()
idAllocator, err := allocator.NewIDAllocator(p.proxyLoopCtx, masterAddr)
@ -83,7 +84,7 @@ func CreateProxy(ctx context.Context) (*Proxy, error) {
p.manipulationMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamInsertBufSize())
p.manipulationMsgStream.SetPulsarClient(pulsarAddress)
p.manipulationMsgStream.CreatePulsarProducers(Params.insertChannelNames())
p.manipulationMsgStream.CreatePulsarProducers(Params.InsertChannelNames())
repackFuncImpl := func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
return insertRepackFunc(tsMsgs, hashKeys, p.segAssigner, false)
}
@ -137,7 +138,7 @@ func (p *Proxy) AddCloseCallback(callbacks ...func()) {
func (p *Proxy) grpcLoop() {
defer p.proxyLoopWg.Done()
lis, err := net.Listen("tcp", Params.NetWorkAddress())
lis, err := net.Listen("tcp", ":"+strconv.Itoa(Params.NetworkPort()))
if err != nil {
log.Fatalf("Proxy grpc server fatal error=%v", err)
}

View File

@ -100,7 +100,7 @@ func setup() {
startMaster(ctx)
startProxy(ctx)
proxyAddr := Params.NetWorkAddress()
proxyAddr := Params.NetworkAddress()
addr := strings.Split(proxyAddr, ":")
if addr[0] == "0.0.0.0" {
proxyAddr = "127.0.0.1:" + addr[1]

View File

@ -364,7 +364,7 @@ func (sched *TaskScheduler) queryResultLoop() {
unmarshal := msgstream.NewUnmarshalDispatcher()
queryResultMsgStream := msgstream.NewPulsarMsgStream(sched.ctx, Params.MsgStreamSearchResultBufSize())
queryResultMsgStream.SetPulsarClient(Params.PulsarAddress())
queryResultMsgStream.CreatePulsarConsumers(Params.searchResultChannelNames(),
queryResultMsgStream.CreatePulsarConsumers(Params.SearchResultChannelNames(),
Params.ProxySubName(),
unmarshal,
Params.MsgStreamSearchResultPulsarBufSize())

View File

@ -31,7 +31,7 @@ import (
* is up-to-date.
*/
type collectionReplica interface {
getTSafe() *tSafe
getTSafe() tSafe
// collection
getCollectionNum() int
@ -68,11 +68,11 @@ type collectionReplicaImpl struct {
collections []*Collection
segments map[UniqueID]*Segment
tSafe *tSafe
tSafe tSafe
}
//----------------------------------------------------------------------------------------------------- tSafe
func (colReplica *collectionReplicaImpl) getTSafe() *tSafe {
func (colReplica *collectionReplicaImpl) getTSafe() tSafe {
return colReplica.tSafe
}
@ -111,6 +111,7 @@ func (colReplica *collectionReplicaImpl) removeCollection(collectionID UniqueID)
if col.ID() == collectionID {
for _, p := range *col.Partitions() {
for _, s := range *p.Segments() {
deleteSegment(colReplica.segments[s.ID()])
delete(colReplica.segments, s.ID())
}
}
@ -202,6 +203,7 @@ func (colReplica *collectionReplicaImpl) removePartition(collectionID UniqueID,
for _, p := range *collection.Partitions() {
if p.Tag() == partitionTag {
for _, s := range *p.Segments() {
deleteSegment(colReplica.segments[s.ID()])
delete(colReplica.segments, s.ID())
}
} else {

File diff suppressed because it is too large Load Diff

View File

@ -1,179 +1,48 @@
package querynode
import (
"context"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
func TestCollection_Partitions(t *testing.T) {
ctx := context.Background()
node := NewQueryNode(ctx, 0)
node := newQueryNode()
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
collectionID := UniqueID(0)
initTestMeta(t, node, collectionName, collectionID, 0)
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
collection, err := node.replica.getCollectionByName(collectionName)
assert.NoError(t, err)
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
for _, tag := range collectionMeta.PartitionTags {
err := (*node.replica).addPartition(collection.ID(), tag)
assert.NoError(t, err)
}
partitions := collection.Partitions()
assert.Equal(t, len(collectionMeta.PartitionTags), len(*partitions))
assert.Equal(t, 1, len(*partitions))
}
func TestCollection_newCollection(t *testing.T) {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
}
func TestCollection_deleteCollection(t *testing.T) {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
deleteCollection(collection)
}

View File

@ -11,10 +11,10 @@ type dataSyncService struct {
ctx context.Context
fg *flowgraph.TimeTickedFlowGraph
replica *collectionReplica
replica collectionReplica
}
func newDataSyncService(ctx context.Context, replica *collectionReplica) *dataSyncService {
func newDataSyncService(ctx context.Context, replica collectionReplica) *dataSyncService {
return &dataSyncService{
ctx: ctx,

View File

@ -1,101 +1,22 @@
package querynode
import (
"context"
"encoding/binary"
"math"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
// NOTE: start pulsar before test
func TestDataSyncService_Start(t *testing.T) {
Params.Init()
var ctx context.Context
if closeWithDeadline {
var cancel context.CancelFunc
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel = context.WithDeadline(context.Background(), d)
defer cancel()
} else {
ctx = context.Background()
}
// init query node
pulsarURL, _ := Params.pulsarAddress()
node := NewQueryNode(ctx, 0)
// init meta
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
assert.NoError(t, err)
segmentID := UniqueID(0)
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
assert.NoError(t, err)
node := newQueryNode()
initTestMeta(t, node, "collection0", 0, 0)
// test data generate
const msgLength = 10
const DIM = 16
@ -179,25 +100,25 @@ func TestDataSyncService_Start(t *testing.T) {
// pulsar produce
const receiveBufSize = 1024
producerChannels := Params.insertChannelNames()
pulsarURL, _ := Params.pulsarAddress()
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
insertStream.SetPulsarClient(pulsarURL)
insertStream.CreatePulsarProducers(producerChannels)
var insertMsgStream msgstream.MsgStream = insertStream
insertMsgStream.Start()
err = insertMsgStream.Produce(&msgPack)
err := insertMsgStream.Produce(&msgPack)
assert.NoError(t, err)
err = insertMsgStream.Broadcast(&timeTickMsgPack)
assert.NoError(t, err)
// dataSync
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
go node.dataSyncService.start()
node.Close()
<-ctx.Done()
}

View File

@ -10,7 +10,7 @@ import (
type insertNode struct {
BaseNode
replica *collectionReplica
replica collectionReplica
}
type InsertData struct {
@ -58,13 +58,13 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg {
insertData.insertRecords[task.SegmentID] = append(insertData.insertRecords[task.SegmentID], task.RowData...)
// check if segment exists, if not, create this segment
if !(*iNode.replica).hasSegment(task.SegmentID) {
collection, err := (*iNode.replica).getCollectionByName(task.CollectionName)
if !iNode.replica.hasSegment(task.SegmentID) {
collection, err := iNode.replica.getCollectionByName(task.CollectionName)
if err != nil {
log.Println(err)
continue
}
err = (*iNode.replica).addSegment(task.SegmentID, task.PartitionTag, collection.ID())
err = iNode.replica.addSegment(task.SegmentID, task.PartitionTag, collection.ID())
if err != nil {
log.Println(err)
continue
@ -74,7 +74,7 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg {
// 2. do preInsert
for segmentID := range insertData.insertRecords {
var targetSegment, err = (*iNode.replica).getSegmentByID(segmentID)
var targetSegment, err = iNode.replica.getSegmentByID(segmentID)
if err != nil {
log.Println("preInsert failed")
// TODO: add error handling
@ -102,7 +102,7 @@ func (iNode *insertNode) Operate(in []*Msg) []*Msg {
}
func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *sync.WaitGroup) {
var targetSegment, err = (*iNode.replica).getSegmentByID(segmentID)
var targetSegment, err = iNode.replica.getSegmentByID(segmentID)
if err != nil {
log.Println("cannot find segment:", segmentID)
// TODO: add error handling
@ -127,7 +127,7 @@ func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *syn
wg.Done()
}
func newInsertNode(replica *collectionReplica) *insertNode {
func newInsertNode(replica collectionReplica) *insertNode {
maxQueueLength := Params.flowGraphMaxQueueLength()
maxParallelism := Params.flowGraphMaxParallelism()

View File

@ -6,7 +6,7 @@ import (
type serviceTimeNode struct {
BaseNode
replica *collectionReplica
replica collectionReplica
}
func (stNode *serviceTimeNode) Name() string {
@ -28,12 +28,12 @@ func (stNode *serviceTimeNode) Operate(in []*Msg) []*Msg {
}
// update service time
(*(*stNode.replica).getTSafe()).set(serviceTimeMsg.timeRange.timestampMax)
stNode.replica.getTSafe().set(serviceTimeMsg.timeRange.timestampMax)
//fmt.Println("update tSafe to:", getPhysicalTime(serviceTimeMsg.timeRange.timestampMax))
return nil
}
func newServiceTimeNode(replica *collectionReplica) *serviceTimeNode {
func newServiceTimeNode(replica collectionReplica) *serviceTimeNode {
maxQueueLength := Params.flowGraphMaxQueueLength()
maxParallelism := Params.flowGraphMaxParallelism()

View File

@ -26,10 +26,10 @@ const (
type metaService struct {
ctx context.Context
kvBase *etcdkv.EtcdKV
replica *collectionReplica
replica collectionReplica
}
func newMetaService(ctx context.Context, replica *collectionReplica) *metaService {
func newMetaService(ctx context.Context, replica collectionReplica) *metaService {
ETCDAddr := Params.etcdAddress()
MetaRootPath := Params.metaRootPath()
@ -149,12 +149,12 @@ func (mService *metaService) processCollectionCreate(id string, value string) {
col := mService.collectionUnmarshal(value)
if col != nil {
err := (*mService.replica).addCollection(col, value)
err := mService.replica.addCollection(col, value)
if err != nil {
log.Println(err)
}
for _, partitionTag := range col.PartitionTags {
err = (*mService.replica).addPartition(col.ID, partitionTag)
err = mService.replica.addPartition(col.ID, partitionTag)
if err != nil {
log.Println(err)
}
@ -173,7 +173,7 @@ func (mService *metaService) processSegmentCreate(id string, value string) {
// TODO: what if seg == nil? We need to notify master and return rpc request failed
if seg != nil {
err := (*mService.replica).addSegment(seg.SegmentID, seg.PartitionTag, seg.CollectionID)
err := mService.replica.addSegment(seg.SegmentID, seg.PartitionTag, seg.CollectionID)
if err != nil {
log.Println(err)
return
@ -202,7 +202,7 @@ func (mService *metaService) processSegmentModify(id string, value string) {
}
if seg != nil {
targetSegment, err := (*mService.replica).getSegmentByID(seg.SegmentID)
targetSegment, err := mService.replica.getSegmentByID(seg.SegmentID)
if err != nil {
log.Println(err)
return
@ -218,11 +218,11 @@ func (mService *metaService) processCollectionModify(id string, value string) {
col := mService.collectionUnmarshal(value)
if col != nil {
err := (*mService.replica).addPartitionsByCollectionMeta(col)
err := mService.replica.addPartitionsByCollectionMeta(col)
if err != nil {
log.Println(err)
}
err = (*mService.replica).removePartitionsByCollectionMeta(col)
err = mService.replica.removePartitionsByCollectionMeta(col)
if err != nil {
log.Println(err)
}
@ -249,7 +249,7 @@ func (mService *metaService) processSegmentDelete(id string) {
log.Println("Cannot parse segment id:" + id)
}
err = (*mService.replica).removeSegment(segmentID)
err = mService.replica.removeSegment(segmentID)
if err != nil {
log.Println(err)
return
@ -264,7 +264,7 @@ func (mService *metaService) processCollectionDelete(id string) {
log.Println("Cannot parse collection id:" + id)
}
err = (*mService.replica).removeCollection(collectionID)
err = mService.replica.removeCollection(collectionID)
if err != nil {
log.Println(err)
return

View File

@ -3,23 +3,13 @@ package querynode
import (
"context"
"math"
"os"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
func TestMain(m *testing.M) {
Params.Init()
exitCode := m.Run()
os.Exit(exitCode)
}
func TestMetaService_start(t *testing.T) {
var ctx context.Context
@ -37,6 +27,7 @@ func TestMetaService_start(t *testing.T) {
node.metaService = newMetaService(ctx, node.replica)
(*node.metaService).start()
node.Close()
}
func TestMetaService_getCollectionObjId(t *testing.T) {
@ -119,47 +110,9 @@ func TestMetaService_isSegmentChannelRangeInQueryNodeChannelRange(t *testing.T)
func TestMetaService_printCollectionStruct(t *testing.T) {
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
printCollectionStruct(&collectionMeta)
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
printCollectionStruct(collectionMeta)
}
func TestMetaService_printSegmentStruct(t *testing.T) {
@ -178,13 +131,8 @@ func TestMetaService_printSegmentStruct(t *testing.T) {
}
func TestMetaService_processCollectionCreate(t *testing.T) {
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
id := "0"
value := `schema: <
@ -196,6 +144,10 @@ func TestMetaService_processCollectionCreate(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -212,71 +164,21 @@ func TestMetaService_processCollectionCreate(t *testing.T) {
node.metaService.processCollectionCreate(id, value)
collectionNum := (*node.replica).getCollectionNum()
collectionNum := node.replica.getCollectionNum()
assert.Equal(t, collectionNum, 1)
collection, err := (*node.replica).getCollectionByName("test")
collection, err := node.replica.getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
node.Close()
}
func TestMetaService_processSegmentCreate(t *testing.T) {
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
node := newQueryNode()
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
colMetaBlob := proto.MarshalTextString(&collectionMeta)
err := (*node.replica).addCollection(&collectionMeta, string(colMetaBlob))
assert.NoError(t, err)
err = (*node.replica).addPartition(UniqueID(0), "default")
assert.NoError(t, err)
collectionID := UniqueID(0)
initTestMeta(t, node, collectionName, collectionID, 0)
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
id := "0"
value := `partition_tag: "default"
@ -287,19 +189,15 @@ func TestMetaService_processSegmentCreate(t *testing.T) {
(*node.metaService).processSegmentCreate(id, value)
s, err := (*node.replica).getSegmentByID(UniqueID(0))
s, err := node.replica.getSegmentByID(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, s.segmentID, UniqueID(0))
node.Close()
}
func TestMetaService_processCreate(t *testing.T) {
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
key1 := "by-dev/meta/collection/0"
msg1 := `schema: <
@ -311,6 +209,10 @@ func TestMetaService_processCreate(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -326,10 +228,10 @@ func TestMetaService_processCreate(t *testing.T) {
`
(*node.metaService).processCreate(key1, msg1)
collectionNum := (*node.replica).getCollectionNum()
collectionNum := node.replica.getCollectionNum()
assert.Equal(t, collectionNum, 1)
collection, err := (*node.replica).getCollectionByName("test")
collection, err := node.replica.getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
@ -341,68 +243,19 @@ func TestMetaService_processCreate(t *testing.T) {
`
(*node.metaService).processCreate(key2, msg2)
s, err := (*node.replica).getSegmentByID(UniqueID(0))
s, err := node.replica.getSegmentByID(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, s.segmentID, UniqueID(0))
node.Close()
}
func TestMetaService_processSegmentModify(t *testing.T) {
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
node := newQueryNode()
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
colMetaBlob := proto.MarshalTextString(&collectionMeta)
err := (*node.replica).addCollection(&collectionMeta, string(colMetaBlob))
assert.NoError(t, err)
err = (*node.replica).addPartition(UniqueID(0), "default")
assert.NoError(t, err)
collectionID := UniqueID(0)
segmentID := UniqueID(0)
initTestMeta(t, node, collectionName, collectionID, segmentID)
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
id := "0"
value := `partition_tag: "default"
@ -412,9 +265,9 @@ func TestMetaService_processSegmentModify(t *testing.T) {
`
(*node.metaService).processSegmentCreate(id, value)
s, err := (*node.replica).getSegmentByID(UniqueID(0))
s, err := node.replica.getSegmentByID(segmentID)
assert.NoError(t, err)
assert.Equal(t, s.segmentID, UniqueID(0))
assert.Equal(t, s.segmentID, segmentID)
newValue := `partition_tag: "default"
channel_start: 0
@ -424,19 +277,15 @@ func TestMetaService_processSegmentModify(t *testing.T) {
// TODO: modify segment for testing processCollectionModify
(*node.metaService).processSegmentModify(id, newValue)
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
seg, err := node.replica.getSegmentByID(segmentID)
assert.NoError(t, err)
assert.Equal(t, seg.segmentID, UniqueID(0))
assert.Equal(t, seg.segmentID, segmentID)
node.Close()
}
func TestMetaService_processCollectionModify(t *testing.T) {
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
id := "0"
value := `schema: <
@ -448,6 +297,10 @@ func TestMetaService_processCollectionModify(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -465,24 +318,24 @@ func TestMetaService_processCollectionModify(t *testing.T) {
`
(*node.metaService).processCollectionCreate(id, value)
collectionNum := (*node.replica).getCollectionNum()
collectionNum := node.replica.getCollectionNum()
assert.Equal(t, collectionNum, 1)
collection, err := (*node.replica).getCollectionByName("test")
collection, err := node.replica.getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
partitionNum, err := (*node.replica).getPartitionNum(UniqueID(0))
partitionNum, err := node.replica.getPartitionNum(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, partitionNum, 3)
hasPartition := (*node.replica).hasPartition(UniqueID(0), "p0")
hasPartition := node.replica.hasPartition(UniqueID(0), "p0")
assert.Equal(t, hasPartition, true)
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
assert.Equal(t, hasPartition, true)
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
assert.Equal(t, hasPartition, true)
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
assert.Equal(t, hasPartition, false)
newValue := `schema: <
@ -494,6 +347,10 @@ func TestMetaService_processCollectionModify(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -511,32 +368,28 @@ func TestMetaService_processCollectionModify(t *testing.T) {
`
(*node.metaService).processCollectionModify(id, newValue)
collection, err = (*node.replica).getCollectionByName("test")
collection, err = node.replica.getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
partitionNum, err = (*node.replica).getPartitionNum(UniqueID(0))
partitionNum, err = node.replica.getPartitionNum(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, partitionNum, 3)
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p0")
hasPartition = node.replica.hasPartition(UniqueID(0), "p0")
assert.Equal(t, hasPartition, false)
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
assert.Equal(t, hasPartition, true)
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
assert.Equal(t, hasPartition, true)
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
assert.Equal(t, hasPartition, true)
node.Close()
}
func TestMetaService_processModify(t *testing.T) {
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
key1 := "by-dev/meta/collection/0"
msg1 := `schema: <
@ -548,6 +401,10 @@ func TestMetaService_processModify(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -565,24 +422,24 @@ func TestMetaService_processModify(t *testing.T) {
`
(*node.metaService).processCreate(key1, msg1)
collectionNum := (*node.replica).getCollectionNum()
collectionNum := node.replica.getCollectionNum()
assert.Equal(t, collectionNum, 1)
collection, err := (*node.replica).getCollectionByName("test")
collection, err := node.replica.getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
partitionNum, err := (*node.replica).getPartitionNum(UniqueID(0))
partitionNum, err := node.replica.getPartitionNum(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, partitionNum, 3)
hasPartition := (*node.replica).hasPartition(UniqueID(0), "p0")
hasPartition := node.replica.hasPartition(UniqueID(0), "p0")
assert.Equal(t, hasPartition, true)
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
assert.Equal(t, hasPartition, true)
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
assert.Equal(t, hasPartition, true)
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
assert.Equal(t, hasPartition, false)
key2 := "by-dev/meta/segment/0"
@ -593,7 +450,7 @@ func TestMetaService_processModify(t *testing.T) {
`
(*node.metaService).processCreate(key2, msg2)
s, err := (*node.replica).getSegmentByID(UniqueID(0))
s, err := node.replica.getSegmentByID(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, s.segmentID, UniqueID(0))
@ -608,6 +465,10 @@ func TestMetaService_processModify(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -625,21 +486,21 @@ func TestMetaService_processModify(t *testing.T) {
`
(*node.metaService).processModify(key1, msg3)
collection, err = (*node.replica).getCollectionByName("test")
collection, err = node.replica.getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
partitionNum, err = (*node.replica).getPartitionNum(UniqueID(0))
partitionNum, err = node.replica.getPartitionNum(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, partitionNum, 3)
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p0")
hasPartition = node.replica.hasPartition(UniqueID(0), "p0")
assert.Equal(t, hasPartition, false)
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p1")
hasPartition = node.replica.hasPartition(UniqueID(0), "p1")
assert.Equal(t, hasPartition, true)
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p2")
hasPartition = node.replica.hasPartition(UniqueID(0), "p2")
assert.Equal(t, hasPartition, true)
hasPartition = (*node.replica).hasPartition(UniqueID(0), "p3")
hasPartition = node.replica.hasPartition(UniqueID(0), "p3")
assert.Equal(t, hasPartition, true)
msg4 := `partition_tag: "p1"
@ -649,68 +510,18 @@ func TestMetaService_processModify(t *testing.T) {
`
(*node.metaService).processModify(key2, msg4)
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
seg, err := node.replica.getSegmentByID(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, seg.segmentID, UniqueID(0))
node.Close()
}
func TestMetaService_processSegmentDelete(t *testing.T) {
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
node := newQueryNode()
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
colMetaBlob := proto.MarshalTextString(&collectionMeta)
err := (*node.replica).addCollection(&collectionMeta, string(colMetaBlob))
assert.NoError(t, err)
err = (*node.replica).addPartition(UniqueID(0), "default")
assert.NoError(t, err)
collectionID := UniqueID(0)
initTestMeta(t, node, collectionName, collectionID, 0)
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
id := "0"
value := `partition_tag: "default"
@ -720,23 +531,19 @@ func TestMetaService_processSegmentDelete(t *testing.T) {
`
(*node.metaService).processSegmentCreate(id, value)
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
seg, err := node.replica.getSegmentByID(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, seg.segmentID, UniqueID(0))
(*node.metaService).processSegmentDelete("0")
mapSize := (*node.replica).getSegmentNum()
mapSize := node.replica.getSegmentNum()
assert.Equal(t, mapSize, 0)
node.Close()
}
func TestMetaService_processCollectionDelete(t *testing.T) {
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
id := "0"
value := `schema: <
@ -748,6 +555,10 @@ func TestMetaService_processCollectionDelete(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -763,26 +574,22 @@ func TestMetaService_processCollectionDelete(t *testing.T) {
`
(*node.metaService).processCollectionCreate(id, value)
collectionNum := (*node.replica).getCollectionNum()
collectionNum := node.replica.getCollectionNum()
assert.Equal(t, collectionNum, 1)
collection, err := (*node.replica).getCollectionByName("test")
collection, err := node.replica.getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
(*node.metaService).processCollectionDelete(id)
collectionNum = (*node.replica).getCollectionNum()
collectionNum = node.replica.getCollectionNum()
assert.Equal(t, collectionNum, 0)
node.Close()
}
func TestMetaService_processDelete(t *testing.T) {
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), d)
defer cancel()
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
key1 := "by-dev/meta/collection/0"
msg1 := `schema: <
@ -794,6 +601,10 @@ func TestMetaService_processDelete(t *testing.T) {
key: "dim"
value: "16"
>
index_params: <
key: "metric_type"
value: "L2"
>
>
fields: <
name: "age"
@ -809,10 +620,10 @@ func TestMetaService_processDelete(t *testing.T) {
`
(*node.metaService).processCreate(key1, msg1)
collectionNum := (*node.replica).getCollectionNum()
collectionNum := node.replica.getCollectionNum()
assert.Equal(t, collectionNum, 1)
collection, err := (*node.replica).getCollectionByName("test")
collection, err := node.replica.getCollectionByName("test")
assert.NoError(t, err)
assert.Equal(t, collection.ID(), UniqueID(0))
@ -824,77 +635,48 @@ func TestMetaService_processDelete(t *testing.T) {
`
(*node.metaService).processCreate(key2, msg2)
seg, err := (*node.replica).getSegmentByID(UniqueID(0))
seg, err := node.replica.getSegmentByID(UniqueID(0))
assert.NoError(t, err)
assert.Equal(t, seg.segmentID, UniqueID(0))
(*node.metaService).processDelete(key1)
collectionsSize := (*node.replica).getCollectionNum()
collectionsSize := node.replica.getCollectionNum()
assert.Equal(t, collectionsSize, 0)
mapSize := (*node.replica).getSegmentNum()
mapSize := node.replica.getSegmentNum()
assert.Equal(t, mapSize, 0)
node.Close()
}
func TestMetaService_processResp(t *testing.T) {
var ctx context.Context
if closeWithDeadline {
var cancel context.CancelFunc
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel = context.WithDeadline(context.Background(), d)
defer cancel()
} else {
ctx = context.Background()
}
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
metaChan := (*node.metaService).kvBase.WatchWithPrefix("")
select {
case <-node.ctx.Done():
case <-node.queryNodeLoopCtx.Done():
return
case resp := <-metaChan:
_ = (*node.metaService).processResp(resp)
}
node.Close()
}
func TestMetaService_loadCollections(t *testing.T) {
var ctx context.Context
if closeWithDeadline {
var cancel context.CancelFunc
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel = context.WithDeadline(context.Background(), d)
defer cancel()
} else {
ctx = context.Background()
}
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
err2 := (*node.metaService).loadCollections()
assert.Nil(t, err2)
node.Close()
}
func TestMetaService_loadSegments(t *testing.T) {
var ctx context.Context
if closeWithDeadline {
var cancel context.CancelFunc
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel = context.WithDeadline(context.Background(), d)
defer cancel()
} else {
ctx = context.Background()
}
// init metaService
node := NewQueryNode(ctx, 0)
node.metaService = newMetaService(ctx, node.replica)
node := newQueryNode()
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
err2 := (*node.metaService).loadSegments()
assert.Nil(t, err2)
node.Close()
}

View File

@ -2,8 +2,8 @@ package querynode
import (
"log"
"os"
"strconv"
"strings"
"github.com/zilliztech/milvus-distributed/internal/util/paramtable"
)
@ -21,15 +21,16 @@ func (p *ParamTable) Init() {
panic(err)
}
err = p.LoadYaml("milvus.yaml")
if err != nil {
panic(err)
}
err = p.LoadYaml("advanced/channel.yaml")
if err != nil {
panic(err)
queryNodeIDStr := os.Getenv("QUERY_NODE_ID")
if queryNodeIDStr == "" {
queryNodeIDList := p.QueryNodeIDList()
if len(queryNodeIDList) <= 0 {
queryNodeIDStr = "0"
} else {
queryNodeIDStr = strconv.Itoa(int(queryNodeIDList[0]))
}
}
p.Save("_queryNodeID", queryNodeIDStr)
}
func (p *ParamTable) pulsarAddress() (string, error) {
@ -40,8 +41,8 @@ func (p *ParamTable) pulsarAddress() (string, error) {
return url, nil
}
func (p *ParamTable) queryNodeID() int {
queryNodeID, err := p.Load("reader.clientid")
func (p *ParamTable) QueryNodeID() UniqueID {
queryNodeID, err := p.Load("_queryNodeID")
if err != nil {
panic(err)
}
@ -49,7 +50,7 @@ func (p *ParamTable) queryNodeID() int {
if err != nil {
panic(err)
}
return id
return UniqueID(id)
}
func (p *ParamTable) insertChannelRange() []int {
@ -57,138 +58,47 @@ func (p *ParamTable) insertChannelRange() []int {
if err != nil {
panic(err)
}
channelRange := strings.Split(insertChannelRange, ",")
if len(channelRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(channelRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(channelRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
return []int{channelBegin, channelEnd}
return paramtable.ConvertRangeToIntRange(insertChannelRange, ",")
}
// advanced params
// stats
func (p *ParamTable) statsPublishInterval() int {
timeInterval, err := p.Load("queryNode.stats.publishInterval")
if err != nil {
panic(err)
}
interval, err := strconv.Atoi(timeInterval)
if err != nil {
panic(err)
}
return interval
return p.ParseInt("queryNode.stats.publishInterval")
}
// dataSync:
func (p *ParamTable) flowGraphMaxQueueLength() int32 {
queueLength, err := p.Load("queryNode.dataSync.flowGraph.maxQueueLength")
if err != nil {
panic(err)
}
length, err := strconv.Atoi(queueLength)
if err != nil {
panic(err)
}
return int32(length)
return p.ParseInt32("queryNode.dataSync.flowGraph.maxQueueLength")
}
func (p *ParamTable) flowGraphMaxParallelism() int32 {
maxParallelism, err := p.Load("queryNode.dataSync.flowGraph.maxParallelism")
if err != nil {
panic(err)
}
maxPara, err := strconv.Atoi(maxParallelism)
if err != nil {
panic(err)
}
return int32(maxPara)
return p.ParseInt32("queryNode.dataSync.flowGraph.maxParallelism")
}
// msgStream
func (p *ParamTable) insertReceiveBufSize() int64 {
revBufSize, err := p.Load("queryNode.msgStream.insert.recvBufSize")
if err != nil {
panic(err)
}
bufSize, err := strconv.Atoi(revBufSize)
if err != nil {
panic(err)
}
return int64(bufSize)
return p.ParseInt64("queryNode.msgStream.insert.recvBufSize")
}
func (p *ParamTable) insertPulsarBufSize() int64 {
pulsarBufSize, err := p.Load("queryNode.msgStream.insert.pulsarBufSize")
if err != nil {
panic(err)
}
bufSize, err := strconv.Atoi(pulsarBufSize)
if err != nil {
panic(err)
}
return int64(bufSize)
return p.ParseInt64("queryNode.msgStream.insert.pulsarBufSize")
}
func (p *ParamTable) searchReceiveBufSize() int64 {
revBufSize, err := p.Load("queryNode.msgStream.search.recvBufSize")
if err != nil {
panic(err)
}
bufSize, err := strconv.Atoi(revBufSize)
if err != nil {
panic(err)
}
return int64(bufSize)
return p.ParseInt64("queryNode.msgStream.search.recvBufSize")
}
func (p *ParamTable) searchPulsarBufSize() int64 {
pulsarBufSize, err := p.Load("queryNode.msgStream.search.pulsarBufSize")
if err != nil {
panic(err)
}
bufSize, err := strconv.Atoi(pulsarBufSize)
if err != nil {
panic(err)
}
return int64(bufSize)
return p.ParseInt64("queryNode.msgStream.search.pulsarBufSize")
}
func (p *ParamTable) searchResultReceiveBufSize() int64 {
revBufSize, err := p.Load("queryNode.msgStream.searchResult.recvBufSize")
if err != nil {
panic(err)
}
bufSize, err := strconv.Atoi(revBufSize)
if err != nil {
panic(err)
}
return int64(bufSize)
return p.ParseInt64("queryNode.msgStream.searchResult.recvBufSize")
}
func (p *ParamTable) statsReceiveBufSize() int64 {
revBufSize, err := p.Load("queryNode.msgStream.stats.recvBufSize")
if err != nil {
panic(err)
}
bufSize, err := strconv.Atoi(revBufSize)
if err != nil {
panic(err)
}
return int64(bufSize)
return p.ParseInt64("queryNode.msgStream.stats.recvBufSize")
}
func (p *ParamTable) etcdAddress() string {
@ -212,123 +122,73 @@ func (p *ParamTable) metaRootPath() string {
}
func (p *ParamTable) gracefulTime() int64 {
gracefulTime, err := p.Load("queryNode.gracefulTime")
if err != nil {
panic(err)
}
time, err := strconv.Atoi(gracefulTime)
if err != nil {
panic(err)
}
return int64(time)
return p.ParseInt64("queryNode.gracefulTime")
}
func (p *ParamTable) insertChannelNames() []string {
ch, err := p.Load("msgChannel.chanNamePrefix.insert")
prefix, err := p.Load("msgChannel.chanNamePrefix.insert")
if err != nil {
log.Fatal(err)
}
prefix += "-"
channelRange, err := p.Load("msgChannel.channelRange.insert")
if err != nil {
panic(err)
}
channelIDs := paramtable.ConvertRangeToIntSlice(channelRange, ",")
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
sep := len(channelIDs) / p.queryNodeNum()
index := p.sliceIndex()
if index == -1 {
panic("queryNodeID not Match with Config")
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
}
return channels
start := index * sep
return ret[start : start+sep]
}
func (p *ParamTable) searchChannelNames() []string {
ch, err := p.Load("msgChannel.chanNamePrefix.search")
prefix, err := p.Load("msgChannel.chanNamePrefix.search")
if err != nil {
log.Fatal(err)
}
prefix += "-"
channelRange, err := p.Load("msgChannel.channelRange.search")
if err != nil {
panic(err)
}
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channelIDs := paramtable.ConvertRangeToIntSlice(channelRange, ",")
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
return channels
return ret
}
func (p *ParamTable) searchResultChannelNames() []string {
ch, err := p.Load("msgChannel.chanNamePrefix.searchResult")
prefix, err := p.Load("msgChannel.chanNamePrefix.searchResult")
if err != nil {
log.Fatal(err)
}
prefix += "-"
channelRange, err := p.Load("msgChannel.channelRange.searchResult")
if err != nil {
panic(err)
}
chanRange := strings.Split(channelRange, ",")
if len(chanRange) != 2 {
panic("Illegal channel range num")
}
channelBegin, err := strconv.Atoi(chanRange[0])
if err != nil {
panic(err)
}
channelEnd, err := strconv.Atoi(chanRange[1])
if err != nil {
panic(err)
}
if channelBegin < 0 || channelEnd < 0 {
panic("Illegal channel range value")
}
if channelBegin > channelEnd {
panic("Illegal channel range value")
}
channelIDs := paramtable.ConvertRangeToIntSlice(channelRange, ",")
channels := make([]string, channelEnd-channelBegin)
for i := 0; i < channelEnd-channelBegin; i++ {
channels[i] = ch + "-" + strconv.Itoa(channelBegin+i)
var ret []string
for _, ID := range channelIDs {
ret = append(ret, prefix+strconv.Itoa(ID))
}
return channels
return ret
}
func (p *ParamTable) msgChannelSubName() string {
@ -337,7 +197,11 @@ func (p *ParamTable) msgChannelSubName() string {
if err != nil {
log.Panic(err)
}
return name
queryNodeIDStr, err := p.Load("_QueryNodeID")
if err != nil {
panic(err)
}
return name + "-" + queryNodeIDStr
}
func (p *ParamTable) statsChannelName() string {
@ -347,3 +211,18 @@ func (p *ParamTable) statsChannelName() string {
}
return channels
}
func (p *ParamTable) sliceIndex() int {
queryNodeID := p.QueryNodeID()
queryNodeIDList := p.QueryNodeIDList()
for i := 0; i < len(queryNodeIDList); i++ {
if queryNodeID == queryNodeIDList[i] {
return i
}
}
return -1
}
func (p *ParamTable) queryNodeNum() int {
return len(p.QueryNodeIDList())
}

View File

@ -1,128 +1,109 @@
package querynode
import (
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParamTable_Init(t *testing.T) {
Params.Init()
}
func TestParamTable_PulsarAddress(t *testing.T) {
Params.Init()
address, err := Params.pulsarAddress()
assert.NoError(t, err)
split := strings.Split(address, ":")
assert.Equal(t, split[0], "pulsar")
assert.Equal(t, split[len(split)-1], "6650")
assert.Equal(t, "pulsar", split[0])
assert.Equal(t, "6650", split[len(split)-1])
}
func TestParamTable_QueryNodeID(t *testing.T) {
Params.Init()
id := Params.queryNodeID()
assert.Equal(t, id, 0)
id := Params.QueryNodeID()
assert.Contains(t, Params.QueryNodeIDList(), id)
}
func TestParamTable_insertChannelRange(t *testing.T) {
Params.Init()
channelRange := Params.insertChannelRange()
assert.Equal(t, len(channelRange), 2)
assert.Equal(t, channelRange[0], 0)
assert.Equal(t, channelRange[1], 1)
assert.Equal(t, 2, len(channelRange))
}
func TestParamTable_statsServiceTimeInterval(t *testing.T) {
Params.Init()
interval := Params.statsPublishInterval()
assert.Equal(t, interval, 1000)
assert.Equal(t, 1000, interval)
}
func TestParamTable_statsMsgStreamReceiveBufSize(t *testing.T) {
Params.Init()
bufSize := Params.statsReceiveBufSize()
assert.Equal(t, bufSize, int64(64))
assert.Equal(t, int64(64), bufSize)
}
func TestParamTable_insertMsgStreamReceiveBufSize(t *testing.T) {
Params.Init()
bufSize := Params.insertReceiveBufSize()
assert.Equal(t, bufSize, int64(1024))
assert.Equal(t, int64(1024), bufSize)
}
func TestParamTable_searchMsgStreamReceiveBufSize(t *testing.T) {
Params.Init()
bufSize := Params.searchReceiveBufSize()
assert.Equal(t, bufSize, int64(512))
assert.Equal(t, int64(512), bufSize)
}
func TestParamTable_searchResultMsgStreamReceiveBufSize(t *testing.T) {
Params.Init()
bufSize := Params.searchResultReceiveBufSize()
assert.Equal(t, bufSize, int64(64))
assert.Equal(t, int64(64), bufSize)
}
func TestParamTable_searchPulsarBufSize(t *testing.T) {
Params.Init()
bufSize := Params.searchPulsarBufSize()
assert.Equal(t, bufSize, int64(512))
assert.Equal(t, int64(512), bufSize)
}
func TestParamTable_insertPulsarBufSize(t *testing.T) {
Params.Init()
bufSize := Params.insertPulsarBufSize()
assert.Equal(t, bufSize, int64(1024))
assert.Equal(t, int64(1024), bufSize)
}
func TestParamTable_flowGraphMaxQueueLength(t *testing.T) {
Params.Init()
length := Params.flowGraphMaxQueueLength()
assert.Equal(t, length, int32(1024))
assert.Equal(t, int32(1024), length)
}
func TestParamTable_flowGraphMaxParallelism(t *testing.T) {
Params.Init()
maxParallelism := Params.flowGraphMaxParallelism()
assert.Equal(t, maxParallelism, int32(1024))
assert.Equal(t, int32(1024), maxParallelism)
}
func TestParamTable_insertChannelNames(t *testing.T) {
Params.Init()
names := Params.insertChannelNames()
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "insert-0")
channelRange := Params.insertChannelRange()
num := channelRange[1] - channelRange[0]
num = num / Params.queryNodeNum()
assert.Equal(t, num, len(names))
start := num * Params.sliceIndex()
assert.Equal(t, fmt.Sprintf("insert-%d", channelRange[start]), names[0])
}
func TestParamTable_searchChannelNames(t *testing.T) {
Params.Init()
names := Params.searchChannelNames()
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "search-0")
assert.Equal(t, "search-0", names[0])
}
func TestParamTable_searchResultChannelNames(t *testing.T) {
Params.Init()
names := Params.searchResultChannelNames()
assert.Equal(t, len(names), 1)
assert.Equal(t, names[0], "searchResult-0")
assert.NotNil(t, names)
}
func TestParamTable_msgChannelSubName(t *testing.T) {
Params.Init()
name := Params.msgChannelSubName()
assert.Equal(t, name, "queryNode")
expectName := fmt.Sprintf("queryNode-%d", Params.QueryNodeID())
assert.Equal(t, expectName, name)
}
func TestParamTable_statsChannelName(t *testing.T) {
Params.Init()
name := Params.statsChannelName()
assert.Equal(t, name, "query-node-stats")
assert.Equal(t, "query-node-stats", name)
}
func TestParamTable_metaRootPath(t *testing.T) {
Params.Init()
path := Params.metaRootPath()
assert.Equal(t, path, "by-dev/meta")
assert.Equal(t, "by-dev/meta", path)
}

View File

@ -1,77 +1,20 @@
package querynode
import (
"context"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
func TestPartition_Segments(t *testing.T) {
ctx := context.Background()
node := NewQueryNode(ctx, 0)
node := newQueryNode()
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
collectionID := UniqueID(0)
initTestMeta(t, node, collectionName, collectionID, 0)
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
collection, err := node.replica.getCollectionByName(collectionName)
assert.NoError(t, err)
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
for _, tag := range collectionMeta.PartitionTags {
err := (*node.replica).addPartition(collection.ID(), tag)
assert.NoError(t, err)
}
collectionMeta := collection.meta
partitions := collection.Partitions()
assert.Equal(t, len(collectionMeta.PartitionTags), len(*partitions))
@ -80,12 +23,12 @@ func TestPartition_Segments(t *testing.T) {
const segmentNum = 3
for i := 0; i < segmentNum; i++ {
err := (*node.replica).addSegment(UniqueID(i), targetPartition.partitionTag, collection.ID())
err := node.replica.addSegment(UniqueID(i), targetPartition.partitionTag, collection.ID())
assert.NoError(t, err)
}
segments := targetPartition.Segments()
assert.Equal(t, segmentNum, len(*segments))
assert.Equal(t, segmentNum+1, len(*segments))
}
func TestPartition_newPartition(t *testing.T) {

View File

@ -8,59 +8,17 @@ import (
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
func TestPlan_Plan(t *testing.T) {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
collection := newCollection(collectionMeta, collectionMetaBlob)
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
@ -74,52 +32,13 @@ func TestPlan_Plan(t *testing.T) {
}
func TestPlan_PlaceholderGroup(t *testing.T) {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
collection := newCollection(collectionMeta, collectionMetaBlob)
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"

View File

@ -17,11 +17,12 @@ import (
)
type QueryNode struct {
ctx context.Context
queryNodeLoopCtx context.Context
queryNodeLoopCancel func()
QueryNodeID uint64
replica *collectionReplica
replica collectionReplica
dataSyncService *dataSyncService
metaService *metaService
@ -29,7 +30,14 @@ type QueryNode struct {
statsService *statsService
}
func Init() {
Params.Init()
}
func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode {
ctx1, cancel := context.WithCancel(ctx)
segmentsMap := make(map[int64]*Segment)
collections := make([]*Collection, 0)
@ -43,11 +51,11 @@ func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode {
}
return &QueryNode{
ctx: ctx,
queryNodeLoopCtx: ctx1,
queryNodeLoopCancel: cancel,
QueryNodeID: queryNodeID,
QueryNodeID: queryNodeID,
replica: &replica,
replica: replica,
dataSyncService: nil,
metaService: nil,
@ -56,31 +64,34 @@ func NewQueryNode(ctx context.Context, queryNodeID uint64) *QueryNode {
}
}
func (node *QueryNode) Start() {
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
node.searchService = newSearchService(node.ctx, node.replica)
node.metaService = newMetaService(node.ctx, node.replica)
node.statsService = newStatsService(node.ctx, node.replica)
func (node *QueryNode) Start() error {
// todo add connectMaster logic
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica)
node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica)
node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica)
go node.dataSyncService.start()
go node.searchService.start()
go node.metaService.start()
node.statsService.start()
go node.statsService.start()
return nil
}
func (node *QueryNode) Close() {
<-node.ctx.Done()
node.queryNodeLoopCancel()
// free collectionReplica
(*node.replica).freeAll()
node.replica.freeAll()
// close services
if node.dataSyncService != nil {
(*node.dataSyncService).close()
node.dataSyncService.close()
}
if node.searchService != nil {
(*node.searchService).close()
node.searchService.close()
}
if node.statsService != nil {
(*node.statsService).close()
node.statsService.close()
}
}

View File

@ -2,18 +2,93 @@ package querynode
import (
"context"
"os"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
const ctxTimeInMillisecond = 200
const closeWithDeadline = true
// NOTE: start pulsar and etcd before test
func TestQueryNode_start(t *testing.T) {
func setup() {
Params.Init()
}
func genTestCollectionMeta(collectionName string, collectionID UniqueID) *etcdpb.CollectionMeta {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
IndexParams: []*commonpb.KeyValuePair{
{
Key: "metric_type",
Value: "L2",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: collectionID,
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
return &collectionMeta
}
func initTestMeta(t *testing.T, node *QueryNode, collectionName string, collectionID UniqueID, segmentID UniqueID) {
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = node.replica.addCollection(collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := node.replica.getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
assert.Equal(t, node.replica.getCollectionNum(), 1)
err = node.replica.addPartition(collection.ID(), collectionMeta.PartitionTags[0])
assert.NoError(t, err)
err = node.replica.addSegment(segmentID, collectionMeta.PartitionTags[0], collectionID)
assert.NoError(t, err)
}
func newQueryNode() *QueryNode {
var ctx context.Context
if closeWithDeadline {
var cancel context.CancelFunc
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
@ -23,7 +98,21 @@ func TestQueryNode_start(t *testing.T) {
ctx = context.Background()
}
node := NewQueryNode(ctx, 0)
node.Start()
node.Close()
svr := NewQueryNode(ctx, 0)
return svr
}
func TestMain(m *testing.M) {
setup()
exitCode := m.Run()
os.Exit(exitCode)
}
// NOTE: start pulsar and etcd before test
func TestQueryNode_Start(t *testing.T) {
localNode := newQueryNode()
err := localNode.Start()
assert.Nil(t, err)
localNode.Close()
}

View File

@ -1,15 +0,0 @@
package querynode
import (
"context"
)
func Init() {
Params.Init()
}
func StartQueryNode(ctx context.Context) {
node := NewQueryNode(ctx, 0)
node.Start()
}

View File

@ -9,63 +9,19 @@ import (
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
func TestReduce_AllFunc(t *testing.T) {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
collectionName := "collection0"
collectionID := UniqueID(0)
segmentID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
segmentID := UniqueID(0)
collection := newCollection(collectionMeta, collectionMetaBlob)
segment := newSegment(collection, segmentID)
assert.Equal(t, segmentID, segment.segmentID)
const DIM = 16
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}

View File

@ -4,10 +4,12 @@ import "C"
import (
"context"
"errors"
"github.com/golang/protobuf/proto"
"fmt"
"log"
"sync"
"github.com/golang/protobuf/proto"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
@ -19,7 +21,7 @@ type searchService struct {
wait sync.WaitGroup
cancel context.CancelFunc
replica *collectionReplica
replica collectionReplica
tSafeWatcher *tSafeWatcher
serviceableTime Timestamp
@ -27,13 +29,14 @@ type searchService struct {
msgBuffer chan msgstream.TsMsg
unsolvedMsg []msgstream.TsMsg
searchMsgStream *msgstream.MsgStream
searchResultMsgStream *msgstream.MsgStream
searchMsgStream msgstream.MsgStream
searchResultMsgStream msgstream.MsgStream
queryNodeID UniqueID
}
type ResultEntityIds []UniqueID
func newSearchService(ctx context.Context, replica *collectionReplica) *searchService {
func newSearchService(ctx context.Context, replica collectionReplica) *searchService {
receiveBufSize := Params.searchReceiveBufSize()
pulsarBufSize := Params.searchPulsarBufSize()
@ -69,14 +72,15 @@ func newSearchService(ctx context.Context, replica *collectionReplica) *searchSe
replica: replica,
tSafeWatcher: newTSafeWatcher(),
searchMsgStream: &inputStream,
searchResultMsgStream: &outputStream,
searchMsgStream: inputStream,
searchResultMsgStream: outputStream,
queryNodeID: Params.QueryNodeID(),
}
}
func (ss *searchService) start() {
(*ss.searchMsgStream).Start()
(*ss.searchResultMsgStream).Start()
ss.searchMsgStream.Start()
ss.searchResultMsgStream.Start()
ss.register()
ss.wait.Add(2)
go ss.receiveSearchMsg()
@ -85,20 +89,24 @@ func (ss *searchService) start() {
}
func (ss *searchService) close() {
(*ss.searchMsgStream).Close()
(*ss.searchResultMsgStream).Close()
if ss.searchMsgStream != nil {
ss.searchMsgStream.Close()
}
if ss.searchResultMsgStream != nil {
ss.searchResultMsgStream.Close()
}
ss.cancel()
}
func (ss *searchService) register() {
tSafe := (*(ss.replica)).getTSafe()
(*tSafe).registerTSafeWatcher(ss.tSafeWatcher)
tSafe := ss.replica.getTSafe()
tSafe.registerTSafeWatcher(ss.tSafeWatcher)
}
func (ss *searchService) waitNewTSafe() Timestamp {
// block until dataSyncService updating tSafe
ss.tSafeWatcher.hasUpdate()
timestamp := (*(*ss.replica).getTSafe()).get()
timestamp := ss.replica.getTSafe().get()
return timestamp
}
@ -122,7 +130,7 @@ func (ss *searchService) receiveSearchMsg() {
case <-ss.ctx.Done():
return
default:
msgPack := (*ss.searchMsgStream).Consume()
msgPack := ss.searchMsgStream.Consume()
if msgPack == nil || len(msgPack.Msgs) <= 0 {
continue
}
@ -219,7 +227,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
}
collectionName := query.CollectionName
partitionTags := query.PartitionTags
collection, err := (*ss.replica).getCollectionByName(collectionName)
collection, err := ss.replica.getCollectionByName(collectionName)
if err != nil {
return err
}
@ -241,14 +249,14 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
matchedSegments := make([]*Segment, 0)
for _, partitionTag := range partitionTags {
hasPartition := (*ss.replica).hasPartition(collectionID, partitionTag)
hasPartition := ss.replica.hasPartition(collectionID, partitionTag)
if !hasPartition {
return errors.New("search Failed, invalid partitionTag")
}
}
for _, partitionTag := range partitionTags {
partition, _ := (*ss.replica).getPartitionByTag(collectionID, partitionTag)
partition, _ := ss.replica.getPartitionByTag(collectionID, partitionTag)
for _, segment := range partition.segments {
//fmt.Println("dsl = ", dsl)
@ -268,13 +276,13 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_SUCCESS},
ReqID: searchMsg.ReqID,
ProxyID: searchMsg.ProxyID,
QueryNodeID: searchMsg.ProxyID,
QueryNodeID: ss.queryNodeID,
Timestamp: searchTimestamp,
ResultChannelID: searchMsg.ResultChannelID,
Hits: nil,
}
searchResultMsg := &msgstream.SearchResultMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
SearchResult: results,
}
err = ss.publishSearchResult(searchResultMsg)
@ -333,7 +341,7 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
Hits: hits,
}
searchResultMsg := &msgstream.SearchResultMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
SearchResult: results,
}
err = ss.publishSearchResult(searchResultMsg)
@ -350,9 +358,10 @@ func (ss *searchService) search(msg msgstream.TsMsg) error {
}
func (ss *searchService) publishSearchResult(msg msgstream.TsMsg) error {
fmt.Println("Public SearchResult", msg.HashKeys())
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, msg)
err := (*ss.searchResultMsgStream).Produce(&msgPack)
err := ss.searchResultMsgStream.Produce(&msgPack)
if err != nil {
return err
}
@ -377,11 +386,11 @@ func (ss *searchService) publishFailedSearchResult(msg msgstream.TsMsg, errMsg s
}
tsMsg := &msgstream.SearchResultMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{0}},
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(searchMsg.ResultChannelID)}},
SearchResult: results,
}
msgPack.Msgs = append(msgPack.Msgs, tsMsg)
err := (*ss.searchResultMsgStream).Produce(&msgPack)
err := ss.searchResultMsgStream.Produce(&msgPack)
if err != nil {
return err
}

View File

@ -13,80 +13,15 @@ import (
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
func TestSearch_Search(t *testing.T) {
Params.Init()
ctx, cancel := context.WithCancel(context.Background())
node := NewQueryNode(context.Background(), 0)
initTestMeta(t, node, "collection0", 0, 0)
// init query node
pulsarURL, _ := Params.pulsarAddress()
node := NewQueryNode(ctx, 0)
// init meta
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
assert.NoError(t, err)
segmentID := UniqueID(0)
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
assert.NoError(t, err)
// test data generate
const msgLength = 10
@ -158,14 +93,14 @@ func TestSearch_Search(t *testing.T) {
msgPackSearch := msgstream.MsgPack{}
msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg)
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
searchStream.SetPulsarClient(pulsarURL)
searchStream.CreatePulsarProducers(searchProducerChannels)
searchStream.Start()
err = searchStream.Produce(&msgPackSearch)
assert.NoError(t, err)
node.searchService = newSearchService(node.ctx, node.replica)
node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica)
go node.searchService.start()
// start insert
@ -235,7 +170,7 @@ func TestSearch_Search(t *testing.T) {
timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg)
// pulsar produce
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
insertStream.SetPulsarClient(pulsarURL)
insertStream.CreatePulsarProducers(insertProducerChannels)
insertStream.Start()
@ -245,83 +180,19 @@ func TestSearch_Search(t *testing.T) {
assert.NoError(t, err)
// dataSync
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
go node.dataSyncService.start()
time.Sleep(1 * time.Second)
cancel()
node.Close()
}
func TestSearch_SearchMultiSegments(t *testing.T) {
Params.Init()
ctx, cancel := context.WithCancel(context.Background())
node := NewQueryNode(context.Background(), 0)
initTestMeta(t, node, "collection0", 0, 0)
// init query node
pulsarURL, _ := Params.pulsarAddress()
node := NewQueryNode(ctx, 0)
// init meta
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
assert.NoError(t, err)
segmentID := UniqueID(0)
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
assert.NoError(t, err)
// test data generate
const msgLength = 1024
@ -393,14 +264,14 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
msgPackSearch := msgstream.MsgPack{}
msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg)
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
searchStream.SetPulsarClient(pulsarURL)
searchStream.CreatePulsarProducers(searchProducerChannels)
searchStream.Start()
err = searchStream.Produce(&msgPackSearch)
assert.NoError(t, err)
node.searchService = newSearchService(node.ctx, node.replica)
node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica)
go node.searchService.start()
// start insert
@ -474,7 +345,7 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
timeTickMsgPack.Msgs = append(timeTickMsgPack.Msgs, timeTickMsg)
// pulsar produce
insertStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
insertStream.SetPulsarClient(pulsarURL)
insertStream.CreatePulsarProducers(insertProducerChannels)
insertStream.Start()
@ -484,11 +355,10 @@ func TestSearch_SearchMultiSegments(t *testing.T) {
assert.NoError(t, err)
// dataSync
node.dataSyncService = newDataSyncService(node.ctx, node.replica)
node.dataSyncService = newDataSyncService(node.queryNodeLoopCtx, node.replica)
go node.dataSyncService.start()
time.Sleep(1 * time.Second)
cancel()
node.Close()
}

View File

@ -1,7 +1,6 @@
package querynode
import (
"context"
"encoding/binary"
"log"
"math"
@ -10,61 +9,21 @@ import (
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
)
//-------------------------------------------------------------------------------------- constructor and destructor
func TestSegment_newSegment(t *testing.T) {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -74,52 +33,15 @@ func TestSegment_newSegment(t *testing.T) {
}
func TestSegment_deleteSegment(t *testing.T) {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -131,52 +53,15 @@ func TestSegment_deleteSegment(t *testing.T) {
//-------------------------------------------------------------------------------------- stats functions
func TestSegment_getRowCount(t *testing.T) {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -219,52 +104,15 @@ func TestSegment_getRowCount(t *testing.T) {
}
func TestSegment_getDeletedCount(t *testing.T) {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -313,52 +161,15 @@ func TestSegment_getDeletedCount(t *testing.T) {
}
func TestSegment_getMemSize(t *testing.T) {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -402,53 +213,15 @@ func TestSegment_getMemSize(t *testing.T) {
//-------------------------------------------------------------------------------------- dm & search functions
func TestSegment_segmentInsert(t *testing.T) {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
assert.Equal(t, segmentID, segment.segmentID)
@ -486,52 +259,15 @@ func TestSegment_segmentInsert(t *testing.T) {
}
func TestSegment_segmentDelete(t *testing.T) {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -576,55 +312,15 @@ func TestSegment_segmentDelete(t *testing.T) {
}
func TestSegment_segmentSearch(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -661,13 +357,6 @@ func TestSegment_segmentSearch(t *testing.T) {
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
pulsarURL, _ := Params.pulsarAddress()
const receiveBufSize = 1024
searchProducerChannels := Params.searchChannelNames()
searchStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
searchStream.SetPulsarClient(pulsarURL)
searchStream.CreatePulsarProducers(searchProducerChannels)
var searchRawData []byte
for _, ele := range vec {
buf := make([]byte, 4)
@ -708,52 +397,15 @@ func TestSegment_segmentSearch(t *testing.T) {
//-------------------------------------------------------------------------------------- preDm functions
func TestSegment_segmentPreInsert(t *testing.T) {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)
@ -787,52 +439,15 @@ func TestSegment_segmentPreInsert(t *testing.T) {
}
func TestSegment_segmentPreDelete(t *testing.T) {
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: "collection0",
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMetaBlob := proto.MarshalTextString(collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
collection := newCollection(&collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
collection := newCollection(collectionMeta, collectionMetaBlob)
assert.Equal(t, collection.meta.Schema.Name, collectionName)
assert.Equal(t, collection.meta.ID, collectionID)
segmentID := UniqueID(0)
segment := newSegment(collection, segmentID)

View File

@ -14,11 +14,11 @@ import (
type statsService struct {
ctx context.Context
statsStream *msgstream.MsgStream
replica *collectionReplica
statsStream msgstream.MsgStream
replica collectionReplica
}
func newStatsService(ctx context.Context, replica *collectionReplica) *statsService {
func newStatsService(ctx context.Context, replica collectionReplica) *statsService {
return &statsService{
ctx: ctx,
@ -44,8 +44,8 @@ func (sService *statsService) start() {
var statsMsgStream msgstream.MsgStream = statsStream
sService.statsStream = &statsMsgStream
(*sService.statsStream).Start()
sService.statsStream = statsMsgStream
sService.statsStream.Start()
// start service
fmt.Println("do segments statistic in ", strconv.Itoa(sleepTimeInterval), "ms")
@ -60,11 +60,13 @@ func (sService *statsService) start() {
}
func (sService *statsService) close() {
(*sService.statsStream).Close()
if sService.statsStream != nil {
sService.statsStream.Close()
}
}
func (sService *statsService) sendSegmentStatistic() {
statisticData := (*sService.replica).getSegmentStatistics()
statisticData := sService.replica.getSegmentStatistics()
// fmt.Println("Publish segment statistic")
// fmt.Println(statisticData)
@ -82,7 +84,7 @@ func (sService *statsService) publicStatistic(statistic *internalpb.QueryNodeSeg
var msgPack = msgstream.MsgPack{
Msgs: []msgstream.TsMsg{msg},
}
err := (*sService.statsStream).Produce(&msgPack)
err := sService.statsStream.Produce(&msgPack)
if err != nil {
log.Println(err)
}

View File

@ -1,193 +1,42 @@
package querynode
import (
"context"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
// NOTE: start pulsar before test
func TestStatsService_start(t *testing.T) {
Params.Init()
var ctx context.Context
if closeWithDeadline {
var cancel context.CancelFunc
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel = context.WithDeadline(context.Background(), d)
defer cancel()
} else {
ctx = context.Background()
}
// init query node
node := NewQueryNode(ctx, 0)
// init meta
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
assert.NoError(t, err)
segmentID := UniqueID(0)
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
assert.NoError(t, err)
// start stats service
node.statsService = newStatsService(node.ctx, node.replica)
node := newQueryNode()
initTestMeta(t, node, "collection0", 0, 0)
node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica)
node.statsService.start()
node.Close()
}
// NOTE: start pulsar before test
//NOTE: start pulsar before test
func TestSegmentManagement_SegmentStatisticService(t *testing.T) {
Params.Init()
var ctx context.Context
if closeWithDeadline {
var cancel context.CancelFunc
d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond)
ctx, cancel = context.WithDeadline(context.Background(), d)
defer cancel()
} else {
ctx = context.Background()
}
// init query node
pulsarURL, _ := Params.pulsarAddress()
node := NewQueryNode(ctx, 0)
// init meta
collectionName := "collection0"
fieldVec := schemapb.FieldSchema{
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
}
fieldInt := schemapb.FieldSchema{
Name: "age",
IsPrimaryKey: false,
DataType: schemapb.DataType_INT32,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "1",
},
},
}
schema := schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
&fieldVec, &fieldInt,
},
}
collectionMeta := etcdpb.CollectionMeta{
ID: UniqueID(0),
Schema: &schema,
CreateTime: Timestamp(0),
SegmentIDs: []UniqueID{0},
PartitionTags: []string{"default"},
}
collectionMetaBlob := proto.MarshalTextString(&collectionMeta)
assert.NotEqual(t, "", collectionMetaBlob)
var err = (*node.replica).addCollection(&collectionMeta, collectionMetaBlob)
assert.NoError(t, err)
collection, err := (*node.replica).getCollectionByName(collectionName)
assert.NoError(t, err)
assert.Equal(t, collection.meta.Schema.Name, "collection0")
assert.Equal(t, collection.meta.ID, UniqueID(0))
assert.Equal(t, (*node.replica).getCollectionNum(), 1)
err = (*node.replica).addPartition(collection.ID(), collectionMeta.PartitionTags[0])
assert.NoError(t, err)
segmentID := UniqueID(0)
err = (*node.replica).addSegment(segmentID, collectionMeta.PartitionTags[0], UniqueID(0))
assert.NoError(t, err)
node := newQueryNode()
initTestMeta(t, node, "collection0", 0, 0)
const receiveBufSize = 1024
// start pulsar
producerChannels := []string{Params.statsChannelName()}
statsStream := msgstream.NewPulsarMsgStream(ctx, receiveBufSize)
pulsarURL, _ := Params.pulsarAddress()
statsStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
statsStream.SetPulsarClient(pulsarURL)
statsStream.CreatePulsarProducers(producerChannels)
var statsMsgStream msgstream.MsgStream = statsStream
node.statsService = newStatsService(node.ctx, node.replica)
node.statsService.statsStream = &statsMsgStream
(*node.statsService.statsStream).Start()
node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica)
node.statsService.statsStream = statsMsgStream
node.statsService.statsStream.Start()
// send stats
node.statsService.sendSegmentStatistic()
node.Close()
}

View File

@ -36,11 +36,11 @@ type tSafeImpl struct {
watcherList []*tSafeWatcher
}
func newTSafe() *tSafe {
func newTSafe() tSafe {
var t tSafe = &tSafeImpl{
watcherList: make([]*tSafeWatcher, 0),
}
return &t
return t
}
func (ts *tSafeImpl) registerTSafeWatcher(t *tSafeWatcher) {

View File

@ -9,13 +9,13 @@ import (
func TestTSafe_GetAndSet(t *testing.T) {
tSafe := newTSafe()
watcher := newTSafeWatcher()
(*tSafe).registerTSafeWatcher(watcher)
tSafe.registerTSafeWatcher(watcher)
go func() {
watcher.hasUpdate()
timestamp := (*tSafe).get()
timestamp := tSafe.get()
assert.Equal(t, timestamp, Timestamp(1000))
}()
(*tSafe).set(Timestamp(1000))
tSafe.set(Timestamp(1000))
}

View File

@ -206,7 +206,7 @@ extern "C" CStatus AddBinaryVectorToPayload(CPayloadWriter payloadWriter, uint8_
st.error_msg = ErrorMsg("payload has finished");
return st;
}
auto ast = builder->AppendValues(values, (dimension / 8) * length);
auto ast = builder->AppendValues(values, length);
if (!ast.ok()) {
st.error_code = static_cast<int>(ErrorCode::UNEXPECTED_ERROR);
st.error_msg = ErrorMsg(ast.message());
@ -249,7 +249,7 @@ extern "C" CStatus AddFloatVectorToPayload(CPayloadWriter payloadWriter, float *
st.error_msg = ErrorMsg("payload has finished");
return st;
}
auto ast = builder->AppendValues(reinterpret_cast<const uint8_t *>(values), dimension * length * sizeof(float));
auto ast = builder->AppendValues(reinterpret_cast<const uint8_t *>(values), length);
if (!ast.ok()) {
st.error_code = static_cast<int>(ErrorCode::UNEXPECTED_ERROR);
st.error_msg = ErrorMsg(ast.message());
@ -451,7 +451,7 @@ extern "C" CStatus GetBinaryVectorFromPayload(CPayloadReader payloadReader,
return st;
}
*dimension = array->byte_width() * 8;
*length = array->length() / array->byte_width();
*length = array->length();
*values = (uint8_t *) array->raw_values();
return st;
}
@ -470,7 +470,7 @@ extern "C" CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader,
return st;
}
*dimension = array->byte_width() / sizeof(float);
*length = array->length() / array->byte_width();
*length = array->length();
*values = (float *) array->raw_values();
return st;
}
@ -478,12 +478,7 @@ extern "C" CStatus GetFloatVectorFromPayload(CPayloadReader payloadReader,
extern "C" int GetPayloadLengthFromReader(CPayloadReader payloadReader) {
auto p = reinterpret_cast<wrapper::PayloadReader *>(payloadReader);
if (p->array == nullptr) return 0;
auto ba = std::dynamic_pointer_cast<arrow::FixedSizeBinaryArray>(p->array);
if (ba == nullptr) {
return p->array->length();
} else {
return ba->length() / ba->byte_width();
}
return p->array->length();
}
extern "C" CStatus ReleasePayloadReader(CPayloadReader payloadReader) {

View File

@ -5,6 +5,7 @@ extern "C" {
#endif
#include <stdint.h>
#include <stdbool.h>
typedef void *CPayloadWriter;
@ -19,7 +20,7 @@ typedef struct CStatus {
} CStatus;
CPayloadWriter NewPayloadWriter(int columnType);
//CStatus AddBooleanToPayload(CPayloadWriter payloadWriter, bool *values, int length);
CStatus AddBooleanToPayload(CPayloadWriter payloadWriter, bool *values, int length);
CStatus AddInt8ToPayload(CPayloadWriter payloadWriter, int8_t *values, int length);
CStatus AddInt16ToPayload(CPayloadWriter payloadWriter, int16_t *values, int length);
CStatus AddInt32ToPayload(CPayloadWriter payloadWriter, int32_t *values, int length);
@ -39,7 +40,7 @@ CStatus ReleasePayloadWriter(CPayloadWriter handler);
typedef void *CPayloadReader;
CPayloadReader NewPayloadReader(int columnType, uint8_t *buffer, int64_t buf_size);
//CStatus GetBoolFromPayload(CPayloadReader payloadReader, bool **values, int *length);
CStatus GetBoolFromPayload(CPayloadReader payloadReader, bool **values, int *length);
CStatus GetInt8FromPayload(CPayloadReader payloadReader, int8_t **values, int *length);
CStatus GetInt16FromPayload(CPayloadReader payloadReader, int16_t **values, int *length);
CStatus GetInt32FromPayload(CPayloadReader payloadReader, int32_t **values, int *length);
@ -55,4 +56,4 @@ CStatus ReleasePayloadReader(CPayloadReader payloadReader);
#ifdef __cplusplus
}
#endif
#endif

View File

@ -70,38 +70,38 @@ TEST(wrapper, inoutstream) {
ASSERT_EQ(inarray->Value(4), 5);
}
//TEST(wrapper, boolean) {
// auto payload = NewPayloadWriter(ColumnType::BOOL);
// bool data[] = {true, false, true, false};
//
// auto st = AddBooleanToPayload(payload, data, 4);
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
// st = FinishPayloadWriter(payload);
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
// auto cb = GetPayloadBufferFromWriter(payload);
// ASSERT_GT(cb.length, 0);
// ASSERT_NE(cb.data, nullptr);
// auto nums = GetPayloadLengthFromWriter(payload);
// ASSERT_EQ(nums, 4);
//
// auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length);
// bool *values;
// int length;
// st = GetBoolFromPayload(reader, &values, &length);
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
// ASSERT_NE(values, nullptr);
// ASSERT_EQ(length, 4);
// length = GetPayloadLengthFromReader(reader);
// ASSERT_EQ(length, 4);
// for (int i = 0; i < length; i++) {
// ASSERT_EQ(data[i], values[i]);
// }
//
// st = ReleasePayloadWriter(payload);
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
// st = ReleasePayloadReader(reader);
// ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
//}
TEST(wrapper, boolean) {
auto payload = NewPayloadWriter(ColumnType::BOOL);
bool data[] = {true, false, true, false};
auto st = AddBooleanToPayload(payload, data, 4);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
st = FinishPayloadWriter(payload);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
auto cb = GetPayloadBufferFromWriter(payload);
ASSERT_GT(cb.length, 0);
ASSERT_NE(cb.data, nullptr);
auto nums = GetPayloadLengthFromWriter(payload);
ASSERT_EQ(nums, 4);
auto reader = NewPayloadReader(ColumnType::BOOL, (uint8_t *) cb.data, cb.length);
bool *values;
int length;
st = GetBoolFromPayload(reader, &values, &length);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
ASSERT_NE(values, nullptr);
ASSERT_EQ(length, 4);
length = GetPayloadLengthFromReader(reader);
ASSERT_EQ(length, 4);
for (int i = 0; i < length; i++) {
ASSERT_EQ(data[i], values[i]);
}
st = ReleasePayloadWriter(payload);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
st = ReleasePayloadReader(reader);
ASSERT_EQ(st.error_code, ErrorCode::SUCCESS);
}
#define NUMERIC_TEST(TEST_NAME, COLUMN_TYPE, DATA_TYPE, ADD_FUNC, GET_FUNC, ARRAY_TYPE) TEST(wrapper, TEST_NAME) { \
auto payload = NewPayloadWriter(COLUMN_TYPE); \

View File

@ -16,25 +16,311 @@ import (
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
type PayloadWriter struct {
payloadWriterPtr C.CPayloadWriter
}
type (
PayloadWriter struct {
payloadWriterPtr C.CPayloadWriter
colType schemapb.DataType
}
PayloadReader struct {
payloadReaderPtr C.CPayloadReader
colType schemapb.DataType
}
)
func NewPayloadWriter(colType schemapb.DataType) (*PayloadWriter, error) {
w := C.NewPayloadWriter(C.int(colType))
if w == nil {
return nil, errors.New("create Payload writer failed")
}
return &PayloadWriter{payloadWriterPtr: w}, nil
return &PayloadWriter{payloadWriterPtr: w, colType: colType}, nil
}
func (w *PayloadWriter) AddDataToPayload(msgs interface{}, dim ...int) error {
switch len(dim) {
case 0:
switch w.colType {
case schemapb.DataType_BOOL:
val, ok := msgs.([]bool)
if !ok {
return errors.New("incorrect data type")
}
return w.AddBoolToPayload(val)
case schemapb.DataType_INT8:
val, ok := msgs.([]int8)
if !ok {
return errors.New("incorrect data type")
}
return w.AddInt8ToPayload(val)
case schemapb.DataType_INT16:
val, ok := msgs.([]int16)
if !ok {
return errors.New("incorrect data type")
}
return w.AddInt16ToPayload(val)
case schemapb.DataType_INT32:
val, ok := msgs.([]int32)
if !ok {
return errors.New("incorrect data type")
}
return w.AddInt32ToPayload(val)
case schemapb.DataType_INT64:
val, ok := msgs.([]int64)
if !ok {
return errors.New("incorrect data type")
}
return w.AddInt64ToPayload(val)
case schemapb.DataType_FLOAT:
val, ok := msgs.([]float32)
if !ok {
return errors.New("incorrect data type")
}
return w.AddFloatToPayload(val)
case schemapb.DataType_DOUBLE:
val, ok := msgs.([]float64)
if !ok {
return errors.New("incorrect data type")
}
return w.AddDoubleToPayload(val)
case schemapb.DataType_STRING:
val, ok := msgs.(string)
if !ok {
return errors.New("incorrect data type")
}
return w.AddOneStringToPayload(val)
}
case 1:
switch w.colType {
case schemapb.DataType_VECTOR_BINARY:
val, ok := msgs.([]byte)
if !ok {
return errors.New("incorrect data type")
}
return w.AddBinaryVectorToPayload(val, dim[0])
case schemapb.DataType_VECTOR_FLOAT:
val, ok := msgs.([]float32)
if !ok {
return errors.New("incorrect data type")
}
return w.AddFloatVectorToPayload(val, dim[0])
}
default:
return errors.New("incorrect input numbers")
}
return nil
}
func (w *PayloadWriter) AddBoolToPayload(msgs []bool) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.bool)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddBooleanToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddInt8ToPayload(msgs []int8) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.int8_t)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddInt8ToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddInt16ToPayload(msgs []int16) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.int16_t)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddInt16ToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddInt32ToPayload(msgs []int32) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.int32_t)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddInt32ToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddInt64ToPayload(msgs []int64) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.int64_t)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddInt64ToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddFloatToPayload(msgs []float32) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.float)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddFloatToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddDoubleToPayload(msgs []float64) error {
length := len(msgs)
if length <= 0 {
return errors.Errorf("can't add empty msgs into payload")
}
cMsgs := (*C.double)(unsafe.Pointer(&msgs[0]))
cLength := C.int(length)
status := C.AddDoubleToPayload(w.payloadWriterPtr, cMsgs, cLength)
errCode := commonpb.ErrorCode(status.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(status.error_msg)
defer C.free(unsafe.Pointer(status.error_msg))
return errors.New(msg)
}
return nil
}
func (w *PayloadWriter) AddOneStringToPayload(msg string) error {
if len(msg) == 0 {
length := len(msg)
if length == 0 {
return errors.New("can't add empty string into payload")
}
cstr := C.CString(msg)
defer C.free(unsafe.Pointer(cstr))
st := C.AddOneStringToPayload(w.payloadWriterPtr, cstr, C.int(len(msg)))
cmsg := C.CString(msg)
clength := C.int(length)
defer C.free(unsafe.Pointer(cmsg))
st := C.AddOneStringToPayload(w.payloadWriterPtr, cmsg, clength)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return errors.New(msg)
}
return nil
}
// dimension > 0 && (%8 == 0)
func (w *PayloadWriter) AddBinaryVectorToPayload(binVec []byte, dim int) error {
length := len(binVec)
if length <= 0 {
return errors.New("can't add empty binVec into payload")
}
if dim <= 0 {
return errors.New("dimension should be greater than 0")
}
cBinVec := (*C.uint8_t)(&binVec[0])
cDim := C.int(dim)
cLength := C.int(length / (dim / 8))
st := C.AddBinaryVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return errors.New(msg)
}
return nil
}
// dimension > 0 && (%8 == 0)
func (w *PayloadWriter) AddFloatVectorToPayload(floatVec []float32, dim int) error {
length := len(floatVec)
if length <= 0 {
return errors.New("can't add empty floatVec into payload")
}
if dim <= 0 {
return errors.New("dimension should be greater than 0")
}
cBinVec := (*C.float)(&floatVec[0])
cDim := C.int(dim)
cLength := C.int(length / dim)
st := C.AddFloatVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
@ -56,13 +342,13 @@ func (w *PayloadWriter) FinishPayloadWriter() error {
}
func (w *PayloadWriter) GetPayloadBufferFromWriter() ([]byte, error) {
//See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
cb := C.GetPayloadBufferFromWriter(w.payloadWriterPtr)
pointer := unsafe.Pointer(cb.data)
length := int(cb.length)
if length <= 0 {
return nil, errors.New("empty buffer")
}
// refer to: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
slice := (*[1 << 28]byte)(pointer)[:length:length]
return slice, nil
}
@ -87,16 +373,71 @@ func (w *PayloadWriter) Close() error {
return w.ReleasePayloadWriter()
}
type PayloadReader struct {
payloadReaderPtr C.CPayloadReader
}
func NewPayloadReader(colType schemapb.DataType, buf []byte) (*PayloadReader, error) {
if len(buf) == 0 {
return nil, errors.New("create Payload reader failed, buffer is empty")
}
r := C.NewPayloadReader(C.int(colType), (*C.uchar)(unsafe.Pointer(&buf[0])), C.long(len(buf)))
return &PayloadReader{payloadReaderPtr: r}, nil
return &PayloadReader{payloadReaderPtr: r, colType: colType}, nil
}
// Params:
// `idx`: String index
// Return:
// `interface{}`: all types.
// `int`: length, only meaningful to FLOAT/BINARY VECTOR type.
// `error`: error.
func (r *PayloadReader) GetDataFromPayload(idx ...int) (interface{}, int, error) {
switch len(idx) {
case 1:
switch r.colType {
case schemapb.DataType_STRING:
val, err := r.GetOneStringFromPayload(idx[0])
return val, 0, err
}
case 0:
switch r.colType {
case schemapb.DataType_BOOL:
val, err := r.GetBoolFromPayload()
return val, 0, err
case schemapb.DataType_INT8:
val, err := r.GetInt8FromPayload()
return val, 0, err
case schemapb.DataType_INT16:
val, err := r.GetInt16FromPayload()
return val, 0, err
case schemapb.DataType_INT32:
val, err := r.GetInt32FromPayload()
return val, 0, err
case schemapb.DataType_INT64:
val, err := r.GetInt64FromPayload()
return val, 0, err
case schemapb.DataType_FLOAT:
val, err := r.GetFloatFromPayload()
return val, 0, err
case schemapb.DataType_DOUBLE:
val, err := r.GetDoubleFromPayload()
return val, 0, err
case schemapb.DataType_VECTOR_BINARY:
return r.GetBinaryVectorFromPayload()
case schemapb.DataType_VECTOR_FLOAT:
return r.GetFloatVectorFromPayload()
default:
return nil, 0, errors.New("Unknown type")
}
default:
return nil, 0, errors.New("incorrect number of index")
}
return nil, 0, errors.New("unknown error")
}
func (r *PayloadReader) ReleasePayloadReader() error {
@ -110,18 +451,169 @@ func (r *PayloadReader) ReleasePayloadReader() error {
return nil
}
func (r *PayloadReader) GetBoolFromPayload() ([]bool, error) {
var cMsg *C.bool
var cSize C.int
st := C.GetBoolFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]bool)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetInt8FromPayload() ([]int8, error) {
var cMsg *C.int8_t
var cSize C.int
st := C.GetInt8FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]int8)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetInt16FromPayload() ([]int16, error) {
var cMsg *C.int16_t
var cSize C.int
st := C.GetInt16FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]int16)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetInt32FromPayload() ([]int32, error) {
var cMsg *C.int32_t
var cSize C.int
st := C.GetInt32FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]int32)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetInt64FromPayload() ([]int64, error) {
var cMsg *C.int64_t
var cSize C.int
st := C.GetInt64FromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]int64)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetFloatFromPayload() ([]float32, error) {
var cMsg *C.float
var cSize C.int
st := C.GetFloatFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]float32)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetDoubleFromPayload() ([]float64, error) {
var cMsg *C.double
var cSize C.int
st := C.GetDoubleFromPayload(r.payloadReaderPtr, &cMsg, &cSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, errors.New(msg)
}
slice := (*[1 << 28]float64)(unsafe.Pointer(cMsg))[:cSize:cSize]
return slice, nil
}
func (r *PayloadReader) GetOneStringFromPayload(idx int) (string, error) {
var cStr *C.char
var strSize C.int
var cSize C.int
st := C.GetOneStringFromPayload(r.payloadReaderPtr, C.int(idx), &cStr, &cSize)
st := C.GetOneStringFromPayload(r.payloadReaderPtr, C.int(idx), &cStr, &strSize)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return "", errors.New(msg)
}
return C.GoStringN(cStr, strSize), nil
return C.GoStringN(cStr, cSize), nil
}
// ,dimension, error
func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, error) {
var cMsg *C.uint8_t
var cDim C.int
var cLen C.int
st := C.GetBinaryVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, 0, errors.New(msg)
}
length := (cDim / 8) * cLen
slice := (*[1 << 28]byte)(unsafe.Pointer(cMsg))[:length:length]
return slice, int(cDim), nil
}
// ,dimension, error
func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) {
var cMsg *C.float
var cDim C.int
var cLen C.int
st := C.GetFloatVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen)
errCode := commonpb.ErrorCode(st.error_code)
if errCode != commonpb.ErrorCode_SUCCESS {
msg := C.GoString(st.error_msg)
defer C.free(unsafe.Pointer(st.error_msg))
return nil, 0, errors.New(msg)
}
length := cDim * cLen
slice := (*[1 << 28]float32)(unsafe.Pointer(cMsg))[:length:length]
return slice, int(cDim), nil
}
func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) {

View File

@ -1,54 +1,426 @@
package storage
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
func TestNewPayloadWriter(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_STRING)
assert.Nil(t, err)
assert.NotNil(t, w)
err = w.Close()
assert.Nil(t, err)
}
func TestPayLoadString(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_STRING)
assert.Nil(t, err)
err = w.AddOneStringToPayload("hello0")
assert.Nil(t, err)
err = w.AddOneStringToPayload("hello1")
assert.Nil(t, err)
err = w.AddOneStringToPayload("hello2")
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, length, 3)
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_STRING, buffer)
assert.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 3)
str0, err := r.GetOneStringFromPayload(0)
assert.Nil(t, err)
assert.Equal(t, str0, "hello0")
str1, err := r.GetOneStringFromPayload(1)
assert.Nil(t, err)
assert.Equal(t, str1, "hello1")
str2, err := r.GetOneStringFromPayload(2)
assert.Nil(t, err)
assert.Equal(t, str2, "hello2")
err = r.ReleasePayloadReader()
assert.Nil(t, err)
err = w.ReleasePayloadWriter()
assert.Nil(t, err)
func TestPayload_ReaderandWriter(t *testing.T) {
t.Run("TestBool", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_BOOL)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddBoolToPayload([]bool{false, false, false, false})
assert.Nil(t, err)
err = w.AddDataToPayload([]bool{false, false, false, false})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 8, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_BOOL, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 8)
bools, err := r.GetBoolFromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []bool{false, false, false, false, false, false, false, false}, bools)
ibools, _, err := r.GetDataFromPayload()
bools = ibools.([]bool)
assert.Nil(t, err)
assert.ElementsMatch(t, []bool{false, false, false, false, false, false, false, false}, bools)
defer r.ReleasePayloadReader()
})
t.Run("TestInt8", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_INT8)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddInt8ToPayload([]int8{1, 2, 3})
assert.Nil(t, err)
err = w.AddDataToPayload([]int8{4, 5, 6})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_INT8, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
int8s, err := r.GetInt8FromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []int8{1, 2, 3, 4, 5, 6}, int8s)
iint8s, _, err := r.GetDataFromPayload()
int8s = iint8s.([]int8)
assert.Nil(t, err)
assert.ElementsMatch(t, []int8{1, 2, 3, 4, 5, 6}, int8s)
defer r.ReleasePayloadReader()
})
t.Run("TestInt16", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_INT16)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddInt16ToPayload([]int16{1, 2, 3})
assert.Nil(t, err)
err = w.AddDataToPayload([]int16{1, 2, 3})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_INT16, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
int16s, err := r.GetInt16FromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []int16{1, 2, 3, 1, 2, 3}, int16s)
iint16s, _, err := r.GetDataFromPayload()
int16s = iint16s.([]int16)
assert.Nil(t, err)
assert.ElementsMatch(t, []int16{1, 2, 3, 1, 2, 3}, int16s)
defer r.ReleasePayloadReader()
})
t.Run("TestInt32", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_INT32)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddInt32ToPayload([]int32{1, 2, 3})
assert.Nil(t, err)
err = w.AddDataToPayload([]int32{1, 2, 3})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_INT32, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
int32s, err := r.GetInt32FromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []int32{1, 2, 3, 1, 2, 3}, int32s)
iint32s, _, err := r.GetDataFromPayload()
int32s = iint32s.([]int32)
assert.Nil(t, err)
assert.ElementsMatch(t, []int32{1, 2, 3, 1, 2, 3}, int32s)
defer r.ReleasePayloadReader()
})
t.Run("TestInt64", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_INT64)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddInt64ToPayload([]int64{1, 2, 3})
assert.Nil(t, err)
err = w.AddDataToPayload([]int64{1, 2, 3})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_INT64, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
int64s, err := r.GetInt64FromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []int64{1, 2, 3, 1, 2, 3}, int64s)
iint64s, _, err := r.GetDataFromPayload()
int64s = iint64s.([]int64)
assert.Nil(t, err)
assert.ElementsMatch(t, []int64{1, 2, 3, 1, 2, 3}, int64s)
defer r.ReleasePayloadReader()
})
t.Run("TestFloat32", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_FLOAT)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddFloatToPayload([]float32{1.0, 2.0, 3.0})
assert.Nil(t, err)
err = w.AddDataToPayload([]float32{1.0, 2.0, 3.0})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_FLOAT, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
float32s, err := r.GetFloatFromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float32s)
ifloat32s, _, err := r.GetDataFromPayload()
float32s = ifloat32s.([]float32)
assert.Nil(t, err)
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float32s)
defer r.ReleasePayloadReader()
})
t.Run("TestDouble", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_DOUBLE)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddDoubleToPayload([]float64{1.0, 2.0, 3.0})
assert.Nil(t, err)
err = w.AddDataToPayload([]float64{1.0, 2.0, 3.0})
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 6, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_DOUBLE, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 6)
float64s, err := r.GetDoubleFromPayload()
assert.Nil(t, err)
assert.ElementsMatch(t, []float64{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float64s)
ifloat64s, _, err := r.GetDataFromPayload()
float64s = ifloat64s.([]float64)
assert.Nil(t, err)
assert.ElementsMatch(t, []float64{1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, float64s)
defer r.ReleasePayloadReader()
})
t.Run("TestAddOneString", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_STRING)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddOneStringToPayload("hello0")
assert.Nil(t, err)
err = w.AddOneStringToPayload("hello1")
assert.Nil(t, err)
err = w.AddOneStringToPayload("hello2")
assert.Nil(t, err)
err = w.AddDataToPayload("hello3")
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, length, 4)
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_STRING, buffer)
assert.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 4)
str0, err := r.GetOneStringFromPayload(0)
assert.Nil(t, err)
assert.Equal(t, str0, "hello0")
str1, err := r.GetOneStringFromPayload(1)
assert.Nil(t, err)
assert.Equal(t, str1, "hello1")
str2, err := r.GetOneStringFromPayload(2)
assert.Nil(t, err)
assert.Equal(t, str2, "hello2")
str3, err := r.GetOneStringFromPayload(3)
assert.Nil(t, err)
assert.Equal(t, str3, "hello3")
istr0, _, err := r.GetDataFromPayload(0)
str0 = istr0.(string)
assert.Nil(t, err)
assert.Equal(t, str0, "hello0")
istr1, _, err := r.GetDataFromPayload(1)
str1 = istr1.(string)
assert.Nil(t, err)
assert.Equal(t, str1, "hello1")
istr2, _, err := r.GetDataFromPayload(2)
str2 = istr2.(string)
assert.Nil(t, err)
assert.Equal(t, str2, "hello2")
istr3, _, err := r.GetDataFromPayload(3)
str3 = istr3.(string)
assert.Nil(t, err)
assert.Equal(t, str3, "hello3")
err = r.ReleasePayloadReader()
assert.Nil(t, err)
err = w.ReleasePayloadWriter()
assert.Nil(t, err)
})
t.Run("TestBinaryVector", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_VECTOR_BINARY)
require.Nil(t, err)
require.NotNil(t, w)
in := make([]byte, 16)
for i := 0; i < 16; i++ {
in[i] = 1
}
in2 := make([]byte, 8)
for i := 0; i < 8; i++ {
in2[i] = 1
}
err = w.AddBinaryVectorToPayload(in, 8)
assert.Nil(t, err)
err = w.AddDataToPayload(in2, 8)
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 24, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_VECTOR_BINARY, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 24)
binVecs, dim, err := r.GetBinaryVectorFromPayload()
assert.Nil(t, err)
assert.Equal(t, 8, dim)
assert.Equal(t, 24, len(binVecs))
fmt.Println(binVecs)
ibinVecs, dim, err := r.GetDataFromPayload()
assert.Nil(t, err)
binVecs = ibinVecs.([]byte)
assert.Equal(t, 8, dim)
assert.Equal(t, 24, len(binVecs))
defer r.ReleasePayloadReader()
})
t.Run("TestFloatVector", func(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_VECTOR_FLOAT)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddFloatVectorToPayload([]float32{1.0, 2.0}, 1)
assert.Nil(t, err)
err = w.AddDataToPayload([]float32{3.0, 4.0}, 1)
assert.Nil(t, err)
err = w.FinishPayloadWriter()
assert.Nil(t, err)
length, err := w.GetPayloadLengthFromWriter()
assert.Nil(t, err)
assert.Equal(t, 4, length)
defer w.ReleasePayloadWriter()
buffer, err := w.GetPayloadBufferFromWriter()
assert.Nil(t, err)
r, err := NewPayloadReader(schemapb.DataType_VECTOR_FLOAT, buffer)
require.Nil(t, err)
length, err = r.GetPayloadLengthFromReader()
assert.Nil(t, err)
assert.Equal(t, length, 4)
floatVecs, dim, err := r.GetFloatVectorFromPayload()
assert.Nil(t, err)
assert.Equal(t, 1, dim)
assert.Equal(t, 4, len(floatVecs))
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 4.0}, floatVecs)
ifloatVecs, dim, err := r.GetDataFromPayload()
assert.Nil(t, err)
floatVecs = ifloatVecs.([]float32)
assert.Equal(t, 1, dim)
assert.Equal(t, 4, len(floatVecs))
assert.ElementsMatch(t, []float32{1.0, 2.0, 3.0, 4.0}, floatVecs)
defer r.ReleasePayloadReader()
})
}

View File

@ -16,13 +16,17 @@ import (
"os"
"path"
"runtime"
"strconv"
"strings"
"github.com/spf13/cast"
"github.com/spf13/viper"
memkv "github.com/zilliztech/milvus-distributed/internal/kv/mem"
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
)
type UniqueID = typeutil.UniqueID
type Base interface {
Load(key string) (string, error)
LoadRange(key, endKey string, limit int) ([]string, []string, error)
@ -38,7 +42,18 @@ type BaseTable struct {
func (gp *BaseTable) Init() {
gp.params = memkv.NewMemoryKV()
err := gp.LoadYaml("config.yaml")
err := gp.LoadYaml("milvus.yaml")
if err != nil {
panic(err)
}
err = gp.LoadYaml("advanced/common.yaml")
if err != nil {
panic(err)
}
err = gp.LoadYaml("advanced/channel.yaml")
if err != nil {
panic(err)
}
@ -146,3 +161,140 @@ func (gp *BaseTable) Remove(key string) error {
func (gp *BaseTable) Save(key, value string) error {
return gp.params.Save(strings.ToLower(key), value)
}
func (gp *BaseTable) ParseFloat(key string) float64 {
valueStr, err := gp.Load(key)
if err != nil {
panic(err)
}
value, err := strconv.ParseFloat(valueStr, 64)
if err != nil {
panic(err)
}
return value
}
func (gp *BaseTable) ParseInt64(key string) int64 {
valueStr, err := gp.Load(key)
if err != nil {
panic(err)
}
value, err := strconv.Atoi(valueStr)
if err != nil {
panic(err)
}
return int64(value)
}
func (gp *BaseTable) ParseInt32(key string) int32 {
valueStr, err := gp.Load(key)
if err != nil {
panic(err)
}
value, err := strconv.Atoi(valueStr)
if err != nil {
panic(err)
}
return int32(value)
}
func (gp *BaseTable) ParseInt(key string) int {
valueStr, err := gp.Load(key)
if err != nil {
panic(err)
}
value, err := strconv.Atoi(valueStr)
if err != nil {
panic(err)
}
return value
}
func (gp *BaseTable) WriteNodeIDList() []UniqueID {
proxyIDStr, err := gp.Load("nodeID.writeNodeIDList")
if err != nil {
panic(err)
}
var ret []UniqueID
proxyIDs := strings.Split(proxyIDStr, ",")
for _, i := range proxyIDs {
v, err := strconv.Atoi(i)
if err != nil {
log.Panicf("load write node id list error, %s", err.Error())
}
ret = append(ret, UniqueID(v))
}
return ret
}
func (gp *BaseTable) ProxyIDList() []UniqueID {
proxyIDStr, err := gp.Load("nodeID.proxyIDList")
if err != nil {
panic(err)
}
var ret []UniqueID
proxyIDs := strings.Split(proxyIDStr, ",")
for _, i := range proxyIDs {
v, err := strconv.Atoi(i)
if err != nil {
log.Panicf("load proxy id list error, %s", err.Error())
}
ret = append(ret, UniqueID(v))
}
return ret
}
func (gp *BaseTable) QueryNodeIDList() []UniqueID {
queryNodeIDStr, err := gp.Load("nodeID.queryNodeIDList")
if err != nil {
panic(err)
}
var ret []UniqueID
queryNodeIDs := strings.Split(queryNodeIDStr, ",")
for _, i := range queryNodeIDs {
v, err := strconv.Atoi(i)
if err != nil {
log.Panicf("load proxy id list error, %s", err.Error())
}
ret = append(ret, UniqueID(v))
}
return ret
}
// package methods
func ConvertRangeToIntRange(rangeStr, sep string) []int {
items := strings.Split(rangeStr, sep)
if len(items) != 2 {
panic("Illegal range ")
}
startStr := items[0]
endStr := items[1]
start, err := strconv.Atoi(startStr)
if err != nil {
panic(err)
}
end, err := strconv.Atoi(endStr)
if err != nil {
panic(err)
}
if start < 0 || end < 0 {
panic("Illegal range value")
}
if start > end {
panic("Illegal range value, start > end")
}
return []int{start, end}
}
func ConvertRangeToIntSlice(rangeStr, sep string) []int {
rangeSlice := ConvertRangeToIntRange(rangeStr, sep)
start, end := rangeSlice[0], rangeSlice[1]
var ret []int
for i := start; i < end; i++ {
ret = append(ret, i)
}
return ret
}

View File

@ -12,6 +12,7 @@
package paramtable
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
@ -21,6 +22,8 @@ var Params = BaseTable{}
func TestMain(m *testing.M) {
Params.Init()
code := m.Run()
os.Exit(code)
}
//func TestMain
@ -55,13 +58,13 @@ func TestGlobalParamsTable_SaveAndLoad(t *testing.T) {
}
func TestGlobalParamsTable_LoadRange(t *testing.T) {
_ = Params.Save("abc", "10")
_ = Params.Save("fghz", "20")
_ = Params.Save("bcde", "1.1")
_ = Params.Save("abcd", "testSaveAndLoad")
_ = Params.Save("zhi", "12")
_ = Params.Save("xxxaab", "10")
_ = Params.Save("xxxfghz", "20")
_ = Params.Save("xxxbcde", "1.1")
_ = Params.Save("xxxabcd", "testSaveAndLoad")
_ = Params.Save("xxxzhi", "12")
keys, values, err := Params.LoadRange("a", "g", 10)
keys, values, err := Params.LoadRange("xxxa", "xxxg", 10)
assert.Nil(t, err)
assert.Equal(t, 4, len(keys))
assert.Equal(t, "10", values[0])
@ -97,24 +100,17 @@ func TestGlobalParamsTable_Remove(t *testing.T) {
}
func TestGlobalParamsTable_LoadYaml(t *testing.T) {
err := Params.LoadYaml("config.yaml")
err := Params.LoadYaml("milvus.yaml")
assert.Nil(t, err)
value1, err1 := Params.Load("etcd.address")
value2, err2 := Params.Load("pulsar.port")
value3, err3 := Params.Load("reader.topicend")
value4, err4 := Params.Load("proxy.pulsarTopics.readerTopicPrefix")
value5, err5 := Params.Load("proxy.network.address")
err = Params.LoadYaml("advanced/channel.yaml")
assert.Nil(t, err)
assert.Equal(t, value1, "localhost")
assert.Equal(t, value2, "6650")
assert.Equal(t, value3, "128")
assert.Equal(t, value4, "milvusReader")
assert.Equal(t, value5, "0.0.0.0")
_, err = Params.Load("etcd.address")
assert.Nil(t, err)
_, err = Params.Load("pulsar.port")
assert.Nil(t, err)
_, err = Params.Load("msgChannel.channelRange.insert")
assert.Nil(t, err)
assert.Nil(t, err1)
assert.Nil(t, err2)
assert.Nil(t, err3)
assert.Nil(t, err4)
assert.Nil(t, err5)
}