influxdb/vault/secret.go

274 lines
6.6 KiB
Go

package vault
import (
"context"
"encoding/json"
"fmt"
"strconv"
"time"
"github.com/hashicorp/vault/api"
platform "github.com/influxdata/influxdb"
)
var _ platform.SecretService = (*SecretService)(nil)
// SecretService is service for storing user secrets
type SecretService struct {
Client *api.Client
}
// Config may setup the vault client configuration. If any field is a zero
// value, it will be ignored and the default used.
type Config struct {
Address string
AgentAddress string
ClientTimeout time.Duration
MaxRetries int
Token string
TLSConfig
}
// TLSConfig is the configuration for TLS.
type TLSConfig struct {
CACert string
CAPath string
ClientCert string
ClientKey string
InsecureSkipVerify bool
TLSServerName string
}
func (c Config) assign(apiCFG *api.Config) error {
if c.Address != "" {
apiCFG.Address = c.Address
}
if c.AgentAddress != "" {
apiCFG.AgentAddress = c.AgentAddress
}
if c.ClientTimeout > 0 {
apiCFG.Timeout = c.ClientTimeout
}
if c.MaxRetries > 0 {
apiCFG.MaxRetries = c.MaxRetries
}
if c.TLSServerName != "" {
err := apiCFG.ConfigureTLS(&api.TLSConfig{
CACert: c.CACert,
CAPath: c.CAPath,
ClientCert: c.ClientCert,
ClientKey: c.ClientKey,
TLSServerName: c.TLSServerName,
Insecure: c.InsecureSkipVerify,
})
if err != nil {
return err
}
}
return nil
}
// ConfigOptFn is a functional input option to configure a vault service.
type ConfigOptFn func(Config) Config
// WithConfig provides a configuration to the service constructor.
func WithConfig(config Config) ConfigOptFn {
return func(Config) Config {
return config
}
}
// WithTLSConfig allows one to set the TLS config only.
func WithTLSConfig(tlsCFG TLSConfig) ConfigOptFn {
return func(cfg Config) Config {
cfg.TLSConfig = tlsCFG
return cfg
}
}
// NewSecretService creates an instance of a SecretService.
// The service is configured using the standard vault environment variables.
// https://www.vaultproject.io/docs/commands/index.html#environment-variables
func NewSecretService(cfgOpts ...ConfigOptFn) (*SecretService, error) {
explicitConfig := Config{}
for _, o := range cfgOpts {
explicitConfig = o(explicitConfig)
}
cfg := api.DefaultConfig()
if cfg.Error != nil {
return nil, cfg.Error
}
err := explicitConfig.assign(cfg)
if err != nil {
return nil, err
}
c, err := api.NewClient(cfg)
if err != nil {
return nil, err
}
if explicitConfig.Token != "" {
c.SetToken(explicitConfig.Token)
}
return &SecretService{
Client: c,
}, nil
}
// LoadSecret retrieves the secret value v found at key k for organization orgID.
func (s *SecretService) LoadSecret(ctx context.Context, orgID platform.ID, k string) (string, error) {
data, _, err := s.loadSecrets(ctx, orgID)
if err != nil {
return "", err
}
if v, ok := data[k]; ok {
return v, nil
}
return "", fmt.Errorf("secret not found")
}
// loadSecrets retrieves a map of secrets for an organization and the version of the secrets retrieved.
// The version is used to ensure that concurrent updates will not overwrite one another.
func (s *SecretService) loadSecrets(ctx context.Context, orgID platform.ID) (map[string]string, int, error) {
// TODO(desa): update url construction
sec, err := s.Client.Logical().Read(fmt.Sprintf("/secret/data/%s", orgID))
if err != nil {
return nil, -1, err
}
m := map[string]string{}
if sec == nil {
return m, 0, nil
}
data, ok := sec.Data["data"].(map[string]interface{})
if !ok {
return nil, -1, fmt.Errorf("value found in secret data is not map[string]interface{}")
}
for k, v := range data {
val, ok := v.(string)
if !ok {
continue
}
m[k] = val
}
metadata, ok := sec.Data["metadata"].(map[string]interface{})
if !ok {
return nil, -1, fmt.Errorf("value found in secret metadata is not map[string]interface{}")
}
var version int
switch v := metadata["version"].(type) {
case json.Number:
ver, err := v.Int64()
if err != nil {
return nil, -1, err
}
version = int(ver)
case string:
ver, err := strconv.Atoi(v)
if err != nil {
return nil, -1, fmt.Errorf("version provided is not a valid integer: %v", err)
}
version = ver
case int:
version = v
default:
return nil, -1, fmt.Errorf("version provided is %T not a string or int", v)
}
return m, version, nil
}
// GetSecretKeys retrieves all secret keys that are stored for the organization orgID.
func (s *SecretService) GetSecretKeys(ctx context.Context, orgID platform.ID) ([]string, error) {
data, _, err := s.loadSecrets(ctx, orgID)
if err != nil {
return nil, err
}
keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
return keys, nil
}
// PutSecret stores the secret pair (k,v) for the organization orgID.
func (s *SecretService) PutSecret(ctx context.Context, orgID platform.ID, k string, v string) error {
data, ver, err := s.loadSecrets(ctx, orgID)
if err != nil {
return err
}
data[k] = v
return s.putSecrets(ctx, orgID, data, ver)
}
// putSecrets will set all provided data values for the organization orgID.
// If version is negative, the write will overwrite all specified values.
// If version is 0, the write will only be allowed if the keys do not exists.
// If version is non-zero, the write will only be allowed if the keys current
// version in vault matches the version specified.
func (s *SecretService) putSecrets(ctx context.Context, orgID platform.ID, data map[string]string, version int) error {
m := map[string]interface{}{"data": data}
if version >= 0 {
m["options"] = map[string]interface{}{"cas": version}
}
if _, err := s.Client.Logical().Write(fmt.Sprintf("/secret/data/%s", orgID), m); err != nil {
return err
}
return nil
}
// PutSecrets puts all provided secrets and overwrites any previous values.
func (s *SecretService) PutSecrets(ctx context.Context, orgID platform.ID, m map[string]string) error {
return s.putSecrets(ctx, orgID, m, -1)
}
// PatchSecrets patches all provided secrets and updates any previous values.
func (s *SecretService) PatchSecrets(ctx context.Context, orgID platform.ID, m map[string]string) error {
data, ver, err := s.loadSecrets(ctx, orgID)
if err != nil {
return err
}
for k, v := range m {
data[k] = v
}
return s.putSecrets(ctx, orgID, data, ver)
}
// DeleteSecret removes a single secret from the secret store.
func (s *SecretService) DeleteSecret(ctx context.Context, orgID platform.ID, ks ...string) error {
data, ver, err := s.loadSecrets(ctx, orgID)
if err != nil {
return err
}
for _, k := range ks {
delete(data, k)
}
return s.putSecrets(ctx, orgID, data, ver)
}