mirror of https://github.com/milvus-io/milvus.git
enhance: refactor createCollection in RESTful API (#36790)
1. support isClusteringKey in restful api; 2. throw err if passed invalid 'enableDynamicField' params 3. parameters in indexparams are not processed properly, related with #36365 Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>pull/36878/head
parent
d566b0ceff
commit
c9752bd2e6
|
@ -1152,8 +1152,14 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
|||
}
|
||||
enableDynamic := EnableDynamic
|
||||
if enStr, ok := httpReq.Params["enableDynamicField"]; ok {
|
||||
if en, err := strconv.ParseBool(fmt.Sprintf("%v", enStr)); err == nil {
|
||||
enableDynamic = en
|
||||
enableDynamic, err = strconv.ParseBool(fmt.Sprintf("%v", enStr))
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("high level restful api, parse enableDynamicField fail", zap.Error(err), zap.Any("request", anyReq))
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: "parse enableDynamicField fail, err:" + err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
schema, err = proto.Marshal(&schemapb.CollectionSchema{
|
||||
|
@ -1340,7 +1346,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
|||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 {
|
||||
if len(httpReq.Schema.Fields) == 0 {
|
||||
if len(httpReq.MetricType) == 0 {
|
||||
httpReq.MetricType = DefaultMetricType
|
||||
}
|
||||
|
@ -1377,8 +1383,15 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
|
|||
IndexName: indexParam.IndexName,
|
||||
ExtraParams: []*commonpb.KeyValuePair{{Key: common.MetricTypeKey, Value: indexParam.MetricType}},
|
||||
}
|
||||
for key, value := range indexParam.Params {
|
||||
createIndexReq.ExtraParams = append(createIndexReq.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
|
||||
createIndexReq.ExtraParams, err = convertToExtraParams(indexParam)
|
||||
if err != nil {
|
||||
// will not happen
|
||||
log.Ctx(ctx).Warn("high level restful api, convertToExtraParams fail", zap.Error(err), zap.Any("request", anyReq))
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
HTTPReturnCode: merr.Code(err),
|
||||
HTTPReturnMessage: err.Error(),
|
||||
})
|
||||
return resp, err
|
||||
}
|
||||
statusResponse, err := wrapperProxyWithLimit(ctx, c, createIndexReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||
return h.proxy.CreateIndex(ctx, req.(*milvuspb.CreateIndexRequest))
|
||||
|
|
|
@ -847,7 +847,7 @@ func TestCreateCollection(t *testing.T) {
|
|||
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "schema": {
|
||||
"fields": [
|
||||
{"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}},
|
||||
{"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}},
|
||||
{"fieldName": "word_count", "dataType": "Int64","isClusteringKey":true, "elementTypeParams": {}},
|
||||
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": 2}}
|
||||
]
|
||||
}, "indexParams": [{"fieldName": "book_xxx", "indexName": "book_intro_vector", "metricType": "L2"}]}`),
|
||||
|
@ -983,6 +983,13 @@ func TestCreateCollection(t *testing.T) {
|
|||
errMsg: "convert defaultValue fail, err:Wrong defaultValue type: invalid parameter[expected=number][actual=10]",
|
||||
errCode: 1100,
|
||||
})
|
||||
postTestCases = append(postTestCases, requestBodyTestCase{
|
||||
path: path,
|
||||
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "Varchar",` +
|
||||
`"params": {"max_length": 256, "enableDynamicField": 100, "shardsNum": 2, "consistencyLevel": "unknown", "ttlSeconds": 3600}}`),
|
||||
errMsg: "parse enableDynamicField fail, err:strconv.ParseBool: parsing \"100\": invalid syntax",
|
||||
errCode: 65535,
|
||||
})
|
||||
|
||||
for _, testcase := range postTestCases {
|
||||
t.Run("post"+testcase.path, func(t *testing.T) {
|
||||
|
|
|
@ -285,8 +285,9 @@ func (req *GrantReq) GetDbName() string { return req.DbName }
|
|||
|
||||
type IndexParam struct {
|
||||
FieldName string `json:"fieldName" binding:"required"`
|
||||
IndexName string `json:"indexName" binding:"required"`
|
||||
MetricType string `json:"metricType" binding:"required"`
|
||||
IndexName string `json:"indexName"`
|
||||
MetricType string `json:"metricType"`
|
||||
IndexType string `json:"indexType"`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
}
|
||||
|
||||
|
@ -319,6 +320,7 @@ type FieldSchema struct {
|
|||
ElementDataType string `json:"elementDataType"`
|
||||
IsPrimary bool `json:"isPrimary"`
|
||||
IsPartitionKey bool `json:"isPartitionKey"`
|
||||
IsClusteringKey bool `json:"isClusteringKey"`
|
||||
ElementTypeParams map[string]interface{} `json:"elementTypeParams" binding:"required"`
|
||||
Nullable bool `json:"nullable" binding:"required"`
|
||||
DefaultValue interface{} `json:"defaultValue" binding:"required"`
|
||||
|
|
|
@ -1482,3 +1482,21 @@ func convertDefaultValue(value interface{}, dataType schemapb.DataType) (*schema
|
|||
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("Unexpected default value type: %d", dataType))
|
||||
}
|
||||
}
|
||||
|
||||
func convertToExtraParams(indexParam IndexParam) ([]*commonpb.KeyValuePair, error) {
|
||||
var params []*commonpb.KeyValuePair
|
||||
if indexParam.IndexType != "" {
|
||||
params = append(params, &commonpb.KeyValuePair{Key: common.IndexTypeKey, Value: indexParam.IndexType})
|
||||
}
|
||||
if indexParam.MetricType != "" {
|
||||
params = append(params, &commonpb.KeyValuePair{Key: common.MetricTypeKey, Value: indexParam.MetricType})
|
||||
}
|
||||
if len(indexParam.Params) != 0 {
|
||||
v, err := json.Marshal(indexParam.Params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
params = append(params, &commonpb.KeyValuePair{Key: common.IndexParamsKey, Value: string(v)})
|
||||
}
|
||||
return params, nil
|
||||
}
|
||||
|
|
|
@ -1737,3 +1737,27 @@ func TestConvertConsistencyLevel(t *testing.T) {
|
|||
_, _, err = convertConsistencyLevel("test")
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestConvertToExtraParams(t *testing.T) {
|
||||
indexParams := IndexParam{
|
||||
MetricType: "L2",
|
||||
IndexType: "IVF_FLAT",
|
||||
Params: map[string]interface{}{
|
||||
"nlist": 128,
|
||||
},
|
||||
}
|
||||
params, err := convertToExtraParams(indexParams)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 3, len(params))
|
||||
for _, pair := range params {
|
||||
if pair.Key == common.MetricTypeKey {
|
||||
assert.Equal(t, "L2", pair.Value)
|
||||
}
|
||||
if pair.Key == common.IndexTypeKey {
|
||||
assert.Equal(t, "IVF_FLAT", pair.Value)
|
||||
}
|
||||
if pair.Key == common.IndexParamsKey {
|
||||
assert.Equal(t, string("{\"nlist\":128}"), pair.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue