mirror of https://github.com/milvus-io/milvus.git
266 lines
7.3 KiB
Go
266 lines
7.3 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package client
|
|
|
|
import (
|
|
"encoding/json"
|
|
"strconv"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus/client/v2/entity"
|
|
)
|
|
|
|
const (
|
|
spAnnsField = `anns_field`
|
|
spTopK = `topk`
|
|
spOffset = `offset`
|
|
spLimit = `limit`
|
|
spParams = `params`
|
|
spMetricsType = `metric_type`
|
|
spRoundDecimal = `round_decimal`
|
|
spIgnoreGrowing = `ignore_growing`
|
|
spGroupBy = `group_by_field`
|
|
)
|
|
|
|
type SearchOption interface {
|
|
Request() *milvuspb.SearchRequest
|
|
}
|
|
|
|
var _ SearchOption = (*searchOption)(nil)
|
|
|
|
type searchOption struct {
|
|
collectionName string
|
|
partitionNames []string
|
|
topK int
|
|
offset int
|
|
outputFields []string
|
|
consistencyLevel entity.ConsistencyLevel
|
|
useDefaultConsistencyLevel bool
|
|
ignoreGrowing bool
|
|
expr string
|
|
|
|
// normal search request
|
|
request *annRequest
|
|
// TODO add sub request when support hybrid search
|
|
}
|
|
|
|
type annRequest struct {
|
|
vectors []entity.Vector
|
|
|
|
annField string
|
|
metricsType entity.MetricType
|
|
searchParam map[string]string
|
|
groupByField string
|
|
}
|
|
|
|
func (opt *searchOption) Request() *milvuspb.SearchRequest {
|
|
// TODO check whether search is hybrid after logic merged
|
|
return opt.prepareSearchRequest(opt.request)
|
|
}
|
|
|
|
func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.SearchRequest {
|
|
request := &milvuspb.SearchRequest{
|
|
CollectionName: opt.collectionName,
|
|
PartitionNames: opt.partitionNames,
|
|
Dsl: opt.expr,
|
|
DslType: commonpb.DslType_BoolExprV1,
|
|
ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel),
|
|
OutputFields: opt.outputFields,
|
|
}
|
|
if annRequest != nil {
|
|
// nq
|
|
request.Nq = int64(len(annRequest.vectors))
|
|
|
|
// search param
|
|
bs, _ := json.Marshal(annRequest.searchParam)
|
|
params := map[string]string{
|
|
spAnnsField: annRequest.annField,
|
|
spTopK: strconv.Itoa(opt.topK),
|
|
spOffset: strconv.Itoa(opt.offset),
|
|
spParams: string(bs),
|
|
spMetricsType: string(annRequest.metricsType),
|
|
spRoundDecimal: "-1",
|
|
spIgnoreGrowing: strconv.FormatBool(opt.ignoreGrowing),
|
|
}
|
|
if annRequest.groupByField != "" {
|
|
params[spGroupBy] = annRequest.groupByField
|
|
}
|
|
request.SearchParams = entity.MapKvPairs(params)
|
|
|
|
// placeholder group
|
|
request.PlaceholderGroup = vector2PlaceholderGroupBytes(annRequest.vectors)
|
|
}
|
|
|
|
return request
|
|
}
|
|
|
|
func (opt *searchOption) WithFilter(expr string) *searchOption {
|
|
opt.expr = expr
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithOffset(offset int) *searchOption {
|
|
opt.offset = offset
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithOutputFields(fieldNames []string) *searchOption {
|
|
opt.outputFields = fieldNames
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *searchOption {
|
|
opt.consistencyLevel = consistencyLevel
|
|
opt.useDefaultConsistencyLevel = false
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithANNSField(annsField string) *searchOption {
|
|
opt.request.annField = annsField
|
|
return opt
|
|
}
|
|
|
|
func (opt *searchOption) WithPartitions(partitionNames []string) *searchOption {
|
|
opt.partitionNames = partitionNames
|
|
return opt
|
|
}
|
|
|
|
func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption {
|
|
return &searchOption{
|
|
collectionName: collectionName,
|
|
topK: limit,
|
|
request: &annRequest{
|
|
vectors: vectors,
|
|
},
|
|
useDefaultConsistencyLevel: true,
|
|
consistencyLevel: entity.ClBounded,
|
|
}
|
|
}
|
|
|
|
func vector2PlaceholderGroupBytes(vectors []entity.Vector) []byte {
|
|
phg := &commonpb.PlaceholderGroup{
|
|
Placeholders: []*commonpb.PlaceholderValue{
|
|
vector2Placeholder(vectors),
|
|
},
|
|
}
|
|
|
|
bs, _ := proto.Marshal(phg)
|
|
return bs
|
|
}
|
|
|
|
func vector2Placeholder(vectors []entity.Vector) *commonpb.PlaceholderValue {
|
|
var placeHolderType commonpb.PlaceholderType
|
|
ph := &commonpb.PlaceholderValue{
|
|
Tag: "$0",
|
|
Values: make([][]byte, 0, len(vectors)),
|
|
}
|
|
if len(vectors) == 0 {
|
|
return ph
|
|
}
|
|
switch vectors[0].(type) {
|
|
case entity.FloatVector:
|
|
placeHolderType = commonpb.PlaceholderType_FloatVector
|
|
case entity.BinaryVector:
|
|
placeHolderType = commonpb.PlaceholderType_BinaryVector
|
|
case entity.BFloat16Vector:
|
|
placeHolderType = commonpb.PlaceholderType_BFloat16Vector
|
|
case entity.Float16Vector:
|
|
placeHolderType = commonpb.PlaceholderType_Float16Vector
|
|
case entity.SparseEmbedding:
|
|
placeHolderType = commonpb.PlaceholderType_SparseFloatVector
|
|
}
|
|
ph.Type = placeHolderType
|
|
for _, vector := range vectors {
|
|
ph.Values = append(ph.Values, vector.Serialize())
|
|
}
|
|
return ph
|
|
}
|
|
|
|
type QueryOption interface {
|
|
Request() *milvuspb.QueryRequest
|
|
}
|
|
|
|
type queryOption struct {
|
|
collectionName string
|
|
partitionNames []string
|
|
queryParams map[string]string
|
|
outputFields []string
|
|
consistencyLevel entity.ConsistencyLevel
|
|
useDefaultConsistencyLevel bool
|
|
expr string
|
|
}
|
|
|
|
func (opt *queryOption) Request() *milvuspb.QueryRequest {
|
|
return &milvuspb.QueryRequest{
|
|
CollectionName: opt.collectionName,
|
|
PartitionNames: opt.partitionNames,
|
|
OutputFields: opt.outputFields,
|
|
|
|
Expr: opt.expr,
|
|
QueryParams: entity.MapKvPairs(opt.queryParams),
|
|
ConsistencyLevel: opt.consistencyLevel.CommonConsistencyLevel(),
|
|
}
|
|
}
|
|
|
|
func (opt *queryOption) WithFilter(expr string) *queryOption {
|
|
opt.expr = expr
|
|
return opt
|
|
}
|
|
|
|
func (opt *queryOption) WithOffset(offset int) *queryOption {
|
|
if opt.queryParams == nil {
|
|
opt.queryParams = make(map[string]string)
|
|
}
|
|
opt.queryParams[spOffset] = strconv.Itoa(offset)
|
|
return opt
|
|
}
|
|
|
|
func (opt *queryOption) WithLimit(limit int) *queryOption {
|
|
if opt.queryParams == nil {
|
|
opt.queryParams = make(map[string]string)
|
|
}
|
|
opt.queryParams[spLimit] = strconv.Itoa(limit)
|
|
return opt
|
|
}
|
|
|
|
func (opt *queryOption) WithOutputFields(fieldNames []string) *queryOption {
|
|
opt.outputFields = fieldNames
|
|
return opt
|
|
}
|
|
|
|
func (opt *queryOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *queryOption {
|
|
opt.consistencyLevel = consistencyLevel
|
|
opt.useDefaultConsistencyLevel = false
|
|
return opt
|
|
}
|
|
|
|
func (opt *queryOption) WithPartitions(partitionNames []string) *queryOption {
|
|
opt.partitionNames = partitionNames
|
|
return opt
|
|
}
|
|
|
|
func NewQueryOption(collectionName string) *queryOption {
|
|
return &queryOption{
|
|
collectionName: collectionName,
|
|
useDefaultConsistencyLevel: true,
|
|
consistencyLevel: entity.ClBounded,
|
|
}
|
|
}
|