enhance: [GoSDK] support expression template (#38568)

Related to #36672

This PR add
- Expression template for search, query & hybrid search
- fix hybrid search rerank param
- add reranker interface(migrate from go sdk old repo)

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/38577/head
congqixia 2024-12-19 11:20:47 +08:00 committed by GitHub
parent 78438ef41e
commit 01cfb1fd97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 471 additions and 24 deletions

View File

@ -158,8 +158,11 @@ func (c *Client) parseSearchResult(sch *entity.Schema, outputFields []string, fi
}
func (c *Client) Query(ctx context.Context, option QueryOption, callOptions ...grpc.CallOption) (ResultSet, error) {
req := option.Request()
var resultSet ResultSet
req, err := option.Request()
if err != nil {
return resultSet, err
}
collection, err := c.getCollection(ctx, req.GetCollectionName())
if err != nil {

View File

@ -17,9 +17,11 @@
package milvusclient
import (
"fmt"
"math/rand"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -137,3 +139,170 @@ func (s *SearchOptionSuite) TestPlaceHolder() {
func TestSearchOption(t *testing.T) {
suite.Run(t, new(SearchOptionSuite))
}
func TestAny2TmplValue(t *testing.T) {
t.Run("primitives", func(t *testing.T) {
t.Run("int", func(t *testing.T) {
v := rand.Int()
val, err := any2TmplValue(v)
assert.NoError(t, err)
assert.EqualValues(t, v, val.GetInt64Val())
})
t.Run("int32", func(t *testing.T) {
v := rand.Int31()
val, err := any2TmplValue(v)
assert.NoError(t, err)
assert.EqualValues(t, v, val.GetInt64Val())
})
t.Run("int64", func(t *testing.T) {
v := rand.Int63()
val, err := any2TmplValue(v)
assert.NoError(t, err)
assert.EqualValues(t, v, val.GetInt64Val())
})
t.Run("float32", func(t *testing.T) {
v := rand.Float32()
val, err := any2TmplValue(v)
assert.NoError(t, err)
assert.EqualValues(t, v, val.GetFloatVal())
})
t.Run("float64", func(t *testing.T) {
v := rand.Float64()
val, err := any2TmplValue(v)
assert.NoError(t, err)
assert.EqualValues(t, v, val.GetFloatVal())
})
t.Run("bool", func(t *testing.T) {
val, err := any2TmplValue(true)
assert.NoError(t, err)
assert.True(t, val.GetBoolVal())
})
t.Run("string", func(t *testing.T) {
v := fmt.Sprintf("%v", rand.Int())
val, err := any2TmplValue(v)
assert.NoError(t, err)
assert.EqualValues(t, v, val.GetStringVal())
})
})
t.Run("slice", func(t *testing.T) {
t.Run("int", func(t *testing.T) {
l := rand.Intn(10) + 1
v := make([]int, 0, l)
for i := 0; i < l; i++ {
v = append(v, rand.Int())
}
val, err := any2TmplValue(v)
assert.NoError(t, err)
data := val.GetArrayVal().GetLongData().GetData()
assert.Equal(t, l, len(data))
for i, val := range data {
assert.EqualValues(t, v[i], val)
}
})
t.Run("int32", func(t *testing.T) {
l := rand.Intn(10) + 1
v := make([]int32, 0, l)
for i := 0; i < l; i++ {
v = append(v, rand.Int31())
}
val, err := any2TmplValue(v)
assert.NoError(t, err)
data := val.GetArrayVal().GetLongData().GetData()
assert.Equal(t, l, len(data))
for i, val := range data {
assert.EqualValues(t, v[i], val)
}
})
t.Run("int64", func(t *testing.T) {
l := rand.Intn(10) + 1
v := make([]int64, 0, l)
for i := 0; i < l; i++ {
v = append(v, rand.Int63())
}
val, err := any2TmplValue(v)
assert.NoError(t, err)
data := val.GetArrayVal().GetLongData().GetData()
assert.Equal(t, l, len(data))
for i, val := range data {
assert.EqualValues(t, v[i], val)
}
})
t.Run("float32", func(t *testing.T) {
l := rand.Intn(10) + 1
v := make([]float32, 0, l)
for i := 0; i < l; i++ {
v = append(v, rand.Float32())
}
val, err := any2TmplValue(v)
assert.NoError(t, err)
data := val.GetArrayVal().GetDoubleData().GetData()
assert.Equal(t, l, len(data))
for i, val := range data {
assert.EqualValues(t, v[i], val)
}
})
t.Run("float64", func(t *testing.T) {
l := rand.Intn(10) + 1
v := make([]float64, 0, l)
for i := 0; i < l; i++ {
v = append(v, rand.Float64())
}
val, err := any2TmplValue(v)
assert.NoError(t, err)
data := val.GetArrayVal().GetDoubleData().GetData()
assert.Equal(t, l, len(data))
for i, val := range data {
assert.EqualValues(t, v[i], val)
}
})
t.Run("bool", func(t *testing.T) {
l := rand.Intn(10) + 1
v := make([]bool, 0, l)
for i := 0; i < l; i++ {
v = append(v, rand.Int()%2 == 0)
}
val, err := any2TmplValue(v)
assert.NoError(t, err)
data := val.GetArrayVal().GetBoolData().GetData()
assert.Equal(t, l, len(data))
for i, val := range data {
assert.EqualValues(t, v[i], val)
}
})
t.Run("string", func(t *testing.T) {
l := rand.Intn(10) + 1
v := make([]string, 0, l)
for i := 0; i < l; i++ {
v = append(v, fmt.Sprintf("%v", rand.Int()))
}
val, err := any2TmplValue(v)
assert.NoError(t, err)
data := val.GetArrayVal().GetStringData().GetData()
assert.Equal(t, l, len(data))
for i, val := range data {
assert.EqualValues(t, v[i], val)
}
})
})
t.Run("unsupported", func(*testing.T) {
_, err := any2TmplValue(struct{}{})
assert.Error(t, err)
_, err = any2TmplValue([]struct{}{})
assert.Error(t, err)
})
}

View File

@ -18,6 +18,8 @@ package milvusclient
import (
"encoding/json"
"fmt"
"reflect"
"strconv"
"github.com/cockroachdb/errors"
@ -25,6 +27,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/index"
)
@ -59,22 +62,24 @@ type searchOption struct {
type annRequest struct {
vectors []entity.Vector
annField string
metricsType entity.MetricType
searchParam map[string]string
groupByField string
annParam index.AnnParam
ignoreGrowing bool
expr string
topK int
offset int
annField string
metricsType entity.MetricType
searchParam map[string]string
groupByField string
annParam index.AnnParam
ignoreGrowing bool
expr string
topK int
offset int
templateParams map[string]any
}
func NewAnnRequest(annField string, limit int, vectors ...entity.Vector) *annRequest {
return &annRequest{
annField: annField,
vectors: vectors,
topK: limit,
annField: annField,
vectors: vectors,
topK: limit,
templateParams: make(map[string]any),
}
}
@ -116,9 +121,98 @@ func (r *annRequest) searchRequest() (*milvuspb.SearchRequest, error) {
}
request.SearchParams = entity.MapKvPairs(params)
request.ExprTemplateValues = make(map[string]*schemapb.TemplateValue)
for key, value := range r.templateParams {
tmplVal, err := any2TmplValue(value)
if err != nil {
return nil, err
}
request.ExprTemplateValues[key] = tmplVal
}
return request, nil
}
func any2TmplValue(val any) (*schemapb.TemplateValue, error) {
result := &schemapb.TemplateValue{}
switch v := val.(type) {
case int, int8, int16, int32:
result.Val = &schemapb.TemplateValue_Int64Val{Int64Val: reflect.ValueOf(v).Int()}
case int64:
result.Val = &schemapb.TemplateValue_Int64Val{Int64Val: v}
case float32:
result.Val = &schemapb.TemplateValue_FloatVal{FloatVal: float64(v)}
case float64:
result.Val = &schemapb.TemplateValue_FloatVal{FloatVal: v}
case bool:
result.Val = &schemapb.TemplateValue_BoolVal{BoolVal: v}
case string:
result.Val = &schemapb.TemplateValue_StringVal{StringVal: v}
default:
if reflect.TypeOf(val).Kind() == reflect.Slice {
return slice2TmplValue(val)
}
return nil, fmt.Errorf("unsupported template value type: %T", val)
}
return result, nil
}
func slice2TmplValue(val any) (*schemapb.TemplateValue, error) {
arrVal := &schemapb.TemplateValue_ArrayVal{
ArrayVal: &schemapb.TemplateArrayValue{},
}
rv := reflect.ValueOf(val)
switch t := reflect.TypeOf(val).Elem().Kind(); t {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
data := make([]int64, 0, rv.Len())
for i := 0; i < rv.Len(); i++ {
data = append(data, rv.Index(i).Int())
}
arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_LongData{
LongData: &schemapb.LongArray{
Data: data,
},
}
case reflect.Bool:
data := make([]bool, 0, rv.Len())
for i := 0; i < rv.Len(); i++ {
data = append(data, rv.Index(i).Bool())
}
arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_BoolData{
BoolData: &schemapb.BoolArray{
Data: data,
},
}
case reflect.Float32, reflect.Float64:
data := make([]float64, 0, rv.Len())
for i := 0; i < rv.Len(); i++ {
data = append(data, rv.Index(i).Float())
}
arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: data,
},
}
case reflect.String:
data := make([]string, 0, rv.Len())
for i := 0; i < rv.Len(); i++ {
data = append(data, rv.Index(i).String())
}
arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_StringData{
StringData: &schemapb.StringArray{
Data: data,
},
}
default:
return nil, fmt.Errorf("unsupported template type: slice of %v", t)
}
return &schemapb.TemplateValue{
Val: arrVal,
}, nil
}
func (r *annRequest) WithANNSField(annsField string) *annRequest {
r.annField = annsField
return r
@ -144,6 +238,11 @@ func (r *annRequest) WithFilter(expr string) *annRequest {
return r
}
func (r *annRequest) WithTemplateParam(key string, val any) *annRequest {
r.templateParams[key] = val
return r
}
func (r *annRequest) WithOffset(offset int) *annRequest {
r.offset = offset
return r
@ -179,6 +278,11 @@ func (opt *searchOption) WithFilter(expr string) *searchOption {
return opt
}
func (opt *searchOption) WithTemplateParam(key string, val any) *searchOption {
opt.annRequest.WithTemplateParam(key, val)
return opt
}
func (opt *searchOption) WithOffset(offset int) *searchOption {
opt.annRequest.WithOffset(offset)
return opt
@ -223,9 +327,10 @@ func (opt *searchOption) WithSearchParam(key, value string) *searchOption {
func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption {
return &searchOption{
annRequest: &annRequest{
vectors: vectors,
searchParam: make(map[string]string),
topK: limit,
vectors: vectors,
searchParam: make(map[string]string),
topK: limit,
templateParams: make(map[string]any),
},
collectionName: collectionName,
useDefaultConsistencyLevel: true,
@ -293,6 +398,10 @@ type hybridSearchOption struct {
outputFields []string
useDefaultConsistency bool
consistencyLevel entity.ConsistencyLevel
limit int
offset int
reranker Reranker
}
func (opt *hybridSearchOption) WithConsistencyLevel(cl entity.ConsistencyLevel) *hybridSearchOption {
@ -311,6 +420,16 @@ func (opt *hybridSearchOption) WithOutputFields(outputFields ...string) *hybridS
return opt
}
func (opt *hybridSearchOption) WithReranker(reranker Reranker) *hybridSearchOption {
opt.reranker = reranker
return opt
}
func (opt *hybridSearchOption) WithOffset(offset int) *hybridSearchOption {
opt.offset = offset
return opt
}
func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, error) {
requests := make([]*milvuspb.SearchRequest, 0, len(opt.reqs))
for _, annRequest := range opt.reqs {
@ -321,6 +440,15 @@ func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, e
requests = append(requests, req)
}
var params []*commonpb.KeyValuePair
if opt.reranker != nil {
params = opt.reranker.GetParams()
}
params = append(params, &commonpb.KeyValuePair{Key: spLimit, Value: strconv.FormatInt(int64(opt.limit), 10)})
if opt.offset > 0 {
params = append(params, &commonpb.KeyValuePair{Key: spOffset, Value: strconv.FormatInt(int64(opt.offset), 10)})
}
return &milvuspb.HybridSearchRequest{
CollectionName: opt.collectionName,
PartitionNames: opt.partitionNames,
@ -328,20 +456,22 @@ func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, e
UseDefaultConsistency: opt.useDefaultConsistency,
ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel),
OutputFields: opt.outputFields,
RankParams: params,
}, nil
}
func NewHybridSearchOption(collectionName string, annRequests ...*annRequest) *hybridSearchOption {
func NewHybridSearchOption(collectionName string, limit int, annRequests ...*annRequest) *hybridSearchOption {
return &hybridSearchOption{
collectionName: collectionName,
reqs: annRequests,
useDefaultConsistency: true,
limit: limit,
}
}
type QueryOption interface {
Request() *milvuspb.QueryRequest
Request() (*milvuspb.QueryRequest, error)
}
type queryOption struct {
@ -352,10 +482,11 @@ type queryOption struct {
consistencyLevel entity.ConsistencyLevel
useDefaultConsistencyLevel bool
expr string
templateParams map[string]any
}
func (opt *queryOption) Request() *milvuspb.QueryRequest {
return &milvuspb.QueryRequest{
func (opt *queryOption) Request() (*milvuspb.QueryRequest, error) {
req := &milvuspb.QueryRequest{
CollectionName: opt.collectionName,
PartitionNames: opt.partitionNames,
OutputFields: opt.outputFields,
@ -364,6 +495,17 @@ func (opt *queryOption) Request() *milvuspb.QueryRequest {
QueryParams: entity.MapKvPairs(opt.queryParams),
ConsistencyLevel: opt.consistencyLevel.CommonConsistencyLevel(),
}
req.ExprTemplateValues = make(map[string]*schemapb.TemplateValue)
for key, value := range opt.templateParams {
tmplVal, err := any2TmplValue(value)
if err != nil {
return nil, err
}
req.ExprTemplateValues[key] = tmplVal
}
return req, nil
}
func (opt *queryOption) WithFilter(expr string) *queryOption {
@ -371,6 +513,11 @@ func (opt *queryOption) WithFilter(expr string) *queryOption {
return opt
}
func (opt *queryOption) WithTemplateParam(key string, val any) *queryOption {
opt.templateParams[key] = val
return opt
}
func (opt *queryOption) WithOffset(offset int) *queryOption {
if opt.queryParams == nil {
opt.queryParams = make(map[string]string)
@ -408,5 +555,6 @@ func NewQueryOption(collectionName string) *queryOption {
collectionName: collectionName,
useDefaultConsistencyLevel: true,
consistencyLevel: entity.ClBounded,
templateParams: make(map[string]any),
}
}

View File

@ -75,6 +75,8 @@ func (s *ReadSuite) TestSearch() {
return rand.Float32()
})),
}).WithPartitions(partitionName).
WithFilter("id > {tmpl_id}").
WithTemplateParam("tmpl_id", 100).
WithGroupByField("group_by").
WithSearchParam("ignore_growing", "true").
WithAnnParam(ap),
@ -178,11 +180,11 @@ func (s *ReadSuite) TestHybridSearch() {
}, nil
}).Once()
_, err := s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
_, err := s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, 5, NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
return rand.Float32()
}))).WithFilter("ID > 100"), NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
return rand.Float32()
})))).WithConsistencyLevel(entity.ClStrong).WithPartitons(partitionName).WithOutputFields("*"))
})))).WithConsistencyLevel(entity.ClStrong).WithPartitons(partitionName).WithReranker(NewRRFReranker()).WithOutputFields("*"))
s.NoError(err)
})
@ -190,14 +192,14 @@ func (s *ReadSuite) TestHybridSearch() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.setupCache(collectionName, s.schemaDyn)
_, err := s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, NewAnnRequest("vector", 10, nonSupportData{})))
_, err := s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, 5, NewAnnRequest("vector", 10, nonSupportData{})))
s.Error(err)
s.mock.EXPECT().HybridSearch(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, hsr *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
return nil, merr.WrapErrServiceInternal("mocked")
}).Once()
_, err = s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
_, err = s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, 5, NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
return rand.Float32()
}))).WithFilter("ID > 100"), NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
return rand.Float32()
@ -224,6 +226,14 @@ func (s *ReadSuite) TestQuery() {
_, err := s.client.Query(ctx, NewQueryOption(collectionName).WithPartitions(partitionName))
s.NoError(err)
})
s.Run("bad_request", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.setupCache(collectionName, s.schema)
_, err := s.client.Query(ctx, NewQueryOption(collectionName).WithFilter("id > {tmpl_id}").WithTemplateParam("tmpl_id", struct{}{}))
s.Error(err)
})
}
func TestRead(t *testing.T) {

View File

@ -0,0 +1,62 @@
package milvusclient
import (
"encoding/json"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
)
const (
rerankType = "strategy"
rerankParams = "params"
rffParam = "k"
weightedParam = "weights"
rrfRerankType = `rrf`
weightedRerankType = `weighted`
)
type Reranker interface {
GetParams() []*commonpb.KeyValuePair
}
type rrfReranker struct {
K float64 `json:"k,omitempty"`
}
func (r *rrfReranker) WithK(k float64) *rrfReranker {
r.K = k
return r
}
func (r *rrfReranker) GetParams() []*commonpb.KeyValuePair {
bs, _ := json.Marshal(r)
return []*commonpb.KeyValuePair{
{Key: rerankType, Value: rrfRerankType},
{Key: rerankParams, Value: string(bs)},
}
}
func NewRRFReranker() *rrfReranker {
return &rrfReranker{K: 60}
}
type weightedReranker struct {
Weights []float64 `json:"weights,omitempty"`
}
func (r *weightedReranker) GetParams() []*commonpb.KeyValuePair {
bs, _ := json.Marshal(r)
return []*commonpb.KeyValuePair{
{Key: rerankType, Value: weightedRerankType},
{Key: rerankParams, Value: string(bs)},
}
}
func NewWeightedReranker(weights []float64) *weightedReranker {
return &weightedReranker{
Weights: weights,
}
}

View File

@ -0,0 +1,55 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package milvusclient
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
)
func TestReranker(t *testing.T) {
checkParam := func(params []*commonpb.KeyValuePair, key string, value string) bool {
for _, kv := range params {
if kv.Key == key && kv.Value == value {
return true
}
}
return false
}
t.Run("rffReranker", func(t *testing.T) {
rr := NewRRFReranker()
params := rr.GetParams()
assert.True(t, checkParam(params, rerankType, rrfRerankType))
assert.True(t, checkParam(params, rerankParams, `{"k":60}`), "default k shall be 60")
rr.WithK(50)
params = rr.GetParams()
assert.True(t, checkParam(params, rerankType, rrfRerankType))
assert.True(t, checkParam(params, rerankParams, `{"k":50}`))
})
t.Run("weightedReranker", func(t *testing.T) {
rr := NewWeightedReranker([]float64{1, 2, 1})
params := rr.GetParams()
assert.True(t, checkParam(params, rerankType, weightedRerankType))
assert.True(t, checkParam(params, rerankParams, `{"weights":[1,2,1]}`))
})
}