Support bool field filter in search and query expression (#7814)

Signed-off-by: fishpenguin <kun.yu@zilliz.com>
pull/7893/head
yukun 2021-09-14 16:15:47 +08:00 committed by GitHub
parent e28db5600d
commit 4bf3c6889c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 78 additions and 1 deletions

View File

@ -14,6 +14,7 @@ package proxy
import (
"fmt"
"math"
"strings"
ant_ast "github.com/antonmedv/expr/ast"
ant_parser "github.com/antonmedv/expr/parser"
@ -256,7 +257,33 @@ func getLogicalOpType(opStr string) planpb.BinaryExpr_BinaryOp {
}
}
func parseBoolNode(nodeRaw *ant_ast.Node) *ant_ast.BoolNode {
switch node := (*nodeRaw).(type) {
case *ant_ast.IdentifierNode:
val := strings.ToLower(node.Value)
if val == "true" {
return &ant_ast.BoolNode{
Value: true,
}
} else if val == "false" {
return &ant_ast.BoolNode{
Value: false,
}
} else {
return nil
}
default:
return nil
}
}
func (context *ParserContext) createCmpExpr(left, right ant_ast.Node, operator string) (*planpb.Expr, error) {
if boolNode := parseBoolNode(&left); boolNode != nil {
left = boolNode
}
if boolNode := parseBoolNode(&right); boolNode != nil {
right = boolNode
}
idNodeLeft, leftIDNode := left.(*ant_ast.IdentifierNode)
idNodeRight, rightIDNode := right.(*ant_ast.IdentifierNode)
@ -540,11 +567,20 @@ func (context *ParserContext) handleLeafValue(nodeRaw *ant_ast.Node, dataType sc
Int64Val: int64(node.Value),
},
}
} else if dataType == schemapb.DataType_Bool {
gv = &planpb.GenericValue{
Val: &planpb.GenericValue_BoolVal{},
}
if node.Value == 1 {
gv.Val.(*planpb.GenericValue_BoolVal).BoolVal = true
} else {
gv.Val.(*planpb.GenericValue_BoolVal).BoolVal = false
}
} else {
return nil, fmt.Errorf("type mismatch")
}
case *ant_ast.BoolNode:
if typeutil.IsFloatingType(dataType) {
if typeutil.IsBoolType(dataType) {
gv = &planpb.GenericValue{
Val: &planpb.GenericValue_BoolVal{
BoolVal: node.Value,

View File

@ -15,6 +15,7 @@ import (
"fmt"
"testing"
ant_ast "github.com/antonmedv/expr/ast"
ant_parser "github.com/antonmedv/expr/parser"
"github.com/golang/protobuf/proto"
@ -148,6 +149,11 @@ func TestExprMultiRange_Str(t *testing.T) {
"0.1 ** 2 < FloatN < 2 ** 0.1",
"0.1 ** 1.1 < FloatN < 3.1 / 4",
"4.1 / 3 < FloatN < 0.0 / 5.0",
"BoolN1 == True",
"True == BoolN1",
"BoolN1 == False",
"BoolN1 == 1",
"BoolN1 == 0",
}
fields := []*schemapb.FieldSchema{
@ -156,6 +162,7 @@ func TestExprMultiRange_Str(t *testing.T) {
{FieldID: 102, Name: "age2", DataType: schemapb.DataType_Int64},
{FieldID: 103, Name: "FloatN", DataType: schemapb.DataType_Float},
{FieldID: 104, Name: "FloatN2", DataType: schemapb.DataType_Float},
{FieldID: 105, Name: "BoolN1", DataType: schemapb.DataType_Bool},
}
schema := &schemapb.CollectionSchema{
@ -214,3 +221,28 @@ func TestExprFieldCompare_Str(t *testing.T) {
println(dbgStr)
}
}
func Test_ParseBoolNode(t *testing.T) {
var nodeRaw1, nodeRaw2, nodeRaw3, nodeRaw4 ant_ast.Node
nodeRaw1 = &ant_ast.IdentifierNode{
Value: "True",
}
boolNode1 := parseBoolNode(&nodeRaw1)
assert.Equal(t, boolNode1.Value, true)
nodeRaw2 = &ant_ast.IdentifierNode{
Value: "False",
}
boolNode2 := parseBoolNode(&nodeRaw2)
assert.Equal(t, boolNode2.Value, false)
nodeRaw3 = &ant_ast.IdentifierNode{
Value: "abcd",
}
assert.Nil(t, parseBoolNode(&nodeRaw3))
nodeRaw4 = &ant_ast.BoolNode{
Value: true,
}
assert.Nil(t, parseBoolNode(&nodeRaw4))
}

View File

@ -161,3 +161,12 @@ func IsFloatingType(dataType schemapb.DataType) bool {
return false
}
}
func IsBoolType(dataType schemapb.DataType) bool {
switch dataType {
case schemapb.DataType_Bool:
return true
default:
return false
}
}