package main import ( "bytes" "flag" "fmt" "go/format" "io" "os" "strings" "text/template" "github.com/Masterminds/sprig" "github.com/influxdata/influxdb/v2/kit/feature" yaml "gopkg.in/yaml.v2" ) const tmpl = `// Code generated by the feature package; DO NOT EDIT. package feature {{ .Qualify | import }} {{ range $_, $flag := .Flags }} var {{ $flag.Key }} = {{ $.Qualify | package }}{{ $flag.Default | maker }}( {{ $flag.Name | quote }}, {{ $flag.Key | quote }}, {{ $flag.Contact | quote }}, {{ $flag.Default | conditionalQuote }}, {{ $.Qualify | package }}{{ $flag.Lifetime | lifetime }}, {{ $flag.Expose }}, ) // {{ $flag.Name | replace " " "_" | camelcase }} - {{ $flag.Description }} func {{ $flag.Name | replace " " "_" | camelcase }}() {{ $.Qualify | package }}{{ $flag.Default | flagType }} { return {{ $flag.Key }} } {{ end }} var all = []{{ .Qualify | package }}Flag{ {{ range $_, $flag := .Flags }} {{ $flag.Key }}, {{ end }}} var byKey = map[string]{{ $.Qualify | package }}Flag{ {{ range $_, $flag := .Flags }} {{ $flag.Key | quote }}: {{ $flag.Key }}, {{ end }}} ` type flagConfig struct { Name string Description string Key string Default interface{} Contact string Lifetime feature.Lifetime Expose bool } func (f flagConfig) Valid() error { var problems []string if f.Key == "" { problems = append(problems, "missing key") } if f.Contact == "" { problems = append(problems, "missing contact") } if f.Default == nil { problems = append(problems, "missing default") } if f.Description == "" { problems = append(problems, "missing description") } if len(problems) > 0 { name := f.Name if name == "" { if f.Key != "" { name = f.Key } else { name = "anonymous" } } // e.g. "my flag: missing key; missing default" return fmt.Errorf("%s: %s\n", name, strings.Join(problems, "; ")) } return nil } type flagValidationError struct { errs []error } func newFlagValidationError(errs []error) *flagValidationError { if len(errs) == 0 { return nil } return &flagValidationError{errs} } func (e *flagValidationError) Error() string { var s strings.Builder s.WriteString("flag validation error: \n") for _, err := range e.errs { s.WriteString(err.Error()) } return s.String() } func validate(flags []flagConfig) error { var ( errs []error seen = make(map[string]bool, len(flags)) ) for _, flag := range flags { if err := flag.Valid(); err != nil { errs = append(errs, err) } else if _, repeated := seen[flag.Key]; repeated { errs = append(errs, fmt.Errorf("duplicate flag key '%s'\n", flag.Key)) } seen[flag.Key] = true } if len(errs) != 0 { return newFlagValidationError(errs) } return nil } var argv = struct { in, out *string qualify *bool }{ in: flag.String("in", "", "flag configuration path"), out: flag.String("out", "", "flag generation destination path"), qualify: flag.Bool("qualify", false, "qualify types with imported package name"), } func main() { if err := run(); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } os.Exit(0) } func run() error { flag.Parse() in, err := os.Open(*argv.in) if err != nil { return err } defer in.Close() configuration, err := io.ReadAll(in) if err != nil { return err } var flags []flagConfig err = yaml.Unmarshal(configuration, &flags) if err != nil { return err } err = validate(flags) if err != nil { return err } t, err := template.New("flags").Funcs(templateFunctions()).Parse(tmpl) if err != nil { return err } out, err := os.Create(*argv.out) if err != nil { return err } defer out.Close() var ( buf = new(bytes.Buffer) vars = struct { Qualify bool Flags []flagConfig }{ Qualify: *argv.qualify, Flags: flags, } ) if err := t.Execute(buf, vars); err != nil { return err } raw, err := io.ReadAll(buf) if err != nil { return err } formatted, err := format.Source(raw) if err != nil { return err } _, err = out.Write(formatted) return err } func templateFunctions() template.FuncMap { functions := sprig.TxtFuncMap() functions["lifetime"] = func(t interface{}) string { switch t { case feature.Permanent: return "Permanent" default: return "Temporary" } } functions["conditionalQuote"] = func(t interface{}) string { switch t.(type) { case string: return fmt.Sprintf("%q", t) default: return fmt.Sprintf("%v", t) } } functions["flagType"] = func(t interface{}) string { switch t.(type) { case bool: return "BoolFlag" case float64: return "FloatFlag" case int: return "IntFlag" default: return "StringFlag" } } functions["maker"] = func(t interface{}) string { switch t.(type) { case bool: return "MakeBoolFlag" case float64: return "MakeFloatFlag" case int: return "MakeIntFlag" default: return "MakeStringFlag" } } functions["package"] = func(t interface{}) string { if t.(bool) { return "feature." } return "" } functions["import"] = func(t interface{}) string { if t.(bool) { return "import \"github.com/influxdata/influxdb/v2/kit/feature\"" } return "" } return functions }