Move ApplyEnvOverrides to toml package

This code has been duplicated to other projects and its implementations
have grown out of sync. Now the code can live as a package-level
function rather than a method coupled with particular structs.
pull/9376/head
Mark Rushakoff 2018-03-14 15:23:46 -07:00
parent c9e7c5ad2b
commit fbddcf7cad
3 changed files with 207 additions and 120 deletions

View File

@ -1,16 +1,13 @@
package run
import (
"encoding"
"fmt"
"io/ioutil"
"log"
"os"
"os/user"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"
"github.com/BurntSushi/toml"
@ -29,6 +26,7 @@ import (
"github.com/influxdata/influxdb/services/storage"
"github.com/influxdata/influxdb/services/subscriber"
"github.com/influxdata/influxdb/services/udp"
itoml "github.com/influxdata/influxdb/toml"
"github.com/influxdata/influxdb/tsdb"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"
@ -195,123 +193,7 @@ func (c *Config) Validate() error {
// ApplyEnvOverrides apply the environment configuration on top of the config.
func (c *Config) ApplyEnvOverrides(getenv func(string) string) error {
if getenv == nil {
getenv = os.Getenv
}
return c.applyEnvOverrides(getenv, "INFLUXDB", reflect.ValueOf(c), "")
}
func (c *Config) applyEnvOverrides(getenv func(string) string, prefix string, spec reflect.Value, structKey string) error {
element := spec
// If spec is a named type and is addressable,
// check the address to see if it implements encoding.TextUnmarshaler.
if spec.Kind() != reflect.Ptr && spec.Type().Name() != "" && spec.CanAddr() {
v := spec.Addr()
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
value := getenv(prefix)
return u.UnmarshalText([]byte(value))
}
}
// If we have a pointer, dereference it
if spec.Kind() == reflect.Ptr {
element = spec.Elem()
}
value := getenv(prefix)
switch element.Kind() {
case reflect.String:
if len(value) == 0 {
return nil
}
element.SetString(value)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
intValue, err := strconv.ParseInt(value, 0, element.Type().Bits())
if err != nil {
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v'", prefix, structKey, element.Type().String(), value)
}
element.SetInt(intValue)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
intValue, err := strconv.ParseUint(value, 0, element.Type().Bits())
if err != nil {
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v'", prefix, structKey, element.Type().String(), value)
}
element.SetUint(intValue)
case reflect.Bool:
boolValue, err := strconv.ParseBool(value)
if err != nil {
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v'", prefix, structKey, element.Type().String(), value)
}
element.SetBool(boolValue)
case reflect.Float32, reflect.Float64:
floatValue, err := strconv.ParseFloat(value, element.Type().Bits())
if err != nil {
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v'", prefix, structKey, element.Type().String(), value)
}
element.SetFloat(floatValue)
case reflect.Slice:
// If the type is s slice, apply to each using the index as a suffix, e.g. GRAPHITE_0, GRAPHITE_0_TEMPLATES_0 or GRAPHITE_0_TEMPLATES="item1,item2"
for j := 0; j < element.Len(); j++ {
f := element.Index(j)
if err := c.applyEnvOverrides(getenv, prefix, f, structKey); err != nil {
return err
}
if err := c.applyEnvOverrides(getenv, fmt.Sprintf("%s_%d", prefix, j), f, structKey); err != nil {
return err
}
}
// If the type is s slice but have value not parsed as slice e.g. GRAPHITE_0_TEMPLATES="item1,item2"
if element.Len() == 0 && len(value) > 0 {
rules := strings.Split(value, ",")
for _, rule := range rules {
element.Set(reflect.Append(element, reflect.ValueOf(rule)))
}
}
case reflect.Struct:
typeOfSpec := element.Type()
for i := 0; i < element.NumField(); i++ {
field := element.Field(i)
// Skip any fields that we cannot set
if !field.CanSet() && field.Kind() != reflect.Slice {
continue
}
fieldName := typeOfSpec.Field(i).Name
configName := typeOfSpec.Field(i).Tag.Get("toml")
// Replace hyphens with underscores to avoid issues with shells
configName = strings.Replace(configName, "-", "_", -1)
envKey := strings.ToUpper(configName)
if prefix != "" {
envKey = strings.ToUpper(fmt.Sprintf("%s_%s", prefix, configName))
}
// If it's a sub-config, recursively apply
if field.Kind() == reflect.Struct || field.Kind() == reflect.Ptr ||
field.Kind() == reflect.Slice || field.Kind() == reflect.Array {
if err := c.applyEnvOverrides(getenv, envKey, field, fieldName); err != nil {
return err
}
continue
}
value := getenv(envKey)
// Skip any fields we don't have a value to set
if len(value) == 0 {
continue
}
if err := c.applyEnvOverrides(getenv, envKey, field, fieldName); err != nil {
return err
}
}
}
return nil
return itoml.ApplyEnvOverrides(getenv, "INFLUXDB", c)
}
// Diagnostics returns a diagnostics representation of Config.

View File

@ -2,9 +2,13 @@
package toml // import "github.com/influxdata/influxdb/toml"
import (
"encoding"
"fmt"
"math"
"os"
"reflect"
"strconv"
"strings"
"time"
"unicode"
)
@ -89,3 +93,123 @@ func (s *Size) UnmarshalText(text []byte) error {
*s = Size(size)
return nil
}
func ApplyEnvOverrides(getenv func(string) string, prefix string, val interface{}) error {
if getenv == nil {
getenv = os.Getenv
}
return applyEnvOverrides(getenv, prefix, reflect.ValueOf(val), "")
}
func applyEnvOverrides(getenv func(string) string, prefix string, spec reflect.Value, structKey string) error {
element := spec
// If spec is a named type and is addressable,
// check the address to see if it implements encoding.TextUnmarshaler.
if spec.Kind() != reflect.Ptr && spec.Type().Name() != "" && spec.CanAddr() {
v := spec.Addr()
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
value := getenv(prefix)
return u.UnmarshalText([]byte(value))
}
}
// If we have a pointer, dereference it
if spec.Kind() == reflect.Ptr {
element = spec.Elem()
}
value := getenv(prefix)
switch element.Kind() {
case reflect.String:
if len(value) == 0 {
return nil
}
element.SetString(value)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
intValue, err := strconv.ParseInt(value, 0, element.Type().Bits())
if err != nil {
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v': %s", prefix, structKey, element.Type().String(), value, err)
}
element.SetInt(intValue)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
intValue, err := strconv.ParseUint(value, 0, element.Type().Bits())
if err != nil {
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v': %s", prefix, structKey, element.Type().String(), value, err)
}
element.SetUint(intValue)
case reflect.Bool:
boolValue, err := strconv.ParseBool(value)
if err != nil {
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v': %s", prefix, structKey, element.Type().String(), value, err)
}
element.SetBool(boolValue)
case reflect.Float32, reflect.Float64:
floatValue, err := strconv.ParseFloat(value, element.Type().Bits())
if err != nil {
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v': %s", prefix, structKey, element.Type().String(), value, err)
}
element.SetFloat(floatValue)
case reflect.Slice:
// If the type is s slice, apply to each using the index as a suffix, e.g. GRAPHITE_0, GRAPHITE_0_TEMPLATES_0 or GRAPHITE_0_TEMPLATES="item1,item2"
for j := 0; j < element.Len(); j++ {
f := element.Index(j)
if err := applyEnvOverrides(getenv, prefix, f, structKey); err != nil {
return err
}
if err := applyEnvOverrides(getenv, fmt.Sprintf("%s_%d", prefix, j), f, structKey); err != nil {
return err
}
}
// If the type is s slice but have value not parsed as slice e.g. GRAPHITE_0_TEMPLATES="item1,item2"
if element.Len() == 0 && len(value) > 0 {
rules := strings.Split(value, ",")
for _, rule := range rules {
element.Set(reflect.Append(element, reflect.ValueOf(rule)))
}
}
case reflect.Struct:
typeOfSpec := element.Type()
for i := 0; i < element.NumField(); i++ {
field := element.Field(i)
// Skip any fields that we cannot set
if !field.CanSet() && field.Kind() != reflect.Slice {
continue
}
fieldName := typeOfSpec.Field(i).Name
configName := typeOfSpec.Field(i).Tag.Get("toml")
// Replace hyphens with underscores to avoid issues with shells
configName = strings.Replace(configName, "-", "_", -1)
envKey := strings.ToUpper(configName)
if prefix != "" {
envKey = strings.ToUpper(fmt.Sprintf("%s_%s", prefix, configName))
}
// If it's a sub-config, recursively apply
if field.Kind() == reflect.Struct || field.Kind() == reflect.Ptr ||
field.Kind() == reflect.Slice || field.Kind() == reflect.Array {
if err := applyEnvOverrides(getenv, envKey, field, fieldName); err != nil {
return err
}
continue
}
value := getenv(envKey)
// Skip any fields we don't have a value to set
if len(value) == 0 {
continue
}
if err := applyEnvOverrides(getenv, envKey, field, fieldName); err != nil {
return err
}
}
}
return nil
}

View File

@ -9,6 +9,7 @@ import (
"time"
"github.com/BurntSushi/toml"
"github.com/google/go-cmp/cmp"
"github.com/influxdata/influxdb/cmd/influxd/run"
itoml "github.com/influxdata/influxdb/toml"
)
@ -73,3 +74,83 @@ func TestConfig_Encode(t *testing.T) {
t.Fatalf("Encoding config failed.\nfailed to find %s in:\n%s\n", search, got)
}
}
func TestEnvOverride_Builtins(t *testing.T) {
envMap := map[string]string{
"X_STRING": "a string",
"X_DURATION": "1m1s",
"X_INT": "1",
"X_INT8": "2",
"X_INT16": "3",
"X_INT32": "4",
"X_INT64": "5",
"X_UINT": "6",
"X_UINT8": "7",
"X_UINT16": "8",
"X_UINT32": "9",
"X_UINT64": "10",
"X_BOOL": "true",
"X_FLOAT32": "11.5",
"X_FLOAT64": "12.5",
"X_NESTED_STRING": "a nested string",
"X_NESTED_INT": "13",
}
env := func(s string) string {
return envMap[s]
}
type nested struct {
Str string `toml:"string"`
Int int `toml:"int"`
}
type all struct {
Str string `toml:"string"`
Dur itoml.Duration `toml:"duration"`
Int int `toml:"int"`
Int8 int8 `toml:"int8"`
Int16 int16 `toml:"int16"`
Int32 int32 `toml:"int32"`
Int64 int64 `toml:"int64"`
Uint uint `toml:"uint"`
Uint8 uint8 `toml:"uint8"`
Uint16 uint16 `toml:"uint16"`
Uint32 uint32 `toml:"uint32"`
Uint64 uint64 `toml:"uint64"`
Bool bool `toml:"bool"`
Float32 float32 `toml:"float32"`
Float64 float64 `toml:"float64"`
Nested nested `toml:"nested"`
}
var got all
if err := itoml.ApplyEnvOverrides(env, "X", &got); err != nil {
t.Fatal(err)
}
exp := all{
Str: "a string",
Dur: itoml.Duration(time.Minute + time.Second),
Int: 1,
Int8: 2,
Int16: 3,
Int32: 4,
Int64: 5,
Uint: 6,
Uint8: 7,
Uint16: 8,
Uint32: 9,
Uint64: 10,
Bool: true,
Float32: 11.5,
Float64: 12.5,
Nested: nested{
Str: "a nested string",
Int: 13,
},
}
if diff := cmp.Diff(got, exp); diff != "" {
t.Fatal(diff)
}
}