Add unittest for loadIndex service

Signed-off-by: xige-16 <xi.ge@zilliz.com>
pull/4973/head^2
xige-16 2020-12-26 14:16:51 +08:00 committed by yefu.chen
parent 2031c54746
commit d599407e2b
9 changed files with 194 additions and 47 deletions

View File

@ -32,6 +32,12 @@ NewLoadIndexInfo(CLoadIndexInfo* c_load_index_info) {
}
}
void
DeleteLoadIndexInfo(CLoadIndexInfo c_load_index_info) {
auto info = (LoadIndexInfo*)c_load_index_info;
delete info;
}
CStatus
AppendIndexParam(CLoadIndexInfo c_load_index_info, const char* c_index_key, const char* c_index_value) {
try {

View File

@ -25,6 +25,9 @@ typedef void* CBinarySet;
CStatus
NewLoadIndexInfo(CLoadIndexInfo* c_load_index_info);
void
DeleteLoadIndexInfo(CLoadIndexInfo c_load_index_info);
CStatus
AppendIndexParam(CLoadIndexInfo c_load_index_info, const char* index_key, const char* index_value);

View File

@ -176,8 +176,9 @@ FillTargetEntry(CSegmentBase c_segment, CPlan c_plan, CQueryResult c_result) {
CStatus
UpdateSegmentIndex(CSegmentBase c_segment, CLoadIndexInfo c_load_index_info) {
auto load_index_info = (LoadIndexInfo*)c_load_index_info;
try {
auto segment = (milvus::segcore::SegmentBase*)c_segment;
auto load_index_info = (LoadIndexInfo*)c_load_index_info;
auto status = CStatus();
status.error_code = Success;
status.error_msg = "";
@ -189,7 +190,6 @@ UpdateSegmentIndex(CSegmentBase c_segment, CLoadIndexInfo c_load_index_info) {
return status;
}
}
//////////////////////////////////////////////////////////////////
int

View File

@ -685,6 +685,49 @@ TEST(CApiTest, Reduce) {
DeleteSegment(segment);
}
TEST(CApiTest, LoadIndexInfo) {
// generator index
constexpr auto DIM = 16;
constexpr auto K = 10;
auto N = 1024 * 10;
auto [raw_data, timestamps, uids] = generate_data(N);
auto indexing = std::make_shared<milvus::knowhere::IVFPQ>();
auto conf = milvus::knowhere::Config{{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
{milvus::knowhere::IndexParams::nlist, 100},
{milvus::knowhere::IndexParams::nprobe, 4},
{milvus::knowhere::IndexParams::m, 4},
{milvus::knowhere::IndexParams::nbits, 8},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
{milvus::knowhere::meta::DEVICEID, 0}};
auto database = milvus::knowhere::GenDataset(N, DIM, raw_data.data());
indexing->Train(database, conf);
indexing->AddWithoutIds(database, conf);
EXPECT_EQ(indexing->Count(), N);
EXPECT_EQ(indexing->Dim(), DIM);
auto binary_set = indexing->Serialize(conf);
CBinarySet c_binary_set = (CBinarySet)&binary_set;
void* c_load_index_info = nullptr;
auto status = NewLoadIndexInfo(&c_load_index_info);
assert(status.error_code == Success);
std::string index_param_key1 = "index_type";
std::string index_param_value1 = "IVF_PQ";
status = AppendIndexParam(c_load_index_info, index_param_key1.data(), index_param_value1.data());
std::string index_param_key2 = "index_mode";
std::string index_param_value2 = "cpu";
status = AppendIndexParam(c_load_index_info, index_param_key2.data(), index_param_value2.data());
assert(status.error_code == Success);
std::string field_name = "field0";
status = AppendFieldInfo(c_load_index_info, field_name.data(), 0);
assert(status.error_code == Success);
status = AppendIndex(c_load_index_info, c_binary_set);
assert(status.error_code == Success);
DeleteLoadIndexInfo(c_load_index_info);
}
TEST(CApiTest, LoadIndex_Search) {
// generator index
constexpr auto DIM = 16;

View File

@ -4,6 +4,7 @@ import (
"context"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
)
@ -21,18 +22,28 @@ func NewLoadIndexClient(ctx context.Context, pulsarAddress string, loadIndexChan
}
}
func (lic *LoadIndexClient) LoadIndex(indexPaths []string, segmentID int64, fieldID int64, indexParam map[string]string) error {
// TODO:: add indexParam to proto
func (lic *LoadIndexClient) LoadIndex(indexPaths []string, segmentID int64, fieldID int64, fieldName string, indexParams map[string]string) error {
baseMsg := msgstream.BaseMsg{
BeginTimestamp: 0,
EndTimestamp: 0,
HashValues: []uint32{0},
}
var indexParamsKV []*commonpb.KeyValuePair
for indexParam := range indexParams {
indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{
Key: indexParam,
Value: indexParams[indexParam],
})
}
loadIndexRequest := internalPb.LoadIndex{
MsgType: internalPb.MsgType_kLoadIndex,
SegmentID: segmentID,
FieldID: fieldID,
IndexPaths: indexPaths,
MsgType: internalPb.MsgType_kLoadIndex,
SegmentID: segmentID,
FieldName: fieldName,
FieldID: fieldID,
IndexPaths: indexPaths,
IndexParams: indexParamsKV,
}
loadIndexMsg := &msgstream.LoadIndexMsg{

View File

@ -18,7 +18,7 @@ type LoadIndexInfo struct {
cLoadIndexInfo C.CLoadIndexInfo
}
func NewLoadIndexInfo() (*LoadIndexInfo, error) {
func newLoadIndexInfo() (*LoadIndexInfo, error) {
var cLoadIndexInfo C.CLoadIndexInfo
status := C.NewLoadIndexInfo(&cLoadIndexInfo)
errorCode := status.error_code
@ -31,7 +31,11 @@ func NewLoadIndexInfo() (*LoadIndexInfo, error) {
return &LoadIndexInfo{cLoadIndexInfo: cLoadIndexInfo}, nil
}
func (li *LoadIndexInfo) AppendIndexParam(indexKey string, indexValue string) error {
func deleteLoadIndexInfo(info *LoadIndexInfo) {
C.DeleteLoadIndexInfo(info.cLoadIndexInfo)
}
func (li *LoadIndexInfo) appendIndexParam(indexKey string, indexValue string) error {
cIndexKey := C.CString(indexKey)
cIndexValue := C.CString(indexValue)
status := C.AppendIndexParam(li.cLoadIndexInfo, cIndexKey, cIndexValue)
@ -45,7 +49,7 @@ func (li *LoadIndexInfo) AppendIndexParam(indexKey string, indexValue string) er
return nil
}
func (li *LoadIndexInfo) AppendFieldInfo(fieldName string, fieldID int64) error {
func (li *LoadIndexInfo) appendFieldInfo(fieldName string, fieldID int64) error {
cFieldName := C.CString(fieldName)
cFieldID := C.long(fieldID)
status := C.AppendFieldInfo(li.cLoadIndexInfo, cFieldName, cFieldID)
@ -59,7 +63,7 @@ func (li *LoadIndexInfo) AppendFieldInfo(fieldName string, fieldID int64) error
return nil
}
func (li *LoadIndexInfo) AppendIndex(bytesIndex [][]byte, indexKeys []string) error {
func (li *LoadIndexInfo) appendIndex(bytesIndex [][]byte, indexKeys []string) error {
var cBinarySet C.CBinarySet
status := C.NewBinarySet(&cBinarySet)

View File

@ -0,0 +1,36 @@
package querynode
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
)
func TestLoadIndexInfo(t *testing.T) {
indexParams := make([]*commonpb.KeyValuePair, 0)
indexParams = append(indexParams, &commonpb.KeyValuePair{
Key: "index_type",
Value: "IVF_PQ",
})
indexParams = append(indexParams, &commonpb.KeyValuePair{
Key: "index_mode",
Value: "cpu",
})
indexBytes := make([][]byte, 0)
indexValue := make([]byte, 10)
indexBytes = append(indexBytes, indexValue)
indexPaths := make([]string, 0)
indexPaths = append(indexPaths, "index-0")
loadIndexInfo, err := newLoadIndexInfo()
assert.Nil(t, err)
for _, indexParam := range indexParams {
loadIndexInfo.appendIndexParam(indexParam.Key, indexParam.Value)
}
loadIndexInfo.appendFieldInfo("field0", 0)
loadIndexInfo.appendIndex(indexBytes, indexPaths)
deleteLoadIndexInfo(loadIndexInfo)
}

View File

@ -107,17 +107,28 @@ func (lis *loadIndexService) start() {
log.Println("type assertion failed for LoadIndexMsg")
continue
}
/* TODO: debug
// 1. use msg's index paths to get index bytes
indexBuffer := lis.loadIndex(indexMsg.IndexPaths)
// 2. use index bytes and index path to update segment
err := lis.updateSegmentIndex(indexBuffer, indexMsg.IndexPaths, indexMsg.SegmentID)
if err != nil {
log.Println(err)
continue
}
*/
// 3. update segment index stats
//// 1. use msg's index paths to get index bytes
//var indexBuffer [][]byte
//var err error
//fn := func() error {
// indexBuffer, err = lis.loadIndex(indexMsg.IndexPaths)
// if err != nil {
// return err
// }
// return nil
//}
//err = msgstream.Retry(5, time.Millisecond*200, fn)
//if err != nil {
// log.Println(err)
// continue
//}
//// 2. use index bytes and index path to update segment
//err = lis.updateSegmentIndex(indexBuffer, indexMsg)
//if err != nil {
// log.Println(err)
// continue
//}
//3. update segment index stats
err := lis.updateSegmentIndexStats(indexMsg)
if err != nil {
log.Println(err)
@ -216,7 +227,7 @@ func (lis *loadIndexService) updateSegmentIndexStats(indexMsg *msgstream.LoadInd
return nil
}
func (lis *loadIndexService) loadIndex(indexPath []string) [][]byte {
func (lis *loadIndexService) loadIndex(indexPath []string) ([][]byte, error) {
index := make([][]byte, 0)
for _, path := range indexPath {
@ -224,13 +235,12 @@ func (lis *loadIndexService) loadIndex(indexPath []string) [][]byte {
binarySetKey := filepath.Base(path)
indexPiece, err := (*lis.client).Load(binarySetKey)
if err != nil {
log.Println(err)
return nil
return nil, err
}
index = append(index, []byte(indexPiece))
}
return index
return index, nil
}
func (lis *loadIndexService) updateSegmentIndex(bytesIndex [][]byte, loadIndexMsg *msgstream.LoadIndexMsg) error {
@ -239,21 +249,22 @@ func (lis *loadIndexService) updateSegmentIndex(bytesIndex [][]byte, loadIndexMs
return err
}
loadIndexInfo, err := NewLoadIndexInfo()
loadIndexInfo, err := newLoadIndexInfo()
defer deleteLoadIndexInfo(loadIndexInfo)
if err != nil {
return err
}
err = loadIndexInfo.AppendFieldInfo(loadIndexMsg.FieldName, loadIndexMsg.FieldID)
err = loadIndexInfo.appendFieldInfo(loadIndexMsg.FieldName, loadIndexMsg.FieldID)
if err != nil {
return err
}
for _, indexParam := range loadIndexMsg.IndexParams {
err = loadIndexInfo.AppendIndexParam(indexParam.Key, indexParam.Value)
err = loadIndexInfo.appendIndexParam(indexParam.Key, indexParam.Value)
if err != nil {
return err
}
}
err = loadIndexInfo.AppendIndex(bytesIndex, loadIndexMsg.IndexPaths)
err = loadIndexInfo.appendIndex(bytesIndex, loadIndexMsg.IndexPaths)
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package querynode
import (
"context"
"math"
"math/rand"
"sort"
@ -11,8 +12,26 @@ import (
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/querynode/client"
)
func TestLoadIndexClient_LoadIndex(t *testing.T) {
pulsarURL := Params.PulsarAddress
loadIndexChannels := Params.LoadIndexChannelNames
loadIndexClient := client.NewLoadIndexClient(context.Background(), pulsarURL, loadIndexChannels)
loadIndexPath := "collection0-segment0-field0"
loadIndexPaths := make([]string, 0)
loadIndexPaths = append(loadIndexPaths, loadIndexPath)
indexParams := make(map[string]string)
indexParams["index_type"] = "IVF_PQ"
indexParams["index_mode"] = "cpu"
loadIndexClient.LoadIndex(loadIndexPaths, 0, 0, "field0", indexParams)
loadIndexClient.Close()
}
func TestLoadIndexService_PulsarAddress(t *testing.T) {
node := newQueryNode()
collectionID := rand.Int63n(1000000)
@ -125,24 +144,38 @@ func TestLoadIndexService_PulsarAddress(t *testing.T) {
statsMs.CreatePulsarConsumers([]string{Params.StatsChannelName}, Params.MsgChannelSubName, msgstream.NewUnmarshalDispatcher(), Params.StatsReceiveBufSize)
statsMs.Start()
receiveMsg := msgstream.MsgStream(statsMs).Consume()
assert.NotNil(t, receiveMsg)
assert.NotEqual(t, len(receiveMsg.Msgs), 0)
statsMsg, ok := receiveMsg.Msgs[0].(*msgstream.QueryNodeStatsMsg)
assert.Equal(t, ok, true)
assert.Equal(t, len(statsMsg.FieldStats), 1)
fieldStats0 := statsMsg.FieldStats[0]
assert.Equal(t, fieldStats0.FieldID, fieldID)
assert.Equal(t, fieldStats0.CollectionID, collectionID)
assert.Equal(t, len(fieldStats0.IndexStats), 1)
indexStats0 := fieldStats0.IndexStats[0]
findFiledStats := false
for {
receiveMsg := msgstream.MsgStream(statsMs).Consume()
assert.NotNil(t, receiveMsg)
assert.NotEqual(t, len(receiveMsg.Msgs), 0)
params := indexStats0.IndexParams
// sort index params by key
sort.Slice(indexParams, func(i, j int) bool { return indexParams[i].Key < indexParams[j].Key })
indexEqual := node.loadIndexService.indexParamsEqual(params, indexParams)
assert.Equal(t, indexEqual, true)
for _, msg := range receiveMsg.Msgs {
statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg)
if statsMsg.FieldStats == nil || len(statsMsg.FieldStats) == 0 {
continue
}
findFiledStats = true
assert.Equal(t, ok, true)
assert.Equal(t, len(statsMsg.FieldStats), 1)
fieldStats0 := statsMsg.FieldStats[0]
assert.Equal(t, fieldStats0.FieldID, fieldID)
assert.Equal(t, fieldStats0.CollectionID, collectionID)
assert.Equal(t, len(fieldStats0.IndexStats), 1)
indexStats0 := fieldStats0.IndexStats[0]
params := indexStats0.IndexParams
// sort index params by key
sort.Slice(indexParams, func(i, j int) bool { return indexParams[i].Key < indexParams[j].Key })
indexEqual := node.loadIndexService.indexParamsEqual(params, indexParams)
assert.Equal(t, indexEqual, true)
}
if findFiledStats {
break
}
}
defer assert.Equal(t, findFiledStats, true)
<-node.queryNodeLoopCtx.Done()
node.Close()
}