// Package toml adds support to marshal and unmarshal types not in the official TOML spec. package toml import ( "encoding" "errors" "fmt" "os" "os/user" "reflect" "strconv" "strings" "time" "github.com/dustin/go-humanize" "github.com/spf13/pflag" ) // Duration is a TOML wrapper type for time.Duration. type Duration time.Duration var _ pflag.Value = (*Duration)(nil) func (d *Duration) Set(s string) error { return d.UnmarshalText([]byte(s)) } func (d Duration) Type() string { return "Duration" } // String returns the string representation of the duration. func (d Duration) String() string { return time.Duration(d).String() } // UnmarshalText parses a TOML value into a duration value. func (d *Duration) UnmarshalText(text []byte) error { // Ignore if there is no value set. if len(text) == 0 { return nil } // Otherwise parse as a duration formatted string. duration, err := time.ParseDuration(string(text)) if err != nil { return err } // Set duration and return. *d = Duration(duration) return nil } // MarshalText converts a duration to a string for decoding toml func (d Duration) MarshalText() (text []byte, err error) { return []byte(d.String()), nil } // Size represents a TOML parsable file size. // Users can specify size using "k" or "K" for kibibytes, "m" or "M" for mebibytes, // and "g" or "G" for gibibytes. If a size suffix isn't specified then bytes are assumed. type Size uint64 var _ pflag.Value = (*Size)(nil) func (s Size) String() string { return humanize.IBytes(uint64(s)) } func (s *Size) Set(d string) error { return s.UnmarshalText([]byte(d)) } func (s Size) Type() string { return "Size" } // UnmarshalText parses a byte size from text. func (s *Size) UnmarshalText(text []byte) error { if len(text) == 0 { return fmt.Errorf("size was empty") } v, err := humanize.ParseBytes(string(text)) if err != nil { return err } *s = Size(v) return nil } type FileMode uint32 func (m *FileMode) UnmarshalText(text []byte) error { // Ignore if there is no value set. if len(text) == 0 { return nil } mode, err := strconv.ParseUint(string(text), 8, 32) if err != nil { return err } else if mode == 0 { return errors.New("file mode cannot be zero") } *m = FileMode(mode) return nil } func (m FileMode) MarshalText() (text []byte, err error) { if m != 0 { return []byte(fmt.Sprintf("%04o", m)), nil } return nil, nil } type Group int func (g *Group) UnmarshalTOML(data interface{}) error { if grpName, ok := data.(string); ok { group, err := user.LookupGroup(grpName) if err != nil { return err } gid, err := strconv.Atoi(group.Gid) if err != nil { return err } *g = Group(gid) return nil } else if gid, ok := data.(int64); ok { *g = Group(gid) return nil } return errors.New("group must be a name (string) or id (int)") } func ApplyEnvOverrides(getenv func(string) string, prefix string, val interface{}) error { if getenv == nil { getenv = os.Getenv } return applyEnvOverrides(getenv, prefix, reflect.ValueOf(val), "") } func applyEnvOverrides(getenv func(string) string, prefix string, spec reflect.Value, structKey string) error { element := spec // If spec is a named type and is addressable, // check the address to see if it implements encoding.TextUnmarshaler. if spec.Kind() != reflect.Pointer && spec.Type().Name() != "" && spec.CanAddr() { v := spec.Addr() if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { value := getenv(prefix) return u.UnmarshalText([]byte(value)) } } // If we have a pointer, dereference it if spec.Kind() == reflect.Pointer { element = spec.Elem() } value := getenv(prefix) switch element.Kind() { case reflect.String: if len(value) == 0 { return nil } element.SetString(value) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: intValue, err := strconv.ParseInt(value, 0, element.Type().Bits()) if err != nil { return fmt.Errorf("failed to apply %v to %v using type %v and value '%v': %s", prefix, structKey, element.Type().String(), value, err) } element.SetInt(intValue) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: intValue, err := strconv.ParseUint(value, 0, element.Type().Bits()) if err != nil { return fmt.Errorf("failed to apply %v to %v using type %v and value '%v': %s", prefix, structKey, element.Type().String(), value, err) } element.SetUint(intValue) case reflect.Bool: boolValue, err := strconv.ParseBool(value) if err != nil { return fmt.Errorf("failed to apply %v to %v using type %v and value '%v': %s", prefix, structKey, element.Type().String(), value, err) } element.SetBool(boolValue) case reflect.Float32, reflect.Float64: floatValue, err := strconv.ParseFloat(value, element.Type().Bits()) if err != nil { return fmt.Errorf("failed to apply %v to %v using type %v and value '%v': %s", prefix, structKey, element.Type().String(), value, err) } element.SetFloat(floatValue) case reflect.Slice: // If the type is s slice, apply to each using the index as a suffix, e.g. GRAPHITE_0, GRAPHITE_0_TEMPLATES_0 or GRAPHITE_0_TEMPLATES="item1,item2" for j := 0; j < element.Len(); j++ { f := element.Index(j) if err := applyEnvOverrides(getenv, prefix, f, structKey); err != nil { return err } if err := applyEnvOverrides(getenv, fmt.Sprintf("%s_%d", prefix, j), f, structKey); err != nil { return err } } // If the type is s slice but have value not parsed as slice e.g. GRAPHITE_0_TEMPLATES="item1,item2" if element.Len() == 0 && len(value) > 0 { rules := strings.Split(value, ",") for _, rule := range rules { element.Set(reflect.Append(element, reflect.ValueOf(rule))) } } case reflect.Struct: typeOfSpec := element.Type() for i := 0; i < element.NumField(); i++ { field := element.Field(i) // Skip any fields that we cannot set if !field.CanSet() && field.Kind() != reflect.Slice { continue } structField := typeOfSpec.Field(i) fieldName := structField.Name configName := structField.Tag.Get("toml") if configName == "-" { // Skip fields with tag `toml:"-"`. continue } if configName == "" && structField.Anonymous { // Embedded field without a toml tag. // Don't modify prefix. if err := applyEnvOverrides(getenv, prefix, field, fieldName); err != nil { return err } continue } // Replace hyphens with underscores to avoid issues with shells configName = strings.Replace(configName, "-", "_", -1) envKey := strings.ToUpper(configName) if prefix != "" { envKey = strings.ToUpper(fmt.Sprintf("%s_%s", prefix, configName)) } // If it's a sub-config, recursively apply if field.Kind() == reflect.Struct || field.Kind() == reflect.Pointer || field.Kind() == reflect.Slice || field.Kind() == reflect.Array { if err := applyEnvOverrides(getenv, envKey, field, fieldName); err != nil { return err } continue } value := getenv(envKey) // Skip any fields we don't have a value to set if len(value) == 0 { continue } if err := applyEnvOverrides(getenv, envKey, field, fieldName); err != nil { return err } } } return nil }