diff --git a/influxql/ast.go b/influxql/ast.go index 9fd1e69bac..dfa0398e69 100644 --- a/influxql/ast.go +++ b/influxql/ast.go @@ -1737,6 +1737,83 @@ type rewriterFunc func(Node) Node func (fn rewriterFunc) Rewrite(n Node) Node { return fn(n) } +// Eval evaluates expr against a map. +func Eval(expr Expr, m map[string]interface{}) interface{} { + if expr == nil { + return nil + } + + switch expr := expr.(type) { + case *BinaryExpr: + return evalBinaryExpr(expr, m) + case *BooleanLiteral: + return expr.Val + case *NumberLiteral: + return expr.Val + case *ParenExpr: + return Eval(expr.Expr, m) + case *StringLiteral: + return expr.Val + case *VarRef: + return m[expr.Val] + default: + return nil + } +} + +func evalBinaryExpr(expr *BinaryExpr, m map[string]interface{}) interface{} { + lhs := Eval(expr.LHS, m) + rhs := Eval(expr.RHS, m) + + // Evaluate if both sides are simple types. + switch lhs := lhs.(type) { + case bool: + rhs, _ := rhs.(bool) + switch expr.Op { + case AND: + return lhs && rhs + case OR: + return lhs || rhs + } + case float64: + rhs, _ := rhs.(float64) + switch expr.Op { + case EQ: + return lhs == rhs + case NEQ: + return lhs != rhs + case LT: + return lhs < rhs + case LTE: + return lhs <= rhs + case GT: + return lhs > rhs + case GTE: + return lhs >= rhs + case ADD: + return lhs + rhs + case SUB: + return lhs - rhs + case MUL: + return lhs * rhs + case DIV: + if rhs == 0 { + return float64(0) + } + return lhs / rhs + } + case string: + rhs, _ := rhs.(string) + switch expr.Op { + case EQ: + return lhs == rhs + case NEQ: + return lhs != rhs + } + } + return nil +} + // Reduce evaluates expr using the available values in valuer. // References that don't exist in valuer are ignored. func Reduce(expr Expr, valuer Valuer) Expr { diff --git a/influxql/ast_test.go b/influxql/ast_test.go index 261d975c57..33872e9082 100644 --- a/influxql/ast_test.go +++ b/influxql/ast_test.go @@ -1,6 +1,7 @@ package influxql_test import ( + "reflect" "strings" "testing" @@ -165,6 +166,50 @@ func TestRewrite(t *testing.T) { } } +// Ensure an expression can be reduced. +func TestEval(t *testing.T) { + for i, tt := range []struct { + in string + out interface{} + data map[string]interface{} + }{ + // Number literals. + {in: `1 + 2`, out: float64(3)}, + {in: `(foo*2) + ( (4/2) + (3 * 5) - 0.5 )`, out: float64(26.5), data: map[string]interface{}{"foo": float64(5)}}, + {in: `foo / 2`, out: float64(2), data: map[string]interface{}{"foo": float64(4)}}, + {in: `4 = 4`, out: true}, + {in: `4 <> 4`, out: false}, + {in: `6 > 4`, out: true}, + {in: `4 >= 4`, out: true}, + {in: `4 < 6`, out: true}, + {in: `4 <= 4`, out: true}, + {in: `4 AND 5`, out: nil}, + + // Boolean literals. + {in: `true AND false`, out: false}, + {in: `true OR false`, out: true}, + + // String literals. + {in: `'foo' = 'bar'`, out: false}, + {in: `'foo' = 'foo'`, out: true}, + + // Variable references. + {in: `foo`, out: "bar", data: map[string]interface{}{"foo": "bar"}}, + {in: `foo = 'bar'`, out: true, data: map[string]interface{}{"foo": "bar"}}, + {in: `foo = 'bar'`, out: nil, data: map[string]interface{}{"foo": nil}}, + {in: `foo <> 'bar'`, out: true, data: map[string]interface{}{"foo": "xxx"}}, + } { + // Evaluate expression. + out := influxql.Eval(MustParseExpr(tt.in), tt.data) + + // Compare with expected output. + if !reflect.DeepEqual(tt.out, out) { + t.Errorf("%d. %s: unexpected output:\n\nexp=%#v\n\ngot=%#v\n\n", i, tt.in, tt.out, out) + continue + } + } +} + // Ensure an expression can be reduced. func TestReduce(t *testing.T) { now := mustParseTime("2000-01-01T00:00:00Z") diff --git a/tx_test.go b/tx_test.go index 9a93b81483..fe97f3148e 100644 --- a/tx_test.go +++ b/tx_test.go @@ -34,7 +34,7 @@ func TestTx_CreateIterators(t *testing.T) { stmt := MustParseSelectStatement(` SELECT value FROM "db"."raw"."cpu" - WHERE (service = 'redis' AND (value > 10 OR value < 5)) AND (time >= '2000-01-01' AND time < '2000-01-02') + WHERE (service = 'redis' AND value < 100) AND (time >= '2000-01-01' AND time < '2000-01-02') GROUP BY time(1h), region`) // Retrieve iterators from server. @@ -60,7 +60,6 @@ func TestTx_CreateIterators(t *testing.T) { // Iterate over each one. if data := slurp(itrs); !reflect.DeepEqual(data, []keyValue{ - {key: 946684800000000000, value: float64(100), tags: "\x00\aus-east"}, {key: 946684800000000000, value: float64(2), tags: "\x00\aus-west"}, {key: 946684810000000000, value: float64(90), tags: "\x00\aus-east"}, {key: 946684830000000000, value: float64(70), tags: "\x00\aus-east"},