Merge pull request #1366 from influxdata/feat/secret-service

Add secret service interface and boltdb implementation
pull/10616/head
Michael Desa 2018-11-15 13:20:52 -05:00 committed by GitHub
commit 904d70ba05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 492 additions and 0 deletions

View File

@ -162,6 +162,11 @@ func (c *Client) initialize(ctx context.Context) error {
return err
}
// Always create SecretService bucket.
if err := c.initializeSecretService(ctx, tx); err != nil {
return err
}
return nil
}); err != nil {
return err

191
bolt/secret.go Normal file
View File

@ -0,0 +1,191 @@
package bolt
import (
"context"
"encoding/base64"
"fmt"
bolt "github.com/coreos/bbolt"
"github.com/influxdata/platform"
)
var (
secretBucket = []byte("secretsv1")
)
var _ platform.SecretService = (*Client)(nil)
func (c *Client) initializeSecretService(ctx context.Context, tx *bolt.Tx) error {
if _, err := tx.CreateBucketIfNotExists([]byte(secretBucket)); err != nil {
return err
}
return nil
}
// LoadSecret retrieves the secret value v found at key k for organization orgID.
func (c *Client) LoadSecret(ctx context.Context, orgID platform.ID, k string) (string, error) {
var v string
err := c.db.View(func(tx *bolt.Tx) error {
val, err := c.loadSecret(ctx, tx, orgID, k)
if err != nil {
return err
}
v = val
return nil
})
if err != nil {
return "", err
}
return v, nil
}
func (c *Client) loadSecret(ctx context.Context, tx *bolt.Tx, orgID platform.ID, k string) (string, error) {
key, err := encodeSecretKey(orgID, k)
if err != nil {
return "", err
}
val := tx.Bucket(secretBucket).Get(key)
if len(val) == 0 {
return "", fmt.Errorf("secret not found")
}
v, err := decodeSecretValue(val)
if err != nil {
return "", err
}
return v, nil
}
// GetSecretKeys retrieves all secret keys that are stored for the organization orgID.
func (c *Client) GetSecretKeys(ctx context.Context, orgID platform.ID) ([]string, error) {
var vs []string
err := c.db.View(func(tx *bolt.Tx) error {
vals, err := c.getSecretKeys(ctx, tx, orgID)
if err != nil {
return err
}
vs = vals
return nil
})
if err != nil {
return nil, err
}
return vs, nil
}
func (c *Client) getSecretKeys(ctx context.Context, tx *bolt.Tx, orgID platform.ID) ([]string, error) {
cur := tx.Bucket(secretBucket).Cursor()
prefix, err := orgID.Encode()
if err != nil {
return nil, err
}
k, _ := cur.Seek(prefix)
id, key, err := decodeSecretKey(k)
if err != nil {
return nil, err
}
if id != orgID {
return nil, fmt.Errorf("organization has no secret keys")
}
keys := []string{key}
for {
k, _ = cur.Next()
if len(k) == 0 {
// We've reached the end of the keys so we're done
break
}
id, key, err = decodeSecretKey(k)
if err != nil {
return nil, err
}
if id != orgID {
// We've reached the end of the keyspace for the provided orgID
break
}
keys = append(keys, key)
}
return keys, nil
}
// PutSecret stores the secret pair (k,v) for the organization orgID.
func (c *Client) PutSecret(ctx context.Context, orgID platform.ID, k, v string) error {
return c.db.Update(func(tx *bolt.Tx) error {
return c.putSecret(ctx, tx, orgID, k, v)
})
}
func (c *Client) putSecret(ctx context.Context, tx *bolt.Tx, orgID platform.ID, k, v string) error {
key, err := encodeSecretKey(orgID, k)
if err != nil {
return err
}
val := encodeSecretValue(v)
if err := tx.Bucket(secretBucket).Put(key, val); err != nil {
return err
}
return nil
}
func encodeSecretKey(orgID platform.ID, k string) ([]byte, error) {
buf, err := orgID.Encode()
if err != nil {
return nil, err
}
key := make([]byte, 0, platform.IDLength+len(k))
key = append(key, buf...)
key = append(key, k...)
return key, nil
}
func decodeSecretKey(key []byte) (platform.ID, string, error) {
if len(key) < platform.IDLength {
// This should not happen.
return platform.InvalidID(), "", fmt.Errorf("Provided key is too short to contain an ID. Please report this error.")
}
var id platform.ID
if err := id.Decode(key[:platform.IDLength]); err != nil {
return platform.InvalidID(), "", err
}
k := string(key[platform.IDLength:])
return id, k, nil
}
func decodeSecretValue(val []byte) (string, error) {
// store the secret value base64 encoded so that it's marginally better than plaintext
v := make([]byte, base64.StdEncoding.DecodedLen(len(val)))
if _, err := base64.StdEncoding.Decode(v, val); err != nil {
return "", err
}
return string(v), nil
}
func encodeSecretValue(v string) []byte {
val := make([]byte, base64.StdEncoding.EncodedLen(len(v)))
base64.StdEncoding.Encode(val, []byte(v))
return val
}

31
bolt/secret_test.go Normal file
View File

@ -0,0 +1,31 @@
package bolt_test
import (
"context"
"testing"
"github.com/influxdata/platform"
platformtesting "github.com/influxdata/platform/testing"
)
func initSecretService(f platformtesting.SecretServiceFields, t *testing.T) (platform.SecretService, func()) {
c, closeFn, err := NewTestClient()
if err != nil {
t.Fatalf("failed to create new bolt client: %v", err)
}
ctx := context.TODO()
for _, s := range f.Secrets {
for k, v := range s.Env {
if err := c.PutSecret(ctx, s.OrganizationID, k, v); err != nil {
t.Fatalf("failed to populate secrets")
}
}
}
return c, func() {
defer closeFn()
}
}
func TestSecretService(t *testing.T) {
platformtesting.SecretService(initSecretService, t)
}

15
secret.go Normal file
View File

@ -0,0 +1,15 @@
package platform
import "context"
// SecretService a service for storing and retrieving secrets.
type SecretService interface {
// LoadSecret retrieves the secret value v found at key k for organization orgID.
LoadSecret(ctx context.Context, orgID ID, k string) (string, error)
// GetSecretKeys retrieves all secret keys that are stored for the organization orgID.
GetSecretKeys(ctx context.Context, orgID ID) ([]string, error)
// PutSecret stores the secret pair (k,v) for the organization orgID.
PutSecret(ctx context.Context, orgID ID, k string, v string) error
}

250
testing/secret.go Normal file
View File

@ -0,0 +1,250 @@
package testing
import (
"context"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/influxdata/platform"
)
// A secret is a comparable data structure that is used for testing
type Secret struct {
OrganizationID platform.ID
Env map[string]string
}
// SecretServiceFields contain the
type SecretServiceFields struct {
Secrets []Secret
}
// SecretService will test all methods for the secrets service.
func SecretService(
init func(SecretServiceFields, *testing.T) (platform.SecretService, func()),
t *testing.T,
) {
tests := []struct {
name string
fn func(
init func(SecretServiceFields, *testing.T) (platform.SecretService, func()),
t *testing.T,
)
}{
{
name: "LoadSecret",
fn: LoadSecret,
},
{
name: "PutSecret",
fn: PutSecret,
},
{
name: "GetSecretKeys",
fn: GetSecretKeys,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.fn(init, t)
})
}
}
// LoadSecret tests the LoadSecret method for the SecretService interface.
func LoadSecret(
init func(f SecretServiceFields, t *testing.T) (platform.SecretService, func()),
t *testing.T,
) {
type args struct {
orgID platform.ID
key string
}
type wants struct {
value string
err error
}
tests := []struct {
name string
fields SecretServiceFields
args args
wants wants
}{
{
name: "load secret field",
fields: SecretServiceFields{
Secrets: []Secret{
{
OrganizationID: platform.ID(1),
Env: map[string]string{
"api_key": "abc123xyz",
},
},
},
},
args: args{
orgID: platform.ID(1),
key: "api_key",
},
wants: wants{
value: "abc123xyz",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s, done := init(tt.fields, t)
defer done()
ctx := context.Background()
val, err := s.LoadSecret(ctx, tt.args.orgID, tt.args.key)
if (err != nil) != (tt.wants.err != nil) {
t.Fatalf("expected error '%v' got '%v'", tt.wants.err, err)
}
if err != nil && tt.wants.err != nil {
if err.Error() != tt.wants.err.Error() {
t.Fatalf("expected error messages to match '%v' got '%v'", tt.wants.err, err.Error())
}
}
if want, got := tt.wants.value, val; want != got {
t.Errorf("expected value to be %s, got %s", want, got)
}
})
}
}
// PutSecret tests the PutSecret method for the SecretService interface.
func PutSecret(
init func(f SecretServiceFields, t *testing.T) (platform.SecretService, func()),
t *testing.T,
) {
type args struct {
orgID platform.ID
key string
value string
}
type wants struct {
err error
}
tests := []struct {
name string
fields SecretServiceFields
args args
wants wants
}{
{
name: "put secret",
fields: SecretServiceFields{},
args: args{
orgID: platform.ID(1),
key: "api_key",
value: "abc123xyz",
},
wants: wants{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s, done := init(tt.fields, t)
defer done()
ctx := context.Background()
err := s.PutSecret(ctx, tt.args.orgID, tt.args.key, tt.args.value)
if (err != nil) != (tt.wants.err != nil) {
t.Fatalf("expected error '%v' got '%v'", tt.wants.err, err)
}
if err != nil && tt.wants.err != nil {
if err.Error() != tt.wants.err.Error() {
t.Fatalf("expected error messages to match '%v' got '%v'", tt.wants.err, err.Error())
}
}
val, err := s.LoadSecret(ctx, tt.args.orgID, tt.args.key)
if err != nil {
t.Fatalf("unexpected error %v", err)
}
if want, got := tt.args.value, val; want != got {
t.Errorf("expected value to be %s, got %s", want, got)
}
})
}
}
// GetSecretKeys tests the GetSecretKeys method for the SecretService interface.
func GetSecretKeys(
init func(f SecretServiceFields, t *testing.T) (platform.SecretService, func()),
t *testing.T,
) {
type args struct {
orgID platform.ID
}
type wants struct {
keys []string
err error
}
tests := []struct {
name string
fields SecretServiceFields
args args
wants wants
}{
{
name: "get secret keys for one org",
fields: SecretServiceFields{
Secrets: []Secret{
{
OrganizationID: platform.ID(1),
Env: map[string]string{
"api_key": "abc123xyz",
},
},
{
OrganizationID: platform.ID(2),
Env: map[string]string{
"api_key": "zyx321cba",
},
},
},
},
args: args{
orgID: platform.ID(1),
},
wants: wants{
keys: []string{"api_key"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s, done := init(tt.fields, t)
defer done()
ctx := context.Background()
keys, err := s.GetSecretKeys(ctx, tt.args.orgID)
if (err != nil) != (tt.wants.err != nil) {
t.Fatalf("expected error '%v' got '%v'", tt.wants.err, err)
}
if err != nil && tt.wants.err != nil {
if err.Error() != tt.wants.err.Error() {
t.Fatalf("expected error messages to match '%v' got '%v'", tt.wants.err, err.Error())
}
}
if diff := cmp.Diff(keys, tt.wants.keys); diff != "" {
t.Errorf("keys are different -got/+want\ndiff %s", diff)
}
})
}
}