mirror of https://github.com/milvus-io/milvus.git
73 lines
2.0 KiB
Go
73 lines
2.0 KiB
Go
package querycoord
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
"github.com/milvus-io/milvus/internal/kv"
|
|
memkv "github.com/milvus-io/milvus/internal/kv/mem"
|
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func Test_segmentsInfo_getSegment(t *testing.T) {
|
|
s := newSegmentsInfo(createTestKv(t))
|
|
assert.Nil(t, s.loadSegments())
|
|
got := s.getSegment(1)
|
|
assert.EqualValues(t, 1, got.GetSegmentID())
|
|
got = s.getSegment(2)
|
|
assert.EqualValues(t, 2, got.GetSegmentID())
|
|
|
|
segment := &querypb.SegmentInfo{SegmentID: 3, CollectionID: 3}
|
|
assert.Nil(t, s.saveSegment(segment))
|
|
got = s.getSegment(3)
|
|
assert.NotNil(t, got)
|
|
assert.True(t, proto.Equal(segment, got))
|
|
|
|
assert.Nil(t, s.removeSegment(segment))
|
|
got = s.getSegment(3)
|
|
assert.Nil(t, got)
|
|
}
|
|
|
|
func Test_segmentsInfo_getSegments(t *testing.T) {
|
|
s := newSegmentsInfo(createTestKv(t))
|
|
assert.Nil(t, s.loadSegments())
|
|
got := s.getSegments()
|
|
assert.ElementsMatch(t, []int64{1, 2}, collectSegmentIDs(got))
|
|
|
|
segment := &querypb.SegmentInfo{SegmentID: 3, CollectionID: 3}
|
|
assert.Nil(t, s.saveSegment(segment))
|
|
got = s.getSegments()
|
|
assert.ElementsMatch(t, []int64{1, 2, 3}, collectSegmentIDs(got))
|
|
assert.Nil(t, s.saveSegment(segment))
|
|
got = s.getSegments()
|
|
assert.ElementsMatch(t, []int64{1, 2, 3}, collectSegmentIDs(got))
|
|
|
|
assert.Nil(t, s.removeSegment(segment))
|
|
got = s.getSegments()
|
|
assert.ElementsMatch(t, []int64{1, 2}, collectSegmentIDs(got))
|
|
}
|
|
|
|
func createTestKv(t *testing.T) kv.TxnKV {
|
|
kv := memkv.NewMemoryKV()
|
|
segments := []*querypb.SegmentInfo{
|
|
{SegmentID: 1, CollectionID: 1},
|
|
{SegmentID: 2, CollectionID: 2},
|
|
}
|
|
for _, segment := range segments {
|
|
k := getSegmentKey(segment)
|
|
v, err := proto.Marshal(segment)
|
|
assert.Nil(t, err)
|
|
assert.Nil(t, kv.Save(k, string(v)))
|
|
}
|
|
return kv
|
|
}
|
|
|
|
func collectSegmentIDs(segments []*querypb.SegmentInfo) []int64 {
|
|
res := make([]int64, 0, len(segments))
|
|
for _, s := range segments {
|
|
res = append(res, s.GetSegmentID())
|
|
}
|
|
return res
|
|
}
|