// Licensed to the LF AI & Data foundation under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package milvusclient import ( "encoding/json" "fmt" "reflect" "strconv" "strings" "github.com/cockroachdb/errors" "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/client/v2/column" "github.com/milvus-io/milvus/client/v2/entity" "github.com/milvus-io/milvus/client/v2/index" ) const ( spAnnsField = `anns_field` spTopK = `topk` spOffset = `offset` spLimit = `limit` spParams = `params` spMetricsType = `metric_type` spRoundDecimal = `round_decimal` spIgnoreGrowing = `ignore_growing` spGroupBy = `group_by_field` spGroupSize = `group_size` spStrictGroupSize = `strict_group_size` ) type SearchOption interface { Request() (*milvuspb.SearchRequest, error) } var _ SearchOption = (*searchOption)(nil) type searchOption struct { annRequest *annRequest collectionName string partitionNames []string outputFields []string consistencyLevel entity.ConsistencyLevel useDefaultConsistencyLevel bool } type annRequest struct { vectors []entity.Vector annField string metricsType entity.MetricType searchParam map[string]string groupByField string groupSize int strictGroupSize bool annParam index.AnnParam ignoreGrowing bool expr string topK int offset int templateParams map[string]any } func NewAnnRequest(annField string, limit int, vectors ...entity.Vector) *annRequest { return &annRequest{ annField: annField, vectors: vectors, topK: limit, searchParam: make(map[string]string), templateParams: make(map[string]any), } } func (r *annRequest) searchRequest() (*milvuspb.SearchRequest, error) { request := &milvuspb.SearchRequest{ Nq: int64(len(r.vectors)), Dsl: r.expr, DslType: commonpb.DslType_BoolExprV1, } var err error // placeholder group request.PlaceholderGroup, err = vector2PlaceholderGroupBytes(r.vectors) if err != nil { return nil, err } params := map[string]string{ spAnnsField: r.annField, spTopK: strconv.Itoa(r.topK), spOffset: strconv.Itoa(r.offset), spMetricsType: string(r.metricsType), spRoundDecimal: "-1", spIgnoreGrowing: strconv.FormatBool(r.ignoreGrowing), } if r.groupByField != "" { params[spGroupBy] = r.groupByField } if r.groupSize != 0 { params[spGroupSize] = strconv.Itoa(r.groupSize) } if r.strictGroupSize { params[spStrictGroupSize] = "true" } // ann param if r.annParam != nil { bs, _ := json.Marshal(r.annParam.Params()) params[spParams] = string(bs) } else { params[spParams] = "{}" } // use custom search param to overwrite for k, v := range r.searchParam { params[k] = v } request.SearchParams = entity.MapKvPairs(params) request.ExprTemplateValues = make(map[string]*schemapb.TemplateValue) for key, value := range r.templateParams { tmplVal, err := any2TmplValue(value) if err != nil { return nil, err } request.ExprTemplateValues[key] = tmplVal } return request, nil } func any2TmplValue(val any) (*schemapb.TemplateValue, error) { result := &schemapb.TemplateValue{} switch v := val.(type) { case int, int8, int16, int32: result.Val = &schemapb.TemplateValue_Int64Val{Int64Val: reflect.ValueOf(v).Int()} case int64: result.Val = &schemapb.TemplateValue_Int64Val{Int64Val: v} case float32: result.Val = &schemapb.TemplateValue_FloatVal{FloatVal: float64(v)} case float64: result.Val = &schemapb.TemplateValue_FloatVal{FloatVal: v} case bool: result.Val = &schemapb.TemplateValue_BoolVal{BoolVal: v} case string: result.Val = &schemapb.TemplateValue_StringVal{StringVal: v} default: if reflect.TypeOf(val).Kind() == reflect.Slice { return slice2TmplValue(val) } return nil, fmt.Errorf("unsupported template value type: %T", val) } return result, nil } func slice2TmplValue(val any) (*schemapb.TemplateValue, error) { arrVal := &schemapb.TemplateValue_ArrayVal{ ArrayVal: &schemapb.TemplateArrayValue{}, } rv := reflect.ValueOf(val) switch t := reflect.TypeOf(val).Elem().Kind(); t { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: data := make([]int64, 0, rv.Len()) for i := 0; i < rv.Len(); i++ { data = append(data, rv.Index(i).Int()) } arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_LongData{ LongData: &schemapb.LongArray{ Data: data, }, } case reflect.Bool: data := make([]bool, 0, rv.Len()) for i := 0; i < rv.Len(); i++ { data = append(data, rv.Index(i).Bool()) } arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_BoolData{ BoolData: &schemapb.BoolArray{ Data: data, }, } case reflect.Float32, reflect.Float64: data := make([]float64, 0, rv.Len()) for i := 0; i < rv.Len(); i++ { data = append(data, rv.Index(i).Float()) } arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_DoubleData{ DoubleData: &schemapb.DoubleArray{ Data: data, }, } case reflect.String: data := make([]string, 0, rv.Len()) for i := 0; i < rv.Len(); i++ { data = append(data, rv.Index(i).String()) } arrVal.ArrayVal.Data = &schemapb.TemplateArrayValue_StringData{ StringData: &schemapb.StringArray{ Data: data, }, } default: return nil, fmt.Errorf("unsupported template type: slice of %v", t) } return &schemapb.TemplateValue{ Val: arrVal, }, nil } func (r *annRequest) WithANNSField(annsField string) *annRequest { r.annField = annsField return r } func (r *annRequest) WithGroupByField(groupByField string) *annRequest { r.groupByField = groupByField return r } func (r *annRequest) WithGroupSize(groupSize int) *annRequest { r.groupSize = groupSize return r } func (r *annRequest) WithStrictGroupSize(strictGroupSize bool) *annRequest { r.strictGroupSize = strictGroupSize return r } func (r *annRequest) WithSearchParam(key, value string) *annRequest { r.searchParam[key] = value return r } func (r *annRequest) WithAnnParam(ap index.AnnParam) *annRequest { r.annParam = ap return r } func (r *annRequest) WithFilter(expr string) *annRequest { r.expr = expr return r } func (r *annRequest) WithTemplateParam(key string, val any) *annRequest { r.templateParams[key] = val return r } func (r *annRequest) WithOffset(offset int) *annRequest { r.offset = offset return r } func (r *annRequest) WithIgnoreGrowing(ignoreGrowing bool) *annRequest { r.ignoreGrowing = ignoreGrowing return r } func (opt *searchOption) Request() (*milvuspb.SearchRequest, error) { request, err := opt.annRequest.searchRequest() if err != nil { return nil, err } request.CollectionName = opt.collectionName request.PartitionNames = opt.partitionNames request.ConsistencyLevel = commonpb.ConsistencyLevel(opt.consistencyLevel) request.UseDefaultConsistency = opt.useDefaultConsistencyLevel request.OutputFields = opt.outputFields return request, nil } func (opt *searchOption) WithPartitions(partitionNames ...string) *searchOption { opt.partitionNames = partitionNames return opt } func (opt *searchOption) WithFilter(expr string) *searchOption { opt.annRequest.WithFilter(expr) return opt } func (opt *searchOption) WithTemplateParam(key string, val any) *searchOption { opt.annRequest.WithTemplateParam(key, val) return opt } func (opt *searchOption) WithOffset(offset int) *searchOption { opt.annRequest.WithOffset(offset) return opt } func (opt *searchOption) WithOutputFields(fieldNames ...string) *searchOption { opt.outputFields = fieldNames return opt } func (opt *searchOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *searchOption { opt.consistencyLevel = consistencyLevel opt.useDefaultConsistencyLevel = false return opt } func (opt *searchOption) WithANNSField(annsField string) *searchOption { opt.annRequest.WithANNSField(annsField) return opt } func (opt *searchOption) WithGroupByField(groupByField string) *searchOption { opt.annRequest.WithGroupByField(groupByField) return opt } func (opt *searchOption) WithGroupSize(groupSize int) *searchOption { opt.annRequest.WithGroupSize(groupSize) return opt } func (opt *searchOption) WithStrictGroupSize(strictGroupSize bool) *searchOption { opt.annRequest.WithStrictGroupSize(strictGroupSize) return opt } func (opt *searchOption) WithIgnoreGrowing(ignoreGrowing bool) *searchOption { opt.annRequest.WithIgnoreGrowing(ignoreGrowing) return opt } func (opt *searchOption) WithAnnParam(ap index.AnnParam) *searchOption { opt.annRequest.WithAnnParam(ap) return opt } func (opt *searchOption) WithSearchParam(key, value string) *searchOption { opt.annRequest.WithSearchParam(key, value) return opt } func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption { return &searchOption{ annRequest: NewAnnRequest("", limit, vectors...), collectionName: collectionName, useDefaultConsistencyLevel: true, consistencyLevel: entity.ClBounded, } } func vector2PlaceholderGroupBytes(vectors []entity.Vector) ([]byte, error) { phv, err := vector2Placeholder(vectors) if err != nil { return nil, err } phg := &commonpb.PlaceholderGroup{ Placeholders: []*commonpb.PlaceholderValue{ phv, }, } bs, err := proto.Marshal(phg) return bs, err } func vector2Placeholder(vectors []entity.Vector) (*commonpb.PlaceholderValue, error) { var placeHolderType commonpb.PlaceholderType ph := &commonpb.PlaceholderValue{ Tag: "$0", Values: make([][]byte, 0, len(vectors)), } if len(vectors) == 0 { return ph, nil } switch vectors[0].(type) { case entity.FloatVector: placeHolderType = commonpb.PlaceholderType_FloatVector case entity.BinaryVector: placeHolderType = commonpb.PlaceholderType_BinaryVector case entity.BFloat16Vector: placeHolderType = commonpb.PlaceholderType_BFloat16Vector case entity.Float16Vector: placeHolderType = commonpb.PlaceholderType_Float16Vector case entity.SparseEmbedding: placeHolderType = commonpb.PlaceholderType_SparseFloatVector case entity.Int8Vector: placeHolderType = commonpb.PlaceholderType_Int8Vector case entity.Text: placeHolderType = commonpb.PlaceholderType_VarChar default: return nil, errors.Newf("unsupported search data type: %T", vectors[0]) } ph.Type = placeHolderType for _, vector := range vectors { ph.Values = append(ph.Values, vector.Serialize()) } return ph, nil } type HybridSearchOption interface { HybridRequest() (*milvuspb.HybridSearchRequest, error) } type hybridSearchOption struct { collectionName string partitionNames []string reqs []*annRequest outputFields []string useDefaultConsistency bool consistencyLevel entity.ConsistencyLevel limit int offset int reranker Reranker } func (opt *hybridSearchOption) WithConsistencyLevel(cl entity.ConsistencyLevel) *hybridSearchOption { opt.consistencyLevel = cl opt.useDefaultConsistency = false return opt } func (opt *hybridSearchOption) WithPartitons(partitions ...string) *hybridSearchOption { opt.partitionNames = partitions return opt } func (opt *hybridSearchOption) WithOutputFields(outputFields ...string) *hybridSearchOption { opt.outputFields = outputFields return opt } func (opt *hybridSearchOption) WithReranker(reranker Reranker) *hybridSearchOption { opt.reranker = reranker return opt } func (opt *hybridSearchOption) WithOffset(offset int) *hybridSearchOption { opt.offset = offset return opt } func (opt *hybridSearchOption) HybridRequest() (*milvuspb.HybridSearchRequest, error) { requests := make([]*milvuspb.SearchRequest, 0, len(opt.reqs)) for _, annRequest := range opt.reqs { req, err := annRequest.searchRequest() if err != nil { return nil, err } requests = append(requests, req) } var params []*commonpb.KeyValuePair if opt.reranker != nil { params = opt.reranker.GetParams() } params = append(params, &commonpb.KeyValuePair{Key: spLimit, Value: strconv.FormatInt(int64(opt.limit), 10)}) if opt.offset > 0 { params = append(params, &commonpb.KeyValuePair{Key: spOffset, Value: strconv.FormatInt(int64(opt.offset), 10)}) } return &milvuspb.HybridSearchRequest{ CollectionName: opt.collectionName, PartitionNames: opt.partitionNames, Requests: requests, UseDefaultConsistency: opt.useDefaultConsistency, ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel), OutputFields: opt.outputFields, RankParams: params, }, nil } func NewHybridSearchOption(collectionName string, limit int, annRequests ...*annRequest) *hybridSearchOption { return &hybridSearchOption{ collectionName: collectionName, reqs: annRequests, useDefaultConsistency: true, limit: limit, } } type QueryOption interface { Request() (*milvuspb.QueryRequest, error) } type queryOption struct { collectionName string partitionNames []string queryParams map[string]string outputFields []string consistencyLevel entity.ConsistencyLevel useDefaultConsistencyLevel bool expr string templateParams map[string]any } func (opt *queryOption) Request() (*milvuspb.QueryRequest, error) { req := &milvuspb.QueryRequest{ CollectionName: opt.collectionName, PartitionNames: opt.partitionNames, OutputFields: opt.outputFields, Expr: opt.expr, QueryParams: entity.MapKvPairs(opt.queryParams), ConsistencyLevel: opt.consistencyLevel.CommonConsistencyLevel(), UseDefaultConsistency: opt.useDefaultConsistencyLevel, } req.ExprTemplateValues = make(map[string]*schemapb.TemplateValue) for key, value := range opt.templateParams { tmplVal, err := any2TmplValue(value) if err != nil { return nil, err } req.ExprTemplateValues[key] = tmplVal } return req, nil } func (opt *queryOption) WithFilter(expr string) *queryOption { opt.expr = expr return opt } func (opt *queryOption) WithTemplateParam(key string, val any) *queryOption { opt.templateParams[key] = val return opt } func (opt *queryOption) WithOffset(offset int) *queryOption { if opt.queryParams == nil { opt.queryParams = make(map[string]string) } opt.queryParams[spOffset] = strconv.Itoa(offset) return opt } func (opt *queryOption) WithLimit(limit int) *queryOption { if opt.queryParams == nil { opt.queryParams = make(map[string]string) } opt.queryParams[spLimit] = strconv.Itoa(limit) return opt } func (opt *queryOption) WithOutputFields(fieldNames ...string) *queryOption { opt.outputFields = fieldNames return opt } func (opt *queryOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *queryOption { opt.consistencyLevel = consistencyLevel opt.useDefaultConsistencyLevel = false return opt } func (opt *queryOption) WithPartitions(partitionNames ...string) *queryOption { opt.partitionNames = partitionNames return opt } func (opt *queryOption) WithIDs(ids column.Column) *queryOption { opt.expr = pks2Expr(ids) return opt } func pks2Expr(ids column.Column) string { var expr string pkName := ids.Name() switch ids.Type() { case entity.FieldTypeInt64: expr = fmt.Sprintf("%s in %s", pkName, strings.Join(strings.Fields(fmt.Sprint(ids.FieldData().GetScalars().GetLongData().GetData())), ",")) case entity.FieldTypeVarChar: data := ids.FieldData().GetScalars().GetData().(*schemapb.ScalarField_StringData).StringData.GetData() for i := range data { data[i] = fmt.Sprintf("\"%s\"", data[i]) } expr = fmt.Sprintf("%s in [%s]", pkName, strings.Join(data, ",")) } return expr } func NewQueryOption(collectionName string) *queryOption { return &queryOption{ collectionName: collectionName, useDefaultConsistencyLevel: true, consistencyLevel: entity.ClBounded, templateParams: make(map[string]any), } }