Move ApplyEnvOverrides to toml package
This code has been duplicated to other projects and its implementations have grown out of sync. Now the code can live as a package-level function rather than a method coupled with particular structs.pull/9376/head
parent
c9e7c5ad2b
commit
fbddcf7cad
|
@ -1,16 +1,13 @@
|
|||
package run
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
|
@ -29,6 +26,7 @@ import (
|
|||
"github.com/influxdata/influxdb/services/storage"
|
||||
"github.com/influxdata/influxdb/services/subscriber"
|
||||
"github.com/influxdata/influxdb/services/udp"
|
||||
itoml "github.com/influxdata/influxdb/toml"
|
||||
"github.com/influxdata/influxdb/tsdb"
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
"golang.org/x/text/transform"
|
||||
|
@ -195,123 +193,7 @@ func (c *Config) Validate() error {
|
|||
|
||||
// ApplyEnvOverrides apply the environment configuration on top of the config.
|
||||
func (c *Config) ApplyEnvOverrides(getenv func(string) string) error {
|
||||
if getenv == nil {
|
||||
getenv = os.Getenv
|
||||
}
|
||||
return c.applyEnvOverrides(getenv, "INFLUXDB", reflect.ValueOf(c), "")
|
||||
}
|
||||
|
||||
func (c *Config) applyEnvOverrides(getenv func(string) string, prefix string, spec reflect.Value, structKey string) error {
|
||||
element := spec
|
||||
// If spec is a named type and is addressable,
|
||||
// check the address to see if it implements encoding.TextUnmarshaler.
|
||||
if spec.Kind() != reflect.Ptr && spec.Type().Name() != "" && spec.CanAddr() {
|
||||
v := spec.Addr()
|
||||
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
|
||||
value := getenv(prefix)
|
||||
return u.UnmarshalText([]byte(value))
|
||||
}
|
||||
}
|
||||
// If we have a pointer, dereference it
|
||||
if spec.Kind() == reflect.Ptr {
|
||||
element = spec.Elem()
|
||||
}
|
||||
|
||||
value := getenv(prefix)
|
||||
|
||||
switch element.Kind() {
|
||||
case reflect.String:
|
||||
if len(value) == 0 {
|
||||
return nil
|
||||
}
|
||||
element.SetString(value)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
intValue, err := strconv.ParseInt(value, 0, element.Type().Bits())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v'", prefix, structKey, element.Type().String(), value)
|
||||
}
|
||||
element.SetInt(intValue)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
intValue, err := strconv.ParseUint(value, 0, element.Type().Bits())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v'", prefix, structKey, element.Type().String(), value)
|
||||
}
|
||||
element.SetUint(intValue)
|
||||
case reflect.Bool:
|
||||
boolValue, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v'", prefix, structKey, element.Type().String(), value)
|
||||
}
|
||||
element.SetBool(boolValue)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
floatValue, err := strconv.ParseFloat(value, element.Type().Bits())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v'", prefix, structKey, element.Type().String(), value)
|
||||
}
|
||||
element.SetFloat(floatValue)
|
||||
case reflect.Slice:
|
||||
// If the type is s slice, apply to each using the index as a suffix, e.g. GRAPHITE_0, GRAPHITE_0_TEMPLATES_0 or GRAPHITE_0_TEMPLATES="item1,item2"
|
||||
for j := 0; j < element.Len(); j++ {
|
||||
f := element.Index(j)
|
||||
if err := c.applyEnvOverrides(getenv, prefix, f, structKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.applyEnvOverrides(getenv, fmt.Sprintf("%s_%d", prefix, j), f, structKey); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If the type is s slice but have value not parsed as slice e.g. GRAPHITE_0_TEMPLATES="item1,item2"
|
||||
if element.Len() == 0 && len(value) > 0 {
|
||||
rules := strings.Split(value, ",")
|
||||
|
||||
for _, rule := range rules {
|
||||
element.Set(reflect.Append(element, reflect.ValueOf(rule)))
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
typeOfSpec := element.Type()
|
||||
for i := 0; i < element.NumField(); i++ {
|
||||
field := element.Field(i)
|
||||
|
||||
// Skip any fields that we cannot set
|
||||
if !field.CanSet() && field.Kind() != reflect.Slice {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldName := typeOfSpec.Field(i).Name
|
||||
|
||||
configName := typeOfSpec.Field(i).Tag.Get("toml")
|
||||
// Replace hyphens with underscores to avoid issues with shells
|
||||
configName = strings.Replace(configName, "-", "_", -1)
|
||||
|
||||
envKey := strings.ToUpper(configName)
|
||||
if prefix != "" {
|
||||
envKey = strings.ToUpper(fmt.Sprintf("%s_%s", prefix, configName))
|
||||
}
|
||||
|
||||
// If it's a sub-config, recursively apply
|
||||
if field.Kind() == reflect.Struct || field.Kind() == reflect.Ptr ||
|
||||
field.Kind() == reflect.Slice || field.Kind() == reflect.Array {
|
||||
if err := c.applyEnvOverrides(getenv, envKey, field, fieldName); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
value := getenv(envKey)
|
||||
// Skip any fields we don't have a value to set
|
||||
if len(value) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.applyEnvOverrides(getenv, envKey, field, fieldName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return itoml.ApplyEnvOverrides(getenv, "INFLUXDB", c)
|
||||
}
|
||||
|
||||
// Diagnostics returns a diagnostics representation of Config.
|
||||
|
|
124
toml/toml.go
124
toml/toml.go
|
@ -2,9 +2,13 @@
|
|||
package toml // import "github.com/influxdata/influxdb/toml"
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
)
|
||||
|
@ -89,3 +93,123 @@ func (s *Size) UnmarshalText(text []byte) error {
|
|||
*s = Size(size)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ApplyEnvOverrides(getenv func(string) string, prefix string, val interface{}) error {
|
||||
if getenv == nil {
|
||||
getenv = os.Getenv
|
||||
}
|
||||
return applyEnvOverrides(getenv, prefix, reflect.ValueOf(val), "")
|
||||
}
|
||||
|
||||
func applyEnvOverrides(getenv func(string) string, prefix string, spec reflect.Value, structKey string) error {
|
||||
element := spec
|
||||
// If spec is a named type and is addressable,
|
||||
// check the address to see if it implements encoding.TextUnmarshaler.
|
||||
if spec.Kind() != reflect.Ptr && spec.Type().Name() != "" && spec.CanAddr() {
|
||||
v := spec.Addr()
|
||||
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
|
||||
value := getenv(prefix)
|
||||
return u.UnmarshalText([]byte(value))
|
||||
}
|
||||
}
|
||||
// If we have a pointer, dereference it
|
||||
if spec.Kind() == reflect.Ptr {
|
||||
element = spec.Elem()
|
||||
}
|
||||
|
||||
value := getenv(prefix)
|
||||
|
||||
switch element.Kind() {
|
||||
case reflect.String:
|
||||
if len(value) == 0 {
|
||||
return nil
|
||||
}
|
||||
element.SetString(value)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
intValue, err := strconv.ParseInt(value, 0, element.Type().Bits())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v': %s", prefix, structKey, element.Type().String(), value, err)
|
||||
}
|
||||
element.SetInt(intValue)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
intValue, err := strconv.ParseUint(value, 0, element.Type().Bits())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v': %s", prefix, structKey, element.Type().String(), value, err)
|
||||
}
|
||||
element.SetUint(intValue)
|
||||
case reflect.Bool:
|
||||
boolValue, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v': %s", prefix, structKey, element.Type().String(), value, err)
|
||||
}
|
||||
element.SetBool(boolValue)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
floatValue, err := strconv.ParseFloat(value, element.Type().Bits())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to apply %v to %v using type %v and value '%v': %s", prefix, structKey, element.Type().String(), value, err)
|
||||
}
|
||||
element.SetFloat(floatValue)
|
||||
case reflect.Slice:
|
||||
// If the type is s slice, apply to each using the index as a suffix, e.g. GRAPHITE_0, GRAPHITE_0_TEMPLATES_0 or GRAPHITE_0_TEMPLATES="item1,item2"
|
||||
for j := 0; j < element.Len(); j++ {
|
||||
f := element.Index(j)
|
||||
if err := applyEnvOverrides(getenv, prefix, f, structKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := applyEnvOverrides(getenv, fmt.Sprintf("%s_%d", prefix, j), f, structKey); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If the type is s slice but have value not parsed as slice e.g. GRAPHITE_0_TEMPLATES="item1,item2"
|
||||
if element.Len() == 0 && len(value) > 0 {
|
||||
rules := strings.Split(value, ",")
|
||||
|
||||
for _, rule := range rules {
|
||||
element.Set(reflect.Append(element, reflect.ValueOf(rule)))
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
typeOfSpec := element.Type()
|
||||
for i := 0; i < element.NumField(); i++ {
|
||||
field := element.Field(i)
|
||||
|
||||
// Skip any fields that we cannot set
|
||||
if !field.CanSet() && field.Kind() != reflect.Slice {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldName := typeOfSpec.Field(i).Name
|
||||
|
||||
configName := typeOfSpec.Field(i).Tag.Get("toml")
|
||||
// Replace hyphens with underscores to avoid issues with shells
|
||||
configName = strings.Replace(configName, "-", "_", -1)
|
||||
|
||||
envKey := strings.ToUpper(configName)
|
||||
if prefix != "" {
|
||||
envKey = strings.ToUpper(fmt.Sprintf("%s_%s", prefix, configName))
|
||||
}
|
||||
|
||||
// If it's a sub-config, recursively apply
|
||||
if field.Kind() == reflect.Struct || field.Kind() == reflect.Ptr ||
|
||||
field.Kind() == reflect.Slice || field.Kind() == reflect.Array {
|
||||
if err := applyEnvOverrides(getenv, envKey, field, fieldName); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
value := getenv(envKey)
|
||||
// Skip any fields we don't have a value to set
|
||||
if len(value) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := applyEnvOverrides(getenv, envKey, field, fieldName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/influxdata/influxdb/cmd/influxd/run"
|
||||
itoml "github.com/influxdata/influxdb/toml"
|
||||
)
|
||||
|
@ -73,3 +74,83 @@ func TestConfig_Encode(t *testing.T) {
|
|||
t.Fatalf("Encoding config failed.\nfailed to find %s in:\n%s\n", search, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverride_Builtins(t *testing.T) {
|
||||
envMap := map[string]string{
|
||||
"X_STRING": "a string",
|
||||
"X_DURATION": "1m1s",
|
||||
"X_INT": "1",
|
||||
"X_INT8": "2",
|
||||
"X_INT16": "3",
|
||||
"X_INT32": "4",
|
||||
"X_INT64": "5",
|
||||
"X_UINT": "6",
|
||||
"X_UINT8": "7",
|
||||
"X_UINT16": "8",
|
||||
"X_UINT32": "9",
|
||||
"X_UINT64": "10",
|
||||
"X_BOOL": "true",
|
||||
"X_FLOAT32": "11.5",
|
||||
"X_FLOAT64": "12.5",
|
||||
"X_NESTED_STRING": "a nested string",
|
||||
"X_NESTED_INT": "13",
|
||||
}
|
||||
|
||||
env := func(s string) string {
|
||||
return envMap[s]
|
||||
}
|
||||
|
||||
type nested struct {
|
||||
Str string `toml:"string"`
|
||||
Int int `toml:"int"`
|
||||
}
|
||||
type all struct {
|
||||
Str string `toml:"string"`
|
||||
Dur itoml.Duration `toml:"duration"`
|
||||
Int int `toml:"int"`
|
||||
Int8 int8 `toml:"int8"`
|
||||
Int16 int16 `toml:"int16"`
|
||||
Int32 int32 `toml:"int32"`
|
||||
Int64 int64 `toml:"int64"`
|
||||
Uint uint `toml:"uint"`
|
||||
Uint8 uint8 `toml:"uint8"`
|
||||
Uint16 uint16 `toml:"uint16"`
|
||||
Uint32 uint32 `toml:"uint32"`
|
||||
Uint64 uint64 `toml:"uint64"`
|
||||
Bool bool `toml:"bool"`
|
||||
Float32 float32 `toml:"float32"`
|
||||
Float64 float64 `toml:"float64"`
|
||||
Nested nested `toml:"nested"`
|
||||
}
|
||||
|
||||
var got all
|
||||
if err := itoml.ApplyEnvOverrides(env, "X", &got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
exp := all{
|
||||
Str: "a string",
|
||||
Dur: itoml.Duration(time.Minute + time.Second),
|
||||
Int: 1,
|
||||
Int8: 2,
|
||||
Int16: 3,
|
||||
Int32: 4,
|
||||
Int64: 5,
|
||||
Uint: 6,
|
||||
Uint8: 7,
|
||||
Uint16: 8,
|
||||
Uint32: 9,
|
||||
Uint64: 10,
|
||||
Bool: true,
|
||||
Float32: 11.5,
|
||||
Float64: 12.5,
|
||||
Nested: nested{
|
||||
Str: "a nested string",
|
||||
Int: 13,
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(got, exp); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue