enhance: refactor createCollection in RESTful API (#36790)

1.  support isClusteringKey in restful api;
2. throw err if passed invalid 'enableDynamicField' params
3. parameters in indexparams are not processed properly, related with
#36365

Signed-off-by: lixinguo <xinguo.li@zilliz.com>
Co-authored-by: lixinguo <xinguo.li@zilliz.com>
pull/36878/head
smellthemoon 2024-10-15 10:29:22 +08:00 committed by GitHub
parent d566b0ceff
commit c9752bd2e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 72 additions and 8 deletions

View File

@ -1152,8 +1152,14 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
}
enableDynamic := EnableDynamic
if enStr, ok := httpReq.Params["enableDynamicField"]; ok {
if en, err := strconv.ParseBool(fmt.Sprintf("%v", enStr)); err == nil {
enableDynamic = en
enableDynamic, err = strconv.ParseBool(fmt.Sprintf("%v", enStr))
if err != nil {
log.Ctx(ctx).Warn("high level restful api, parse enableDynamicField fail", zap.Error(err), zap.Any("request", anyReq))
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(err),
HTTPReturnMessage: "parse enableDynamicField fail, err:" + err.Error(),
})
return nil, err
}
}
schema, err = proto.Marshal(&schemapb.CollectionSchema{
@ -1340,7 +1346,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
if err != nil {
return resp, err
}
if httpReq.Schema.Fields == nil || len(httpReq.Schema.Fields) == 0 {
if len(httpReq.Schema.Fields) == 0 {
if len(httpReq.MetricType) == 0 {
httpReq.MetricType = DefaultMetricType
}
@ -1377,8 +1383,15 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe
IndexName: indexParam.IndexName,
ExtraParams: []*commonpb.KeyValuePair{{Key: common.MetricTypeKey, Value: indexParam.MetricType}},
}
for key, value := range indexParam.Params {
createIndexReq.ExtraParams = append(createIndexReq.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
createIndexReq.ExtraParams, err = convertToExtraParams(indexParam)
if err != nil {
// will not happen
log.Ctx(ctx).Warn("high level restful api, convertToExtraParams fail", zap.Error(err), zap.Any("request", anyReq))
HTTPAbortReturn(c, http.StatusOK, gin.H{
HTTPReturnCode: merr.Code(err),
HTTPReturnMessage: err.Error(),
})
return resp, err
}
statusResponse, err := wrapperProxyWithLimit(ctx, c, createIndexReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateIndex(ctx, req.(*milvuspb.CreateIndexRequest))

View File

@ -847,7 +847,7 @@ func TestCreateCollection(t *testing.T) {
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "schema": {
"fields": [
{"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}},
{"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}},
{"fieldName": "word_count", "dataType": "Int64","isClusteringKey":true, "elementTypeParams": {}},
{"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": 2}}
]
}, "indexParams": [{"fieldName": "book_xxx", "indexName": "book_intro_vector", "metricType": "L2"}]}`),
@ -983,6 +983,13 @@ func TestCreateCollection(t *testing.T) {
errMsg: "convert defaultValue fail, err:Wrong defaultValue type: invalid parameter[expected=number][actual=10]",
errCode: 1100,
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2, "idType": "Varchar",` +
`"params": {"max_length": 256, "enableDynamicField": 100, "shardsNum": 2, "consistencyLevel": "unknown", "ttlSeconds": 3600}}`),
errMsg: "parse enableDynamicField fail, err:strconv.ParseBool: parsing \"100\": invalid syntax",
errCode: 65535,
})
for _, testcase := range postTestCases {
t.Run("post"+testcase.path, func(t *testing.T) {

View File

@ -285,8 +285,9 @@ func (req *GrantReq) GetDbName() string { return req.DbName }
type IndexParam struct {
FieldName string `json:"fieldName" binding:"required"`
IndexName string `json:"indexName" binding:"required"`
MetricType string `json:"metricType" binding:"required"`
IndexName string `json:"indexName"`
MetricType string `json:"metricType"`
IndexType string `json:"indexType"`
Params map[string]interface{} `json:"params"`
}
@ -319,6 +320,7 @@ type FieldSchema struct {
ElementDataType string `json:"elementDataType"`
IsPrimary bool `json:"isPrimary"`
IsPartitionKey bool `json:"isPartitionKey"`
IsClusteringKey bool `json:"isClusteringKey"`
ElementTypeParams map[string]interface{} `json:"elementTypeParams" binding:"required"`
Nullable bool `json:"nullable" binding:"required"`
DefaultValue interface{} `json:"defaultValue" binding:"required"`

View File

@ -1482,3 +1482,21 @@ func convertDefaultValue(value interface{}, dataType schemapb.DataType) (*schema
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("Unexpected default value type: %d", dataType))
}
}
func convertToExtraParams(indexParam IndexParam) ([]*commonpb.KeyValuePair, error) {
var params []*commonpb.KeyValuePair
if indexParam.IndexType != "" {
params = append(params, &commonpb.KeyValuePair{Key: common.IndexTypeKey, Value: indexParam.IndexType})
}
if indexParam.MetricType != "" {
params = append(params, &commonpb.KeyValuePair{Key: common.MetricTypeKey, Value: indexParam.MetricType})
}
if len(indexParam.Params) != 0 {
v, err := json.Marshal(indexParam.Params)
if err != nil {
return nil, err
}
params = append(params, &commonpb.KeyValuePair{Key: common.IndexParamsKey, Value: string(v)})
}
return params, nil
}

View File

@ -1737,3 +1737,27 @@ func TestConvertConsistencyLevel(t *testing.T) {
_, _, err = convertConsistencyLevel("test")
assert.NotNil(t, err)
}
func TestConvertToExtraParams(t *testing.T) {
indexParams := IndexParam{
MetricType: "L2",
IndexType: "IVF_FLAT",
Params: map[string]interface{}{
"nlist": 128,
},
}
params, err := convertToExtraParams(indexParams)
assert.Equal(t, nil, err)
assert.Equal(t, 3, len(params))
for _, pair := range params {
if pair.Key == common.MetricTypeKey {
assert.Equal(t, "L2", pair.Value)
}
if pair.Key == common.IndexTypeKey {
assert.Equal(t, "IVF_FLAT", pair.Value)
}
if pair.Key == common.IndexParamsKey {
assert.Equal(t, string("{\"nlist\":128}"), pair.Value)
}
}
}