mirror of https://github.com/milvus-io/milvus.git
enhance: [GoSDK] support expression template (#38568)
Related to #36672 This PR add - Expression template for search, query & hybrid search - fix hybrid search rerank param - add reranker interface(migrate from go sdk old repo) --------- Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>pull/38577/head
parent
78438ef41e
commit
01cfb1fd97
|
@ -158,8 +158,11 @@ func (c *Client) parseSearchResult(sch *entity.Schema, outputFields []string, fi
|
|||
}
|
||||
|
||||
func (c *Client) Query(ctx context.Context, option QueryOption, callOptions ...grpc.CallOption) (ResultSet, error) {
|
||||
req := option.Request()
|
||||
var resultSet ResultSet
|
||||
req, err := option.Request()
|
||||
if err != nil {
|
||||
return resultSet, err
|
||||
}
|
||||
|
||||
collection, err := c.getCollection(ctx, req.GetCollectionName())
|
||||
if err != nil {
|
||||
|
|
|
@ -17,9 +17,11 @@
|
|||
package milvusclient
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
|
@ -137,3 +139,170 @@ func (s *SearchOptionSuite) TestPlaceHolder() {
|
|||
func TestSearchOption(t *testing.T) {
|
||||
suite.Run(t, new(SearchOptionSuite))
|
||||
}
|
||||
|
||||
func TestAny2TmplValue(t *testing.T) {
|
||||
t.Run("primitives", func(t *testing.T) {
|
||||
t.Run("int", func(t *testing.T) {
|
||||
v := rand.Int()
|
||||
val, err := any2TmplValue(v)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, v, val.GetInt64Val())
|
||||
})
|
||||
|
||||
t.Run("int32", func(t *testing.T) {
|
||||
v := rand.Int31()
|
||||
val, err := any2TmplValue(v)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, v, val.GetInt64Val())
|
||||
})
|
||||
|
||||
t.Run("int64", func(t *testing.T) {
|
||||
v := rand.Int63()
|
||||
val, err := any2TmplValue(v)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, v, val.GetInt64Val())
|
||||
})
|
||||
|
||||
t.Run("float32", func(t *testing.T) {
|
||||
v := rand.Float32()
|
||||
val, err := any2TmplValue(v)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, v, val.GetFloatVal())
|
||||
})
|
||||
|
||||
t.Run("float64", func(t *testing.T) {
|
||||
v := rand.Float64()
|
||||
val, err := any2TmplValue(v)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, v, val.GetFloatVal())
|
||||
})
|
||||
|
||||
t.Run("bool", func(t *testing.T) {
|
||||
val, err := any2TmplValue(true)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, val.GetBoolVal())
|
||||
})
|
||||
|
||||
t.Run("string", func(t *testing.T) {
|
||||
v := fmt.Sprintf("%v", rand.Int())
|
||||
val, err := any2TmplValue(v)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, v, val.GetStringVal())
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("slice", func(t *testing.T) {
|
||||
t.Run("int", func(t *testing.T) {
|
||||
l := rand.Intn(10) + 1
|
||||
v := make([]int, 0, l)
|
||||
for i := 0; i < l; i++ {
|
||||
v = append(v, rand.Int())
|
||||
}
|
||||
val, err := any2TmplValue(v)
|
||||
assert.NoError(t, err)
|
||||
data := val.GetArrayVal().GetLongData().GetData()
|
||||
assert.Equal(t, l, len(data))
|
||||
for i, val := range data {
|
||||
assert.EqualValues(t, v[i], val)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("int32", func(t *testing.T) {
|
||||
l := rand.Intn(10) + 1
|
||||
v := make([]int32, 0, l)
|
||||
for i := 0; i < l; i++ {
|
||||
v = append(v, rand.Int31())
|
||||
}
|
||||
val, err := any2TmplValue(v)
|
||||
assert.NoError(t, err)
|
||||
data := val.GetArrayVal().GetLongData().GetData()
|
||||
assert.Equal(t, l, len(data))
|
||||
for i, val := range data {
|
||||
assert.EqualValues(t, v[i], val)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("int64", func(t *testing.T) {
|
||||
l := rand.Intn(10) + 1
|
||||
v := make([]int64, 0, l)
|
||||
for i := 0; i < l; i++ {
|
||||
v = append(v, rand.Int63())
|
||||
}
|
||||
val, err := any2TmplValue(v)
|
||||
assert.NoError(t, err)
|
||||
data := val.GetArrayVal().GetLongData().GetData()
|
||||
assert.Equal(t, l, len(data))
|
||||
for i, val := range data {
|
||||
assert.EqualValues(t, v[i], val)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("float32", func(t *testing.T) {
|
||||
l := rand.Intn(10) + 1
|
||||
v := make([]float32, 0, l)
|
||||
for i := 0; i < l; i++ {
|
||||
v = append(v, rand.Float32())
|
||||
}
|
||||
val, err := any2TmplValue(v)
|
||||
assert.NoError(t, err)
|
||||
data := val.GetArrayVal().GetDoubleData().GetData()
|
||||
assert.Equal(t, l, len(data))
|
||||
for i, val := range data {
|
||||
assert.EqualValues(t, v[i], val)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("float64", func(t *testing.T) {
|
||||
l := rand.Intn(10) + 1
|
||||
v := make([]float64, 0, l)
|
||||
for i := 0; i < l; i++ {
|
||||
v = append(v, rand.Float64())
|
||||
}
|
||||
val, err := any2TmplValue(v)
|
||||
assert.NoError(t, err)
|
||||
data := val.GetArrayVal().GetDoubleData().GetData()
|
||||
assert.Equal(t, l, len(data))
|
||||
for i, val := range data {
|
||||
assert.EqualValues(t, v[i], val)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bool", func(t *testing.T) {
|
||||
l := rand.Intn(10) + 1
|
||||
v := make([]bool, 0, l)
|
||||
for i := 0; i < l; i++ {
|
||||
v = append(v, rand.Int()%2 == 0)
|
||||
}
|
||||
val, err := any2TmplValue(v)
|
||||
assert.NoError(t, err)
|
||||
data := val.GetArrayVal().GetBoolData().GetData()
|
||||
assert.Equal(t, l, len(data))
|
||||
for i, val := range data {
|
||||
assert.EqualValues(t, v[i], val)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("string", func(t *testing.T) {
|
||||
l := rand.Intn(10) + 1
|
||||
v := make([]string, 0, l)
|
||||
for i := 0; i < l; i++ {
|
||||
v = append(v, fmt.Sprintf("%v", rand.Int()))
|
||||
}
|
||||
val, err := any2TmplValue(v)
|
||||
assert.NoError(t, err)
|
||||
data := val.GetArrayVal().GetStringData().GetData()
|
||||
assert.Equal(t, l, len(data))
|
||||
for i, val := range data {
|
||||
assert.EqualValues(t, v[i], val)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("unsupported", func(*testing.T) {
|
||||
_, err := any2TmplValue(struct{}{})
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = any2TmplValue([]struct{}{})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -18,6 +18,8 @@ package milvusclient
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
@ -25,6 +27,7 @@ import (
|
|||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/client/v2/index"
|
||||
)
|
||||
|
@ -59,22 +62,24 @@ type searchOption struct {
|
|||
type annRequest struct {
|
||||
vectors []entity.Vector
|
||||
|
||||
annField string
|
||||
metricsType entity.MetricType
|
||||
searchParam map[string]string
|
||||
groupByField string
|
||||
annParam index.AnnParam
|
||||
ignoreGrowing bool
|
||||
expr string
|
||||
topK int
|
||||
offset int
|
||||
annField string
|
||||
metricsType entity.MetricType
|
||||
searchParam map[string]string
|
||||
groupByField string
|
||||
annParam index.AnnParam
|
||||
ignoreGrowing bool
|
||||
expr string
|
||||
topK int
|
||||
offset int
|
||||
templateParams map[string]any
|
||||
}
|
||||
|
||||
func NewAnnRequest(annField string, limit int, vectors ...entity.Vector) *annRequest {
|
||||
return &annRequest{
|
||||
annField: annField,
|
||||
vectors: vectors,
|
||||
topK: limit,
|
||||
annField: annField,
|
||||
vectors: vectors,
|
||||
topK: limit,
|
||||
templateParams: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -116,9 +121,98 @@ func (r *annRequest) searchRequest() (*milvuspb.SearchRequest, error) {
|
|||
}
|
||||
request.SearchParams = entity.MapKvPairs(params)
|
||||
|
||||
request.ExprTemplateValues = make(map[string]*schemapb.TemplateValue)
|
||||
for key, value := range r.templateParams {
|
||||
tmplVal, err := any2TmplValue(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.ExprTemplateValues[key] = tmplVal
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func any2TmplValue(val any) (*schemapb.TemplateValue, error) {
|
||||
result := &schemapb.TemplateValue{}
|
||||
switch v := val.(type) {
|
||||
case int, int8, int16, int32:
|
||||
result.Val = &schemapb.TemplateValue_Int64Val{Int64Val: reflect.ValueOf(v).Int()}
|
||||
case int64:
|
||||
result.Val = &schemapb.TemplateValue_Int64Val{Int64Val: v}
|
||||
case float32:
|
||||
result.Val = &schemapb.TemplateValue_FloatVal{FloatVal: float64(v)}
|
||||
case float64:
|
||||
result.Val = &schemapb.TemplateValue_FloatVal{FloatVal: v}
|
||||
case bool:
|
||||
result.Val = &schemapb.TemplateValue_BoolVal{BoolVal: v}
|
||||
case string:
|
||||
result.Val = &schemapb.TemplateValue_StringVal{StringVal: v}
|
||||
default:
|
||||
if reflect.TypeOf(val).Kind() == reflect.Slice {
|
||||
return slice2TmplValue(val)
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported template value type: %T", val)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func slice2TmplValue(val any) (*schemapb.TemplateValue, error) {
|
||||
arrVal := &schemapb.TemplateValue_ArrayVal{
|
||||
ArrayVal: &schemapb.TemplateArrayValue{},
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(val)
|
||||
switch t := reflect.TypeOf(val).Elem().Kind(); t {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
data := make([]int64, 0, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
data = append(data, rv.Index(i).Int())
|
||||
}
|
||||
arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: data,
|
||||
},
|
||||
}
|
||||
case reflect.Bool:
|
||||
data := make([]bool, 0, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
data = append(data, rv.Index(i).Bool())
|
||||
}
|
||||
arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: data,
|
||||
},
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
data := make([]float64, 0, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
data = append(data, rv.Index(i).Float())
|
||||
}
|
||||
arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_DoubleData{
|
||||
DoubleData: &schemapb.DoubleArray{
|
||||
Data: data,
|
||||
},
|
||||
}
|
||||
case reflect.String:
|
||||
data := make([]string, 0, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
data = append(data, rv.Index(i).String())
|
||||
}
|
||||
arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: data,
|
||||
},
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported template type: slice of %v", t)
|
||||
}
|
||||
|
||||
return &schemapb.TemplateValue{
|
||||
Val: arrVal,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *annRequest) WithANNSField(annsField string) *annRequest {
|
||||
r.annField = annsField
|
||||
return r
|
||||
|
@ -144,6 +238,11 @@ func (r *annRequest) WithFilter(expr string) *annRequest {
|
|||
return r
|
||||
}
|
||||
|
||||
func (r *annRequest) WithTemplateParam(key string, val any) *annRequest {
|
||||
r.templateParams[key] = val
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *annRequest) WithOffset(offset int) *annRequest {
|
||||
r.offset = offset
|
||||
return r
|
||||
|
@ -179,6 +278,11 @@ func (opt *searchOption) WithFilter(expr string) *searchOption {
|
|||
return opt
|
||||
}
|
||||
|
||||
func (opt *searchOption) WithTemplateParam(key string, val any) *searchOption {
|
||||
opt.annRequest.WithTemplateParam(key, val)
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *searchOption) WithOffset(offset int) *searchOption {
|
||||
opt.annRequest.WithOffset(offset)
|
||||
return opt
|
||||
|
@ -223,9 +327,10 @@ func (opt *searchOption) WithSearchParam(key, value string) *searchOption {
|
|||
func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption {
|
||||
return &searchOption{
|
||||
annRequest: &annRequest{
|
||||
vectors: vectors,
|
||||
searchParam: make(map[string]string),
|
||||
topK: limit,
|
||||
vectors: vectors,
|
||||
searchParam: make(map[string]string),
|
||||
topK: limit,
|
||||
templateParams: make(map[string]any),
|
||||
},
|
||||
collectionName: collectionName,
|
||||
useDefaultConsistencyLevel: true,
|
||||
|
@ -293,6 +398,10 @@ type hybridSearchOption struct {
|
|||
outputFields []string
|
||||
useDefaultConsistency bool
|
||||
consistencyLevel entity.ConsistencyLevel
|
||||
|
||||
limit int
|
||||
offset int
|
||||
reranker Reranker
|
||||
}
|
||||
|
||||
func (opt *hybridSearchOption) WithConsistencyLevel(cl entity.ConsistencyLevel) *hybridSearchOption {
|
||||
|
@ -311,6 +420,16 @@ func (opt *hybridSearchOption) WithOutputFields(outputFields ...string) *hybridS
|
|||
return opt
|
||||
}
|
||||
|
||||
func (opt *hybridSearchOption) WithReranker(reranker Reranker) *hybridSearchOption {
|
||||
opt.reranker = reranker
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *hybridSearchOption) WithOffset(offset int) *hybridSearchOption {
|
||||
opt.offset = offset
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, error) {
|
||||
requests := make([]*milvuspb.SearchRequest, 0, len(opt.reqs))
|
||||
for _, annRequest := range opt.reqs {
|
||||
|
@ -321,6 +440,15 @@ func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, e
|
|||
requests = append(requests, req)
|
||||
}
|
||||
|
||||
var params []*commonpb.KeyValuePair
|
||||
if opt.reranker != nil {
|
||||
params = opt.reranker.GetParams()
|
||||
}
|
||||
params = append(params, &commonpb.KeyValuePair{Key: spLimit, Value: strconv.FormatInt(int64(opt.limit), 10)})
|
||||
if opt.offset > 0 {
|
||||
params = append(params, &commonpb.KeyValuePair{Key: spOffset, Value: strconv.FormatInt(int64(opt.offset), 10)})
|
||||
}
|
||||
|
||||
return &milvuspb.HybridSearchRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
PartitionNames: opt.partitionNames,
|
||||
|
@ -328,20 +456,22 @@ func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, e
|
|||
UseDefaultConsistency: opt.useDefaultConsistency,
|
||||
ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel),
|
||||
OutputFields: opt.outputFields,
|
||||
RankParams: params,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewHybridSearchOption(collectionName string, annRequests ...*annRequest) *hybridSearchOption {
|
||||
func NewHybridSearchOption(collectionName string, limit int, annRequests ...*annRequest) *hybridSearchOption {
|
||||
return &hybridSearchOption{
|
||||
collectionName: collectionName,
|
||||
|
||||
reqs: annRequests,
|
||||
useDefaultConsistency: true,
|
||||
limit: limit,
|
||||
}
|
||||
}
|
||||
|
||||
type QueryOption interface {
|
||||
Request() *milvuspb.QueryRequest
|
||||
Request() (*milvuspb.QueryRequest, error)
|
||||
}
|
||||
|
||||
type queryOption struct {
|
||||
|
@ -352,10 +482,11 @@ type queryOption struct {
|
|||
consistencyLevel entity.ConsistencyLevel
|
||||
useDefaultConsistencyLevel bool
|
||||
expr string
|
||||
templateParams map[string]any
|
||||
}
|
||||
|
||||
func (opt *queryOption) Request() *milvuspb.QueryRequest {
|
||||
return &milvuspb.QueryRequest{
|
||||
func (opt *queryOption) Request() (*milvuspb.QueryRequest, error) {
|
||||
req := &milvuspb.QueryRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
PartitionNames: opt.partitionNames,
|
||||
OutputFields: opt.outputFields,
|
||||
|
@ -364,6 +495,17 @@ func (opt *queryOption) Request() *milvuspb.QueryRequest {
|
|||
QueryParams: entity.MapKvPairs(opt.queryParams),
|
||||
ConsistencyLevel: opt.consistencyLevel.CommonConsistencyLevel(),
|
||||
}
|
||||
|
||||
req.ExprTemplateValues = make(map[string]*schemapb.TemplateValue)
|
||||
for key, value := range opt.templateParams {
|
||||
tmplVal, err := any2TmplValue(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.ExprTemplateValues[key] = tmplVal
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (opt *queryOption) WithFilter(expr string) *queryOption {
|
||||
|
@ -371,6 +513,11 @@ func (opt *queryOption) WithFilter(expr string) *queryOption {
|
|||
return opt
|
||||
}
|
||||
|
||||
func (opt *queryOption) WithTemplateParam(key string, val any) *queryOption {
|
||||
opt.templateParams[key] = val
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *queryOption) WithOffset(offset int) *queryOption {
|
||||
if opt.queryParams == nil {
|
||||
opt.queryParams = make(map[string]string)
|
||||
|
@ -408,5 +555,6 @@ func NewQueryOption(collectionName string) *queryOption {
|
|||
collectionName: collectionName,
|
||||
useDefaultConsistencyLevel: true,
|
||||
consistencyLevel: entity.ClBounded,
|
||||
templateParams: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -75,6 +75,8 @@ func (s *ReadSuite) TestSearch() {
|
|||
return rand.Float32()
|
||||
})),
|
||||
}).WithPartitions(partitionName).
|
||||
WithFilter("id > {tmpl_id}").
|
||||
WithTemplateParam("tmpl_id", 100).
|
||||
WithGroupByField("group_by").
|
||||
WithSearchParam("ignore_growing", "true").
|
||||
WithAnnParam(ap),
|
||||
|
@ -178,11 +180,11 @@ func (s *ReadSuite) TestHybridSearch() {
|
|||
}, nil
|
||||
}).Once()
|
||||
|
||||
_, err := s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
||||
_, err := s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, 5, NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
||||
return rand.Float32()
|
||||
}))).WithFilter("ID > 100"), NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
||||
return rand.Float32()
|
||||
})))).WithConsistencyLevel(entity.ClStrong).WithPartitons(partitionName).WithOutputFields("*"))
|
||||
})))).WithConsistencyLevel(entity.ClStrong).WithPartitons(partitionName).WithReranker(NewRRFReranker()).WithOutputFields("*"))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
|
@ -190,14 +192,14 @@ func (s *ReadSuite) TestHybridSearch() {
|
|||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
s.setupCache(collectionName, s.schemaDyn)
|
||||
|
||||
_, err := s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, NewAnnRequest("vector", 10, nonSupportData{})))
|
||||
_, err := s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, 5, NewAnnRequest("vector", 10, nonSupportData{})))
|
||||
s.Error(err)
|
||||
|
||||
s.mock.EXPECT().HybridSearch(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, hsr *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
|
||||
return nil, merr.WrapErrServiceInternal("mocked")
|
||||
}).Once()
|
||||
|
||||
_, err = s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
||||
_, err = s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, 5, NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
||||
return rand.Float32()
|
||||
}))).WithFilter("ID > 100"), NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
||||
return rand.Float32()
|
||||
|
@ -224,6 +226,14 @@ func (s *ReadSuite) TestQuery() {
|
|||
_, err := s.client.Query(ctx, NewQueryOption(collectionName).WithPartitions(partitionName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("bad_request", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
s.setupCache(collectionName, s.schema)
|
||||
|
||||
_, err := s.client.Query(ctx, NewQueryOption(collectionName).WithFilter("id > {tmpl_id}").WithTemplateParam("tmpl_id", struct{}{}))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
package milvusclient
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
)
|
||||
|
||||
const (
|
||||
rerankType = "strategy"
|
||||
rerankParams = "params"
|
||||
rffParam = "k"
|
||||
weightedParam = "weights"
|
||||
|
||||
rrfRerankType = `rrf`
|
||||
weightedRerankType = `weighted`
|
||||
)
|
||||
|
||||
type Reranker interface {
|
||||
GetParams() []*commonpb.KeyValuePair
|
||||
}
|
||||
|
||||
type rrfReranker struct {
|
||||
K float64 `json:"k,omitempty"`
|
||||
}
|
||||
|
||||
func (r *rrfReranker) WithK(k float64) *rrfReranker {
|
||||
r.K = k
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *rrfReranker) GetParams() []*commonpb.KeyValuePair {
|
||||
bs, _ := json.Marshal(r)
|
||||
|
||||
return []*commonpb.KeyValuePair{
|
||||
{Key: rerankType, Value: rrfRerankType},
|
||||
{Key: rerankParams, Value: string(bs)},
|
||||
}
|
||||
}
|
||||
|
||||
func NewRRFReranker() *rrfReranker {
|
||||
return &rrfReranker{K: 60}
|
||||
}
|
||||
|
||||
type weightedReranker struct {
|
||||
Weights []float64 `json:"weights,omitempty"`
|
||||
}
|
||||
|
||||
func (r *weightedReranker) GetParams() []*commonpb.KeyValuePair {
|
||||
bs, _ := json.Marshal(r)
|
||||
|
||||
return []*commonpb.KeyValuePair{
|
||||
{Key: rerankType, Value: weightedRerankType},
|
||||
{Key: rerankParams, Value: string(bs)},
|
||||
}
|
||||
}
|
||||
|
||||
func NewWeightedReranker(weights []float64) *weightedReranker {
|
||||
return &weightedReranker{
|
||||
Weights: weights,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package milvusclient
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
)
|
||||
|
||||
func TestReranker(t *testing.T) {
|
||||
checkParam := func(params []*commonpb.KeyValuePair, key string, value string) bool {
|
||||
for _, kv := range params {
|
||||
if kv.Key == key && kv.Value == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
t.Run("rffReranker", func(t *testing.T) {
|
||||
rr := NewRRFReranker()
|
||||
params := rr.GetParams()
|
||||
assert.True(t, checkParam(params, rerankType, rrfRerankType))
|
||||
assert.True(t, checkParam(params, rerankParams, `{"k":60}`), "default k shall be 60")
|
||||
|
||||
rr.WithK(50)
|
||||
params = rr.GetParams()
|
||||
assert.True(t, checkParam(params, rerankType, rrfRerankType))
|
||||
assert.True(t, checkParam(params, rerankParams, `{"k":50}`))
|
||||
})
|
||||
|
||||
t.Run("weightedReranker", func(t *testing.T) {
|
||||
rr := NewWeightedReranker([]float64{1, 2, 1})
|
||||
params := rr.GetParams()
|
||||
assert.True(t, checkParam(params, rerankType, weightedRerankType))
|
||||
assert.True(t, checkParam(params, rerankParams, `{"weights":[1,2,1]}`))
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue