Add HasDefaultDatabase interface to several statements

pull/8293/head
Cory LaNou 2017-03-22 15:25:58 -05:00
parent db430cb434
commit 215b5fc89c
2 changed files with 175 additions and 0 deletions

View File

@ -593,6 +593,11 @@ func (s *DropRetentionPolicyStatement) RequiredPrivileges() (ExecutionPrivileges
return ExecutionPrivileges{{Admin: false, Name: s.Database, Privilege: WritePrivilege}}, nil
}
// DefaultDatabase returns the default database from the statement.
func (s *DropRetentionPolicyStatement) DefaultDatabase() string {
return s.Database
}
// CreateUserStatement represents a command for creating a new user.
type CreateUserStatement struct {
// Name of the user to be created.
@ -703,6 +708,11 @@ func (s *GrantStatement) RequiredPrivileges() (ExecutionPrivileges, error) {
return ExecutionPrivileges{{Admin: true, Name: "", Privilege: AllPrivileges}}, nil
}
// DefaultDatabase returns the default database from the statement.
func (s *GrantStatement) DefaultDatabase() string {
return s.On
}
// GrantAdminStatement represents a command for granting admin privilege.
type GrantAdminStatement struct {
// Who to grant the privilege to.
@ -801,6 +811,11 @@ func (s *RevokeStatement) RequiredPrivileges() (ExecutionPrivileges, error) {
return ExecutionPrivileges{{Admin: true, Name: "", Privilege: AllPrivileges}}, nil
}
// DefaultDatabase returns the default database from the statement.
func (s *RevokeStatement) DefaultDatabase() string {
return s.On
}
// RevokeAdminStatement represents a command to revoke admin privilege from a user.
type RevokeAdminStatement struct {
// Who to revoke admin privilege from.
@ -867,6 +882,11 @@ func (s *CreateRetentionPolicyStatement) RequiredPrivileges() (ExecutionPrivileg
return ExecutionPrivileges{{Admin: true, Name: "", Privilege: AllPrivileges}}, nil
}
// DefaultDatabase returns the default database from the statement.
func (s *CreateRetentionPolicyStatement) DefaultDatabase() string {
return s.Database
}
// AlterRetentionPolicyStatement represents a command to alter an existing retention policy.
type AlterRetentionPolicyStatement struct {
// Name of policy to alter.
@ -923,6 +943,11 @@ func (s *AlterRetentionPolicyStatement) RequiredPrivileges() (ExecutionPrivilege
return ExecutionPrivileges{{Admin: true, Name: "", Privilege: AllPrivileges}}, nil
}
// DefaultDatabase returns the default database from the statement.
func (s *AlterRetentionPolicyStatement) DefaultDatabase() string {
return s.Database
}
// FillOption represents different options for filling aggregate windows.
type FillOption int
@ -2538,6 +2563,14 @@ func (s *DeleteStatement) RequiredPrivileges() (ExecutionPrivileges, error) {
return ExecutionPrivileges{{Admin: false, Name: "", Privilege: WritePrivilege}}, nil
}
// DefaultDatabase returns the default database from the statement.
func (s *DeleteStatement) DefaultDatabase() string {
if m, ok := s.Source.(*Measurement); ok {
return m.Database
}
return ""
}
// ShowSeriesStatement represents a command for listing series in the database.
type ShowSeriesStatement struct {
// Database to query. If blank, use the default database.
@ -2599,6 +2632,11 @@ func (s *ShowSeriesStatement) RequiredPrivileges() (ExecutionPrivileges, error)
return ExecutionPrivileges{{Admin: false, Name: "", Privilege: ReadPrivilege}}, nil
}
// DefaultDatabase returns the default database from the statement.
func (s *ShowSeriesStatement) DefaultDatabase() string {
return s.Database
}
// DropSeriesStatement represents a command for removing a series from the database.
type DropSeriesStatement struct {
// Data source that fields are extracted from (optional)
@ -2819,6 +2857,11 @@ func (s *DropContinuousQueryStatement) RequiredPrivileges() (ExecutionPrivileges
return ExecutionPrivileges{{Admin: false, Name: "", Privilege: WritePrivilege}}, nil
}
// DefaultDatabase returns the default database from the statement.
func (s *DropContinuousQueryStatement) DefaultDatabase() string {
return s.Database
}
// ShowMeasurementsStatement represents a command for listing measurements.
type ShowMeasurementsStatement struct {
// Database to query. If blank, use the default database.
@ -2883,6 +2926,11 @@ func (s *ShowMeasurementsStatement) RequiredPrivileges() (ExecutionPrivileges, e
return ExecutionPrivileges{{Admin: false, Name: "", Privilege: ReadPrivilege}}, nil
}
// DefaultDatabase returns the default database from the statement.
func (s *ShowMeasurementsStatement) DefaultDatabase() string {
return s.Database
}
// DropMeasurementStatement represents a command to drop a measurement.
type DropMeasurementStatement struct {
// Name of the measurement to be dropped.
@ -2937,6 +2985,11 @@ func (s *ShowRetentionPoliciesStatement) RequiredPrivileges() (ExecutionPrivileg
return ExecutionPrivileges{{Admin: false, Name: "", Privilege: ReadPrivilege}}, nil
}
// DefaultDatabase returns the default database from the statement.
func (s *ShowRetentionPoliciesStatement) DefaultDatabase() string {
return s.Database
}
// ShowStatsStatement displays statistics for a given module.
type ShowStatsStatement struct {
Module string
@ -3038,6 +3091,11 @@ func (s *CreateSubscriptionStatement) RequiredPrivileges() (ExecutionPrivileges,
return ExecutionPrivileges{{Admin: true, Name: "", Privilege: AllPrivileges}}, nil
}
// DefaultDatabase returns the default database from the statement.
func (s *CreateSubscriptionStatement) DefaultDatabase() string {
return s.Database
}
// DropSubscriptionStatement represents a command to drop a subscription to the incoming data stream.
type DropSubscriptionStatement struct {
Name string
@ -3055,6 +3113,11 @@ func (s *DropSubscriptionStatement) RequiredPrivileges() (ExecutionPrivileges, e
return ExecutionPrivileges{{Admin: true, Name: "", Privilege: AllPrivileges}}, nil
}
// DefaultDatabase returns the default database from the statement.
func (s *DropSubscriptionStatement) DefaultDatabase() string {
return s.Database
}
// ShowSubscriptionsStatement represents a command to show a list of subscriptions.
type ShowSubscriptionsStatement struct {
}
@ -3142,6 +3205,11 @@ func (s *ShowTagKeysStatement) RequiredPrivileges() (ExecutionPrivileges, error)
return ExecutionPrivileges{{Admin: false, Name: "", Privilege: ReadPrivilege}}, nil
}
// DefaultDatabase returns the default database from the statement.
func (s *ShowTagKeysStatement) DefaultDatabase() string {
return s.Database
}
// ShowTagValuesStatement represents a command for listing tag values.
type ShowTagValuesStatement struct {
// Database to query. If blank, use the default database.
@ -3216,6 +3284,11 @@ func (s *ShowTagValuesStatement) RequiredPrivileges() (ExecutionPrivileges, erro
return ExecutionPrivileges{{Admin: false, Name: "", Privilege: ReadPrivilege}}, nil
}
// DefaultDatabase returns the default database from the statement.
func (s *ShowTagValuesStatement) DefaultDatabase() string {
return s.Database
}
// ShowUsersStatement represents a command for listing users.
type ShowUsersStatement struct{}
@ -3282,6 +3355,11 @@ func (s *ShowFieldKeysStatement) RequiredPrivileges() (ExecutionPrivileges, erro
return ExecutionPrivileges{{Admin: false, Name: "", Privilege: ReadPrivilege}}, nil
}
// DefaultDatabase returns the default database from the statement.
func (s *ShowFieldKeysStatement) DefaultDatabase() string {
return s.Database
}
// Fields represents a list of fields.
type Fields []*Field

View File

@ -2,6 +2,7 @@ package influxql_test
import (
"fmt"
"go/importer"
"reflect"
"strings"
"testing"
@ -1576,6 +1577,102 @@ func TestParse_Errors(t *testing.T) {
}
}
// This test checks to ensure that we have given thought to the database
// context required for security checks. If a new statement is added, this
// test will fail until it is categorized into the correct bucket below.
func Test_EnforceHasDefaultDatabase(t *testing.T) {
pkg, err := importer.Default().Import("github.com/influxdata/influxdb/influxql")
if err != nil {
fmt.Printf("error: %s\n", err.Error())
return
}
statements := []string{}
// this is a list of statements that do not have a database context
exemptStatements := []string{
"CreateDatabaseStatement",
"CreateUserStatement",
"DeleteSeriesStatement",
"DropDatabaseStatement",
"DropMeasurementStatement",
"DropSeriesStatement",
"DropShardStatement",
"DropUserStatement",
"GrantAdminStatement",
"KillQueryStatement",
"RevokeAdminStatement",
"SelectStatement",
"SetPasswordUserStatement",
"ShowContinuousQueriesStatement",
"ShowDatabasesStatement",
"ShowDiagnosticsStatement",
"ShowGrantsForUserStatement",
"ShowQueriesStatement",
"ShowShardGroupsStatement",
"ShowShardsStatement",
"ShowStatsStatement",
"ShowSubscriptionsStatement",
"ShowUsersStatement",
}
exists := func(stmt string) bool {
switch stmt {
// These are functions with the word statement in them, and can be ignored
case "Statement", "MustParseStatement", "ParseStatement", "RewriteStatement":
return true
default:
// check the exempt statements
for _, s := range exemptStatements {
if s == stmt {
return true
}
}
// check the statements that passed the interface test for HasDefaultDatabase
for _, s := range statements {
if s == stmt {
return true
}
}
return false
}
}
needsHasDefault := []interface{}{
&influxql.AlterRetentionPolicyStatement{},
&influxql.CreateContinuousQueryStatement{},
&influxql.CreateRetentionPolicyStatement{},
&influxql.CreateSubscriptionStatement{},
&influxql.DeleteStatement{},
&influxql.DropContinuousQueryStatement{},
&influxql.DropRetentionPolicyStatement{},
&influxql.DropSubscriptionStatement{},
&influxql.GrantStatement{},
&influxql.RevokeStatement{},
&influxql.ShowFieldKeysStatement{},
&influxql.ShowMeasurementsStatement{},
&influxql.ShowRetentionPoliciesStatement{},
&influxql.ShowSeriesStatement{},
&influxql.ShowTagKeysStatement{},
&influxql.ShowTagValuesStatement{},
}
for _, stmt := range needsHasDefault {
statements = append(statements, strings.TrimPrefix(fmt.Sprintf("%T", stmt), "*influxql."))
if _, ok := stmt.(influxql.HasDefaultDatabase); !ok {
t.Errorf("%T was expected to declare DefaultDatabase method", stmt)
}
}
for _, declName := range pkg.Scope().Names() {
if strings.HasSuffix(declName, "Statement") {
if !exists(declName) {
t.Errorf("unchecked statement %s. please update this test to determine if this statement needs to declare 'DefaultDatabase'", declName)
}
}
}
}
// Valuer represents a simple wrapper around a map to implement the influxql.Valuer interface.
type Valuer map[string]interface{}