enhance: cache collection schema attributes to reduce proxy cpu (#29668)

See also #29113

The collection schema is crucial when performing search/query but some
of the information is calculated for every request.

This PR change schema field of cached collection info into a utility
`schemaInfo` type to store some stable result, say pk field,
partitionKeyEnabled, etc. And provided field name to id map for
search/query services.

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/29138/merge
congqixia 2024-01-04 17:28:46 +08:00 committed by GitHub
parent a988daf143
commit 4f8c540c77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 185 additions and 118 deletions

View File

@ -120,7 +120,7 @@ func (h *HandlersV1) checkDatabase(ctx context.Context, c *gin.Context, dbName s
func (h *HandlersV1) describeCollection(ctx context.Context, c *gin.Context, dbName string, collectionName string) (*schemapb.CollectionSchema, error) {
collSchema, err := proxy.GetCachedCollectionSchema(ctx, dbName, collectionName)
if err == nil {
return collSchema, nil
return collSchema.CollectionSchema, nil
}
req := milvuspb.DescribeCollectionRequest{
DbName: dbName,

View File

@ -69,7 +69,7 @@ type Cache interface {
// GetPartitionsIndex returns a partition names in partition key indexed order.
GetPartitionsIndex(ctx context.Context, database, collectionName string) ([]string, error)
// GetCollectionSchema get collection's schema.
GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemapb.CollectionSchema, error)
GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemaInfo, error)
GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error)
DeprecateShardCache(database, collectionName string)
expireShardLeaderCache(ctx context.Context)
@ -99,9 +99,8 @@ type collectionBasicInfo struct {
}
type collectionInfo struct {
collID typeutil.UniqueID
schema *schemapb.CollectionSchema
// partInfo map[string]*partitionInfo
collID typeutil.UniqueID
schema *schemaInfo
partInfo *partitionInfos
leaderMutex sync.RWMutex
shardLeaders *shardLeaders
@ -110,6 +109,51 @@ type collectionInfo struct {
consistencyLevel commonpb.ConsistencyLevel
}
// schemaInfo is a helper function wraps *schemapb.CollectionSchema
// with extra fields mapping and methods
type schemaInfo struct {
*schemapb.CollectionSchema
fieldMap *typeutil.ConcurrentMap[string, int64] // field name to id mapping
hasPartitionKeyField bool
pkField *schemapb.FieldSchema
}
func newSchemaInfo(schema *schemapb.CollectionSchema) *schemaInfo {
fieldMap := typeutil.NewConcurrentMap[string, int64]()
hasPartitionkey := false
var pkField *schemapb.FieldSchema
for _, field := range schema.GetFields() {
fieldMap.Insert(field.GetName(), field.GetFieldID())
if field.GetIsPartitionKey() {
hasPartitionkey = true
}
if field.GetIsPrimaryKey() {
pkField = field
}
}
return &schemaInfo{
CollectionSchema: schema,
fieldMap: fieldMap,
hasPartitionKeyField: hasPartitionkey,
pkField: pkField,
}
}
func (s *schemaInfo) MapFieldID(name string) (int64, bool) {
return s.fieldMap.Get(name)
}
func (s *schemaInfo) IsPartitionKeyCollection() bool {
return s.hasPartitionKeyField
}
func (s *schemaInfo) GetPkField() (*schemapb.FieldSchema, error) {
if s.pkField == nil {
return nil, merr.WrapErrServiceInternal("pk field not found")
}
return s.pkField, nil
}
// partitionInfos contains the cached collection partition informations.
type partitionInfos struct {
partitionInfos []*partitionInfo
@ -396,7 +440,7 @@ func (m *MetaCache) getFullCollectionInfo(ctx context.Context, database, collect
return collInfo, nil
}
func (m *MetaCache) GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemapb.CollectionSchema, error) {
func (m *MetaCache) GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemaInfo, error) {
m.mu.RLock()
var collInfo *collectionInfo
var ok bool
@ -445,7 +489,7 @@ func (m *MetaCache) updateCollection(coll *milvuspb.DescribeCollectionResponse,
if !ok {
m.collInfo[database][collectionName] = &collectionInfo{}
}
m.collInfo[database][collectionName].schema = coll.Schema
m.collInfo[database][collectionName].schema = newSchemaInfo(coll.Schema)
m.collInfo[database][collectionName].collID = coll.CollectionID
m.collInfo[database][collectionName].createdTimestamp = coll.CreatedTimestamp
m.collInfo[database][collectionName].createdUtcTimestamp = coll.CreatedUtcTimestamp

View File

@ -208,7 +208,7 @@ func TestMetaCache_GetCollection(t *testing.T) {
schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.Equal(t, rootCoord.GetAccessCount(), 1)
assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection1",
@ -220,7 +220,7 @@ func TestMetaCache_GetCollection(t *testing.T) {
schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection2")
assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection2",
@ -234,7 +234,7 @@ func TestMetaCache_GetCollection(t *testing.T) {
schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection1",
@ -290,7 +290,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) {
schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.Equal(t, rootCoord.GetAccessCount(), 1)
assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection1",
@ -302,7 +302,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) {
schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection2")
assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection2",
@ -316,7 +316,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) {
schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.Equal(t, rootCoord.GetAccessCount(), 2)
assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection1",
@ -340,7 +340,7 @@ func TestMetaCache_GetCollectionFailure(t *testing.T) {
schema, err = globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection1",
@ -349,7 +349,7 @@ func TestMetaCache_GetCollectionFailure(t *testing.T) {
rootCoord.Error = true
// should be cached with no error
assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection1",
@ -410,7 +410,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) {
// GetCollectionSchema will never fail
schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, "collection1")
assert.NoError(t, err)
assert.Equal(t, schema, &schemapb.CollectionSchema{
assert.Equal(t, schema.CollectionSchema, &schemapb.CollectionSchema{
AutoID: true,
Fields: []*schemapb.FieldSchema{},
Name: "collection1",

View File

@ -8,8 +8,6 @@ import (
internalpb "github.com/milvus-io/milvus/internal/proto/internalpb"
mock "github.com/stretchr/testify/mock"
schemapb "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
typeutil "github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -226,19 +224,19 @@ func (_c *MockCache_GetCollectionName_Call) RunAndReturn(run func(context.Contex
}
// GetCollectionSchema provides a mock function with given fields: ctx, database, collectionName
func (_m *MockCache) GetCollectionSchema(ctx context.Context, database string, collectionName string) (*schemapb.CollectionSchema, error) {
func (_m *MockCache) GetCollectionSchema(ctx context.Context, database string, collectionName string) (*schemaInfo, error) {
ret := _m.Called(ctx, database, collectionName)
var r0 *schemapb.CollectionSchema
var r0 *schemaInfo
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (*schemapb.CollectionSchema, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, string, string) (*schemaInfo, error)); ok {
return rf(ctx, database, collectionName)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) *schemapb.CollectionSchema); ok {
if rf, ok := ret.Get(0).(func(context.Context, string, string) *schemaInfo); ok {
r0 = rf(ctx, database, collectionName)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*schemapb.CollectionSchema)
r0 = ret.Get(0).(*schemaInfo)
}
}
@ -271,12 +269,12 @@ func (_c *MockCache_GetCollectionSchema_Call) Run(run func(ctx context.Context,
return _c
}
func (_c *MockCache_GetCollectionSchema_Call) Return(_a0 *schemapb.CollectionSchema, _a1 error) *MockCache_GetCollectionSchema_Call {
func (_c *MockCache_GetCollectionSchema_Call) Return(_a0 *schemaInfo, _a1 error) *MockCache_GetCollectionSchema_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCache_GetCollectionSchema_Call) RunAndReturn(run func(context.Context, string, string) (*schemapb.CollectionSchema, error)) *MockCache_GetCollectionSchema_Call {
func (_c *MockCache_GetCollectionSchema_Call) RunAndReturn(run func(context.Context, string, string) (*schemaInfo, error)) *MockCache_GetCollectionSchema_Call {
_c.Call.Return(run)
return _c
}

View File

@ -1487,7 +1487,7 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) {
),
DbID: 0,
CollectionID: collID,
Schema: collSchema,
Schema: collSchema.CollectionSchema,
ReplicaNumber: t.ReplicaNumber,
FieldIndexID: fieldIndexIDs,
Refresh: t.Refresh,
@ -1738,7 +1738,7 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error {
DbID: 0,
CollectionID: collID,
PartitionIDs: partitionIDs,
Schema: collSchema,
Schema: collSchema.CollectionSchema,
ReplicaNumber: t.ReplicaNumber,
FieldIndexID: fieldIndexIDs,
Refresh: t.Refresh,

View File

@ -230,7 +230,7 @@ type deleteRunner struct {
tsoAllocatorIns tsoAllocator
// delete info
schema *schemapb.CollectionSchema
schema *schemaInfo
collectionID UniqueID
partitionID UniqueID
partitionKeyMode bool
@ -264,8 +264,8 @@ func (dr *deleteRunner) Init(ctx context.Context) error {
return ErrWithLog(log, "Failed to get collection schema", err)
}
dr.partitionKeyMode = hasParitionKeyModeField(dr.schema)
// get prititionIDs of delete
dr.partitionKeyMode = dr.schema.IsPartitionKeyCollection()
// get partitionIDs of delete
dr.partitionID = common.InvalidPartitionID
if len(dr.req.PartitionName) > 0 {
if dr.partitionKeyMode {
@ -300,12 +300,12 @@ func (dr *deleteRunner) Init(ctx context.Context) error {
}
func (dr *deleteRunner) Run(ctx context.Context) error {
plan, err := planparserv2.CreateRetrievePlan(dr.schema, dr.req.Expr)
plan, err := planparserv2.CreateRetrievePlan(dr.schema.CollectionSchema, dr.req.Expr)
if err != nil {
return fmt.Errorf("failed to create expr plan, expr = %s", dr.req.GetExpr())
}
isSimple, pk, numRow := getPrimaryKeysFromPlan(dr.schema, plan)
isSimple, pk, numRow := getPrimaryKeysFromPlan(dr.schema.CollectionSchema, plan)
if isSimple {
// if could get delete.primaryKeys from delete expr
err := dr.simpleDelete(ctx, pk, numRow)
@ -379,7 +379,7 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe
zap.Int64("nodeID", nodeID))
// set plan
_, outputFieldIDs := translatePkOutputFields(dr.schema)
_, outputFieldIDs := translatePkOutputFields(dr.schema.CollectionSchema)
outputFieldIDs = append(outputFieldIDs, common.TimeStampField)
plan.OutputFieldIds = outputFieldIDs

View File

@ -234,7 +234,7 @@ func TestDeleteRunner_Init(t *testing.T) {
// channels := []string{"test_channel"}
dbName := "test_1"
schema := &schemapb.CollectionSchema{
collSchema := &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
@ -253,6 +253,7 @@ func TestDeleteRunner_Init(t *testing.T) {
},
},
}
schema := newSchemaInfo(collSchema)
t.Run("empty collection name", func(t *testing.T) {
dr := deleteRunner{}
@ -312,7 +313,7 @@ func TestDeleteRunner_Init(t *testing.T) {
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(&schemapb.CollectionSchema{
).Return(newSchemaInfo(&schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
@ -325,7 +326,7 @@ func TestDeleteRunner_Init(t *testing.T) {
IsPartitionKey: true,
},
},
}, nil)
}), nil)
globalMetaCache = cache
assert.Error(t, dr.Init(context.Background()))
@ -440,7 +441,7 @@ func TestDeleteRunner_Run(t *testing.T) {
queue.Start()
defer queue.Close()
schema := &schemapb.CollectionSchema{
collSchema := &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
@ -459,6 +460,7 @@ func TestDeleteRunner_Run(t *testing.T) {
},
},
}
schema := newSchemaInfo(collSchema)
metaCache := NewMockCache(t)
metaCache.EXPECT().GetCollectionID(mock.Anything, dbName, collectionName).Return(collectionID, nil).Maybe()
@ -474,6 +476,7 @@ func TestDeleteRunner_Run(t *testing.T) {
req: &milvuspb.DeleteRequest{
Expr: "????",
},
schema: schema,
}
assert.Error(t, dr.Run(context.Background()))
})
@ -838,7 +841,7 @@ func TestDeleteRunner_StreamingQueryAndDelteFunc(t *testing.T) {
queue.Start()
defer queue.Close()
schema := &schemapb.CollectionSchema{
collSchema := &schemapb.CollectionSchema{
Name: "test_delete",
Description: "",
AutoID: false,
@ -859,7 +862,9 @@ func TestDeleteRunner_StreamingQueryAndDelteFunc(t *testing.T) {
}
// test partitionKey mode
schema.Fields[1].IsPartitionKey = true
collSchema.Fields[1].IsPartitionKey = true
schema := newSchemaInfo(collSchema)
partitionMaps := make(map[string]int64)
partitionMaps["test_0"] = 1
partitionMaps["test_1"] = 2
@ -930,7 +935,7 @@ func TestDeleteRunner_StreamingQueryAndDelteFunc(t *testing.T) {
globalMetaCache = mockCache
defer func() { globalMetaCache = nil }()
plan, err := planparserv2.CreateRetrievePlan(dr.schema, dr.req.Expr)
plan, err := planparserv2.CreateRetrievePlan(dr.schema.CollectionSchema, dr.req.Expr)
assert.NoError(t, err)
queryFunc := dr.getStreamingQueryAndDelteFunc(plan)
assert.Error(t, queryFunc(ctx, 1, qn))
@ -973,7 +978,7 @@ func TestDeleteRunner_StreamingQueryAndDelteFunc(t *testing.T) {
globalMetaCache = mockCache
defer func() { globalMetaCache = nil }()
plan, err := planparserv2.CreateRetrievePlan(dr.schema, dr.req.Expr)
plan, err := planparserv2.CreateRetrievePlan(dr.schema.CollectionSchema, dr.req.Expr)
assert.NoError(t, err)
queryFunc := dr.getStreamingQueryAndDelteFunc(plan)
assert.Error(t, queryFunc(ctx, 1, qn))

View File

@ -294,7 +294,7 @@ func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.Fiel
log.Error("failed to get collection schema", zap.Error(err))
return nil, fmt.Errorf("failed to get collection schema: %s", err)
}
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
schemaHelper, err := typeutil.CreateSchemaHelper(schema.CollectionSchema)
if err != nil {
log.Error("failed to parse collection schema", zap.Error(err))
return nil, fmt.Errorf("failed to parse collection schema: %s", err)
@ -616,7 +616,7 @@ func (dit *describeIndexTask) Execute(ctx context.Context) error {
log.Error("failed to get collection schema", zap.Error(err))
return fmt.Errorf("failed to get collection schema: %s", err)
}
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
schemaHelper, err := typeutil.CreateSchemaHelper(schema.CollectionSchema)
if err != nil {
log.Error("failed to parse collection schema", zap.Error(err))
return fmt.Errorf("failed to parse collection schema: %s", err)
@ -740,7 +740,7 @@ func (dit *getIndexStatisticsTask) Execute(ctx context.Context) error {
log.Error("failed to get collection schema", zap.String("collection_name", dit.GetCollectionName()), zap.Error(err))
return fmt.Errorf("failed to get collection schema: %s", dit.GetCollectionName())
}
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
schemaHelper, err := typeutil.CreateSchemaHelper(schema.CollectionSchema)
if err != nil {
log.Error("failed to parse collection schema", zap.String("collection_name", schema.GetName()), zap.Error(err))
return fmt.Errorf("failed to parse collection schema: %s", dit.GetCollectionName())

View File

@ -245,7 +245,7 @@ func TestCreateIndexTask_PreExecute(t *testing.T) {
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(newTestSchema(), nil)
).Return(newSchemaInfo(newTestSchema()), nil)
globalMetaCache = mockCache

View File

@ -116,7 +116,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
log.Warn("get collection schema from global meta cache failed", zap.String("collectionName", collectionName), zap.Error(err))
return err
}
it.schema = schema
it.schema = schema.CollectionSchema
rowNums := uint32(it.insertMsg.NRows())
// set insertTask.rowIDs
@ -164,7 +164,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
}
// set field ID to insert field data
err = fillFieldIDBySchema(it.insertMsg.GetFieldsData(), schema)
err = fillFieldIDBySchema(it.insertMsg.GetFieldsData(), schema.CollectionSchema)
if err != nil {
log.Info("set fieldID to fieldData failed",
zap.Error(err))
@ -199,7 +199,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
}
if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck(), withMaxCapCheck()).
Validate(it.insertMsg.GetFieldsData(), schema, it.insertMsg.NRows()); err != nil {
Validate(it.insertMsg.GetFieldsData(), schema.CollectionSchema, it.insertMsg.NRows()); err != nil {
return err
}

View File

@ -52,7 +52,7 @@ type queryTask struct {
ids *schemapb.IDs
collectionName string
queryParams *queryParams
schema *schemapb.CollectionSchema
schema *schemaInfo
userOutputFields []string
@ -206,25 +206,25 @@ func (t *queryTask) createPlan(ctx context.Context) error {
cntMatch := matchCountRule(t.request.GetOutputFields())
if cntMatch {
var err error
t.plan, err = createCntPlan(t.request.GetExpr(), schema)
t.plan, err = createCntPlan(t.request.GetExpr(), schema.CollectionSchema)
t.userOutputFields = []string{"count(*)"}
return err
}
var err error
if t.plan == nil {
t.plan, err = planparserv2.CreateRetrievePlan(schema, t.request.Expr)
t.plan, err = planparserv2.CreateRetrievePlan(schema.CollectionSchema, t.request.Expr)
if err != nil {
return err
}
}
t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, schema, true)
t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, true)
if err != nil {
return err
}
outputFieldIDs, err := translateToOutputFieldIDs(t.request.GetOutputFields(), schema)
outputFieldIDs, err := translateToOutputFieldIDs(t.request.GetOutputFields(), schema.CollectionSchema)
if err != nil {
return err
}
@ -453,7 +453,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Observe(0.0)
tr.CtxRecord(ctx, "reduceResultStart")
reducer := createMilvusReducer(ctx, t.queryParams, t.RetrieveRequest, t.schema, t.plan, t.collectionName)
reducer := createMilvusReducer(ctx, t.queryParams, t.RetrieveRequest, t.schema.CollectionSchema, t.plan, t.collectionName)
t.result, err = reducer.Reduce(toReduceResults)
if err != nil {

View File

@ -849,10 +849,22 @@ func Test_createCntPlan(t *testing.T) {
func Test_queryTask_createPlan(t *testing.T) {
t.Run("match count rule", func(t *testing.T) {
collSchema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "a",
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
},
},
}
schema := newSchemaInfo(collSchema)
tsk := &queryTask{
request: &milvuspb.QueryRequest{
OutputFields: []string{"count(*)"},
},
schema: schema,
}
err := tsk.createPlan(context.TODO())
assert.NoError(t, err)
@ -866,13 +878,14 @@ func Test_queryTask_createPlan(t *testing.T) {
request: &milvuspb.QueryRequest{
OutputFields: []string{"a"},
},
schema: &schemaInfo{},
}
err := tsk.createPlan(context.TODO())
assert.Error(t, err)
})
t.Run("invalid expression", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
collSchema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
@ -882,6 +895,7 @@ func Test_queryTask_createPlan(t *testing.T) {
},
},
}
schema := newSchemaInfo(collSchema)
tsk := &queryTask{
schema: schema,
@ -895,7 +909,7 @@ func Test_queryTask_createPlan(t *testing.T) {
})
t.Run("invalid output fields", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
collSchema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
@ -905,6 +919,7 @@ func Test_queryTask_createPlan(t *testing.T) {
},
},
}
schema := newSchemaInfo(collSchema)
tsk := &queryTask{
schema: schema,

View File

@ -55,7 +55,7 @@ type searchTask struct {
tr *timerecord.TimeRecorder
collectionName string
schema *schemapb.CollectionSchema
schema *schemaInfo
requery bool
userOutputFields []string
@ -179,20 +179,14 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryIn
}, offset, nil
}
func getOutputFieldIDs(schema *schemapb.CollectionSchema, outputFields []string) (outputFieldIDs []UniqueID, err error) {
func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) {
outputFieldIDs = make([]UniqueID, 0, len(outputFields))
for _, name := range outputFields {
hitField := false
for _, field := range schema.GetFields() {
if field.Name == name {
outputFieldIDs = append(outputFieldIDs, field.GetFieldID())
hitField = true
break
}
}
if !hitField {
id, ok := schema.MapFieldID(name)
if !ok {
return nil, fmt.Errorf("Field %s not exist", name)
}
outputFieldIDs = append(outputFieldIDs, id)
}
return outputFieldIDs, nil
}
@ -294,7 +288,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
if err != nil || len(annsField) == 0 {
vecFields := typeutil.GetVectorFieldSchemas(t.schema)
vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema)
if len(vecFields) == 0 {
return errors.New(AnnsFieldKey + " not found in schema")
}
@ -311,7 +305,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
}
t.offset = offset
plan, err := planparserv2.CreateSearchPlan(t.schema, t.request.Dsl, annsField, queryInfo)
plan, err := planparserv2.CreateSearchPlan(t.schema.CollectionSchema, t.request.Dsl, annsField, queryInfo)
if err != nil {
log.Warn("failed to create query plan", zap.Error(err),
zap.String("dsl", t.request.Dsl), // may be very large if large term passed.
@ -489,7 +483,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
zap.Int64s("partitionIDs", t.GetPartitionIDs()),
zap.Int("number of valid search results", len(validSearchResults)))
tr.CtxRecord(ctx, "reduceResultStart")
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(t.schema)
primaryFieldSchema, err := t.schema.GetPkField()
if err != nil {
log.Warn("failed to get primary field schema", zap.Error(err))
return err
@ -582,7 +576,7 @@ func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) {
}
func (t *searchTask) Requery() error {
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema)
pkField, err := t.schema.GetPkField()
if err != nil {
return err
}

View File

@ -1933,8 +1933,10 @@ func TestSearchTask_Requery(t *testing.T) {
collectionName := "col"
collectionID := UniqueID(0)
cache := NewMockCache(t)
collSchema := constructCollectionSchema(pkField, vecField, dim, collection)
schema := newSchemaInfo(collSchema)
cache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(collectionID, nil).Maybe()
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(constructCollectionSchema(pkField, vecField, dim, collection), nil).Maybe()
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(schema, nil).Maybe()
cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe()
cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionBasicInfo{}, nil).Maybe()
cache.EXPECT().GetShards(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[string][]nodeInfo{}, nil).Maybe()
@ -1942,7 +1944,8 @@ func TestSearchTask_Requery(t *testing.T) {
globalMetaCache = cache
t.Run("Test normal", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
collSchema := constructCollectionSchema(pkField, vecField, dim, collection)
schema := newSchemaInfo(collSchema)
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, request *querypb.QueryRequest, option ...grpc.CallOption) (*internalpb.RetrieveResults, error) {
@ -2033,7 +2036,9 @@ func TestSearchTask_Requery(t *testing.T) {
})
t.Run("Test no primary key", func(t *testing.T) {
schema := &schemapb.CollectionSchema{}
collSchema := &schemapb.CollectionSchema{}
schema := newSchemaInfo(collSchema)
node := mocks.NewMockProxy(t)
qt := &searchTask{
@ -2056,7 +2061,8 @@ func TestSearchTask_Requery(t *testing.T) {
})
t.Run("Test requery failed", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
collSchema := constructCollectionSchema(pkField, vecField, dim, collection)
schema := newSchemaInfo(collSchema)
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Query(mock.Anything, mock.Anything).
Return(nil, fmt.Errorf("mock err 1"))
@ -2089,7 +2095,8 @@ func TestSearchTask_Requery(t *testing.T) {
})
t.Run("Test postExecute with requery failed", func(t *testing.T) {
schema := constructCollectionSchema(pkField, vecField, dim, collection)
collSchema := constructCollectionSchema(pkField, vecField, dim, collection)
schema := newSchemaInfo(collSchema)
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Query(mock.Anything, mock.Anything).
Return(nil, fmt.Errorf("mock err 1"))

View File

@ -434,7 +434,7 @@ func TestTranslateOutputFields(t *testing.T) {
var userOutputFields []string
var err error
schema := &schemapb.CollectionSchema{
collSchema := &schemapb.CollectionSchema{
Name: "TestTranslateOutputFields",
Description: "TestTranslateOutputFields",
AutoID: false,
@ -446,6 +446,7 @@ func TestTranslateOutputFields(t *testing.T) {
{Name: float16VectorFieldName, FieldID: 102, DataType: schemapb.DataType_Float16Vector},
},
}
schema := newSchemaInfo(collSchema)
outputFields, userOutputFields, err = translateOutputFields([]string{}, schema, false)
assert.Equal(t, nil, err)
@ -527,7 +528,7 @@ func TestTranslateOutputFields(t *testing.T) {
assert.Error(t, err)
t.Run("enable dynamic schema", func(t *testing.T) {
schema := &schemapb.CollectionSchema{
collSchema := &schemapb.CollectionSchema{
Name: "TestTranslateOutputFields",
Description: "TestTranslateOutputFields",
AutoID: false,
@ -540,6 +541,7 @@ func TestTranslateOutputFields(t *testing.T) {
{Name: common.MetaFieldName, FieldID: 102, DataType: schemapb.DataType_JSON, IsDynamic: true},
},
}
schema := newSchemaInfo(collSchema)
outputFields, userOutputFields, err = translateOutputFields([]string{"A", idFieldName}, schema, true)
assert.Equal(t, nil, err)
@ -1322,7 +1324,7 @@ func TestDropPartitionTask(t *testing.T) {
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(&schemapb.CollectionSchema{}, nil)
).Return(newSchemaInfo(&schemapb.CollectionSchema{}), nil)
globalMetaCache = mockCache
task := &dropPartitionTask{
@ -1373,7 +1375,7 @@ func TestDropPartitionTask(t *testing.T) {
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(&schemapb.CollectionSchema{}, nil)
).Return(newSchemaInfo(&schemapb.CollectionSchema{}), nil)
globalMetaCache = mockCache
task.PartitionName = "partition1"
err = task.PreExecute(ctx)
@ -1400,7 +1402,7 @@ func TestDropPartitionTask(t *testing.T) {
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(&schemapb.CollectionSchema{}, nil)
).Return(newSchemaInfo(&schemapb.CollectionSchema{}), nil)
globalMetaCache = mockCache
err = task.PreExecute(ctx)
assert.NoError(t, err)
@ -1426,7 +1428,7 @@ func TestDropPartitionTask(t *testing.T) {
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(&schemapb.CollectionSchema{}, nil)
).Return(newSchemaInfo(&schemapb.CollectionSchema{}), nil)
globalMetaCache = mockCache
err = task.PreExecute(ctx)
assert.Error(t, err)
@ -2136,7 +2138,7 @@ func Test_createIndexTask_getIndexedField(t *testing.T) {
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(&schemapb.CollectionSchema{
).Return(newSchemaInfo(&schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
@ -2153,7 +2155,7 @@ func Test_createIndexTask_getIndexedField(t *testing.T) {
AutoID: false,
},
},
}, nil)
}), nil)
globalMetaCache = cache
field, err := cit.getIndexedField(context.Background())
@ -2179,7 +2181,7 @@ func Test_createIndexTask_getIndexedField(t *testing.T) {
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(&schemapb.CollectionSchema{
).Return(newSchemaInfo(&schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: fieldName,
@ -2188,7 +2190,7 @@ func Test_createIndexTask_getIndexedField(t *testing.T) {
Name: fieldName, // duplicate
},
},
}, nil)
}), nil)
globalMetaCache = cache
_, err := cit.getIndexedField(context.Background())
assert.Error(t, err)
@ -2200,13 +2202,13 @@ func Test_createIndexTask_getIndexedField(t *testing.T) {
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(&schemapb.CollectionSchema{
).Return(newSchemaInfo(&schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: fieldName + fieldName,
},
},
}, nil)
}), nil)
globalMetaCache = cache
_, err := cit.getIndexedField(context.Background())
assert.Error(t, err)
@ -2348,7 +2350,7 @@ func Test_createIndexTask_PreExecute(t *testing.T) {
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(&schemapb.CollectionSchema{
).Return(newSchemaInfo(&schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
@ -2365,7 +2367,7 @@ func Test_createIndexTask_PreExecute(t *testing.T) {
AutoID: false,
},
},
}, nil)
}), nil)
globalMetaCache = cache
cit.req.ExtraParams = []*commonpb.KeyValuePair{
{

View File

@ -59,7 +59,7 @@ type upsertTask struct {
chTicker channelsTimeTicker
vChannels []vChan
pChannels []pChan
schema *schemapb.CollectionSchema
schema *schemaInfo
partitionKeyMode bool
partitionKeys *schemapb.FieldData
}
@ -172,7 +172,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
it.result.SuccIndex = sliceIndex
if it.schema.EnableDynamicField {
err := checkDynamicFieldData(it.schema, it.upsertMsg.InsertMsg)
err := checkDynamicFieldData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg)
if err != nil {
return err
}
@ -181,7 +181,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
// check primaryFieldData whether autoID is true or not
// only allow support autoID == false
var err error
it.result.IDs, err = checkPrimaryFieldData(it.schema, it.result, it.upsertMsg.InsertMsg, false)
it.result.IDs, err = checkPrimaryFieldData(it.schema.CollectionSchema, it.result, it.upsertMsg.InsertMsg, false)
log := log.Ctx(ctx).With(zap.String("collectionName", it.upsertMsg.InsertMsg.CollectionName))
if err != nil {
log.Warn("check primary field data and hash primary key failed when upsert",
@ -189,7 +189,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
return err
}
// set field ID to insert field data
err = fillFieldIDBySchema(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema)
err = fillFieldIDBySchema(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.CollectionSchema)
if err != nil {
log.Warn("insert set fieldID to fieldData failed when upsert",
zap.Error(err))
@ -197,8 +197,8 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
}
if it.partitionKeyMode {
fieldSchema, _ := typeutil.GetPartitionKeyFieldSchema(it.schema)
it.partitionKeys, err = getPartitionKeyFieldData(fieldSchema, it.upsertMsg.InsertMsg)
pkFieldSchema, _ := it.schema.GetPkField()
it.partitionKeys, err = getPartitionKeyFieldData(pkFieldSchema, it.upsertMsg.InsertMsg)
if err != nil {
log.Warn("get partition keys from insert request failed",
zap.String("collectionName", collectionName),
@ -214,7 +214,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
}
if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck()).
Validate(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema, it.upsertMsg.InsertMsg.NRows()); err != nil {
Validate(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.CollectionSchema, it.upsertMsg.InsertMsg.NRows()); err != nil {
return err
}

View File

@ -73,6 +73,24 @@ func TestUpsertTask_CheckAligned(t *testing.T) {
numRows := 20
dim := 128
collSchema := &schemapb.CollectionSchema{
Name: "TestUpsertTask_checkRowNums",
Description: "TestUpsertTask_checkRowNums",
AutoID: false,
Fields: []*schemapb.FieldSchema{
boolFieldSchema,
int8FieldSchema,
int16FieldSchema,
int32FieldSchema,
int64FieldSchema,
floatFieldSchema,
doubleFieldSchema,
floatVectorFieldSchema,
binaryVectorFieldSchema,
varCharFieldSchema,
},
}
schema := newSchemaInfo(collSchema)
case2 := upsertTask{
req: &milvuspb.UpsertRequest{
NumRows: uint32(numRows),
@ -80,23 +98,7 @@ func TestUpsertTask_CheckAligned(t *testing.T) {
},
rowIDs: generateInt64Array(numRows),
timestamps: generateUint64Array(numRows),
schema: &schemapb.CollectionSchema{
Name: "TestUpsertTask_checkRowNums",
Description: "TestUpsertTask_checkRowNums",
AutoID: false,
Fields: []*schemapb.FieldSchema{
boolFieldSchema,
int8FieldSchema,
int16FieldSchema,
int32FieldSchema,
int64FieldSchema,
floatFieldSchema,
doubleFieldSchema,
floatVectorFieldSchema,
binaryVectorFieldSchema,
varCharFieldSchema,
},
},
schema: schema,
upsertMsg: &msgstream.UpsertMsg{
InsertMsg: &msgstream.InsertMsg{
InsertRequest: msgpb.InsertRequest{},

View File

@ -978,7 +978,7 @@ func translatePkOutputFields(schema *schemapb.CollectionSchema) ([]string, []int
// output_fields=["*"] ==> [A,B,C,D]
// output_fields=["*",A] ==> [A,B,C,D]
// output_fields=["*",C] ==> [A,B,C,D]
func translateOutputFields(outputFields []string, schema *schemapb.CollectionSchema, addPrimary bool) ([]string, []string, error) {
func translateOutputFields(outputFields []string, schema *schemaInfo, addPrimary bool) ([]string, []string, error) {
var primaryFieldName string
allFieldNameMap := make(map[string]bool)
resultFieldNameMap := make(map[string]bool)
@ -1006,7 +1006,7 @@ func translateOutputFields(outputFields []string, schema *schemapb.CollectionSch
userOutputFieldsMap[outputFieldName] = true
} else {
if schema.EnableDynamicField {
schemaH, err := typeutil.CreateSchemaHelper(schema)
schemaH, err := typeutil.CreateSchemaHelper(schema.CollectionSchema)
if err != nil {
return nil, nil, err
}
@ -1447,7 +1447,7 @@ func assignPartitionKeys(ctx context.Context, dbName string, collName string, ke
return nil, err
}
partitionKeyFieldSchema, err := typeutil.GetPartitionKeyFieldSchema(schema)
partitionKeyFieldSchema, err := typeutil.GetPartitionKeyFieldSchema(schema.CollectionSchema)
if err != nil {
return nil, err
}
@ -1600,7 +1600,7 @@ func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream.
}
}
func GetCachedCollectionSchema(ctx context.Context, dbName string, colName string) (*schemapb.CollectionSchema, error) {
func GetCachedCollectionSchema(ctx context.Context, dbName string, colName string) (*schemaInfo, error) {
if globalMetaCache != nil {
return globalMetaCache.GetCollectionSchema(ctx, dbName, colName)
}