diff --git a/influxql/ast.go b/influxql/ast.go index fee0ccdf29..e5d90f46aa 100644 --- a/influxql/ast.go +++ b/influxql/ast.go @@ -1062,12 +1062,6 @@ func foldDurationNumberLiterals(expr *BinaryExpr) Expr { } } -// Visitor can be called by Walk to traverse an AST hierarchy. -// The Visit() function is called once per node. -type Visitor interface { - Visit(Node) Visitor -} - // TimeRange returns the minimum and maximum times specified by an expression. // Returns zero times if there is no bound. func TimeRange(expr Expr) (min, max time.Time) { @@ -1137,6 +1131,12 @@ func timeExprValue(ref Expr, lit Expr) time.Time { return time.Time{} } +// Visitor can be called by Walk to traverse an AST hierarchy. +// The Visit() function is called once per node. +type Visitor interface { + Visit(Node) Visitor +} + // Walk traverses a node hierarchy in depth-first order. func Walk(v Visitor, node Node) { if v = v.Visit(node); v == nil { @@ -1196,3 +1196,68 @@ func WalkFunc(node Node, fn func(Node)) { type walkFuncVisitor func(Node) func (fn walkFuncVisitor) Visit(n Node) Visitor { fn(n); return fn } + +// Rewriter can be called by Rewrite to replace nodes in the AST hierarchy. +// The Rewrite() function is called once per node. +type Rewriter interface { + Rewrite(Node) Node +} + +// Rewrite recursively invokes the rewriter to replace each node. +// Nodes are traversed depth-first and rewritten from leaf to root. +func Rewrite(r Rewriter, node Node) Node { + switch n := node.(type) { + case *Query: + n.Statements = Rewrite(r, n.Statements).(Statements) + + case Statements: + for i, s := range n { + n[i] = Rewrite(r, s).(Statement) + } + + case *SelectStatement: + n.Fields = Rewrite(r, n.Fields).(Fields) + n.Dimensions = Rewrite(r, n.Dimensions).(Dimensions) + n.Source = Rewrite(r, n.Source).(Source) + n.Condition = Rewrite(r, n.Condition).(Expr) + + case Fields: + for i, f := range n { + n[i] = Rewrite(r, f).(*Field) + } + + case *Field: + n.Expr = Rewrite(r, n.Expr).(Expr) + + case Dimensions: + for i, d := range n { + n[i] = Rewrite(r, d).(*Dimension) + } + + case *Dimension: + n.Expr = Rewrite(r, n.Expr).(Expr) + + case *BinaryExpr: + n.LHS = Rewrite(r, n.LHS).(Expr) + n.RHS = Rewrite(r, n.RHS).(Expr) + + case *ParenExpr: + n.Expr = Rewrite(r, n.Expr).(Expr) + + case *Call: + for i, expr := range n.Args { + n.Args[i] = Rewrite(r, expr).(Expr) + } + } + + return r.Rewrite(node) +} + +// RewriteFunc rewrites a node hierarchy. +func RewriteFunc(node Node, fn func(Node) Node) Node { + return Rewrite(rewriterFunc(fn), node) +} + +type rewriterFunc func(Node) Node + +func (fn rewriterFunc) Rewrite(n Node) Node { return fn(n) } diff --git a/influxql/ast_test.go b/influxql/ast_test.go index 9b4bd7b000..a2fb9eda17 100644 --- a/influxql/ast_test.go +++ b/influxql/ast_test.go @@ -212,3 +212,23 @@ func TestTimeRange(t *testing.T) { } } } + +// Ensure an AST node can be rewritten. +func TestRewrite(t *testing.T) { + expr := MustParseExpr(`time > 1 OR foo = 2`) + + // Flip LHS & RHS in all binary expressions. + act := influxql.RewriteFunc(expr, func(n influxql.Node) influxql.Node { + switch n := n.(type) { + case *influxql.BinaryExpr: + return &influxql.BinaryExpr{Op: n.Op, LHS: n.RHS, RHS: n.LHS} + default: + return n + } + }) + + // Verify that everything is flipped. + if act := act.String(); act != `2.000 = foo OR 1.000 > time` { + t.Fatalf("unexpected result: %s", act) + } +}