Fix upgrade casue panic (#23833)

Signed-off-by: xiaofan-luan <xiaofan.luan@zilliz.com>
pull/23844/head
Xiaofan 2023-05-01 23:06:37 -07:00 committed by GitHub
parent 4b26c0afb3
commit 87d790f052
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 136 additions and 56 deletions

View File

@ -630,7 +630,10 @@ func (s *Server) watchNodes(revision int64) {
)
s.nodeMgr.Add(session.NewNodeInfo(nodeID, addr))
s.nodeUpEventChan <- nodeID
s.notifyNodeUp <- struct{}{}
select {
case s.notifyNodeUp <- struct{}{}:
default:
}
case sessionutil.SessionUpdateEvent:
nodeID := event.Session.ServerID

View File

@ -20,6 +20,8 @@ import (
"container/heap"
"context"
"fmt"
"strconv"
"strings"
"sync"
"github.com/milvus-io/milvus/pkg/metrics"
@ -30,6 +32,7 @@ import (
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type dmlMsgStream struct {
@ -356,3 +359,32 @@ func genChannelNames(prefix string, num int64) []string {
}
return results
}
func parseChannelNameIndex(channeName string) int {
index := strings.LastIndex(channeName, "_")
if index < 0 {
log.Error("invalid channel name", zap.String("chanName", channeName))
panic("invalid channel name: " + channeName)
}
index, err := strconv.Atoi(channeName[index+1:])
if err != nil {
log.Error("invalid channel name", zap.String("chanName", channeName), zap.Error(err))
panic("invalid channel name: " + channeName)
}
return index
}
func getNeedChanNum(setNum int, chanMap map[typeutil.UniqueID][]string) int {
// find the largest number of current channel usage
maxChanUsed := setNum
for _, chanNames := range chanMap {
for _, chanName := range chanNames {
index := parseChannelNameIndex(chanName)
if maxChanUsed < index+1 {
maxChanUsed = index + 1
}
}
}
return maxChanUsed
}

View File

@ -113,10 +113,10 @@ func (c *chanTsMsg) getTimetick(channelName string) typeutil.Timestamp {
func newTimeTickSync(ctx context.Context, sourceID int64, factory msgstream.Factory, chanMap map[typeutil.UniqueID][]string) *timetickSync {
// if the old channels number used by the user is greater than the set default value currently
// keep the old channels
defaultChanNum := getNeedChanNum(Params.RootCoordCfg.DmlChannelNum.GetAsInt(), chanMap)
chanNum := getNeedChanNum(Params.RootCoordCfg.DmlChannelNum.GetAsInt(), chanMap)
// initialize dml channels used for insert
dmlChannels := newDmlChannels(ctx, factory, Params.CommonCfg.RootCoordDml.GetValue(), int64(defaultChanNum))
dmlChannels := newDmlChannels(ctx, factory, Params.CommonCfg.RootCoordDml.GetValue(), int64(chanNum))
// recover physical channels for all collections
for collID, chanNames := range chanMap {

View File

@ -176,3 +176,80 @@ func Test_ttHistogram_get(t *testing.T) {
assert.Equal(t, typeutil.ZeroTimestamp, h.get("ch1"))
assert.Equal(t, typeutil.ZeroTimestamp, h.get("ch2"))
}
func TestTimetickSyncWithExistChannels(t *testing.T) {
ctx := context.Background()
sourceID := int64(100)
factory := dependency.NewDefaultFactory(true)
//chanMap := map[typeutil.UniqueID][]string{
// int64(1): {"rootcoord-dml_0"},
//}
var baseParams = paramtable.BaseTable{}
baseParams.Save("msgChannel.chanNamePrefix.rootCoordDml", "common.chanNamePrefix.rootCoordDml")
baseParams.Save("msgChannel.chanNamePrefix.rootCoordDelta", "common.chanNamePrefix.rootCoordDelta")
chans := map[UniqueID][]string{}
chans[UniqueID(100)] = []string{"rootcoord-dml_4", "rootcoord-dml_8"}
chans[UniqueID(102)] = []string{"rootcoord-dml_2", "rootcoord-dml_9"}
ttSync := newTimeTickSync(ctx, sourceID, factory, chans)
var wg sync.WaitGroup
wg.Add(1)
t.Run("sendToChannel", func(t *testing.T) {
defer wg.Done()
ttSync.sendToChannel()
ttSync.sess2ChanTsMap[1] = nil
ttSync.sendToChannel()
msg := &internalpb.ChannelTimeTickMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_TimeTick,
},
}
ttSync.sess2ChanTsMap[1] = newChanTsMsg(msg, 1)
ttSync.sendToChannel()
})
wg.Add(1)
t.Run("assign channels", func(t *testing.T) {
defer wg.Done()
channels := ttSync.getDmlChannelNames(int(4))
assert.Equal(t, channels, []string{"rootcoord-dml_0", "rootcoord-dml_1", "rootcoord-dml_3", "rootcoord-dml_5"})
channels = ttSync.getDmlChannelNames(int(4))
assert.Equal(t, channels, []string{"rootcoord-dml_6", "rootcoord-dml_7", "rootcoord-dml_0", "rootcoord-dml_1"})
})
// test get new channels
}
func TestTimetickSyncInvalidName(t *testing.T) {
ctx := context.Background()
sourceID := int64(100)
factory := dependency.NewDefaultFactory(true)
//chanMap := map[typeutil.UniqueID][]string{
// int64(1): {"rootcoord-dml_0"},
//}
var baseParams = paramtable.BaseTable{}
baseParams.Save("msgChannel.chanNamePrefix.rootCoordDml", "common.chanNamePrefix.rootCoordDml")
baseParams.Save("msgChannel.chanNamePrefix.rootCoordDelta", "common.chanNamePrefix.rootCoordDelta")
chans := map[UniqueID][]string{}
chans[UniqueID(100)] = []string{"rootcoord-dml4"}
assert.Panics(t, func() {
newTimeTickSync(ctx, sourceID, factory, chans)
})
chans = map[UniqueID][]string{}
chans[UniqueID(102)] = []string{"rootcoord-dml_a"}
assert.Panics(t, func() {
newTimeTickSync(ctx, sourceID, factory, chans)
})
}

View File

@ -135,16 +135,3 @@ func getTravelTs(req TimeTravelRequest) Timestamp {
func isMaxTs(ts Timestamp) bool {
return ts == typeutil.MaxTimestamp
}
func getNeedChanNum(setNum int, chanMap map[typeutil.UniqueID][]string) int {
chanNames := typeutil.NewSet[string]()
for _, chanName := range chanMap {
chanNames.Insert(chanName...)
}
ret := chanNames.Len()
if setNum > chanNames.Len() {
ret = setNum
}
return ret
}

View File

@ -148,22 +148,3 @@ func Test_isMaxTs(t *testing.T) {
})
}
}
func Test_GetNeedChanNum(t *testing.T) {
chanMap := map[typeutil.UniqueID][]string{
int64(1): {"rootcoord-dml_101"},
int64(2): {"rootcoord-dml_102"},
int64(3): {"rootcoord-dml_103"},
}
num := getNeedChanNum(2, chanMap)
assert.Equal(t, num, 3)
chanMap = map[typeutil.UniqueID][]string{
int64(1): {"rootcoord-dml_101", "rootcoord-dml_102"},
int64(2): {"rootcoord-dml_102", "rootcoord-dml_101"},
int64(3): {"rootcoord-dml_103", "rootcoord-dml_102"},
}
num = getNeedChanNum(2, chanMap)
assert.Equal(t, num, 3)
}

View File

@ -101,7 +101,7 @@ class TestCompactionParams(TestcaseBase):
target = c_plans.plans[0].target
# verify queryNode load the compacted segments
cost = 30
cost = 180
start = time()
while time() - start < cost:
collection_w.load()
@ -242,8 +242,8 @@ class TestCompactionParams(TestcaseBase):
while True:
if collection_w.num_entities == exp_num_entities_after_compact:
break
if time() - start > 60:
raise MilvusException(1, "Auto delete ratio compaction cost more than 60s")
if time() - start > 180:
raise MilvusException(1, "Auto delete ratio compaction cost more than 180s")
sleep(1)
collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index)
@ -388,10 +388,10 @@ class TestCompactionParams(TestcaseBase):
collection_w.load()
replicas = collection_w.get_replicas()[0]
replica_num = len(replicas.groups)
cost = 60
cost = 180
start = time()
while time() - start < cost:
sleep(2.0)
sleep(1.0)
collection_w.load()
segment_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0]
if len(segment_info) == 1*replica_num:
@ -598,10 +598,10 @@ class TestCompactionOperation(TestcaseBase):
c_plans = collection_w.get_compaction_plans(check_task=CheckTasks.check_merge_compact)[0]
# waiting for handoff completed and search
cost = 60
cost = 180
start = time()
while True:
sleep(5)
sleep(1)
segment_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0]
if len(segment_info) != 0 and segment_info[0].segmentID == c_plans.plans[0].target:
log.debug(segment_info)
@ -877,9 +877,9 @@ class TestCompactionOperation(TestcaseBase):
collection_w.load()
start = time()
cost = 60
cost = 180
while True:
sleep(5)
sleep(1)
segments_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0]
# verify segments reaches threshold, auto-merge ten segments into one
@ -887,7 +887,7 @@ class TestCompactionOperation(TestcaseBase):
break
end = time()
if end - start > cost:
raise MilvusException(1, "Compact merge two segments more than 60s")
raise MilvusException(1, "Compact merge two segments more than 180s")
assert c_plans.plans[0].target == segments_info[0].segmentID
@pytest.mark.tags(CaseLabel.L2)
@ -949,9 +949,9 @@ class TestCompactionOperation(TestcaseBase):
collection_w.load()
start = time()
cost = 60
cost = 180
while True:
sleep(5)
sleep(1)
segments_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0]
# verify segments reaches threshold, auto-merge ten segments into one
@ -959,7 +959,7 @@ class TestCompactionOperation(TestcaseBase):
break
end = time()
if end - start > cost:
raise MilvusException(1, "Compact auto and manual more than 60s")
raise MilvusException(1, "Compact auto and manual more than 180s")
assert segments_info[0].segmentID == c_plans.plans[0].target
@pytest.mark.tags(CaseLabel.L1)
@ -988,10 +988,10 @@ class TestCompactionOperation(TestcaseBase):
target = c_plans.plans[0].target
collection_w.load()
cost = 60
cost = 180
start = time()
while True:
sleep(5)
sleep(1)
segments_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0]
# verify segments reaches threshold, auto-merge ten segments into one
@ -999,7 +999,7 @@ class TestCompactionOperation(TestcaseBase):
break
end = time()
if end - start > cost:
raise MilvusException(1, "Compact merge multiple segments more than 60s")
raise MilvusException(1, "Compact merge multiple segments more than 180s")
replicas = collection_w.get_replicas()[0]
replica_num = len(replicas.groups)
assert len(segments_info) == 1*replica_num
@ -1059,13 +1059,13 @@ class TestCompactionOperation(TestcaseBase):
log.debug(collection_w.index())
# Estimated auto-merging takes 30s
cost = 120
cost = 180
collection_w.load()
replicas = collection_w.get_replicas()[0]
replica_num = len(replicas.groups)
start = time()
while True:
sleep(5)
sleep(1)
segments_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0]
# verify segments reaches threshold, auto-merge ten segments into one
@ -1073,7 +1073,7 @@ class TestCompactionOperation(TestcaseBase):
break
end = time()
if end - start > cost:
raise MilvusException(1, "Compact auto-merge more than 60s")
raise MilvusException(1, "Compact auto-merge more than 180s")
@pytest.mark.tags(CaseLabel.L2)
def test_compact_less_threshold_no_merge(self):
@ -1298,10 +1298,10 @@ class TestCompactionOperation(TestcaseBase):
collection_w.get_compaction_plans()
# waitting for new segment index and compact
compact_cost = 120
compact_cost = 180
start = time()
while True:
sleep(10)
sleep(1)
collection_w.load()
# new segment compacted
seg_info = self.utility_wrap.get_query_segment_info(collection_w.name)[0]