mirror of https://github.com/milvus-io/milvus.git
enhance: refactor createIndex in RESTful API (#37235)
Make the parameter input method consistent with miluvs-client. Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>pull/37501/head
parent
40b770cb7b
commit
86fd3200be
|
@ -1876,8 +1876,16 @@ func (h *HandlersV2) createIndex(ctx context.Context, c *gin.Context, anyReq any
|
|||
}
|
||||
c.Set(ContextRequest, req)
|
||||
|
||||
for key, value := range indexParam.Params {
|
||||
req.ExtraParams = append(req.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
|
||||
var err error
|
||||
req.ExtraParams, err = convertToExtraParams(indexParam)
|
||||
if err != nil {
|
||||
// will not happen
|
||||
log.Ctx(ctx).Warn("high level restful api, convertToExtraParams fail", zap.Error(err), zap.Any("request", anyReq))
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.CreateIndex(reqCtx, req.(*milvuspb.CreateIndexRequest))
|
||||
|
|
|
@ -709,6 +709,46 @@ func TestDocInDocOutSearch(t *testing.T) {
|
|||
sendReqAndVerify(t, testEngine, testcase.path, http.MethodPost, testcase)
|
||||
}
|
||||
|
||||
func TestCreateIndex(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false")
|
||||
defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key)
|
||||
|
||||
postTestCases := []requestBodyTestCase{}
|
||||
mp := mocks.NewMockProxy(t)
|
||||
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice()
|
||||
testEngine := initHTTPServerV2(mp, false)
|
||||
path := versionalV2(IndexCategory, CreateAction)
|
||||
// the previous format
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "params": {"index_type": "L2", "nlist": 10}}]}`),
|
||||
})
|
||||
// the current format
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2", "indexType": "L2", "params":{"nlist": 10}}]}`),
|
||||
})
|
||||
|
||||
for _, testcase := range postTestCases {
|
||||
t.Run("post"+testcase.path, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, testcase.path, bytes.NewReader(testcase.requestBody))
|
||||
w := httptest.NewRecorder()
|
||||
testEngine.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
fmt.Println(w.Body.String())
|
||||
returnBody := &ReturnErrMsg{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), returnBody)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, testcase.errCode, returnBody.Code)
|
||||
if testcase.errCode != 0 {
|
||||
assert.Equal(t, testcase.errMsg, returnBody.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCollection(t *testing.T) {
|
||||
paramtable.Init()
|
||||
// disable rate limit
|
||||
|
|
|
@ -1489,6 +1489,14 @@ func convertToExtraParams(indexParam IndexParam) ([]*commonpb.KeyValuePair, erro
|
|||
if indexParam.IndexType != "" {
|
||||
params = append(params, &commonpb.KeyValuePair{Key: common.IndexTypeKey, Value: indexParam.IndexType})
|
||||
}
|
||||
if indexParam.IndexType == "" {
|
||||
for key, value := range indexParam.Params {
|
||||
if key == common.IndexTypeKey {
|
||||
params = append(params, &commonpb.KeyValuePair{Key: common.IndexTypeKey, Value: fmt.Sprintf("%v", value)})
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if indexParam.MetricType != "" {
|
||||
params = append(params, &commonpb.KeyValuePair{Key: common.MetricTypeKey, Value: indexParam.MetricType})
|
||||
}
|
||||
|
|
|
@ -160,7 +160,7 @@ class TestCreateIndex(TestBase):
|
|||
# create index
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"indexParams": [{"fieldName": "word_count", "indexName": "word_count_vector",
|
||||
"indexParams": [{"fieldName": "word_count", "indexName": "word_count_vector", "indexType": "INVERTED",
|
||||
"params": {"index_type": "INVERTED"}}]
|
||||
}
|
||||
rsp = self.index_client.index_create(payload)
|
||||
|
@ -177,7 +177,7 @@ class TestCreateIndex(TestBase):
|
|||
for i in range(len(expected_index)):
|
||||
assert expected_index[i]['fieldName'] == actual_index[i]['fieldName']
|
||||
assert expected_index[i]['indexName'] == actual_index[i]['indexName']
|
||||
assert expected_index[i]['params']['index_type'] == actual_index[i]['indexType']
|
||||
assert expected_index[i]['indexType'] == actual_index[i]['indexType']
|
||||
|
||||
@pytest.mark.parametrize("index_type", ["BIN_FLAT", "BIN_IVF_FLAT"])
|
||||
@pytest.mark.parametrize("metric_type", ["JACCARD", "HAMMING"])
|
||||
|
@ -228,7 +228,7 @@ class TestCreateIndex(TestBase):
|
|||
index_name = "binary_vector_index"
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"indexParams": [{"fieldName": "binary_vector", "indexName": index_name, "metricType": metric_type,
|
||||
"indexParams": [{"fieldName": "binary_vector", "indexName": index_name, "metricType": metric_type, "indexType": index_type,
|
||||
"params": {"index_type": index_type}}]
|
||||
}
|
||||
if index_type == "BIN_IVF_FLAT":
|
||||
|
@ -247,7 +247,7 @@ class TestCreateIndex(TestBase):
|
|||
for i in range(len(expected_index)):
|
||||
assert expected_index[i]['fieldName'] == actual_index[i]['fieldName']
|
||||
assert expected_index[i]['indexName'] == actual_index[i]['indexName']
|
||||
assert expected_index[i]['params']['index_type'] == actual_index[i]['indexType']
|
||||
assert expected_index[i]['indexType'] == actual_index[i]['indexType']
|
||||
|
||||
@pytest.mark.parametrize("insert_round", [1])
|
||||
@pytest.mark.parametrize("auto_id", [True])
|
||||
|
|
Loading…
Reference in New Issue