enhance: [GoSDK] Add range & sparse ann param (#39751)

Related to #38846

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
pull/39781/head
congqixia 2025-02-12 14:54:46 +08:00 committed by GitHub
parent 5fdc7578bb
commit 95f3a248b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 29 additions and 4 deletions

View File

@ -32,6 +32,14 @@ func (b baseAnnParam) Params() map[string]any {
return b.params
}
func (b baseAnnParam) WithRadius(radius float64) {
b.WithExtraParam("radius", radius)
}
func (b baseAnnParam) WithRangeFilter(rangeFilter float64) {
b.WithExtraParam("range_filter", rangeFilter)
}
type CustomAnnParam struct {
baseAnnParam
}

View File

@ -61,3 +61,19 @@ func NewSparseWANDIndex(metricType MetricType, dropRatio float64) Index {
dropRatio: dropRatio,
}
}
type sparseAnnParam struct {
baseAnnParam
}
func NewSparseAnnParam() sparseAnnParam {
return sparseAnnParam{
baseAnnParam: baseAnnParam{
params: make(map[string]any),
},
}
}
func (b sparseAnnParam) WithDropRatio(dropRatio float64) {
b.WithExtraParam("drop_ratio_search", dropRatio)
}

View File

@ -1090,7 +1090,6 @@ func TestSearchSparseVectorNotSupported(t *testing.T) {
}
func TestRangeSearchSparseVector(t *testing.T) {
t.Skipf("https://github.com/milvus-io/milvus/issues/38846")
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout*2)
mc := createDefaultMilvusClient(ctx, t)
@ -1111,10 +1110,12 @@ func TestRangeSearchSparseVector(t *testing.T) {
log.Info("default search", zap.Any("score", res.Scores))
}
radius := 10
rangeFilter := 30
annParams := index.NewSparseAnnParam()
annParams.WithRadius(10)
annParams.WithRangeFilter(30)
annParams.WithDropRatio(0.2)
resRange, errSearch = mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).
WithSearchParam("drop_ratio_search", "0.2").WithSearchParam("radius", strconv.Itoa(radius)).WithSearchParam("range_filter", strconv.Itoa(rangeFilter)))
WithAnnParam(annParams))
common.CheckErr(t, errSearch, true)
common.CheckErr(t, errSearch, true)
require.Len(t, resRange, common.DefaultNq)