274 lines
6.6 KiB
Go
274 lines
6.6 KiB
Go
package vault
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/hashicorp/vault/api"
|
|
platform "github.com/influxdata/influxdb/v2"
|
|
)
|
|
|
|
var _ platform.SecretService = (*SecretService)(nil)
|
|
|
|
// SecretService is service for storing user secrets
|
|
type SecretService struct {
|
|
Client *api.Client
|
|
}
|
|
|
|
// Config may setup the vault client configuration. If any field is a zero
|
|
// value, it will be ignored and the default used.
|
|
type Config struct {
|
|
Address string
|
|
AgentAddress string
|
|
ClientTimeout time.Duration
|
|
MaxRetries int
|
|
Token string
|
|
TLSConfig
|
|
}
|
|
|
|
// TLSConfig is the configuration for TLS.
|
|
type TLSConfig struct {
|
|
CACert string
|
|
CAPath string
|
|
ClientCert string
|
|
ClientKey string
|
|
InsecureSkipVerify bool
|
|
TLSServerName string
|
|
}
|
|
|
|
func (c Config) assign(apiCFG *api.Config) error {
|
|
if c.Address != "" {
|
|
apiCFG.Address = c.Address
|
|
}
|
|
|
|
if c.AgentAddress != "" {
|
|
apiCFG.AgentAddress = c.AgentAddress
|
|
}
|
|
|
|
if c.ClientTimeout > 0 {
|
|
apiCFG.Timeout = c.ClientTimeout
|
|
}
|
|
|
|
if c.MaxRetries > 0 {
|
|
apiCFG.MaxRetries = c.MaxRetries
|
|
}
|
|
|
|
if c.TLSServerName != "" {
|
|
err := apiCFG.ConfigureTLS(&api.TLSConfig{
|
|
CACert: c.CACert,
|
|
CAPath: c.CAPath,
|
|
ClientCert: c.ClientCert,
|
|
ClientKey: c.ClientKey,
|
|
TLSServerName: c.TLSServerName,
|
|
Insecure: c.InsecureSkipVerify,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ConfigOptFn is a functional input option to configure a vault service.
|
|
type ConfigOptFn func(Config) Config
|
|
|
|
// WithConfig provides a configuration to the service constructor.
|
|
func WithConfig(config Config) ConfigOptFn {
|
|
return func(Config) Config {
|
|
return config
|
|
}
|
|
}
|
|
|
|
// WithTLSConfig allows one to set the TLS config only.
|
|
func WithTLSConfig(tlsCFG TLSConfig) ConfigOptFn {
|
|
return func(cfg Config) Config {
|
|
cfg.TLSConfig = tlsCFG
|
|
return cfg
|
|
}
|
|
}
|
|
|
|
// NewSecretService creates an instance of a SecretService.
|
|
// The service is configured using the standard vault environment variables.
|
|
// https://www.vaultproject.io/docs/commands/index.html#environment-variables
|
|
func NewSecretService(cfgOpts ...ConfigOptFn) (*SecretService, error) {
|
|
explicitConfig := Config{}
|
|
for _, o := range cfgOpts {
|
|
explicitConfig = o(explicitConfig)
|
|
}
|
|
|
|
cfg := api.DefaultConfig()
|
|
if cfg.Error != nil {
|
|
return nil, cfg.Error
|
|
}
|
|
|
|
err := explicitConfig.assign(cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c, err := api.NewClient(cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if explicitConfig.Token != "" {
|
|
c.SetToken(explicitConfig.Token)
|
|
}
|
|
|
|
return &SecretService{
|
|
Client: c,
|
|
}, nil
|
|
}
|
|
|
|
// LoadSecret retrieves the secret value v found at key k for organization orgID.
|
|
func (s *SecretService) LoadSecret(ctx context.Context, orgID platform.ID, k string) (string, error) {
|
|
data, _, err := s.loadSecrets(ctx, orgID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if v, ok := data[k]; ok {
|
|
return v, nil
|
|
}
|
|
|
|
return "", fmt.Errorf("secret not found")
|
|
}
|
|
|
|
// loadSecrets retrieves a map of secrets for an organization and the version of the secrets retrieved.
|
|
// The version is used to ensure that concurrent updates will not overwrite one another.
|
|
func (s *SecretService) loadSecrets(ctx context.Context, orgID platform.ID) (map[string]string, int, error) {
|
|
// TODO(desa): update url construction
|
|
sec, err := s.Client.Logical().Read(fmt.Sprintf("/secret/data/%s", orgID))
|
|
if err != nil {
|
|
return nil, -1, err
|
|
}
|
|
|
|
m := map[string]string{}
|
|
if sec == nil {
|
|
return m, 0, nil
|
|
}
|
|
|
|
data, ok := sec.Data["data"].(map[string]interface{})
|
|
if !ok {
|
|
return nil, -1, fmt.Errorf("value found in secret data is not map[string]interface{}")
|
|
}
|
|
|
|
for k, v := range data {
|
|
val, ok := v.(string)
|
|
if !ok {
|
|
continue
|
|
}
|
|
m[k] = val
|
|
}
|
|
|
|
metadata, ok := sec.Data["metadata"].(map[string]interface{})
|
|
if !ok {
|
|
return nil, -1, fmt.Errorf("value found in secret metadata is not map[string]interface{}")
|
|
}
|
|
|
|
var version int
|
|
switch v := metadata["version"].(type) {
|
|
case json.Number:
|
|
ver, err := v.Int64()
|
|
if err != nil {
|
|
return nil, -1, err
|
|
}
|
|
version = int(ver)
|
|
case string:
|
|
ver, err := strconv.Atoi(v)
|
|
if err != nil {
|
|
return nil, -1, fmt.Errorf("version provided is not a valid integer: %v", err)
|
|
}
|
|
version = ver
|
|
case int:
|
|
version = v
|
|
default:
|
|
return nil, -1, fmt.Errorf("version provided is %T not a string or int", v)
|
|
}
|
|
|
|
return m, version, nil
|
|
}
|
|
|
|
// GetSecretKeys retrieves all secret keys that are stored for the organization orgID.
|
|
func (s *SecretService) GetSecretKeys(ctx context.Context, orgID platform.ID) ([]string, error) {
|
|
data, _, err := s.loadSecrets(ctx, orgID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
keys := make([]string, 0, len(data))
|
|
for k := range data {
|
|
keys = append(keys, k)
|
|
}
|
|
|
|
return keys, nil
|
|
}
|
|
|
|
// PutSecret stores the secret pair (k,v) for the organization orgID.
|
|
func (s *SecretService) PutSecret(ctx context.Context, orgID platform.ID, k string, v string) error {
|
|
data, ver, err := s.loadSecrets(ctx, orgID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
data[k] = v
|
|
|
|
return s.putSecrets(ctx, orgID, data, ver)
|
|
}
|
|
|
|
// putSecrets will set all provided data values for the organization orgID.
|
|
// If version is negative, the write will overwrite all specified values.
|
|
// If version is 0, the write will only be allowed if the keys do not exists.
|
|
// If version is non-zero, the write will only be allowed if the keys current
|
|
// version in vault matches the version specified.
|
|
func (s *SecretService) putSecrets(ctx context.Context, orgID platform.ID, data map[string]string, version int) error {
|
|
m := map[string]interface{}{"data": data}
|
|
|
|
if version >= 0 {
|
|
m["options"] = map[string]interface{}{"cas": version}
|
|
}
|
|
|
|
if _, err := s.Client.Logical().Write(fmt.Sprintf("/secret/data/%s", orgID), m); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// PutSecrets puts all provided secrets and overwrites any previous values.
|
|
func (s *SecretService) PutSecrets(ctx context.Context, orgID platform.ID, m map[string]string) error {
|
|
return s.putSecrets(ctx, orgID, m, -1)
|
|
}
|
|
|
|
// PatchSecrets patches all provided secrets and updates any previous values.
|
|
func (s *SecretService) PatchSecrets(ctx context.Context, orgID platform.ID, m map[string]string) error {
|
|
data, ver, err := s.loadSecrets(ctx, orgID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for k, v := range m {
|
|
data[k] = v
|
|
}
|
|
|
|
return s.putSecrets(ctx, orgID, data, ver)
|
|
}
|
|
|
|
// DeleteSecret removes a single secret from the secret store.
|
|
func (s *SecretService) DeleteSecret(ctx context.Context, orgID platform.ID, ks ...string) error {
|
|
data, ver, err := s.loadSecrets(ctx, orgID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, k := range ks {
|
|
delete(data, k)
|
|
}
|
|
|
|
return s.putSecrets(ctx, orgID, data, ver)
|
|
}
|