mirror of https://github.com/milvus-io/milvus.git
Fix segmentPtr concurrent visit and destroy (#19560)
Signed-off-by: Congqi Xia <congqi.xia@zilliz.com> Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>pull/19528/head
parent
ccf30d358c
commit
d79f88c5f6
|
@ -166,7 +166,7 @@ func (dNode *deleteNode) delete(deleteData *deleteData, segmentID UniqueID, wg *
|
|||
return fmt.Errorf("getSegmentByID failed, err = %s", err)
|
||||
}
|
||||
|
||||
if targetSegment.segmentType != segmentTypeSealed {
|
||||
if targetSegment.getType() != segmentTypeSealed {
|
||||
return fmt.Errorf("unexpected segmentType when delete, segmentID = %d, segmentType = %s", segmentID, targetSegment.segmentType.String())
|
||||
}
|
||||
|
||||
|
|
|
@ -175,6 +175,8 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
|
|||
|
||||
// 2. do preInsert
|
||||
for segmentID := range iData.insertRecords {
|
||||
log := log.With(
|
||||
zap.Int64("segmentID", segmentID))
|
||||
var targetSegment, err = iNode.metaReplica.getSegmentByID(segmentID, segmentTypeGrowing)
|
||||
if err != nil {
|
||||
// should not happen, segment should be created before
|
||||
|
@ -187,6 +189,10 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
|
|||
if targetSegment != nil {
|
||||
offset, err := targetSegment.segmentPreInsert(numOfRecords)
|
||||
if err != nil {
|
||||
if errors.Is(err, errSegmentUnhealthy) {
|
||||
log.Debug("segment removed before preInsert")
|
||||
continue
|
||||
}
|
||||
// error occurs when cgo function `PreInsert` failed
|
||||
err = fmt.Errorf("segmentPreInsert failed, segmentID = %d, err = %s", segmentID, err)
|
||||
log.Error(err.Error(), zap.Int64("collectionID", iNode.collectionID), zap.String("channel", iNode.channel))
|
||||
|
@ -256,7 +262,9 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg {
|
|||
panic(err)
|
||||
}
|
||||
offset := segment.segmentPreDelete(len(pks))
|
||||
delData.deleteOffset[segmentID] = offset
|
||||
if offset >= 0 {
|
||||
delData.deleteOffset[segmentID] = offset
|
||||
}
|
||||
}
|
||||
|
||||
// 3. do delete
|
||||
|
@ -364,6 +372,9 @@ func filterSegmentsByPKs(pks []primaryKey, timestamps []Timestamp, segment *Segm
|
|||
|
||||
// insert would execute insert operations for specific growing segment
|
||||
func (iNode *insertNode) insert(iData *insertData, segmentID UniqueID) error {
|
||||
log := log.With(
|
||||
zap.Int64("collectionID", iNode.collectionID),
|
||||
zap.Int64("segmentID", segmentID))
|
||||
var targetSegment, err = iNode.metaReplica.getSegmentByID(segmentID, segmentTypeGrowing)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getSegmentByID failed, err = %s", err)
|
||||
|
@ -379,15 +390,22 @@ func (iNode *insertNode) insert(iData *insertData, segmentID UniqueID) error {
|
|||
|
||||
err = targetSegment.segmentInsert(offsets, ids, timestamps, insertRecord)
|
||||
if err != nil {
|
||||
if errors.Is(err, errSegmentUnhealthy) {
|
||||
log.Debug("segment removed before insert")
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("segmentInsert failed, segmentID = %d, err = %s", segmentID, err)
|
||||
}
|
||||
|
||||
log.Debug("Do insert done", zap.Int("len", len(iData.insertIDs[segmentID])), zap.Int64("collectionID", targetSegment.collectionID), zap.Int64("segmentID", segmentID))
|
||||
log.Debug("Do insert done", zap.Int("len", len(iData.insertIDs[segmentID])))
|
||||
return nil
|
||||
}
|
||||
|
||||
// delete would execute delete operations for specific growing segment
|
||||
func (iNode *insertNode) delete(deleteData *deleteData, segmentID UniqueID) error {
|
||||
log := log.With(
|
||||
zap.Int64("collectionID", iNode.collectionID),
|
||||
zap.Int64("segmentID", segmentID))
|
||||
targetSegment, err := iNode.metaReplica.getSegmentByID(segmentID, segmentTypeGrowing)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSegmentNotFound) {
|
||||
|
@ -400,7 +418,7 @@ func (iNode *insertNode) delete(deleteData *deleteData, segmentID UniqueID) erro
|
|||
return fmt.Errorf("getSegmentByID failed, err = %s", err)
|
||||
}
|
||||
|
||||
if targetSegment.segmentType != segmentTypeGrowing {
|
||||
if targetSegment.getType() != segmentTypeGrowing {
|
||||
return fmt.Errorf("unexpected segmentType when delete, segmentType = %s", targetSegment.segmentType.String())
|
||||
}
|
||||
|
||||
|
@ -410,10 +428,14 @@ func (iNode *insertNode) delete(deleteData *deleteData, segmentID UniqueID) erro
|
|||
|
||||
err = targetSegment.segmentDelete(offset, ids, timestamps)
|
||||
if err != nil {
|
||||
if errors.Is(err, errSegmentUnhealthy) {
|
||||
log.Debug("segment removed before delete")
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("segmentDelete failed, err = %s", err)
|
||||
}
|
||||
|
||||
log.Debug("Do delete done", zap.Int("len", len(deleteData.deleteIDs[segmentID])), zap.Int64("segmentID", segmentID))
|
||||
log.Debug("Do delete done", zap.Int("len", len(deleteData.deleteIDs[segmentID])))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -887,7 +887,7 @@ func (replica *metaReplica) getSegmentInfo(segment *Segment) *querypb.SegmentInf
|
|||
IndexName: indexName,
|
||||
IndexID: indexID,
|
||||
DmChannel: segment.vChannelID,
|
||||
SegmentState: segment.segmentType,
|
||||
SegmentState: segment.getType(),
|
||||
IndexInfos: indexInfos,
|
||||
NodeIds: []UniqueID{Params.QueryNodeCfg.GetNodeID()},
|
||||
}
|
||||
|
|
|
@ -226,7 +226,7 @@ func TestStreaming_search(t *testing.T) {
|
|||
seg, err := streaming.getSegmentByID(defaultSegmentID, segmentTypeGrowing)
|
||||
assert.NoError(t, err)
|
||||
|
||||
seg.segmentPtr = nil
|
||||
seg.setUnhealthy()
|
||||
|
||||
_, _, _, err = searchStreaming(context.TODO(), streaming, searchReq,
|
||||
defaultCollectionID,
|
||||
|
|
|
@ -36,12 +36,14 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus/internal/util/concurrency"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/util/timerecord"
|
||||
|
||||
"github.com/bits-and-blooms/bloom/v3"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/api/commonpb"
|
||||
|
@ -66,6 +68,8 @@ const (
|
|||
maxBloomFalsePositive float64 = 0.005
|
||||
)
|
||||
|
||||
var errSegmentUnhealthy = errors.New("segment unhealthy")
|
||||
|
||||
// IndexedFieldInfo contains binlog info of vector field
|
||||
type IndexedFieldInfo struct {
|
||||
fieldBinlog *datapb.FieldBinlog
|
||||
|
@ -74,6 +78,7 @@ type IndexedFieldInfo struct {
|
|||
|
||||
// Segment is a wrapper of the underlying C-structure segment.
|
||||
type Segment struct {
|
||||
mut sync.RWMutex // protects segmentPtr
|
||||
segmentPtr C.CSegmentInterface
|
||||
|
||||
segmentID UniqueID
|
||||
|
@ -85,16 +90,13 @@ type Segment struct {
|
|||
lastMemSize int64
|
||||
lastRowCount int64
|
||||
|
||||
rmMutex sync.RWMutex // guards recentlyModified
|
||||
recentlyModified bool
|
||||
|
||||
typeMu sync.RWMutex // guards segmentType
|
||||
segmentType segmentType
|
||||
recentlyModified *atomic.Bool
|
||||
segmentType *atomic.Int32
|
||||
destroyed *atomic.Bool
|
||||
|
||||
idBinlogRowSizes []int64
|
||||
|
||||
indexedFieldMutex sync.RWMutex // guards indexedFieldInfos
|
||||
indexedFieldInfos map[UniqueID]*IndexedFieldInfo
|
||||
indexedFieldInfos *typeutil.ConcurrentMap[UniqueID, *IndexedFieldInfo]
|
||||
|
||||
pkFilter *bloom.BloomFilter // bloom filter of pk inside a segment
|
||||
|
||||
|
@ -115,56 +117,52 @@ func (s *Segment) getIDBinlogRowSizes() []int64 {
|
|||
}
|
||||
|
||||
func (s *Segment) setRecentlyModified(modify bool) {
|
||||
s.rmMutex.Lock()
|
||||
defer s.rmMutex.Unlock()
|
||||
s.recentlyModified = modify
|
||||
s.recentlyModified.Store(modify)
|
||||
}
|
||||
|
||||
func (s *Segment) getRecentlyModified() bool {
|
||||
s.rmMutex.RLock()
|
||||
defer s.rmMutex.RUnlock()
|
||||
return s.recentlyModified
|
||||
return s.recentlyModified.Load()
|
||||
}
|
||||
|
||||
func (s *Segment) setType(segType segmentType) {
|
||||
s.typeMu.Lock()
|
||||
defer s.typeMu.Unlock()
|
||||
s.segmentType = segType
|
||||
s.segmentType.Store(int32(segType))
|
||||
}
|
||||
|
||||
func (s *Segment) getType() segmentType {
|
||||
s.typeMu.RLock()
|
||||
defer s.typeMu.RUnlock()
|
||||
return s.segmentType
|
||||
return commonpb.SegmentState(s.segmentType.Load())
|
||||
}
|
||||
|
||||
func (s *Segment) setIndexedFieldInfo(fieldID UniqueID, info *IndexedFieldInfo) {
|
||||
s.indexedFieldMutex.Lock()
|
||||
defer s.indexedFieldMutex.Unlock()
|
||||
s.indexedFieldInfos[fieldID] = info
|
||||
s.indexedFieldInfos.Insert(fieldID, info)
|
||||
}
|
||||
|
||||
func (s *Segment) getIndexedFieldInfo(fieldID UniqueID) (*IndexedFieldInfo, error) {
|
||||
s.indexedFieldMutex.RLock()
|
||||
defer s.indexedFieldMutex.RUnlock()
|
||||
if info, ok := s.indexedFieldInfos[fieldID]; ok {
|
||||
return &IndexedFieldInfo{
|
||||
fieldBinlog: info.fieldBinlog,
|
||||
indexInfo: info.indexInfo,
|
||||
}, nil
|
||||
info, ok := s.indexedFieldInfos.Get(fieldID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Invalid fieldID %d", fieldID)
|
||||
}
|
||||
return nil, fmt.Errorf("Invalid fieldID %d", fieldID)
|
||||
return &IndexedFieldInfo{
|
||||
fieldBinlog: info.fieldBinlog,
|
||||
indexInfo: info.indexInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Segment) hasLoadIndexForIndexedField(fieldID int64) bool {
|
||||
s.indexedFieldMutex.RLock()
|
||||
defer s.indexedFieldMutex.RUnlock()
|
||||
|
||||
if fieldInfo, ok := s.indexedFieldInfos[fieldID]; ok {
|
||||
return fieldInfo.indexInfo != nil && fieldInfo.indexInfo.EnableIndex
|
||||
fieldInfo, ok := s.indexedFieldInfos.Get(fieldID)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return fieldInfo.indexInfo != nil && fieldInfo.indexInfo.EnableIndex
|
||||
}
|
||||
|
||||
return false
|
||||
// healthy checks whether it's safe to use `segmentPtr`.
|
||||
// shall acquire mut.RLock before check this flag.
|
||||
func (s *Segment) healthy() bool {
|
||||
return !s.destroyed.Load()
|
||||
}
|
||||
|
||||
func (s *Segment) setUnhealthy() {
|
||||
s.destroyed.Store(true)
|
||||
}
|
||||
|
||||
func newSegment(collection *Collection,
|
||||
|
@ -210,13 +208,15 @@ func newSegment(collection *Collection,
|
|||
|
||||
var segment = &Segment{
|
||||
segmentPtr: segmentPtr,
|
||||
segmentType: segType,
|
||||
segmentType: atomic.NewInt32(int32(segType)),
|
||||
segmentID: segmentID,
|
||||
partitionID: partitionID,
|
||||
collectionID: collectionID,
|
||||
version: version,
|
||||
vChannelID: vChannelID,
|
||||
indexedFieldInfos: make(map[UniqueID]*IndexedFieldInfo),
|
||||
indexedFieldInfos: typeutil.NewConcurrentMap[int64, *IndexedFieldInfo](),
|
||||
recentlyModified: atomic.NewBool(false),
|
||||
destroyed: atomic.NewBool(false),
|
||||
|
||||
pkFilter: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive),
|
||||
pool: pool,
|
||||
|
@ -230,16 +230,22 @@ func deleteSegment(segment *Segment) {
|
|||
void
|
||||
deleteSegment(CSegmentInterface segment);
|
||||
*/
|
||||
if segment.segmentPtr == nil {
|
||||
var cPtr C.CSegmentInterface
|
||||
// wait all read ops finished
|
||||
segment.mut.Lock()
|
||||
segment.setUnhealthy()
|
||||
cPtr = segment.segmentPtr
|
||||
segment.segmentPtr = nil
|
||||
segment.mut.Unlock()
|
||||
|
||||
if cPtr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
cPtr := segment.segmentPtr
|
||||
segment.pool.Submit(func() (interface{}, error) {
|
||||
C.DeleteSegment(cPtr)
|
||||
return nil, nil
|
||||
}).Await()
|
||||
segment.segmentPtr = nil
|
||||
|
||||
log.Info("delete segment from memory",
|
||||
zap.Int64("collectionID", segment.collectionID),
|
||||
|
@ -253,7 +259,9 @@ func (s *Segment) getRealCount() int64 {
|
|||
int64_t
|
||||
GetRealCount(CSegmentInterface c_segment);
|
||||
*/
|
||||
if s.segmentPtr == nil {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
if !s.healthy() {
|
||||
return -1
|
||||
}
|
||||
var rowCount C.int64_t
|
||||
|
@ -270,7 +278,9 @@ func (s *Segment) getRowCount() int64 {
|
|||
long int
|
||||
getRowCount(CSegmentInterface c_segment);
|
||||
*/
|
||||
if s.segmentPtr == nil {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
if !s.healthy() {
|
||||
return -1
|
||||
}
|
||||
var rowCount C.int64_t
|
||||
|
@ -287,7 +297,9 @@ func (s *Segment) getDeletedCount() int64 {
|
|||
long int
|
||||
getDeletedCount(CSegmentInterface c_segment);
|
||||
*/
|
||||
if s.segmentPtr == nil {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
if !s.healthy() {
|
||||
return -1
|
||||
}
|
||||
|
||||
|
@ -305,7 +317,9 @@ func (s *Segment) getMemSize() int64 {
|
|||
long int
|
||||
GetMemoryUsageInBytes(CSegmentInterface c_segment);
|
||||
*/
|
||||
if s.segmentPtr == nil {
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
if !s.healthy() {
|
||||
return -1
|
||||
}
|
||||
var memoryUsageInBytes C.int64_t
|
||||
|
@ -327,8 +341,10 @@ func (s *Segment) search(searchReq *searchRequest) (*SearchResult, error) {
|
|||
long int* result_ids,
|
||||
float* result_distances);
|
||||
*/
|
||||
if s.segmentPtr == nil {
|
||||
return nil, errors.New("null seg core pointer")
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
if !s.healthy() {
|
||||
return nil, fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
|
||||
}
|
||||
|
||||
if searchReq.plan == nil {
|
||||
|
@ -363,8 +379,10 @@ func (s *Segment) search(searchReq *searchRequest) (*SearchResult, error) {
|
|||
}
|
||||
|
||||
func (s *Segment) retrieve(plan *RetrievePlan) (*segcorepb.RetrieveResults, error) {
|
||||
if s.segmentPtr == nil {
|
||||
return nil, errors.New("null seg core pointer")
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
if !s.healthy() {
|
||||
return nil, fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
|
||||
}
|
||||
|
||||
var retrieveResult RetrieveResult
|
||||
|
@ -621,9 +639,16 @@ func (s *Segment) segmentPreInsert(numOfRecords int) (int64, error) {
|
|||
long int
|
||||
PreInsert(CSegmentInterface c_segment, long int size);
|
||||
*/
|
||||
if s.segmentType != segmentTypeGrowing {
|
||||
if s.getType() != segmentTypeGrowing {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
if !s.healthy() {
|
||||
return -1, fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
|
||||
}
|
||||
|
||||
var offset int64
|
||||
var status C.CStatus
|
||||
cOffset := (*C.int64_t)(&offset)
|
||||
|
@ -642,6 +667,11 @@ func (s *Segment) segmentPreDelete(numOfRecords int) int64 {
|
|||
long int
|
||||
PreDelete(CSegmentInterface c_segment, long int size);
|
||||
*/
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
if !s.healthy() {
|
||||
return -1
|
||||
}
|
||||
|
||||
var offset C.int64_t
|
||||
s.pool.Submit(func() (interface{}, error) {
|
||||
|
@ -654,12 +684,14 @@ func (s *Segment) segmentPreDelete(numOfRecords int) int64 {
|
|||
}
|
||||
|
||||
func (s *Segment) segmentInsert(offset int64, entityIDs []UniqueID, timestamps []Timestamp, record *segcorepb.InsertRecord) error {
|
||||
if s.segmentType != segmentTypeGrowing {
|
||||
if s.getType() != segmentTypeGrowing {
|
||||
return fmt.Errorf("unexpected segmentType when segmentInsert, segmentType = %s", s.segmentType.String())
|
||||
}
|
||||
|
||||
if s.segmentPtr == nil {
|
||||
return errors.New("null seg core pointer")
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
if !s.healthy() {
|
||||
return fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
|
||||
}
|
||||
|
||||
insertRecordBlob, err := proto.Marshal(record)
|
||||
|
@ -708,8 +740,10 @@ func (s *Segment) segmentDelete(offset int64, entityIDs []primaryKey, timestamps
|
|||
return fmt.Errorf("empty pks to delete")
|
||||
}
|
||||
|
||||
if s.segmentPtr == nil {
|
||||
return errors.New("null seg core pointer")
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
if !s.healthy() {
|
||||
return fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
|
||||
}
|
||||
|
||||
if len(entityIDs) != len(timestamps) {
|
||||
|
@ -772,13 +806,15 @@ func (s *Segment) segmentLoadFieldData(fieldID int64, rowCount int64, data *sche
|
|||
CStatus
|
||||
LoadFieldData(CSegmentInterface c_segment, CLoadFieldDataInfo load_field_data_info);
|
||||
*/
|
||||
if s.segmentPtr == nil {
|
||||
return errors.New("null seg core pointer")
|
||||
}
|
||||
if s.segmentType != segmentTypeSealed {
|
||||
if s.getType() != segmentTypeSealed {
|
||||
errMsg := fmt.Sprintln("segmentLoadFieldData failed, illegal segment type ", s.segmentType, "segmentID = ", s.ID())
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
if !s.healthy() {
|
||||
return fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
|
||||
}
|
||||
|
||||
dataBlob, err := proto.Marshal(data)
|
||||
if err != nil {
|
||||
|
@ -811,8 +847,10 @@ func (s *Segment) segmentLoadFieldData(fieldID int64, rowCount int64, data *sche
|
|||
}
|
||||
|
||||
func (s *Segment) segmentLoadDeletedRecord(primaryKeys []primaryKey, timestamps []Timestamp, rowCount int64) error {
|
||||
if s.segmentPtr == nil {
|
||||
return errors.New("null seg core pointer")
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
if !s.healthy() {
|
||||
return fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
|
||||
}
|
||||
|
||||
if len(primaryKeys) <= 0 {
|
||||
|
@ -893,15 +931,15 @@ func (s *Segment) segmentLoadIndexData(bytesIndex [][]byte, indexInfo *querypb.F
|
|||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if s.segmentPtr == nil {
|
||||
return errors.New("null seg core pointer")
|
||||
}
|
||||
|
||||
if s.segmentType != segmentTypeSealed {
|
||||
if s.getType() != segmentTypeSealed {
|
||||
errMsg := fmt.Sprintln("updateSegmentIndex failed, illegal segment type ", s.segmentType, "segmentID = ", s.ID())
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
s.mut.RLock()
|
||||
defer s.mut.RUnlock()
|
||||
if !s.healthy() {
|
||||
return fmt.Errorf("%w(segmentID=%d)", errSegmentUnhealthy, s.segmentID)
|
||||
}
|
||||
|
||||
var status C.CStatus
|
||||
s.pool.Submit(func() (interface{}, error) {
|
||||
|
|
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/milvus-io/milvus/api/commonpb"
|
||||
"github.com/milvus-io/milvus/api/schemapb"
|
||||
|
@ -904,11 +905,12 @@ func TestSegmentLoader_getFieldType(t *testing.T) {
|
|||
loader := &segmentLoader{metaReplica: replica}
|
||||
|
||||
// failed to get collection.
|
||||
segment := &Segment{segmentType: segmentTypeSealed}
|
||||
segment := &Segment{segmentType: atomic.NewInt32(0)}
|
||||
segment.setType(segmentTypeSealed)
|
||||
_, err := loader.getFieldType(segment, 100)
|
||||
assert.Error(t, err)
|
||||
|
||||
segment.segmentType = segmentTypeGrowing
|
||||
segment.setType(segmentTypeGrowing)
|
||||
_, err = loader.getFieldType(segment, 100)
|
||||
assert.Error(t, err)
|
||||
|
||||
|
@ -927,12 +929,12 @@ func TestSegmentLoader_getFieldType(t *testing.T) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
segment.segmentType = segmentTypeGrowing
|
||||
segment.setType(segmentTypeGrowing)
|
||||
fieldType, err := loader.getFieldType(segment, 100)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, schemapb.DataType_Int64, fieldType)
|
||||
|
||||
segment.segmentType = segmentTypeSealed
|
||||
segment.setType(segmentTypeSealed)
|
||||
fieldType, err = loader.getFieldType(segment, 100)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, schemapb.DataType_Int64, fieldType)
|
||||
|
|
|
@ -89,7 +89,7 @@ func TestSegment_deleteSegment(t *testing.T) {
|
|||
t.Run("test delete nil ptr", func(t *testing.T) {
|
||||
s, err := genSimpleSealedSegment(defaultMsgLength)
|
||||
assert.NoError(t, err)
|
||||
s.segmentPtr = nil
|
||||
s.setUnhealthy()
|
||||
deleteSegment(s)
|
||||
})
|
||||
}
|
||||
|
@ -134,7 +134,7 @@ func TestSegment_getRowCount(t *testing.T) {
|
|||
t.Run("test getRowCount nil ptr", func(t *testing.T) {
|
||||
s, err := genSimpleSealedSegment(defaultMsgLength)
|
||||
assert.NoError(t, err)
|
||||
s.segmentPtr = nil
|
||||
s.setUnhealthy()
|
||||
res := s.getRowCount()
|
||||
assert.Equal(t, int64(-1), res)
|
||||
})
|
||||
|
@ -273,7 +273,7 @@ func TestSegment_getDeletedCount(t *testing.T) {
|
|||
t.Run("test getDeletedCount nil ptr", func(t *testing.T) {
|
||||
s, err := genSimpleSealedSegment(defaultMsgLength)
|
||||
assert.NoError(t, err)
|
||||
s.segmentPtr = nil
|
||||
s.setUnhealthy()
|
||||
res := s.getDeletedCount()
|
||||
assert.Equal(t, int64(-1), res)
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue