milvus/internal/querycoord/cluster_test.go

689 lines
20 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querycoord
import (
"context"
"encoding/json"
"errors"
"fmt"
"path"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/indexnode"
"github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
minioKV "github.com/milvus-io/milvus/internal/kv/minio"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
const (
indexID = UniqueID(0)
dimKey = "dim"
rowIDFieldID = 0
timestampFieldID = 1
indexName = "query-coord-index-0"
defaultDim = 128
defaultMetricType = "JACCARD"
defaultKVRootPath = "query-coord-unittest"
)
type constFieldParam struct {
id int64
dataType schemapb.DataType
}
var simpleConstField = constFieldParam{
id: 101,
dataType: schemapb.DataType_Int32,
}
type vecFieldParam struct {
id int64
dim int
metricType string
vecType schemapb.DataType
}
var simpleVecField = vecFieldParam{
id: 100,
dim: defaultDim,
metricType: defaultMetricType,
vecType: schemapb.DataType_FloatVector,
}
var timestampField = constFieldParam{
id: timestampFieldID,
dataType: schemapb.DataType_Int64,
}
var uidField = constFieldParam{
id: rowIDFieldID,
dataType: schemapb.DataType_Int64,
}
type indexParam = map[string]string
func segSizeEstimateForTest(segments *querypb.LoadSegmentsRequest, dataKV kv.DataKV) (int64, error) {
sizePerRecord, err := typeutil.EstimateSizePerRecord(segments.Schema)
if err != nil {
return 0, err
}
sizeOfReq := int64(0)
for _, loadInfo := range segments.Infos {
sizeOfReq += int64(sizePerRecord) * loadInfo.NumOfRows
}
return sizeOfReq, nil
}
func genCollectionMeta(collectionID UniqueID, schema *schemapb.CollectionSchema) *etcdpb.CollectionMeta {
colInfo := &etcdpb.CollectionMeta{
ID: collectionID,
Schema: schema,
PartitionIDs: []UniqueID{defaultPartitionID},
}
return colInfo
}
// ---------- unittest util functions ----------
// functions of inserting data init
func genInsertData(msgLength int, schema *schemapb.CollectionSchema) (*storage.InsertData, error) {
insertData := &storage.InsertData{
Data: make(map[int64]storage.FieldData),
}
for _, f := range schema.Fields {
switch f.DataType {
case schemapb.DataType_Bool:
data := make([]bool, msgLength)
for i := 0; i < msgLength; i++ {
data[i] = true
}
insertData.Data[f.FieldID] = &storage.BoolFieldData{
NumRows: []int64{int64(msgLength)},
Data: data,
}
case schemapb.DataType_Int8:
data := make([]int8, msgLength)
for i := 0; i < msgLength; i++ {
data[i] = int8(i)
}
insertData.Data[f.FieldID] = &storage.Int8FieldData{
NumRows: []int64{int64(msgLength)},
Data: data,
}
case schemapb.DataType_Int16:
data := make([]int16, msgLength)
for i := 0; i < msgLength; i++ {
data[i] = int16(i)
}
insertData.Data[f.FieldID] = &storage.Int16FieldData{
NumRows: []int64{int64(msgLength)},
Data: data,
}
case schemapb.DataType_Int32:
data := make([]int32, msgLength)
for i := 0; i < msgLength; i++ {
data[i] = int32(i)
}
insertData.Data[f.FieldID] = &storage.Int32FieldData{
NumRows: []int64{int64(msgLength)},
Data: data,
}
case schemapb.DataType_Int64:
data := make([]int64, msgLength)
for i := 0; i < msgLength; i++ {
data[i] = int64(i)
}
insertData.Data[f.FieldID] = &storage.Int64FieldData{
NumRows: []int64{int64(msgLength)},
Data: data,
}
case schemapb.DataType_Float:
data := make([]float32, msgLength)
for i := 0; i < msgLength; i++ {
data[i] = float32(i)
}
insertData.Data[f.FieldID] = &storage.FloatFieldData{
NumRows: []int64{int64(msgLength)},
Data: data,
}
case schemapb.DataType_Double:
data := make([]float64, msgLength)
for i := 0; i < msgLength; i++ {
data[i] = float64(i)
}
insertData.Data[f.FieldID] = &storage.DoubleFieldData{
NumRows: []int64{int64(msgLength)},
Data: data,
}
case schemapb.DataType_FloatVector:
dim := simpleVecField.dim // if no dim specified, use simpleVecField's dim
for _, p := range f.TypeParams {
if p.Key == dimKey {
var err error
dim, err = strconv.Atoi(p.Value)
if err != nil {
return nil, err
}
}
}
data := make([]float32, 0)
for i := 0; i < msgLength; i++ {
for j := 0; j < dim; j++ {
data = append(data, float32(i*j)*0.1)
}
}
insertData.Data[f.FieldID] = &storage.FloatVectorFieldData{
NumRows: []int64{int64(msgLength)},
Data: data,
Dim: dim,
}
default:
err := errors.New("data type not supported")
return nil, err
}
}
return insertData, nil
}
func genStorageBlob(collectionID UniqueID,
partitionID UniqueID,
segmentID UniqueID,
msgLength int,
schema *schemapb.CollectionSchema) ([]*storage.Blob, error) {
collMeta := genCollectionMeta(collectionID, schema)
inCodec := storage.NewInsertCodec(collMeta)
insertData, err := genInsertData(msgLength, schema)
if err != nil {
return nil, err
}
// timestamp field not allowed 0 timestamp
if _, ok := insertData.Data[timestampFieldID]; ok {
insertData.Data[timestampFieldID].(*storage.Int64FieldData).Data[0] = 1
}
binLogs, _, err := inCodec.Serialize(partitionID, segmentID, insertData)
return binLogs, err
}
func saveSimpleBinLog(ctx context.Context, schema *schemapb.CollectionSchema, dataKV kv.DataKV) ([]*datapb.FieldBinlog, error) {
return saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultNumRowPerSegment, schema, dataKV)
}
func saveBinLog(ctx context.Context,
collectionID UniqueID,
partitionID UniqueID,
segmentID UniqueID,
msgLength int,
schema *schemapb.CollectionSchema,
dataKV kv.DataKV) ([]*datapb.FieldBinlog, error) {
binLogs, err := genStorageBlob(collectionID, partitionID, segmentID, msgLength, schema)
if err != nil {
return nil, err
}
log.Debug(".. [query coord unittest] Saving bin logs to MinIO ..", zap.Int("number", len(binLogs)))
kvs := make(map[string]string, len(binLogs))
// write insert binlog
fieldBinlog := make([]*datapb.FieldBinlog, 0)
for _, blob := range binLogs {
fieldID, err := strconv.ParseInt(blob.GetKey(), 10, 64)
log.Debug("[QueryCoord unittest] save binlog", zap.Int64("fieldID", fieldID))
if err != nil {
return nil, err
}
key := genKey(collectionID, partitionID, segmentID, fieldID)
kvs[key] = string(blob.Value[:])
fieldBinlog = append(fieldBinlog, &datapb.FieldBinlog{
FieldID: fieldID,
Binlogs: []*datapb.Binlog{{LogPath: key}},
})
}
log.Debug("[QueryCoord unittest] save binlog file to MinIO/S3")
err = dataKV.MultiSave(kvs)
return fieldBinlog, err
}
func genKey(collectionID, partitionID, segmentID UniqueID, fieldID int64) string {
ids := []string{
defaultKVRootPath,
strconv.FormatInt(collectionID, 10),
strconv.FormatInt(partitionID, 10),
strconv.FormatInt(segmentID, 10),
strconv.FormatInt(fieldID, 10),
}
return path.Join(ids...)
}
func genSimpleIndexParams() indexParam {
indexParams := make(map[string]string)
indexParams["index_type"] = "IVF_PQ"
indexParams["index_mode"] = "cpu"
indexParams["dim"] = strconv.FormatInt(defaultDim, 10)
indexParams["k"] = "10"
indexParams["nlist"] = "100"
indexParams["nprobe"] = "10"
indexParams["m"] = "4"
indexParams["nbits"] = "8"
indexParams["metric_type"] = "L2"
indexParams["SLICE_SIZE"] = "400"
return indexParams
}
func generateIndex(segmentID UniqueID) ([]string, error) {
indexParams := genSimpleIndexParams()
var indexParamsKV []*commonpb.KeyValuePair
for key, value := range indexParams {
indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{
Key: key,
Value: value,
})
}
typeParams := make(map[string]string)
typeParams["dim"] = strconv.Itoa(defaultDim)
var indexRowData []float32
for n := 0; n < defaultNumRowPerSegment; n++ {
for i := 0; i < defaultDim; i++ {
indexRowData = append(indexRowData, float32(n*i))
}
}
index, err := indexnode.NewCIndex(typeParams, indexParams)
if err != nil {
return nil, err
}
err = index.BuildFloatVecIndexWithoutIds(indexRowData)
if err != nil {
return nil, err
}
option := &minioKV.Option{
Address: Params.QueryCoordCfg.MinioEndPoint,
AccessKeyID: Params.QueryCoordCfg.MinioAccessKeyID,
SecretAccessKeyID: Params.QueryCoordCfg.MinioSecretAccessKey,
UseSSL: Params.QueryCoordCfg.MinioUseSSLStr,
BucketName: Params.QueryCoordCfg.MinioBucketName,
CreateBucket: true,
}
kv, err := minioKV.NewMinIOKV(context.Background(), option)
if err != nil {
return nil, err
}
// save index to minio
binarySet, err := index.Serialize()
if err != nil {
return nil, err
}
// serialize index params
indexCodec := storage.NewIndexFileBinlogCodec()
serializedIndexBlobs, err := indexCodec.Serialize(
0,
0,
0,
0,
0,
0,
indexParams,
indexName,
indexID,
binarySet,
)
if err != nil {
return nil, err
}
indexPaths := make([]string, 0)
for _, index := range serializedIndexBlobs {
p := strconv.Itoa(int(segmentID)) + "/" + index.Key
indexPaths = append(indexPaths, p)
err := kv.Save(p, string(index.Value))
if err != nil {
return nil, err
}
}
return indexPaths, nil
}
func TestQueryNodeCluster_getMetrics(t *testing.T) {
log.Info("TestQueryNodeCluster_getMetrics, todo")
}
func TestReloadClusterFromKV(t *testing.T) {
etcdCli, err := etcd.GetEtcdClient(&Params.BaseParams)
defer etcdCli.Close()
assert.Nil(t, err)
t.Run("Test LoadOnlineNodes", func(t *testing.T) {
refreshParams()
baseCtx := context.Background()
kv := etcdkv.NewEtcdKV(etcdCli, Params.QueryCoordCfg.MetaRootPath)
clusterSession := sessionutil.NewSession(context.Background(), Params.QueryCoordCfg.MetaRootPath, etcdCli)
clusterSession.Init(typeutil.QueryCoordRole, Params.QueryCoordCfg.Address, true, false)
clusterSession.Register()
cluster := &queryNodeCluster{
ctx: baseCtx,
client: kv,
nodes: make(map[int64]Node),
newNodeFn: newQueryNodeTest,
session: clusterSession,
segSizeEstimator: segSizeEstimateForTest,
}
queryNode, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
cluster.reloadFromKV()
nodeID := queryNode.queryNodeID
waitQueryNodeOnline(cluster, nodeID)
queryNode.stop()
err = removeNodeSession(queryNode.queryNodeID)
assert.Nil(t, err)
})
t.Run("Test LoadOfflineNodes", func(t *testing.T) {
refreshParams()
kv := etcdkv.NewEtcdKV(etcdCli, Params.QueryCoordCfg.MetaRootPath)
clusterSession := sessionutil.NewSession(context.Background(), Params.QueryCoordCfg.MetaRootPath, etcdCli)
clusterSession.Init(typeutil.QueryCoordRole, Params.QueryCoordCfg.Address, true, false)
clusterSession.Register()
cluster := &queryNodeCluster{
client: kv,
nodes: make(map[int64]Node),
newNodeFn: newQueryNodeTest,
session: clusterSession,
segSizeEstimator: segSizeEstimateForTest,
}
kvs := make(map[string]string)
session := &sessionutil.Session{
ServerID: 100,
Address: "localhost",
}
sessionBlob, err := json.Marshal(session)
assert.Nil(t, err)
sessionKey := fmt.Sprintf("%s/%d", queryNodeInfoPrefix, 100)
kvs[sessionKey] = string(sessionBlob)
err = kv.MultiSave(kvs)
assert.Nil(t, err)
cluster.reloadFromKV()
assert.Equal(t, 1, len(cluster.nodes))
err = removeAllSession()
assert.Nil(t, err)
})
}
func TestGrpcRequest(t *testing.T) {
refreshParams()
baseCtx, cancel := context.WithCancel(context.Background())
etcdCli, err := etcd.GetEtcdClient(&Params.BaseParams)
assert.Nil(t, err)
defer etcdCli.Close()
kv := etcdkv.NewEtcdKV(etcdCli, Params.QueryCoordCfg.MetaRootPath)
clusterSession := sessionutil.NewSession(context.Background(), Params.QueryCoordCfg.MetaRootPath, etcdCli)
clusterSession.Init(typeutil.QueryCoordRole, Params.QueryCoordCfg.Address, true, false)
clusterSession.Register()
factory := msgstream.NewPmsFactory()
m := map[string]interface{}{
"PulsarAddress": Params.QueryCoordCfg.PulsarAddress,
"ReceiveBufSize": 1024,
"PulsarBufSize": 1024}
err = factory.SetParams(m)
assert.Nil(t, err)
idAllocator := func() (UniqueID, error) {
return 0, nil
}
meta, err := newMeta(baseCtx, kv, factory, idAllocator)
assert.Nil(t, err)
cluster := &queryNodeCluster{
ctx: baseCtx,
cancel: cancel,
client: kv,
clusterMeta: meta,
nodes: make(map[int64]Node),
newNodeFn: newQueryNodeTest,
session: clusterSession,
segSizeEstimator: segSizeEstimateForTest,
}
t.Run("Test GetNodeInfoByIDWithNodeNotExist", func(t *testing.T) {
_, err := cluster.getNodeInfoByID(defaultQueryNodeID)
assert.NotNil(t, err)
})
t.Run("Test GetSegmentInfoByNodeWithNodeNotExist", func(t *testing.T) {
getSegmentInfoReq := &querypb.GetSegmentInfoRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SegmentInfo,
},
CollectionID: defaultCollectionID,
}
_, err = cluster.getSegmentInfoByNode(baseCtx, defaultQueryNodeID, getSegmentInfoReq)
assert.NotNil(t, err)
})
node, err := startQueryNodeServer(baseCtx)
assert.Nil(t, err)
nodeSession := node.session
nodeID := node.queryNodeID
cluster.registerNode(baseCtx, nodeSession, nodeID, disConnect)
waitQueryNodeOnline(cluster, nodeID)
t.Run("Test GetComponentInfos", func(t *testing.T) {
infos := cluster.getComponentInfos(baseCtx)
assert.Equal(t, 1, len(infos))
})
t.Run("Test LoadSegments", func(t *testing.T) {
segmentLoadInfo := &querypb.SegmentLoadInfo{
SegmentID: defaultSegmentID,
PartitionID: defaultPartitionID,
CollectionID: defaultCollectionID,
}
loadSegmentReq := &querypb.LoadSegmentsRequest{
DstNodeID: nodeID,
Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo},
Schema: genCollectionSchema(defaultCollectionID, false),
CollectionID: defaultCollectionID,
}
err := cluster.loadSegments(baseCtx, nodeID, loadSegmentReq)
assert.Nil(t, err)
})
t.Run("Test ReleaseSegments", func(t *testing.T) {
releaseSegmentReq := &querypb.ReleaseSegmentsRequest{
NodeID: nodeID,
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
SegmentIDs: []UniqueID{defaultSegmentID},
}
err := cluster.releaseSegments(baseCtx, nodeID, releaseSegmentReq)
assert.Nil(t, err)
})
t.Run("Test AddQueryChannel", func(t *testing.T) {
info := cluster.clusterMeta.getQueryChannelInfoByID(defaultCollectionID)
addQueryChannelReq := &querypb.AddQueryChannelRequest{
NodeID: nodeID,
CollectionID: defaultCollectionID,
QueryChannel: info.QueryChannel,
QueryResultChannel: info.QueryResultChannel,
}
err = cluster.addQueryChannel(baseCtx, nodeID, addQueryChannelReq)
assert.Nil(t, err)
})
t.Run("Test RemoveQueryChannel", func(t *testing.T) {
info := cluster.clusterMeta.getQueryChannelInfoByID(defaultCollectionID)
removeQueryChannelReq := &querypb.RemoveQueryChannelRequest{
NodeID: nodeID,
CollectionID: defaultCollectionID,
QueryChannel: info.QueryChannel,
QueryResultChannel: info.QueryResultChannel,
}
err = cluster.removeQueryChannel(baseCtx, nodeID, removeQueryChannelReq)
assert.Nil(t, err)
})
t.Run("Test GetSegmentInfo", func(t *testing.T) {
getSegmentInfoReq := &querypb.GetSegmentInfoRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SegmentInfo,
},
CollectionID: defaultCollectionID,
}
_, err = cluster.getSegmentInfo(baseCtx, getSegmentInfoReq)
assert.Nil(t, err)
})
t.Run("Test GetSegmentInfoByNode", func(t *testing.T) {
getSegmentInfoReq := &querypb.GetSegmentInfoRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SegmentInfo,
},
CollectionID: defaultCollectionID,
}
_, err = cluster.getSegmentInfoByNode(baseCtx, nodeID, getSegmentInfoReq)
assert.Nil(t, err)
})
node.getSegmentInfos = returnFailedGetSegmentInfoResult
t.Run("Test GetSegmentInfoFailed", func(t *testing.T) {
getSegmentInfoReq := &querypb.GetSegmentInfoRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SegmentInfo,
},
CollectionID: defaultCollectionID,
}
_, err = cluster.getSegmentInfo(baseCtx, getSegmentInfoReq)
assert.NotNil(t, err)
})
t.Run("Test GetSegmentInfoByNodeFailed", func(t *testing.T) {
getSegmentInfoReq := &querypb.GetSegmentInfoRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SegmentInfo,
},
CollectionID: defaultCollectionID,
}
_, err = cluster.getSegmentInfoByNode(baseCtx, nodeID, getSegmentInfoReq)
assert.NotNil(t, err)
})
node.getSegmentInfos = returnSuccessGetSegmentInfoResult
t.Run("Test GetNodeInfoByID", func(t *testing.T) {
res, err := cluster.getNodeInfoByID(nodeID)
assert.Nil(t, err)
assert.NotNil(t, res)
})
node.getMetrics = returnFailedGetMetricsResult
t.Run("Test GetNodeInfoByIDFailed", func(t *testing.T) {
_, err := cluster.getNodeInfoByID(nodeID)
assert.NotNil(t, err)
})
node.getMetrics = returnSuccessGetMetricsResult
cluster.stopNode(nodeID)
t.Run("Test GetSegmentInfoByNodeAfterNodeStop", func(t *testing.T) {
getSegmentInfoReq := &querypb.GetSegmentInfoRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SegmentInfo,
},
CollectionID: defaultCollectionID,
}
_, err = cluster.getSegmentInfoByNode(baseCtx, nodeID, getSegmentInfoReq)
assert.NotNil(t, err)
})
err = removeAllSession()
assert.Nil(t, err)
}
func TestEstimateSegmentSize(t *testing.T) {
refreshParams()
binlog := []*datapb.FieldBinlog{
{
FieldID: defaultVecFieldID,
Binlogs: []*datapb.Binlog{{LogPath: "by-dev/rand/path", LogSize: 1024}},
},
}
loadInfo := &querypb.SegmentLoadInfo{
SegmentID: defaultSegmentID,
PartitionID: defaultPartitionID,
CollectionID: defaultCollectionID,
BinlogPaths: binlog,
NumOfRows: defaultNumRowPerSegment,
}
loadReq := &querypb.LoadSegmentsRequest{
Infos: []*querypb.SegmentLoadInfo{loadInfo},
CollectionID: defaultCollectionID,
}
size, err := estimateSegmentsSize(loadReq, nil)
assert.NoError(t, err)
assert.Equal(t, int64(1024), size)
indexInfo := &querypb.VecFieldIndexInfo{
FieldID: defaultVecFieldID,
EnableIndex: true,
IndexSize: 2048,
}
loadInfo.IndexInfos = []*querypb.VecFieldIndexInfo{indexInfo}
size, err = estimateSegmentsSize(loadReq, nil)
assert.NoError(t, err)
assert.Equal(t, int64(2048), size)
}