Refactor the aggregation to use the more generic aggregator interface.
parent
647bb4710a
commit
4e6ce97627
|
@ -0,0 +1,98 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"parser"
|
||||
"protocol"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Aggregator interface {
|
||||
AggregatePoint(series string, group interface{}, p *protocol.Point) error
|
||||
GetValue(series string, group interface{}) *protocol.FieldValue
|
||||
ColumnName() string
|
||||
ColumnType() protocol.FieldDefinition_Type
|
||||
}
|
||||
|
||||
type AggregatorIniitializer func(*parser.Query) (Aggregator, error)
|
||||
|
||||
var registeredAggregators = make(map[string]AggregatorIniitializer)
|
||||
|
||||
func init() {
|
||||
registeredAggregators["count"] = NewCountAggregator
|
||||
registeredAggregators["__timestamp_aggregator"] = NewTimestampAggregator
|
||||
}
|
||||
|
||||
type CountAggregator struct {
|
||||
counts map[string]map[interface{}]int32
|
||||
}
|
||||
|
||||
func (self *CountAggregator) AggregatePoint(series string, group interface{}, p *protocol.Point) error {
|
||||
counts := self.counts[series]
|
||||
if counts == nil {
|
||||
counts = make(map[interface{}]int32)
|
||||
self.counts[series] = counts
|
||||
}
|
||||
counts[group]++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (self *CountAggregator) ColumnName() string {
|
||||
return "count"
|
||||
}
|
||||
|
||||
func (self *CountAggregator) ColumnType() protocol.FieldDefinition_Type {
|
||||
return protocol.FieldDefinition_INT32
|
||||
}
|
||||
|
||||
func (self *CountAggregator) GetValue(series string, group interface{}) *protocol.FieldValue {
|
||||
value := self.counts[series][group]
|
||||
return &protocol.FieldValue{IntValue: &value}
|
||||
}
|
||||
|
||||
func NewCountAggregator(query *parser.Query) (Aggregator, error) {
|
||||
return &CountAggregator{make(map[string]map[interface{}]int32)}, nil
|
||||
}
|
||||
|
||||
type TimestampAggregator struct {
|
||||
duration *time.Duration
|
||||
timestamps map[string]map[interface{}]int64
|
||||
}
|
||||
|
||||
func (self *TimestampAggregator) AggregatePoint(series string, group interface{}, p *protocol.Point) error {
|
||||
timestamps := self.timestamps[series]
|
||||
if timestamps == nil {
|
||||
timestamps = make(map[interface{}]int64)
|
||||
self.timestamps[series] = timestamps
|
||||
}
|
||||
if self.duration != nil {
|
||||
timestamps[group] = time.Unix(*p.Timestamp, 0).Round(*self.duration).Unix()
|
||||
} else {
|
||||
timestamps[group] = *p.Timestamp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (self *TimestampAggregator) ColumnName() string {
|
||||
return "count"
|
||||
}
|
||||
|
||||
func (self *TimestampAggregator) ColumnType() protocol.FieldDefinition_Type {
|
||||
return protocol.FieldDefinition_INT32
|
||||
}
|
||||
|
||||
func (self *TimestampAggregator) GetValue(series string, group interface{}) *protocol.FieldValue {
|
||||
value := self.timestamps[series][group]
|
||||
return &protocol.FieldValue{Int64Value: &value}
|
||||
}
|
||||
|
||||
func NewTimestampAggregator(query *parser.Query) (Aggregator, error) {
|
||||
duration, err := query.GetGroupByClause().GetGroupByTime()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TimestampAggregator{
|
||||
timestamps: make(map[string]map[interface{}]int64),
|
||||
duration: duration,
|
||||
}, nil
|
||||
}
|
|
@ -29,7 +29,7 @@ func (self *QueryEngine) RunQuery(query *parser.Query, yield func(*protocol.Seri
|
|||
}
|
||||
}()
|
||||
|
||||
if isCountQuery(query) {
|
||||
if isAggregateQuery(query) {
|
||||
return self.executeCountQueryWithGroupBy(query, yield)
|
||||
} else {
|
||||
self.coordinator.DistributeQuery(query, yield)
|
||||
|
@ -41,13 +41,12 @@ func NewQueryEngine(c coordinator.Coordinator) (EngineI, error) {
|
|||
return &QueryEngine{c}, nil
|
||||
}
|
||||
|
||||
func isCountQuery(query *parser.Query) bool {
|
||||
func isAggregateQuery(query *parser.Query) bool {
|
||||
for _, column := range query.GetColumnNames() {
|
||||
if column.IsFunctionCall() && column.Name == "count" {
|
||||
if column.IsFunctionCall() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -190,19 +189,22 @@ func createValuesToInterface(groupBy parser.GroupByClause, definitions []*protoc
|
|||
}
|
||||
|
||||
func (self *QueryEngine) executeCountQueryWithGroupBy(query *parser.Query, yield func(*protocol.Series) error) error {
|
||||
counts := make(map[string]map[interface{}]int32)
|
||||
timestamps := make(map[string]map[interface{}]int64)
|
||||
countAggregator, err := registeredAggregators["count"](query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
timestampAggregator, err := registeredAggregators["__timestamp_aggregator"](query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
aggregators := []Aggregator{countAggregator}
|
||||
|
||||
groups := make(map[string]map[interface{}]bool)
|
||||
groupBy := query.GetGroupByClause()
|
||||
|
||||
fieldTypes := map[string]*protocol.FieldDefinition_Type{}
|
||||
var inverse InverseMapper
|
||||
|
||||
duration, err := groupBy.GetGroupByTime()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = self.coordinator.DistributeQuery(query, func(series *protocol.Series) error {
|
||||
var mapper Mapper
|
||||
mapper, inverse, err = createValuesToInterface(groupBy, series.Fields)
|
||||
|
@ -216,23 +218,17 @@ func (self *QueryEngine) executeCountQueryWithGroupBy(query *parser.Query, yield
|
|||
|
||||
for _, point := range series.Points {
|
||||
value := mapper(point)
|
||||
tableCounts := counts[*series.Name]
|
||||
if tableCounts == nil {
|
||||
tableCounts = make(map[interface{}]int32)
|
||||
counts[*series.Name] = tableCounts
|
||||
for _, aggregator := range aggregators {
|
||||
aggregator.AggregatePoint(*series.Name, value, point)
|
||||
}
|
||||
tableCounts[value]++
|
||||
|
||||
tableTimestamps := timestamps[*series.Name]
|
||||
if tableTimestamps == nil {
|
||||
tableTimestamps = make(map[interface{}]int64)
|
||||
timestamps[*series.Name] = tableTimestamps
|
||||
}
|
||||
if duration != nil {
|
||||
tableTimestamps[value] = getTimestampFromPoint(*duration, point)
|
||||
} else {
|
||||
tableTimestamps[value] = point.GetTimestamp()
|
||||
timestampAggregator.AggregatePoint(*series.Name, value, point)
|
||||
seriesGroups := groups[*series.Name]
|
||||
if seriesGroups == nil {
|
||||
seriesGroups = make(map[interface{}]bool)
|
||||
groups[*series.Name] = seriesGroups
|
||||
}
|
||||
seriesGroups[value] = true
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -242,13 +238,15 @@ func (self *QueryEngine) executeCountQueryWithGroupBy(query *parser.Query, yield
|
|||
return err
|
||||
}
|
||||
|
||||
expectedFieldType := protocol.FieldDefinition_INT32
|
||||
expectedName := "count"
|
||||
var sequenceNumber uint32 = 1
|
||||
|
||||
/* fields := []*protocol.FieldDefinition{} */
|
||||
fields := []*protocol.FieldDefinition{
|
||||
&protocol.FieldDefinition{Name: &expectedName, Type: &expectedFieldType},
|
||||
fields := []*protocol.FieldDefinition{}
|
||||
|
||||
for _, aggregator := range aggregators {
|
||||
columnName := aggregator.ColumnName()
|
||||
columnType := aggregator.ColumnType()
|
||||
fields = append(fields, &protocol.FieldDefinition{Name: &columnName, Type: &columnType})
|
||||
}
|
||||
|
||||
for _, value := range groupBy {
|
||||
|
@ -260,27 +258,23 @@ func (self *QueryEngine) executeCountQueryWithGroupBy(query *parser.Query, yield
|
|||
fields = append(fields, &protocol.FieldDefinition{Name: &tempName, Type: fieldTypes[tempName]})
|
||||
}
|
||||
|
||||
for table, tableCounts := range counts {
|
||||
for table, tableGroups := range groups {
|
||||
tempTable := table
|
||||
points := []*protocol.Point{}
|
||||
for key, count := range tableCounts {
|
||||
tempKey := key
|
||||
tempCount := count
|
||||
|
||||
timestamp := timestamps[table][tempKey]
|
||||
|
||||
for groupId, _ := range tableGroups {
|
||||
timestamp := *timestampAggregator.GetValue(table, groupId).Int64Value
|
||||
point := &protocol.Point{
|
||||
Timestamp: ×tamp,
|
||||
SequenceNumber: &sequenceNumber,
|
||||
Values: []*protocol.FieldValue{
|
||||
&protocol.FieldValue{
|
||||
IntValue: &tempCount,
|
||||
},
|
||||
},
|
||||
Values: []*protocol.FieldValue{},
|
||||
}
|
||||
|
||||
for _, aggregator := range aggregators {
|
||||
point.Values = append(point.Values, aggregator.GetValue(table, groupId))
|
||||
}
|
||||
|
||||
for idx, _ := range groupBy {
|
||||
value := inverse(tempKey, idx)
|
||||
value := inverse(groupId, idx)
|
||||
|
||||
switch x := value.(type) {
|
||||
case string:
|
||||
|
|
Loading…
Reference in New Issue