Merge pull request #4858 from viru/fix-nested-aggr

Validate nested aggregations in queries
pull/4988/head
Philip O'Toole 2015-12-04 05:58:29 -08:00
commit c76b109ef4
2 changed files with 73 additions and 32 deletions

View File

@ -1158,6 +1158,45 @@ func (s *SelectStatement) validSelectWithAggregate() error {
return nil
}
// validTopBottomAggr determines if TOP or BOTTOM aggregates have valid arguments.
func (s *SelectStatement) validTopBottomAggr(expr *Call) error {
if exp, got := 2, len(expr.Args); got < exp {
return fmt.Errorf("invalid number of arguments for %s, expected at least %d, got %d", expr.Name, exp, got)
}
if len(expr.Args) > 1 {
callLimit, ok := expr.Args[len(expr.Args)-1].(*NumberLiteral)
if !ok {
return fmt.Errorf("expected integer as last argument in %s(), found %s", expr.Name, expr.Args[len(expr.Args)-1])
}
// Check if they asked for a limit smaller than what they passed into the call
if int64(callLimit.Val) > int64(s.Limit) && s.Limit != 0 {
return fmt.Errorf("limit (%d) in %s function can not be larger than the LIMIT (%d) in the select statement", int64(callLimit.Val), expr.Name, int64(s.Limit))
}
for _, v := range expr.Args[:len(expr.Args)-1] {
if _, ok := v.(*VarRef); !ok {
return fmt.Errorf("only fields or tags are allowed in %s(), found %s", expr.Name, v)
}
}
}
return nil
}
// validPercentileAggr determines if PERCENTILE have valid arguments.
func (s *SelectStatement) validPercentileAggr(expr *Call) error {
if err := s.validSelectWithAggregate(); err != nil {
return err
}
if exp, got := 2, len(expr.Args); got != exp {
return fmt.Errorf("invalid number of arguments for %s, expected %d, got %d", expr.Name, exp, got)
}
_, ok := expr.Args[1].(*NumberLiteral)
if !ok {
return fmt.Errorf("expected float argument in percentile()")
}
return nil
}
func (s *SelectStatement) validateAggregates(tr targetRequirement) error {
for _, f := range s.Fields {
for _, expr := range walkFunctionCalls(f.Expr) {
@ -1170,44 +1209,38 @@ func (s *SelectStatement) validateAggregates(tr targetRequirement) error {
return fmt.Errorf("invalid number of arguments for %s, expected at least %d but no more than %d, got %d", expr.Name, min, max, got)
}
// Validate that if they have grouping by time, they need a sub-call like min/max, etc.
groupByInterval, _ := s.GroupByInterval()
groupByInterval, err := s.GroupByInterval()
if err != nil {
return fmt.Errorf("invalid group interval: %v", err)
}
if groupByInterval > 0 {
if _, ok := expr.Args[0].(*Call); !ok {
c, ok := expr.Args[0].(*Call)
if !ok {
return fmt.Errorf("aggregate function required inside the call to %s", expr.Name)
}
}
case "percentile":
if err := s.validSelectWithAggregate(); err != nil {
return err
}
if exp, got := 2, len(expr.Args); got != exp {
return fmt.Errorf("invalid number of arguments for %s, expected %d, got %d", expr.Name, exp, got)
}
_, ok := expr.Args[1].(*NumberLiteral)
if !ok {
return fmt.Errorf("expected float argument in percentile()")
}
case "top", "bottom":
if exp, got := 2, len(expr.Args); got < exp {
return fmt.Errorf("invalid number of arguments for %s, expected at least %d, got %d", expr.Name, exp, got)
}
if len(expr.Args) > 1 {
callLimit, ok := expr.Args[len(expr.Args)-1].(*NumberLiteral)
if !ok {
return fmt.Errorf("expected integer as last argument in %s(), found %s", expr.Name, expr.Args[len(expr.Args)-1])
}
// Check if they asked for a limit smaller than what they passed into the call
if int64(callLimit.Val) > int64(s.Limit) && s.Limit != 0 {
return fmt.Errorf("limit (%d) in %s function can not be larger than the LIMIT (%d) in the select statement", int64(callLimit.Val), expr.Name, int64(s.Limit))
}
for _, v := range expr.Args[:len(expr.Args)-1] {
if _, ok := v.(*VarRef); !ok {
return fmt.Errorf("only fields or tags are allowed in %s(), found %s", expr.Name, v)
switch c.Name {
case "top", "bottom":
if err := s.validTopBottomAggr(c); err != nil {
return err
}
case "percentile":
if err := s.validPercentileAggr(c); err != nil {
return err
}
default:
if exp, got := 1, len(c.Args); got != exp {
return fmt.Errorf("invalid number of arguments for %s, expected %d, got %d", c.Name, exp, got)
}
}
}
case "top", "bottom":
if err := s.validTopBottomAggr(expr); err != nil {
return err
}
case "percentile":
if err := s.validPercentileAggr(expr); err != nil {
return err
}
default:
if err := s.validSelectWithAggregate(); err != nil {
return err

View File

@ -1663,10 +1663,18 @@ func TestParser_ParseStatement(t *testing.T) {
{s: `select derivative() from myseries`, err: `invalid number of arguments for derivative, expected at least 1 but no more than 2, got 0`},
{s: `select derivative(mean(value), 1h, 3) from myseries`, err: `invalid number of arguments for derivative, expected at least 1 but no more than 2, got 3`},
{s: `SELECT derivative(value) FROM myseries group by time(1h)`, err: `aggregate function required inside the call to derivative`},
{s: `SELECT derivative(top(value)) FROM myseries where time < now() and time > now() - 1d group by time(1h)`, err: `invalid number of arguments for top, expected at least 2, got 1`},
{s: `SELECT derivative(bottom(value)) FROM myseries where time < now() and time > now() - 1d group by time(1h)`, err: `invalid number of arguments for bottom, expected at least 2, got 1`},
{s: `SELECT derivative(max()) FROM myseries where time < now() and time > now() - 1d group by time(1h)`, err: `invalid number of arguments for max, expected 1, got 0`},
{s: `SELECT derivative(percentile(value)) FROM myseries where time < now() and time > now() - 1d group by time(1h)`, err: `invalid number of arguments for percentile, expected 2, got 1`},
{s: `SELECT non_negative_derivative(), field1 FROM myseries`, err: `mixing aggregate and non-aggregate queries is not supported`},
{s: `select non_negative_derivative() from myseries`, err: `invalid number of arguments for non_negative_derivative, expected at least 1 but no more than 2, got 0`},
{s: `select non_negative_derivative(mean(value), 1h, 3) from myseries`, err: `invalid number of arguments for non_negative_derivative, expected at least 1 but no more than 2, got 3`},
{s: `SELECT non_negative_derivative(value) FROM myseries group by time(1h)`, err: `aggregate function required inside the call to non_negative_derivative`},
{s: `SELECT non_negative_derivative(top(value)) FROM myseries where time < now() and time > now() - 1d group by time(1h)`, err: `invalid number of arguments for top, expected at least 2, got 1`},
{s: `SELECT non_negative_derivative(bottom(value)) FROM myseries where time < now() and time > now() - 1d group by time(1h)`, err: `invalid number of arguments for bottom, expected at least 2, got 1`},
{s: `SELECT non_negative_derivative(max()) FROM myseries where time < now() and time > now() - 1d group by time(1h)`, err: `invalid number of arguments for max, expected 1, got 0`},
{s: `SELECT non_negative_derivative(percentile(value)) FROM myseries where time < now() and time > now() - 1d group by time(1h)`, err: `invalid number of arguments for percentile, expected 2, got 1`},
{s: `SELECT field1 from myseries WHERE host =~ 'asd' LIMIT 1`, err: `found asd, expected regex at line 1, char 42`},
{s: `SELECT value > 2 FROM cpu`, err: `invalid operator > in SELECT clause at line 1, char 8; operator is intended for WHERE clause`},
{s: `SELECT value = 2 FROM cpu`, err: `invalid operator = in SELECT clause at line 1, char 8; operator is intended for WHERE clause`},