mirror of https://github.com/milvus-io/milvus.git
611 lines
16 KiB
Go
611 lines
16 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package milvusclient
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/client/v2/column"
|
|
"github.com/milvus-io/milvus/client/v2/entity"
|
|
"github.com/milvus-io/milvus/client/v2/index"
|
|
)
|
|
|
|
const (
|
|
spAnnsField = `anns_field`
|
|
spTopK = `topk`
|
|
spOffset = `offset`
|
|
spLimit = `limit`
|
|
spParams = `params`
|
|
spMetricsType = `metric_type`
|
|
spRoundDecimal = `round_decimal`
|
|
spIgnoreGrowing = `ignore_growing`
|
|
spGroupBy = `group_by_field`
|
|
spGroupSize = `group_size`
|
|
spStrictGroupSize = `strict_group_size`
|
|
)
|
|
|
|
type SearchOption interface {
|
|
Request() (*milvuspb.SearchRequest, error)
|
|
}
|
|
|
|
var _ SearchOption = (*searchOption)(nil)
|
|
|
|
type searchOption struct {
|
|
annRequest *annRequest
|
|
collectionName string
|
|
partitionNames []string
|
|
outputFields []string
|
|
consistencyLevel entity.ConsistencyLevel
|
|
useDefaultConsistencyLevel bool
|
|
}
|
|
|
|
type annRequest struct {
|
|
vectors []entity.Vector
|
|
|
|
annField string
|
|
metricsType entity.MetricType
|
|
searchParam map[string]string
|
|
groupByField string
|
|
groupSize int
|
|
strictGroupSize bool
|
|
annParam index.AnnParam
|
|
ignoreGrowing bool
|
|
expr string
|
|
topK int
|
|
offset int
|
|
templateParams map[string]any
|
|
}
|
|
|
|
func NewAnnRequest(annField string, limit int, vectors ...entity.Vector) *annRequest {
|
|
return &annRequest{
|
|
annField: annField,
|
|
vectors: vectors,
|
|
topK: limit,
|
|
searchParam: make(map[string]string),
|
|
templateParams: make(map[string]any),
|
|
}
|
|
}
|
|
|
|
func (r *annRequest) searchRequest() (*milvuspb.SearchRequest, error) {
|
|
request := &milvuspb.SearchRequest{
|
|
Nq: int64(len(r.vectors)),
|
|
Dsl: r.expr,
|
|
DslType: commonpb.DslType_BoolExprV1,
|
|
}
|
|
|
|
var err error
|
|
// placeholder group
|
|
request.PlaceholderGroup, err = vector2PlaceholderGroupBytes(r.vectors)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
params := map[string]string{
|
|
spAnnsField: r.annField,
|
|
spTopK: strconv.Itoa(r.topK),
|
|
spOffset: strconv.Itoa(r.offset),
|
|
spMetricsType: string(r.metricsType),
|
|
spRoundDecimal: "-1",
|
|
spIgnoreGrowing: strconv.FormatBool(r.ignoreGrowing),
|
|
}
|
|
if r.groupByField != "" {
|
|
params[spGroupBy] = r.groupByField
|
|
}
|
|
if r.groupSize != 0 {
|
|
params[spGroupSize] = strconv.Itoa(r.groupSize)
|
|
}
|
|
if r.strictGroupSize {
|
|
params[spStrictGroupSize] = "true"
|
|
}
|
|
// ann param
|
|
if r.annParam != nil {
|
|
bs, _ := json.Marshal(r.annParam.Params())
|
|
params[spParams] = string(bs)
|
|
} else {
|
|
params[spParams] = "{}"
|
|
}
|
|
// use custom search param to overwrite
|
|
for k, v := range r.searchParam {
|
|
params[k] = v
|
|
}
|
|
request.SearchParams = entity.MapKvPairs(params)
|
|
|
|
request.ExprTemplateValues = make(map[string]*schemapb.TemplateValue)
|
|
for key, value := range r.templateParams {
|
|
tmplVal, err := any2TmplValue(value)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
request.ExprTemplateValues[key] = tmplVal
|
|
}
|
|
|
|
return request, nil
|
|
}
|
|
|
|
func any2TmplValue(val any) (*schemapb.TemplateValue, error) {
|
|
result := &schemapb.TemplateValue{}
|
|
switch v := val.(type) {
|
|
case int, int8, int16, int32:
|
|
result.Val = &schemapb.TemplateValue_Int64Val{Int64Val: reflect.ValueOf(v).Int()}
|
|
case int64:
|
|
result.Val = &schemapb.TemplateValue_Int64Val{Int64Val: v}
|
|
case float32:
|
|
result.Val = &schemapb.TemplateValue_FloatVal{FloatVal: float64(v)}
|
|
case float64:
|
|
result.Val = &schemapb.TemplateValue_FloatVal{FloatVal: v}
|
|
case bool:
|
|
result.Val = &schemapb.TemplateValue_BoolVal{BoolVal: v}
|
|
case string:
|
|
result.Val = &schemapb.TemplateValue_StringVal{StringVal: v}
|
|
default:
|
|
if reflect.TypeOf(val).Kind() == reflect.Slice {
|
|
return slice2TmplValue(val)
|
|
}
|
|
return nil, fmt.Errorf("unsupported template value type: %T", val)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func slice2TmplValue(val any) (*schemapb.TemplateValue, error) {
|
|
arrVal := &schemapb.TemplateValue_ArrayVal{
|
|
ArrayVal: &schemapb.TemplateArrayValue{},
|
|
}
|
|
|
|
rv := reflect.ValueOf(val)
|
|
switch t := reflect.TypeOf(val).Elem().Kind(); t {
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
data := make([]int64, 0, rv.Len())
|
|
for i := 0; i < rv.Len(); i++ {
|
|
data = append(data, rv.Index(i).Int())
|
|
}
|
|
arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: data,
|
|
},
|
|
}
|
|
case reflect.Bool:
|
|
data := make([]bool, 0, rv.Len())
|
|
for i := 0; i < rv.Len(); i++ {
|
|
data = append(data, rv.Index(i).Bool())
|
|
}
|
|
arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_BoolData{
|
|
BoolData: &schemapb.BoolArray{
|
|
Data: data,
|
|
},
|
|
}
|
|
case reflect.Float32, reflect.Float64:
|
|
data := make([]float64, 0, rv.Len())
|
|
for i := 0; i < rv.Len(); i++ {
|
|
data = append(data, rv.Index(i).Float())
|
|
}
|
|
arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_DoubleData{
|
|
DoubleData: &schemapb.DoubleArray{
|
|
Data: data,
|
|
},
|
|
}
|
|
case reflect.String:
|
|
data := make([]string, 0, rv.Len())
|
|
for i := 0; i < rv.Len(); i++ {
|
|
data = append(data, rv.Index(i).String())
|
|
}
|
|
arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_StringData{
|
|
StringData: &schemapb.StringArray{
|
|
Data: data,
|
|
},
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("unsupported template type: slice of %v", t)
|
|
}
|
|
|
|
return &schemapb.TemplateValue{
|
|
Val: arrVal,
|
|
}, nil
|
|
}
|
|
|
|
func (r *annRequest) WithANNSField(annsField string) *annRequest {
|
|
r.annField = annsField
|
|
return r
|
|
}
|
|
|
|
func (r *annRequest) WithGroupByField(groupByField string) *annRequest {
|
|
r.groupByField = groupByField
|
|
return r
|
|
}
|
|
|
|
func (r *annRequest) WithGroupSize(groupSize int) *annRequest {
|
|
r.groupSize = groupSize
|
|
return r
|
|
}
|
|
|
|
func (r *annRequest) WithStrictGroupSize(strictGroupSize bool) *annRequest {
|
|
r.strictGroupSize = strictGroupSize
|
|
return r
|
|
}
|
|
|
|
func (r *annRequest) WithSearchParam(key, value string) *annRequest {
|
|
r.searchParam[key] = value
|
|
return r
|
|
}
|
|
|
|
func (r *annRequest) WithAnnParam(ap index.AnnParam) *annRequest {
|
|
r.annParam = ap
|
|
return r
|
|
}
|
|
|
|
func (r *annRequest) WithFilter(expr string) *annRequest {
|
|
r.expr = expr
|
|
return r
|
|
}
|
|
|
|
func (r *annRequest) WithTemplateParam(key string, val any) *annRequest {
|
|
r.templateParams[key] = val
|
|
return r
|
|
}
|
|
|
|
func (r *annRequest) WithOffset(offset int) *annRequest {
|
|
r.offset = offset
|
|
return r
|
|
}
|
|
|
|
func (r *annRequest) WithIgnoreGrowing(ignoreGrowing bool) *annRequest {
|
|
r.ignoreGrowing = ignoreGrowing
|
|
return r
|
|
}
|
|
|
|
func (opt *searchOption) Request() (*milvuspb.SearchRequest, error) {
|
|
request, err := opt.annRequest.searchRequest()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
request.CollectionName = opt.collectionName
|
|
request.PartitionNames = opt.partitionNames
|
|
request.ConsistencyLevel = commonpb.ConsistencyLevel(opt.consistencyLevel)
|
|
request.UseDefaultConsistency = opt.useDefaultConsistencyLevel
|
|
request.OutputFields = opt.outputFields
|
|
|
|
return request, nil
|
|
}
|
|
|
|
func (opt *searchOption) WithPartitions(partitionNames ...string) *searchOption {
|
|
opt.partitionNames = partitionNames
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithFilter(expr string) *searchOption {
|
|
opt.annRequest.WithFilter(expr)
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithTemplateParam(key string, val any) *searchOption {
|
|
opt.annRequest.WithTemplateParam(key, val)
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithOffset(offset int) *searchOption {
|
|
opt.annRequest.WithOffset(offset)
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithOutputFields(fieldNames ...string) *searchOption {
|
|
opt.outputFields = fieldNames
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *searchOption {
|
|
opt.consistencyLevel = consistencyLevel
|
|
opt.useDefaultConsistencyLevel = false
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithANNSField(annsField string) *searchOption {
|
|
opt.annRequest.WithANNSField(annsField)
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithGroupByField(groupByField string) *searchOption {
|
|
opt.annRequest.WithGroupByField(groupByField)
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithGroupSize(groupSize int) *searchOption {
|
|
opt.annRequest.WithGroupSize(groupSize)
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithStrictGroupSize(strictGroupSize bool) *searchOption {
|
|
opt.annRequest.WithStrictGroupSize(strictGroupSize)
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithIgnoreGrowing(ignoreGrowing bool) *searchOption {
|
|
opt.annRequest.WithIgnoreGrowing(ignoreGrowing)
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithAnnParam(ap index.AnnParam) *searchOption {
|
|
opt.annRequest.WithAnnParam(ap)
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithSearchParam(key, value string) *searchOption {
|
|
opt.annRequest.WithSearchParam(key, value)
|
|
return opt
|
|
}
|
|
|
|
func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption {
|
|
return &searchOption{
|
|
annRequest: NewAnnRequest("", limit, vectors...),
|
|
collectionName: collectionName,
|
|
useDefaultConsistencyLevel: true,
|
|
consistencyLevel: entity.ClBounded,
|
|
}
|
|
}
|
|
|
|
func vector2PlaceholderGroupBytes(vectors []entity.Vector) ([]byte, error) {
|
|
phv, err := vector2Placeholder(vectors)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
phg := &commonpb.PlaceholderGroup{
|
|
Placeholders: []*commonpb.PlaceholderValue{
|
|
phv,
|
|
},
|
|
}
|
|
|
|
bs, err := proto.Marshal(phg)
|
|
return bs, err
|
|
}
|
|
|
|
func vector2Placeholder(vectors []entity.Vector) (*commonpb.PlaceholderValue, error) {
|
|
var placeHolderType commonpb.PlaceholderType
|
|
ph := &commonpb.PlaceholderValue{
|
|
Tag: "$0",
|
|
Values: make([][]byte, 0, len(vectors)),
|
|
}
|
|
if len(vectors) == 0 {
|
|
return ph, nil
|
|
}
|
|
switch vectors[0].(type) {
|
|
case entity.FloatVector:
|
|
placeHolderType = commonpb.PlaceholderType_FloatVector
|
|
case entity.BinaryVector:
|
|
placeHolderType = commonpb.PlaceholderType_BinaryVector
|
|
case entity.BFloat16Vector:
|
|
placeHolderType = commonpb.PlaceholderType_BFloat16Vector
|
|
case entity.Float16Vector:
|
|
placeHolderType = commonpb.PlaceholderType_Float16Vector
|
|
case entity.SparseEmbedding:
|
|
placeHolderType = commonpb.PlaceholderType_SparseFloatVector
|
|
case entity.Text:
|
|
placeHolderType = commonpb.PlaceholderType_VarChar
|
|
default:
|
|
return nil, errors.Newf("unsupported search data type: %T", vectors[0])
|
|
}
|
|
ph.Type = placeHolderType
|
|
for _, vector := range vectors {
|
|
ph.Values = append(ph.Values, vector.Serialize())
|
|
}
|
|
return ph, nil
|
|
}
|
|
|
|
type HybridSearchOption interface {
|
|
HybridRequest() (*milvuspb.HybridSearchRequest, error)
|
|
}
|
|
|
|
type hybridSearchOption struct {
|
|
collectionName string
|
|
partitionNames []string
|
|
|
|
reqs []*annRequest
|
|
|
|
outputFields []string
|
|
useDefaultConsistency bool
|
|
consistencyLevel entity.ConsistencyLevel
|
|
|
|
limit int
|
|
offset int
|
|
reranker Reranker
|
|
}
|
|
|
|
func (opt *hybridSearchOption) WithConsistencyLevel(cl entity.ConsistencyLevel) *hybridSearchOption {
|
|
opt.consistencyLevel = cl
|
|
opt.useDefaultConsistency = false
|
|
return opt
|
|
}
|
|
|
|
func (opt *hybridSearchOption) WithPartitons(partitions ...string) *hybridSearchOption {
|
|
opt.partitionNames = partitions
|
|
return opt
|
|
}
|
|
|
|
func (opt *hybridSearchOption) WithOutputFields(outputFields ...string) *hybridSearchOption {
|
|
opt.outputFields = outputFields
|
|
return opt
|
|
}
|
|
|
|
func (opt *hybridSearchOption) WithReranker(reranker Reranker) *hybridSearchOption {
|
|
opt.reranker = reranker
|
|
return opt
|
|
}
|
|
|
|
func (opt *hybridSearchOption) WithOffset(offset int) *hybridSearchOption {
|
|
opt.offset = offset
|
|
return opt
|
|
}
|
|
|
|
func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, error) {
|
|
requests := make([]*milvuspb.SearchRequest, 0, len(opt.reqs))
|
|
for _, annRequest := range opt.reqs {
|
|
req, err := annRequest.searchRequest()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
requests = append(requests, req)
|
|
}
|
|
|
|
var params []*commonpb.KeyValuePair
|
|
if opt.reranker != nil {
|
|
params = opt.reranker.GetParams()
|
|
}
|
|
params = append(params, &commonpb.KeyValuePair{Key: spLimit, Value: strconv.FormatInt(int64(opt.limit), 10)})
|
|
if opt.offset > 0 {
|
|
params = append(params, &commonpb.KeyValuePair{Key: spOffset, Value: strconv.FormatInt(int64(opt.offset), 10)})
|
|
}
|
|
|
|
return &milvuspb.HybridSearchRequest{
|
|
CollectionName: opt.collectionName,
|
|
PartitionNames: opt.partitionNames,
|
|
Requests: requests,
|
|
UseDefaultConsistency: opt.useDefaultConsistency,
|
|
ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel),
|
|
OutputFields: opt.outputFields,
|
|
RankParams: params,
|
|
}, nil
|
|
}
|
|
|
|
func NewHybridSearchOption(collectionName string, limit int, annRequests ...*annRequest) *hybridSearchOption {
|
|
return &hybridSearchOption{
|
|
collectionName: collectionName,
|
|
|
|
reqs: annRequests,
|
|
useDefaultConsistency: true,
|
|
limit: limit,
|
|
}
|
|
}
|
|
|
|
type QueryOption interface {
|
|
Request() (*milvuspb.QueryRequest, error)
|
|
}
|
|
|
|
type queryOption struct {
|
|
collectionName string
|
|
partitionNames []string
|
|
queryParams map[string]string
|
|
outputFields []string
|
|
consistencyLevel entity.ConsistencyLevel
|
|
useDefaultConsistencyLevel bool
|
|
expr string
|
|
templateParams map[string]any
|
|
}
|
|
|
|
func (opt *queryOption) Request() (*milvuspb.QueryRequest, error) {
|
|
req := &milvuspb.QueryRequest{
|
|
CollectionName: opt.collectionName,
|
|
PartitionNames: opt.partitionNames,
|
|
OutputFields: opt.outputFields,
|
|
|
|
Expr: opt.expr,
|
|
QueryParams: entity.MapKvPairs(opt.queryParams),
|
|
ConsistencyLevel: opt.consistencyLevel.CommonConsistencyLevel(),
|
|
UseDefaultConsistency: opt.useDefaultConsistencyLevel,
|
|
}
|
|
|
|
req.ExprTemplateValues = make(map[string]*schemapb.TemplateValue)
|
|
for key, value := range opt.templateParams {
|
|
tmplVal, err := any2TmplValue(value)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.ExprTemplateValues[key] = tmplVal
|
|
}
|
|
|
|
return req, nil
|
|
}
|
|
|
|
func (opt *queryOption) WithFilter(expr string) *queryOption {
|
|
opt.expr = expr
|
|
return opt
|
|
}
|
|
|
|
func (opt *queryOption) WithTemplateParam(key string, val any) *queryOption {
|
|
opt.templateParams[key] = val
|
|
return opt
|
|
}
|
|
|
|
func (opt *queryOption) WithOffset(offset int) *queryOption {
|
|
if opt.queryParams == nil {
|
|
opt.queryParams = make(map[string]string)
|
|
}
|
|
opt.queryParams[spOffset] = strconv.Itoa(offset)
|
|
return opt
|
|
}
|
|
|
|
func (opt *queryOption) WithLimit(limit int) *queryOption {
|
|
if opt.queryParams == nil {
|
|
opt.queryParams = make(map[string]string)
|
|
}
|
|
opt.queryParams[spLimit] = strconv.Itoa(limit)
|
|
return opt
|
|
}
|
|
|
|
func (opt *queryOption) WithOutputFields(fieldNames ...string) *queryOption {
|
|
opt.outputFields = fieldNames
|
|
return opt
|
|
}
|
|
|
|
func (opt *queryOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *queryOption {
|
|
opt.consistencyLevel = consistencyLevel
|
|
opt.useDefaultConsistencyLevel = false
|
|
return opt
|
|
}
|
|
|
|
func (opt *queryOption) WithPartitions(partitionNames ...string) *queryOption {
|
|
opt.partitionNames = partitionNames
|
|
return opt
|
|
}
|
|
|
|
func (opt *queryOption) WithIDs(ids column.Column) *queryOption {
|
|
opt.expr = pks2Expr(ids)
|
|
return opt
|
|
}
|
|
|
|
func pks2Expr(ids column.Column) string {
|
|
var expr string
|
|
pkName := ids.Name()
|
|
switch ids.Type() {
|
|
case entity.FieldTypeInt64:
|
|
expr = fmt.Sprintf("%s in %s", pkName, strings.Join(strings.Fields(fmt.Sprint(ids.FieldData().GetScalars().GetLongData().GetData())), ","))
|
|
case entity.FieldTypeVarChar:
|
|
data := ids.FieldData().GetScalars().GetData().(*schemapb.ScalarField_StringData).StringData.GetData()
|
|
for i := range data {
|
|
data[i] = fmt.Sprintf("\"%s\"", data[i])
|
|
}
|
|
expr = fmt.Sprintf("%s in [%s]", pkName, strings.Join(data, ","))
|
|
}
|
|
return expr
|
|
}
|
|
|
|
func NewQueryOption(collectionName string) *queryOption {
|
|
return &queryOption{
|
|
collectionName: collectionName,
|
|
useDefaultConsistencyLevel: true,
|
|
consistencyLevel: entity.ClBounded,
|
|
templateParams: make(map[string]any),
|
|
}
|
|
}
|