milvus/tests/go_client/testcases/groupby_search_test.go

506 lines
20 KiB
Go

package testcases
import (
"context"
"fmt"
"log"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/client/v2/column"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/index"
client "github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/milvus-io/milvus/tests/go_client/base"
"github.com/milvus-io/milvus/tests/go_client/common"
hp "github.com/milvus-io/milvus/tests/go_client/testcases/helper"
)
// Generate groupBy-supported vector indexes
func genGroupByVectorIndex(metricType entity.MetricType) []index.Index {
nlist := 128
idxFlat := index.NewFlatIndex(metricType)
idxIvfFlat := index.NewIvfFlatIndex(metricType, nlist)
idxHnsw := index.NewHNSWIndex(metricType, 8, 96)
idxIvfSq8 := index.NewIvfSQ8Index(metricType, 128)
allFloatIndex := []index.Index{
idxFlat,
idxIvfFlat,
idxHnsw,
idxIvfSq8,
}
return allFloatIndex
}
// Generate groupBy-supported vector indexes
func genGroupByBinaryIndex(metricType entity.MetricType) []index.Index {
nlist := 128
idxBinFlat := index.NewBinFlatIndex(metricType)
idxBinIvfFlat := index.NewBinIvfFlatIndex(metricType, nlist)
allFloatIndex := []index.Index{
idxBinFlat,
idxBinIvfFlat,
}
return allFloatIndex
}
func genUnsupportedFloatGroupByIndex() []index.Index {
idxIvfPq := index.NewIvfPQIndex(entity.L2, 128, 16, 8)
return []index.Index{
idxIvfPq,
}
}
func prepareDataForGroupBySearch(t *testing.T, loopInsert int, insertNi int, idx index.Index, withGrowing bool) (*base.MilvusClient, context.Context, string) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout*5)
mc := hp.CreateDefaultMilvusClient(ctx, t)
// create collection with all datatype
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.AllFields), hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true))
for i := 0; i < loopInsert; i++ {
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(insertNi))
}
if !withGrowing {
prepare.FlushData(ctx, t, mc, schema.CollectionName)
}
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultFloatVecFieldName: idx}))
// create scalar index
supportedGroupByFields := []string{
common.DefaultInt64FieldName, common.DefaultInt8FieldName, common.DefaultInt16FieldName,
common.DefaultInt32FieldName, common.DefaultVarcharFieldName, common.DefaultBoolFieldName,
}
for _, groupByField := range supportedGroupByFields {
idxTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(schema.CollectionName, groupByField, index.NewAutoIndex(entity.L2)))
common.CheckErr(t, err, true)
err = idxTask.Await(ctx)
common.CheckErr(t, err, true)
}
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
return mc, ctx, schema.CollectionName
}
// create coll with all datatype -> build all supported index
// -> search with WithGroupByField (int* + varchar + bool
// -> verify every top passage is the top of whole group
// output_fields: pk + groupBy
func TestSearchGroupByFloatDefault(t *testing.T) {
t.Parallel()
concurrency := 10
ch := make(chan struct{}, concurrency)
wg := sync.WaitGroup{}
testFunc := func(idx index.Index) {
defer func() {
wg.Done()
<-ch
}()
// prepare data
mc, ctx, collName := prepareDataForGroupBySearch(t, 100, 200, idx, false)
// search params
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
// search with groupBy field
supportedGroupByFields := []string{
common.DefaultInt64FieldName, common.DefaultInt8FieldName,
common.DefaultInt16FieldName, common.DefaultInt32FieldName, common.DefaultVarcharFieldName, common.DefaultBoolFieldName,
}
for _, groupByField := range supportedGroupByFields {
resGroupBy, _ := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithANNSField(common.DefaultFloatVecFieldName).
WithGroupByField(groupByField).WithOutputFields(common.DefaultInt64FieldName, groupByField))
// verify each topK entity is the top1 of the whole group
hitsNum := 0
total := 0
for i := 0; i < common.DefaultNq; i++ {
for j := 0; j < resGroupBy[i].ResultCount; j++ {
groupByValue, _ := resGroupBy[i].GroupByValue.Get(j)
pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j)
var expr string
if groupByField == "varchar" {
expr = fmt.Sprintf("%s == '%v' ", groupByField, groupByValue)
} else {
expr = fmt.Sprintf("%s == %v", groupByField, groupByValue)
}
// search filter with groupByValue is the top1
resFilter, _ := mc.Search(ctx, client.NewSearchOption(collName, 1, queryVec[:1]).WithANNSField(common.DefaultFloatVecFieldName).
WithGroupByField(groupByField).WithFilter(expr).WithOutputFields(common.DefaultInt64FieldName, groupByField))
filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)
if filterTop1Pk == pkValue {
hitsNum += 1
}
total += 1
}
}
// verify hits rate
hitsRate := float32(hitsNum) / float32(total)
_str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n",
groupByField, common.DefaultNq, common.DefaultLimit, hitsNum, hitsRate)
log.Println(_str)
if groupByField != "bool" {
// waiting for fix https://github.com/milvus-io/milvus/issues/32630
require.GreaterOrEqualf(t, hitsRate, float32(0.1), _str)
}
}
}
for _, idx := range genGroupByVectorIndex(entity.L2) {
ch <- struct{}{}
wg.Add(1)
go testFunc(idx)
}
wg.Wait()
}
func TestSearchGroupByFloatDefaultCosine(t *testing.T) {
t.Parallel()
concurrency := 10
ch := make(chan struct{}, concurrency)
wg := sync.WaitGroup{}
testFunc := func(idx index.Index) {
defer func() {
wg.Done()
<-ch
}()
// prepare data
mc, ctx, collName := prepareDataForGroupBySearch(t, 100, 200, idx, false)
// search params
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
// search with groupBy field without varchar
supportedGroupByFields := []string{
common.DefaultInt64FieldName, common.DefaultInt8FieldName,
common.DefaultInt16FieldName, common.DefaultInt32FieldName, common.DefaultBoolFieldName,
}
for _, groupByField := range supportedGroupByFields {
resGroupBy, _ := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithANNSField(common.DefaultFloatVecFieldName).
WithGroupByField(groupByField).WithOutputFields(common.DefaultInt64FieldName, groupByField))
// verify each topK entity is the top1 of the whole group
hitsNum := 0
total := 0
for i := 0; i < common.DefaultNq; i++ {
for j := 0; j < resGroupBy[i].ResultCount; j++ {
groupByValue, _ := resGroupBy[i].GroupByValue.Get(j)
pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j)
expr := fmt.Sprintf("%s == %v", groupByField, groupByValue)
// search filter with groupByValue is the top1
resFilter, _ := mc.Search(ctx, client.NewSearchOption(collName, 1, queryVec[:1]).WithANNSField(common.DefaultFloatVecFieldName).
WithGroupByField(groupByField).WithFilter(expr).WithOutputFields(common.DefaultInt64FieldName, groupByField))
filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)
if filterTop1Pk == pkValue {
hitsNum += 1
}
total += 1
}
}
// verify hits rate
hitsRate := float32(hitsNum) / float32(total)
_str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n",
groupByField, common.DefaultNq, common.DefaultLimit, hitsNum, hitsRate)
log.Println(_str)
if groupByField != "bool" {
// waiting for fix https://github.com/milvus-io/milvus/issues/32630
require.GreaterOrEqualf(t, hitsRate, float32(0.1), _str)
}
}
}
for _, idx := range genGroupByVectorIndex(entity.COSINE) {
ch <- struct{}{}
wg.Add(1)
go testFunc(idx)
}
wg.Wait()
}
// test groupBy search sparse vector
func TestGroupBySearchSparseVector(t *testing.T) {
t.Parallel()
idxInverted := index.NewSparseInvertedIndex(entity.IP, 0.3)
idxWand := index.NewSparseWANDIndex(entity.IP, 0.2)
for _, idx := range []index.Index{idxInverted, idxWand} {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), hp.TNewFieldsOption().TWithMaxLen(common.TestMaxLen),
hp.TNewSchemaOption().TWithEnableDynamicField(true))
for i := 0; i < 100; i++ {
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(200))
}
prepare.FlushData(ctx, t, mc, schema.CollectionName)
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultSparseVecFieldName: idx}))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
// groupBy search
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeSparseVector)
resGroupBy, _ := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithANNSField(common.DefaultSparseVecFieldName).
WithGroupByField(common.DefaultVarcharFieldName).WithOutputFields(common.DefaultInt64FieldName, common.DefaultVarcharFieldName))
// verify each topK entity is the top1 of the whole group
hitsNum := 0
total := 0
for i := 0; i < common.DefaultNq; i++ {
if resGroupBy[i].ResultCount > 0 {
for j := 0; j < resGroupBy[i].ResultCount; j++ {
groupByValue, _ := resGroupBy[i].GroupByValue.Get(j)
pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j)
expr := fmt.Sprintf("%s == '%v' ", common.DefaultVarcharFieldName, groupByValue)
// search filter with groupByValue is the top1
resFilter, _ := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, 1, []entity.Vector{queryVec[i]}).
WithANNSField(common.DefaultSparseVecFieldName).
WithGroupByField(common.DefaultVarcharFieldName).
WithFilter(expr).
WithOutputFields(common.DefaultInt64FieldName, common.DefaultVarcharFieldName))
filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)
log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d",
common.DefaultVarcharFieldName, groupByValue, pkValue, filterTop1Pk)
if filterTop1Pk == pkValue {
hitsNum += 1
}
total += 1
}
}
}
// verify hits rate
hitsRate := float32(hitsNum) / float32(total)
_str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n",
common.DefaultVarcharFieldName, common.DefaultNq, common.DefaultLimit, hitsNum, hitsRate)
log.Println(_str)
require.GreaterOrEqualf(t, hitsRate, float32(0.8), _str)
}
}
// binary vector -> not supported
func TestSearchGroupByBinaryDefault(t *testing.T) {
t.Parallel()
for _, metricType := range hp.SupportBinIvfFlatMetricType {
for _, idx := range genGroupByBinaryIndex(metricType) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.VarcharBinary), hp.TNewFieldsOption(),
hp.TNewSchemaOption().TWithEnableDynamicField(true))
for i := 0; i < 2; i++ {
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(1000))
}
prepare.FlushData(ctx, t, mc, schema.CollectionName)
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultBinaryVecFieldName: idx}))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
// search params
queryVec := hp.GenSearchVectors(2, common.DefaultDim, entity.FieldTypeBinaryVector)
t.Log("Waiting for support for specifying search parameters")
// sp, _ := index.NewBinIvfFlatIndexSearchParam(32)
supportedGroupByFields := []string{common.DefaultVarcharFieldName, common.DefaultBinaryVecFieldName}
// search with groupBy field
for _, groupByField := range supportedGroupByFields {
_, err := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithGroupByField(groupByField).
WithOutputFields(common.DefaultVarcharFieldName, groupByField))
common.CheckErr(t, err, false, "not support search_group_by operation based on binary vector column")
}
}
}
}
// binary vector -> growing segments, maybe brute force
// default Bounded ConsistencyLevel -> succ ??
// strong ConsistencyLevel -> error
func TestSearchGroupByBinaryGrowing(t *testing.T) {
t.Parallel()
for _, metricType := range hp.SupportBinIvfFlatMetricType {
idxBinIvfFlat := index.NewBinIvfFlatIndex(metricType, 128)
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.VarcharBinary), hp.TNewFieldsOption(),
hp.TNewSchemaOption().TWithEnableDynamicField(true))
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultBinaryVecFieldName: idxBinIvfFlat}))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
for i := 0; i < 2; i++ {
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(1000))
}
// search params
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeBinaryVector)
t.Log("Waiting for support for specifying search parameters")
// sp, _ := index.NewBinIvfFlatIndexSearchParam(64)
supportedGroupByFields := []string{common.DefaultVarcharFieldName}
// search with groupBy field
for _, groupByField := range supportedGroupByFields {
_, err := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithGroupByField(groupByField).
WithOutputFields(common.DefaultVarcharFieldName, groupByField).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, false, "not support search_group_by operation based on binary vector column")
}
}
}
// groupBy in growing segments, maybe growing index or brute force
func TestSearchGroupByFloatGrowing(t *testing.T) {
t.Parallel()
concurrency := 10
ch := make(chan struct{}, concurrency)
wg := sync.WaitGroup{}
testFunc := func(metricType entity.MetricType) {
defer func() {
wg.Done()
<-ch
}()
idxHnsw := index.NewHNSWIndex(metricType, 8, 96)
mc, ctx, collName := prepareDataForGroupBySearch(t, 100, 200, idxHnsw, true)
// search params
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
supportedGroupByFields := []string{common.DefaultInt64FieldName, "int8", "int16", "int32", "varchar", "bool"}
// search with groupBy field
hitsNum := 0
total := 0
for _, groupByField := range supportedGroupByFields {
resGroupBy, _ := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithANNSField(common.DefaultFloatVecFieldName).
WithOutputFields(common.DefaultInt64FieldName, groupByField).WithGroupByField(groupByField).WithConsistencyLevel(entity.ClStrong))
// verify each topK entity is the top1 in the group
for i := 0; i < common.DefaultNq; i++ {
for j := 0; j < resGroupBy[i].ResultCount; j++ {
groupByValue, _ := resGroupBy[i].GroupByValue.Get(j)
pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j)
var expr string
if groupByField == "varchar" {
expr = fmt.Sprintf("%s == '%v' ", groupByField, groupByValue)
} else {
expr = fmt.Sprintf("%s == %v", groupByField, groupByValue)
}
resFilter, _ := mc.Search(ctx, client.NewSearchOption(collName, 1, queryVec).WithANNSField(common.DefaultFloatVecFieldName).
WithOutputFields(common.DefaultInt64FieldName, groupByField).WithGroupByField(groupByField).WithFilter(expr).WithConsistencyLevel(entity.ClStrong))
// search filter with groupByValue is the top1
filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0)
log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d",
groupByField, groupByValue, pkValue, filterTop1Pk)
if filterTop1Pk == pkValue {
hitsNum += 1
}
total += 1
}
}
// verify hits rate
hitsRate := float32(hitsNum) / float32(total)
_str := fmt.Sprintf("GroupBy search with field %s, nq=%d and limit=%d , then hitsNum= %d, hitsRate=%v\n",
groupByField, common.DefaultNq, common.DefaultLimit, hitsNum, hitsRate)
log.Println(_str)
if groupByField != "bool" {
require.GreaterOrEqualf(t, hitsRate, float32(0.8), _str)
}
}
}
for _, metricType := range hp.SupportFloatMetricType {
ch <- struct{}{}
wg.Add(1)
go testFunc(metricType)
}
wg.Wait()
}
// groupBy + pagination
func TestSearchGroupByPagination(t *testing.T) {
t.Parallel()
// create index and load
idx := index.NewHNSWIndex(entity.COSINE, 8, 96)
mc, ctx, collName := prepareDataForGroupBySearch(t, 10, 1000, idx, false)
// search params
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
offset := 10
// search pagination & groupBy
resGroupByPagination, _ := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName).WithOffset(offset).
WithOutputFields(common.DefaultInt64FieldName, common.DefaultVarcharFieldName).WithANNSField(common.DefaultFloatVecFieldName))
common.CheckSearchResult(t, resGroupByPagination, common.DefaultNq, common.DefaultLimit)
// search limit=origin limit + offset
resGroupByDefault, _ := mc.Search(ctx, client.NewSearchOption(collName, offset+common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName).
WithOutputFields(common.DefaultInt64FieldName, common.DefaultVarcharFieldName).WithANNSField(common.DefaultFloatVecFieldName))
for i := 0; i < common.DefaultNq; i++ {
require.Equal(t, resGroupByDefault[i].IDs.(*column.ColumnInt64).Data()[10:], resGroupByPagination[i].IDs.(*column.ColumnInt64).Data())
}
}
// only support: "FLAT", "IVF_FLAT", "HNSW"
func TestSearchGroupByUnsupportedIndex(t *testing.T) {
t.Parallel()
for _, idx := range genUnsupportedFloatGroupByIndex() {
t.Run(string(idx.IndexType()), func(t *testing.T) {
mc, ctx, collName := prepareDataForGroupBySearch(t, 3, 1000, idx, false)
// groupBy search
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
_, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName).WithANNSField(common.DefaultFloatVecFieldName))
common.CheckErr(t, err, false, "doesn't support")
})
}
}
// FLOAT, DOUBLE, JSON, ARRAY
func TestSearchGroupByUnsupportedDataType(t *testing.T) {
idxHnsw := index.NewHNSWIndex(entity.L2, 8, 96)
mc, ctx, collName := prepareDataForGroupBySearch(t, 1, 1000, idxHnsw, true)
// groupBy search with unsupported field type
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
for _, unsupportedField := range []string{
common.DefaultFloatFieldName, common.DefaultDoubleFieldName,
common.DefaultJSONFieldName, common.DefaultFloatVecFieldName, common.DefaultInt8ArrayField, common.DefaultFloatArrayField,
} {
_, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(unsupportedField).WithANNSField(common.DefaultFloatVecFieldName))
common.CheckErr(t, err, false, "unsupported data type")
}
}
// groupBy + iterator -> not supported
func TestSearchGroupByIterator(t *testing.T) {
// TODO: sdk support
}
// groupBy + range search -> not supported
func TestSearchGroupByRangeSearch(t *testing.T) {
t.Skipf("https://github.com/milvus-io/milvus/issues/38846")
idxHnsw := index.NewHNSWIndex(entity.COSINE, 8, 96)
mc, ctx, collName := prepareDataForGroupBySearch(t, 1, 1000, idxHnsw, true)
// groupBy search with range
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
// range search
_, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName).
WithANNSField(common.DefaultFloatVecFieldName).WithSearchParam("radius", "0").WithSearchParam("range_filter", "0.8"))
common.CheckErr(t, err, false, "Not allowed to do range-search when doing search-group-by")
}
// groupBy + advanced search
func TestSearchGroupByHybridSearch(t *testing.T) {
t.Skipf("Waiting for HybridSearch implemention")
}