package engine import ( "fmt" "sort" "github.com/influxdb/influxdb/protocol" ) // TODO: add some documentation // A trie structure to map groups to aggregates, each level of the // group corresponds to a field in the group by clause // values at each level are sorted like this => nil, boolean, int64, float64, string type Trie struct { numLevels int numStates int rootNode *Node } const MaxInt = int(^uint(0) >> 1) type Nodes []*Node type Node struct { isLeaf bool value *protocol.FieldValue // the value of the group by column corresponding to this level states []interface{} // the aggregator state childNodes Nodes // the slice of the next level } func NewTrie(numLevels, numStates int) *Trie { trie := &Trie{numLevels, numStates, nil} trie.Clear() return trie } func (self *Trie) Clear() { self.rootNode = &Node{true, nil, make([]interface{}, self.numStates), nil} } func (self *Trie) CountLeafNodes() int { return self.rootNode.CountLeafNodes() } func (self *Trie) Traverse(f func([]*protocol.FieldValue, *Node) error) error { return self.TraverseLevel(-1, f) } // Traverses all nodes at the given level, -1 to get nodes at the most bottom level func (self *Trie) TraverseLevel(level int, f func([]*protocol.FieldValue, *Node) error) error { if self.numLevels == 0 { return f(nil, self.rootNode) } if level == -1 { level = MaxInt } return self.rootNode.traverse(level, nil, f) } func (self *Trie) GetNode(values []*protocol.FieldValue) *Node { if len(values) != self.numLevels { panic(fmt.Errorf("number of levels doesn't match values. Expected: %d, Actual: %d", self.numLevels, len(values))) } if self.numLevels == 0 { return self.rootNode } node := self.rootNode for idx, v := range values { if self.numLevels-idx-1 > 0 { node = node.findOrCreateNode(v, 0) } else { node = node.findOrCreateNode(v, self.numStates) } } return node } func (self *Node) CountLeafNodes() int { size := 0 for _, child := range self.childNodes { size += child.CountLeafNodes() if child.isLeaf { size++ } } return size } func (self *Node) GetChildNode(value *protocol.FieldValue) *Node { idx := self.childNodes.findNode(value) if idx == len(self.childNodes) || !self.childNodes[idx].value.Equals(value) { return nil } return self.childNodes[idx] } func (self *Node) traverse(level int, values []*protocol.FieldValue, f func([]*protocol.FieldValue, *Node) error) error { if level == 0 { return f(values, self) } for _, node := range self.childNodes { if node.isLeaf { err := f(append(values, node.value), node) if err != nil { return err } continue } err := node.traverse(level-1, append(values, node.value), f) if err != nil { return err } } return nil } func (self *Node) findOrCreateNode(value *protocol.FieldValue, numOfStates int) *Node { idx := self.childNodes.findNode(value) if idx == len(self.childNodes) || !self.childNodes[idx].value.Equals(value) { // add the new node node := self.createChildNode(value, numOfStates) self.childNodes = append(self.childNodes, node) // if idx is equal to the previous length, then leave it at the // end, otherwise, move it to that index. if idx != len(self.childNodes)-1 { // shift all values to the right by one copy(self.childNodes[idx+1:], self.childNodes[idx:]) self.childNodes[idx] = node } return node } return self.childNodes[idx] } func (self *Node) createChildNode(value *protocol.FieldValue, numOfStates int) *Node { node := &Node{value: value} if numOfStates > 0 { node.states = make([]interface{}, numOfStates) node.isLeaf = true return node } return node } func (self Nodes) findNode(value *protocol.FieldValue) int { return sort.Search(len(self), func(i int) bool { return self[i].value.GreaterOrEqual(value) }) }