// Licensed to the LF AI & Data foundation under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package milvusclient import ( "fmt" "math/rand" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/client/v2/entity" ) type SearchOptionSuite struct { suite.Suite } type nonSupportData struct{} func (d nonSupportData) Serialize() []byte { return []byte{} } func (d nonSupportData) Dim() int { return 0 } func (d nonSupportData) FieldType() entity.FieldType { return entity.FieldType(0) } func (s *SearchOptionSuite) TestBasic() { collName := "search_opt_basic" topK := rand.Intn(100) + 1 opt := NewSearchOption(collName, topK, []entity.Vector{entity.FloatVector([]float32{0.1, 0.2})}) opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000").WithGroupByField("group_field").WithGroupSize(10).WithStrictGroupSize(true) req, err := opt.Request() s.Require().NoError(err) s.Equal(collName, req.GetCollectionName()) s.Equal("ID > 1000", req.GetDsl()) s.ElementsMatch([]string{"ID", "Value"}, req.GetOutputFields()) searchParams := entity.KvPairsMap(req.GetSearchParams()) annField, ok := searchParams[spAnnsField] s.Require().True(ok) s.Equal("test_field", annField) groupField, ok := searchParams[spGroupBy] s.Require().True(ok) s.Equal("group_field", groupField) groupSize, ok := searchParams[spGroupSize] s.Require().True(ok) s.Equal("10", groupSize) spStrictGroupSize, ok := searchParams[spStrictGroupSize] s.Require().True(ok) s.Equal("true", spStrictGroupSize) opt = NewSearchOption(collName, topK, []entity.Vector{nonSupportData{}}) _, err = opt.Request() s.Error(err) } func (s *SearchOptionSuite) TestPlaceHolder() { type testCase struct { tag string input []entity.Vector expectError bool expectType commonpb.PlaceholderType } sparse, err := entity.NewSliceSparseEmbedding([]uint32{0, 10, 12}, []float32{0.1, 0.2, 0.3}) s.Require().NoError(err) cases := []*testCase{ { tag: "empty_input", input: nil, expectType: commonpb.PlaceholderType_None, }, { tag: "float_vector", input: []entity.Vector{entity.FloatVector([]float32{0.1, 0.2, 0.3})}, expectType: commonpb.PlaceholderType_FloatVector, }, { tag: "sparse_vector", input: []entity.Vector{sparse}, expectType: commonpb.PlaceholderType_SparseFloatVector, }, { tag: "fp16_vector", input: []entity.Vector{entity.Float16Vector([]byte{})}, expectType: commonpb.PlaceholderType_Float16Vector, }, { tag: "bf16_vector", input: []entity.Vector{entity.BFloat16Vector([]byte{})}, expectType: commonpb.PlaceholderType_BFloat16Vector, }, { tag: "binary_vector", input: []entity.Vector{entity.BinaryVector([]byte{})}, expectType: commonpb.PlaceholderType_BinaryVector, }, { tag: "int8_vector", input: []entity.Vector{entity.Int8Vector([]int8{})}, expectType: commonpb.PlaceholderType_Int8Vector, }, { tag: "text", input: []entity.Vector{entity.Text("abc")}, expectType: commonpb.PlaceholderType_VarChar, }, { tag: "non_supported", input: []entity.Vector{nonSupportData{}}, expectError: true, }, } for _, tc := range cases { s.Run(tc.tag, func() { phv, err := vector2Placeholder(tc.input) if tc.expectError { s.Error(err) } else { s.NoError(err) s.Equal(tc.expectType, phv.GetType()) } }) } } func TestSearchOption(t *testing.T) { suite.Run(t, new(SearchOptionSuite)) } func TestAny2TmplValue(t *testing.T) { t.Run("primitives", func(t *testing.T) { t.Run("int", func(t *testing.T) { v := rand.Int() val, err := any2TmplValue(v) assert.NoError(t, err) assert.EqualValues(t, v, val.GetInt64Val()) }) t.Run("int32", func(t *testing.T) { v := rand.Int31() val, err := any2TmplValue(v) assert.NoError(t, err) assert.EqualValues(t, v, val.GetInt64Val()) }) t.Run("int64", func(t *testing.T) { v := rand.Int63() val, err := any2TmplValue(v) assert.NoError(t, err) assert.EqualValues(t, v, val.GetInt64Val()) }) t.Run("float32", func(t *testing.T) { v := rand.Float32() val, err := any2TmplValue(v) assert.NoError(t, err) assert.EqualValues(t, v, val.GetFloatVal()) }) t.Run("float64", func(t *testing.T) { v := rand.Float64() val, err := any2TmplValue(v) assert.NoError(t, err) assert.EqualValues(t, v, val.GetFloatVal()) }) t.Run("bool", func(t *testing.T) { val, err := any2TmplValue(true) assert.NoError(t, err) assert.True(t, val.GetBoolVal()) }) t.Run("string", func(t *testing.T) { v := fmt.Sprintf("%v", rand.Int()) val, err := any2TmplValue(v) assert.NoError(t, err) assert.EqualValues(t, v, val.GetStringVal()) }) }) t.Run("slice", func(t *testing.T) { t.Run("int", func(t *testing.T) { l := rand.Intn(10) + 1 v := make([]int, 0, l) for i := 0; i < l; i++ { v = append(v, rand.Int()) } val, err := any2TmplValue(v) assert.NoError(t, err) data := val.GetArrayVal().GetLongData().GetData() assert.Equal(t, l, len(data)) for i, val := range data { assert.EqualValues(t, v[i], val) } }) t.Run("int32", func(t *testing.T) { l := rand.Intn(10) + 1 v := make([]int32, 0, l) for i := 0; i < l; i++ { v = append(v, rand.Int31()) } val, err := any2TmplValue(v) assert.NoError(t, err) data := val.GetArrayVal().GetLongData().GetData() assert.Equal(t, l, len(data)) for i, val := range data { assert.EqualValues(t, v[i], val) } }) t.Run("int64", func(t *testing.T) { l := rand.Intn(10) + 1 v := make([]int64, 0, l) for i := 0; i < l; i++ { v = append(v, rand.Int63()) } val, err := any2TmplValue(v) assert.NoError(t, err) data := val.GetArrayVal().GetLongData().GetData() assert.Equal(t, l, len(data)) for i, val := range data { assert.EqualValues(t, v[i], val) } }) t.Run("float32", func(t *testing.T) { l := rand.Intn(10) + 1 v := make([]float32, 0, l) for i := 0; i < l; i++ { v = append(v, rand.Float32()) } val, err := any2TmplValue(v) assert.NoError(t, err) data := val.GetArrayVal().GetDoubleData().GetData() assert.Equal(t, l, len(data)) for i, val := range data { assert.EqualValues(t, v[i], val) } }) t.Run("float64", func(t *testing.T) { l := rand.Intn(10) + 1 v := make([]float64, 0, l) for i := 0; i < l; i++ { v = append(v, rand.Float64()) } val, err := any2TmplValue(v) assert.NoError(t, err) data := val.GetArrayVal().GetDoubleData().GetData() assert.Equal(t, l, len(data)) for i, val := range data { assert.EqualValues(t, v[i], val) } }) t.Run("bool", func(t *testing.T) { l := rand.Intn(10) + 1 v := make([]bool, 0, l) for i := 0; i < l; i++ { v = append(v, rand.Int()%2 == 0) } val, err := any2TmplValue(v) assert.NoError(t, err) data := val.GetArrayVal().GetBoolData().GetData() assert.Equal(t, l, len(data)) for i, val := range data { assert.EqualValues(t, v[i], val) } }) t.Run("string", func(t *testing.T) { l := rand.Intn(10) + 1 v := make([]string, 0, l) for i := 0; i < l; i++ { v = append(v, fmt.Sprintf("%v", rand.Int())) } val, err := any2TmplValue(v) assert.NoError(t, err) data := val.GetArrayVal().GetStringData().GetData() assert.Equal(t, l, len(data)) for i, val := range data { assert.EqualValues(t, v[i], val) } }) }) t.Run("unsupported", func(*testing.T) { _, err := any2TmplValue(struct{}{}) assert.Error(t, err) _, err = any2TmplValue([]struct{}{}) assert.Error(t, err) }) }