diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index aee332f98f..408205975a 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -55,10 +55,10 @@ import ( type Cache interface { // GetCollectionID get collection's id by name. GetCollectionID(ctx context.Context, database, collectionName string) (typeutil.UniqueID, error) - // GetDatabaseAndCollectionName get collection's name and database by id - GetDatabaseAndCollectionName(ctx context.Context, collectionID int64) (string, string, error) + // GetCollectionName get collection's name and database by id + GetCollectionName(ctx context.Context, collectionID int64) (string, error) // GetCollectionInfo get collection's information by name or collection id, such as schema, and etc. - GetCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionInfo, error) + GetCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionBasicInfo, error) // GetPartitionID get partition's identifier of specific collection. GetPartitionID(ctx context.Context, database, collectionName string, partitionName string) (typeutil.UniqueID, error) // GetPartitions get all partitions' id of specific collection. @@ -87,6 +87,14 @@ type Cache interface { RemoveDatabase(ctx context.Context, database string) } +type collectionBasicInfo struct { + collID typeutil.UniqueID + createdTimestamp uint64 + createdUtcTimestamp uint64 + consistencyLevel commonpb.ConsistencyLevel + partInfo map[string]*partitionInfo +} + type collectionInfo struct { collID typeutil.UniqueID schema *schemapb.CollectionSchema @@ -96,7 +104,23 @@ type collectionInfo struct { createdTimestamp uint64 createdUtcTimestamp uint64 consistencyLevel commonpb.ConsistencyLevel - database string +} + +// getBasicInfo get a basic info by deep copy. +func (info *collectionInfo) getBasicInfo() *collectionBasicInfo { + // Do a deep copy for all fields. + basicInfo := &collectionBasicInfo{ + collID: info.collID, + createdTimestamp: info.createdTimestamp, + createdUtcTimestamp: info.createdUtcTimestamp, + consistencyLevel: info.consistencyLevel, + partInfo: make(map[string]*partitionInfo, len(info.partInfo)), + } + for s, info := range info.partInfo { + info2 := *info + basicInfo.partInfo[s] = &info2 + } + return basicInfo } func (info *collectionInfo) isCollectionCached() bool { @@ -252,8 +276,8 @@ func (m *MetaCache) GetCollectionID(ctx context.Context, database, collectionNam return collInfo.collID, nil } -// GetDatabaseAndCollectionName returns the corresponding collection name for provided collection id -func (m *MetaCache) GetDatabaseAndCollectionName(ctx context.Context, collectionID int64) (string, string, error) { +// GetCollectionName returns the corresponding collection name for provided collection id +func (m *MetaCache) GetCollectionName(ctx context.Context, collectionID int64) (string, error) { m.mu.RLock() var collInfo *collectionInfo for _, db := range m.collInfo { @@ -272,24 +296,22 @@ func (m *MetaCache) GetDatabaseAndCollectionName(ctx context.Context, collection m.mu.RUnlock() coll, err := m.describeCollection(ctx, "", "", collectionID) if err != nil { - return "", "", err + return "", err } m.mu.Lock() defer m.mu.Unlock() m.updateCollection(coll, coll.GetDbName(), coll.Schema.Name) metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) - return coll.GetDbName(), coll.Schema.Name, nil + return coll.Schema.Name, nil } defer m.mu.RUnlock() metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() - return collInfo.database, collInfo.schema.Name, nil + return collInfo.schema.Name, nil } -// GetCollectionInfo returns the collection information related to provided collection name -// If the information is not found, proxy will try to fetch information for other source (RootCoord for now) -func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionInfo, error) { +func (m *MetaCache) GetCollectionInfo(ctx context.Context, database string, collectionName string, collectionID int64) (*collectionBasicInfo, error) { m.mu.RLock() var collInfo *collectionInfo var ok bool @@ -298,12 +320,49 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionN if dbOk { collInfo, ok = db[collectionName] } - m.mu.RUnlock() method := "GetCollectionInfo" // if collInfo.collID != collectionID, means that the cache is not trustable // try to get collection according to collectionID if !ok || !collInfo.isCollectionCached() || collInfo.collID != collectionID { + m.mu.RUnlock() + tr := timerecord.NewTimeRecorder("UpdateCache") + metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() + coll, err := m.describeCollection(ctx, database, "", collectionID) + if err != nil { + return nil, err + } + m.mu.Lock() + defer m.mu.Unlock() + m.updateCollection(coll, database, collectionName) + collInfo = m.collInfo[database][collectionName] + metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) + return collInfo.getBasicInfo(), nil + } + defer m.mu.RUnlock() + + metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() + return collInfo.getBasicInfo(), nil +} + +// GetCollectionInfo returns the collection information related to provided collection name +// If the information is not found, proxy will try to fetch information for other source (RootCoord for now) +// TODO: may cause data race of this implementation, should be refactored in future. +func (m *MetaCache) getFullCollectionInfo(ctx context.Context, database, collectionName string, collectionID int64) (*collectionInfo, error) { + m.mu.RLock() + var collInfo *collectionInfo + var ok bool + + db, dbOk := m.collInfo[database] + if dbOk { + collInfo, ok = db[collectionName] + } + + method := "GetCollectionInfo" + // if collInfo.collID != collectionID, means that the cache is not trustable + // try to get collection according to collectionID + if !ok || !collInfo.isCollectionCached() || collInfo.collID != collectionID { + m.mu.RUnlock() tr := timerecord.NewTimeRecorder("UpdateCache") metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() var coll *milvuspb.DescribeCollectionResponse @@ -320,8 +379,10 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, database, collectionN collInfo = m.collInfo[database][collectionName] m.mu.Unlock() metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) + return collInfo, nil } + m.mu.RUnlock() metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() return collInfo, nil } @@ -707,7 +768,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col zap.String("collectionName", collectionName), zap.Int64("collectionID", collectionID)) - info, err := m.GetCollectionInfo(ctx, database, collectionName, collectionID) + info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID) if err != nil { return nil, err } @@ -764,7 +825,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col shards := parseShardLeaderList2QueryNode(resp.GetShards()) - info, err = m.GetCollectionInfo(ctx, database, collectionName, collectionID) + info, err = m.getFullCollectionInfo(ctx, database, collectionName, collectionID) if err != nil { return nil, fmt.Errorf("failed to get shards, collectionName %s, colectionID %d not found", collectionName, collectionID) } @@ -825,7 +886,6 @@ func (m *MetaCache) DeprecateShardCache(database, collectionName string) { if ok { info.deprecateLeaderCache() } - } func (m *MetaCache) expireShardLeaderCache(ctx context.Context) { diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index e29d241101..ab4974e343 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -259,7 +259,40 @@ func TestMetaCache_GetCollection(t *testing.T) { Fields: []*schemapb.FieldSchema{}, Name: "collection1", }) +} +func TestMetaCache_GetBasicCollectionInfo(t *testing.T) { + ctx := context.Background() + rootCoord := &MockRootCoordClientInterface{} + queryCoord := &mocks.MockQueryCoord{} + mgr := newShardClientMgr() + err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) + assert.NoError(t, err) + + // should be no data race. + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + info, err := globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1) + assert.NoError(t, err) + assert.Equal(t, info.collID, int64(1)) + _ = info.consistencyLevel + _ = info.createdTimestamp + _ = info.createdUtcTimestamp + _ = info.partInfo + }() + go func() { + defer wg.Done() + info, err := globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1) + assert.NoError(t, err) + assert.Equal(t, info.collID, int64(1)) + _ = info.consistencyLevel + _ = info.createdTimestamp + _ = info.createdUtcTimestamp + _ = info.partInfo + }() + wg.Wait() } func TestMetaCache_GetCollectionName(t *testing.T) { @@ -270,9 +303,8 @@ func TestMetaCache_GetCollectionName(t *testing.T) { err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) - db, collection, err := globalMetaCache.GetDatabaseAndCollectionName(ctx, 1) + collection, err := globalMetaCache.GetCollectionName(ctx, 1) assert.NoError(t, err) - assert.Equal(t, db, dbName) assert.Equal(t, collection, "collection1") assert.Equal(t, rootCoord.GetAccessCount(), 1) @@ -285,7 +317,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) { Fields: []*schemapb.FieldSchema{}, Name: "collection1", }) - _, collection, err = globalMetaCache.GetDatabaseAndCollectionName(ctx, 1) + collection, err = globalMetaCache.GetCollectionName(ctx, 1) assert.Equal(t, rootCoord.GetAccessCount(), 1) assert.NoError(t, err) assert.Equal(t, collection, "collection1") @@ -299,7 +331,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) { }) // test to get from cache, this should trigger root request - _, collection, err = globalMetaCache.GetDatabaseAndCollectionName(ctx, 1) + collection, err = globalMetaCache.GetCollectionName(ctx, 1) assert.Equal(t, rootCoord.GetAccessCount(), 2) assert.NoError(t, err) assert.Equal(t, collection, "collection1") @@ -397,7 +429,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) { getCollectionCacheFunc := func(wg *sync.WaitGroup) { defer wg.Done() for i := 0; i < cnt; i++ { - //GetCollectionSchema will never fail + // GetCollectionSchema will never fail schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1") assert.NoError(t, err) assert.Equal(t, schema, &schemapb.CollectionSchema{ @@ -412,7 +444,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) { getPartitionCacheFunc := func(wg *sync.WaitGroup) { defer wg.Done() for i := 0; i < cnt; i++ { - //GetPartitions may fail + // GetPartitions may fail globalMetaCache.GetPartitions(ctx, dbName, "collection1") time.Sleep(10 * time.Millisecond) } @@ -421,7 +453,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) { invalidCacheFunc := func(wg *sync.WaitGroup) { defer wg.Done() for i := 0; i < cnt; i++ { - //periodically invalid collection cache + // periodically invalid collection cache globalMetaCache.RemoveCollection(ctx, dbName, "collection1") time.Sleep(10 * time.Millisecond) } @@ -574,7 +606,6 @@ func TestMetaCache_ClearShards(t *testing.T) { }) t.Run("Clear valid collection valid cache", func(t *testing.T) { - qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, @@ -731,6 +762,13 @@ func TestMetaCache_RemoveCollection(t *testing.T) { assert.NoError(t, err) // shouldn't access RootCoord again assert.Equal(t, rootCoord.GetAccessCount(), 3) + + globalMetaCache.RemoveCollectionsByID(ctx, UniqueID(1)) + // no collectionInfo of collection2, should access RootCoord + _, err = globalMetaCache.GetCollectionInfo(ctx, dbName, "collection1", 1) + assert.NoError(t, err) + // no collectionInfo of collection1, should access RootCoord + assert.Equal(t, rootCoord.GetAccessCount(), 4) } func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { diff --git a/internal/proxy/mock_cache_test.go b/internal/proxy/mock_cache_test.go index ac9d87e24d..0e457ed524 100644 --- a/internal/proxy/mock_cache_test.go +++ b/internal/proxy/mock_cache_test.go @@ -115,19 +115,19 @@ func (_c *MockCache_GetCollectionID_Call) RunAndReturn(run func(context.Context, } // GetCollectionInfo provides a mock function with given fields: ctx, database, collectionName, collectionID -func (_m *MockCache) GetCollectionInfo(ctx context.Context, database string, collectionName string, collectionID int64) (*collectionInfo, error) { +func (_m *MockCache) GetCollectionInfo(ctx context.Context, database string, collectionName string, collectionID int64) (*collectionBasicInfo, error) { ret := _m.Called(ctx, database, collectionName, collectionID) - var r0 *collectionInfo + var r0 *collectionBasicInfo var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) (*collectionInfo, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) (*collectionBasicInfo, error)); ok { return rf(ctx, database, collectionName, collectionID) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) *collectionInfo); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) *collectionBasicInfo); ok { r0 = rf(ctx, database, collectionName, collectionID) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*collectionInfo) + r0 = ret.Get(0).(*collectionBasicInfo) } } @@ -161,12 +161,65 @@ func (_c *MockCache_GetCollectionInfo_Call) Run(run func(ctx context.Context, da return _c } -func (_c *MockCache_GetCollectionInfo_Call) Return(_a0 *collectionInfo, _a1 error) *MockCache_GetCollectionInfo_Call { +func (_c *MockCache_GetCollectionInfo_Call) Return(_a0 *collectionBasicInfo, _a1 error) *MockCache_GetCollectionInfo_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockCache_GetCollectionInfo_Call) RunAndReturn(run func(context.Context, string, string, int64) (*collectionInfo, error)) *MockCache_GetCollectionInfo_Call { +func (_c *MockCache_GetCollectionInfo_Call) RunAndReturn(run func(context.Context, string, string, int64) (*collectionBasicInfo, error)) *MockCache_GetCollectionInfo_Call { + _c.Call.Return(run) + return _c +} + +// GetCollectionName provides a mock function with given fields: ctx, collectionID +func (_m *MockCache) GetCollectionName(ctx context.Context, collectionID int64) (string, error) { + ret := _m.Called(ctx, collectionID) + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64) (string, error)); ok { + return rf(ctx, collectionID) + } + if rf, ok := ret.Get(0).(func(context.Context, int64) string); ok { + r0 = rf(ctx, collectionID) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, collectionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCache_GetCollectionName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionName' +type MockCache_GetCollectionName_Call struct { + *mock.Call +} + +// GetCollectionName is a helper method to define mock.On call +// - ctx context.Context +// - collectionID int64 +func (_e *MockCache_Expecter) GetCollectionName(ctx interface{}, collectionID interface{}) *MockCache_GetCollectionName_Call { + return &MockCache_GetCollectionName_Call{Call: _e.mock.On("GetCollectionName", ctx, collectionID)} +} + +func (_c *MockCache_GetCollectionName_Call) Run(run func(ctx context.Context, collectionID int64)) *MockCache_GetCollectionName_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockCache_GetCollectionName_Call) Return(_a0 string, _a1 error) *MockCache_GetCollectionName_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCache_GetCollectionName_Call) RunAndReturn(run func(context.Context, int64) (string, error)) *MockCache_GetCollectionName_Call { _c.Call.Return(run) return _c } @@ -282,66 +335,6 @@ func (_c *MockCache_GetCredentialInfo_Call) RunAndReturn(run func(context.Contex return _c } -// GetDatabaseAndCollectionName provides a mock function with given fields: ctx, collectionID -func (_m *MockCache) GetDatabaseAndCollectionName(ctx context.Context, collectionID int64) (string, string, error) { - ret := _m.Called(ctx, collectionID) - - var r0 string - var r1 string - var r2 error - if rf, ok := ret.Get(0).(func(context.Context, int64) (string, string, error)); ok { - return rf(ctx, collectionID) - } - if rf, ok := ret.Get(0).(func(context.Context, int64) string); ok { - r0 = rf(ctx, collectionID) - } else { - r0 = ret.Get(0).(string) - } - - if rf, ok := ret.Get(1).(func(context.Context, int64) string); ok { - r1 = rf(ctx, collectionID) - } else { - r1 = ret.Get(1).(string) - } - - if rf, ok := ret.Get(2).(func(context.Context, int64) error); ok { - r2 = rf(ctx, collectionID) - } else { - r2 = ret.Error(2) - } - - return r0, r1, r2 -} - -// MockCache_GetDatabaseAndCollectionName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDatabaseAndCollectionName' -type MockCache_GetDatabaseAndCollectionName_Call struct { - *mock.Call -} - -// GetDatabaseAndCollectionName is a helper method to define mock.On call -// - ctx context.Context -// - collectionID int64 -func (_e *MockCache_Expecter) GetDatabaseAndCollectionName(ctx interface{}, collectionID interface{}) *MockCache_GetDatabaseAndCollectionName_Call { - return &MockCache_GetDatabaseAndCollectionName_Call{Call: _e.mock.On("GetDatabaseAndCollectionName", ctx, collectionID)} -} - -func (_c *MockCache_GetDatabaseAndCollectionName_Call) Run(run func(ctx context.Context, collectionID int64)) *MockCache_GetDatabaseAndCollectionName_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(int64)) - }) - return _c -} - -func (_c *MockCache_GetDatabaseAndCollectionName_Call) Return(_a0 string, _a1 string, _a2 error) *MockCache_GetDatabaseAndCollectionName_Call { - _c.Call.Return(_a0, _a1, _a2) - return _c -} - -func (_c *MockCache_GetDatabaseAndCollectionName_Call) RunAndReturn(run func(context.Context, int64) (string, string, error)) *MockCache_GetDatabaseAndCollectionName_Call { - _c.Call.Return(run) - return _c -} - // GetPartitionID provides a mock function with given fields: ctx, database, collectionName, partitionName func (_m *MockCache) GetPartitionID(ctx context.Context, database string, collectionName string, partitionName string) (int64, error) { ret := _m.Called(ctx, database, collectionName, partitionName) diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 1dcc510af9..0c9cc5b235 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -521,7 +521,6 @@ func (dct *describeCollectionTask) Execute(ctx context.Context) error { } result, err := dct.rootCoord.DescribeCollection(ctx, dct.DescribeCollectionRequest) - if err != nil { return err } @@ -634,7 +633,6 @@ func (sct *showCollectionsTask) PreExecute(ctx context.Context) error { func (sct *showCollectionsTask) Execute(ctx context.Context) error { respFromRootCoord, err := sct.rootCoord.ShowCollections(ctx, sct.ShowCollectionsRequest) - if err != nil { return err } @@ -670,10 +668,9 @@ func (sct *showCollectionsTask) Execute(ctx context.Context) error { sct.Base, commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections), ), - //DbID: sct.ShowCollectionsRequest.DbName, + // DbID: sct.ShowCollectionsRequest.DbName, CollectionIDs: collectionIDs, }) - if err != nil { return err } @@ -1179,7 +1176,6 @@ func (spt *showPartitionsTask) Execute(ctx context.Context) error { CollectionID: collectionID, PartitionIDs: partitionIDs, }) - if err != nil { return err } @@ -2209,13 +2205,12 @@ func (t *DescribeResourceGroupTask) Execute(ctx context.Context) error { resp, err := t.queryCoord.DescribeResourceGroup(ctx, &querypb.DescribeResourceGroupRequest{ ResourceGroup: t.ResourceGroup, }) - if err != nil { return err } getCollectionNameFunc := func(value int32, key int64) string { - _, name, err := globalMetaCache.GetDatabaseAndCollectionName(ctx, key) + name, err := globalMetaCache.GetCollectionName(ctx, key) if err != nil { // unreachable logic path return "unavailable_collection"