diff --git a/bolt/bbolt.go b/bolt/bbolt.go index b86047a741..5ba9eed679 100644 --- a/bolt/bbolt.go +++ b/bolt/bbolt.go @@ -162,6 +162,11 @@ func (c *Client) initialize(ctx context.Context) error { return err } + // Always create SecretService bucket. + if err := c.initializeSecretService(ctx, tx); err != nil { + return err + } + return nil }); err != nil { return err diff --git a/bolt/secret.go b/bolt/secret.go new file mode 100644 index 0000000000..f97adfdc00 --- /dev/null +++ b/bolt/secret.go @@ -0,0 +1,191 @@ +package bolt + +import ( + "context" + "encoding/base64" + "fmt" + + bolt "github.com/coreos/bbolt" + "github.com/influxdata/platform" +) + +var ( + secretBucket = []byte("secretsv1") +) + +var _ platform.SecretService = (*Client)(nil) + +func (c *Client) initializeSecretService(ctx context.Context, tx *bolt.Tx) error { + if _, err := tx.CreateBucketIfNotExists([]byte(secretBucket)); err != nil { + return err + } + return nil +} + +// LoadSecret retrieves the secret value v found at key k for organization orgID. +func (c *Client) LoadSecret(ctx context.Context, orgID platform.ID, k string) (string, error) { + var v string + err := c.db.View(func(tx *bolt.Tx) error { + val, err := c.loadSecret(ctx, tx, orgID, k) + if err != nil { + return err + } + + v = val + return nil + }) + + if err != nil { + return "", err + } + + return v, nil +} + +func (c *Client) loadSecret(ctx context.Context, tx *bolt.Tx, orgID platform.ID, k string) (string, error) { + key, err := encodeSecretKey(orgID, k) + if err != nil { + return "", err + } + + val := tx.Bucket(secretBucket).Get(key) + if len(val) == 0 { + return "", fmt.Errorf("secret not found") + } + + v, err := decodeSecretValue(val) + if err != nil { + return "", err + } + + return v, nil +} + +// GetSecretKeys retrieves all secret keys that are stored for the organization orgID. +func (c *Client) GetSecretKeys(ctx context.Context, orgID platform.ID) ([]string, error) { + var vs []string + err := c.db.View(func(tx *bolt.Tx) error { + vals, err := c.getSecretKeys(ctx, tx, orgID) + if err != nil { + return err + } + + vs = vals + return nil + }) + + if err != nil { + return nil, err + } + + return vs, nil +} + +func (c *Client) getSecretKeys(ctx context.Context, tx *bolt.Tx, orgID platform.ID) ([]string, error) { + cur := tx.Bucket(secretBucket).Cursor() + prefix, err := orgID.Encode() + if err != nil { + return nil, err + } + k, _ := cur.Seek(prefix) + + id, key, err := decodeSecretKey(k) + if err != nil { + return nil, err + } + + if id != orgID { + return nil, fmt.Errorf("organization has no secret keys") + } + + keys := []string{key} + + for { + k, _ = cur.Next() + + if len(k) == 0 { + // We've reached the end of the keys so we're done + break + } + + id, key, err = decodeSecretKey(k) + if err != nil { + return nil, err + } + + if id != orgID { + // We've reached the end of the keyspace for the provided orgID + break + } + + keys = append(keys, key) + } + + return keys, nil +} + +// PutSecret stores the secret pair (k,v) for the organization orgID. +func (c *Client) PutSecret(ctx context.Context, orgID platform.ID, k, v string) error { + return c.db.Update(func(tx *bolt.Tx) error { + return c.putSecret(ctx, tx, orgID, k, v) + }) +} + +func (c *Client) putSecret(ctx context.Context, tx *bolt.Tx, orgID platform.ID, k, v string) error { + key, err := encodeSecretKey(orgID, k) + if err != nil { + return err + } + + val := encodeSecretValue(v) + + if err := tx.Bucket(secretBucket).Put(key, val); err != nil { + return err + } + return nil +} + +func encodeSecretKey(orgID platform.ID, k string) ([]byte, error) { + buf, err := orgID.Encode() + if err != nil { + return nil, err + } + + key := make([]byte, 0, platform.IDLength+len(k)) + key = append(key, buf...) + key = append(key, k...) + + return key, nil +} + +func decodeSecretKey(key []byte) (platform.ID, string, error) { + if len(key) < platform.IDLength { + // This should not happen. + return platform.InvalidID(), "", fmt.Errorf("Provided key is too short to contain an ID. Please report this error.") + } + + var id platform.ID + if err := id.Decode(key[:platform.IDLength]); err != nil { + return platform.InvalidID(), "", err + } + + k := string(key[platform.IDLength:]) + + return id, k, nil +} + +func decodeSecretValue(val []byte) (string, error) { + // store the secret value base64 encoded so that it's marginally better than plaintext + v := make([]byte, base64.StdEncoding.DecodedLen(len(val))) + if _, err := base64.StdEncoding.Decode(v, val); err != nil { + return "", err + } + + return string(v), nil +} + +func encodeSecretValue(v string) []byte { + val := make([]byte, base64.StdEncoding.EncodedLen(len(v))) + base64.StdEncoding.Encode(val, []byte(v)) + return val +} diff --git a/bolt/secret_test.go b/bolt/secret_test.go new file mode 100644 index 0000000000..b548322996 --- /dev/null +++ b/bolt/secret_test.go @@ -0,0 +1,31 @@ +package bolt_test + +import ( + "context" + "testing" + + "github.com/influxdata/platform" + platformtesting "github.com/influxdata/platform/testing" +) + +func initSecretService(f platformtesting.SecretServiceFields, t *testing.T) (platform.SecretService, func()) { + c, closeFn, err := NewTestClient() + if err != nil { + t.Fatalf("failed to create new bolt client: %v", err) + } + ctx := context.TODO() + for _, s := range f.Secrets { + for k, v := range s.Env { + if err := c.PutSecret(ctx, s.OrganizationID, k, v); err != nil { + t.Fatalf("failed to populate secrets") + } + } + } + return c, func() { + defer closeFn() + } +} + +func TestSecretService(t *testing.T) { + platformtesting.SecretService(initSecretService, t) +} diff --git a/secret.go b/secret.go new file mode 100644 index 0000000000..fcc7d9d1c5 --- /dev/null +++ b/secret.go @@ -0,0 +1,15 @@ +package platform + +import "context" + +// SecretService a service for storing and retrieving secrets. +type SecretService interface { + // LoadSecret retrieves the secret value v found at key k for organization orgID. + LoadSecret(ctx context.Context, orgID ID, k string) (string, error) + + // GetSecretKeys retrieves all secret keys that are stored for the organization orgID. + GetSecretKeys(ctx context.Context, orgID ID) ([]string, error) + + // PutSecret stores the secret pair (k,v) for the organization orgID. + PutSecret(ctx context.Context, orgID ID, k string, v string) error +} diff --git a/testing/secret.go b/testing/secret.go new file mode 100644 index 0000000000..9039e06851 --- /dev/null +++ b/testing/secret.go @@ -0,0 +1,250 @@ +package testing + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/influxdata/platform" +) + +// A secret is a comparable data structure that is used for testing +type Secret struct { + OrganizationID platform.ID + Env map[string]string +} + +// SecretServiceFields contain the +type SecretServiceFields struct { + Secrets []Secret +} + +// SecretService will test all methods for the secrets service. +func SecretService( + init func(SecretServiceFields, *testing.T) (platform.SecretService, func()), + t *testing.T, +) { + + tests := []struct { + name string + fn func( + init func(SecretServiceFields, *testing.T) (platform.SecretService, func()), + t *testing.T, + ) + }{ + { + name: "LoadSecret", + fn: LoadSecret, + }, + { + name: "PutSecret", + fn: PutSecret, + }, + { + name: "GetSecretKeys", + fn: GetSecretKeys, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.fn(init, t) + }) + } +} + +// LoadSecret tests the LoadSecret method for the SecretService interface. +func LoadSecret( + init func(f SecretServiceFields, t *testing.T) (platform.SecretService, func()), + t *testing.T, +) { + type args struct { + orgID platform.ID + key string + } + type wants struct { + value string + err error + } + + tests := []struct { + name string + fields SecretServiceFields + args args + wants wants + }{ + { + name: "load secret field", + fields: SecretServiceFields{ + Secrets: []Secret{ + { + OrganizationID: platform.ID(1), + Env: map[string]string{ + "api_key": "abc123xyz", + }, + }, + }, + }, + args: args{ + orgID: platform.ID(1), + key: "api_key", + }, + wants: wants{ + value: "abc123xyz", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s, done := init(tt.fields, t) + defer done() + ctx := context.Background() + + val, err := s.LoadSecret(ctx, tt.args.orgID, tt.args.key) + if (err != nil) != (tt.wants.err != nil) { + t.Fatalf("expected error '%v' got '%v'", tt.wants.err, err) + } + + if err != nil && tt.wants.err != nil { + if err.Error() != tt.wants.err.Error() { + t.Fatalf("expected error messages to match '%v' got '%v'", tt.wants.err, err.Error()) + } + } + + if want, got := tt.wants.value, val; want != got { + t.Errorf("expected value to be %s, got %s", want, got) + } + }) + } +} + +// PutSecret tests the PutSecret method for the SecretService interface. +func PutSecret( + init func(f SecretServiceFields, t *testing.T) (platform.SecretService, func()), + t *testing.T, +) { + type args struct { + orgID platform.ID + key string + value string + } + type wants struct { + err error + } + + tests := []struct { + name string + fields SecretServiceFields + args args + wants wants + }{ + { + name: "put secret", + fields: SecretServiceFields{}, + args: args{ + orgID: platform.ID(1), + key: "api_key", + value: "abc123xyz", + }, + wants: wants{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s, done := init(tt.fields, t) + defer done() + ctx := context.Background() + + err := s.PutSecret(ctx, tt.args.orgID, tt.args.key, tt.args.value) + if (err != nil) != (tt.wants.err != nil) { + t.Fatalf("expected error '%v' got '%v'", tt.wants.err, err) + } + + if err != nil && tt.wants.err != nil { + if err.Error() != tt.wants.err.Error() { + t.Fatalf("expected error messages to match '%v' got '%v'", tt.wants.err, err.Error()) + } + } + + val, err := s.LoadSecret(ctx, tt.args.orgID, tt.args.key) + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + if want, got := tt.args.value, val; want != got { + t.Errorf("expected value to be %s, got %s", want, got) + } + }) + } +} + +// GetSecretKeys tests the GetSecretKeys method for the SecretService interface. +func GetSecretKeys( + init func(f SecretServiceFields, t *testing.T) (platform.SecretService, func()), + t *testing.T, +) { + type args struct { + orgID platform.ID + } + type wants struct { + keys []string + err error + } + + tests := []struct { + name string + fields SecretServiceFields + args args + wants wants + }{ + { + name: "get secret keys for one org", + fields: SecretServiceFields{ + Secrets: []Secret{ + { + OrganizationID: platform.ID(1), + Env: map[string]string{ + "api_key": "abc123xyz", + }, + }, + { + OrganizationID: platform.ID(2), + Env: map[string]string{ + "api_key": "zyx321cba", + }, + }, + }, + }, + args: args{ + orgID: platform.ID(1), + }, + wants: wants{ + keys: []string{"api_key"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s, done := init(tt.fields, t) + defer done() + ctx := context.Background() + + keys, err := s.GetSecretKeys(ctx, tt.args.orgID) + if (err != nil) != (tt.wants.err != nil) { + t.Fatalf("expected error '%v' got '%v'", tt.wants.err, err) + } + + if err != nil && tt.wants.err != nil { + if err.Error() != tt.wants.err.Error() { + t.Fatalf("expected error messages to match '%v' got '%v'", tt.wants.err, err.Error()) + } + } + + if diff := cmp.Diff(keys, tt.wants.keys); diff != "" { + t.Errorf("keys are different -got/+want\ndiff %s", diff) + } + }) + } +}