diff --git a/internal/proxy/plan_parser.go b/internal/proxy/plan_parser.go index ae0d43edd4..5644105019 100644 --- a/internal/proxy/plan_parser.go +++ b/internal/proxy/plan_parser.go @@ -202,44 +202,38 @@ func isSameOrder(opStr1, opStr2 string) bool { return isLess1 == isLess2 } -func getCompareOpType(opStr string, reverse bool) planpb.OpType { - type OpType = planpb.OpType - var op planpb.OpType - - if !reverse { - switch opStr { - case "<": +func getCompareOpType(opStr string, reverse bool) (op planpb.OpType) { + switch opStr { + case ">": + if reverse { op = planpb.OpType_LessThan - case ">": + } else { op = planpb.OpType_GreaterThan - case "<=": - op = planpb.OpType_LessEqual - case ">=": - op = planpb.OpType_GreaterEqual - case "==": - op = planpb.OpType_Equal - case "!=": - op = planpb.OpType_NotEqual - default: - op = planpb.OpType_Invalid } - } else { - switch opStr { - case ">": + case "<": + if reverse { + op = planpb.OpType_GreaterThan + } else { op = planpb.OpType_LessThan - case "<": - op = planpb.OpType_GreaterThan - case ">=": - op = planpb.OpType_LessEqual - case "<=": - op = planpb.OpType_GreaterEqual - case "==": - op = planpb.OpType_Equal - case "!=": - op = planpb.OpType_NotEqual - default: - op = planpb.OpType_Invalid } + case ">=": + if reverse { + op = planpb.OpType_LessEqual + } else { + op = planpb.OpType_GreaterEqual + } + case "<=": + if reverse { + op = planpb.OpType_GreaterEqual + } else { + op = planpb.OpType_LessEqual + } + case "==": + op = planpb.OpType_Equal + case "!=": + op = planpb.OpType_NotEqual + default: + op = planpb.OpType_Invalid } return op } @@ -283,10 +277,10 @@ func (pc *ParserContext) createCmpExpr(left, right ant_ast.Node, operator string if boolNode := parseBoolNode(&right); boolNode != nil { right = boolNode } - idNodeLeft, leftIDNode := left.(*ant_ast.IdentifierNode) - idNodeRight, rightIDNode := right.(*ant_ast.IdentifierNode) + idNodeLeft, okLeft := left.(*ant_ast.IdentifierNode) + idNodeRight, okRight := right.(*ant_ast.IdentifierNode) - if leftIDNode && rightIDNode { + if okLeft && okRight { leftField, err := pc.handleIdentifier(idNodeLeft) if err != nil { return nil, err @@ -312,15 +306,15 @@ func (pc *ParserContext) createCmpExpr(left, right ant_ast.Node, operator string } var idNode *ant_ast.IdentifierNode - var isReversed bool + var reverse bool var valueNode *ant_ast.Node - if leftIDNode { + if okLeft { idNode = idNodeLeft - isReversed = false + reverse = false valueNode = &right - } else if rightIDNode { + } else if okRight { idNode = idNodeRight - isReversed = true + reverse = true valueNode = &left } else { return nil, fmt.Errorf("compare expr has no identifier") @@ -336,7 +330,7 @@ func (pc *ParserContext) createCmpExpr(left, right ant_ast.Node, operator string return nil, err } - op := getCompareOpType(operator, isReversed) + op := getCompareOpType(operator, reverse) if op == planpb.OpType_Invalid { return nil, fmt.Errorf("invalid binary operator(%s)", operator) } diff --git a/internal/proxy/plan_parser_test.go b/internal/proxy/plan_parser_test.go index f51d9cf80b..4e0c6242f2 100644 --- a/internal/proxy/plan_parser_test.go +++ b/internal/proxy/plan_parser_test.go @@ -315,27 +315,66 @@ func TestExprFieldCompare_Str(t *testing.T) { } } -func Test_ParseBoolNode(t *testing.T) { - var nodeRaw1, nodeRaw2, nodeRaw3, nodeRaw4 ant_ast.Node - nodeRaw1 = &ant_ast.IdentifierNode{ - Value: "True", - } - boolNode1 := parseBoolNode(&nodeRaw1) - assert.Equal(t, boolNode1.Value, true) +func TestPlanParseAPIs(t *testing.T) { + t.Run("get compare op type", func(t *testing.T) { + var op planpb.OpType + var reverse bool - nodeRaw2 = &ant_ast.IdentifierNode{ - Value: "False", - } - boolNode2 := parseBoolNode(&nodeRaw2) - assert.Equal(t, boolNode2.Value, false) + reverse = false + op = getCompareOpType(">", reverse) + assert.Equal(t, planpb.OpType_GreaterThan, op) + op = getCompareOpType(">=", reverse) + assert.Equal(t, planpb.OpType_GreaterEqual, op) + op = getCompareOpType("<", reverse) + assert.Equal(t, planpb.OpType_LessThan, op) + op = getCompareOpType("<=", reverse) + assert.Equal(t, planpb.OpType_LessEqual, op) + op = getCompareOpType("==", reverse) + assert.Equal(t, planpb.OpType_Equal, op) + op = getCompareOpType("!=", reverse) + assert.Equal(t, planpb.OpType_NotEqual, op) + op = getCompareOpType("*", reverse) + assert.Equal(t, planpb.OpType_Invalid, op) - nodeRaw3 = &ant_ast.IdentifierNode{ - Value: "abcd", - } - assert.Nil(t, parseBoolNode(&nodeRaw3)) + reverse = true + op = getCompareOpType(">", reverse) + assert.Equal(t, planpb.OpType_LessThan, op) + op = getCompareOpType(">=", reverse) + assert.Equal(t, planpb.OpType_LessEqual, op) + op = getCompareOpType("<", reverse) + assert.Equal(t, planpb.OpType_GreaterThan, op) + op = getCompareOpType("<=", reverse) + assert.Equal(t, planpb.OpType_GreaterEqual, op) + op = getCompareOpType("==", reverse) + assert.Equal(t, planpb.OpType_Equal, op) + op = getCompareOpType("!=", reverse) + assert.Equal(t, planpb.OpType_NotEqual, op) + op = getCompareOpType("*", reverse) + assert.Equal(t, planpb.OpType_Invalid, op) + }) - nodeRaw4 = &ant_ast.BoolNode{ - Value: true, - } - assert.Nil(t, parseBoolNode(&nodeRaw4)) + t.Run("parse bool node", func(t *testing.T) { + var nodeRaw1, nodeRaw2, nodeRaw3, nodeRaw4 ant_ast.Node + nodeRaw1 = &ant_ast.IdentifierNode{ + Value: "True", + } + boolNode1 := parseBoolNode(&nodeRaw1) + assert.Equal(t, boolNode1.Value, true) + + nodeRaw2 = &ant_ast.IdentifierNode{ + Value: "False", + } + boolNode2 := parseBoolNode(&nodeRaw2) + assert.Equal(t, boolNode2.Value, false) + + nodeRaw3 = &ant_ast.IdentifierNode{ + Value: "abcd", + } + assert.Nil(t, parseBoolNode(&nodeRaw3)) + + nodeRaw4 = &ant_ast.BoolNode{ + Value: true, + } + assert.Nil(t, parseBoolNode(&nodeRaw4)) + }) }