influxdb/kit/feature/_codegen/main.go

272 lines
5.1 KiB
Go

package main
import (
"bytes"
"flag"
"fmt"
"go/format"
"io/ioutil"
"os"
"strings"
"text/template"
"github.com/Masterminds/sprig"
"github.com/influxdata/influxdb/v2/kit/feature"
yaml "gopkg.in/yaml.v2"
)
const tmpl = `// Code generated by the feature package; DO NOT EDIT.
package feature
{{ .Qualify | import }}
{{ range $_, $flag := .Flags }}
var {{ $flag.Key }} = {{ $.Qualify | package }}{{ $flag.Default | maker }}(
{{ $flag.Name | quote }},
{{ $flag.Key | quote }},
{{ $flag.Contact | quote }},
{{ $flag.Default | conditionalQuote }},
{{ $.Qualify | package }}{{ $flag.Lifetime | lifetime }},
{{ $flag.Expose }},
)
// {{ $flag.Name | replace " " "_" | camelcase }} - {{ $flag.Description }}
func {{ $flag.Name | replace " " "_" | camelcase }}() {{ $.Qualify | package }}{{ $flag.Default | flagType }} {
return {{ $flag.Key }}
}
{{ end }}
var all = []{{ .Qualify | package }}Flag{
{{ range $_, $flag := .Flags }} {{ $flag.Key }},
{{ end }}}
var byKey = map[string]{{ $.Qualify | package }}Flag{
{{ range $_, $flag := .Flags }} {{ $flag.Key | quote }}: {{ $flag.Key }},
{{ end }}}
`
type flagConfig struct {
Name string
Description string
Key string
Default interface{}
Contact string
Lifetime feature.Lifetime
Expose bool
}
func (f flagConfig) Valid() error {
var problems []string
if f.Key == "" {
problems = append(problems, "missing key")
}
if f.Contact == "" {
problems = append(problems, "missing contact")
}
if f.Default == nil {
problems = append(problems, "missing default")
}
if f.Description == "" {
problems = append(problems, "missing description")
}
if len(problems) > 0 {
name := f.Name
if name == "" {
if f.Key != "" {
name = f.Key
} else {
name = "anonymous"
}
}
// e.g. "my flag: missing key; missing default"
return fmt.Errorf("%s: %s\n", name, strings.Join(problems, "; "))
}
return nil
}
type flagValidationError struct {
errs []error
}
func newFlagValidationError(errs []error) *flagValidationError {
if len(errs) == 0 {
return nil
}
return &flagValidationError{errs}
}
func (e *flagValidationError) Error() string {
var s strings.Builder
s.WriteString("flag validation error: \n")
for _, err := range e.errs {
s.WriteString(err.Error())
}
return s.String()
}
func validate(flags []flagConfig) error {
var (
errs []error
seen = make(map[string]bool, len(flags))
)
for _, flag := range flags {
if err := flag.Valid(); err != nil {
errs = append(errs, err)
} else if _, repeated := seen[flag.Key]; repeated {
errs = append(errs, fmt.Errorf("duplicate flag key '%s'\n", flag.Key))
}
seen[flag.Key] = true
}
if len(errs) != 0 {
return newFlagValidationError(errs)
}
return nil
}
var argv = struct {
in, out *string
qualify *bool
}{
in: flag.String("in", "", "flag configuration path"),
out: flag.String("out", "", "flag generation destination path"),
qualify: flag.Bool("qualify", false, "qualify types with imported package name"),
}
func main() {
if err := run(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
os.Exit(0)
}
func run() error {
flag.Parse()
in, err := os.Open(*argv.in)
if err != nil {
return err
}
defer in.Close()
configuration, err := ioutil.ReadAll(in)
if err != nil {
return err
}
var flags []flagConfig
err = yaml.Unmarshal(configuration, &flags)
if err != nil {
return err
}
err = validate(flags)
if err != nil {
return err
}
t, err := template.New("flags").Funcs(templateFunctions()).Parse(tmpl)
if err != nil {
return err
}
out, err := os.Create(*argv.out)
if err != nil {
return err
}
defer out.Close()
var (
buf = new(bytes.Buffer)
vars = struct {
Qualify bool
Flags []flagConfig
}{
Qualify: *argv.qualify,
Flags: flags,
}
)
if err := t.Execute(buf, vars); err != nil {
return err
}
raw, err := ioutil.ReadAll(buf)
if err != nil {
return err
}
formatted, err := format.Source(raw)
if err != nil {
return err
}
_, err = out.Write(formatted)
return err
}
func templateFunctions() template.FuncMap {
functions := sprig.TxtFuncMap()
functions["lifetime"] = func(t interface{}) string {
switch t {
case feature.Permanent:
return "Permanent"
default:
return "Temporary"
}
}
functions["conditionalQuote"] = func(t interface{}) string {
switch t.(type) {
case string:
return fmt.Sprintf("%q", t)
default:
return fmt.Sprintf("%v", t)
}
}
functions["flagType"] = func(t interface{}) string {
switch t.(type) {
case bool:
return "BoolFlag"
case float64:
return "FloatFlag"
case int:
return "IntFlag"
default:
return "StringFlag"
}
}
functions["maker"] = func(t interface{}) string {
switch t.(type) {
case bool:
return "MakeBoolFlag"
case float64:
return "MakeFloatFlag"
case int:
return "MakeIntFlag"
default:
return "MakeStringFlag"
}
}
functions["package"] = func(t interface{}) string {
if t.(bool) {
return "feature."
}
return ""
}
functions["import"] = func(t interface{}) string {
if t.(bool) {
return "import \"github.com/influxdata/influxdb/v2/kit/feature\""
}
return ""
}
return functions
}