mirror of https://github.com/milvus-io/milvus.git
304 lines
9.5 KiB
Go
304 lines
9.5 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package milvusclient
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"math/rand"
|
|
"testing"
|
|
|
|
"github.com/samber/lo"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/suite"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/client/v2/entity"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
)
|
|
|
|
type SearchIteratorSuite struct {
|
|
MockSuiteBase
|
|
|
|
schema *entity.Schema
|
|
}
|
|
|
|
func (s *SearchIteratorSuite) SetupSuite() {
|
|
s.MockSuiteBase.SetupSuite()
|
|
s.schema = entity.NewSchema().
|
|
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
|
|
WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
|
|
}
|
|
|
|
func (s *SearchIteratorSuite) TestSearchIteratorInit() {
|
|
ctx := context.Background()
|
|
s.Run("success", func() {
|
|
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
|
|
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
|
CollectionID: 1,
|
|
Schema: s.schema.ProtoMessage(),
|
|
}, nil).Once()
|
|
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
|
s.Equal(collectionName, sr.GetCollectionName())
|
|
checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool {
|
|
for _, kv := range kvs {
|
|
if kv.GetKey() == key && kv.GetValue() == value {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true"))
|
|
s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true"))
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: 1,
|
|
TopK: 1,
|
|
FieldsData: []*schemapb.FieldData{
|
|
s.getInt64FieldData("ID", []int64{1}),
|
|
},
|
|
Ids: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{1},
|
|
},
|
|
},
|
|
},
|
|
Scores: make([]float32, 1),
|
|
Topks: []int64{1},
|
|
Recalls: []float32{1},
|
|
SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{
|
|
Token: s.randString(16),
|
|
},
|
|
},
|
|
}, nil
|
|
}).Once()
|
|
|
|
iter, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
|
return rand.Float32()
|
|
}))))
|
|
|
|
s.NoError(err)
|
|
_, ok := iter.(*searchIteratorV2)
|
|
s.True(ok)
|
|
})
|
|
|
|
s.Run("failure", func() {
|
|
s.Run("describe_fail", func() {
|
|
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
|
|
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("mock error")).Once()
|
|
_, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
|
return rand.Float32()
|
|
}))))
|
|
s.Error(err)
|
|
})
|
|
|
|
s.Run("not_v2_result", func() {
|
|
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
|
CollectionID: 1,
|
|
Schema: s.schema.ProtoMessage(),
|
|
}, nil).Once()
|
|
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
|
s.Equal(collectionName, sr.GetCollectionName())
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: 1,
|
|
TopK: 1,
|
|
FieldsData: []*schemapb.FieldData{
|
|
s.getInt64FieldData("ID", []int64{1}),
|
|
},
|
|
Ids: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{1},
|
|
},
|
|
},
|
|
},
|
|
Scores: make([]float32, 1),
|
|
Topks: []int64{1},
|
|
Recalls: []float32{1},
|
|
SearchIteratorV2Results: nil, // nil v2 results
|
|
},
|
|
}, nil
|
|
}).Once()
|
|
|
|
_, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
|
return rand.Float32()
|
|
}))))
|
|
s.Error(err)
|
|
s.ErrorIs(err, ErrServerVersionIncompatible)
|
|
})
|
|
})
|
|
}
|
|
|
|
func (s *SearchIteratorSuite) TestNext() {
|
|
ctx := context.Background()
|
|
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
|
|
token := fmt.Sprintf("iter_token_%s", s.randString(8))
|
|
|
|
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
|
CollectionID: 1,
|
|
Schema: s.schema.ProtoMessage(),
|
|
}, nil).Once()
|
|
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
|
s.Equal(collectionName, sr.GetCollectionName())
|
|
checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool {
|
|
for _, kv := range kvs {
|
|
if kv.GetKey() == key && kv.GetValue() == value {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true"))
|
|
s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true"))
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: 1,
|
|
TopK: 1,
|
|
FieldsData: []*schemapb.FieldData{
|
|
s.getInt64FieldData("ID", []int64{1}),
|
|
},
|
|
Ids: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{1},
|
|
},
|
|
},
|
|
},
|
|
Scores: make([]float32, 1),
|
|
Topks: []int64{1},
|
|
Recalls: []float32{1},
|
|
SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{
|
|
Token: token,
|
|
},
|
|
},
|
|
}, nil
|
|
}).Once()
|
|
|
|
iter, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
|
return rand.Float32()
|
|
}))))
|
|
s.Require().NoError(err)
|
|
s.Require().NotNil(iter)
|
|
|
|
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
|
s.Equal(collectionName, sr.GetCollectionName())
|
|
checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool {
|
|
for _, kv := range kvs {
|
|
if kv.GetKey() == key && kv.GetValue() == value {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true"))
|
|
s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true"))
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: 1,
|
|
TopK: 1,
|
|
FieldsData: []*schemapb.FieldData{
|
|
s.getInt64FieldData("ID", []int64{1}),
|
|
},
|
|
Ids: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{1},
|
|
},
|
|
},
|
|
},
|
|
Scores: []float32{0.5},
|
|
Topks: []int64{1},
|
|
Recalls: []float32{1},
|
|
SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{
|
|
Token: token,
|
|
LastBound: 0.5,
|
|
},
|
|
},
|
|
}, nil
|
|
}).Once()
|
|
|
|
rs, err := iter.Next(ctx)
|
|
s.NoError(err)
|
|
s.EqualValues(1, rs.IDs.Len())
|
|
|
|
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
|
s.Equal(collectionName, sr.GetCollectionName())
|
|
checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool {
|
|
for _, kv := range kvs {
|
|
if kv.GetKey() == key && kv.GetValue() == value {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true"))
|
|
s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true"))
|
|
s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchIDKey, token))
|
|
s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchLastBoundKey, "0.5"))
|
|
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: 1,
|
|
TopK: 1,
|
|
FieldsData: []*schemapb.FieldData{
|
|
s.getInt64FieldData("ID", []int64{}),
|
|
},
|
|
Ids: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{},
|
|
},
|
|
},
|
|
},
|
|
Scores: []float32{},
|
|
Topks: []int64{0},
|
|
Recalls: []float32{1.0},
|
|
SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{
|
|
Token: token,
|
|
LastBound: 0.5,
|
|
},
|
|
},
|
|
}, nil
|
|
}).Once()
|
|
|
|
_, err = iter.Next(ctx)
|
|
s.Error(err)
|
|
s.ErrorIs(err, io.EOF)
|
|
}
|
|
|
|
func TestSearchIterator(t *testing.T) {
|
|
suite.Run(t, new(SearchIteratorSuite))
|
|
}
|