chore(code): reduce the code duplication EE-7278 (#11969)
parent
39bdfa4512
commit
9ee092aa5e
|
@ -10,7 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/internal/url"
|
||||
"github.com/portainer/portainer/api/url"
|
||||
)
|
||||
|
||||
// GetAgentVersionAndPlatform returns the agent version and platform
|
||||
|
|
|
@ -3,7 +3,6 @@ package apikey
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/api/internal/securecookie"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -34,7 +33,7 @@ func Test_generateRandomKey(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := securecookie.GenerateRandomKey(tt.wantLenth)
|
||||
got := GenerateRandomKey(tt.wantLenth)
|
||||
is.Equal(tt.wantLenth, len(got))
|
||||
})
|
||||
}
|
||||
|
@ -42,7 +41,7 @@ func Test_generateRandomKey(t *testing.T) {
|
|||
t.Run("Generated keys are unique", func(t *testing.T) {
|
||||
keys := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
key := securecookie.GenerateRandomKey(8)
|
||||
key := GenerateRandomKey(8)
|
||||
_, ok := keys[string(key)]
|
||||
is.False(ok)
|
||||
keys[string(key)] = true
|
||||
|
|
|
@ -1,69 +1,79 @@
|
|||
package apikey
|
||||
|
||||
import (
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
)
|
||||
|
||||
const defaultAPIKeyCacheSize = 1024
|
||||
const DefaultAPIKeyCacheSize = 1024
|
||||
|
||||
// entry is a tuple containing the user and API key associated to an API key digest
|
||||
type entry struct {
|
||||
user portainer.User
|
||||
type entry[T any] struct {
|
||||
user T
|
||||
apiKey portainer.APIKey
|
||||
}
|
||||
|
||||
// apiKeyCache is a concurrency-safe, in-memory cache which primarily exists for to reduce database roundtrips.
|
||||
type UserCompareFn[T any] func(T, portainer.UserID) bool
|
||||
|
||||
// ApiKeyCache is a concurrency-safe, in-memory cache which primarily exists for to reduce database roundtrips.
|
||||
// We store the api-key digest (keys) and the associated user and key-data (values) in the cache.
|
||||
// This is required because HTTP requests will contain only the api-key digest in the x-api-key request header;
|
||||
// digest value must be mapped to a portainer user (and respective key data) for validation.
|
||||
// This cache is used to avoid multiple database queries to retrieve these user/key associated to the digest.
|
||||
type apiKeyCache struct {
|
||||
type ApiKeyCache[T any] struct {
|
||||
// cache type [string]entry cache (key: string(digest), value: user/key entry)
|
||||
// note: []byte keys are not supported by golang-lru Cache
|
||||
cache *lru.Cache
|
||||
cache *lru.Cache
|
||||
userCmpFn UserCompareFn[T]
|
||||
}
|
||||
|
||||
// NewAPIKeyCache creates a new cache for API keys
|
||||
func NewAPIKeyCache(cacheSize int) *apiKeyCache {
|
||||
func NewAPIKeyCache[T any](cacheSize int, userCompareFn UserCompareFn[T]) *ApiKeyCache[T] {
|
||||
cache, _ := lru.New(cacheSize)
|
||||
return &apiKeyCache{cache: cache}
|
||||
|
||||
return &ApiKeyCache[T]{cache: cache, userCmpFn: userCompareFn}
|
||||
}
|
||||
|
||||
// Get returns the user/key associated to an api-key's digest
|
||||
// This is required because HTTP requests will contain the digest of the API key in header,
|
||||
// the digest value must be mapped to a portainer user.
|
||||
func (c *apiKeyCache) Get(digest string) (portainer.User, portainer.APIKey, bool) {
|
||||
func (c *ApiKeyCache[T]) Get(digest string) (T, portainer.APIKey, bool) {
|
||||
val, ok := c.cache.Get(digest)
|
||||
if !ok {
|
||||
return portainer.User{}, portainer.APIKey{}, false
|
||||
var t T
|
||||
|
||||
return t, portainer.APIKey{}, false
|
||||
}
|
||||
tuple := val.(entry)
|
||||
|
||||
tuple := val.(entry[T])
|
||||
|
||||
return tuple.user, tuple.apiKey, true
|
||||
}
|
||||
|
||||
// Set persists a user/key entry to the cache
|
||||
func (c *apiKeyCache) Set(digest string, user portainer.User, apiKey portainer.APIKey) {
|
||||
c.cache.Add(digest, entry{
|
||||
func (c *ApiKeyCache[T]) Set(digest string, user T, apiKey portainer.APIKey) {
|
||||
c.cache.Add(digest, entry[T]{
|
||||
user: user,
|
||||
apiKey: apiKey,
|
||||
})
|
||||
}
|
||||
|
||||
// Delete evicts a digest's user/key entry key from the cache
|
||||
func (c *apiKeyCache) Delete(digest string) {
|
||||
func (c *ApiKeyCache[T]) Delete(digest string) {
|
||||
c.cache.Remove(digest)
|
||||
}
|
||||
|
||||
// InvalidateUserKeyCache loops through all the api-keys associated to a user and removes them from the cache
|
||||
func (c *apiKeyCache) InvalidateUserKeyCache(userId portainer.UserID) bool {
|
||||
func (c *ApiKeyCache[T]) InvalidateUserKeyCache(userId portainer.UserID) bool {
|
||||
present := false
|
||||
|
||||
for _, k := range c.cache.Keys() {
|
||||
user, _, _ := c.Get(k.(string))
|
||||
if user.ID == userId {
|
||||
if c.userCmpFn(user, userId) {
|
||||
present = c.cache.Remove(k)
|
||||
}
|
||||
}
|
||||
|
||||
return present
|
||||
}
|
||||
|
|
|
@ -10,11 +10,11 @@ import (
|
|||
func Test_apiKeyCacheGet(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
keyCache := NewAPIKeyCache(10)
|
||||
keyCache := NewAPIKeyCache(10, compareUser)
|
||||
|
||||
// pre-populate cache
|
||||
keyCache.cache.Add(string("foo"), entry{user: portainer.User{}, apiKey: portainer.APIKey{}})
|
||||
keyCache.cache.Add(string(""), entry{user: portainer.User{}, apiKey: portainer.APIKey{}})
|
||||
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{}, apiKey: portainer.APIKey{}})
|
||||
keyCache.cache.Add(string(""), entry[portainer.User]{user: portainer.User{}, apiKey: portainer.APIKey{}})
|
||||
|
||||
tests := []struct {
|
||||
digest string
|
||||
|
@ -45,7 +45,7 @@ func Test_apiKeyCacheGet(t *testing.T) {
|
|||
func Test_apiKeyCacheSet(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
keyCache := NewAPIKeyCache(10)
|
||||
keyCache := NewAPIKeyCache(10, compareUser)
|
||||
|
||||
// pre-populate cache
|
||||
keyCache.Set("bar", portainer.User{ID: 2}, portainer.APIKey{})
|
||||
|
@ -57,23 +57,23 @@ func Test_apiKeyCacheSet(t *testing.T) {
|
|||
val, ok := keyCache.cache.Get(string("bar"))
|
||||
is.True(ok)
|
||||
|
||||
tuple := val.(entry)
|
||||
tuple := val.(entry[portainer.User])
|
||||
is.Equal(portainer.User{ID: 2}, tuple.user)
|
||||
|
||||
val, ok = keyCache.cache.Get(string("foo"))
|
||||
is.True(ok)
|
||||
|
||||
tuple = val.(entry)
|
||||
tuple = val.(entry[portainer.User])
|
||||
is.Equal(portainer.User{ID: 3}, tuple.user)
|
||||
}
|
||||
|
||||
func Test_apiKeyCacheDelete(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
keyCache := NewAPIKeyCache(10)
|
||||
keyCache := NewAPIKeyCache(10, compareUser)
|
||||
|
||||
t.Run("Delete an existing entry", func(t *testing.T) {
|
||||
keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
|
||||
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
|
||||
keyCache.Delete("foo")
|
||||
|
||||
_, ok := keyCache.cache.Get(string("foo"))
|
||||
|
@ -128,7 +128,7 @@ func Test_apiKeyCacheLRU(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
keyCache := NewAPIKeyCache(test.cacheLen)
|
||||
keyCache := NewAPIKeyCache(test.cacheLen, compareUser)
|
||||
|
||||
for _, key := range test.key {
|
||||
keyCache.Set(key, portainer.User{ID: 1}, portainer.APIKey{})
|
||||
|
@ -150,10 +150,10 @@ func Test_apiKeyCacheLRU(t *testing.T) {
|
|||
func Test_apiKeyCacheInvalidateUserKeyCache(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
keyCache := NewAPIKeyCache(10)
|
||||
keyCache := NewAPIKeyCache(10, compareUser)
|
||||
|
||||
t.Run("Removes users keys from cache", func(t *testing.T) {
|
||||
keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
|
||||
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
|
||||
|
||||
ok := keyCache.InvalidateUserKeyCache(1)
|
||||
is.True(ok)
|
||||
|
@ -163,8 +163,8 @@ func Test_apiKeyCacheInvalidateUserKeyCache(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Does not affect other keys", func(t *testing.T) {
|
||||
keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
|
||||
keyCache.cache.Add(string("bar"), entry{user: portainer.User{ID: 2}, apiKey: portainer.APIKey{}})
|
||||
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
|
||||
keyCache.cache.Add(string("bar"), entry[portainer.User]{user: portainer.User{ID: 2}, apiKey: portainer.APIKey{}})
|
||||
|
||||
ok := keyCache.InvalidateUserKeyCache(1)
|
||||
is.True(ok)
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
package apikey
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/internal/securecookie"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
@ -20,30 +21,45 @@ var ErrInvalidAPIKey = errors.New("Invalid API key")
|
|||
type apiKeyService struct {
|
||||
apiKeyRepository dataservices.APIKeyRepository
|
||||
userRepository dataservices.UserService
|
||||
cache *apiKeyCache
|
||||
cache *ApiKeyCache[portainer.User]
|
||||
}
|
||||
|
||||
// GenerateRandomKey generates a random key of specified length
|
||||
// source: https://github.com/gorilla/securecookie/blob/master/securecookie.go#L515
|
||||
func GenerateRandomKey(length int) []byte {
|
||||
k := make([]byte, length)
|
||||
if _, err := io.ReadFull(rand.Reader, k); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return k
|
||||
}
|
||||
|
||||
func compareUser(u portainer.User, id portainer.UserID) bool {
|
||||
return u.ID == id
|
||||
}
|
||||
|
||||
func NewAPIKeyService(apiKeyRepository dataservices.APIKeyRepository, userRepository dataservices.UserService) *apiKeyService {
|
||||
return &apiKeyService{
|
||||
apiKeyRepository: apiKeyRepository,
|
||||
userRepository: userRepository,
|
||||
cache: NewAPIKeyCache(defaultAPIKeyCacheSize),
|
||||
cache: NewAPIKeyCache(DefaultAPIKeyCacheSize, compareUser),
|
||||
}
|
||||
}
|
||||
|
||||
// HashRaw computes a hash digest of provided raw API key.
|
||||
func (a *apiKeyService) HashRaw(rawKey string) string {
|
||||
hashDigest := sha256.Sum256([]byte(rawKey))
|
||||
|
||||
return base64.StdEncoding.EncodeToString(hashDigest[:])
|
||||
}
|
||||
|
||||
// GenerateApiKey generates a raw API key for a user (for one-time display).
|
||||
// The generated API key is stored in the cache and database.
|
||||
func (a *apiKeyService) GenerateApiKey(user portainer.User, description string) (string, *portainer.APIKey, error) {
|
||||
randKey := securecookie.GenerateRandomKey(32)
|
||||
randKey := GenerateRandomKey(32)
|
||||
encodedRawAPIKey := base64.StdEncoding.EncodeToString(randKey)
|
||||
prefixedAPIKey := portainerAPIKeyPrefix + encodedRawAPIKey
|
||||
|
||||
hashDigest := a.HashRaw(prefixedAPIKey)
|
||||
|
||||
apiKey := &portainer.APIKey{
|
||||
|
@ -54,8 +70,7 @@ func (a *apiKeyService) GenerateApiKey(user portainer.User, description string)
|
|||
Digest: hashDigest,
|
||||
}
|
||||
|
||||
err := a.apiKeyRepository.Create(apiKey)
|
||||
if err != nil {
|
||||
if err := a.apiKeyRepository.Create(apiKey); err != nil {
|
||||
return "", nil, errors.Wrap(err, "Unable to create API key")
|
||||
}
|
||||
|
||||
|
@ -78,7 +93,6 @@ func (a *apiKeyService) GetAPIKeys(userID portainer.UserID) ([]portainer.APIKey,
|
|||
// GetDigestUserAndKey returns the user and api-key associated to a specified hash digest.
|
||||
// A cache lookup is performed first; if the user/api-key is not found in the cache, respective database lookups are performed.
|
||||
func (a *apiKeyService) GetDigestUserAndKey(digest string) (portainer.User, portainer.APIKey, error) {
|
||||
// get api key from cache if possible
|
||||
cachedUser, cachedKey, ok := a.cache.Get(digest)
|
||||
if ok {
|
||||
return cachedUser, cachedKey, nil
|
||||
|
@ -106,20 +120,21 @@ func (a *apiKeyService) UpdateAPIKey(apiKey *portainer.APIKey) error {
|
|||
if err != nil {
|
||||
return errors.Wrap(err, "Unable to retrieve API key")
|
||||
}
|
||||
|
||||
a.cache.Set(apiKey.Digest, user, *apiKey)
|
||||
|
||||
return a.apiKeyRepository.Update(apiKey.ID, apiKey)
|
||||
}
|
||||
|
||||
// DeleteAPIKey deletes an API key and removes the digest/api-key entry from the cache.
|
||||
func (a *apiKeyService) DeleteAPIKey(apiKeyID portainer.APIKeyID) error {
|
||||
// get api-key digest to remove from cache
|
||||
apiKey, err := a.apiKeyRepository.Read(apiKeyID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, fmt.Sprintf("Unable to retrieve API key: %d", apiKeyID))
|
||||
}
|
||||
|
||||
// delete the user/api-key from cache
|
||||
a.cache.Delete(apiKey.Digest)
|
||||
|
||||
return a.apiKeyRepository.Delete(apiKeyID)
|
||||
}
|
||||
|
||||
|
|
|
@ -17,17 +17,14 @@ import (
|
|||
type Service struct{}
|
||||
|
||||
var (
|
||||
errInvalidEndpointProtocol = errors.New("Invalid environment protocol: Portainer only supports unix://, npipe:// or tcp://")
|
||||
errSocketOrNamedPipeNotFound = errors.New("Unable to locate Unix socket or named pipe")
|
||||
errInvalidSnapshotInterval = errors.New("Invalid snapshot interval")
|
||||
errAdminPassExcludeAdminPassFile = errors.New("Cannot use --admin-password with --admin-password-file")
|
||||
ErrInvalidEndpointProtocol = errors.New("Invalid environment protocol: Portainer only supports unix://, npipe:// or tcp://")
|
||||
ErrSocketOrNamedPipeNotFound = errors.New("Unable to locate Unix socket or named pipe")
|
||||
ErrInvalidSnapshotInterval = errors.New("Invalid snapshot interval")
|
||||
ErrAdminPassExcludeAdminPassFile = errors.New("Cannot use --admin-password with --admin-password-file")
|
||||
)
|
||||
|
||||
// ParseFlags parse the CLI flags and return a portainer.Flags struct
|
||||
func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
||||
kingpin.Version(version)
|
||||
|
||||
flags := &portainer.CLIFlags{
|
||||
func CLIFlags() *portainer.CLIFlags {
|
||||
return &portainer.CLIFlags{
|
||||
Addr: kingpin.Flag("bind", "Address and port to serve Portainer").Default(defaultBindAddress).Short('p').String(),
|
||||
AddrHTTPS: kingpin.Flag("bind-https", "Address and port to serve Portainer via https").Default(defaultHTTPSBindAddress).String(),
|
||||
TunnelAddr: kingpin.Flag("tunnel-addr", "Address to serve the tunnel server").Default(defaultTunnelServerAddress).String(),
|
||||
|
@ -63,6 +60,13 @@ func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
|||
LogLevel: kingpin.Flag("log-level", "Set the minimum logging level to show").Default("INFO").Enum("DEBUG", "INFO", "WARN", "ERROR"),
|
||||
LogMode: kingpin.Flag("log-mode", "Set the logging output mode").Default("PRETTY").Enum("NOCOLOR", "PRETTY", "JSON"),
|
||||
}
|
||||
}
|
||||
|
||||
// ParseFlags parse the CLI flags and return a portainer.Flags struct
|
||||
func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
||||
kingpin.Version(version)
|
||||
|
||||
flags := CLIFlags()
|
||||
|
||||
kingpin.Parse()
|
||||
|
||||
|
@ -82,18 +86,16 @@ func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
|||
func (*Service) ValidateFlags(flags *portainer.CLIFlags) error {
|
||||
displayDeprecationWarnings(flags)
|
||||
|
||||
err := validateEndpointURL(*flags.EndpointURL)
|
||||
if err != nil {
|
||||
if err := validateEndpointURL(*flags.EndpointURL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateSnapshotInterval(*flags.SnapshotInterval)
|
||||
if err != nil {
|
||||
if err := validateSnapshotInterval(*flags.SnapshotInterval); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *flags.AdminPassword != "" && *flags.AdminPasswordFile != "" {
|
||||
return errAdminPassExcludeAdminPassFile
|
||||
return ErrAdminPassExcludeAdminPassFile
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -115,15 +117,16 @@ func validateEndpointURL(endpointURL string) error {
|
|||
}
|
||||
|
||||
if !strings.HasPrefix(endpointURL, "unix://") && !strings.HasPrefix(endpointURL, "tcp://") && !strings.HasPrefix(endpointURL, "npipe://") {
|
||||
return errInvalidEndpointProtocol
|
||||
return ErrInvalidEndpointProtocol
|
||||
}
|
||||
|
||||
if strings.HasPrefix(endpointURL, "unix://") || strings.HasPrefix(endpointURL, "npipe://") {
|
||||
socketPath := strings.TrimPrefix(endpointURL, "unix://")
|
||||
socketPath = strings.TrimPrefix(socketPath, "npipe://")
|
||||
|
||||
if _, err := os.Stat(socketPath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return errSocketOrNamedPipeNotFound
|
||||
return ErrSocketOrNamedPipeNotFound
|
||||
}
|
||||
|
||||
return err
|
||||
|
@ -138,9 +141,8 @@ func validateSnapshotInterval(snapshotInterval string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
_, err := time.ParseDuration(snapshotInterval)
|
||||
if err != nil {
|
||||
return errInvalidSnapshotInterval
|
||||
if _, err := time.ParseDuration(snapshotInterval); err != nil {
|
||||
return ErrInvalidSnapshotInterval
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -56,14 +56,14 @@ import (
|
|||
)
|
||||
|
||||
func initCLI() *portainer.CLIFlags {
|
||||
var cliService portainer.CLIService = &cli.Service{}
|
||||
cliService := &cli.Service{}
|
||||
|
||||
flags, err := cliService.ParseFlags(portainer.APIVersion)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed parsing flags")
|
||||
}
|
||||
|
||||
err = cliService.ValidateFlags(flags)
|
||||
if err != nil {
|
||||
if err := cliService.ValidateFlags(flags); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed validating flags")
|
||||
}
|
||||
|
||||
|
@ -94,14 +94,14 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
|
|||
}
|
||||
|
||||
store := datastore.NewStore(*flags.Data, fileService, connection)
|
||||
|
||||
isNew, err := store.Open()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed opening store")
|
||||
}
|
||||
|
||||
if *flags.Rollback {
|
||||
err := store.Rollback(false)
|
||||
if err != nil {
|
||||
if err := store.Rollback(false); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed rolling back")
|
||||
}
|
||||
|
||||
|
@ -110,8 +110,7 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
|
|||
}
|
||||
|
||||
// Init sets some defaults - it's basically a migration
|
||||
err = store.Init()
|
||||
if err != nil {
|
||||
if err := store.Init(); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing data store")
|
||||
}
|
||||
|
||||
|
@ -133,25 +132,23 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
|
|||
}
|
||||
store.VersionService.UpdateVersion(&v)
|
||||
|
||||
err = updateSettingsFromFlags(store, flags)
|
||||
if err != nil {
|
||||
if err := updateSettingsFromFlags(store, flags); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed updating settings from flags")
|
||||
}
|
||||
} else {
|
||||
err = store.MigrateData()
|
||||
if err != nil {
|
||||
if err := store.MigrateData(); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed migration")
|
||||
}
|
||||
}
|
||||
|
||||
err = updateSettingsFromFlags(store, flags)
|
||||
if err != nil {
|
||||
if err := updateSettingsFromFlags(store, flags); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed updating settings from flags")
|
||||
}
|
||||
|
||||
// this is for the db restore functionality - needs more tests.
|
||||
go func() {
|
||||
<-shutdownCtx.Done()
|
||||
|
||||
defer connection.Close()
|
||||
}()
|
||||
|
||||
|
@ -205,36 +202,16 @@ func initJWTService(userSessionTimeout string, dataStore dataservices.DataStore)
|
|||
userSessionTimeout = portainer.DefaultUserSessionTimeout
|
||||
}
|
||||
|
||||
jwtService, err := jwt.NewService(userSessionTimeout, dataStore)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return jwtService, nil
|
||||
return jwt.NewService(userSessionTimeout, dataStore)
|
||||
}
|
||||
|
||||
func initDigitalSignatureService() portainer.DigitalSignatureService {
|
||||
return crypto.NewECDSAService(os.Getenv("AGENT_SECRET"))
|
||||
}
|
||||
|
||||
func initCryptoService() portainer.CryptoService {
|
||||
return &crypto.Service{}
|
||||
}
|
||||
|
||||
func initLDAPService() portainer.LDAPService {
|
||||
return &ldap.Service{}
|
||||
}
|
||||
|
||||
func initOAuthService() portainer.OAuthService {
|
||||
return oauth.NewService()
|
||||
}
|
||||
|
||||
func initGitService(ctx context.Context) portainer.GitService {
|
||||
return git.NewService(ctx)
|
||||
}
|
||||
|
||||
func initSSLService(addr, certPath, keyPath string, fileService portainer.FileService, dataStore dataservices.DataStore, shutdownTrigger context.CancelFunc) (*ssl.Service, error) {
|
||||
slices := strings.Split(addr, ":")
|
||||
|
||||
host := slices[0]
|
||||
if host == "" {
|
||||
host = "0.0.0.0"
|
||||
|
@ -242,22 +219,13 @@ func initSSLService(addr, certPath, keyPath string, fileService portainer.FileSe
|
|||
|
||||
sslService := ssl.NewService(fileService, dataStore, shutdownTrigger)
|
||||
|
||||
err := sslService.Init(host, certPath, keyPath)
|
||||
if err != nil {
|
||||
if err := sslService.Init(host, certPath, keyPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sslService, nil
|
||||
}
|
||||
|
||||
func initDockerClientFactory(signatureService portainer.DigitalSignatureService, reverseTunnelService portainer.ReverseTunnelService) *dockerclient.ClientFactory {
|
||||
return dockerclient.NewClientFactory(signatureService, reverseTunnelService)
|
||||
}
|
||||
|
||||
func initKubernetesClientFactory(signatureService portainer.DigitalSignatureService, reverseTunnelService portainer.ReverseTunnelService, dataStore dataservices.DataStore, instanceID, addrHTTPS, userSessionTimeout string) (*kubecli.ClientFactory, error) {
|
||||
return kubecli.NewClientFactory(signatureService, reverseTunnelService, dataStore, instanceID, addrHTTPS, userSessionTimeout)
|
||||
}
|
||||
|
||||
func initSnapshotService(
|
||||
snapshotIntervalFromFlag string,
|
||||
dataStore dataservices.DataStore,
|
||||
|
@ -310,14 +278,12 @@ func updateSettingsFromFlags(dataStore dataservices.DataStore, flags *portainer.
|
|||
settings.BlackListedLabels = *flags.Labels
|
||||
}
|
||||
|
||||
settings.AgentSecret = ""
|
||||
if agentKey, ok := os.LookupEnv("AGENT_SECRET"); ok {
|
||||
settings.AgentSecret = agentKey
|
||||
} else {
|
||||
settings.AgentSecret = ""
|
||||
}
|
||||
|
||||
err = dataStore.Settings().UpdateSettings(settings)
|
||||
if err != nil {
|
||||
if err := dataStore.Settings().UpdateSettings(settings); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -340,6 +306,7 @@ func loadAndParseKeyPair(fileService portainer.FileService, signatureService por
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return signatureService.ParseKeyPair(private, public)
|
||||
}
|
||||
|
||||
|
@ -348,7 +315,9 @@ func generateAndStoreKeyPair(fileService portainer.FileService, signatureService
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
privateHeader, publicHeader := signatureService.PEMHeaders()
|
||||
|
||||
return fileService.StoreKeyPair(private, public, privateHeader, publicHeader)
|
||||
}
|
||||
|
||||
|
@ -361,6 +330,7 @@ func initKeyPair(fileService portainer.FileService, signatureService portainer.D
|
|||
if existingKeyPair {
|
||||
return loadAndParseKeyPair(fileService, signatureService)
|
||||
}
|
||||
|
||||
return generateAndStoreKeyPair(fileService, signatureService)
|
||||
}
|
||||
|
||||
|
@ -378,6 +348,7 @@ func loadEncryptionSecretKey(keyfilename string) []byte {
|
|||
|
||||
// return a 32 byte hash of the secret (required for AES)
|
||||
hash := sha256.Sum256(content)
|
||||
|
||||
return hash[:]
|
||||
}
|
||||
|
||||
|
@ -422,17 +393,17 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
|||
log.Fatal().Err(err).Msg("failed initializing JWT service")
|
||||
}
|
||||
|
||||
ldapService := initLDAPService()
|
||||
ldapService := &ldap.Service{}
|
||||
|
||||
oauthService := initOAuthService()
|
||||
oauthService := oauth.NewService()
|
||||
|
||||
gitService := initGitService(shutdownCtx)
|
||||
gitService := git.NewService(shutdownCtx)
|
||||
|
||||
openAMTService := openamt.NewService()
|
||||
|
||||
cryptoService := initCryptoService()
|
||||
cryptoService := &crypto.Service{}
|
||||
|
||||
digitalSignatureService := initDigitalSignatureService()
|
||||
signatureService := initDigitalSignatureService()
|
||||
|
||||
edgeStacksService := edgestacks.NewService(dataStore)
|
||||
|
||||
|
@ -446,15 +417,18 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
|||
log.Fatal().Err(err).Msg("failed to get SSL settings")
|
||||
}
|
||||
|
||||
err = initKeyPair(fileService, digitalSignatureService)
|
||||
if err != nil {
|
||||
if err := initKeyPair(fileService, signatureService); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing key pair")
|
||||
}
|
||||
|
||||
reverseTunnelService := chisel.NewService(dataStore, shutdownCtx, fileService)
|
||||
|
||||
dockerClientFactory := initDockerClientFactory(digitalSignatureService, reverseTunnelService)
|
||||
kubernetesClientFactory, err := initKubernetesClientFactory(digitalSignatureService, reverseTunnelService, dataStore, instanceID, *flags.AddrHTTPS, settings.UserSessionTimeout)
|
||||
dockerClientFactory := dockerclient.NewClientFactory(signatureService, reverseTunnelService)
|
||||
|
||||
kubernetesClientFactory, err := kubecli.NewClientFactory(signatureService, reverseTunnelService, dataStore, instanceID, *flags.AddrHTTPS, settings.UserSessionTimeout)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing Kubernetes Client Factory service")
|
||||
}
|
||||
|
||||
authorizationService := authorization.NewService(dataStore)
|
||||
authorizationService.K8sClientFactory = kubernetesClientFactory
|
||||
|
@ -476,12 +450,12 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
|||
|
||||
composeStackManager := initComposeStackManager(composeDeployer, proxyManager)
|
||||
|
||||
swarmStackManager, err := initSwarmStackManager(*flags.Assets, dockerConfigPath, digitalSignatureService, fileService, reverseTunnelService, dataStore)
|
||||
swarmStackManager, err := initSwarmStackManager(*flags.Assets, dockerConfigPath, signatureService, fileService, reverseTunnelService, dataStore)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing swarm stack manager")
|
||||
}
|
||||
|
||||
kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, digitalSignatureService, proxyManager, *flags.Assets)
|
||||
kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager, *flags.Assets)
|
||||
|
||||
pendingActionsService := pendingactions.NewService(dataStore, kubernetesClientFactory)
|
||||
pendingActionsService.RegisterHandler(actions.CleanNAPWithOverridePolicies, handlers.NewHandlerCleanNAPWithOverridePolicies(authorizationService, dataStore))
|
||||
|
@ -492,17 +466,17 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
|||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing snapshot service")
|
||||
}
|
||||
|
||||
snapshotService.Start()
|
||||
|
||||
proxyManager.NewProxyFactory(dataStore, digitalSignatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService)
|
||||
proxyManager.NewProxyFactory(dataStore, signatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService)
|
||||
|
||||
helmPackageManager, err := initHelmPackageManager(*flags.Assets)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed initializing helm package manager")
|
||||
}
|
||||
|
||||
err = edge.LoadEdgeJobs(dataStore, reverseTunnelService)
|
||||
if err != nil {
|
||||
if err := edge.LoadEdgeJobs(dataStore, reverseTunnelService); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed loading edge jobs from database")
|
||||
}
|
||||
|
||||
|
@ -514,6 +488,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
|||
go endpointutils.InitEndpoint(shutdownCtx, adminCreationDone, flags, dataStore, snapshotService)
|
||||
|
||||
adminPasswordHash := ""
|
||||
|
||||
if *flags.AdminPasswordFile != "" {
|
||||
content, err := fileService.GetFileContent(*flags.AdminPasswordFile, "")
|
||||
if err != nil {
|
||||
|
@ -536,14 +511,14 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
|||
|
||||
if len(users) == 0 {
|
||||
log.Info().Msg("created admin user with the given password.")
|
||||
|
||||
user := &portainer.User{
|
||||
Username: "admin",
|
||||
Role: portainer.AdministratorRole,
|
||||
Password: adminPasswordHash,
|
||||
}
|
||||
|
||||
err := dataStore.User().Create(user)
|
||||
if err != nil {
|
||||
if err := dataStore.User().Create(user); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed creating admin user")
|
||||
}
|
||||
|
||||
|
@ -554,8 +529,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
|||
}
|
||||
}
|
||||
|
||||
err = reverseTunnelService.StartTunnelServer(*flags.TunnelAddr, *flags.TunnelPort, snapshotService)
|
||||
if err != nil {
|
||||
if err := reverseTunnelService.StartTunnelServer(*flags.TunnelAddr, *flags.TunnelPort, snapshotService); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed starting tunnel server")
|
||||
}
|
||||
|
||||
|
@ -613,7 +587,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
|||
ProxyManager: proxyManager,
|
||||
KubernetesTokenCacheManager: kubernetesTokenCacheManager,
|
||||
KubeClusterAccessService: kubeClusterAccessService,
|
||||
SignatureService: digitalSignatureService,
|
||||
SignatureService: signatureService,
|
||||
SnapshotService: snapshotService,
|
||||
SSLService: sslService,
|
||||
DockerClientFactory: dockerClientFactory,
|
||||
|
@ -639,6 +613,7 @@ func main() {
|
|||
|
||||
for {
|
||||
server := buildServer(flags)
|
||||
|
||||
log.Info().
|
||||
Str("version", portainer.APIVersion).
|
||||
Str("build_number", build.BuildNumber).
|
||||
|
|
|
@ -203,6 +203,7 @@ func (connection *DbConnection) ExportRaw(filename string) error {
|
|||
func (connection *DbConnection) ConvertToKey(v int) []byte {
|
||||
b := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(b, uint64(v))
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
|
|
|
@ -46,8 +46,8 @@ func (connection *DbConnection) UnmarshalObject(data []byte, object interface{})
|
|||
return errors.Wrap(err, "Failed decrypting object")
|
||||
}
|
||||
}
|
||||
e := json.Unmarshal(data, object)
|
||||
if e != nil {
|
||||
|
||||
if e := json.Unmarshal(data, object); e != nil {
|
||||
// Special case for the VERSION bucket. Here we're not using json
|
||||
// So we need to return it as a string
|
||||
s, ok := object.(*string)
|
||||
|
@ -57,6 +57,7 @@ func (connection *DbConnection) UnmarshalObject(data []byte, object interface{})
|
|||
|
||||
*s = string(data)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -71,7 +72,7 @@ func encrypt(plaintext []byte, passphrase []byte) (encrypted []byte, err error)
|
|||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return encrypted, err
|
||||
}
|
||||
|
||||
|
|
|
@ -78,6 +78,7 @@ func (tx *DbTransaction) GetNextIdentifier(bucketName string) int {
|
|||
id, err := bucket.NextSequence()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("bucket", bucketName).Msg("failed to get the next identifier")
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
|
|
|
@ -111,5 +111,6 @@ func (store *Store) finishMigrateLegacyVersion(versionToWrite *models.Version) e
|
|||
store.connection.DeleteObject(bucketName, []byte(legacyDBVersionKey))
|
||||
store.connection.DeleteObject(bucketName, []byte(legacyEditionKey))
|
||||
store.connection.DeleteObject(bucketName, []byte(legacyInstanceKey))
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -39,20 +39,19 @@ func (m *Migrator) Migrate() error {
|
|||
latestMigrations := m.LatestMigrations()
|
||||
if latestMigrations.Version.Equal(schemaVersion) &&
|
||||
version.MigratorCount != len(latestMigrations.MigrationFuncs) {
|
||||
err := runMigrations(latestMigrations.MigrationFuncs)
|
||||
if err != nil {
|
||||
if err := runMigrations(latestMigrations.MigrationFuncs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newMigratorCount = len(latestMigrations.MigrationFuncs)
|
||||
}
|
||||
} else {
|
||||
// regular path when major/minor/patch versions differ
|
||||
for _, migration := range m.migrations {
|
||||
if schemaVersion.LessThan(migration.Version) {
|
||||
|
||||
log.Info().Msgf("migrating data to %s", migration.Version.String())
|
||||
err := runMigrations(migration.MigrationFuncs)
|
||||
if err != nil {
|
||||
|
||||
if err := runMigrations(migration.MigrationFuncs); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -63,16 +62,14 @@ func (m *Migrator) Migrate() error {
|
|||
}
|
||||
}
|
||||
|
||||
err = m.Always()
|
||||
if err != nil {
|
||||
if err := m.Always(); err != nil {
|
||||
return migrationError(err, "Always migrations returned error")
|
||||
}
|
||||
|
||||
version.SchemaVersion = portainer.APIVersion
|
||||
version.MigratorCount = newMigratorCount
|
||||
|
||||
err = m.versionService.UpdateVersion(version)
|
||||
if err != nil {
|
||||
if err := m.versionService.UpdateVersion(version); err != nil {
|
||||
return migrationError(err, "StoreDBVersion")
|
||||
}
|
||||
|
||||
|
@ -99,6 +96,7 @@ func (m *Migrator) NeedsMigration() bool {
|
|||
// In this particular instance we should log a fatal error
|
||||
if m.CurrentDBEdition() != portainer.PortainerCE {
|
||||
log.Fatal().Msg("the Portainer database is set for Portainer Business Edition, please follow the instructions in our documentation to downgrade it: https://documentation.portainer.io/v2.0-be/downgrade/be-to-ce/")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/chisel/crypto"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
|
@ -37,9 +38,11 @@ func (m *Migrator) convertSeedToPrivateKeyForDB100() error {
|
|||
log.Info().Msg("ServerInfo object not found")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Error().
|
||||
Err(err).
|
||||
Msg("Failed to read ServerInfo from DB")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -49,14 +52,15 @@ func (m *Migrator) convertSeedToPrivateKeyForDB100() error {
|
|||
log.Error().
|
||||
Err(err).
|
||||
Msg("Failed to read ServerInfo from DB")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.fileService.StoreChiselPrivateKey(key)
|
||||
if err != nil {
|
||||
if err := m.fileService.StoreChiselPrivateKey(key); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Msg("Failed to save Chisel private key to disk")
|
||||
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
|
@ -64,14 +68,14 @@ func (m *Migrator) convertSeedToPrivateKeyForDB100() error {
|
|||
}
|
||||
|
||||
serverInfo.PrivateKeySeed = ""
|
||||
err = m.TunnelServerService.UpdateInfo(serverInfo)
|
||||
if err != nil {
|
||||
if err := m.TunnelServerService.UpdateInfo(serverInfo); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Msg("Failed to clean private key seed in DB")
|
||||
} else {
|
||||
log.Info().Msg("Success to migrate private key seed to private key file")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -84,9 +88,8 @@ func (m *Migrator) updateEdgeStackStatusForDB100() error {
|
|||
}
|
||||
|
||||
for _, edgeStack := range edgeStacks {
|
||||
|
||||
for environmentID, environmentStatus := range edgeStack.Status {
|
||||
// skip if status is already updated
|
||||
// Skip if status is already updated
|
||||
if len(environmentStatus.Status) > 0 {
|
||||
continue
|
||||
}
|
||||
|
@ -146,8 +149,7 @@ func (m *Migrator) updateEdgeStackStatusForDB100() error {
|
|||
edgeStack.Status[environmentID] = environmentStatus
|
||||
}
|
||||
|
||||
err = m.edgeStackService.UpdateEdgeStack(edgeStack.ID, &edgeStack)
|
||||
if err != nil {
|
||||
if err := m.edgeStackService.UpdateEdgeStack(edgeStack.ID, &edgeStack); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,8 +32,8 @@ func (m *Migrator) updateStacksToDB24() error {
|
|||
for idx := range stacks {
|
||||
stack := &stacks[idx]
|
||||
stack.Status = portainer.StackStatusActive
|
||||
err := m.stackService.Update(stack.ID, stack)
|
||||
if err != nil {
|
||||
|
||||
if err := m.stackService.Update(stack.ID, stack); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -583,7 +583,6 @@
|
|||
"AuthenticationMethod": 1,
|
||||
"BlackListedLabels": [],
|
||||
"Edge": {
|
||||
"AsyncMode": false,
|
||||
"CommandInterval": 0,
|
||||
"PingInterval": 0,
|
||||
"SnapshotInterval": 0
|
||||
|
|
|
@ -52,27 +52,24 @@ func NewTestStore(t testing.TB, init, secure bool) (bool, *Store, func(), error)
|
|||
}
|
||||
|
||||
if init {
|
||||
err = store.Init()
|
||||
if err != nil {
|
||||
if err := store.Init(); err != nil {
|
||||
return newStore, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if newStore {
|
||||
// from MigrateData
|
||||
// From MigrateData
|
||||
v := models.Version{
|
||||
SchemaVersion: portainer.APIVersion,
|
||||
Edition: int(portainer.PortainerCE),
|
||||
}
|
||||
err = store.VersionService.UpdateVersion(&v)
|
||||
if err != nil {
|
||||
if err := store.VersionService.UpdateVersion(&v); err != nil {
|
||||
return newStore, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
teardown := func() {
|
||||
err := store.Close()
|
||||
if err != nil {
|
||||
if err := store.Close(); err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,7 +36,6 @@ func (c *ContainerService) Recreate(ctx context.Context, endpoint *portainer.End
|
|||
if err != nil {
|
||||
return nil, errors.Wrap(err, "create client error")
|
||||
}
|
||||
|
||||
defer cli.Close()
|
||||
|
||||
log.Debug().Str("container_id", containerId).Msg("starting to fetch container information")
|
||||
|
|
|
@ -5,10 +5,10 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
|
||||
gorillacsrf "github.com/gorilla/csrf"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
|
@ -16,8 +16,7 @@ func WithProtect(handler http.Handler) (http.Handler, error) {
|
|||
handler = withSendCSRFToken(handler)
|
||||
|
||||
token := make([]byte, 32)
|
||||
_, err := rand.Read(token)
|
||||
if err != nil {
|
||||
if _, err := rand.Read(token); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate CSRF token: %w", err)
|
||||
}
|
||||
|
||||
|
@ -32,7 +31,6 @@ func WithProtect(handler http.Handler) (http.Handler, error) {
|
|||
|
||||
func withSendCSRFToken(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
sw := negroni.NewResponseWriter(w)
|
||||
|
||||
sw.Before(func(sw negroni.ResponseWriter) {
|
||||
|
@ -44,16 +42,15 @@ func withSendCSRFToken(handler http.Handler) http.Handler {
|
|||
})
|
||||
|
||||
handler.ServeHTTP(sw, r)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func withSkipCSRF(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
skip, err := security.ShouldSkipCSRFCheck(r)
|
||||
if err != nil {
|
||||
httperror.WriteError(w, http.StatusForbidden, err.Error(), err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -56,8 +56,7 @@ func (payload *authenticatePayload) Validate(r *http.Request) error {
|
|||
// @router /auth [post]
|
||||
func (handler *Handler) authenticate(rw http.ResponseWriter, r *http.Request) *httperror.HandlerError {
|
||||
var payload authenticatePayload
|
||||
err := request.DecodeAndValidateJSONPayload(r, &payload)
|
||||
if err != nil {
|
||||
if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
|
||||
return httperror.BadRequest("Invalid request payload", err)
|
||||
}
|
||||
|
||||
|
@ -104,8 +103,7 @@ func isUserInitialAdmin(user *portainer.User) bool {
|
|||
}
|
||||
|
||||
func (handler *Handler) authenticateInternal(w http.ResponseWriter, user *portainer.User, password string) *httperror.HandlerError {
|
||||
err := handler.CryptoService.CompareHashAndData(user.Password, password)
|
||||
if err != nil {
|
||||
if err := handler.CryptoService.CompareHashAndData(user.Password, password); err != nil {
|
||||
return httperror.NewError(http.StatusUnprocessableEntity, "Invalid credentials", httperrors.ErrUnauthorized)
|
||||
}
|
||||
|
||||
|
@ -115,8 +113,7 @@ func (handler *Handler) authenticateInternal(w http.ResponseWriter, user *portai
|
|||
}
|
||||
|
||||
func (handler *Handler) authenticateLDAP(w http.ResponseWriter, user *portainer.User, username, password string, ldapSettings *portainer.LDAPSettings) *httperror.HandlerError {
|
||||
err := handler.LDAPService.AuthenticateUser(username, password, ldapSettings)
|
||||
if err != nil {
|
||||
if err := handler.LDAPService.AuthenticateUser(username, password, ldapSettings); err != nil {
|
||||
if errors.Is(err, httperrors.ErrUnauthorized) {
|
||||
return httperror.NewError(http.StatusUnprocessableEntity, "Invalid credentials", httperrors.ErrUnauthorized)
|
||||
}
|
||||
|
@ -131,14 +128,12 @@ func (handler *Handler) authenticateLDAP(w http.ResponseWriter, user *portainer.
|
|||
PortainerAuthorizations: authorization.DefaultPortainerAuthorizations(),
|
||||
}
|
||||
|
||||
err = handler.DataStore.User().Create(user)
|
||||
if err != nil {
|
||||
if err := handler.DataStore.User().Create(user); err != nil {
|
||||
return httperror.InternalServerError("Unable to persist user inside the database", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = handler.syncUserTeamsWithLDAPGroups(user, ldapSettings)
|
||||
if err != nil {
|
||||
if err := handler.syncUserTeamsWithLDAPGroups(user, ldapSettings); err != nil {
|
||||
log.Warn().Err(err).Msg("unable to automatically sync user teams with ldap")
|
||||
}
|
||||
|
||||
|
@ -186,7 +181,6 @@ func (handler *Handler) syncUserTeamsWithLDAPGroups(user *portainer.User, settin
|
|||
|
||||
for _, team := range teams {
|
||||
if teamExists(team.Name, userGroups) {
|
||||
|
||||
if teamMembershipExists(team.ID, userMemberships) {
|
||||
continue
|
||||
}
|
||||
|
@ -197,8 +191,7 @@ func (handler *Handler) syncUserTeamsWithLDAPGroups(user *portainer.User, settin
|
|||
Role: portainer.TeamMember,
|
||||
}
|
||||
|
||||
err := handler.DataStore.TeamMembership().Create(membership)
|
||||
if err != nil {
|
||||
if err := handler.DataStore.TeamMembership().Create(membership); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,5 +41,6 @@ func NewHandler(bouncer security.BouncerService, rateLimiter *security.RateLimit
|
|||
rateLimiter.LimitAccess(bouncer.PublicAccess(httperror.LoggerHandler(h.authenticate)))).Methods(http.MethodPost)
|
||||
h.Handle("/auth/logout",
|
||||
bouncer.PublicAccess(httperror.LoggerHandler(h.logout))).Methods(http.MethodPost)
|
||||
|
||||
return h
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/logoutcontext"
|
||||
"github.com/portainer/portainer/api/logoutcontext"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
)
|
||||
|
|
|
@ -4,14 +4,15 @@ import (
|
|||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/authorization"
|
||||
"github.com/portainer/portainer/api/internal/slices"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
|
@ -70,7 +71,7 @@ func (handler *Handler) customTemplateList(w http.ResponseWriter, r *http.Reques
|
|||
customTemplates = filterByType(customTemplates, templateTypes)
|
||||
|
||||
if edge != nil {
|
||||
customTemplates = slices.Filter(customTemplates, func(customTemplate portainer.CustomTemplate) bool {
|
||||
customTemplates = slicesx.Filter(customTemplates, func(customTemplate portainer.CustomTemplate) bool {
|
||||
return customTemplate.EdgeTemplate == *edge
|
||||
})
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
"github.com/portainer/portainer/api/docker/client"
|
||||
"github.com/portainer/portainer/api/http/handler/docker/utils"
|
||||
"github.com/portainer/portainer/api/internal/set"
|
||||
"github.com/portainer/portainer/api/set"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/authorization"
|
||||
"github.com/portainer/portainer/api/internal/slices"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
)
|
||||
|
||||
// filterByResourceControl filters a list of items based on the user's role and the resource control associated to the item.
|
||||
|
@ -16,7 +16,7 @@ func FilterByResourceControl[T any](tx dataservices.DataStoreTx, items []T, rcTy
|
|||
return items, nil
|
||||
}
|
||||
|
||||
userTeamIDs := slices.Map(securityContext.UserMemberships, func(membership portainer.TeamMembership) portainer.TeamID {
|
||||
userTeamIDs := slicesx.Map(securityContext.UserMemberships, func(membership portainer.TeamMembership) portainer.TeamID {
|
||||
return membership.TeamID
|
||||
})
|
||||
|
||||
|
@ -32,5 +32,6 @@ func FilterByResourceControl[T any](tx dataservices.DataStoreTx, items []T, rcTy
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
return filteredItems, nil
|
||||
}
|
||||
|
|
|
@ -36,23 +36,25 @@ func (payload *edgeGroupCreatePayload) Validate(r *http.Request) error {
|
|||
func calculateEndpointsOrTags(tx dataservices.DataStoreTx, edgeGroup *portainer.EdgeGroup, endpoints []portainer.EndpointID, tagIDs []portainer.TagID) error {
|
||||
if edgeGroup.Dynamic {
|
||||
edgeGroup.TagIDs = tagIDs
|
||||
} else {
|
||||
endpointIDs := []portainer.EndpointID{}
|
||||
|
||||
for _, endpointID := range endpoints {
|
||||
endpoint, err := tx.Endpoint().Endpoint(endpointID)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to retrieve environment from the database", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if endpointutils.IsEdgeEndpoint(endpoint) {
|
||||
endpointIDs = append(endpointIDs, endpoint.ID)
|
||||
}
|
||||
endpointIDs := []portainer.EndpointID{}
|
||||
|
||||
for _, endpointID := range endpoints {
|
||||
endpoint, err := tx.Endpoint().Endpoint(endpointID)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to retrieve environment from the database", err)
|
||||
}
|
||||
|
||||
edgeGroup.Endpoints = endpointIDs
|
||||
if endpointutils.IsEdgeEndpoint(endpoint) {
|
||||
endpointIDs = append(endpointIDs, endpoint.ID)
|
||||
}
|
||||
}
|
||||
|
||||
edgeGroup.Endpoints = endpointIDs
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -71,13 +73,13 @@ func calculateEndpointsOrTags(tx dataservices.DataStoreTx, edgeGroup *portainer.
|
|||
// @router /edge_groups [post]
|
||||
func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
|
||||
var payload edgeGroupCreatePayload
|
||||
err := request.DecodeAndValidateJSONPayload(r, &payload)
|
||||
if err != nil {
|
||||
if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
|
||||
return httperror.BadRequest("Invalid request payload", err)
|
||||
}
|
||||
|
||||
var edgeGroup *portainer.EdgeGroup
|
||||
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
|
||||
err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
edgeGroups, err := tx.EdgeGroup().ReadAll()
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to retrieve Edge groups from the database", err)
|
||||
|
@ -101,8 +103,7 @@ func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request)
|
|||
return err
|
||||
}
|
||||
|
||||
err = tx.EdgeGroup().Create(edgeGroup)
|
||||
if err != nil {
|
||||
if err := tx.EdgeGroup().Create(edgeGroup); err != nil {
|
||||
return httperror.InternalServerError("Unable to persist the Edge group inside the database", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/internal/edge"
|
||||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
"github.com/portainer/portainer/api/internal/unique"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
|
||||
|
@ -113,7 +113,7 @@ func (handler *Handler) edgeGroupUpdate(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
|
||||
newRelatedEndpoints := edge.EdgeGroupRelatedEndpoints(edgeGroup, endpoints, endpointGroups)
|
||||
endpointsToUpdate := unique.Unique(append(newRelatedEndpoints, oldRelatedEndpoints...))
|
||||
endpointsToUpdate := slicesx.Unique(append(newRelatedEndpoints, oldRelatedEndpoints...))
|
||||
|
||||
edgeJobs, err := tx.EdgeJob().ReadAll()
|
||||
if err != nil {
|
||||
|
|
|
@ -31,8 +31,7 @@ func setupHandler(t *testing.T) (*Handler, string) {
|
|||
}
|
||||
|
||||
user := &portainer.User{ID: 2, Username: "admin", Role: portainer.AdministratorRole}
|
||||
err = store.User().Create(user)
|
||||
if err != nil {
|
||||
if err := store.User().Create(user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -66,8 +65,7 @@ func setupHandler(t *testing.T) (*Handler, string) {
|
|||
}
|
||||
settings.EnableEdgeComputeFeatures = true
|
||||
|
||||
err = handler.DataStore.Settings().UpdateSettings(settings)
|
||||
if err != nil {
|
||||
if err := handler.DataStore.Settings().UpdateSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -88,8 +86,7 @@ func createEndpointWithId(t *testing.T, store dataservices.DataStore, endpointID
|
|||
LastCheckInDate: time.Now().Unix(),
|
||||
}
|
||||
|
||||
err := store.Endpoint().Create(&endpoint)
|
||||
if err != nil {
|
||||
if err := store.Endpoint().Create(&endpoint); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -112,8 +109,7 @@ func createEdgeStack(t *testing.T, store dataservices.DataStore, endpointID port
|
|||
PartialMatch: false,
|
||||
}
|
||||
|
||||
err := store.EdgeGroup().Create(&edgeGroup)
|
||||
if err != nil {
|
||||
if err := store.EdgeGroup().Create(&edgeGroup); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -138,13 +134,11 @@ func createEdgeStack(t *testing.T, store dataservices.DataStore, endpointID port
|
|||
},
|
||||
}
|
||||
|
||||
err = store.EdgeStack().Create(edgeStack.ID, &edgeStack)
|
||||
if err != nil {
|
||||
if err := store.EdgeStack().Create(edgeStack.ID, &edgeStack); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = store.EndpointRelation().Create(&endpointRelation)
|
||||
if err != nil {
|
||||
if err := store.EndpointRelation().Create(&endpointRelation); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/internal/edge"
|
||||
"github.com/portainer/portainer/api/internal/set"
|
||||
"github.com/portainer/portainer/api/set"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/segmentio/encoding/json"
|
||||
)
|
||||
|
@ -24,8 +25,7 @@ func TestUpdateAndInspect(t *testing.T) {
|
|||
endpointID := portainer.EndpointID(6)
|
||||
newEndpoint := createEndpointWithId(t, handler.DataStore, endpointID)
|
||||
|
||||
err := handler.DataStore.Endpoint().Create(&newEndpoint)
|
||||
if err != nil {
|
||||
if err := handler.DataStore.Endpoint().Create(&newEndpoint); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -36,8 +36,7 @@ func TestUpdateAndInspect(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
err = handler.DataStore.EndpointRelation().Create(&endpointRelation)
|
||||
if err != nil {
|
||||
if err := handler.DataStore.EndpointRelation().Create(&endpointRelation); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -50,8 +49,7 @@ func TestUpdateAndInspect(t *testing.T) {
|
|||
PartialMatch: false,
|
||||
}
|
||||
|
||||
err = handler.DataStore.EdgeGroup().Create(&newEdgeGroup)
|
||||
if err != nil {
|
||||
if err := handler.DataStore.EdgeGroup().Create(&newEdgeGroup); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -96,8 +94,7 @@ func TestUpdateAndInspect(t *testing.T) {
|
|||
}
|
||||
|
||||
updatedStack := portainer.EdgeStack{}
|
||||
err = json.NewDecoder(rec.Body).Decode(&updatedStack)
|
||||
if err != nil {
|
||||
if err := json.NewDecoder(rec.Body).Decode(&updatedStack); err != nil {
|
||||
t.Fatal("error decoding response:", err)
|
||||
}
|
||||
|
||||
|
@ -120,7 +117,6 @@ func TestUpdateWithInvalidEdgeGroups(t *testing.T) {
|
|||
endpoint := createEndpoint(t, handler.DataStore)
|
||||
edgeStack := createEdgeStack(t, handler.DataStore, endpoint.ID)
|
||||
|
||||
//newEndpoint := createEndpoint(t, handler.DataStore)
|
||||
newEdgeGroup := portainer.EdgeGroup{
|
||||
ID: 2,
|
||||
Name: "EdgeGroup 2",
|
||||
|
@ -130,7 +126,8 @@ func TestUpdateWithInvalidEdgeGroups(t *testing.T) {
|
|||
PartialMatch: false,
|
||||
}
|
||||
|
||||
handler.DataStore.EdgeGroup().Create(&newEdgeGroup)
|
||||
err := handler.DataStore.EdgeGroup().Create(&newEdgeGroup)
|
||||
require.NoError(t, err)
|
||||
|
||||
cases := []struct {
|
||||
Name string
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
|
||||
"github.com/segmentio/encoding/json"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type endpointTestCase struct {
|
||||
|
@ -99,8 +100,7 @@ func mustSetupHandler(t *testing.T) *Handler {
|
|||
}
|
||||
settings.TrustOnFirstConnect = true
|
||||
|
||||
err = store.Settings().UpdateSettings(settings)
|
||||
if err != nil {
|
||||
if err = store.Settings().UpdateSettings(settings); err != nil {
|
||||
t.Fatalf("could not update settings: %s", err)
|
||||
}
|
||||
|
||||
|
@ -122,8 +122,7 @@ func createEndpoint(handler *Handler, endpoint portainer.Endpoint, endpointRelat
|
|||
return nil
|
||||
}
|
||||
|
||||
err = handler.DataStore.Endpoint().Create(&endpoint)
|
||||
if err != nil {
|
||||
if err := handler.DataStore.Endpoint().Create(&endpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -134,14 +133,13 @@ func TestMissingEdgeIdentifier(t *testing.T) {
|
|||
handler := mustSetupHandler(t)
|
||||
endpointID := portainer.EndpointID(45)
|
||||
|
||||
err := createEndpoint(handler, portainer.Endpoint{
|
||||
if err := createEndpoint(handler, portainer.Endpoint{
|
||||
ID: endpointID,
|
||||
Name: "endpoint-id-45",
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
URL: "https://portainer.io:9443",
|
||||
EdgeID: "edge-id",
|
||||
}, portainer.EndpointRelation{EndpointID: endpointID})
|
||||
if err != nil {
|
||||
}, portainer.EndpointRelation{EndpointID: endpointID}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -201,8 +199,7 @@ func TestLastCheckInDateIncreases(t *testing.T) {
|
|||
EndpointID: endpoint.ID,
|
||||
}
|
||||
|
||||
err := createEndpoint(handler, endpoint, endpointRelation)
|
||||
if err != nil {
|
||||
if err := createEndpoint(handler, endpoint, endpointRelation); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -212,6 +209,7 @@ func TestLastCheckInDateIncreases(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal("request error:", err)
|
||||
}
|
||||
|
||||
req.Header.Set(portainer.PortainerAgentEdgeIDHeader, "edge-id")
|
||||
req.Header.Set(portainer.HTTPResponseAgentPlatform, "1")
|
||||
|
||||
|
@ -246,8 +244,7 @@ func TestEmptyEdgeIdWithAgentPlatformHeader(t *testing.T) {
|
|||
EndpointID: endpoint.ID,
|
||||
}
|
||||
|
||||
err := createEndpoint(handler, endpoint, endpointRelation)
|
||||
if err != nil {
|
||||
if err := createEndpoint(handler, endpoint, endpointRelation); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -255,6 +252,7 @@ func TestEmptyEdgeIdWithAgentPlatformHeader(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal("request error:", err)
|
||||
}
|
||||
|
||||
req.Header.Set(portainer.PortainerAgentEdgeIDHeader, edgeId)
|
||||
req.Header.Set(portainer.HTTPResponseAgentPlatform, "1")
|
||||
|
||||
|
@ -308,10 +306,11 @@ func TestEdgeStackStatus(t *testing.T) {
|
|||
edgeStack.ID: true,
|
||||
},
|
||||
}
|
||||
handler.DataStore.EdgeStack().Create(edgeStack.ID, &edgeStack)
|
||||
|
||||
err := createEndpoint(handler, endpoint, endpointRelation)
|
||||
if err != nil {
|
||||
err := handler.DataStore.EdgeStack().Create(edgeStack.ID, &edgeStack)
|
||||
require.NoError(t, err)
|
||||
|
||||
if err := createEndpoint(handler, endpoint, endpointRelation); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -319,6 +318,7 @@ func TestEdgeStackStatus(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal("request error:", err)
|
||||
}
|
||||
|
||||
req.Header.Set(portainer.PortainerAgentEdgeIDHeader, "edge-id")
|
||||
req.Header.Set(portainer.HTTPResponseAgentPlatform, "1")
|
||||
|
||||
|
@ -330,8 +330,7 @@ func TestEdgeStackStatus(t *testing.T) {
|
|||
}
|
||||
|
||||
var data endpointEdgeStatusInspectResponse
|
||||
err = json.NewDecoder(rec.Body).Decode(&data)
|
||||
if err != nil {
|
||||
if err := json.NewDecoder(rec.Body).Decode(&data); err != nil {
|
||||
t.Fatal("error decoding response:", err)
|
||||
}
|
||||
|
||||
|
@ -357,8 +356,7 @@ func TestEdgeJobsResponse(t *testing.T) {
|
|||
EndpointID: endpoint.ID,
|
||||
}
|
||||
|
||||
err := createEndpoint(handler, endpoint, endpointRelation)
|
||||
if err != nil {
|
||||
if err := createEndpoint(handler, endpoint, endpointRelation); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -384,6 +382,7 @@ func TestEdgeJobsResponse(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal("request error:", err)
|
||||
}
|
||||
|
||||
req.Header.Set(portainer.PortainerAgentEdgeIDHeader, "edge-id")
|
||||
req.Header.Set(portainer.HTTPResponseAgentPlatform, "1")
|
||||
|
||||
|
@ -395,8 +394,7 @@ func TestEdgeJobsResponse(t *testing.T) {
|
|||
}
|
||||
|
||||
var data endpointEdgeStatusInspectResponse
|
||||
err = json.NewDecoder(rec.Body).Decode(&data)
|
||||
if err != nil {
|
||||
if err := json.NewDecoder(rec.Body).Decode(&data); err != nil {
|
||||
t.Fatal("error decoding response:", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -8,8 +8,8 @@ import (
|
|||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
"github.com/portainer/portainer/api/internal/tag"
|
||||
"github.com/portainer/portainer/api/pendingactions/handlers"
|
||||
"github.com/portainer/portainer/api/tag"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/set"
|
||||
"github.com/portainer/portainer/api/set"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
)
|
||||
|
|
|
@ -73,8 +73,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
|
|||
payload.GroupID = groupID
|
||||
|
||||
var tagIDs []portainer.TagID
|
||||
err = request.RetrieveMultiPartFormJSONValue(r, "TagIds", &tagIDs, true)
|
||||
if err != nil {
|
||||
if err := request.RetrieveMultiPartFormJSONValue(r, "TagIds", &tagIDs, true); err != nil {
|
||||
return errors.New("invalid TagIds parameter")
|
||||
}
|
||||
payload.TagIDs = tagIDs
|
||||
|
@ -96,6 +95,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
|
|||
if err != nil {
|
||||
return errors.New("invalid CA certificate file. Ensure that the file is uploaded correctly")
|
||||
}
|
||||
|
||||
payload.TLSCACertFile = caCert
|
||||
}
|
||||
|
||||
|
@ -110,6 +110,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
|
|||
if err != nil {
|
||||
return errors.New("invalid key file. Ensure that the file is uploaded correctly")
|
||||
}
|
||||
|
||||
payload.TLSKeyFile = key
|
||||
}
|
||||
}
|
||||
|
@ -120,6 +121,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
|
|||
if err != nil {
|
||||
return errors.New("invalid Azure application ID")
|
||||
}
|
||||
|
||||
payload.AzureApplicationID = azureApplicationID
|
||||
|
||||
azureTenantID, err := request.RetrieveMultiPartFormValue(r, "AzureTenantID", false)
|
||||
|
@ -139,6 +141,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
|
|||
if err != nil || strings.EqualFold("", strings.Trim(endpointURL, " ")) {
|
||||
return errors.New("URL cannot be empty")
|
||||
}
|
||||
|
||||
payload.URL = endpointURL
|
||||
|
||||
publicURL, _ := request.RetrieveMultiPartFormValue(r, "PublicURL", true)
|
||||
|
@ -156,10 +159,10 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
|
|||
}
|
||||
|
||||
gpus := make([]portainer.Pair, 0)
|
||||
err = request.RetrieveMultiPartFormJSONValue(r, "Gpus", &gpus, true)
|
||||
if err != nil {
|
||||
if err := request.RetrieveMultiPartFormJSONValue(r, "Gpus", &gpus, true); err != nil {
|
||||
return errors.New("invalid Gpus parameter")
|
||||
}
|
||||
|
||||
payload.Gpus = gpus
|
||||
|
||||
edgeCheckinInterval, _ := request.RetrieveNumericMultiPartFormValue(r, "EdgeCheckinInterval", true)
|
||||
|
@ -206,8 +209,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
|
|||
// @router /endpoints [post]
|
||||
func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
|
||||
payload := &endpointCreatePayload{}
|
||||
err := payload.Validate(r)
|
||||
if err != nil {
|
||||
if err := payload.Validate(r); err != nil {
|
||||
return httperror.BadRequest("Invalid request payload", err)
|
||||
}
|
||||
|
||||
|
@ -268,8 +270,7 @@ func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *
|
|||
)
|
||||
}
|
||||
|
||||
err = handler.DataStore.EndpointRelation().Create(relationObject)
|
||||
if err != nil {
|
||||
if err := handler.DataStore.EndpointRelation().Create(relationObject); err != nil {
|
||||
return httperror.InternalServerError("Unable to persist the relation object inside the database", err)
|
||||
}
|
||||
|
||||
|
@ -278,6 +279,7 @@ func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *
|
|||
|
||||
func (handler *Handler) createEndpoint(tx dataservices.DataStoreTx, payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) {
|
||||
var err error
|
||||
|
||||
switch payload.EndpointCreationType {
|
||||
case azureEnvironment:
|
||||
return handler.createAzureEndpoint(tx, payload)
|
||||
|
@ -329,8 +331,7 @@ func (handler *Handler) createAzureEndpoint(tx dataservices.DataStoreTx, payload
|
|||
}
|
||||
|
||||
httpClient := client.NewHTTPClient()
|
||||
_, err := httpClient.ExecuteAzureAuthenticationRequest(&credentials)
|
||||
if err != nil {
|
||||
if _, err := httpClient.ExecuteAzureAuthenticationRequest(&credentials); err != nil {
|
||||
return nil, httperror.InternalServerError("Unable to authenticate against Azure", err)
|
||||
}
|
||||
|
||||
|
@ -352,8 +353,7 @@ func (handler *Handler) createAzureEndpoint(tx dataservices.DataStoreTx, payload
|
|||
Kubernetes: portainer.KubernetesDefault(),
|
||||
}
|
||||
|
||||
err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil {
|
||||
return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err)
|
||||
}
|
||||
|
||||
|
@ -405,8 +405,7 @@ func (handler *Handler) createEdgeAgentEndpoint(tx dataservices.DataStoreTx, pay
|
|||
endpoint.EdgeID = edgeID.String()
|
||||
}
|
||||
|
||||
err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil {
|
||||
return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err)
|
||||
}
|
||||
|
||||
|
@ -443,8 +442,7 @@ func (handler *Handler) createUnsecuredEndpoint(tx dataservices.DataStoreTx, pay
|
|||
Kubernetes: portainer.KubernetesDefault(),
|
||||
}
|
||||
|
||||
err := handler.snapshotAndPersistEndpoint(tx, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -478,8 +476,7 @@ func (handler *Handler) createKubernetesEndpoint(tx dataservices.DataStoreTx, pa
|
|||
Kubernetes: portainer.KubernetesDefault(),
|
||||
}
|
||||
|
||||
err := handler.snapshotAndPersistEndpoint(tx, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -510,13 +507,11 @@ func (handler *Handler) createTLSSecuredEndpoint(tx dataservices.DataStoreTx, pa
|
|||
|
||||
endpoint.Agent.Version = agentVersion
|
||||
|
||||
err := handler.storeTLSFiles(endpoint, payload)
|
||||
if err != nil {
|
||||
if err := handler.storeTLSFiles(endpoint, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = handler.snapshotAndPersistEndpoint(tx, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -524,17 +519,16 @@ func (handler *Handler) createTLSSecuredEndpoint(tx dataservices.DataStoreTx, pa
|
|||
}
|
||||
|
||||
func (handler *Handler) snapshotAndPersistEndpoint(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint) *httperror.HandlerError {
|
||||
err := handler.SnapshotService.SnapshotEndpoint(endpoint)
|
||||
if err != nil {
|
||||
if err := handler.SnapshotService.SnapshotEndpoint(endpoint); err != nil {
|
||||
if (endpoint.Type == portainer.AgentOnDockerEnvironment && strings.Contains(err.Error(), "Invalid request signature")) ||
|
||||
(endpoint.Type == portainer.AgentOnKubernetesEnvironment && strings.Contains(err.Error(), "unknown")) {
|
||||
err = errors.New("agent already paired with another Portainer instance")
|
||||
}
|
||||
|
||||
return httperror.InternalServerError("Unable to initiate communications with environment", err)
|
||||
}
|
||||
|
||||
err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil {
|
||||
return httperror.InternalServerError("An error occurred while trying to create the environment", err)
|
||||
}
|
||||
|
||||
|
@ -555,16 +549,14 @@ func (handler *Handler) saveEndpointAndUpdateAuthorizations(tx dataservices.Data
|
|||
AllowStackManagementForRegularUsers: true,
|
||||
}
|
||||
|
||||
err := tx.Endpoint().Create(endpoint)
|
||||
if err != nil {
|
||||
if err := tx.Endpoint().Create(endpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, tagID := range endpoint.TagIDs {
|
||||
err = tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) {
|
||||
if err := tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) {
|
||||
tag.Endpoints[endpoint.ID] = true
|
||||
})
|
||||
if err != nil {
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -580,22 +572,26 @@ func (handler *Handler) storeTLSFiles(endpoint *portainer.Endpoint, payload *end
|
|||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to persist TLS CA certificate file on disk", err)
|
||||
}
|
||||
|
||||
endpoint.TLSConfig.TLSCACertPath = caCertPath
|
||||
}
|
||||
|
||||
if !payload.TLSSkipClientVerify {
|
||||
certPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCert, payload.TLSCertFile)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to persist TLS certificate file on disk", err)
|
||||
}
|
||||
endpoint.TLSConfig.TLSCertPath = certPath
|
||||
|
||||
keyPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileKey, payload.TLSKeyFile)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to persist TLS key file on disk", err)
|
||||
}
|
||||
endpoint.TLSConfig.TLSKeyPath = keyPath
|
||||
if payload.TLSSkipClientVerify {
|
||||
return nil
|
||||
}
|
||||
|
||||
certPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCert, payload.TLSCertFile)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to persist TLS certificate file on disk", err)
|
||||
}
|
||||
|
||||
endpoint.TLSConfig.TLSCertPath = certPath
|
||||
|
||||
keyPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileKey, payload.TLSKeyFile)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to persist TLS key file on disk", err)
|
||||
}
|
||||
endpoint.TLSConfig.TLSKeyPath = keyPath
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -30,24 +30,22 @@ func TestEndpointDeleteEdgeGroupsConcurrently(t *testing.T) {
|
|||
for i := 0; i < endpointsCount; i++ {
|
||||
endpointID := portainer.EndpointID(i) + 1
|
||||
|
||||
err := store.Endpoint().Create(&portainer.Endpoint{
|
||||
if err := store.Endpoint().Create(&portainer.Endpoint{
|
||||
ID: endpointID,
|
||||
Name: "env-" + strconv.Itoa(int(endpointID)),
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
})
|
||||
if err != nil {
|
||||
}); err != nil {
|
||||
t.Fatal("could not create endpoint:", err)
|
||||
}
|
||||
|
||||
endpointIDs = append(endpointIDs, endpointID)
|
||||
}
|
||||
|
||||
err := store.EdgeGroup().Create(&portainer.EdgeGroup{
|
||||
if err := store.EdgeGroup().Create(&portainer.EdgeGroup{
|
||||
ID: 1,
|
||||
Name: "edgegroup-1",
|
||||
Endpoints: endpointIDs,
|
||||
})
|
||||
if err != nil {
|
||||
}); err != nil {
|
||||
t.Fatal("could not create edge group:", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -102,7 +102,6 @@ func Test_EndpointList_AgentVersion(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_endpointList_edgeFilter(t *testing.T) {
|
||||
|
||||
trustedEdgeAsync := portainer.Endpoint{ID: 1, UserTrusted: true, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
|
||||
untrustedEdgeAsync := portainer.Endpoint{ID: 2, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
|
||||
regularUntrustedEdgeStandard := portainer.Endpoint{ID: 3, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: false}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
|
||||
|
@ -227,8 +226,7 @@ func doEndpointListRequest(req *http.Request, h *Handler, is *assert.Assertions)
|
|||
}
|
||||
|
||||
resp := []portainer.Endpoint{}
|
||||
err = json.Unmarshal(body, &resp)
|
||||
if err != nil {
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -34,12 +34,10 @@ func (handler *Handler) endpointRegistriesList(w http.ResponseWriter, r *http.Re
|
|||
}
|
||||
|
||||
var registries []portainer.Registry
|
||||
err = handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error {
|
||||
if err := handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error {
|
||||
registries, err = handler.listRegistries(tx, r, portainer.EndpointID(endpointID))
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
}); err != nil {
|
||||
var httpErr *httperror.HandlerError
|
||||
if errors.As(err, &httpErr) {
|
||||
return httpErr
|
||||
|
@ -104,11 +102,9 @@ func (handler *Handler) filterKubernetesEndpointRegistries(r *http.Request, regi
|
|||
}
|
||||
|
||||
if namespaceParam != "" {
|
||||
authorized, err := handler.isNamespaceAuthorized(endpoint, namespaceParam, user.ID, memberships, isAdmin)
|
||||
if err != nil {
|
||||
if authorized, err := handler.isNamespaceAuthorized(endpoint, namespaceParam, user.ID, memberships, isAdmin); err != nil {
|
||||
return nil, httperror.NotFound("Unable to check for namespace authorization", err)
|
||||
}
|
||||
if !authorized {
|
||||
} else if !authorized {
|
||||
return nil, httperror.Forbidden("User is not authorized to use namespace", errors.New("user is not authorized to use namespace"))
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
"github.com/portainer/portainer/api/http/handler/edgegroups"
|
||||
"github.com/portainer/portainer/api/internal/edge"
|
||||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
"github.com/portainer/portainer/api/internal/unique"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -254,6 +254,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
|
|||
if err != nil {
|
||||
return nil, errors.WithMessage(err, "Unable to retrieve edge group from the database")
|
||||
}
|
||||
|
||||
if edgeGroup.Dynamic {
|
||||
endpointIDs, err := edgegroups.GetEndpointsByTags(datastore, edgeGroup.TagIDs, edgeGroup.PartialMatch)
|
||||
if err != nil {
|
||||
|
@ -261,6 +262,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
|
|||
}
|
||||
edgeGroup.Endpoints = endpointIDs
|
||||
}
|
||||
|
||||
envIds = append(envIds, edgeGroup.Endpoints...)
|
||||
}
|
||||
|
||||
|
@ -275,7 +277,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
|
|||
envIds = envIds[:n]
|
||||
}
|
||||
|
||||
uniqueIds := unique.Unique(envIds)
|
||||
uniqueIds := slicesx.Unique(envIds)
|
||||
filteredEndpoints := filteredEndpointsByIds(endpoints, uniqueIds)
|
||||
|
||||
return filteredEndpoints, nil
|
||||
|
|
|
@ -5,8 +5,8 @@ import (
|
|||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/internal/slices"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -129,7 +129,7 @@ func Test_Filter_edgeFilter(t *testing.T) {
|
|||
func Test_Filter_excludeIDs(t *testing.T) {
|
||||
ids := []portainer.EndpointID{1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||
|
||||
environments := slices.Map(ids, func(id portainer.EndpointID) portainer.Endpoint {
|
||||
environments := slicesx.Map(ids, func(id portainer.EndpointID) portainer.Endpoint {
|
||||
return portainer.Endpoint{ID: id, GroupID: 1, Type: portainer.DockerEnvironment}
|
||||
})
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@ import (
|
|||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/internal/slices"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
@ -162,7 +163,7 @@ func TestSortEndpointsByField(t *testing.T) {
|
|||
}
|
||||
|
||||
func getEndpointIDs(environments []portainer.Endpoint) []portainer.EndpointID {
|
||||
return slices.Map(environments, func(environment portainer.Endpoint) portainer.EndpointID {
|
||||
return slicesx.Map(environments, func(environment portainer.Endpoint) portainer.EndpointID {
|
||||
return environment.ID
|
||||
})
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/internal/edge"
|
||||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
"github.com/portainer/portainer/api/internal/set"
|
||||
"github.com/portainer/portainer/api/set"
|
||||
)
|
||||
|
||||
// updateEdgeRelations updates the edge stacks associated to an edge endpoint
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/internal/set"
|
||||
"github.com/portainer/portainer/api/set"
|
||||
)
|
||||
|
||||
func updateEnvironmentEdgeGroups(tx dataservices.DataStoreTx, newEdgeGroups []portainer.EdgeGroupID, environmentID portainer.EndpointID) (bool, error) {
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
)
|
||||
|
||||
func Test_updateEdgeGroups(t *testing.T) {
|
||||
|
||||
createGroups := func(store *datastore.Store, names []string) ([]portainer.EdgeGroup, error) {
|
||||
groups := make([]portainer.EdgeGroup, len(names))
|
||||
for index, name := range names {
|
||||
|
@ -21,8 +20,7 @@ func Test_updateEdgeGroups(t *testing.T) {
|
|||
Endpoints: make([]portainer.EndpointID, 0),
|
||||
}
|
||||
|
||||
err := store.EdgeGroup().Create(group)
|
||||
if err != nil {
|
||||
if err := store.EdgeGroup().Create(group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -42,6 +40,7 @@ func Test_updateEdgeGroups(t *testing.T) {
|
|||
return
|
||||
}
|
||||
}
|
||||
|
||||
is.Fail("expected endpoint to be in group")
|
||||
}
|
||||
}
|
||||
|
@ -52,6 +51,7 @@ func Test_updateEdgeGroups(t *testing.T) {
|
|||
for j, tag := range groups {
|
||||
if tag.Name == tagName {
|
||||
result[i] = groups[j]
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -88,6 +88,7 @@ func Test_updateEdgeGroups(t *testing.T) {
|
|||
}
|
||||
|
||||
expectedGroups := groupsByName(groups, testCase.groupsToApply)
|
||||
|
||||
expectedIDs := make([]portainer.EdgeGroupID, len(expectedGroups))
|
||||
for i, tag := range expectedGroups {
|
||||
expectedIDs[i] = tag.ID
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/internal/set"
|
||||
"github.com/portainer/portainer/api/set"
|
||||
)
|
||||
|
||||
// updateEnvironmentTags updates the tags associated to an environment
|
||||
|
|
|
@ -10,14 +10,14 @@ import (
|
|||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/exec/exectest"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
helper "github.com/portainer/portainer/api/internal/testhelpers"
|
||||
"github.com/portainer/portainer/api/jwt"
|
||||
"github.com/portainer/portainer/api/kubernetes"
|
||||
"github.com/portainer/portainer/pkg/libhelm/binary/test"
|
||||
"github.com/portainer/portainer/pkg/libhelm/options"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
helper "github.com/portainer/portainer/api/internal/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_helmDelete(t *testing.T) {
|
||||
|
|
|
@ -97,13 +97,13 @@ func (handler *Handler) userHasRegistryAccess(r *http.Request) (hasAccess bool,
|
|||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
endpoint, err := handler.DataStore.Endpoint().Endpoint(portainer.EndpointID(endpointID))
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
err = handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint); err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
|
|
|
@ -71,6 +71,7 @@ func (handler *Handler) settingsPublic(w http.ResponseWriter, r *http.Request) *
|
|||
}
|
||||
|
||||
publicSettings := generatePublicSettings(settings)
|
||||
|
||||
return response.JSON(w, publicSettings)
|
||||
}
|
||||
|
||||
|
@ -96,7 +97,7 @@ func generatePublicSettings(appSettings *portainer.Settings) *publicSettingsResp
|
|||
|
||||
publicSettings.IsDockerDesktopExtension = appSettings.IsDockerDesktopExtension
|
||||
|
||||
//if OAuth authentication is on, compose the related fields from application settings
|
||||
// If OAuth authentication is on, compose the related fields from application settings
|
||||
if publicSettings.AuthenticationMethod == portainer.AuthenticationOAuth {
|
||||
publicSettings.OAuthLogoutURI = appSettings.OAuthSettings.LogoutURI
|
||||
publicSettings.OAuthLoginURI = fmt.Sprintf("%s?response_type=code&client_id=%s&redirect_uri=%s&scope=%s",
|
||||
|
@ -104,16 +105,18 @@ func generatePublicSettings(appSettings *portainer.Settings) *publicSettingsResp
|
|||
appSettings.OAuthSettings.ClientID,
|
||||
appSettings.OAuthSettings.RedirectURI,
|
||||
appSettings.OAuthSettings.Scopes)
|
||||
//control prompt=login param according to the SSO setting
|
||||
|
||||
// Control prompt=login param according to the SSO setting
|
||||
if !appSettings.OAuthSettings.SSO {
|
||||
publicSettings.OAuthLoginURI += "&prompt=login"
|
||||
}
|
||||
}
|
||||
//if LDAP authentication is on, compose the related fields from application settings
|
||||
// If LDAP authentication is on, compose the related fields from application settings
|
||||
if publicSettings.AuthenticationMethod == portainer.AuthenticationLDAP && appSettings.LDAPSettings.GroupSearchSettings != nil {
|
||||
if len(appSettings.LDAPSettings.GroupSearchSettings) > 0 {
|
||||
publicSettings.TeamSync = len(appSettings.LDAPSettings.GroupSearchSettings[0].GroupBaseDN) > 0
|
||||
}
|
||||
}
|
||||
|
||||
return publicSettings
|
||||
}
|
||||
|
|
|
@ -40,14 +40,17 @@ func setup() {
|
|||
|
||||
func TestGeneratePublicSettingsWithSSO(t *testing.T) {
|
||||
setup()
|
||||
|
||||
mockAppSettings.OAuthSettings.SSO = true
|
||||
publicSettings := generatePublicSettings(mockAppSettings)
|
||||
if publicSettings.AuthenticationMethod != portainer.AuthenticationOAuth {
|
||||
t.Errorf("wrong AuthenticationMethod, want: %d, got: %d", portainer.AuthenticationOAuth, publicSettings.AuthenticationMethod)
|
||||
}
|
||||
|
||||
if publicSettings.OAuthLoginURI != dummyOAuthLoginURI {
|
||||
t.Errorf("wrong OAuthLoginURI when SSO is switched on, want: %s, got: %s", dummyOAuthLoginURI, publicSettings.OAuthLoginURI)
|
||||
}
|
||||
|
||||
if publicSettings.OAuthLogoutURI != dummyOAuthLogoutURI {
|
||||
t.Errorf("wrong OAuthLogoutURI, want: %s, got: %s", dummyOAuthLogoutURI, publicSettings.OAuthLogoutURI)
|
||||
}
|
||||
|
@ -55,15 +58,18 @@ func TestGeneratePublicSettingsWithSSO(t *testing.T) {
|
|||
|
||||
func TestGeneratePublicSettingsWithoutSSO(t *testing.T) {
|
||||
setup()
|
||||
|
||||
mockAppSettings.OAuthSettings.SSO = false
|
||||
publicSettings := generatePublicSettings(mockAppSettings)
|
||||
if publicSettings.AuthenticationMethod != portainer.AuthenticationOAuth {
|
||||
t.Errorf("wrong AuthenticationMethod, want: %d, got: %d", portainer.AuthenticationOAuth, publicSettings.AuthenticationMethod)
|
||||
}
|
||||
|
||||
expectedOAuthLoginURI := dummyOAuthLoginURI + "&prompt=login"
|
||||
if publicSettings.OAuthLoginURI != expectedOAuthLoginURI {
|
||||
t.Errorf("wrong OAuthLoginURI when SSO is switched off, want: %s, got: %s", expectedOAuthLoginURI, publicSettings.OAuthLoginURI)
|
||||
}
|
||||
|
||||
if publicSettings.OAuthLogoutURI != dummyOAuthLogoutURI {
|
||||
t.Errorf("wrong OAuthLogoutURI, want: %s, got: %s", dummyOAuthLogoutURI, publicSettings.OAuthLogoutURI)
|
||||
}
|
||||
|
|
|
@ -89,8 +89,7 @@ func (handler *Handler) stackDelete(w http.ResponseWriter, r *http.Request) *htt
|
|||
}
|
||||
|
||||
if !isOrphaned {
|
||||
err = handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint); err != nil {
|
||||
return httperror.Forbidden("Permission denied to access endpoint", err)
|
||||
}
|
||||
|
||||
|
@ -119,25 +118,21 @@ func (handler *Handler) stackDelete(w http.ResponseWriter, r *http.Request) *htt
|
|||
deployments.StopAutoupdate(stack.ID, stack.AutoUpdate.JobID, handler.Scheduler)
|
||||
}
|
||||
|
||||
err = handler.deleteStack(securityContext.UserID, stack, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.deleteStack(securityContext.UserID, stack, endpoint); err != nil {
|
||||
return httperror.InternalServerError(err.Error(), err)
|
||||
}
|
||||
|
||||
err = handler.DataStore.Stack().Delete(portainer.StackID(id))
|
||||
if err != nil {
|
||||
if err := handler.DataStore.Stack().Delete(portainer.StackID(id)); err != nil {
|
||||
return httperror.InternalServerError("Unable to remove the stack from the database", err)
|
||||
}
|
||||
|
||||
if resourceControl != nil {
|
||||
err = handler.DataStore.ResourceControl().Delete(resourceControl.ID)
|
||||
if err != nil {
|
||||
if err := handler.DataStore.ResourceControl().Delete(resourceControl.ID); err != nil {
|
||||
return httperror.InternalServerError("Unable to remove the associated resource control from the database", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = handler.FileService.RemoveDirectory(stack.ProjectPath)
|
||||
if err != nil {
|
||||
if err := handler.FileService.RemoveDirectory(stack.ProjectPath); err != nil {
|
||||
log.Warn().Err(err).Msg("Unable to remove stack files from disk")
|
||||
}
|
||||
|
||||
|
@ -169,8 +164,7 @@ func (handler *Handler) deleteExternalStack(r *http.Request, w http.ResponseWrit
|
|||
return httperror.InternalServerError("Unable to find the endpoint associated to the stack inside the database", err)
|
||||
}
|
||||
|
||||
err = handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint); err != nil {
|
||||
return httperror.Forbidden("Permission denied to access endpoint", err)
|
||||
}
|
||||
|
||||
|
@ -179,8 +173,7 @@ func (handler *Handler) deleteExternalStack(r *http.Request, w http.ResponseWrit
|
|||
Type: portainer.DockerSwarmStack,
|
||||
}
|
||||
|
||||
err = handler.deleteStack(securityContext.UserID, stack, endpoint)
|
||||
if err != nil {
|
||||
if err := handler.deleteStack(securityContext.UserID, stack, endpoint); err != nil {
|
||||
return httperror.InternalServerError("Unable to delete stack", err)
|
||||
}
|
||||
|
||||
|
@ -255,6 +248,7 @@ func (handler *Handler) deleteStack(userID portainer.UserID, stack *portainer.St
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errors.WithMessagef(err, "failed to remove kubernetes resources: %q", out)
|
||||
}
|
||||
|
||||
|
@ -369,18 +363,18 @@ func (handler *Handler) stackDeleteKubernetesByName(w http.ResponseWriter, r *ht
|
|||
if err != nil {
|
||||
log.Err(err).Msgf("Unable to delete Kubernetes stack `%d`", stack.ID)
|
||||
errors = append(errors, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
err = handler.DataStore.Stack().Delete(stack.ID)
|
||||
if err != nil {
|
||||
if err := handler.DataStore.Stack().Delete(stack.ID); err != nil {
|
||||
errors = append(errors, err)
|
||||
log.Err(err).Msgf("Unable to remove the stack `%d` from the database", stack.ID)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
err = handler.FileService.RemoveDirectory(stack.ProjectPath)
|
||||
if err != nil {
|
||||
if err := handler.FileService.RemoveDirectory(stack.ProjectPath); err != nil {
|
||||
errors = append(errors, err)
|
||||
log.Warn().Err(err).Msg("Unable to remove stack files from disk")
|
||||
}
|
||||
|
|
|
@ -18,8 +18,7 @@ func TestTagDeleteEdgeGroupsConcurrently(t *testing.T) {
|
|||
_, store := datastore.MustNewTestStore(t, true, false)
|
||||
|
||||
user := &portainer.User{ID: 2, Username: "admin", Role: portainer.AdministratorRole}
|
||||
err := store.User().Create(user)
|
||||
if err != nil {
|
||||
if err := store.User().Create(user); err != nil {
|
||||
t.Fatal("could not create admin user:", err)
|
||||
}
|
||||
|
||||
|
@ -33,29 +32,28 @@ func TestTagDeleteEdgeGroupsConcurrently(t *testing.T) {
|
|||
for i := 0; i < tagsCount; i++ {
|
||||
tagID := portainer.TagID(i) + 1
|
||||
|
||||
err = store.Tag().Create(&portainer.Tag{
|
||||
if err := store.Tag().Create(&portainer.Tag{
|
||||
ID: tagID,
|
||||
Name: "tag-" + strconv.Itoa(int(tagID)),
|
||||
})
|
||||
if err != nil {
|
||||
}); err != nil {
|
||||
t.Fatal("could not create tag:", err)
|
||||
}
|
||||
|
||||
tagIDs = append(tagIDs, tagID)
|
||||
}
|
||||
|
||||
err = store.EdgeGroup().Create(&portainer.EdgeGroup{
|
||||
if err := store.EdgeGroup().Create(&portainer.EdgeGroup{
|
||||
ID: 1,
|
||||
Name: "edgegroup-1",
|
||||
TagIDs: tagIDs,
|
||||
})
|
||||
if err != nil {
|
||||
}); err != nil {
|
||||
t.Fatal("could not create edge group:", err)
|
||||
}
|
||||
|
||||
// Remove the tags concurrently
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(len(tagIDs))
|
||||
|
||||
for _, tagID := range tagIDs {
|
||||
|
|
|
@ -27,6 +27,7 @@ func (payload *userCreatePayload) Validate(r *http.Request) error {
|
|||
if payload.Role != 1 && payload.Role != 2 {
|
||||
return errors.New("Invalid role value. Value must be one of: 1 (administrator) or 2 (regular user)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -49,8 +50,7 @@ func (payload *userCreatePayload) Validate(r *http.Request) error {
|
|||
// @router /users [post]
|
||||
func (handler *Handler) userCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
|
||||
var payload userCreatePayload
|
||||
err := request.DecodeAndValidateJSONPayload(r, &payload)
|
||||
if err != nil {
|
||||
if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
|
||||
return httperror.BadRequest("Invalid request payload", err)
|
||||
}
|
||||
|
||||
|
@ -89,11 +89,11 @@ func (handler *Handler) userCreate(w http.ResponseWriter, r *http.Request) *http
|
|||
}
|
||||
}
|
||||
|
||||
err = handler.DataStore.User().Create(user)
|
||||
if err != nil {
|
||||
if err := handler.DataStore.User().Create(user); err != nil {
|
||||
return httperror.InternalServerError("Unable to persist user inside the database", err)
|
||||
}
|
||||
|
||||
hideFields(user)
|
||||
|
||||
return response.JSON(w, user)
|
||||
}
|
||||
|
|
|
@ -26,12 +26,12 @@ func Test_userList(t *testing.T) {
|
|||
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
|
||||
// create admin and standard user(s)
|
||||
// Create admin and standard user(s)
|
||||
adminUser := &portainer.User{ID: 1, Username: "admin", Role: portainer.AdministratorRole}
|
||||
err := store.User().Create(adminUser)
|
||||
is.NoError(err, "error creating admin user")
|
||||
|
||||
// setup services
|
||||
// Setup services
|
||||
jwtService, err := jwt.NewService("1h", store)
|
||||
is.NoError(err, "Error initiating jwt service")
|
||||
apiKeyService := apikey.NewAPIKeyService(store.APIKeyRepository(), store.User())
|
||||
|
@ -42,7 +42,7 @@ func Test_userList(t *testing.T) {
|
|||
h := NewHandler(requestBouncer, rateLimiter, apiKeyService, passwordChecker)
|
||||
h.DataStore = store
|
||||
|
||||
// generate admin user tokens
|
||||
// Generate admin user tokens
|
||||
adminJWT, _, _ := jwtService.GenerateToken(&portainer.TokenData{ID: adminUser.ID, Username: adminUser.Username, Role: adminUser.Role})
|
||||
|
||||
// Case 1: the user is given the endpoint access directly
|
||||
|
@ -54,12 +54,12 @@ func Test_userList(t *testing.T) {
|
|||
err = store.User().Create(userWithoutEndpointAccess)
|
||||
is.NoError(err, "error creating user")
|
||||
|
||||
// create environment group
|
||||
// Create environment group
|
||||
endpointGroup := &portainer.EndpointGroup{ID: 1, Name: "default-endpoint-group"}
|
||||
err = store.EndpointGroup().Create(endpointGroup)
|
||||
is.NoError(err, "error creating endpoint group")
|
||||
|
||||
// create endpoint and user access policies
|
||||
// Create endpoint and user access policies
|
||||
userAccessPolicies := make(portainer.UserAccessPolicies, 0)
|
||||
userAccessPolicies[userWithEndpointAccess.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userWithEndpointAccess.Role)}
|
||||
|
||||
|
@ -129,7 +129,7 @@ func Test_userList(t *testing.T) {
|
|||
err = store.User().Create(userUnderGroup)
|
||||
is.NoError(err, "error creating user")
|
||||
|
||||
// create environment group including a user
|
||||
// Create environment group including a user
|
||||
userAccessPoliciesUnderGroup := make(portainer.UserAccessPolicies, 0)
|
||||
userAccessPoliciesUnderGroup[userUnderGroup.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userUnderGroup.Role)}
|
||||
|
||||
|
@ -137,7 +137,7 @@ func Test_userList(t *testing.T) {
|
|||
err = store.EndpointGroup().Create(endpointGroupWithUser)
|
||||
is.NoError(err, "error creating endpoint group")
|
||||
|
||||
// create endpoint
|
||||
// Create endpoint
|
||||
endpointUnderGroupWithUser := &portainer.Endpoint{ID: 2, GroupID: endpointGroupWithUser.ID}
|
||||
err = store.Endpoint().Create(endpointUnderGroupWithUser)
|
||||
is.NoError(err, "error creating endpoint")
|
||||
|
@ -182,7 +182,7 @@ func Test_userList(t *testing.T) {
|
|||
err = store.TeamMembership().Create(teamMembership)
|
||||
is.NoError(err, "error creating team membership")
|
||||
|
||||
// create environment group including a team
|
||||
// Create environment group including a team
|
||||
teamAccessPoliciesUnderGroup := make(portainer.TeamAccessPolicies, 0)
|
||||
teamAccessPoliciesUnderGroup[teamUnderGroup.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userUnderTeam.Role)}
|
||||
|
||||
|
@ -190,7 +190,7 @@ func Test_userList(t *testing.T) {
|
|||
err = store.EndpointGroup().Create(endpointGroupWithTeam)
|
||||
is.NoError(err, "error creating endpoint group")
|
||||
|
||||
// create endpoint
|
||||
// Create endpoint
|
||||
endpointUnderGroupWithTeam := &portainer.Endpoint{ID: 3, GroupID: endpointGroupWithTeam.ID}
|
||||
err = store.Endpoint().Create(endpointUnderGroupWithTeam)
|
||||
is.NoError(err, "error creating endpoint")
|
||||
|
@ -233,12 +233,12 @@ func Test_userList(t *testing.T) {
|
|||
err = store.TeamMembership().Create(teamMembershipWithEndpointAccess)
|
||||
is.NoError(err, "error creating team membership")
|
||||
|
||||
// create environment group
|
||||
// Create environment group
|
||||
endpointGroupWithoutTeam := &portainer.EndpointGroup{ID: 4, Name: "endpoint-group-without-team"}
|
||||
err = store.EndpointGroup().Create(endpointGroupWithoutTeam)
|
||||
is.NoError(err, "error creating endpoint group")
|
||||
|
||||
// create endpoint and team access policies
|
||||
// Create endpoint and team access policies
|
||||
teamAccessPolicies := make(portainer.TeamAccessPolicies, 0)
|
||||
teamAccessPolicies[teamWithEndpointAccess.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userUnderTeamWithEndpointAccess.Role)}
|
||||
|
||||
|
|
|
@ -19,12 +19,12 @@ func Test_updateUserRemovesAccessTokens(t *testing.T) {
|
|||
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
|
||||
// create standard user
|
||||
// Create standard user
|
||||
user := &portainer.User{ID: 2, Username: "standard", Role: portainer.StandardUserRole}
|
||||
err := store.User().Create(user)
|
||||
is.NoError(err, "error creating user")
|
||||
|
||||
// setup services
|
||||
// Setup services
|
||||
jwtService, err := jwt.NewService("1h", store)
|
||||
is.NoError(err, "Error initiating jwt service")
|
||||
apiKeyService := apikey.NewAPIKeyService(store.APIKeyRepository(), store.User())
|
||||
|
|
|
@ -8,12 +8,12 @@ import (
|
|||
"net/url"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/logoutcontext"
|
||||
"github.com/portainer/portainer/api/logoutcontext"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/koding/websocketproxy"
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"github.com/portainer/portainer/api/crypto"
|
||||
"github.com/portainer/portainer/api/http/proxy/factory/agent"
|
||||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
"github.com/portainer/portainer/api/internal/url"
|
||||
"github.com/portainer/portainer/api/url"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
|
|
@ -54,6 +54,7 @@ func decorateObject(object map[string]interface{}, resourceControl *portainer.Re
|
|||
|
||||
portainerMetadata := object["Portainer"].(map[string]interface{})
|
||||
portainerMetadata["ResourceControl"] = resourceControl
|
||||
|
||||
return object
|
||||
}
|
||||
|
||||
|
@ -64,8 +65,7 @@ func (transport *Transport) createPrivateResourceControl(
|
|||
|
||||
resourceControl := authorization.NewPrivateResourceControl(resourceIdentifier, resourceType, userID)
|
||||
|
||||
err := transport.dataStore.ResourceControl().Create(resourceControl)
|
||||
if err != nil {
|
||||
if err := transport.dataStore.ResourceControl().Create(resourceControl); err != nil {
|
||||
log.Error().
|
||||
Str("resource", resourceIdentifier).
|
||||
Err(err).
|
||||
|
@ -84,6 +84,7 @@ func (transport *Transport) userCanDeleteContainerGroup(request *http.Request, c
|
|||
|
||||
resourceIdentifier := request.URL.Path
|
||||
resourceControl := transport.findResourceControl(resourceIdentifier, context)
|
||||
|
||||
return authorization.UserCanAccessResource(context.userID, context.userTeamIDs, resourceControl)
|
||||
}
|
||||
|
||||
|
@ -136,20 +137,19 @@ func (transport *Transport) filterContainerGroups(containerGroups []interface{},
|
|||
|
||||
func (transport *Transport) removeResourceControl(containerGroup map[string]interface{}, context *azureRequestContext) error {
|
||||
containerGroupID, ok := containerGroup["id"].(string)
|
||||
if ok {
|
||||
resourceControl := transport.findResourceControl(containerGroupID, context)
|
||||
if resourceControl != nil {
|
||||
err := transport.dataStore.ResourceControl().Delete(resourceControl.ID)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if !ok {
|
||||
log.Debug().Msg("missing ID in container group")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if resourceControl := transport.findResourceControl(containerGroupID, context); resourceControl != nil {
|
||||
return transport.dataStore.ResourceControl().Delete(resourceControl.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (transport *Transport) findResourceControl(containerGroupId string, context *azureRequestContext) *portainer.ResourceControl {
|
||||
resourceControl := authorization.GetResourceControlByResourceIDAndType(containerGroupId, portainer.ContainerGroupResourceControl, context.resourceControls)
|
||||
return resourceControl
|
||||
return authorization.GetResourceControlByResourceIDAndType(containerGroupId, portainer.ContainerGroupResourceControl, context.resourceControls)
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
"github.com/portainer/portainer/api/http/proxy/factory/docker"
|
||||
"github.com/portainer/portainer/api/internal/url"
|
||||
"github.com/portainer/portainer/api/url"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
|
|
@ -105,8 +105,7 @@ func (transport *Transport) newResourceControlFromPortainerLabels(labelsObject m
|
|||
|
||||
resourceControl := authorization.NewRestrictedResourceControl(resourceID, resourceType, userIDs, teamIDs)
|
||||
|
||||
err := transport.dataStore.ResourceControl().Create(resourceControl)
|
||||
if err != nil {
|
||||
if err := transport.dataStore.ResourceControl().Create(resourceControl); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -119,8 +118,7 @@ func (transport *Transport) newResourceControlFromPortainerLabels(labelsObject m
|
|||
func (transport *Transport) createPrivateResourceControl(resourceIdentifier string, resourceType portainer.ResourceControlType, userID portainer.UserID) (*portainer.ResourceControl, error) {
|
||||
resourceControl := authorization.NewPrivateResourceControl(resourceIdentifier, resourceType, userID)
|
||||
|
||||
err := transport.dataStore.ResourceControl().Create(resourceControl)
|
||||
if err != nil {
|
||||
if err := transport.dataStore.ResourceControl().Create(resourceControl); err != nil {
|
||||
log.Error().
|
||||
Str("resource", resourceIdentifier).
|
||||
Err(err).
|
||||
|
@ -170,6 +168,7 @@ func (transport *Transport) applyAccessControlOnResource(parameters *resourceOpe
|
|||
systemResourceControl := findSystemNetworkResourceControl(responseObject)
|
||||
if systemResourceControl != nil {
|
||||
responseObject = decorateObject(responseObject, systemResourceControl)
|
||||
|
||||
return utils.RewriteResponse(response, responseObject, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
@ -188,6 +187,7 @@ func (transport *Transport) applyAccessControlOnResource(parameters *resourceOpe
|
|||
|
||||
if executor.operationContext.isAdmin || (resourceControl != nil && authorization.UserCanAccessResource(executor.operationContext.userID, executor.operationContext.userTeamIDs, resourceControl)) {
|
||||
responseObject = decorateObject(responseObject, resourceControl)
|
||||
|
||||
return utils.RewriteResponse(response, responseObject, http.StatusOK)
|
||||
}
|
||||
|
||||
|
@ -221,6 +221,7 @@ func (transport *Transport) decorateResourceList(parameters *resourceOperationPa
|
|||
if systemResourceControl != nil {
|
||||
resourceObject = decorateObject(resourceObject, systemResourceControl)
|
||||
decoratedResourceData = append(decoratedResourceData, resourceObject)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
@ -264,6 +265,7 @@ func (transport *Transport) filterResourceList(parameters *resourceOperationPara
|
|||
if systemResourceControl != nil {
|
||||
resourceObject = decorateObject(resourceObject, systemResourceControl)
|
||||
filteredResourceData = append(filteredResourceData, resourceObject)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
@ -277,6 +279,7 @@ func (transport *Transport) filterResourceList(parameters *resourceOperationPara
|
|||
if context.isAdmin {
|
||||
filteredResourceData = append(filteredResourceData, resourceObject)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -334,11 +337,13 @@ func (transport *Transport) findResourceControl(resourceIdentifier string, resou
|
|||
func getStackResourceIDFromLabels(resourceLabelsObject map[string]string, endpointID portainer.EndpointID) string {
|
||||
if resourceLabelsObject[resourceLabelForDockerSwarmStackName] != "" {
|
||||
stackName := resourceLabelsObject[resourceLabelForDockerSwarmStackName]
|
||||
|
||||
return stackutils.ResourceControlID(endpointID, stackName)
|
||||
}
|
||||
|
||||
if resourceLabelsObject[resourceLabelForDockerComposeStackName] != "" {
|
||||
stackName := resourceLabelsObject[resourceLabelForDockerComposeStackName]
|
||||
|
||||
return stackutils.ResourceControlID(endpointID, stackName)
|
||||
}
|
||||
|
||||
|
@ -352,5 +357,6 @@ func decorateObject(object map[string]interface{}, resourceControl *portainer.Re
|
|||
|
||||
portainerMetadata := object["Portainer"].(map[string]interface{})
|
||||
portainerMetadata["ResourceControl"] = resourceControl
|
||||
|
||||
return object
|
||||
}
|
||||
|
|
|
@ -11,9 +11,7 @@ import (
|
|||
"github.com/portainer/portainer/api/internal/authorization"
|
||||
)
|
||||
|
||||
const (
|
||||
configObjectIdentifier = "ID"
|
||||
)
|
||||
const configObjectIdentifier = "ID"
|
||||
|
||||
func getInheritedResourceControlFromConfigLabels(dockerClient *client.Client, endpointID portainer.EndpointID, configID string, resourceControls []portainer.ResourceControl) (*portainer.ResourceControl, error) {
|
||||
config, _, err := dockerClient.ConfigInspectWithRaw(context.Background(), configID)
|
||||
|
@ -78,10 +76,9 @@ func (transport *Transport) configInspectOperation(response *http.Response, exec
|
|||
// https://docs.docker.com/engine/api/v1.37/#operation/ConfigList
|
||||
// https://docs.docker.com/engine/api/v1.37/#operation/ConfigInspect
|
||||
func selectorConfigLabels(responseObject map[string]interface{}) map[string]interface{} {
|
||||
secretSpec := utils.GetJSONObject(responseObject, "Spec")
|
||||
if secretSpec != nil {
|
||||
secretLabelsObject := utils.GetJSONObject(secretSpec, "Labels")
|
||||
return secretLabelsObject
|
||||
if secretSpec := utils.GetJSONObject(responseObject, "Spec"); secretSpec != nil {
|
||||
return utils.GetJSONObject(secretSpec, "Labels")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -7,9 +7,7 @@ import (
|
|||
"github.com/portainer/portainer/api/http/proxy/factory/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
taskServiceObjectIdentifier = "ServiceID"
|
||||
)
|
||||
const taskServiceObjectIdentifier = "ServiceID"
|
||||
|
||||
// taskListOperation extracts the response as a JSON array, loop through the tasks array
|
||||
// and filter the containers based on resource controls before rewriting the response.
|
||||
|
@ -46,5 +44,6 @@ func selectorTaskLabels(responseObject map[string]interface{}) map[string]interf
|
|||
return utils.GetJSONObject(containerSpecObject, "Labels")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -7,19 +7,17 @@ import (
|
|||
"net/http"
|
||||
"path"
|
||||
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/http/proxy/factory/utils"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/authorization"
|
||||
"github.com/portainer/portainer/api/internal/snapshot"
|
||||
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
volumeObjectIdentifier = "ResourceID"
|
||||
)
|
||||
const volumeObjectIdentifier = "ResourceID"
|
||||
|
||||
func getInheritedResourceControlFromVolumeLabels(dockerClient *client.Client, endpointID portainer.EndpointID, volumeID string, resourceControls []portainer.ResourceControl) (*portainer.ResourceControl, error) {
|
||||
volume, err := dockerClient.VolumeInspect(context.Background(), volumeID)
|
||||
|
@ -57,14 +55,13 @@ func (transport *Transport) volumeListOperation(response *http.Response, executo
|
|||
Msg("snapshot is not filled into the endpoint.")
|
||||
}
|
||||
}
|
||||
|
||||
for _, volumeObject := range volumeData {
|
||||
volume := volumeObject.(map[string]interface{})
|
||||
|
||||
err = transport.decorateVolumeResponseWithResourceID(volume)
|
||||
if err != nil {
|
||||
if err := transport.decorateVolumeResponseWithResourceID(volume); err != nil {
|
||||
return fmt.Errorf("failed decorating volume response: %w", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
resourceOperationParameters := &resourceOperationParameters{
|
||||
|
@ -77,6 +74,7 @@ func (transport *Transport) volumeListOperation(response *http.Response, executo
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Overwrite the original volume list
|
||||
responseObject["Volumes"] = volumeData
|
||||
}
|
||||
|
@ -94,8 +92,7 @@ func (transport *Transport) volumeInspectOperation(response *http.Response, exec
|
|||
return err
|
||||
}
|
||||
|
||||
err = transport.decorateVolumeResponseWithResourceID(responseObject)
|
||||
if err != nil {
|
||||
if err := transport.decorateVolumeResponseWithResourceID(responseObject); err != nil {
|
||||
return fmt.Errorf("failed decorating volume response: %w", err)
|
||||
}
|
||||
|
||||
|
@ -148,8 +145,7 @@ func (transport *Transport) decorateVolumeResourceCreationOperation(request *htt
|
|||
}
|
||||
defer cli.Close()
|
||||
|
||||
_, err = cli.VolumeInspect(context.Background(), volumeID)
|
||||
if err == nil {
|
||||
if _, err = cli.VolumeInspect(context.Background(), volumeID); err == nil {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusConflict,
|
||||
}, errors.New("a volume with the same name already exists")
|
||||
|
@ -164,6 +160,7 @@ func (transport *Transport) decorateVolumeResourceCreationOperation(request *htt
|
|||
if response.StatusCode == http.StatusCreated {
|
||||
err = transport.decorateVolumeCreationResponse(response, resourceType, tokenData.ID)
|
||||
}
|
||||
|
||||
return response, err
|
||||
}
|
||||
|
||||
|
@ -195,7 +192,6 @@ func (transport *Transport) decorateVolumeCreationResponse(response *http.Respon
|
|||
}
|
||||
|
||||
func (transport *Transport) restrictedVolumeOperation(requestPath string, request *http.Request) (*http.Response, error) {
|
||||
|
||||
if request.Method == http.MethodGet {
|
||||
return transport.rewriteOperation(request, transport.volumeInspectOperation)
|
||||
}
|
||||
|
@ -210,6 +206,7 @@ func (transport *Transport) restrictedVolumeOperation(requestPath string, reques
|
|||
if request.Method == http.MethodDelete {
|
||||
return transport.executeGenericResourceDeletionOperation(request, resourceID, volumeName, portainer.VolumeResourceControl)
|
||||
}
|
||||
|
||||
return transport.restrictedResourceOperation(request, resourceID, volumeName, portainer.VolumeResourceControl, false)
|
||||
}
|
||||
|
||||
|
@ -218,6 +215,7 @@ func (transport *Transport) getVolumeResourceID(volumeName string) (string, erro
|
|||
if err != nil {
|
||||
return "", fmt.Errorf("failed fetching docker id: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s_%s", volumeName, dockerID), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
"github.com/portainer/portainer/api/internal/tag"
|
||||
"github.com/portainer/portainer/api/tag"
|
||||
)
|
||||
|
||||
// EdgeGroupRelatedEndpoints returns a list of environments(endpoints) related to this Edge group
|
||||
|
|
|
@ -37,12 +37,12 @@ func (service *Service) BuildEdgeStack(
|
|||
registries []portainer.RegistryID,
|
||||
useManifestNamespaces bool,
|
||||
) (*portainer.EdgeStack, error) {
|
||||
err := validateUniqueName(tx.EdgeStack().EdgeStacks, name)
|
||||
if err != nil {
|
||||
if err := validateUniqueName(tx.EdgeStack().EdgeStacks, name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stackID := tx.EdgeStack().GetNextIdentifier()
|
||||
|
||||
return &portainer.EdgeStack{
|
||||
ID: portainer.EdgeStackID(stackID),
|
||||
Name: name,
|
||||
|
@ -77,7 +77,6 @@ func (service *Service) PersistEdgeStack(
|
|||
storeManifest edgetypes.StoreManifestFunc) (*portainer.EdgeStack, error) {
|
||||
|
||||
relationConfig, err := edge.FetchEndpointRelationsConfig(tx)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to find environment relations in database: %w", err)
|
||||
}
|
||||
|
@ -87,6 +86,7 @@ func (service *Service) PersistEdgeStack(
|
|||
if errors.Is(err, edge.ErrEdgeGroupNotFound) {
|
||||
return nil, httperrors.NewInvalidPayloadError(err.Error())
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to persist environment relation in database: %w", err)
|
||||
}
|
||||
|
||||
|
@ -101,13 +101,11 @@ func (service *Service) PersistEdgeStack(
|
|||
stack.EntryPoint = composePath
|
||||
stack.NumDeployments = len(relatedEndpointIds)
|
||||
|
||||
err = service.updateEndpointRelations(tx, stack.ID, relatedEndpointIds)
|
||||
if err != nil {
|
||||
if err := service.updateEndpointRelations(tx, stack.ID, relatedEndpointIds); err != nil {
|
||||
return nil, fmt.Errorf("unable to update endpoint relations: %w", err)
|
||||
}
|
||||
|
||||
err = tx.EdgeStack().Create(stack.ID, stack)
|
||||
if err != nil {
|
||||
if err := tx.EdgeStack().Create(stack.ID, stack); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -126,8 +124,7 @@ func (service *Service) updateEndpointRelations(tx dataservices.DataStoreTx, edg
|
|||
|
||||
relation.EdgeStacks[edgeStackID] = true
|
||||
|
||||
err = endpointRelationService.UpdateEndpointRelation(endpointID, relation)
|
||||
if err != nil {
|
||||
if err := endpointRelationService.UpdateEndpointRelation(endpointID, relation); err != nil {
|
||||
return fmt.Errorf("unable to persist endpoint relation in database: %w", err)
|
||||
}
|
||||
}
|
||||
|
@ -155,14 +152,12 @@ func (service *Service) DeleteEdgeStack(tx dataservices.DataStoreTx, edgeStackID
|
|||
|
||||
delete(relation.EdgeStacks, edgeStackID)
|
||||
|
||||
err = tx.EndpointRelation().UpdateEndpointRelation(endpointID, relation)
|
||||
if err != nil {
|
||||
if err := tx.EndpointRelation().UpdateEndpointRelation(endpointID, relation); err != nil {
|
||||
return errors.WithMessage(err, "Unable to persist environment relation in database")
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.EdgeStack().DeleteEdgeStack(edgeStackID)
|
||||
if err != nil {
|
||||
if err := tx.EdgeStack().DeleteEdgeStack(edgeStackID); err != nil {
|
||||
return errors.WithMessage(err, "Unable to remove the edge stack from the database")
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
// NodesCount returns the total node number of all environments
|
||||
func NodesCount(endpoints []portainer.Endpoint) int {
|
||||
nodes := 0
|
||||
|
||||
for _, env := range endpoints {
|
||||
if !endpointutils.IsEdgeEndpoint(&env) || env.UserTrusted {
|
||||
nodes += countNodes(&env)
|
||||
|
@ -28,11 +29,3 @@ func countNodes(endpoint *portainer.Endpoint) int {
|
|||
|
||||
return 1
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
package securecookie
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"io"
|
||||
)
|
||||
|
||||
// GenerateRandomKey generates a random key of specified length
|
||||
// source: https://github.com/gorilla/securecookie/blob/master/securecookie.go#L515
|
||||
func GenerateRandomKey(length int) []byte {
|
||||
k := make([]byte, length)
|
||||
if _, err := io.ReadFull(rand.Reader, k); err != nil {
|
||||
return nil
|
||||
}
|
||||
return k
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
package slices
|
||||
|
||||
// Map applies the given function to each element of the slice and returns a new slice with the results
|
||||
func Map[T, U any](s []T, f func(T) U) []U {
|
||||
result := make([]U, len(s))
|
||||
for i, v := range s {
|
||||
result[i] = f(v)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Filter returns a new slice containing only the elements of the slice for which the given predicate returns true
|
||||
func Filter[T any](s []T, predicate func(T) bool) []T {
|
||||
n := 0
|
||||
for _, v := range s {
|
||||
if predicate(v) {
|
||||
s[n] = v
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
return s[:n]
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
package unique
|
||||
|
||||
func Unique[T comparable](items []T) []T {
|
||||
return UniqueBy(items, func(item T) T {
|
||||
return item
|
||||
})
|
||||
}
|
||||
|
||||
func UniqueBy[ItemType any, ComparableType comparable](items []ItemType, accessorFunc func(ItemType) ComparableType) []ItemType {
|
||||
includedItems := make(map[ComparableType]bool)
|
||||
result := []ItemType{}
|
||||
|
||||
for _, item := range items {
|
||||
if _, isIncluded := includedItems[accessorFunc(item)]; !isIncluded {
|
||||
includedItems[accessorFunc(item)] = true
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
type someType struct {
|
||||
id int
|
||||
fn func()
|
||||
}
|
||||
|
||||
func Test() {
|
||||
ids := []int{1, 2, 3, 3}
|
||||
_ = UniqueBy(ids, func(id int) int { return id })
|
||||
_ = Unique(ids) // shorthand for UniqueBy Identity/self
|
||||
|
||||
as := []someType{{id: 1}, {id: 2}, {id: 3}, {id: 3}}
|
||||
_ = UniqueBy(as, func(item someType) int { return item.id }) // no error
|
||||
_ = UniqueBy(as, func(item someType) someType { return item }) // compile error - someType is not comparable
|
||||
_ = Unique(as) // compile error - shorthand fails for the same reason
|
||||
}
|
||||
|
||||
*/
|
|
@ -6,10 +6,10 @@ import (
|
|||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/apikey"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/portainer/portainer/api/internal/securecookie"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
|
@ -51,7 +51,7 @@ func NewService(userSessionDuration string, dataStore dataservices.DataStore) (*
|
|||
return nil, err
|
||||
}
|
||||
|
||||
secret := securecookie.GenerateRandomKey(32)
|
||||
secret := apikey.GenerateRandomKey(32)
|
||||
if secret == nil {
|
||||
return nil, errSecretGeneration
|
||||
}
|
||||
|
@ -69,6 +69,7 @@ func NewService(userSessionDuration string, dataStore dataservices.DataStore) (*
|
|||
userSessionTimeout,
|
||||
dataStore,
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
|
@ -80,16 +81,18 @@ func getOrCreateKubeSecret(dataStore dataservices.DataStore) ([]byte, error) {
|
|||
|
||||
kubeSecret := settings.OAuthSettings.KubeSecretKey
|
||||
if kubeSecret == nil {
|
||||
kubeSecret = securecookie.GenerateRandomKey(32)
|
||||
kubeSecret = apikey.GenerateRandomKey(32)
|
||||
if kubeSecret == nil {
|
||||
return nil, errSecretGeneration
|
||||
}
|
||||
|
||||
settings.OAuthSettings.KubeSecretKey = kubeSecret
|
||||
err = dataStore.Settings().UpdateSettings(settings)
|
||||
if err != nil {
|
||||
|
||||
if err := dataStore.Settings().UpdateSettings(settings); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return kubeSecret, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -3,8 +3,8 @@ package cli
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/portainer/portainer/api/concurrent"
|
||||
models "github.com/portainer/portainer/api/http/models/kubernetes"
|
||||
"github.com/portainer/portainer/api/internal/concurrent"
|
||||
|
||||
"k8s.io/apimachinery/pkg/api/errors"
|
||||
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
|
||||
models "github.com/portainer/portainer/api/http/models/kubernetes"
|
||||
|
||||
v1 "k8s.io/api/core/v1"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
labels "k8s.io/apimachinery/pkg/labels"
|
||||
|
@ -67,9 +68,7 @@ func (kcl *KubeClient) GetServices(namespace string, lookupApplications bool) ([
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// CreateService creates a new service in a given namespace in a k8s endpoint.
|
||||
func (kcl *KubeClient) CreateService(namespace string, info models.K8sServiceInfo) error {
|
||||
ServiceClient := kcl.cli.CoreV1().Services(namespace)
|
||||
func (kcl *KubeClient) fillService(info models.K8sServiceInfo) v1.Service {
|
||||
var service v1.Service
|
||||
|
||||
service.Name = info.Name
|
||||
|
@ -93,16 +92,21 @@ func (kcl *KubeClient) CreateService(namespace string, info models.K8sServiceInf
|
|||
|
||||
// Set ingresses.
|
||||
for _, i := range info.IngressStatus {
|
||||
var ing v1.LoadBalancerIngress
|
||||
ing.IP = i.IP
|
||||
ing.Hostname = i.Host
|
||||
service.Status.LoadBalancer.Ingress = append(
|
||||
service.Status.LoadBalancer.Ingress,
|
||||
ing,
|
||||
v1.LoadBalancerIngress{IP: i.IP, Hostname: i.Host},
|
||||
)
|
||||
}
|
||||
|
||||
_, err := ServiceClient.Create(context.Background(), &service, metav1.CreateOptions{})
|
||||
return service
|
||||
}
|
||||
|
||||
// CreateService creates a new service in a given namespace in a k8s endpoint.
|
||||
func (kcl *KubeClient) CreateService(namespace string, info models.K8sServiceInfo) error {
|
||||
serviceClient := kcl.cli.CoreV1().Services(namespace)
|
||||
service := kcl.fillService(info)
|
||||
|
||||
_, err := serviceClient.Create(context.Background(), &service, metav1.CreateOptions{})
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -120,45 +124,16 @@ func (kcl *KubeClient) DeleteServices(reqs models.K8sServiceDeleteRequests) erro
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateService updates service in a given namespace in a k8s endpoint.
|
||||
func (kcl *KubeClient) UpdateService(namespace string, info models.K8sServiceInfo) error {
|
||||
ServiceClient := kcl.cli.CoreV1().Services(namespace)
|
||||
var service v1.Service
|
||||
serviceClient := kcl.cli.CoreV1().Services(namespace)
|
||||
service := kcl.fillService(info)
|
||||
|
||||
service.Name = info.Name
|
||||
service.Spec.Type = v1.ServiceType(info.Type)
|
||||
service.Namespace = info.Namespace
|
||||
service.Annotations = info.Annotations
|
||||
service.Labels = info.Labels
|
||||
service.Spec.AllocateLoadBalancerNodePorts = info.AllocateLoadBalancerNodePorts
|
||||
service.Spec.Selector = info.Selector
|
||||
|
||||
// Set ports.
|
||||
for _, p := range info.Ports {
|
||||
var port v1.ServicePort
|
||||
port.Name = p.Name
|
||||
port.NodePort = int32(p.NodePort)
|
||||
port.Port = int32(p.Port)
|
||||
port.Protocol = v1.Protocol(p.Protocol)
|
||||
port.TargetPort = intstr.FromString(p.TargetPort)
|
||||
service.Spec.Ports = append(service.Spec.Ports, port)
|
||||
}
|
||||
|
||||
// Set ingresses.
|
||||
for _, i := range info.IngressStatus {
|
||||
var ing v1.LoadBalancerIngress
|
||||
ing.IP = i.IP
|
||||
ing.Hostname = i.Host
|
||||
service.Status.LoadBalancer.Ingress = append(
|
||||
service.Status.LoadBalancer.Ingress,
|
||||
ing,
|
||||
)
|
||||
}
|
||||
|
||||
_, err := ServiceClient.Update(context.Background(), &service, metav1.UpdateOptions{})
|
||||
_, err := serviceClient.Update(context.Background(), &service, metav1.UpdateOptions{})
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -210,5 +185,4 @@ func makeApplication(meta metav1.Object) []models.K8sApplication {
|
|||
Name: ownerReference.Name,
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -16,8 +16,7 @@ func Test_getOAuthToken(t *testing.T) {
|
|||
|
||||
t.Run("getOAuthToken fails upon invalid code", func(t *testing.T) {
|
||||
code := ""
|
||||
_, err := getOAuthToken(code, config)
|
||||
if err == nil {
|
||||
if _, err := getOAuthToken(code, config); err == nil {
|
||||
t.Errorf("getOAuthToken should fail upon providing invalid code; code=%v", code)
|
||||
}
|
||||
})
|
||||
|
@ -91,22 +90,19 @@ func Test_getResource(t *testing.T) {
|
|||
defer srv.Close()
|
||||
|
||||
t.Run("should fail upon missing Authorization Bearer header", func(t *testing.T) {
|
||||
_, err := getResource("", config)
|
||||
if err == nil {
|
||||
if _, err := getResource("", config); err == nil {
|
||||
t.Errorf("getResource should fail if access token is not provided in auth bearer header")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should fail upon providing incorrect Authorization Bearer header", func(t *testing.T) {
|
||||
_, err := getResource("incorrect-token", config)
|
||||
if err == nil {
|
||||
if _, err := getResource("incorrect-token", config); err == nil {
|
||||
t.Errorf("getResource should fail if incorrect access token provided in auth bearer header")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should succeed upon providing correct Authorization Bearer header", func(t *testing.T) {
|
||||
_, err := getResource(oauthtest.AccessToken, config)
|
||||
if err != nil {
|
||||
if _, err := getResource(oauthtest.AccessToken, config); err != nil {
|
||||
t.Errorf("getResource should succeed if correct access token provided in auth bearer header")
|
||||
}
|
||||
})
|
||||
|
@ -120,8 +116,7 @@ func Test_Authenticate(t *testing.T) {
|
|||
srv, config := oauthtest.RunOAuthServer(code, &portainer.OAuthSettings{})
|
||||
defer srv.Close()
|
||||
|
||||
_, err := authService.Authenticate(code, config)
|
||||
if err == nil {
|
||||
if _, err := authService.Authenticate(code, config); err == nil {
|
||||
t.Error("Authenticate should fail to extract username from resource if incorrect UserIdentifier provided")
|
||||
}
|
||||
})
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
gittypes "github.com/portainer/portainer/api/git/types"
|
||||
models "github.com/portainer/portainer/api/http/models/kubernetes"
|
||||
"github.com/portainer/portainer/pkg/featureflags"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
v1 "k8s.io/api/core/v1"
|
||||
"k8s.io/apimachinery/pkg/version"
|
||||
|
@ -322,14 +323,14 @@ type (
|
|||
Name string `json:"Name"`
|
||||
Status map[EndpointID]EdgeStackStatus `json:"Status"`
|
||||
// StatusArray map[EndpointID][]EdgeStackStatus `json:"StatusArray"`
|
||||
CreationDate int64 `json:"CreationDate"`
|
||||
EdgeGroups []EdgeGroupID `json:"EdgeGroups"`
|
||||
ProjectPath string `json:"ProjectPath"`
|
||||
EntryPoint string `json:"EntryPoint"`
|
||||
Version int `json:"Version"`
|
||||
NumDeployments int `json:"NumDeployments"`
|
||||
ManifestPath string
|
||||
DeploymentType EdgeStackDeploymentType
|
||||
CreationDate int64 `json:"CreationDate"`
|
||||
EdgeGroups []EdgeGroupID `json:"EdgeGroups"`
|
||||
ProjectPath string `json:"ProjectPath"`
|
||||
EntryPoint string `json:"EntryPoint"`
|
||||
Version int `json:"Version"`
|
||||
NumDeployments int `json:"NumDeployments"`
|
||||
ManifestPath string `json:"ManifestPath"`
|
||||
DeploymentType EdgeStackDeploymentType `json:"DeploymentType"`
|
||||
// Uses the manifest's namespaces instead of the default one
|
||||
UseManifestNamespaces bool
|
||||
|
||||
|
@ -554,23 +555,22 @@ type (
|
|||
|
||||
// Extension represents a deprecated Portainer extension
|
||||
Extension struct {
|
||||
// Extension Identifier
|
||||
ID ExtensionID `json:"Id" example:"1"`
|
||||
Enabled bool `json:"Enabled"`
|
||||
Name string `json:"Name,omitempty"`
|
||||
ShortDescription string `json:"ShortDescription,omitempty"`
|
||||
Description string `json:"Description,omitempty"`
|
||||
DescriptionURL string `json:"DescriptionURL,omitempty"`
|
||||
Price string `json:"Price,omitempty"`
|
||||
PriceDescription string `json:"PriceDescription,omitempty"`
|
||||
Deal bool `json:"Deal,omitempty"`
|
||||
Available bool `json:"Available,omitempty"`
|
||||
License LicenseInformation `json:"License,omitempty"`
|
||||
Version string `json:"Version"`
|
||||
UpdateAvailable bool `json:"UpdateAvailable"`
|
||||
ShopURL string `json:"ShopURL,omitempty"`
|
||||
Images []string `json:"Images,omitempty"`
|
||||
Logo string `json:"Logo,omitempty"`
|
||||
ID ExtensionID `json:"Id" example:"1"`
|
||||
Enabled bool `json:"Enabled"`
|
||||
Name string `json:"Name,omitempty"`
|
||||
ShortDescription string `json:"ShortDescription,omitempty"`
|
||||
Description string `json:"Description,omitempty"`
|
||||
DescriptionURL string `json:"DescriptionURL,omitempty"`
|
||||
Price string `json:"Price,omitempty"`
|
||||
PriceDescription string `json:"PriceDescription,omitempty"`
|
||||
Deal bool `json:"Deal,omitempty"`
|
||||
Available bool `json:"Available,omitempty"`
|
||||
License ExtensionLicenseInformation `json:"License,omitempty"`
|
||||
Version string `json:"Version"`
|
||||
UpdateAvailable bool `json:"UpdateAvailable"`
|
||||
ShopURL string `json:"ShopURL,omitempty"`
|
||||
Images []string `json:"Images,omitempty"`
|
||||
Logo string `json:"Logo,omitempty"`
|
||||
}
|
||||
|
||||
// ExtensionID represents a extension identifier
|
||||
|
@ -737,8 +737,8 @@ type (
|
|||
Groups []string
|
||||
}
|
||||
|
||||
// LicenseInformation represents information about an extension license
|
||||
LicenseInformation struct {
|
||||
// ExtensionLicenseInformation represents information about an extension license
|
||||
ExtensionLicenseInformation struct {
|
||||
LicenseKey string `json:"LicenseKey,omitempty"`
|
||||
Company string `json:"Company,omitempty"`
|
||||
Expiration string `json:"Expiration,omitempty"`
|
||||
|
@ -939,6 +939,18 @@ type (
|
|||
HideStacksFunctionality bool `json:"hideStacksFunctionality" example:"false"`
|
||||
}
|
||||
|
||||
Edge struct {
|
||||
// The command list interval for edge agent - used in edge async mode (in seconds)
|
||||
CommandInterval int `json:"CommandInterval" example:"5"`
|
||||
// The ping interval for edge agent - used in edge async mode (in seconds)
|
||||
PingInterval int `json:"PingInterval" example:"5"`
|
||||
// The snapshot interval for edge agent - used in edge async mode (in seconds)
|
||||
SnapshotInterval int `json:"SnapshotInterval" example:"5"`
|
||||
|
||||
// Deprecated 2.18
|
||||
AsyncMode bool `json:"AsyncMode,omitempty" example:"false"`
|
||||
}
|
||||
|
||||
// Settings represents the application settings
|
||||
Settings struct {
|
||||
// URL to a logo that will be displayed on the login page as well as on top of the sidebar. Will use default Portainer logo when value is empty string
|
||||
|
@ -984,17 +996,7 @@ type (
|
|||
// EdgePortainerURL is the URL that is exposed to edge agents
|
||||
EdgePortainerURL string `json:"EdgePortainerUrl"`
|
||||
|
||||
Edge struct {
|
||||
// The command list interval for edge agent - used in edge async mode (in seconds)
|
||||
CommandInterval int `json:"CommandInterval" example:"5"`
|
||||
// The ping interval for edge agent - used in edge async mode (in seconds)
|
||||
PingInterval int `json:"PingInterval" example:"5"`
|
||||
// The snapshot interval for edge agent - used in edge async mode (in seconds)
|
||||
SnapshotInterval int `json:"SnapshotInterval" example:"5"`
|
||||
|
||||
// Deprecated 2.18
|
||||
AsyncMode bool
|
||||
}
|
||||
Edge Edge `json:"Edge"`
|
||||
|
||||
// Deprecated fields
|
||||
DisplayDonationHeader bool `json:"DisplayDonationHeader,omitempty"`
|
||||
|
|
|
@ -58,8 +58,8 @@ func (s Set[T]) Copy() Set[T] {
|
|||
}
|
||||
|
||||
// Difference returns a new set containing the keys that are in the first set but not in the second set.
|
||||
func (set Set[T]) Difference(second Set[T]) Set[T] {
|
||||
difference := set.Copy()
|
||||
func (s Set[T]) Difference(second Set[T]) Set[T] {
|
||||
difference := s.Copy()
|
||||
|
||||
for key := range second {
|
||||
difference.Remove(key)
|
|
@ -0,0 +1,43 @@
|
|||
package slicesx
|
||||
|
||||
// Map applies the given function to each element of the slice and returns a new slice with the results
|
||||
func Map[T, U any](s []T, f func(T) U) []U {
|
||||
result := make([]U, len(s))
|
||||
for i, v := range s {
|
||||
result[i] = f(v)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Filter returns a new slice containing only the elements of the slice for which the given predicate returns true
|
||||
func Filter[T any](s []T, predicate func(T) bool) []T {
|
||||
n := 0
|
||||
for _, v := range s {
|
||||
if predicate(v) {
|
||||
s[n] = v
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
return s[:n]
|
||||
}
|
||||
|
||||
func Unique[T comparable](items []T) []T {
|
||||
return UniqueBy(items, func(item T) T {
|
||||
return item
|
||||
})
|
||||
}
|
||||
|
||||
func UniqueBy[ItemType any, ComparableType comparable](items []ItemType, accessorFunc func(ItemType) ComparableType) []ItemType {
|
||||
includedItems := make(map[ComparableType]bool)
|
||||
result := []ItemType{}
|
||||
|
||||
for _, item := range items {
|
||||
if _, isIncluded := includedItems[accessorFunc(item)]; !isIncluded {
|
||||
includedItems[accessorFunc(item)] = true
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package slices
|
||||
package slicesx
|
||||
|
||||
import (
|
||||
"strconv"
|
Loading…
Reference in New Issue