mirror of https://github.com/milvus-io/milvus.git
Fix CollectionNotExists when search and retrieve vector (#26524)
Signed-off-by: xige-16 <xi.ge@zilliz.com>pull/26547/head
parent
9131a0aa56
commit
1e5836221a
|
@ -1653,7 +1653,7 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get
|
|||
}
|
||||
} else {
|
||||
if loadProgress, refreshProgress, err = getPartitionProgress(ctx, node.queryCoord, request.GetBase(),
|
||||
request.GetPartitionNames(), request.GetCollectionName(), collectionID); err != nil {
|
||||
request.GetPartitionNames(), request.GetCollectionName(), collectionID, request.GetDbName()); err != nil {
|
||||
return getErrResponse(err), nil
|
||||
}
|
||||
}
|
||||
|
@ -1755,7 +1755,7 @@ func (node *Proxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadSt
|
|||
}
|
||||
} else {
|
||||
if progress, _, err = getPartitionProgress(ctx, node.queryCoord, request.GetBase(),
|
||||
request.GetPartitionNames(), request.GetCollectionName(), collectionID); err != nil {
|
||||
request.GetPartitionNames(), request.GetCollectionName(), collectionID, request.GetDbName()); err != nil {
|
||||
if errors.Is(err, ErrInsufficientMemory) {
|
||||
return &milvuspb.GetLoadStateResponse{
|
||||
Status: InSufficientMemoryStatus(request.GetCollectionName()),
|
||||
|
|
|
@ -330,14 +330,14 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
partitionKeys := ParsePartitionKeys(expr)
|
||||
hashedPartitionNames, err := assignPartitionKeys(ctx, "", t.request.CollectionName, partitionKeys)
|
||||
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
partitionNames = append(partitionNames, hashedPartitionNames...)
|
||||
}
|
||||
t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.CollectionName, partitionNames)
|
||||
t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.request.CollectionName, partitionNames)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -68,14 +68,14 @@ type searchTask struct {
|
|||
lb LBPolicy
|
||||
}
|
||||
|
||||
func getPartitionIDs(ctx context.Context, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
|
||||
func getPartitionIDs(ctx context.Context, dbName string, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
|
||||
for _, tag := range partitionNames {
|
||||
if err := validatePartitionTag(tag, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
partitionsMap, err := globalMetaCache.GetPartitions(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
||||
partitionsMap, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -351,7 +351,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// translate partition name to partition ids. Use regex-pattern to match partition name.
|
||||
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, collectionName, partitionNames)
|
||||
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, partitionNames)
|
||||
if err != nil {
|
||||
log.Warn("failed to get partition ids", zap.Error(err))
|
||||
return err
|
||||
|
@ -579,6 +579,7 @@ func (t *searchTask) Requery() error {
|
|||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
},
|
||||
DbName: t.request.GetDbName(),
|
||||
CollectionName: t.request.GetCollectionName(),
|
||||
Expr: expr,
|
||||
OutputFields: t.request.GetOutputFields(),
|
||||
|
|
|
@ -114,7 +114,7 @@ func (g *getStatisticsTask) PreExecute(ctx context.Context) error {
|
|||
if err != nil { // err is not nil if collection not exists
|
||||
return err
|
||||
}
|
||||
partIDs, err := getPartitionIDs(ctx, g.collectionName, g.partitionNames)
|
||||
partIDs, err := getPartitionIDs(ctx, g.request.GetDbName(), g.collectionName, g.partitionNames)
|
||||
if err != nil { // err is not nil if partition not exists
|
||||
return err
|
||||
}
|
||||
|
@ -131,7 +131,7 @@ func (g *getStatisticsTask) PreExecute(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// check if collection/partitions are loaded into query node
|
||||
loaded, unloaded, err := checkFullLoaded(ctx, g.qc, g.collectionName, g.GetStatisticsRequest.CollectionID, partIDs)
|
||||
loaded, unloaded, err := checkFullLoaded(ctx, g.qc, g.request.GetDbName(), g.collectionName, g.GetStatisticsRequest.CollectionID, partIDs)
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.String("collectionName", g.collectionName),
|
||||
zap.Int64("collectionID", g.CollectionID),
|
||||
|
@ -312,14 +312,14 @@ func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64
|
|||
|
||||
// checkFullLoaded check if collection / partition was fully loaded into QueryNode
|
||||
// return loaded partitions, unloaded partitions and error
|
||||
func checkFullLoaded(ctx context.Context, qc types.QueryCoord, collectionName string, collectionID int64, searchPartitionIDs []UniqueID) ([]UniqueID, []UniqueID, error) {
|
||||
func checkFullLoaded(ctx context.Context, qc types.QueryCoord, dbName string, collectionName string, collectionID int64, searchPartitionIDs []UniqueID) ([]UniqueID, []UniqueID, error) {
|
||||
var loadedPartitionIDs []UniqueID
|
||||
var unloadPartitionIDs []UniqueID
|
||||
|
||||
// TODO: Consider to check if partition loaded from cache to save rpc.
|
||||
info, err := globalMetaCache.GetCollectionInfo(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName, collectionID)
|
||||
info, err := globalMetaCache.GetCollectionInfo(ctx, dbName, collectionName, collectionID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("GetCollectionInfo failed, collectionName = %s,collectionID = %d, err = %s", collectionName, collectionID, err)
|
||||
return nil, nil, fmt.Errorf("GetCollectionInfo failed, dbName = %s, collectionName = %s,collectionID = %d, err = %s", dbName, collectionName, collectionID, err)
|
||||
}
|
||||
|
||||
// If request to search partitions
|
||||
|
|
|
@ -1219,12 +1219,13 @@ func getPartitionProgress(
|
|||
partitionNames []string,
|
||||
collectionName string,
|
||||
collectionID int64,
|
||||
dbName string,
|
||||
) (loadProgress int64, refreshProgress int64, err error) {
|
||||
IDs2Names := make(map[int64]string)
|
||||
partitionIDs := make([]int64, 0)
|
||||
for _, partitionName := range partitionNames {
|
||||
var partitionID int64
|
||||
partitionID, err = globalMetaCache.GetPartitionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName, partitionName)
|
||||
partitionID, err = globalMetaCache.GetPartitionID(ctx, dbName, collectionName, partitionName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1808,6 +1808,6 @@ func Test_GetPartitionProgressFailed(t *testing.T) {
|
|||
Reason: "Unexpected error",
|
||||
},
|
||||
}, nil)
|
||||
_, _, err := getPartitionProgress(context.TODO(), qc, &commonpb.MsgBase{}, []string{}, "", 1)
|
||||
_, _, err := getPartitionProgress(context.TODO(), qc, &commonpb.MsgBase{}, []string{}, "", 1, "")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
|
|
@ -38,6 +38,8 @@ import (
|
|||
type TestGetVectorSuite struct {
|
||||
integration.MiniClusterSuite
|
||||
|
||||
dbName string
|
||||
|
||||
// test params
|
||||
nq int
|
||||
topK int
|
||||
|
@ -62,6 +64,14 @@ func (s *TestGetVectorSuite) run() {
|
|||
dim = 128
|
||||
)
|
||||
|
||||
if len(s.dbName) > 0 {
|
||||
createDataBaseStatus, err := s.Cluster.Proxy.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{
|
||||
DbName: s.dbName,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(createDataBaseStatus.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
}
|
||||
|
||||
pkFieldName := "pkField"
|
||||
vecFieldName := "vecField"
|
||||
pk := &schemapb.FieldSchema{
|
||||
|
@ -98,6 +108,7 @@ func (s *TestGetVectorSuite) run() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
createCollectionStatus, err := s.Cluster.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
|
||||
DbName: s.dbName,
|
||||
CollectionName: collection,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: 2,
|
||||
|
@ -120,6 +131,7 @@ func (s *TestGetVectorSuite) run() {
|
|||
fieldsData = append(fieldsData, vecFieldData)
|
||||
hashKeys := integration.GenerateHashKeys(NB)
|
||||
_, err = s.Cluster.Proxy.Insert(ctx, &milvuspb.InsertRequest{
|
||||
DbName: s.dbName,
|
||||
CollectionName: collection,
|
||||
FieldsData: fieldsData,
|
||||
HashKeys: hashKeys,
|
||||
|
@ -130,6 +142,7 @@ func (s *TestGetVectorSuite) run() {
|
|||
|
||||
// flush
|
||||
flushResp, err := s.Cluster.Proxy.Flush(ctx, &milvuspb.FlushRequest{
|
||||
DbName: s.dbName,
|
||||
CollectionNames: []string{collection},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
@ -146,6 +159,7 @@ func (s *TestGetVectorSuite) run() {
|
|||
|
||||
// create index
|
||||
_, err = s.Cluster.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||
DbName: s.dbName,
|
||||
CollectionName: collection,
|
||||
FieldName: vecFieldName,
|
||||
IndexName: "_default",
|
||||
|
@ -154,15 +168,16 @@ func (s *TestGetVectorSuite) run() {
|
|||
s.Require().NoError(err)
|
||||
s.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
|
||||
s.WaitForIndexBuilt(ctx, collection, vecFieldName)
|
||||
s.WaitForIndexBuiltWithDB(ctx, s.dbName, collection, vecFieldName)
|
||||
|
||||
// load
|
||||
_, err = s.Cluster.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
|
||||
DbName: s.dbName,
|
||||
CollectionName: collection,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||
s.WaitForLoad(ctx, collection)
|
||||
s.WaitForLoadWithDB(ctx, s.dbName, collection)
|
||||
|
||||
// search
|
||||
nq := s.nq
|
||||
|
@ -170,7 +185,7 @@ func (s *TestGetVectorSuite) run() {
|
|||
|
||||
outputFields := []string{vecFieldName}
|
||||
params := integration.GetSearchParams(s.indexType, s.metricType)
|
||||
searchReq := integration.ConstructSearchRequest("", collection, "",
|
||||
searchReq := integration.ConstructSearchRequest(s.dbName, collection, "",
|
||||
vecFieldName, s.vecType, outputFields, s.metricType, params, nq, dim, topk, -1)
|
||||
|
||||
searchResp, err := s.Cluster.Proxy.Search(ctx, searchReq)
|
||||
|
@ -248,6 +263,7 @@ func (s *TestGetVectorSuite) run() {
|
|||
}
|
||||
|
||||
status, err := s.Cluster.Proxy.DropCollection(ctx, &milvuspb.DropCollectionRequest{
|
||||
DbName: s.dbName,
|
||||
CollectionName: collection,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
@ -365,6 +381,18 @@ func (s *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() {
|
|||
s.run()
|
||||
}
|
||||
|
||||
func (s *TestGetVectorSuite) TestGetVector_With_DB_Name() {
|
||||
s.dbName = "test_db"
|
||||
s.nq = 10
|
||||
s.topK = 10
|
||||
s.indexType = integration.IndexHNSW
|
||||
s.metricType = metric.L2
|
||||
s.pkType = schemapb.DataType_Int64
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
s.searchFailed = false
|
||||
s.run()
|
||||
}
|
||||
|
||||
//func (s *TestGetVectorSuite) TestGetVector_DISKANN() {
|
||||
// s.nq = 10
|
||||
// s.topK = 10
|
||||
|
|
|
@ -43,9 +43,18 @@ const (
|
|||
IndexDISKANN = indexparamcheck.IndexDISKANN
|
||||
)
|
||||
|
||||
func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) {
|
||||
s.waitForIndexBuiltInternal(ctx, dbName, collection, field)
|
||||
}
|
||||
|
||||
func (s *MiniClusterSuite) WaitForIndexBuilt(ctx context.Context, collection, field string) {
|
||||
s.waitForIndexBuiltInternal(ctx, "", collection, field)
|
||||
}
|
||||
|
||||
func (s *MiniClusterSuite) waitForIndexBuiltInternal(ctx context.Context, dbName, collection, field string) {
|
||||
getIndexBuilt := func() bool {
|
||||
resp, err := s.Cluster.Proxy.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collection,
|
||||
FieldName: field,
|
||||
})
|
||||
|
|
|
@ -44,10 +44,19 @@ const (
|
|||
LimitKey = "limit"
|
||||
)
|
||||
|
||||
func (s *MiniClusterSuite) WaitForLoadWithDB(ctx context.Context, dbName, collection string) {
|
||||
s.waitForLoadInternal(ctx, dbName, collection)
|
||||
}
|
||||
|
||||
func (s *MiniClusterSuite) WaitForLoad(ctx context.Context, collection string) {
|
||||
s.waitForLoadInternal(ctx, "", collection)
|
||||
}
|
||||
|
||||
func (s *MiniClusterSuite) waitForLoadInternal(ctx context.Context, dbName, collection string) {
|
||||
cluster := s.Cluster
|
||||
getLoadingProgress := func() *milvuspb.GetLoadingProgressResponse {
|
||||
loadProgress, err := cluster.Proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collection,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue