mirror of https://github.com/milvus-io/milvus.git
Add unittest for loadIndex service
Signed-off-by: xige-16 <xi.ge@zilliz.com>pull/4973/head^2
parent
2031c54746
commit
d599407e2b
|
@ -32,6 +32,12 @@ NewLoadIndexInfo(CLoadIndexInfo* c_load_index_info) {
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
DeleteLoadIndexInfo(CLoadIndexInfo c_load_index_info) {
|
||||
auto info = (LoadIndexInfo*)c_load_index_info;
|
||||
delete info;
|
||||
}
|
||||
|
||||
CStatus
|
||||
AppendIndexParam(CLoadIndexInfo c_load_index_info, const char* c_index_key, const char* c_index_value) {
|
||||
try {
|
||||
|
|
|
@ -25,6 +25,9 @@ typedef void* CBinarySet;
|
|||
CStatus
|
||||
NewLoadIndexInfo(CLoadIndexInfo* c_load_index_info);
|
||||
|
||||
void
|
||||
DeleteLoadIndexInfo(CLoadIndexInfo c_load_index_info);
|
||||
|
||||
CStatus
|
||||
AppendIndexParam(CLoadIndexInfo c_load_index_info, const char* index_key, const char* index_value);
|
||||
|
||||
|
|
|
@ -176,8 +176,9 @@ FillTargetEntry(CSegmentBase c_segment, CPlan c_plan, CQueryResult c_result) {
|
|||
|
||||
CStatus
|
||||
UpdateSegmentIndex(CSegmentBase c_segment, CLoadIndexInfo c_load_index_info) {
|
||||
auto load_index_info = (LoadIndexInfo*)c_load_index_info;
|
||||
try {
|
||||
auto segment = (milvus::segcore::SegmentBase*)c_segment;
|
||||
auto load_index_info = (LoadIndexInfo*)c_load_index_info;
|
||||
auto status = CStatus();
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
|
@ -189,7 +190,6 @@ UpdateSegmentIndex(CSegmentBase c_segment, CLoadIndexInfo c_load_index_info) {
|
|||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
|
||||
int
|
||||
|
|
|
@ -685,6 +685,49 @@ TEST(CApiTest, Reduce) {
|
|||
DeleteSegment(segment);
|
||||
}
|
||||
|
||||
TEST(CApiTest, LoadIndexInfo) {
|
||||
// generator index
|
||||
constexpr auto DIM = 16;
|
||||
constexpr auto K = 10;
|
||||
|
||||
auto N = 1024 * 10;
|
||||
auto [raw_data, timestamps, uids] = generate_data(N);
|
||||
auto indexing = std::make_shared<milvus::knowhere::IVFPQ>();
|
||||
auto conf = milvus::knowhere::Config{{milvus::knowhere::meta::DIM, DIM},
|
||||
{milvus::knowhere::meta::TOPK, K},
|
||||
{milvus::knowhere::IndexParams::nlist, 100},
|
||||
{milvus::knowhere::IndexParams::nprobe, 4},
|
||||
{milvus::knowhere::IndexParams::m, 4},
|
||||
{milvus::knowhere::IndexParams::nbits, 8},
|
||||
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
|
||||
{milvus::knowhere::meta::DEVICEID, 0}};
|
||||
|
||||
auto database = milvus::knowhere::GenDataset(N, DIM, raw_data.data());
|
||||
indexing->Train(database, conf);
|
||||
indexing->AddWithoutIds(database, conf);
|
||||
EXPECT_EQ(indexing->Count(), N);
|
||||
EXPECT_EQ(indexing->Dim(), DIM);
|
||||
auto binary_set = indexing->Serialize(conf);
|
||||
CBinarySet c_binary_set = (CBinarySet)&binary_set;
|
||||
|
||||
void* c_load_index_info = nullptr;
|
||||
auto status = NewLoadIndexInfo(&c_load_index_info);
|
||||
assert(status.error_code == Success);
|
||||
std::string index_param_key1 = "index_type";
|
||||
std::string index_param_value1 = "IVF_PQ";
|
||||
status = AppendIndexParam(c_load_index_info, index_param_key1.data(), index_param_value1.data());
|
||||
std::string index_param_key2 = "index_mode";
|
||||
std::string index_param_value2 = "cpu";
|
||||
status = AppendIndexParam(c_load_index_info, index_param_key2.data(), index_param_value2.data());
|
||||
assert(status.error_code == Success);
|
||||
std::string field_name = "field0";
|
||||
status = AppendFieldInfo(c_load_index_info, field_name.data(), 0);
|
||||
assert(status.error_code == Success);
|
||||
status = AppendIndex(c_load_index_info, c_binary_set);
|
||||
assert(status.error_code == Success);
|
||||
DeleteLoadIndexInfo(c_load_index_info);
|
||||
}
|
||||
|
||||
TEST(CApiTest, LoadIndex_Search) {
|
||||
// generator index
|
||||
constexpr auto DIM = 16;
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
|
||||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
)
|
||||
|
||||
|
@ -21,18 +22,28 @@ func NewLoadIndexClient(ctx context.Context, pulsarAddress string, loadIndexChan
|
|||
}
|
||||
}
|
||||
|
||||
func (lic *LoadIndexClient) LoadIndex(indexPaths []string, segmentID int64, fieldID int64, indexParam map[string]string) error {
|
||||
// TODO:: add indexParam to proto
|
||||
func (lic *LoadIndexClient) LoadIndex(indexPaths []string, segmentID int64, fieldID int64, fieldName string, indexParams map[string]string) error {
|
||||
baseMsg := msgstream.BaseMsg{
|
||||
BeginTimestamp: 0,
|
||||
EndTimestamp: 0,
|
||||
HashValues: []uint32{0},
|
||||
}
|
||||
|
||||
var indexParamsKV []*commonpb.KeyValuePair
|
||||
for indexParam := range indexParams {
|
||||
indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{
|
||||
Key: indexParam,
|
||||
Value: indexParams[indexParam],
|
||||
})
|
||||
}
|
||||
|
||||
loadIndexRequest := internalPb.LoadIndex{
|
||||
MsgType: internalPb.MsgType_kLoadIndex,
|
||||
SegmentID: segmentID,
|
||||
FieldID: fieldID,
|
||||
IndexPaths: indexPaths,
|
||||
MsgType: internalPb.MsgType_kLoadIndex,
|
||||
SegmentID: segmentID,
|
||||
FieldName: fieldName,
|
||||
FieldID: fieldID,
|
||||
IndexPaths: indexPaths,
|
||||
IndexParams: indexParamsKV,
|
||||
}
|
||||
|
||||
loadIndexMsg := &msgstream.LoadIndexMsg{
|
||||
|
|
|
@ -18,7 +18,7 @@ type LoadIndexInfo struct {
|
|||
cLoadIndexInfo C.CLoadIndexInfo
|
||||
}
|
||||
|
||||
func NewLoadIndexInfo() (*LoadIndexInfo, error) {
|
||||
func newLoadIndexInfo() (*LoadIndexInfo, error) {
|
||||
var cLoadIndexInfo C.CLoadIndexInfo
|
||||
status := C.NewLoadIndexInfo(&cLoadIndexInfo)
|
||||
errorCode := status.error_code
|
||||
|
@ -31,7 +31,11 @@ func NewLoadIndexInfo() (*LoadIndexInfo, error) {
|
|||
return &LoadIndexInfo{cLoadIndexInfo: cLoadIndexInfo}, nil
|
||||
}
|
||||
|
||||
func (li *LoadIndexInfo) AppendIndexParam(indexKey string, indexValue string) error {
|
||||
func deleteLoadIndexInfo(info *LoadIndexInfo) {
|
||||
C.DeleteLoadIndexInfo(info.cLoadIndexInfo)
|
||||
}
|
||||
|
||||
func (li *LoadIndexInfo) appendIndexParam(indexKey string, indexValue string) error {
|
||||
cIndexKey := C.CString(indexKey)
|
||||
cIndexValue := C.CString(indexValue)
|
||||
status := C.AppendIndexParam(li.cLoadIndexInfo, cIndexKey, cIndexValue)
|
||||
|
@ -45,7 +49,7 @@ func (li *LoadIndexInfo) AppendIndexParam(indexKey string, indexValue string) er
|
|||
return nil
|
||||
}
|
||||
|
||||
func (li *LoadIndexInfo) AppendFieldInfo(fieldName string, fieldID int64) error {
|
||||
func (li *LoadIndexInfo) appendFieldInfo(fieldName string, fieldID int64) error {
|
||||
cFieldName := C.CString(fieldName)
|
||||
cFieldID := C.long(fieldID)
|
||||
status := C.AppendFieldInfo(li.cLoadIndexInfo, cFieldName, cFieldID)
|
||||
|
@ -59,7 +63,7 @@ func (li *LoadIndexInfo) AppendFieldInfo(fieldName string, fieldID int64) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (li *LoadIndexInfo) AppendIndex(bytesIndex [][]byte, indexKeys []string) error {
|
||||
func (li *LoadIndexInfo) appendIndex(bytesIndex [][]byte, indexKeys []string) error {
|
||||
var cBinarySet C.CBinarySet
|
||||
status := C.NewBinarySet(&cBinarySet)
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
)
|
||||
|
||||
func TestLoadIndexInfo(t *testing.T) {
|
||||
indexParams := make([]*commonpb.KeyValuePair, 0)
|
||||
indexParams = append(indexParams, &commonpb.KeyValuePair{
|
||||
Key: "index_type",
|
||||
Value: "IVF_PQ",
|
||||
})
|
||||
indexParams = append(indexParams, &commonpb.KeyValuePair{
|
||||
Key: "index_mode",
|
||||
Value: "cpu",
|
||||
})
|
||||
|
||||
indexBytes := make([][]byte, 0)
|
||||
indexValue := make([]byte, 10)
|
||||
indexBytes = append(indexBytes, indexValue)
|
||||
indexPaths := make([]string, 0)
|
||||
indexPaths = append(indexPaths, "index-0")
|
||||
|
||||
loadIndexInfo, err := newLoadIndexInfo()
|
||||
assert.Nil(t, err)
|
||||
for _, indexParam := range indexParams {
|
||||
loadIndexInfo.appendIndexParam(indexParam.Key, indexParam.Value)
|
||||
}
|
||||
loadIndexInfo.appendFieldInfo("field0", 0)
|
||||
loadIndexInfo.appendIndex(indexBytes, indexPaths)
|
||||
|
||||
deleteLoadIndexInfo(loadIndexInfo)
|
||||
}
|
|
@ -107,17 +107,28 @@ func (lis *loadIndexService) start() {
|
|||
log.Println("type assertion failed for LoadIndexMsg")
|
||||
continue
|
||||
}
|
||||
/* TODO: debug
|
||||
// 1. use msg's index paths to get index bytes
|
||||
indexBuffer := lis.loadIndex(indexMsg.IndexPaths)
|
||||
// 2. use index bytes and index path to update segment
|
||||
err := lis.updateSegmentIndex(indexBuffer, indexMsg.IndexPaths, indexMsg.SegmentID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
*/
|
||||
// 3. update segment index stats
|
||||
//// 1. use msg's index paths to get index bytes
|
||||
//var indexBuffer [][]byte
|
||||
//var err error
|
||||
//fn := func() error {
|
||||
// indexBuffer, err = lis.loadIndex(indexMsg.IndexPaths)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// return nil
|
||||
//}
|
||||
//err = msgstream.Retry(5, time.Millisecond*200, fn)
|
||||
//if err != nil {
|
||||
// log.Println(err)
|
||||
// continue
|
||||
//}
|
||||
//// 2. use index bytes and index path to update segment
|
||||
//err = lis.updateSegmentIndex(indexBuffer, indexMsg)
|
||||
//if err != nil {
|
||||
// log.Println(err)
|
||||
// continue
|
||||
//}
|
||||
//3. update segment index stats
|
||||
err := lis.updateSegmentIndexStats(indexMsg)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
|
@ -216,7 +227,7 @@ func (lis *loadIndexService) updateSegmentIndexStats(indexMsg *msgstream.LoadInd
|
|||
return nil
|
||||
}
|
||||
|
||||
func (lis *loadIndexService) loadIndex(indexPath []string) [][]byte {
|
||||
func (lis *loadIndexService) loadIndex(indexPath []string) ([][]byte, error) {
|
||||
index := make([][]byte, 0)
|
||||
|
||||
for _, path := range indexPath {
|
||||
|
@ -224,13 +235,12 @@ func (lis *loadIndexService) loadIndex(indexPath []string) [][]byte {
|
|||
binarySetKey := filepath.Base(path)
|
||||
indexPiece, err := (*lis.client).Load(binarySetKey)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
index = append(index, []byte(indexPiece))
|
||||
}
|
||||
|
||||
return index
|
||||
return index, nil
|
||||
}
|
||||
|
||||
func (lis *loadIndexService) updateSegmentIndex(bytesIndex [][]byte, loadIndexMsg *msgstream.LoadIndexMsg) error {
|
||||
|
@ -239,21 +249,22 @@ func (lis *loadIndexService) updateSegmentIndex(bytesIndex [][]byte, loadIndexMs
|
|||
return err
|
||||
}
|
||||
|
||||
loadIndexInfo, err := NewLoadIndexInfo()
|
||||
loadIndexInfo, err := newLoadIndexInfo()
|
||||
defer deleteLoadIndexInfo(loadIndexInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = loadIndexInfo.AppendFieldInfo(loadIndexMsg.FieldName, loadIndexMsg.FieldID)
|
||||
err = loadIndexInfo.appendFieldInfo(loadIndexMsg.FieldName, loadIndexMsg.FieldID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, indexParam := range loadIndexMsg.IndexParams {
|
||||
err = loadIndexInfo.AppendIndexParam(indexParam.Key, indexParam.Value)
|
||||
err = loadIndexInfo.appendIndexParam(indexParam.Key, indexParam.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = loadIndexInfo.AppendIndex(bytesIndex, loadIndexMsg.IndexPaths)
|
||||
err = loadIndexInfo.appendIndex(bytesIndex, loadIndexMsg.IndexPaths)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package querynode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"math/rand"
|
||||
"sort"
|
||||
|
@ -11,8 +12,26 @@ import (
|
|||
"github.com/zilliztech/milvus-distributed/internal/msgstream"
|
||||
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
|
||||
internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
|
||||
"github.com/zilliztech/milvus-distributed/internal/querynode/client"
|
||||
)
|
||||
|
||||
func TestLoadIndexClient_LoadIndex(t *testing.T) {
|
||||
pulsarURL := Params.PulsarAddress
|
||||
loadIndexChannels := Params.LoadIndexChannelNames
|
||||
loadIndexClient := client.NewLoadIndexClient(context.Background(), pulsarURL, loadIndexChannels)
|
||||
|
||||
loadIndexPath := "collection0-segment0-field0"
|
||||
loadIndexPaths := make([]string, 0)
|
||||
loadIndexPaths = append(loadIndexPaths, loadIndexPath)
|
||||
|
||||
indexParams := make(map[string]string)
|
||||
indexParams["index_type"] = "IVF_PQ"
|
||||
indexParams["index_mode"] = "cpu"
|
||||
|
||||
loadIndexClient.LoadIndex(loadIndexPaths, 0, 0, "field0", indexParams)
|
||||
loadIndexClient.Close()
|
||||
}
|
||||
|
||||
func TestLoadIndexService_PulsarAddress(t *testing.T) {
|
||||
node := newQueryNode()
|
||||
collectionID := rand.Int63n(1000000)
|
||||
|
@ -125,24 +144,38 @@ func TestLoadIndexService_PulsarAddress(t *testing.T) {
|
|||
statsMs.CreatePulsarConsumers([]string{Params.StatsChannelName}, Params.MsgChannelSubName, msgstream.NewUnmarshalDispatcher(), Params.StatsReceiveBufSize)
|
||||
statsMs.Start()
|
||||
|
||||
receiveMsg := msgstream.MsgStream(statsMs).Consume()
|
||||
assert.NotNil(t, receiveMsg)
|
||||
assert.NotEqual(t, len(receiveMsg.Msgs), 0)
|
||||
statsMsg, ok := receiveMsg.Msgs[0].(*msgstream.QueryNodeStatsMsg)
|
||||
assert.Equal(t, ok, true)
|
||||
assert.Equal(t, len(statsMsg.FieldStats), 1)
|
||||
fieldStats0 := statsMsg.FieldStats[0]
|
||||
assert.Equal(t, fieldStats0.FieldID, fieldID)
|
||||
assert.Equal(t, fieldStats0.CollectionID, collectionID)
|
||||
assert.Equal(t, len(fieldStats0.IndexStats), 1)
|
||||
indexStats0 := fieldStats0.IndexStats[0]
|
||||
findFiledStats := false
|
||||
for {
|
||||
receiveMsg := msgstream.MsgStream(statsMs).Consume()
|
||||
assert.NotNil(t, receiveMsg)
|
||||
assert.NotEqual(t, len(receiveMsg.Msgs), 0)
|
||||
|
||||
params := indexStats0.IndexParams
|
||||
// sort index params by key
|
||||
sort.Slice(indexParams, func(i, j int) bool { return indexParams[i].Key < indexParams[j].Key })
|
||||
indexEqual := node.loadIndexService.indexParamsEqual(params, indexParams)
|
||||
assert.Equal(t, indexEqual, true)
|
||||
for _, msg := range receiveMsg.Msgs {
|
||||
statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg)
|
||||
if statsMsg.FieldStats == nil || len(statsMsg.FieldStats) == 0 {
|
||||
continue
|
||||
}
|
||||
findFiledStats = true
|
||||
assert.Equal(t, ok, true)
|
||||
assert.Equal(t, len(statsMsg.FieldStats), 1)
|
||||
fieldStats0 := statsMsg.FieldStats[0]
|
||||
assert.Equal(t, fieldStats0.FieldID, fieldID)
|
||||
assert.Equal(t, fieldStats0.CollectionID, collectionID)
|
||||
assert.Equal(t, len(fieldStats0.IndexStats), 1)
|
||||
indexStats0 := fieldStats0.IndexStats[0]
|
||||
params := indexStats0.IndexParams
|
||||
// sort index params by key
|
||||
sort.Slice(indexParams, func(i, j int) bool { return indexParams[i].Key < indexParams[j].Key })
|
||||
indexEqual := node.loadIndexService.indexParamsEqual(params, indexParams)
|
||||
assert.Equal(t, indexEqual, true)
|
||||
}
|
||||
|
||||
if findFiledStats {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
defer assert.Equal(t, findFiledStats, true)
|
||||
<-node.queryNodeLoopCtx.Done()
|
||||
node.Close()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue