test: add go case for groupby search (#38411)

issue: #33419

---------

Signed-off-by: ThreadDao <yufen.zong@zilliz.com>
pull/35589/head
ThreadDao 2024-12-16 10:44:43 +08:00 committed by GitHub
parent 4919ccf543
commit 8794ec966e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 497 additions and 25 deletions

View File

@ -44,7 +44,8 @@ func CheckErr(t *testing.T, actualErr error, expErrNil bool, expErrorMsg ...stri
func EqualColumn(t *testing.T, columnA column.Column, columnB column.Column) { func EqualColumn(t *testing.T, columnA column.Column, columnB column.Column) {
require.Equal(t, columnA.Name(), columnB.Name()) require.Equal(t, columnA.Name(), columnB.Name())
require.Equal(t, columnA.Type(), columnB.Type()) require.Equal(t, columnA.Type(), columnB.Type())
switch columnA.Type() { _type := columnA.Type()
switch _type {
case entity.FieldTypeBool: case entity.FieldTypeBool:
require.ElementsMatch(t, columnA.(*column.ColumnBool).Data(), columnB.(*column.ColumnBool).Data()) require.ElementsMatch(t, columnA.(*column.ColumnBool).Data(), columnB.(*column.ColumnBool).Data())
case entity.FieldTypeInt8: case entity.FieldTypeInt8:
@ -65,11 +66,13 @@ func EqualColumn(t *testing.T, columnA column.Column, columnB column.Column) {
log.Debug("data", zap.String("name", columnA.Name()), zap.Any("type", columnA.Type()), zap.Any("data", columnA.FieldData())) log.Debug("data", zap.String("name", columnA.Name()), zap.Any("type", columnA.Type()), zap.Any("data", columnA.FieldData()))
log.Debug("data", zap.String("name", columnB.Name()), zap.Any("type", columnB.Type()), zap.Any("data", columnB.FieldData())) log.Debug("data", zap.String("name", columnB.Name()), zap.Any("type", columnB.Type()), zap.Any("data", columnB.FieldData()))
require.Equal(t, reflect.TypeOf(columnA), reflect.TypeOf(columnB)) require.Equal(t, reflect.TypeOf(columnA), reflect.TypeOf(columnB))
switch columnA.(type) { switch _v := columnA.(type) {
case *column.ColumnDynamic: case *column.ColumnDynamic:
require.ElementsMatch(t, columnA.(*column.ColumnDynamic).Data(), columnB.(*column.ColumnDynamic).Data()) require.ElementsMatch(t, columnA.(*column.ColumnDynamic).Data(), columnB.(*column.ColumnDynamic).Data())
case *column.ColumnJSONBytes: case *column.ColumnJSONBytes:
require.ElementsMatch(t, columnA.(*column.ColumnJSONBytes).Data(), columnB.(*column.ColumnJSONBytes).Data()) require.ElementsMatch(t, columnA.(*column.ColumnJSONBytes).Data(), columnB.(*column.ColumnJSONBytes).Data())
default:
log.Warn("columnA type", zap.String("name", columnB.Name()), zap.Any("type", _v))
} }
case entity.FieldTypeFloatVector: case entity.FieldTypeFloatVector:
require.ElementsMatch(t, columnA.(*column.ColumnFloatVector).Data(), columnB.(*column.ColumnFloatVector).Data()) require.ElementsMatch(t, columnA.(*column.ColumnFloatVector).Data(), columnB.(*column.ColumnFloatVector).Data())
@ -98,7 +101,7 @@ func EqualArrayColumn(t *testing.T, columnA column.Column, columnB column.Column
require.Equal(t, columnA.Name(), columnB.Name()) require.Equal(t, columnA.Name(), columnB.Name())
require.IsType(t, columnA.Type(), entity.FieldTypeArray) require.IsType(t, columnA.Type(), entity.FieldTypeArray)
require.IsType(t, columnB.Type(), entity.FieldTypeArray) require.IsType(t, columnB.Type(), entity.FieldTypeArray)
switch columnA.(type) { switch _type := columnA.(type) {
case *column.ColumnBoolArray: case *column.ColumnBoolArray:
require.ElementsMatch(t, columnA.(*column.ColumnBoolArray).Data(), columnB.(*column.ColumnBoolArray).Data()) require.ElementsMatch(t, columnA.(*column.ColumnBoolArray).Data(), columnB.(*column.ColumnBoolArray).Data())
case *column.ColumnInt8Array: case *column.ColumnInt8Array:
@ -116,6 +119,7 @@ func EqualArrayColumn(t *testing.T, columnA column.Column, columnB column.Column
case *column.ColumnVarCharArray: case *column.ColumnVarCharArray:
require.ElementsMatch(t, columnA.(*column.ColumnVarCharArray).Data(), columnB.(*column.ColumnVarCharArray).Data()) require.ElementsMatch(t, columnA.(*column.ColumnVarCharArray).Data(), columnB.(*column.ColumnVarCharArray).Data())
default: default:
log.Debug("columnA type is", zap.Any("type", _type))
log.Info("Support array element type is:", zap.Any("FieldType", []entity.FieldType{ log.Info("Support array element type is:", zap.Any("FieldType", []entity.FieldType{
entity.FieldTypeBool, entity.FieldTypeInt8, entity.FieldTypeInt16, entity.FieldTypeBool, entity.FieldTypeInt8, entity.FieldTypeInt16,
entity.FieldTypeInt32, entity.FieldTypeInt64, entity.FieldTypeFloat, entity.FieldTypeDouble, entity.FieldTypeVarChar, entity.FieldTypeInt32, entity.FieldTypeInt64, entity.FieldTypeFloat, entity.FieldTypeDouble, entity.FieldTypeVarChar,
@ -124,16 +128,16 @@ func EqualArrayColumn(t *testing.T, columnA column.Column, columnB column.Column
} }
// CheckInsertResult check insert result, ids len (insert count), ids data (pks, but no auto ids) // CheckInsertResult check insert result, ids len (insert count), ids data (pks, but no auto ids)
func CheckInsertResult(t *testing.T, expIds column.Column, insertRes client.InsertResult) { func CheckInsertResult(t *testing.T, expIDs column.Column, insertRes client.InsertResult) {
require.Equal(t, expIds.Len(), insertRes.IDs.Len()) require.Equal(t, expIDs.Len(), insertRes.IDs.Len())
require.Equal(t, expIds.Len(), int(insertRes.InsertCount)) require.Equal(t, expIDs.Len(), int(insertRes.InsertCount))
actualIds := insertRes.IDs actualIDs := insertRes.IDs
switch expIds.Type() { switch expIDs.Type() {
// pk field support int64 and varchar type // pk field support int64 and varchar type
case entity.FieldTypeInt64: case entity.FieldTypeInt64:
require.ElementsMatch(t, actualIds.(*column.ColumnInt64).Data(), expIds.(*column.ColumnInt64).Data()) require.ElementsMatch(t, actualIDs.(*column.ColumnInt64).Data(), expIDs.(*column.ColumnInt64).Data())
case entity.FieldTypeVarChar: case entity.FieldTypeVarChar:
require.ElementsMatch(t, actualIds.(*column.ColumnVarChar).Data(), expIds.(*column.ColumnVarChar).Data()) require.ElementsMatch(t, actualIDs.(*column.ColumnVarChar).Data(), expIDs.(*column.ColumnVarChar).Data())
default: default:
log.Info("The primary field only support ", zap.Any("type", []entity.FieldType{entity.FieldTypeInt64, entity.FieldTypeVarChar})) log.Info("The primary field only support ", zap.Any("type", []entity.FieldType{entity.FieldTypeInt64, entity.FieldTypeVarChar}))
} }

View File

@ -545,8 +545,8 @@ func TestDeleteDuplicatedPks(t *testing.T) {
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
// delete // delete
deleteIds := []int64{0, 0, 0, 0, 0} deleteIDs := []int64{0, 0, 0, 0, 0}
delRes, err := mc.Delete(ctx, client.NewDeleteOption(schema.CollectionName).WithInt64IDs(common.DefaultInt64FieldName, deleteIds)) delRes, err := mc.Delete(ctx, client.NewDeleteOption(schema.CollectionName).WithInt64IDs(common.DefaultInt64FieldName, deleteIDs))
common.CheckErr(t, err, true) common.CheckErr(t, err, true)
require.Equal(t, 5, int(delRes.DeleteCount)) require.Equal(t, 5, int(delRes.DeleteCount))

View File

@ -0,0 +1,466 @@
package testcases
import (
"context"
"fmt"
"log"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/client/v2/column"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/index"
client "github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/milvus-io/milvus/tests/go_client/base"
"github.com/milvus-io/milvus/tests/go_client/common"
hp "github.com/milvus-io/milvus/tests/go_client/testcases/helper"
)
// Generate groupBy-supported vector indexes
func genGroupByVectorIndex(metricType entity.MetricType) []index.Index {
nlist := 128
idxFlat := index.NewFlatIndex(metricType)
idxIvfFlat := index.NewIvfFlatIndex(metricType, nlist)
idxHnsw := index.NewHNSWIndex(metricType, 8, 96)
idxIvfSq8 := index.NewIvfSQ8Index(metricType, 128)
allFloatIndex := []index.Index{
idxFlat,
idxIvfFlat,
idxHnsw,
idxIvfSq8,
}
return allFloatIndex
}
// Generate groupBy-supported vector indexes
func genGroupByBinaryIndex(metricType entity.MetricType) []index.Index {
nlist := 128
idxBinFlat := index.NewBinFlatIndex(metricType)
idxBinIvfFlat := index.NewBinIvfFlatIndex(metricType, nlist)
allFloatIndex := []index.Index{
idxBinFlat,
idxBinIvfFlat,
}
return allFloatIndex
}
func genUnsupportedFloatGroupByIndex() []index.Index {
idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 16, 8)
idxScann := index.NewSCANNIndex(entity.L2, 16, false)
return []index.Index{
idxIvfPq,
idxScann,
}
}
func prepareDataForGroupBySearch(t *testing.T, loopInsert int, insertNi int, idx index.Index, withGrowing bool) (*base.MilvusClient, context.Context, string) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout*5)
mc := createDefaultMilvusClient(ctx, t)
// create collection with all datatype
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.AllFields), hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true))
for i := 0; i < loopInsert; i++ {
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(insertNi))
}
if !withGrowing {
prepare.FlushData(ctx, t, mc, schema.CollectionName)
}
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultFloatVecFieldName: idx}))
// create scalar index
supportedGroupByFields := []string{
common.DefaultInt64FieldName, common.DefaultInt8FieldName, common.DefaultInt16FieldName,
common.DefaultInt32FieldName, common.DefaultVarcharFieldName, common.DefaultBoolFieldName,
}
for _, groupByField := range supportedGroupByFields {
idxTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, groupByField, index.NewAutoIndex(entity.L2)))
common.CheckErr(t, err, true)
err = idxTask.Await(ctx)
common.CheckErr(t, err, true)
}
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
return mc, ctx, schema.CollectionName
}
// create coll with all datatype -> build all supported index
// -> search with WithGroupByField (int* + varchar + bool
// -> verify every top passage is the top of whole group
// output_fields: pk + groupBy
func TestSearchGroupByFloatDefault(t *testing.T) {
t.Skip("https://github.com/milvus-io/milvus/issues/38343")
t.Parallel()
for _, idx := range genGroupByVectorIndex(entity.L2) {
// prepare data
mc, ctx, collName := prepareDataForGroupBySearch(t, 100, 200, idx, false)
// search params
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
// search with groupBy field
supportedGroupByFields := []string{
common.DefaultInt64FieldName, common.DefaultInt8FieldName,
common.DefaultInt16FieldName, common.DefaultInt32FieldName, common.DefaultVarcharFieldName, common.DefaultBoolFieldName,
}
for _, groupByField := range supportedGroupByFields {
resGroupBy, _ := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithANNSField(common.DefaultFloatVecFieldName).
WithGroupByField(groupByField).WithOutputFields(common.DefaultInt64FieldName, groupByField))
// verify each topK entity is the top1 of the whole group
hitsNum := 0
total := 0
for i := 0; i < common.DefaultNq; i++ {
for j := 0; j < resGroupBy[i].ResultCount; j++ {
groupByValue, _ := resGroupBy[i].GroupByValue.Get(j)
pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j)
var expr string
if groupByField == "varchar" {
expr = fmt.Sprintf("%s == '%v' ", groupByField, groupByValue)
} else {
expr = fmt.Sprintf("%s == %v", groupByField, groupByValue)
}
// search filter with groupByValue is the top1
resFilter, _ := mc.Search(ctx, client.NewSearchOption(collName, 1, queryVec[:1]).WithANNSField(common.DefaultFloatVecFieldName).
WithGroupByField(groupByField).WithFilter(expr).WithOutputFields(common.DefaultInt64FieldName, groupByField))
filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)
if filterTop1Pk == pkValue {
hitsNum += 1
}
total += 1
}
}
// verify hits rate
hitsRate := float32(hitsNum) / float32(total)
_str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n",
groupByField, common.DefaultNq, common.DefaultLimit, hitsNum, hitsRate)
log.Println(_str)
if groupByField != "bool" {
// waiting for fix https://github.com/milvus-io/milvus/issues/32630
require.GreaterOrEqualf(t, hitsRate, float32(0.1), _str)
}
}
}
}
func TestSearchGroupByFloatDefaultCosine(t *testing.T) {
t.Skip("https://github.com/milvus-io/milvus/issues/38343")
t.Parallel()
for _, idx := range genGroupByVectorIndex(entity.COSINE) {
// prepare data
mc, ctx, collName := prepareDataForGroupBySearch(t, 100, 200, idx, false)
// search params
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
// search with groupBy field without varchar
supportedGroupByFields := []string{
common.DefaultInt64FieldName, common.DefaultInt8FieldName,
common.DefaultInt16FieldName, common.DefaultInt32FieldName, common.DefaultBoolFieldName,
}
for _, groupByField := range supportedGroupByFields {
resGroupBy, _ := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithANNSField(common.DefaultFloatVecFieldName).
WithGroupByField(groupByField).WithOutputFields(common.DefaultInt64FieldName, groupByField))
// verify each topK entity is the top1 of the whole group
hitsNum := 0
total := 0
for i := 0; i < common.DefaultNq; i++ {
for j := 0; j < resGroupBy[i].ResultCount; j++ {
groupByValue, _ := resGroupBy[i].GroupByValue.Get(j)
pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j)
expr := fmt.Sprintf("%s == %v", groupByField, groupByValue)
// search filter with groupByValue is the top1
resFilter, _ := mc.Search(ctx, client.NewSearchOption(collName, 1, queryVec[:1]).WithANNSField(common.DefaultFloatVecFieldName).
WithGroupByField(groupByField).WithFilter(expr).WithOutputFields(common.DefaultInt64FieldName, groupByField))
filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)
if filterTop1Pk == pkValue {
hitsNum += 1
}
total += 1
}
}
// verify hits rate
hitsRate := float32(hitsNum) / float32(total)
_str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n",
groupByField, common.DefaultNq, common.DefaultLimit, hitsNum, hitsRate)
log.Println(_str)
if groupByField != "bool" {
// waiting for fix https://github.com/milvus-io/milvus/issues/32630
require.GreaterOrEqualf(t, hitsRate, float32(0.1), _str)
}
}
}
}
// test groupBy search sparse vector
func TestGroupBySearchSparseVector(t *testing.T) {
t.Parallel()
idxInverted := index.NewSparseInvertedIndex(entity.IP, 0.3)
idxWand := index.NewSparseWANDIndex(entity.IP, 0.2)
for _, idx := range []index.Index{idxInverted, idxWand} {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := createDefaultMilvusClient(ctx, t)
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), hp.TNewFieldsOption().TWithMaxLen(common.TestMaxLen),
hp.TNewSchemaOption().TWithEnableDynamicField(true))
for i := 0; i < 100; i++ {
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(200))
}
prepare.FlushData(ctx, t, mc, schema.CollectionName)
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultSparseVecFieldName: idx}))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
// groupBy search
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeSparseVector)
resGroupBy, _ := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithANNSField(common.DefaultSparseVecFieldName).
WithGroupByField(common.DefaultVarcharFieldName).WithOutputFields(common.DefaultInt64FieldName, common.DefaultVarcharFieldName))
// verify each topK entity is the top1 of the whole group
hitsNum := 0
total := 0
for i := 0; i < common.DefaultNq; i++ {
if resGroupBy[i].ResultCount > 0 {
for j := 0; j < resGroupBy[i].ResultCount; j++ {
groupByValue, _ := resGroupBy[i].GroupByValue.Get(j)
pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j)
expr := fmt.Sprintf("%s == '%v' ", common.DefaultVarcharFieldName, groupByValue)
// search filter with groupByValue is the top1
resFilter, _ := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, 1, []entity.Vector{queryVec[i]}).
WithANNSField(common.DefaultSparseVecFieldName).
WithGroupByField(common.DefaultVarcharFieldName).
WithFilter(expr).
WithOutputFields(common.DefaultInt64FieldName, common.DefaultVarcharFieldName))
filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)
log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d",
common.DefaultVarcharFieldName, groupByValue, pkValue, filterTop1Pk)
if filterTop1Pk == pkValue {
hitsNum += 1
}
total += 1
}
}
}
// verify hits rate
hitsRate := float32(hitsNum) / float32(total)
_str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n",
common.DefaultVarcharFieldName, common.DefaultNq, common.DefaultLimit, hitsNum, hitsRate)
log.Println(_str)
require.GreaterOrEqualf(t, hitsRate, float32(0.8), _str)
}
}
// binary vector -> not supported
func TestSearchGroupByBinaryDefault(t *testing.T) {
t.Parallel()
for _, metricType := range hp.SupportBinIvfFlatMetricType {
for _, idx := range genGroupByBinaryIndex(metricType) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := createDefaultMilvusClient(ctx, t)
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.VarcharBinary), hp.TNewFieldsOption(),
hp.TNewSchemaOption().TWithEnableDynamicField(true))
for i := 0; i < 2; i++ {
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(1000))
}
prepare.FlushData(ctx, t, mc, schema.CollectionName)
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultBinaryVecFieldName: idx}))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
// search params
queryVec := hp.GenSearchVectors(2, common.DefaultDim, entity.FieldTypeBinaryVector)
t.Log("Waiting for support for specifying search parameters")
// sp, _ := index.NewBinIvfFlatIndexSearchParam(32)
supportedGroupByFields := []string{common.DefaultVarcharFieldName, common.DefaultBinaryVecFieldName}
// search with groupBy field
for _, groupByField := range supportedGroupByFields {
_, err := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithGroupByField(groupByField).
WithOutputFields(common.DefaultVarcharFieldName, groupByField))
common.CheckErr(t, err, false, "not support search_group_by operation based on binary vector column")
}
}
}
}
// binary vector -> growing segments, maybe brute force
// default Bounded ConsistencyLevel -> succ ??
// strong ConsistencyLevel -> error
func TestSearchGroupByBinaryGrowing(t *testing.T) {
t.Skip("https://github.com/milvus-io/milvus/issues/38343")
t.Parallel()
for _, metricType := range hp.SupportBinIvfFlatMetricType {
idxBinIvfFlat := index.NewBinIvfFlatIndex(metricType, 128)
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := createDefaultMilvusClient(ctx, t)
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.VarcharBinary), hp.TNewFieldsOption(),
hp.TNewSchemaOption().TWithEnableDynamicField(true))
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultBinaryVecFieldName: idxBinIvfFlat}))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
for i := 0; i < 2; i++ {
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(1000))
}
// search params
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeBinaryVector)
t.Log("Waiting for support for specifying search parameters")
// sp, _ := index.NewBinIvfFlatIndexSearchParam(64)
supportedGroupByFields := []string{common.DefaultVarcharFieldName}
// search with groupBy field
for _, groupByField := range supportedGroupByFields {
_, err := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithGroupByField(groupByField).
WithOutputFields(common.DefaultVarcharFieldName, groupByField).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, false, "not support search_group_by operation based on binary vector column")
}
}
}
// groupBy in growing segments, maybe growing index or brute force
func TestSearchGroupByFloatGrowing(t *testing.T) {
for _, metricType := range hp.SupportFloatMetricType {
idxHnsw := index.NewHNSWIndex(metricType, 8, 96)
mc, ctx, collName := prepareDataForGroupBySearch(t, 100, 200, idxHnsw, true)
// search params
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
supportedGroupByFields := []string{common.DefaultInt64FieldName, "int8", "int16", "int32", "varchar", "bool"}
// search with groupBy field
hitsNum := 0
total := 0
for _, groupByField := range supportedGroupByFields {
resGroupBy, _ := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithANNSField(common.DefaultFloatVecFieldName).
WithOutputFields(common.DefaultInt64FieldName, groupByField).WithGroupByField(groupByField).WithConsistencyLevel(entity.ClStrong))
// verify each topK entity is the top1 in the group
for i := 0; i < common.DefaultNq; i++ {
for j := 0; j < resGroupBy[i].ResultCount; j++ {
groupByValue, _ := resGroupBy[i].GroupByValue.Get(j)
pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j)
var expr string
if groupByField == "varchar" {
expr = fmt.Sprintf("%s == '%v' ", groupByField, groupByValue)
} else {
expr = fmt.Sprintf("%s == %v", groupByField, groupByValue)
}
resFilter, _ := mc.Search(ctx, client.NewSearchOption(collName, 1, queryVec).WithANNSField(common.DefaultFloatVecFieldName).
WithOutputFields(common.DefaultInt64FieldName, groupByField).WithGroupByField(groupByField).WithFilter(expr).WithConsistencyLevel(entity.ClStrong))
// search filter with groupByValue is the top1
filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)
log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d",
groupByField, groupByValue, pkValue, filterTop1Pk)
if filterTop1Pk == pkValue {
hitsNum += 1
}
total += 1
}
}
// verify hits rate
hitsRate := float32(hitsNum) / float32(total)
_str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n",
groupByField, common.DefaultNq, common.DefaultLimit, hitsNum, hitsRate)
log.Println(_str)
if groupByField != "bool" {
require.GreaterOrEqualf(t, hitsRate, float32(0.8), _str)
}
}
}
}
// groupBy + pagination
func TestSearchGroupByPagination(t *testing.T) {
// create index and load
idx := index.NewHNSWIndex(entity.COSINE, 8, 96)
mc, ctx, collName := prepareDataForGroupBySearch(t, 10, 1000, idx, false)
// search params
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
offset := 10
// search pagination & groupBy
resGroupByPagination, _ := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName).WithOffset(offset).
WithOutputFields(common.DefaultInt64FieldName, common.DefaultVarcharFieldName).WithANNSField(common.DefaultFloatVecFieldName))
common.CheckSearchResult(t, resGroupByPagination, common.DefaultNq, common.DefaultLimit)
// search limit=origin limit + offset
resGroupByDefault, _ := mc.Search(ctx, client.NewSearchOption(collName, offset+common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName).
WithOutputFields(common.DefaultInt64FieldName, common.DefaultVarcharFieldName).WithANNSField(common.DefaultFloatVecFieldName))
for i := 0; i < common.DefaultNq; i++ {
require.Equal(t, resGroupByDefault[i].IDs.(*column.ColumnInt64).Data()[10:], resGroupByPagination[i].IDs.(*column.ColumnInt64).Data())
}
}
// only support: "FLAT", "IVF_FLAT", "HNSW"
func TestSearchGroupByUnsupportedIndex(t *testing.T) {
t.Parallel()
for _, idx := range genUnsupportedFloatGroupByIndex() {
t.Run(string(idx.IndexType()), func(t *testing.T) {
mc, ctx, collName := prepareDataForGroupBySearch(t, 3, 1000, idx, false)
// groupBy search
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
_, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName).WithANNSField(common.DefaultFloatVecFieldName))
common.CheckErr(t, err, false, "doesn't support")
})
}
}
// FLOAT, DOUBLE, JSON, ARRAY
func TestSearchGroupByUnsupportedDataType(t *testing.T) {
idxHnsw := index.NewHNSWIndex(entity.L2, 8, 96)
mc, ctx, collName := prepareDataForGroupBySearch(t, 1, 1000, idxHnsw, true)
// groupBy search with unsupported field type
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
for _, unsupportedField := range []string{
common.DefaultFloatFieldName, common.DefaultDoubleFieldName,
common.DefaultJSONFieldName, common.DefaultFloatVecFieldName, common.DefaultInt8ArrayField, common.DefaultFloatArrayField,
} {
_, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(unsupportedField).WithANNSField(common.DefaultFloatVecFieldName))
common.CheckErr(t, err, false, "unsupported data type")
}
}
// groupBy + iterator -> not supported
func TestSearchGroupByIterator(t *testing.T) {
// TODO: sdk support
}
// groupBy + range search -> not supported
func TestSearchGroupByRangeSearch(t *testing.T) {
t.Skipf("Waiting for support for specifying search parameters")
idxHnsw := index.NewHNSWIndex(entity.COSINE, 8, 96)
mc, ctx, collName := prepareDataForGroupBySearch(t, 1, 1000, idxHnsw, true)
// groupBy search with range
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
// sp, _ := index.NewHNSWIndexSearchParam(50)
// sp.AddRadius(0)
// sp.AddRangeFilter(0.8)
// range search
_, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName).WithANNSField(common.DefaultFloatVecFieldName))
common.CheckErr(t, err, false, "Not allowed to do range-search when doing search-group-by")
}
// groupBy + advanced search
func TestSearchGroupByHybridSearch(t *testing.T) {
t.Skipf("Waiting for HybridSearch implemention")
}

