Merge pull request #1366 from influxdata/feat/secret-service
Add secret service interface and boltdb implementationpull/10616/head
commit
904d70ba05
|
@ -162,6 +162,11 @@ func (c *Client) initialize(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Always create SecretService bucket.
|
||||
if err := c.initializeSecretService(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
|
|
|
@ -0,0 +1,191 @@
|
|||
package bolt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
bolt "github.com/coreos/bbolt"
|
||||
"github.com/influxdata/platform"
|
||||
)
|
||||
|
||||
var (
|
||||
secretBucket = []byte("secretsv1")
|
||||
)
|
||||
|
||||
var _ platform.SecretService = (*Client)(nil)
|
||||
|
||||
func (c *Client) initializeSecretService(ctx context.Context, tx *bolt.Tx) error {
|
||||
if _, err := tx.CreateBucketIfNotExists([]byte(secretBucket)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadSecret retrieves the secret value v found at key k for organization orgID.
|
||||
func (c *Client) LoadSecret(ctx context.Context, orgID platform.ID, k string) (string, error) {
|
||||
var v string
|
||||
err := c.db.View(func(tx *bolt.Tx) error {
|
||||
val, err := c.loadSecret(ctx, tx, orgID, k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v = val
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (c *Client) loadSecret(ctx context.Context, tx *bolt.Tx, orgID platform.ID, k string) (string, error) {
|
||||
key, err := encodeSecretKey(orgID, k)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
val := tx.Bucket(secretBucket).Get(key)
|
||||
if len(val) == 0 {
|
||||
return "", fmt.Errorf("secret not found")
|
||||
}
|
||||
|
||||
v, err := decodeSecretValue(val)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// GetSecretKeys retrieves all secret keys that are stored for the organization orgID.
|
||||
func (c *Client) GetSecretKeys(ctx context.Context, orgID platform.ID) ([]string, error) {
|
||||
var vs []string
|
||||
err := c.db.View(func(tx *bolt.Tx) error {
|
||||
vals, err := c.getSecretKeys(ctx, tx, orgID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vs = vals
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return vs, nil
|
||||
}
|
||||
|
||||
func (c *Client) getSecretKeys(ctx context.Context, tx *bolt.Tx, orgID platform.ID) ([]string, error) {
|
||||
cur := tx.Bucket(secretBucket).Cursor()
|
||||
prefix, err := orgID.Encode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
k, _ := cur.Seek(prefix)
|
||||
|
||||
id, key, err := decodeSecretKey(k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if id != orgID {
|
||||
return nil, fmt.Errorf("organization has no secret keys")
|
||||
}
|
||||
|
||||
keys := []string{key}
|
||||
|
||||
for {
|
||||
k, _ = cur.Next()
|
||||
|
||||
if len(k) == 0 {
|
||||
// We've reached the end of the keys so we're done
|
||||
break
|
||||
}
|
||||
|
||||
id, key, err = decodeSecretKey(k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if id != orgID {
|
||||
// We've reached the end of the keyspace for the provided orgID
|
||||
break
|
||||
}
|
||||
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// PutSecret stores the secret pair (k,v) for the organization orgID.
|
||||
func (c *Client) PutSecret(ctx context.Context, orgID platform.ID, k, v string) error {
|
||||
return c.db.Update(func(tx *bolt.Tx) error {
|
||||
return c.putSecret(ctx, tx, orgID, k, v)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) putSecret(ctx context.Context, tx *bolt.Tx, orgID platform.ID, k, v string) error {
|
||||
key, err := encodeSecretKey(orgID, k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val := encodeSecretValue(v)
|
||||
|
||||
if err := tx.Bucket(secretBucket).Put(key, val); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeSecretKey(orgID platform.ID, k string) ([]byte, error) {
|
||||
buf, err := orgID.Encode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := make([]byte, 0, platform.IDLength+len(k))
|
||||
key = append(key, buf...)
|
||||
key = append(key, k...)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func decodeSecretKey(key []byte) (platform.ID, string, error) {
|
||||
if len(key) < platform.IDLength {
|
||||
// This should not happen.
|
||||
return platform.InvalidID(), "", fmt.Errorf("Provided key is too short to contain an ID. Please report this error.")
|
||||
}
|
||||
|
||||
var id platform.ID
|
||||
if err := id.Decode(key[:platform.IDLength]); err != nil {
|
||||
return platform.InvalidID(), "", err
|
||||
}
|
||||
|
||||
k := string(key[platform.IDLength:])
|
||||
|
||||
return id, k, nil
|
||||
}
|
||||
|
||||
func decodeSecretValue(val []byte) (string, error) {
|
||||
// store the secret value base64 encoded so that it's marginally better than plaintext
|
||||
v := make([]byte, base64.StdEncoding.DecodedLen(len(val)))
|
||||
if _, err := base64.StdEncoding.Decode(v, val); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(v), nil
|
||||
}
|
||||
|
||||
func encodeSecretValue(v string) []byte {
|
||||
val := make([]byte, base64.StdEncoding.EncodedLen(len(v)))
|
||||
base64.StdEncoding.Encode(val, []byte(v))
|
||||
return val
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package bolt_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/influxdata/platform"
|
||||
platformtesting "github.com/influxdata/platform/testing"
|
||||
)
|
||||
|
||||
func initSecretService(f platformtesting.SecretServiceFields, t *testing.T) (platform.SecretService, func()) {
|
||||
c, closeFn, err := NewTestClient()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new bolt client: %v", err)
|
||||
}
|
||||
ctx := context.TODO()
|
||||
for _, s := range f.Secrets {
|
||||
for k, v := range s.Env {
|
||||
if err := c.PutSecret(ctx, s.OrganizationID, k, v); err != nil {
|
||||
t.Fatalf("failed to populate secrets")
|
||||
}
|
||||
}
|
||||
}
|
||||
return c, func() {
|
||||
defer closeFn()
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecretService(t *testing.T) {
|
||||
platformtesting.SecretService(initSecretService, t)
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
package platform
|
||||
|
||||
import "context"
|
||||
|
||||
// SecretService a service for storing and retrieving secrets.
|
||||
type SecretService interface {
|
||||
// LoadSecret retrieves the secret value v found at key k for organization orgID.
|
||||
LoadSecret(ctx context.Context, orgID ID, k string) (string, error)
|
||||
|
||||
// GetSecretKeys retrieves all secret keys that are stored for the organization orgID.
|
||||
GetSecretKeys(ctx context.Context, orgID ID) ([]string, error)
|
||||
|
||||
// PutSecret stores the secret pair (k,v) for the organization orgID.
|
||||
PutSecret(ctx context.Context, orgID ID, k string, v string) error
|
||||
}
|
|
@ -0,0 +1,250 @@
|
|||
package testing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/influxdata/platform"
|
||||
)
|
||||
|
||||
// A secret is a comparable data structure that is used for testing
|
||||
type Secret struct {
|
||||
OrganizationID platform.ID
|
||||
Env map[string]string
|
||||
}
|
||||
|
||||
// SecretServiceFields contain the
|
||||
type SecretServiceFields struct {
|
||||
Secrets []Secret
|
||||
}
|
||||
|
||||
// SecretService will test all methods for the secrets service.
|
||||
func SecretService(
|
||||
init func(SecretServiceFields, *testing.T) (platform.SecretService, func()),
|
||||
t *testing.T,
|
||||
) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(
|
||||
init func(SecretServiceFields, *testing.T) (platform.SecretService, func()),
|
||||
t *testing.T,
|
||||
)
|
||||
}{
|
||||
{
|
||||
name: "LoadSecret",
|
||||
fn: LoadSecret,
|
||||
},
|
||||
{
|
||||
name: "PutSecret",
|
||||
fn: PutSecret,
|
||||
},
|
||||
{
|
||||
name: "GetSecretKeys",
|
||||
fn: GetSecretKeys,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.fn(init, t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// LoadSecret tests the LoadSecret method for the SecretService interface.
|
||||
func LoadSecret(
|
||||
init func(f SecretServiceFields, t *testing.T) (platform.SecretService, func()),
|
||||
t *testing.T,
|
||||
) {
|
||||
type args struct {
|
||||
orgID platform.ID
|
||||
key string
|
||||
}
|
||||
type wants struct {
|
||||
value string
|
||||
err error
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fields SecretServiceFields
|
||||
args args
|
||||
wants wants
|
||||
}{
|
||||
{
|
||||
name: "load secret field",
|
||||
fields: SecretServiceFields{
|
||||
Secrets: []Secret{
|
||||
{
|
||||
OrganizationID: platform.ID(1),
|
||||
Env: map[string]string{
|
||||
"api_key": "abc123xyz",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
orgID: platform.ID(1),
|
||||
key: "api_key",
|
||||
},
|
||||
wants: wants{
|
||||
value: "abc123xyz",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s, done := init(tt.fields, t)
|
||||
defer done()
|
||||
ctx := context.Background()
|
||||
|
||||
val, err := s.LoadSecret(ctx, tt.args.orgID, tt.args.key)
|
||||
if (err != nil) != (tt.wants.err != nil) {
|
||||
t.Fatalf("expected error '%v' got '%v'", tt.wants.err, err)
|
||||
}
|
||||
|
||||
if err != nil && tt.wants.err != nil {
|
||||
if err.Error() != tt.wants.err.Error() {
|
||||
t.Fatalf("expected error messages to match '%v' got '%v'", tt.wants.err, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if want, got := tt.wants.value, val; want != got {
|
||||
t.Errorf("expected value to be %s, got %s", want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// PutSecret tests the PutSecret method for the SecretService interface.
|
||||
func PutSecret(
|
||||
init func(f SecretServiceFields, t *testing.T) (platform.SecretService, func()),
|
||||
t *testing.T,
|
||||
) {
|
||||
type args struct {
|
||||
orgID platform.ID
|
||||
key string
|
||||
value string
|
||||
}
|
||||
type wants struct {
|
||||
err error
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fields SecretServiceFields
|
||||
args args
|
||||
wants wants
|
||||
}{
|
||||
{
|
||||
name: "put secret",
|
||||
fields: SecretServiceFields{},
|
||||
args: args{
|
||||
orgID: platform.ID(1),
|
||||
key: "api_key",
|
||||
value: "abc123xyz",
|
||||
},
|
||||
wants: wants{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s, done := init(tt.fields, t)
|
||||
defer done()
|
||||
ctx := context.Background()
|
||||
|
||||
err := s.PutSecret(ctx, tt.args.orgID, tt.args.key, tt.args.value)
|
||||
if (err != nil) != (tt.wants.err != nil) {
|
||||
t.Fatalf("expected error '%v' got '%v'", tt.wants.err, err)
|
||||
}
|
||||
|
||||
if err != nil && tt.wants.err != nil {
|
||||
if err.Error() != tt.wants.err.Error() {
|
||||
t.Fatalf("expected error messages to match '%v' got '%v'", tt.wants.err, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
val, err := s.LoadSecret(ctx, tt.args.orgID, tt.args.key)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error %v", err)
|
||||
}
|
||||
|
||||
if want, got := tt.args.value, val; want != got {
|
||||
t.Errorf("expected value to be %s, got %s", want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// GetSecretKeys tests the GetSecretKeys method for the SecretService interface.
|
||||
func GetSecretKeys(
|
||||
init func(f SecretServiceFields, t *testing.T) (platform.SecretService, func()),
|
||||
t *testing.T,
|
||||
) {
|
||||
type args struct {
|
||||
orgID platform.ID
|
||||
}
|
||||
type wants struct {
|
||||
keys []string
|
||||
err error
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fields SecretServiceFields
|
||||
args args
|
||||
wants wants
|
||||
}{
|
||||
{
|
||||
name: "get secret keys for one org",
|
||||
fields: SecretServiceFields{
|
||||
Secrets: []Secret{
|
||||
{
|
||||
OrganizationID: platform.ID(1),
|
||||
Env: map[string]string{
|
||||
"api_key": "abc123xyz",
|
||||
},
|
||||
},
|
||||
{
|
||||
OrganizationID: platform.ID(2),
|
||||
Env: map[string]string{
|
||||
"api_key": "zyx321cba",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
orgID: platform.ID(1),
|
||||
},
|
||||
wants: wants{
|
||||
keys: []string{"api_key"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s, done := init(tt.fields, t)
|
||||
defer done()
|
||||
ctx := context.Background()
|
||||
|
||||
keys, err := s.GetSecretKeys(ctx, tt.args.orgID)
|
||||
if (err != nil) != (tt.wants.err != nil) {
|
||||
t.Fatalf("expected error '%v' got '%v'", tt.wants.err, err)
|
||||
}
|
||||
|
||||
if err != nil && tt.wants.err != nil {
|
||||
if err.Error() != tt.wants.err.Error() {
|
||||
t.Fatalf("expected error messages to match '%v' got '%v'", tt.wants.err, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(keys, tt.wants.keys); diff != "" {
|
||||
t.Errorf("keys are different -got/+want\ndiff %s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue