From 87d790f052f06291b450b7d37ac219be4cb2985e Mon Sep 17 00:00:00 2001 From: Xiaofan <83447078+xiaofan-luan@users.noreply.github.com> Date: Mon, 1 May 2023 23:06:37 -0700 Subject: [PATCH] Fix upgrade casue panic (#23833) Signed-off-by: xiaofan-luan --- internal/querycoordv2/server.go | 5 +- internal/rootcoord/dml_channels.go | 32 ++++++++ internal/rootcoord/timeticksync.go | 4 +- internal/rootcoord/timeticksync_test.go | 77 +++++++++++++++++++ internal/rootcoord/util.go | 13 ---- internal/rootcoord/util_test.go | 19 ----- .../testcases/test_compaction.py | 42 +++++----- 7 files changed, 136 insertions(+), 56 deletions(-) diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index 06aa788c7c..ebfb244fab 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -630,7 +630,10 @@ func (s *Server) watchNodes(revision int64) { ) s.nodeMgr.Add(session.NewNodeInfo(nodeID, addr)) s.nodeUpEventChan <- nodeID - s.notifyNodeUp <- struct{}{} + select { + case s.notifyNodeUp <- struct{}{}: + default: + } case sessionutil.SessionUpdateEvent: nodeID := event.Session.ServerID diff --git a/internal/rootcoord/dml_channels.go b/internal/rootcoord/dml_channels.go index ff50351502..78d6cb00b3 100644 --- a/internal/rootcoord/dml_channels.go +++ b/internal/rootcoord/dml_channels.go @@ -20,6 +20,8 @@ import ( "container/heap" "context" "fmt" + "strconv" + "strings" "sync" "github.com/milvus-io/milvus/pkg/metrics" @@ -30,6 +32,7 @@ import ( "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type dmlMsgStream struct { @@ -356,3 +359,32 @@ func genChannelNames(prefix string, num int64) []string { } return results } + +func parseChannelNameIndex(channeName string) int { + index := strings.LastIndex(channeName, "_") + if index < 0 { + log.Error("invalid channel name", zap.String("chanName", channeName)) + panic("invalid channel name: " + channeName) + } + index, err := strconv.Atoi(channeName[index+1:]) + if err != nil { + log.Error("invalid channel name", zap.String("chanName", channeName), zap.Error(err)) + panic("invalid channel name: " + channeName) + } + return index +} + +func getNeedChanNum(setNum int, chanMap map[typeutil.UniqueID][]string) int { + // find the largest number of current channel usage + maxChanUsed := setNum + for _, chanNames := range chanMap { + for _, chanName := range chanNames { + index := parseChannelNameIndex(chanName) + if maxChanUsed < index+1 { + maxChanUsed = index + 1 + } + } + } + + return maxChanUsed +} diff --git a/internal/rootcoord/timeticksync.go b/internal/rootcoord/timeticksync.go index 7d0a9ab431..3ea59f3df9 100644 --- a/internal/rootcoord/timeticksync.go +++ b/internal/rootcoord/timeticksync.go @@ -113,10 +113,10 @@ func (c *chanTsMsg) getTimetick(channelName string) typeutil.Timestamp { func newTimeTickSync(ctx context.Context, sourceID int64, factory msgstream.Factory, chanMap map[typeutil.UniqueID][]string) *timetickSync { // if the old channels number used by the user is greater than the set default value currently // keep the old channels - defaultChanNum := getNeedChanNum(Params.RootCoordCfg.DmlChannelNum.GetAsInt(), chanMap) + chanNum := getNeedChanNum(Params.RootCoordCfg.DmlChannelNum.GetAsInt(), chanMap) // initialize dml channels used for insert - dmlChannels := newDmlChannels(ctx, factory, Params.CommonCfg.RootCoordDml.GetValue(), int64(defaultChanNum)) + dmlChannels := newDmlChannels(ctx, factory, Params.CommonCfg.RootCoordDml.GetValue(), int64(chanNum)) // recover physical channels for all collections for collID, chanNames := range chanMap { diff --git a/internal/rootcoord/timeticksync_test.go b/internal/rootcoord/timeticksync_test.go index d70b70fd37..dad8e95b48 100644 --- a/internal/rootcoord/timeticksync_test.go +++ b/internal/rootcoord/timeticksync_test.go @@ -176,3 +176,80 @@ func Test_ttHistogram_get(t *testing.T) { assert.Equal(t, typeutil.ZeroTimestamp, h.get("ch1")) assert.Equal(t, typeutil.ZeroTimestamp, h.get("ch2")) } + +func TestTimetickSyncWithExistChannels(t *testing.T) { + ctx := context.Background() + sourceID := int64(100) + + factory := dependency.NewDefaultFactory(true) + + //chanMap := map[typeutil.UniqueID][]string{ + // int64(1): {"rootcoord-dml_0"}, + //} + + var baseParams = paramtable.BaseTable{} + baseParams.Save("msgChannel.chanNamePrefix.rootCoordDml", "common.chanNamePrefix.rootCoordDml") + baseParams.Save("msgChannel.chanNamePrefix.rootCoordDelta", "common.chanNamePrefix.rootCoordDelta") + chans := map[UniqueID][]string{} + + chans[UniqueID(100)] = []string{"rootcoord-dml_4", "rootcoord-dml_8"} + chans[UniqueID(102)] = []string{"rootcoord-dml_2", "rootcoord-dml_9"} + ttSync := newTimeTickSync(ctx, sourceID, factory, chans) + + var wg sync.WaitGroup + wg.Add(1) + t.Run("sendToChannel", func(t *testing.T) { + defer wg.Done() + ttSync.sendToChannel() + + ttSync.sess2ChanTsMap[1] = nil + ttSync.sendToChannel() + + msg := &internalpb.ChannelTimeTickMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_TimeTick, + }, + } + ttSync.sess2ChanTsMap[1] = newChanTsMsg(msg, 1) + ttSync.sendToChannel() + }) + + wg.Add(1) + t.Run("assign channels", func(t *testing.T) { + defer wg.Done() + channels := ttSync.getDmlChannelNames(int(4)) + assert.Equal(t, channels, []string{"rootcoord-dml_0", "rootcoord-dml_1", "rootcoord-dml_3", "rootcoord-dml_5"}) + + channels = ttSync.getDmlChannelNames(int(4)) + assert.Equal(t, channels, []string{"rootcoord-dml_6", "rootcoord-dml_7", "rootcoord-dml_0", "rootcoord-dml_1"}) + }) + + // test get new channels + +} + +func TestTimetickSyncInvalidName(t *testing.T) { + ctx := context.Background() + sourceID := int64(100) + + factory := dependency.NewDefaultFactory(true) + + //chanMap := map[typeutil.UniqueID][]string{ + // int64(1): {"rootcoord-dml_0"}, + //} + + var baseParams = paramtable.BaseTable{} + baseParams.Save("msgChannel.chanNamePrefix.rootCoordDml", "common.chanNamePrefix.rootCoordDml") + baseParams.Save("msgChannel.chanNamePrefix.rootCoordDelta", "common.chanNamePrefix.rootCoordDelta") + chans := map[UniqueID][]string{} + chans[UniqueID(100)] = []string{"rootcoord-dml4"} + assert.Panics(t, func() { + newTimeTickSync(ctx, sourceID, factory, chans) + }) + + chans = map[UniqueID][]string{} + chans[UniqueID(102)] = []string{"rootcoord-dml_a"} + assert.Panics(t, func() { + newTimeTickSync(ctx, sourceID, factory, chans) + }) +} diff --git a/internal/rootcoord/util.go b/internal/rootcoord/util.go index b9eac5c6d5..3076c96c45 100644 --- a/internal/rootcoord/util.go +++ b/internal/rootcoord/util.go @@ -135,16 +135,3 @@ func getTravelTs(req TimeTravelRequest) Timestamp { func isMaxTs(ts Timestamp) bool { return ts == typeutil.MaxTimestamp } - -func getNeedChanNum(setNum int, chanMap map[typeutil.UniqueID][]string) int { - chanNames := typeutil.NewSet[string]() - for _, chanName := range chanMap { - chanNames.Insert(chanName...) - } - - ret := chanNames.Len() - if setNum > chanNames.Len() { - ret = setNum - } - return ret -} diff --git a/internal/rootcoord/util_test.go b/internal/rootcoord/util_test.go index 3665a9738e..6046f6d78b 100644 --- a/internal/rootcoord/util_test.go +++ b/internal/rootcoord/util_test.go @@ -148,22 +148,3 @@ func Test_isMaxTs(t *testing.T) { }) } } - -func Test_GetNeedChanNum(t *testing.T) { - chanMap := map[typeutil.UniqueID][]string{ - int64(1): {"rootcoord-dml_101"}, - int64(2): {"rootcoord-dml_102"}, - int64(3): {"rootcoord-dml_103"}, - } - - num := getNeedChanNum(2, chanMap) - assert.Equal(t, num, 3) - - chanMap = map[typeutil.UniqueID][]string{ - int64(1): {"rootcoord-dml_101", "rootcoord-dml_102"}, - int64(2): {"rootcoord-dml_102", "rootcoord-dml_101"}, - int64(3): {"rootcoord-dml_103", "rootcoord-dml_102"}, - } - num = getNeedChanNum(2, chanMap) - assert.Equal(t, num, 3) -} diff --git a/tests/python_client/testcases/test_compaction.py b/tests/python_client/testcases/test_compaction.py index 5581ddb070..174d4c714b 100644 --- a/tests/python_client/testcases/test_compaction.py +++ b/tests/python_client/testcases/test_compaction.py @@ -101,7 +101,7 @@ class TestCompactionParams(TestcaseBase): target = c_plans.plans[0].target # verify queryNode load the compacted segments - cost = 30 + cost = 180 start = time() while time() - start < cost: collection_w.load() @@ -242,8 +242,8 @@ class TestCompactionParams(TestcaseBase): while True: if collection_w.num_entities == exp_num_entities_after_compact: break - if time() - start > 60: - raise MilvusException(1, "Auto delete ratio compaction cost more than 60s") + if time() - start > 180: + raise MilvusException(1, "Auto delete ratio compaction cost more than 180s") sleep(1) collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) @@ -388,10 +388,10 @@ class TestCompactionParams(TestcaseBase): collection_w.load() replicas = collection_w.get_replicas()[0] replica_num = len(replicas.groups) - cost = 60 + cost = 180 start = time() while time() - start < cost: - sleep(2.0) + sleep(1.0) collection_w.load() segment_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] if len(segment_info) == 1*replica_num: @@ -598,10 +598,10 @@ class TestCompactionOperation(TestcaseBase): c_plans = collection_w.get_compaction_plans(check_task=CheckTasks.check_merge_compact)[0] # waiting for handoff completed and search - cost = 60 + cost = 180 start = time() while True: - sleep(5) + sleep(1) segment_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] if len(segment_info) != 0 and segment_info[0].segmentID == c_plans.plans[0].target: log.debug(segment_info) @@ -877,9 +877,9 @@ class TestCompactionOperation(TestcaseBase): collection_w.load() start = time() - cost = 60 + cost = 180 while True: - sleep(5) + sleep(1) segments_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] # verify segments reaches threshold, auto-merge ten segments into one @@ -887,7 +887,7 @@ class TestCompactionOperation(TestcaseBase): break end = time() if end - start > cost: - raise MilvusException(1, "Compact merge two segments more than 60s") + raise MilvusException(1, "Compact merge two segments more than 180s") assert c_plans.plans[0].target == segments_info[0].segmentID @pytest.mark.tags(CaseLabel.L2) @@ -949,9 +949,9 @@ class TestCompactionOperation(TestcaseBase): collection_w.load() start = time() - cost = 60 + cost = 180 while True: - sleep(5) + sleep(1) segments_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] # verify segments reaches threshold, auto-merge ten segments into one @@ -959,7 +959,7 @@ class TestCompactionOperation(TestcaseBase): break end = time() if end - start > cost: - raise MilvusException(1, "Compact auto and manual more than 60s") + raise MilvusException(1, "Compact auto and manual more than 180s") assert segments_info[0].segmentID == c_plans.plans[0].target @pytest.mark.tags(CaseLabel.L1) @@ -988,10 +988,10 @@ class TestCompactionOperation(TestcaseBase): target = c_plans.plans[0].target collection_w.load() - cost = 60 + cost = 180 start = time() while True: - sleep(5) + sleep(1) segments_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] # verify segments reaches threshold, auto-merge ten segments into one @@ -999,7 +999,7 @@ class TestCompactionOperation(TestcaseBase): break end = time() if end - start > cost: - raise MilvusException(1, "Compact merge multiple segments more than 60s") + raise MilvusException(1, "Compact merge multiple segments more than 180s") replicas = collection_w.get_replicas()[0] replica_num = len(replicas.groups) assert len(segments_info) == 1*replica_num @@ -1059,13 +1059,13 @@ class TestCompactionOperation(TestcaseBase): log.debug(collection_w.index()) # Estimated auto-merging takes 30s - cost = 120 + cost = 180 collection_w.load() replicas = collection_w.get_replicas()[0] replica_num = len(replicas.groups) start = time() while True: - sleep(5) + sleep(1) segments_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0] # verify segments reaches threshold, auto-merge ten segments into one @@ -1073,7 +1073,7 @@ class TestCompactionOperation(TestcaseBase): break end = time() if end - start > cost: - raise MilvusException(1, "Compact auto-merge more than 60s") + raise MilvusException(1, "Compact auto-merge more than 180s") @pytest.mark.tags(CaseLabel.L2) def test_compact_less_threshold_no_merge(self): @@ -1298,10 +1298,10 @@ class TestCompactionOperation(TestcaseBase): collection_w.get_compaction_plans() # waitting for new segment index and compact - compact_cost = 120 + compact_cost = 180 start = time() while True: - sleep(10) + sleep(1) collection_w.load() # new segment compacted seg_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0]