milvus/client/milvusclient/iterator.go

214 lines
6.1 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package milvusclient
import (
"context"
"fmt"
"io"
"github.com/cockroachdb/errors"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
const (
// IteratorKey is the const search param key in indicating enabling iterator.
IteratorKey = "iterator"
IteratorSessionTsKey = "iterator_session_ts"
IteratorSearchV2Key = "search_iter_v2"
IteratorSearchBatchSizeKey = "search_iter_batch_size"
IteratorSearchLastBoundKey = "search_iter_last_bound"
IteratorSearchIDKey = "search_iter_id"
CollectionIDKey = `collection_id`
// Unlimited
Unlimited int64 = -1
)
var ErrServerVersionIncompatible = errors.New("server version incompatible")
// SearchIterator is the interface for search iterator.
type SearchIterator interface {
// Next returns next batch of iterator
// when iterator reaches the end, return `io.EOF`.
Next(ctx context.Context) (ResultSet, error)
}
type searchIteratorV2 struct {
client *Client
option SearchIteratorOption
schema *entity.Schema
limit int64
}
func (it *searchIteratorV2) Next(ctx context.Context) (ResultSet, error) {
// limit reached, return EOF
if it.limit == 0 {
return ResultSet{}, io.EOF
}
rs, err := it.next(ctx)
if err != nil {
return rs, err
}
if it.limit == Unlimited {
return rs, err
}
if int64(rs.Len()) > it.limit {
rs = rs.Slice(0, int(it.limit))
}
it.limit -= int64(rs.Len())
return rs, nil
}
func (it *searchIteratorV2) next(ctx context.Context) (ResultSet, error) {
opt := it.option.SearchOption()
req, err := opt.Request()
if err != nil {
return ResultSet{}, err
}
var rs ResultSet
err = it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Search(ctx, req)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
iteratorInfo := resp.GetResults().GetSearchIteratorV2Results()
opt.annRequest.WithSearchParam(IteratorSearchIDKey, iteratorInfo.GetToken())
opt.annRequest.WithSearchParam(IteratorSearchLastBoundKey, fmt.Sprintf("%v", iteratorInfo.GetLastBound()))
resultSets, err := it.client.handleSearchResult(it.schema, req.GetOutputFields(), int(resp.GetResults().GetNumQueries()), resp)
if err != nil {
return err
}
rs = resultSets[0]
if rs.IDs.Len() == 0 {
return io.EOF
}
return nil
})
return rs, err
}
func (it *searchIteratorV2) setupCollectionID(ctx context.Context) error {
opt := it.option.SearchOption()
return it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{
CollectionName: opt.collectionName,
})
if merr.CheckRPCCall(resp, err) != nil {
return err
}
opt.WithSearchParam(CollectionIDKey, fmt.Sprintf("%d", resp.GetCollectionID()))
schema := &entity.Schema{}
it.schema = schema.ReadProto(resp.GetSchema())
return nil
})
}
// probeCompatiblity checks if the server support SearchIteratorV2.
// It checks if the search result contains search iterator v2 results info and token.
func (it *searchIteratorV2) probeCompatiblity(ctx context.Context) error {
opt := it.option.SearchOption()
opt.annRequest.topK = 1 // ok to leave it here, will be overwritten in next iteration
opt.annRequest.WithSearchParam(IteratorSearchBatchSizeKey, "1")
req, err := opt.Request()
if err != nil {
return err
}
return it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Search(ctx, req)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
if resp.GetResults().GetSearchIteratorV2Results() == nil || resp.GetResults().GetSearchIteratorV2Results().GetToken() == "" {
return ErrServerVersionIncompatible
}
return nil
})
}
// newSearchIteratorV2 creates a new search iterator V2.
//
// It sets up the collection ID and checks if the server supports search iterator V2.
// If the server does not support search iterator V2, it returns an error.
func newSearchIteratorV2(ctx context.Context, client *Client, option SearchIteratorOption) (*searchIteratorV2, error) {
iter := &searchIteratorV2{
client: client,
option: option,
limit: option.Limit(),
}
if err := iter.setupCollectionID(ctx); err != nil {
return nil, err
}
if err := iter.probeCompatiblity(ctx); err != nil {
return nil, err
}
return iter, nil
}
type searchIteratorV1 struct {
client *Client
}
func (s *searchIteratorV1) Next(_ context.Context) (ResultSet, error) {
return ResultSet{}, errors.New("not implemented")
}
func newSearchIteratorV1(_ *Client) (*searchIteratorV1, error) {
// search iterator v1 is not supported
return nil, ErrServerVersionIncompatible
}
// SearchIterator creates a search iterator from a collection.
//
// If the server supports search iterator V2, it creates a search iterator V2.
func (c *Client) SearchIterator(ctx context.Context, option SearchIteratorOption, callOptions ...grpc.CallOption) (SearchIterator, error) {
if err := option.ValidateParams(); err != nil {
return nil, err
}
iter, err := newSearchIteratorV2(ctx, c, option)
if err == nil {
return iter, nil
}
if !errors.Is(err, ErrServerVersionIncompatible) {
return nil, err
}
return newSearchIteratorV1(c)
}