diff --git a/internal/datacoord/index_service.go b/internal/datacoord/index_service.go index a15447de79..2094c28560 100644 --- a/internal/datacoord/index_service.go +++ b/internal/datacoord/index_service.go @@ -208,14 +208,8 @@ func (s *Server) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques // merge with previous params because create index would not pass mmap params indexes := s.meta.indexMeta.GetFieldIndexes(req.GetCollectionID(), req.GetFieldID(), req.GetIndexName()) if len(indexes) == 1 { - req.UserIndexParams, err = UpdateParams(indexes[0], indexes[0].UserIndexParams, req.GetUserIndexParams()) - if err != nil { - return merr.Status(err), nil - } - req.IndexParams, err = UpdateParams(indexes[0], indexes[0].IndexParams, req.GetIndexParams()) - if err != nil { - return merr.Status(err), nil - } + req.UserIndexParams = UpdateParams(indexes[0], indexes[0].UserIndexParams, req.GetUserIndexParams()) + req.IndexParams = UpdateParams(indexes[0], indexes[0].IndexParams, req.GetIndexParams()) } if indexID == 0 { @@ -246,6 +240,11 @@ func (s *Server) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques UserIndexParams: req.GetUserIndexParams(), } + if err := ValidateIndexParams(index); err != nil { + metrics.IndexRequestCounter.WithLabelValues(metrics.FailLabel).Inc() + return merr.Status(err), nil + } + // Get flushed segments and create index err = s.meta.indexMeta.CreateIndex(index) if err != nil { @@ -267,22 +266,27 @@ func (s *Server) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques return merr.Success(), nil } -func ValidateIndexParams(index *model.Index, key, value string) error { - switch key { - case common.MmapEnabledKey: - indexType := GetIndexType(index.IndexParams) - if !indexparamcheck.IsMmapSupported(indexType) { - return merr.WrapErrParameterInvalidMsg("index type %s does not support mmap", indexType) - } +func ValidateIndexParams(index *model.Index) error { + for _, paramSet := range [][]*commonpb.KeyValuePair{index.IndexParams, index.UserIndexParams} { + for _, param := range paramSet { + switch param.GetKey() { + case common.MmapEnabledKey: + indexType := GetIndexType(index.IndexParams) + if !indexparamcheck.IsMmapSupported(indexType) { + return merr.WrapErrParameterInvalidMsg("index type %s does not support mmap", indexType) + } - if _, err := strconv.ParseBool(value); err != nil { - return merr.WrapErrParameterInvalidMsg("invalid %s value: %s, expected: true, false", key, value) + if _, err := strconv.ParseBool(param.GetValue()); err != nil { + return merr.WrapErrParameterInvalidMsg("invalid %s value: %s, expected: true, false", param.GetKey(), param.GetValue()) + } + } } } + return nil } -func UpdateParams(index *model.Index, from []*commonpb.KeyValuePair, updates []*commonpb.KeyValuePair) ([]*commonpb.KeyValuePair, error) { +func UpdateParams(index *model.Index, from []*commonpb.KeyValuePair, updates []*commonpb.KeyValuePair) []*commonpb.KeyValuePair { params := make(map[string]string) for _, param := range from { params[param.GetKey()] = param.GetValue() @@ -290,10 +294,6 @@ func UpdateParams(index *model.Index, from []*commonpb.KeyValuePair, updates []* // update the params for _, param := range updates { - if err := ValidateIndexParams(index, param.GetKey(), param.GetValue()); err != nil { - log.Warn("failed to alter index params", zap.Error(err)) - return nil, err - } params[param.GetKey()] = param.GetValue() } @@ -302,7 +302,7 @@ func UpdateParams(index *model.Index, from []*commonpb.KeyValuePair, updates []* Key: k, Value: v, } - }), nil + }) } func (s *Server) AlterIndex(ctx context.Context, req *indexpb.AlterIndexRequest) (*commonpb.Status, error) { @@ -329,10 +329,7 @@ func (s *Server) AlterIndex(ctx context.Context, req *indexpb.AlterIndexRequest) for _, index := range indexes { // update user index params - newUserIndexParams, err := UpdateParams(index, index.UserIndexParams, req.GetParams()) - if err != nil { - return merr.Status(err), nil - } + newUserIndexParams := UpdateParams(index, index.UserIndexParams, req.GetParams()) log.Info("alter index user index params", zap.String("indexName", index.IndexName), zap.Any("params", newUserIndexParams), @@ -340,15 +337,16 @@ func (s *Server) AlterIndex(ctx context.Context, req *indexpb.AlterIndexRequest) index.UserIndexParams = newUserIndexParams // update index params - newIndexParams, err := UpdateParams(index, index.IndexParams, req.GetParams()) - if err != nil { - return merr.Status(err), nil - } + newIndexParams := UpdateParams(index, index.IndexParams, req.GetParams()) log.Info("alter index user index params", zap.String("indexName", index.IndexName), zap.Any("params", newIndexParams), ) index.IndexParams = newIndexParams + + if err := ValidateIndexParams(index); err != nil { + return merr.Status(err), nil + } } err := s.meta.indexMeta.AlterIndex(ctx, indexes...) diff --git a/internal/datacoord/index_service_test.go b/internal/datacoord/index_service_test.go index 6de769d78b..d5aede0660 100644 --- a/internal/datacoord/index_service_test.go +++ b/internal/datacoord/index_service_test.go @@ -112,8 +112,7 @@ func TestServer_CreateIndex(t *testing.T) { s.broker = broker.NewCoordinatorBroker(b) resp, err := s.CreateIndex(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode()) + assert.Error(t, merr.CheckRPCCall(resp, err)) assert.Equal(t, "mock error", resp.GetReason()) }) @@ -169,22 +168,19 @@ func TestServer_CreateIndex(t *testing.T) { t.Run("success", func(t *testing.T) { resp, err := s.CreateIndex(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + assert.NoError(t, merr.CheckRPCCall(resp, err)) }) t.Run("success with index exist", func(t *testing.T) { req.IndexName = "" resp, err := s.CreateIndex(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + assert.NoError(t, merr.CheckRPCCall(resp, err)) }) t.Run("server not healthy", func(t *testing.T) { s.stateCode.Store(commonpb.StateCode_Abnormal) resp, err := s.CreateIndex(ctx, req) - assert.NoError(t, err) - assert.ErrorIs(t, merr.Error(resp), merr.ErrServiceNotReady) + assert.Error(t, merr.CheckRPCCall(resp, err)) }) req.IndexName = "FieldFloatVector" @@ -192,8 +188,7 @@ func TestServer_CreateIndex(t *testing.T) { s.stateCode.Store(commonpb.StateCode_Healthy) req.FieldID++ resp, err := s.CreateIndex(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode()) + assert.Error(t, merr.CheckRPCCall(resp, err)) }) t.Run("alloc ID fail", func(t *testing.T) { @@ -201,8 +196,7 @@ func TestServer_CreateIndex(t *testing.T) { s.allocator = &FailsAllocator{allocIDSucceed: false} s.meta.indexMeta.indexes = map[UniqueID]map[UniqueID]*model.Index{} resp, err := s.CreateIndex(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode()) + assert.Error(t, merr.CheckRPCCall(resp, err)) }) t.Run("not support disk index", func(t *testing.T) { @@ -216,8 +210,35 @@ func TestServer_CreateIndex(t *testing.T) { } s.indexNodeManager = NewNodeManager(ctx, defaultIndexNodeCreatorFunc) resp, err := s.CreateIndex(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode()) + assert.Error(t, merr.CheckRPCCall(resp, err)) + }) + + t.Run("disk index with mmap", func(t *testing.T) { + s.allocator = newMockAllocator() + s.meta.indexMeta.indexes = map[UniqueID]map[UniqueID]*model.Index{} + req.IndexParams = []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "DISKANN", + }, + { + Key: common.MmapEnabledKey, + Value: "true", + }, + } + nodeManager := NewNodeManager(ctx, defaultIndexNodeCreatorFunc) + s.indexNodeManager = nodeManager + mockNode := mocks.NewMockIndexNodeClient(t) + s.indexNodeManager.lock.Lock() + s.indexNodeManager.nodeClients[1001] = mockNode + s.indexNodeManager.lock.Unlock() + mockNode.EXPECT().GetJobStats(mock.Anything, mock.Anything).Return(&indexpb.GetJobStatsResponse{ + Status: merr.Success(), + EnableDisk: true, + }, nil) + + resp, err := s.CreateIndex(ctx, req) + assert.Error(t, merr.CheckRPCCall(resp, err)) }) t.Run("save index fail", func(t *testing.T) { @@ -234,8 +255,7 @@ func TestServer_CreateIndex(t *testing.T) { }, } resp, err := s.CreateIndex(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode()) + assert.Error(t, merr.CheckRPCCall(resp, err)) }) } diff --git a/internal/datacoord/indexnode_manager.go b/internal/datacoord/indexnode_manager.go index 7a6721f72b..d454743790 100644 --- a/internal/datacoord/indexnode_manager.go +++ b/internal/datacoord/indexnode_manager.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/merr" ) // IndexNodeManager is used to manage the client of IndexNode. @@ -175,15 +176,10 @@ func (nm *IndexNodeManager) ClientSupportDisk() bool { go func() { defer wg.Done() resp, err := client.GetJobStats(ctx, &indexpb.GetJobStatsRequest{}) - if err != nil { + if err := merr.CheckRPCCall(resp, err); err != nil { log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), zap.Error(err)) return } - if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), - zap.String("reason", resp.GetStatus().GetReason())) - return - } log.Debug("get job stats success", zap.Int64("nodeID", nodeID), zap.Bool("enable disk", resp.GetEnableDisk())) if resp.GetEnableDisk() { nodeMutex.Lock()