package vault import ( "context" "encoding/json" "fmt" "strconv" "github.com/hashicorp/vault/api" platform "github.com/influxdata/influxdb" ) var _ platform.SecretService = (*SecretService)(nil) // SecretService is service for storing user secrets type SecretService struct { Client *api.Client } // NewSecretService creates an instance of a SecretService. // The service is configured using the standard vault environment variables. // https://www.vaultproject.io/docs/commands/index.html#environment-variables func NewSecretService() (*SecretService, error) { cfg := api.DefaultConfig() if err := cfg.ReadEnvironment(); err != nil { return nil, err } c, err := api.NewClient(cfg) if err != nil { return nil, err } return &SecretService{ Client: c, }, nil } // LoadSecret retrieves the secret value v found at key k for organization orgID. func (s *SecretService) LoadSecret(ctx context.Context, orgID platform.ID, k string) (string, error) { data, _, err := s.loadSecrets(ctx, orgID) if err != nil { return "", err } if v, ok := data[k]; ok { return v, nil } return "", fmt.Errorf("secret not found") } // loadSecrets retrieves a map of secrets for an organization and the version of the secrets retrieved. // The version is used to ensure that concurrent updates will not overwrite one another. func (s *SecretService) loadSecrets(ctx context.Context, orgID platform.ID) (map[string]string, int, error) { // TODO(desa): update url construction sec, err := s.Client.Logical().Read(fmt.Sprintf("/secret/data/%s", orgID)) if err != nil { return nil, -1, err } m := map[string]string{} if sec == nil { return m, 0, nil } data, ok := sec.Data["data"].(map[string]interface{}) if !ok { return nil, -1, fmt.Errorf("value found in secret data is not map[string]interface{}") } for k, v := range data { val, ok := v.(string) if !ok { continue } m[k] = val } metadata, ok := sec.Data["metadata"].(map[string]interface{}) if !ok { return nil, -1, fmt.Errorf("value found in secret metadata is not map[string]interface{}") } var version int switch v := metadata["version"].(type) { case json.Number: ver, err := v.Int64() if err != nil { return nil, -1, err } version = int(ver) case string: ver, err := strconv.Atoi(v) if err != nil { return nil, -1, fmt.Errorf("version provided is not a valid integer: %v", err) } version = ver case int: version = v default: return nil, -1, fmt.Errorf("version provided is %T not a string or int", v) } return m, version, nil } // GetSecretKeys retrieves all secret keys that are stored for the organization orgID. func (s *SecretService) GetSecretKeys(ctx context.Context, orgID platform.ID) ([]string, error) { data, _, err := s.loadSecrets(ctx, orgID) if err != nil { return nil, err } keys := make([]string, 0, len(data)) for k := range data { keys = append(keys, k) } return keys, nil } // PutSecret stores the secret pair (k,v) for the organization orgID. func (s *SecretService) PutSecret(ctx context.Context, orgID platform.ID, k string, v string) error { data, ver, err := s.loadSecrets(ctx, orgID) if err != nil { return err } data[k] = v return s.putSecrets(ctx, orgID, data, ver) } // putSecrets will set all provided data values for the organization orgID. // If version is negative, the write will overwrite all specified values. // If version is 0, the write will only be allowed if the keys do not exists. // If version is non-zero, the write will only be allowed if the keys current // version in vault matches the version specified. func (s *SecretService) putSecrets(ctx context.Context, orgID platform.ID, data map[string]string, version int) error { m := map[string]interface{}{"data": data} if version >= 0 { m["options"] = map[string]interface{}{"cas": version} } if _, err := s.Client.Logical().Write(fmt.Sprintf("/secret/data/%s", orgID), m); err != nil { return err } return nil } // PutSecrets puts all provided secrets and overwrites any previous values. func (s *SecretService) PutSecrets(ctx context.Context, orgID platform.ID, m map[string]string) error { return s.putSecrets(ctx, orgID, m, -1) } // PatchSecrets patches all provided secrets and updates any previous values. func (s *SecretService) PatchSecrets(ctx context.Context, orgID platform.ID, m map[string]string) error { data, ver, err := s.loadSecrets(ctx, orgID) if err != nil { return err } for k, v := range m { data[k] = v } return s.putSecrets(ctx, orgID, data, ver) } // DeleteSecret removes a single secret from the secret store. func (s *SecretService) DeleteSecret(ctx context.Context, orgID platform.ID, ks ...string) error { data, ver, err := s.loadSecrets(ctx, orgID) if err != nil { return err } for _, k := range ks { delete(data, k) } return s.putSecrets(ctx, orgID, data, ver) }