From e91eafd871b6c2ea85d31e8e91193cda2c4eaefa Mon Sep 17 00:00:00 2001 From: yukun Date: Wed, 20 Oct 2021 16:34:36 +0800 Subject: [PATCH] Fix Bitsetview bug in segcore (#10272) Signed-off-by: fishpenguin --- .../index/thirdparty/faiss/utils/BitsetView.h | 11 ++++- .../thirdparty/faiss/utils/ConcurrentBitset.h | 2 +- .../core/src/segcore/SegmentGrowingImpl.cpp | 11 +++-- internal/querynode/flow_graph_insert_node.go | 47 +++++++------------ .../querynode/flow_graph_insert_node_test.go | 8 ++-- 5 files changed, 38 insertions(+), 41 deletions(-) diff --git a/internal/core/src/index/thirdparty/faiss/utils/BitsetView.h b/internal/core/src/index/thirdparty/faiss/utils/BitsetView.h index ecd1f89bcc..75c884cac5 100644 --- a/internal/core/src/index/thirdparty/faiss/utils/BitsetView.h +++ b/internal/core/src/index/thirdparty/faiss/utils/BitsetView.h @@ -24,7 +24,11 @@ class BitsetView { BitsetView(const uint8_t* blocks, int64_t size) : blocks_(blocks), size_(size) { } - BitsetView(const ConcurrentBitset& bitset) : blocks_(bitset.data()), size_(bitset.count()) { + BitsetView(const ConcurrentBitset& bitset) : size_(bitset.count()) { + // memcpy(block_data_.data(), bitset.data(), bitset.size()); + // blocks_ = block_data_.data(); + blocks_ = new uint8_t[bitset.size()]; + memcpy(mutable_data(), bitset.data(), bitset.size()); } BitsetView(const ConcurrentBitsetPtr& bitset_ptr) { @@ -59,6 +63,11 @@ class BitsetView { return blocks_; } + uint8_t* + mutable_data() { + return const_cast(blocks_); + } + operator bool() const { return !empty(); } diff --git a/internal/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.h b/internal/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.h index 44221f9322..4076fcc54a 100644 --- a/internal/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.h +++ b/internal/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.h @@ -30,7 +30,7 @@ class ConcurrentBitset { } } - explicit ConcurrentBitset(size_t count, const uint8_t* data) : bitset_(((count + 8 - 1) >> 3)) { + explicit ConcurrentBitset(size_t count, const uint8_t* data) : count_(count), bitset_(((count + 8 - 1) >> 3)) { memcpy(mutable_data(), data, (count + 8 - 1) >> 3); } diff --git a/internal/core/src/segcore/SegmentGrowingImpl.cpp b/internal/core/src/segcore/SegmentGrowingImpl.cpp index ec61dcfb87..abb17bed28 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.cpp +++ b/internal/core/src/segcore/SegmentGrowingImpl.cpp @@ -127,14 +127,14 @@ SegmentGrowingImpl::get_filtered_bitmap(const BitsetView& bitset, int64_t ins_ba } AssertInfo(bitmap_holder, "bitmap_holder is null"); auto deleted_bitmap = bitmap_holder->bitmap_ptr; - AssertInfo(deleted_bitmap->count() == bitset.u8size(), "Deleted bitmap count not equal to filtered bitmap count"); + AssertInfo(deleted_bitmap->count() == bitset.size(), "Deleted bitmap count not equal to filtered bitmap count"); - auto filtered_bitmap = - std::make_shared(faiss::ConcurrentBitset(bitset.u8size(), bitset.data())); + auto filtered_bitmap = std::make_shared(bitset.size(), bitset.data()); auto final_bitmap = (*deleted_bitmap.get()) | (*filtered_bitmap.get()); - return BitsetView(final_bitmap); + BitsetView res = BitsetView(final_bitmap); + return res; } Status @@ -245,10 +245,12 @@ SegmentGrowingImpl::Delete(int64_t reserved_begin, std::vector uids(size); std::vector timestamps(size); // #pragma omp parallel for + std::cout << "zzzz: " << size << std::endl; for (int index = 0; index < size; ++index) { auto [t, uid] = ordering[index]; timestamps[index] = t; uids[index] = uid; + std::cout << "In Segcore Delete: " << uid << std::endl; } deleted_record_.timestamps_.set_data(reserved_begin, timestamps.data(), size); deleted_record_.uids_.set_data(reserved_begin, uids.data(), size); @@ -293,7 +295,6 @@ SegmentGrowingImpl::vector_search(int64_t vec_count, Timestamp timestamp, const BitsetView& bitset, SearchResult& output) const { - // TODO(yukun): get final filtered bitmap auto& sealed_indexing = this->get_sealed_indexing_record(); if (sealed_indexing.is_ready(search_info.field_offset_)) { query::SearchOnSealed(this->get_schema(), sealed_indexing, search_info, query_data, query_count, bitset, diff --git a/internal/querynode/flow_graph_insert_node.go b/internal/querynode/flow_graph_insert_node.go index b6f22c5ace..9ced25e57d 100644 --- a/internal/querynode/flow_graph_insert_node.go +++ b/internal/querynode/flow_graph_insert_node.go @@ -172,43 +172,26 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { log.Warn(err.Error()) continue } - exist, err := filterSegmentsByPKs(delMsg.PrimaryKeys, segment) + pks, err := filterSegmentsByPKs(delMsg.PrimaryKeys, segment) if err != nil { log.Warn(err.Error()) continue } - if exist { - offset := segment.segmentPreDelete(len(delMsg.PrimaryKeys)) + if len(pks) > 0 { + offset := segment.segmentPreDelete(len(pks)) if err != nil { log.Warn(err.Error()) continue } - delData.deleteIDs[segmentID] = append(delData.deleteIDs[segmentID], delMsg.PrimaryKeys...) - delData.deleteTimestamps[segmentID] = append(delData.deleteTimestamps[segmentID], delMsg.Timestamps...) + delData.deleteIDs[segmentID] = append(delData.deleteIDs[segmentID], pks...) + // TODO(yukun) get offset of pks + delData.deleteTimestamps[segmentID] = append(delData.deleteTimestamps[segmentID], delMsg.Timestamps[:len(pks)]...) delData.deleteOffset[segmentID] = offset } } } - // 2. do preDelete - for segmentID := range delData.deleteIDs { - var targetSegment, err = iNode.replica.getSegmentByID(segmentID) - if err != nil { - log.Warn(err.Error()) - } - - var numOfRecords = len(delData.deleteIDs[segmentID]) - if targetSegment != nil { - offset := targetSegment.segmentPreDelete(numOfRecords) - if err != nil { - log.Warn(err.Error()) - } - delData.deleteOffset[segmentID] = offset - log.Debug("insertNode operator", zap.Int("delete size", numOfRecords), zap.Int64("delete offset", offset), zap.Int64("segment id", segmentID)) - } - } - - // 3. do delete + // 2. do delete for segmentID := range delData.deleteIDs { wg.Add(1) go iNode.delete(delData, segmentID, &wg) @@ -225,20 +208,24 @@ func (iNode *insertNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { return []Msg{res} } -func filterSegmentsByPKs(pks []int64, segment *Segment) (bool, error) { +func filterSegmentsByPKs(pks []int64, segment *Segment) ([]int64, error) { if pks == nil { - return false, fmt.Errorf("pks is nil when getSegmentsByPKs") + return nil, fmt.Errorf("pks is nil when getSegmentsByPKs") } if segment == nil { - return false, fmt.Errorf("segments is nil when getSegmentsByPKs") + return nil, fmt.Errorf("segments is nil when getSegmentsByPKs") } buf := make([]byte, 8) + res := make([]int64, 0) for _, pk := range pks { binary.BigEndian.PutUint64(buf, uint64(pk)) exist := segment.pkFilter.Test(buf) - return exist, nil + if exist { + res = append(res, pk) + } } - return false, nil + log.Debug("In filterSegmentsByPKs", zap.Any("pk", res), zap.Any("segment", segment.segmentID)) + return res, nil } func (iNode *insertNode) insert(iData *insertData, segmentID UniqueID, wg *sync.WaitGroup) { @@ -270,7 +257,7 @@ func (iNode *insertNode) insert(iData *insertData, segmentID UniqueID, wg *sync. } log.Debug("Do insert done", zap.Int("len", len(iData.insertIDs[segmentID])), - zap.Int64("segmentID", segmentID)) + zap.Int64("segmentID", segmentID), zap.Any("IDS", iData.insertPKs)) wg.Done() } diff --git a/internal/querynode/flow_graph_insert_node_test.go b/internal/querynode/flow_graph_insert_node_test.go index c7c98556a2..7e643cc854 100644 --- a/internal/querynode/flow_graph_insert_node_test.go +++ b/internal/querynode/flow_graph_insert_node_test.go @@ -373,13 +373,13 @@ func TestGetSegmentsByPKs(t *testing.T) { segmentID: 1, pkFilter: filter, } - exist, err := filterSegmentsByPKs([]int64{0, 1, 2, 3, 4}, segment) + pks, err := filterSegmentsByPKs([]int64{0, 1, 2, 3, 4}, segment) assert.Nil(t, err) - assert.True(t, exist) + assert.Equal(t, len(pks), 3) - exist, err = filterSegmentsByPKs([]int64{}, segment) + pks, err = filterSegmentsByPKs([]int64{}, segment) assert.Nil(t, err) - assert.False(t, exist) + assert.Equal(t, len(pks), 0) _, err = filterSegmentsByPKs(nil, segment) assert.NotNil(t, err) _, err = filterSegmentsByPKs([]int64{0, 1, 2, 3, 4}, nil)