View File

@ -272,9 +272,10 @@ func (cf FieldsAllFields) GenFields(option GenFieldsOption) []*entity.Field {
} }
// scalar fields and array fields // scalar fields and array fields
for _, fieldType := range GetAllScalarFieldType() { for _, fieldType := range GetAllScalarFieldType() {
if fieldType == entity.FieldTypeInt64 { switch fieldType {
case entity.FieldTypeInt64:
continue continue
} else if fieldType == entity.FieldTypeArray { case entity.FieldTypeArray:
for _, eleType := range GetAllArrayElementType() { for _, eleType := range GetAllArrayElementType() {
arrayField := entity.NewField().WithName(GetFieldNameByElementType(eleType)).WithDataType(entity.FieldTypeArray).WithElementType(eleType).WithMaxCapacity(option.MaxCapacity) arrayField := entity.NewField().WithName(GetFieldNameByElementType(eleType)).WithDataType(entity.FieldTypeArray).WithElementType(eleType).WithMaxCapacity(option.MaxCapacity)
if eleType == entity.FieldTypeVarChar { if eleType == entity.FieldTypeVarChar {
@ -282,10 +283,10 @@ func (cf FieldsAllFields) GenFields(option GenFieldsOption) []*entity.Field {
} }
fields = append(fields, arrayField) fields = append(fields, arrayField)
} }
} else if fieldType == entity.FieldTypeVarChar { case entity.FieldTypeVarChar:
varcharField := entity.NewField().WithName(GetFieldNameByFieldType(fieldType)).WithDataType(fieldType).WithMaxLength(option.MaxLength) varcharField := entity.NewField().WithName(GetFieldNameByFieldType(fieldType)).WithDataType(fieldType).WithMaxLength(option.MaxLength)
fields = append(fields, varcharField) fields = append(fields, varcharField)
} else { default:
scalarField := entity.NewField().WithName(GetFieldNameByFieldType(fieldType)).WithDataType(fieldType) scalarField := entity.NewField().WithName(GetFieldNameByFieldType(fieldType)).WithDataType(fieldType)
fields = append(fields, scalarField) fields = append(fields, scalarField)
} }
@ -312,9 +313,10 @@ func (cf FieldsInt64VecAllScalar) GenFields(option GenFieldsOption) []*entity.Fi
} }
// scalar fields and array fields // scalar fields and array fields
for _, fieldType := range GetAllScalarFieldType() { for _, fieldType := range GetAllScalarFieldType() {
if fieldType == entity.FieldTypeInt64 { switch fieldType {
case entity.FieldTypeInt64:
continue continue
} else if fieldType == entity.FieldTypeArray { case entity.FieldTypeArray:
for _, eleType := range GetAllArrayElementType() { for _, eleType := range GetAllArrayElementType() {
arrayField := entity.NewField().WithName(GetFieldNameByElementType(eleType)).WithDataType(entity.FieldTypeArray).WithElementType(eleType).WithMaxCapacity(option.MaxCapacity) arrayField := entity.NewField().WithName(GetFieldNameByElementType(eleType)).WithDataType(entity.FieldTypeArray).WithElementType(eleType).WithMaxCapacity(option.MaxCapacity)
if eleType == entity.FieldTypeVarChar { if eleType == entity.FieldTypeVarChar {
@ -322,10 +324,10 @@ func (cf FieldsInt64VecAllScalar) GenFields(option GenFieldsOption) []*entity.Fi
} }
fields = append(fields, arrayField) fields = append(fields, arrayField)
} }
} else if fieldType == entity.FieldTypeVarChar { case entity.FieldTypeVarChar:
varcharField := entity.NewField().WithName(GetFieldNameByFieldType(fieldType)).WithDataType(fieldType).WithMaxLength(option.MaxLength) varcharField := entity.NewField().WithName(GetFieldNameByFieldType(fieldType)).WithDataType(fieldType).WithMaxLength(option.MaxLength)
fields = append(fields, varcharField) fields = append(fields, varcharField)
} else { default:
scalarField := entity.NewField().WithName(GetFieldNameByFieldType(fieldType)).WithDataType(fieldType) scalarField := entity.NewField().WithName(GetFieldNameByFieldType(fieldType)).WithDataType(fieldType)
fields = append(fields, scalarField) fields = append(fields, scalarField)
} }

View File

@ -488,8 +488,8 @@ func TestInsertReadSparseEmptyVector(t *testing.T) {
// sparse vector: empty position and values // sparse vector: empty position and values
sparseVec, err := entity.NewSliceSparseEmbedding([]uint32{}, []float32{}) sparseVec, err := entity.NewSliceSparseEmbedding([]uint32{}, []float32{})
common.CheckErr(t, err, true) common.CheckErr(t, err, true)
data2 := append(data, column.NewColumnSparseVectors(common.DefaultSparseVecFieldName, []entity.SparseEmbedding{sparseVec})) data = append(data, column.NewColumnSparseVectors(common.DefaultSparseVecFieldName, []entity.SparseEmbedding{sparseVec}))
insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName, data2...)) insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName, data...))
common.CheckErr(t, err, true) common.CheckErr(t, err, true)
require.EqualValues(t, 1, insertRes.InsertCount) require.EqualValues(t, 1, insertRes.InsertCount)
@ -526,8 +526,8 @@ func TestInsertSparseInvalidVector(t *testing.T) {
values = []float32{0.4} values = []float32{0.4}
sparseVec, err := entity.NewSliceSparseEmbedding(positions, values) sparseVec, err := entity.NewSliceSparseEmbedding(positions, values)
common.CheckErr(t, err, true) common.CheckErr(t, err, true)
data1 := append(data, column.NewColumnSparseVectors(common.DefaultSparseVecFieldName, []entity.SparseEmbedding{sparseVec})) data = append(data, column.NewColumnSparseVectors(common.DefaultSparseVecFieldName, []entity.SparseEmbedding{sparseVec}))
_, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName, data1...)) _, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName, data...))
common.CheckErr(t, err, false, "invalid index in sparse float vector: must be less than 2^32-1") common.CheckErr(t, err, false, "invalid index in sparse float vector: must be less than 2^32-1")
} }

View File

@ -28,7 +28,7 @@ func teardown() {
defer cancel() defer cancel()
mc, err := base.NewMilvusClient(ctx, &defaultCfg) mc, err := base.NewMilvusClient(ctx, &defaultCfg)
if err != nil { if err != nil {
log.Fatal("teardown failed to connect milvus with error", zap.Error(err)) log.Error("teardown failed to connect milvus with error", zap.Error(err))
} }
defer mc.Close(ctx) defer mc.Close(ctx)