mirror of https://github.com/milvus-io/milvus.git
				
				
				
			
		
			
				
	
	
		
			139 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
			
		
		
	
	
			139 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
package proxy
 | 
						|
 | 
						|
import (
 | 
						|
	"testing"
 | 
						|
 | 
						|
	"github.com/stretchr/testify/assert"
 | 
						|
 | 
						|
	"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
 | 
						|
	"github.com/milvus-io/milvus/internal/parser/planparserv2"
 | 
						|
	"github.com/milvus-io/milvus/internal/proto/planpb"
 | 
						|
	"github.com/milvus-io/milvus/pkg/common"
 | 
						|
	"github.com/milvus-io/milvus/pkg/util/funcutil"
 | 
						|
)
 | 
						|
 | 
						|
func TestParsePartitionKeys(t *testing.T) {
 | 
						|
	prefix := "TestParsePartitionKeys"
 | 
						|
	collectionName := prefix + funcutil.GenRandomStr()
 | 
						|
 | 
						|
	fieldName2Type := make(map[string]schemapb.DataType)
 | 
						|
	fieldName2Type["int64_field"] = schemapb.DataType_Int64
 | 
						|
	fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
 | 
						|
	fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector
 | 
						|
	schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false)
 | 
						|
	partitionKeyField := &schemapb.FieldSchema{
 | 
						|
		Name:           "partition_key_field",
 | 
						|
		DataType:       schemapb.DataType_Int64,
 | 
						|
		IsPartitionKey: true,
 | 
						|
	}
 | 
						|
	schema.Fields = append(schema.Fields, partitionKeyField)
 | 
						|
	fieldID := common.StartOfUserFieldID
 | 
						|
	for _, field := range schema.Fields {
 | 
						|
		field.FieldID = int64(fieldID)
 | 
						|
		fieldID++
 | 
						|
	}
 | 
						|
 | 
						|
	queryInfo := &planpb.QueryInfo{
 | 
						|
		Topk:         10,
 | 
						|
		MetricType:   "L2",
 | 
						|
		SearchParams: "",
 | 
						|
		RoundDecimal: -1,
 | 
						|
	}
 | 
						|
 | 
						|
	type testCase struct {
 | 
						|
		name                 string
 | 
						|
		expr                 string
 | 
						|
		expected             int
 | 
						|
		validPartitionKeys   []int64
 | 
						|
		invalidPartitionKeys []int64
 | 
						|
	}
 | 
						|
	cases := []testCase{
 | 
						|
		{
 | 
						|
			name:                 "binary_expr_and with term",
 | 
						|
			expr:                 "partition_key_field in [7, 8] && int64_field >= 10",
 | 
						|
			expected:             2,
 | 
						|
			validPartitionKeys:   []int64{7, 8},
 | 
						|
			invalidPartitionKeys: []int64{},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:                 "binary_expr_and with equal",
 | 
						|
			expr:                 "partition_key_field == 7 && int64_field >= 10",
 | 
						|
			expected:             1,
 | 
						|
			validPartitionKeys:   []int64{7},
 | 
						|
			invalidPartitionKeys: []int64{},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:                 "binary_expr_and with term2",
 | 
						|
			expr:                 "partition_key_field in [7, 8] && int64_field == 10",
 | 
						|
			expected:             2,
 | 
						|
			validPartitionKeys:   []int64{7, 8},
 | 
						|
			invalidPartitionKeys: []int64{10},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:                 "binary_expr_and with partition key in range",
 | 
						|
			expr:                 "partition_key_field in [7, 8] && partition_key_field > 9",
 | 
						|
			expected:             2,
 | 
						|
			validPartitionKeys:   []int64{7, 8},
 | 
						|
			invalidPartitionKeys: []int64{9},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:                 "binary_expr_and with partition key in range2",
 | 
						|
			expr:                 "int64_field == 10 && partition_key_field > 9",
 | 
						|
			expected:             0,
 | 
						|
			validPartitionKeys:   []int64{},
 | 
						|
			invalidPartitionKeys: []int64{},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:                 "binary_expr_and with term and not",
 | 
						|
			expr:                 "partition_key_field in [7, 8] && partition_key_field not in [10, 20]",
 | 
						|
			expected:             2,
 | 
						|
			validPartitionKeys:   []int64{7, 8},
 | 
						|
			invalidPartitionKeys: []int64{10, 20},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:                 "binary_expr_or with term and not",
 | 
						|
			expr:                 "partition_key_field in [7, 8] or partition_key_field not in [10, 20]",
 | 
						|
			expected:             0,
 | 
						|
			validPartitionKeys:   []int64{},
 | 
						|
			invalidPartitionKeys: []int64{},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:                 "binary_expr_or with term and not 2",
 | 
						|
			expr:                 "partition_key_field in [7, 8] or int64_field not in [10, 20]",
 | 
						|
			expected:             2,
 | 
						|
			validPartitionKeys:   []int64{7, 8},
 | 
						|
			invalidPartitionKeys: []int64{10, 20},
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for _, tc := range cases {
 | 
						|
		t.Run(tc.name, func(t *testing.T) {
 | 
						|
			// test search plan
 | 
						|
			searchPlan, err := planparserv2.CreateSearchPlan(schema, tc.expr, "fvec_field", queryInfo)
 | 
						|
			assert.NoError(t, err)
 | 
						|
			expr, err := ParseExprFromPlan(searchPlan)
 | 
						|
			assert.NoError(t, err)
 | 
						|
			partitionKeys := ParsePartitionKeys(expr)
 | 
						|
			assert.Equal(t, tc.expected, len(partitionKeys))
 | 
						|
			for _, key := range partitionKeys {
 | 
						|
				int64Val := key.Val.(*planpb.GenericValue_Int64Val).Int64Val
 | 
						|
				assert.Contains(t, tc.validPartitionKeys, int64Val)
 | 
						|
				assert.NotContains(t, tc.invalidPartitionKeys, int64Val)
 | 
						|
			}
 | 
						|
 | 
						|
			// test query plan
 | 
						|
			queryPlan, err := planparserv2.CreateRetrievePlan(schema, tc.expr)
 | 
						|
			assert.NoError(t, err)
 | 
						|
			expr, err = ParseExprFromPlan(queryPlan)
 | 
						|
			assert.NoError(t, err)
 | 
						|
			partitionKeys = ParsePartitionKeys(expr)
 | 
						|
			assert.Equal(t, tc.expected, len(partitionKeys))
 | 
						|
			for _, key := range partitionKeys {
 | 
						|
				int64Val := key.Val.(*planpb.GenericValue_Int64Val).Int64Val
 | 
						|
				assert.Contains(t, tc.validPartitionKeys, int64Val)
 | 
						|
				assert.NotContains(t, tc.invalidPartitionKeys, int64Val)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 |