mirror of https://github.com/milvus-io/milvus.git
214 lines
6.1 KiB
Go
214 lines
6.1 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package milvusclient
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"google.golang.org/grpc"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus/client/v2/entity"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
)
|
|
|
|
const (
|
|
// IteratorKey is the const search param key in indicating enabling iterator.
|
|
IteratorKey = "iterator"
|
|
IteratorSessionTsKey = "iterator_session_ts"
|
|
IteratorSearchV2Key = "search_iter_v2"
|
|
IteratorSearchBatchSizeKey = "search_iter_batch_size"
|
|
IteratorSearchLastBoundKey = "search_iter_last_bound"
|
|
IteratorSearchIDKey = "search_iter_id"
|
|
CollectionIDKey = `collection_id`
|
|
|
|
// Unlimited
|
|
Unlimited int64 = -1
|
|
)
|
|
|
|
var ErrServerVersionIncompatible = errors.New("server version incompatible")
|
|
|
|
// SearchIterator is the interface for search iterator.
|
|
type SearchIterator interface {
|
|
// Next returns next batch of iterator
|
|
// when iterator reaches the end, return `io.EOF`.
|
|
Next(ctx context.Context) (ResultSet, error)
|
|
}
|
|
|
|
type searchIteratorV2 struct {
|
|
client *Client
|
|
option SearchIteratorOption
|
|
schema *entity.Schema
|
|
limit int64
|
|
}
|
|
|
|
func (it *searchIteratorV2) Next(ctx context.Context) (ResultSet, error) {
|
|
// limit reached, return EOF
|
|
if it.limit == 0 {
|
|
return ResultSet{}, io.EOF
|
|
}
|
|
|
|
rs, err := it.next(ctx)
|
|
if err != nil {
|
|
return rs, err
|
|
}
|
|
|
|
if it.limit == Unlimited {
|
|
return rs, err
|
|
}
|
|
|
|
if int64(rs.Len()) > it.limit {
|
|
rs = rs.Slice(0, int(it.limit))
|
|
}
|
|
it.limit -= int64(rs.Len())
|
|
return rs, nil
|
|
}
|
|
|
|
func (it *searchIteratorV2) next(ctx context.Context) (ResultSet, error) {
|
|
opt := it.option.SearchOption()
|
|
req, err := opt.Request()
|
|
if err != nil {
|
|
return ResultSet{}, err
|
|
}
|
|
|
|
var rs ResultSet
|
|
|
|
err = it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
|
resp, err := milvusService.Search(ctx, req)
|
|
err = merr.CheckRPCCall(resp, err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
iteratorInfo := resp.GetResults().GetSearchIteratorV2Results()
|
|
opt.annRequest.WithSearchParam(IteratorSearchIDKey, iteratorInfo.GetToken())
|
|
opt.annRequest.WithSearchParam(IteratorSearchLastBoundKey, fmt.Sprintf("%v", iteratorInfo.GetLastBound()))
|
|
|
|
resultSets, err := it.client.handleSearchResult(it.schema, req.GetOutputFields(), int(resp.GetResults().GetNumQueries()), resp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rs = resultSets[0]
|
|
|
|
if rs.IDs.Len() == 0 {
|
|
return io.EOF
|
|
}
|
|
|
|
return nil
|
|
})
|
|
return rs, err
|
|
}
|
|
|
|
func (it *searchIteratorV2) setupCollectionID(ctx context.Context) error {
|
|
opt := it.option.SearchOption()
|
|
|
|
return it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
|
resp, err := milvusService.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{
|
|
CollectionName: opt.collectionName,
|
|
})
|
|
if merr.CheckRPCCall(resp, err) != nil {
|
|
return err
|
|
}
|
|
|
|
opt.WithSearchParam(CollectionIDKey, fmt.Sprintf("%d", resp.GetCollectionID()))
|
|
schema := &entity.Schema{}
|
|
it.schema = schema.ReadProto(resp.GetSchema())
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// probeCompatiblity checks if the server support SearchIteratorV2.
|
|
// It checks if the search result contains search iterator v2 results info and token.
|
|
func (it *searchIteratorV2) probeCompatiblity(ctx context.Context) error {
|
|
opt := it.option.SearchOption()
|
|
opt.annRequest.topK = 1 // ok to leave it here, will be overwritten in next iteration
|
|
opt.annRequest.WithSearchParam(IteratorSearchBatchSizeKey, "1")
|
|
req, err := opt.Request()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
|
resp, err := milvusService.Search(ctx, req)
|
|
err = merr.CheckRPCCall(resp, err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if resp.GetResults().GetSearchIteratorV2Results() == nil || resp.GetResults().GetSearchIteratorV2Results().GetToken() == "" {
|
|
return ErrServerVersionIncompatible
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// newSearchIteratorV2 creates a new search iterator V2.
|
|
//
|
|
// It sets up the collection ID and checks if the server supports search iterator V2.
|
|
// If the server does not support search iterator V2, it returns an error.
|
|
func newSearchIteratorV2(ctx context.Context, client *Client, option SearchIteratorOption) (*searchIteratorV2, error) {
|
|
iter := &searchIteratorV2{
|
|
client: client,
|
|
option: option,
|
|
limit: option.Limit(),
|
|
}
|
|
if err := iter.setupCollectionID(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := iter.probeCompatiblity(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return iter, nil
|
|
}
|
|
|
|
type searchIteratorV1 struct {
|
|
client *Client
|
|
}
|
|
|
|
func (s *searchIteratorV1) Next(_ context.Context) (ResultSet, error) {
|
|
return ResultSet{}, errors.New("not implemented")
|
|
}
|
|
|
|
func newSearchIteratorV1(_ *Client) (*searchIteratorV1, error) {
|
|
// search iterator v1 is not supported
|
|
return nil, ErrServerVersionIncompatible
|
|
}
|
|
|
|
// SearchIterator creates a search iterator from a collection.
|
|
//
|
|
// If the server supports search iterator V2, it creates a search iterator V2.
|
|
func (c *Client) SearchIterator(ctx context.Context, option SearchIteratorOption, callOptions ...grpc.CallOption) (SearchIterator, error) {
|
|
if err := option.ValidateParams(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
iter, err := newSearchIteratorV2(ctx, c, option)
|
|
if err == nil {
|
|
return iter, nil
|
|
}
|
|
|
|
if !errors.Is(err, ErrServerVersionIncompatible) {
|
|
return nil, err
|
|
}
|
|
|
|
return newSearchIteratorV1(c)
|
|
}
|