package vault

import (
	"context"
	"encoding/json"
	"fmt"
	"strconv"

	"github.com/hashicorp/vault/api"
	platform "github.com/influxdata/influxdb"
)

var _ platform.SecretService = (*SecretService)(nil)

// SecretService is service for storing user secrets
type SecretService struct {
	Client *api.Client
}

// NewSecretService creates an instance of a SecretService.
// The service is configured using the standard vault environment variables.
// https://www.vaultproject.io/docs/commands/index.html#environment-variables
func NewSecretService() (*SecretService, error) {
	cfg := api.DefaultConfig()
	if err := cfg.ReadEnvironment(); err != nil {
		return nil, err
	}

	c, err := api.NewClient(cfg)
	if err != nil {
		return nil, err
	}

	return &SecretService{
		Client: c,
	}, nil
}

// LoadSecret retrieves the secret value v found at key k for organization orgID.
func (s *SecretService) LoadSecret(ctx context.Context, orgID platform.ID, k string) (string, error) {
	data, _, err := s.loadSecrets(ctx, orgID)
	if err != nil {
		return "", err
	}

	if v, ok := data[k]; ok {
		return v, nil
	}

	return "", fmt.Errorf("secret not found")
}

// loadSecrets retrieves a map of secrets for an organization and the version of the secrets retrieved.
// The version is used to ensure that concurrent updates will not overwrite one another.
func (s *SecretService) loadSecrets(ctx context.Context, orgID platform.ID) (map[string]string, int, error) {
	// TODO(desa): update url construction
	sec, err := s.Client.Logical().Read(fmt.Sprintf("/secret/data/%s", orgID))
	if err != nil {
		return nil, -1, err
	}

	m := map[string]string{}
	if sec == nil {
		return m, 0, nil
	}

	data, ok := sec.Data["data"].(map[string]interface{})
	if !ok {
		return nil, -1, fmt.Errorf("value found in secret data is not map[string]interface{}")
	}

	for k, v := range data {
		val, ok := v.(string)
		if !ok {
			continue
		}
		m[k] = val
	}

	metadata, ok := sec.Data["metadata"].(map[string]interface{})
	if !ok {
		return nil, -1, fmt.Errorf("value found in secret metadata is not map[string]interface{}")
	}

	var version int
	switch v := metadata["version"].(type) {
	case json.Number:
		ver, err := v.Int64()
		if err != nil {
			return nil, -1, err
		}
		version = int(ver)
	case string:
		ver, err := strconv.Atoi(v)
		if err != nil {
			return nil, -1, fmt.Errorf("version provided is not a valid integer: %v", err)
		}
		version = ver
	case int:
		version = v
	default:
		return nil, -1, fmt.Errorf("version provided is %T not a string or int", v)
	}

	return m, version, nil
}

// GetSecretKeys retrieves all secret keys that are stored for the organization orgID.
func (s *SecretService) GetSecretKeys(ctx context.Context, orgID platform.ID) ([]string, error) {
	data, _, err := s.loadSecrets(ctx, orgID)
	if err != nil {
		return nil, err
	}

	keys := make([]string, 0, len(data))
	for k := range data {
		keys = append(keys, k)
	}

	return keys, nil
}

// PutSecret stores the secret pair (k,v) for the organization orgID.
func (s *SecretService) PutSecret(ctx context.Context, orgID platform.ID, k string, v string) error {
	data, ver, err := s.loadSecrets(ctx, orgID)
	if err != nil {
		return err
	}

	data[k] = v

	return s.putSecrets(ctx, orgID, data, ver)
}

// putSecrets will set all provided data values for the organization orgID.
// If version is negative, the write will overwrite all specified values.
// If version is 0, the write will only be allowed if the keys do not exists.
// If version is non-zero, the write will only be allowed if the keys current
// version in vault matches the version specified.
func (s *SecretService) putSecrets(ctx context.Context, orgID platform.ID, data map[string]string, version int) error {
	m := map[string]interface{}{"data": data}

	if version >= 0 {
		m["options"] = map[string]interface{}{"cas": version}
	}

	if _, err := s.Client.Logical().Write(fmt.Sprintf("/secret/data/%s", orgID), m); err != nil {
		return err
	}

	return nil
}

// PutSecrets puts all provided secrets and overwrites any previous values.
func (s *SecretService) PutSecrets(ctx context.Context, orgID platform.ID, m map[string]string) error {
	return s.putSecrets(ctx, orgID, m, -1)
}

// PatchSecrets patches all provided secrets and updates any previous values.
func (s *SecretService) PatchSecrets(ctx context.Context, orgID platform.ID, m map[string]string) error {
	data, ver, err := s.loadSecrets(ctx, orgID)
	if err != nil {
		return err
	}

	for k, v := range m {
		data[k] = v
	}

	return s.putSecrets(ctx, orgID, data, ver)
}

// DeleteSecret removes a single secret from the secret store.
func (s *SecretService) DeleteSecret(ctx context.Context, orgID platform.ID, ks ...string) error {
	data, ver, err := s.loadSecrets(ctx, orgID)
	if err != nil {
		return err
	}

	for _, k := range ks {
		delete(data, k)
	}

	return s.putSecrets(ctx, orgID, data, ver)
}