chore(code): reduce the code duplication EE-7278 (#11969)

pull/11499/merge
andres-portainer 2024-06-26 18:14:22 -03:00 committed by GitHub
parent 39bdfa4512
commit 9ee092aa5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
85 changed files with 520 additions and 618 deletions

View File

@ -10,7 +10,7 @@ import (
"time"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/internal/url"
"github.com/portainer/portainer/api/url"
)
// GetAgentVersionAndPlatform returns the agent version and platform

View File

@ -3,7 +3,6 @@ package apikey
import (
"testing"
"github.com/portainer/portainer/api/internal/securecookie"
"github.com/stretchr/testify/assert"
)
@ -34,7 +33,7 @@ func Test_generateRandomKey(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := securecookie.GenerateRandomKey(tt.wantLenth)
got := GenerateRandomKey(tt.wantLenth)
is.Equal(tt.wantLenth, len(got))
})
}
@ -42,7 +41,7 @@ func Test_generateRandomKey(t *testing.T) {
t.Run("Generated keys are unique", func(t *testing.T) {
keys := make(map[string]bool)
for i := 0; i < 100; i++ {
key := securecookie.GenerateRandomKey(8)
key := GenerateRandomKey(8)
_, ok := keys[string(key)]
is.False(ok)
keys[string(key)] = true

View File

@ -1,69 +1,79 @@
package apikey
import (
lru "github.com/hashicorp/golang-lru"
portainer "github.com/portainer/portainer/api"
lru "github.com/hashicorp/golang-lru"
)
const defaultAPIKeyCacheSize = 1024
const DefaultAPIKeyCacheSize = 1024
// entry is a tuple containing the user and API key associated to an API key digest
type entry struct {
user portainer.User
type entry[T any] struct {
user T
apiKey portainer.APIKey
}
// apiKeyCache is a concurrency-safe, in-memory cache which primarily exists for to reduce database roundtrips.
type UserCompareFn[T any] func(T, portainer.UserID) bool
// ApiKeyCache is a concurrency-safe, in-memory cache which primarily exists for to reduce database roundtrips.
// We store the api-key digest (keys) and the associated user and key-data (values) in the cache.
// This is required because HTTP requests will contain only the api-key digest in the x-api-key request header;
// digest value must be mapped to a portainer user (and respective key data) for validation.
// This cache is used to avoid multiple database queries to retrieve these user/key associated to the digest.
type apiKeyCache struct {
type ApiKeyCache[T any] struct {
// cache type [string]entry cache (key: string(digest), value: user/key entry)
// note: []byte keys are not supported by golang-lru Cache
cache *lru.Cache
cache *lru.Cache
userCmpFn UserCompareFn[T]
}
// NewAPIKeyCache creates a new cache for API keys
func NewAPIKeyCache(cacheSize int) *apiKeyCache {
func NewAPIKeyCache[T any](cacheSize int, userCompareFn UserCompareFn[T]) *ApiKeyCache[T] {
cache, _ := lru.New(cacheSize)
return &apiKeyCache{cache: cache}
return &ApiKeyCache[T]{cache: cache, userCmpFn: userCompareFn}
}
// Get returns the user/key associated to an api-key's digest
// This is required because HTTP requests will contain the digest of the API key in header,
// the digest value must be mapped to a portainer user.
func (c *apiKeyCache) Get(digest string) (portainer.User, portainer.APIKey, bool) {
func (c *ApiKeyCache[T]) Get(digest string) (T, portainer.APIKey, bool) {
val, ok := c.cache.Get(digest)
if !ok {
return portainer.User{}, portainer.APIKey{}, false
var t T
return t, portainer.APIKey{}, false
}
tuple := val.(entry)
tuple := val.(entry[T])
return tuple.user, tuple.apiKey, true
}
// Set persists a user/key entry to the cache
func (c *apiKeyCache) Set(digest string, user portainer.User, apiKey portainer.APIKey) {
c.cache.Add(digest, entry{
func (c *ApiKeyCache[T]) Set(digest string, user T, apiKey portainer.APIKey) {
c.cache.Add(digest, entry[T]{
user: user,
apiKey: apiKey,
})
}
// Delete evicts a digest's user/key entry key from the cache
func (c *apiKeyCache) Delete(digest string) {
func (c *ApiKeyCache[T]) Delete(digest string) {
c.cache.Remove(digest)
}
// InvalidateUserKeyCache loops through all the api-keys associated to a user and removes them from the cache
func (c *apiKeyCache) InvalidateUserKeyCache(userId portainer.UserID) bool {
func (c *ApiKeyCache[T]) InvalidateUserKeyCache(userId portainer.UserID) bool {
present := false
for _, k := range c.cache.Keys() {
user, _, _ := c.Get(k.(string))
if user.ID == userId {
if c.userCmpFn(user, userId) {
present = c.cache.Remove(k)
}
}
return present
}

View File

@ -10,11 +10,11 @@ import (
func Test_apiKeyCacheGet(t *testing.T) {
is := assert.New(t)
keyCache := NewAPIKeyCache(10)
keyCache := NewAPIKeyCache(10, compareUser)
// pre-populate cache
keyCache.cache.Add(string("foo"), entry{user: portainer.User{}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string(""), entry{user: portainer.User{}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string(""), entry[portainer.User]{user: portainer.User{}, apiKey: portainer.APIKey{}})
tests := []struct {
digest string
@ -45,7 +45,7 @@ func Test_apiKeyCacheGet(t *testing.T) {
func Test_apiKeyCacheSet(t *testing.T) {
is := assert.New(t)
keyCache := NewAPIKeyCache(10)
keyCache := NewAPIKeyCache(10, compareUser)
// pre-populate cache
keyCache.Set("bar", portainer.User{ID: 2}, portainer.APIKey{})
@ -57,23 +57,23 @@ func Test_apiKeyCacheSet(t *testing.T) {
val, ok := keyCache.cache.Get(string("bar"))
is.True(ok)
tuple := val.(entry)
tuple := val.(entry[portainer.User])
is.Equal(portainer.User{ID: 2}, tuple.user)
val, ok = keyCache.cache.Get(string("foo"))
is.True(ok)
tuple = val.(entry)
tuple = val.(entry[portainer.User])
is.Equal(portainer.User{ID: 3}, tuple.user)
}
func Test_apiKeyCacheDelete(t *testing.T) {
is := assert.New(t)
keyCache := NewAPIKeyCache(10)
keyCache := NewAPIKeyCache(10, compareUser)
t.Run("Delete an existing entry", func(t *testing.T) {
keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
keyCache.Delete("foo")
_, ok := keyCache.cache.Get(string("foo"))
@ -128,7 +128,7 @@ func Test_apiKeyCacheLRU(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
keyCache := NewAPIKeyCache(test.cacheLen)
keyCache := NewAPIKeyCache(test.cacheLen, compareUser)
for _, key := range test.key {
keyCache.Set(key, portainer.User{ID: 1}, portainer.APIKey{})
@ -150,10 +150,10 @@ func Test_apiKeyCacheLRU(t *testing.T) {
func Test_apiKeyCacheInvalidateUserKeyCache(t *testing.T) {
is := assert.New(t)
keyCache := NewAPIKeyCache(10)
keyCache := NewAPIKeyCache(10, compareUser)
t.Run("Removes users keys from cache", func(t *testing.T) {
keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
ok := keyCache.InvalidateUserKeyCache(1)
is.True(ok)
@ -163,8 +163,8 @@ func Test_apiKeyCacheInvalidateUserKeyCache(t *testing.T) {
})
t.Run("Does not affect other keys", func(t *testing.T) {
keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string("bar"), entry{user: portainer.User{ID: 2}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string("bar"), entry[portainer.User]{user: portainer.User{ID: 2}, apiKey: portainer.APIKey{}})
ok := keyCache.InvalidateUserKeyCache(1)
is.True(ok)

View File

@ -1,14 +1,15 @@
package apikey
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"time"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/securecookie"
"github.com/pkg/errors"
)
@ -20,30 +21,45 @@ var ErrInvalidAPIKey = errors.New("Invalid API key")
type apiKeyService struct {
apiKeyRepository dataservices.APIKeyRepository
userRepository dataservices.UserService
cache *apiKeyCache
cache *ApiKeyCache[portainer.User]
}
// GenerateRandomKey generates a random key of specified length
// source: https://github.com/gorilla/securecookie/blob/master/securecookie.go#L515
func GenerateRandomKey(length int) []byte {
k := make([]byte, length)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return nil
}
return k
}
func compareUser(u portainer.User, id portainer.UserID) bool {
return u.ID == id
}
func NewAPIKeyService(apiKeyRepository dataservices.APIKeyRepository, userRepository dataservices.UserService) *apiKeyService {
return &apiKeyService{
apiKeyRepository: apiKeyRepository,
userRepository: userRepository,
cache: NewAPIKeyCache(defaultAPIKeyCacheSize),
cache: NewAPIKeyCache(DefaultAPIKeyCacheSize, compareUser),
}
}
// HashRaw computes a hash digest of provided raw API key.
func (a *apiKeyService) HashRaw(rawKey string) string {
hashDigest := sha256.Sum256([]byte(rawKey))
return base64.StdEncoding.EncodeToString(hashDigest[:])
}
// GenerateApiKey generates a raw API key for a user (for one-time display).
// The generated API key is stored in the cache and database.
func (a *apiKeyService) GenerateApiKey(user portainer.User, description string) (string, *portainer.APIKey, error) {
randKey := securecookie.GenerateRandomKey(32)
randKey := GenerateRandomKey(32)
encodedRawAPIKey := base64.StdEncoding.EncodeToString(randKey)
prefixedAPIKey := portainerAPIKeyPrefix + encodedRawAPIKey
hashDigest := a.HashRaw(prefixedAPIKey)
apiKey := &portainer.APIKey{
@ -54,8 +70,7 @@ func (a *apiKeyService) GenerateApiKey(user portainer.User, description string)
Digest: hashDigest,
}
err := a.apiKeyRepository.Create(apiKey)
if err != nil {
if err := a.apiKeyRepository.Create(apiKey); err != nil {
return "", nil, errors.Wrap(err, "Unable to create API key")
}
@ -78,7 +93,6 @@ func (a *apiKeyService) GetAPIKeys(userID portainer.UserID) ([]portainer.APIKey,
// GetDigestUserAndKey returns the user and api-key associated to a specified hash digest.
// A cache lookup is performed first; if the user/api-key is not found in the cache, respective database lookups are performed.
func (a *apiKeyService) GetDigestUserAndKey(digest string) (portainer.User, portainer.APIKey, error) {
// get api key from cache if possible
cachedUser, cachedKey, ok := a.cache.Get(digest)
if ok {
return cachedUser, cachedKey, nil
@ -106,20 +120,21 @@ func (a *apiKeyService) UpdateAPIKey(apiKey *portainer.APIKey) error {
if err != nil {
return errors.Wrap(err, "Unable to retrieve API key")
}
a.cache.Set(apiKey.Digest, user, *apiKey)
return a.apiKeyRepository.Update(apiKey.ID, apiKey)
}
// DeleteAPIKey deletes an API key and removes the digest/api-key entry from the cache.
func (a *apiKeyService) DeleteAPIKey(apiKeyID portainer.APIKeyID) error {
// get api-key digest to remove from cache
apiKey, err := a.apiKeyRepository.Read(apiKeyID)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("Unable to retrieve API key: %d", apiKeyID))
}
// delete the user/api-key from cache
a.cache.Delete(apiKey.Digest)
return a.apiKeyRepository.Delete(apiKeyID)
}

View File

@ -17,17 +17,14 @@ import (
type Service struct{}
var (
errInvalidEndpointProtocol = errors.New("Invalid environment protocol: Portainer only supports unix://, npipe:// or tcp://")
errSocketOrNamedPipeNotFound = errors.New("Unable to locate Unix socket or named pipe")
errInvalidSnapshotInterval = errors.New("Invalid snapshot interval")
errAdminPassExcludeAdminPassFile = errors.New("Cannot use --admin-password with --admin-password-file")
ErrInvalidEndpointProtocol = errors.New("Invalid environment protocol: Portainer only supports unix://, npipe:// or tcp://")
ErrSocketOrNamedPipeNotFound = errors.New("Unable to locate Unix socket or named pipe")
ErrInvalidSnapshotInterval = errors.New("Invalid snapshot interval")
ErrAdminPassExcludeAdminPassFile = errors.New("Cannot use --admin-password with --admin-password-file")
)
// ParseFlags parse the CLI flags and return a portainer.Flags struct
func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
kingpin.Version(version)
flags := &portainer.CLIFlags{
func CLIFlags() *portainer.CLIFlags {
return &portainer.CLIFlags{
Addr: kingpin.Flag("bind", "Address and port to serve Portainer").Default(defaultBindAddress).Short('p').String(),
AddrHTTPS: kingpin.Flag("bind-https", "Address and port to serve Portainer via https").Default(defaultHTTPSBindAddress).String(),
TunnelAddr: kingpin.Flag("tunnel-addr", "Address to serve the tunnel server").Default(defaultTunnelServerAddress).String(),
@ -63,6 +60,13 @@ func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
LogLevel: kingpin.Flag("log-level", "Set the minimum logging level to show").Default("INFO").Enum("DEBUG", "INFO", "WARN", "ERROR"),
LogMode: kingpin.Flag("log-mode", "Set the logging output mode").Default("PRETTY").Enum("NOCOLOR", "PRETTY", "JSON"),
}
}
// ParseFlags parse the CLI flags and return a portainer.Flags struct
func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
kingpin.Version(version)
flags := CLIFlags()
kingpin.Parse()
@ -82,18 +86,16 @@ func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
func (*Service) ValidateFlags(flags *portainer.CLIFlags) error {
displayDeprecationWarnings(flags)
err := validateEndpointURL(*flags.EndpointURL)
if err != nil {
if err := validateEndpointURL(*flags.EndpointURL); err != nil {
return err
}
err = validateSnapshotInterval(*flags.SnapshotInterval)
if err != nil {
if err := validateSnapshotInterval(*flags.SnapshotInterval); err != nil {
return err
}
if *flags.AdminPassword != "" && *flags.AdminPasswordFile != "" {
return errAdminPassExcludeAdminPassFile
return ErrAdminPassExcludeAdminPassFile
}
return nil
@ -115,15 +117,16 @@ func validateEndpointURL(endpointURL string) error {
}
if !strings.HasPrefix(endpointURL, "unix://") && !strings.HasPrefix(endpointURL, "tcp://") && !strings.HasPrefix(endpointURL, "npipe://") {
return errInvalidEndpointProtocol
return ErrInvalidEndpointProtocol
}
if strings.HasPrefix(endpointURL, "unix://") || strings.HasPrefix(endpointURL, "npipe://") {
socketPath := strings.TrimPrefix(endpointURL, "unix://")
socketPath = strings.TrimPrefix(socketPath, "npipe://")
if _, err := os.Stat(socketPath); err != nil {
if os.IsNotExist(err) {
return errSocketOrNamedPipeNotFound
return ErrSocketOrNamedPipeNotFound
}
return err
@ -138,9 +141,8 @@ func validateSnapshotInterval(snapshotInterval string) error {
return nil
}
_, err := time.ParseDuration(snapshotInterval)
if err != nil {
return errInvalidSnapshotInterval
if _, err := time.ParseDuration(snapshotInterval); err != nil {
return ErrInvalidSnapshotInterval
}
return nil

View File

@ -56,14 +56,14 @@ import (
)
func initCLI() *portainer.CLIFlags {
var cliService portainer.CLIService = &cli.Service{}
cliService := &cli.Service{}
flags, err := cliService.ParseFlags(portainer.APIVersion)
if err != nil {
log.Fatal().Err(err).Msg("failed parsing flags")
}
err = cliService.ValidateFlags(flags)
if err != nil {
if err := cliService.ValidateFlags(flags); err != nil {
log.Fatal().Err(err).Msg("failed validating flags")
}
@ -94,14 +94,14 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
}
store := datastore.NewStore(*flags.Data, fileService, connection)
isNew, err := store.Open()
if err != nil {
log.Fatal().Err(err).Msg("failed opening store")
}
if *flags.Rollback {
err := store.Rollback(false)
if err != nil {
if err := store.Rollback(false); err != nil {
log.Fatal().Err(err).Msg("failed rolling back")
}
@ -110,8 +110,7 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
}
// Init sets some defaults - it's basically a migration
err = store.Init()
if err != nil {
if err := store.Init(); err != nil {
log.Fatal().Err(err).Msg("failed initializing data store")
}
@ -133,25 +132,23 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
}
store.VersionService.UpdateVersion(&v)
err = updateSettingsFromFlags(store, flags)
if err != nil {
if err := updateSettingsFromFlags(store, flags); err != nil {
log.Fatal().Err(err).Msg("failed updating settings from flags")
}
} else {
err = store.MigrateData()
if err != nil {
if err := store.MigrateData(); err != nil {
log.Fatal().Err(err).Msg("failed migration")
}
}
err = updateSettingsFromFlags(store, flags)
if err != nil {
if err := updateSettingsFromFlags(store, flags); err != nil {
log.Fatal().Err(err).Msg("failed updating settings from flags")
}
// this is for the db restore functionality - needs more tests.
go func() {
<-shutdownCtx.Done()
defer connection.Close()
}()
@ -205,36 +202,16 @@ func initJWTService(userSessionTimeout string, dataStore dataservices.DataStore)
userSessionTimeout = portainer.DefaultUserSessionTimeout
}
jwtService, err := jwt.NewService(userSessionTimeout, dataStore)
if err != nil {
return nil, err
}
return jwtService, nil
return jwt.NewService(userSessionTimeout, dataStore)
}
func initDigitalSignatureService() portainer.DigitalSignatureService {
return crypto.NewECDSAService(os.Getenv("AGENT_SECRET"))
}
func initCryptoService() portainer.CryptoService {
return &crypto.Service{}
}
func initLDAPService() portainer.LDAPService {
return &ldap.Service{}
}
func initOAuthService() portainer.OAuthService {
return oauth.NewService()
}
func initGitService(ctx context.Context) portainer.GitService {
return git.NewService(ctx)
}
func initSSLService(addr, certPath, keyPath string, fileService portainer.FileService, dataStore dataservices.DataStore, shutdownTrigger context.CancelFunc) (*ssl.Service, error) {
slices := strings.Split(addr, ":")
host := slices[0]
if host == "" {
host = "0.0.0.0"
@ -242,22 +219,13 @@ func initSSLService(addr, certPath, keyPath string, fileService portainer.FileSe
sslService := ssl.NewService(fileService, dataStore, shutdownTrigger)
err := sslService.Init(host, certPath, keyPath)
if err != nil {
if err := sslService.Init(host, certPath, keyPath); err != nil {
return nil, err
}
return sslService, nil
}
func initDockerClientFactory(signatureService portainer.DigitalSignatureService, reverseTunnelService portainer.ReverseTunnelService) *dockerclient.ClientFactory {
return dockerclient.NewClientFactory(signatureService, reverseTunnelService)
}
func initKubernetesClientFactory(signatureService portainer.DigitalSignatureService, reverseTunnelService portainer.ReverseTunnelService, dataStore dataservices.DataStore, instanceID, addrHTTPS, userSessionTimeout string) (*kubecli.ClientFactory, error) {
return kubecli.NewClientFactory(signatureService, reverseTunnelService, dataStore, instanceID, addrHTTPS, userSessionTimeout)
}
func initSnapshotService(
snapshotIntervalFromFlag string,
dataStore dataservices.DataStore,
@ -310,14 +278,12 @@ func updateSettingsFromFlags(dataStore dataservices.DataStore, flags *portainer.
settings.BlackListedLabels = *flags.Labels
}
settings.AgentSecret = ""
if agentKey, ok := os.LookupEnv("AGENT_SECRET"); ok {
settings.AgentSecret = agentKey
} else {
settings.AgentSecret = ""
}
err = dataStore.Settings().UpdateSettings(settings)
if err != nil {
if err := dataStore.Settings().UpdateSettings(settings); err != nil {
return err
}
@ -340,6 +306,7 @@ func loadAndParseKeyPair(fileService portainer.FileService, signatureService por
if err != nil {
return err
}
return signatureService.ParseKeyPair(private, public)
}
@ -348,7 +315,9 @@ func generateAndStoreKeyPair(fileService portainer.FileService, signatureService
if err != nil {
return err
}
privateHeader, publicHeader := signatureService.PEMHeaders()
return fileService.StoreKeyPair(private, public, privateHeader, publicHeader)
}
@ -361,6 +330,7 @@ func initKeyPair(fileService portainer.FileService, signatureService portainer.D
if existingKeyPair {
return loadAndParseKeyPair(fileService, signatureService)
}
return generateAndStoreKeyPair(fileService, signatureService)
}
@ -378,6 +348,7 @@ func loadEncryptionSecretKey(keyfilename string) []byte {
// return a 32 byte hash of the secret (required for AES)
hash := sha256.Sum256(content)
return hash[:]
}
@ -422,17 +393,17 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
log.Fatal().Err(err).Msg("failed initializing JWT service")
}
ldapService := initLDAPService()
ldapService := &ldap.Service{}
oauthService := initOAuthService()
oauthService := oauth.NewService()
gitService := initGitService(shutdownCtx)
gitService := git.NewService(shutdownCtx)
openAMTService := openamt.NewService()
cryptoService := initCryptoService()
cryptoService := &crypto.Service{}
digitalSignatureService := initDigitalSignatureService()
signatureService := initDigitalSignatureService()
edgeStacksService := edgestacks.NewService(dataStore)
@ -446,15 +417,18 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
log.Fatal().Err(err).Msg("failed to get SSL settings")
}
err = initKeyPair(fileService, digitalSignatureService)
if err != nil {
if err := initKeyPair(fileService, signatureService); err != nil {
log.Fatal().Err(err).Msg("failed initializing key pair")
}
reverseTunnelService := chisel.NewService(dataStore, shutdownCtx, fileService)
dockerClientFactory := initDockerClientFactory(digitalSignatureService, reverseTunnelService)
kubernetesClientFactory, err := initKubernetesClientFactory(digitalSignatureService, reverseTunnelService, dataStore, instanceID, *flags.AddrHTTPS, settings.UserSessionTimeout)
dockerClientFactory := dockerclient.NewClientFactory(signatureService, reverseTunnelService)
kubernetesClientFactory, err := kubecli.NewClientFactory(signatureService, reverseTunnelService, dataStore, instanceID, *flags.AddrHTTPS, settings.UserSessionTimeout)
if err != nil {
log.Fatal().Err(err).Msg("failed initializing Kubernetes Client Factory service")
}
authorizationService := authorization.NewService(dataStore)
authorizationService.K8sClientFactory = kubernetesClientFactory
@ -476,12 +450,12 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
composeStackManager := initComposeStackManager(composeDeployer, proxyManager)
swarmStackManager, err := initSwarmStackManager(*flags.Assets, dockerConfigPath, digitalSignatureService, fileService, reverseTunnelService, dataStore)
swarmStackManager, err := initSwarmStackManager(*flags.Assets, dockerConfigPath, signatureService, fileService, reverseTunnelService, dataStore)
if err != nil {
log.Fatal().Err(err).Msg("failed initializing swarm stack manager")
}
kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, digitalSignatureService, proxyManager, *flags.Assets)
kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager, *flags.Assets)
pendingActionsService := pendingactions.NewService(dataStore, kubernetesClientFactory)
pendingActionsService.RegisterHandler(actions.CleanNAPWithOverridePolicies, handlers.NewHandlerCleanNAPWithOverridePolicies(authorizationService, dataStore))
@ -492,17 +466,17 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
if err != nil {
log.Fatal().Err(err).Msg("failed initializing snapshot service")
}
snapshotService.Start()
proxyManager.NewProxyFactory(dataStore, digitalSignatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService)
proxyManager.NewProxyFactory(dataStore, signatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService)
helmPackageManager, err := initHelmPackageManager(*flags.Assets)
if err != nil {
log.Fatal().Err(err).Msg("failed initializing helm package manager")
}
err = edge.LoadEdgeJobs(dataStore, reverseTunnelService)
if err != nil {
if err := edge.LoadEdgeJobs(dataStore, reverseTunnelService); err != nil {
log.Fatal().Err(err).Msg("failed loading edge jobs from database")
}
@ -514,6 +488,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
go endpointutils.InitEndpoint(shutdownCtx, adminCreationDone, flags, dataStore, snapshotService)
adminPasswordHash := ""
if *flags.AdminPasswordFile != "" {
content, err := fileService.GetFileContent(*flags.AdminPasswordFile, "")
if err != nil {
@ -536,14 +511,14 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
if len(users) == 0 {
log.Info().Msg("created admin user with the given password.")
user := &portainer.User{
Username: "admin",
Role: portainer.AdministratorRole,
Password: adminPasswordHash,
}
err := dataStore.User().Create(user)
if err != nil {
if err := dataStore.User().Create(user); err != nil {
log.Fatal().Err(err).Msg("failed creating admin user")
}
@ -554,8 +529,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
}
}
err = reverseTunnelService.StartTunnelServer(*flags.TunnelAddr, *flags.TunnelPort, snapshotService)
if err != nil {
if err := reverseTunnelService.StartTunnelServer(*flags.TunnelAddr, *flags.TunnelPort, snapshotService); err != nil {
log.Fatal().Err(err).Msg("failed starting tunnel server")
}
@ -613,7 +587,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
ProxyManager: proxyManager,
KubernetesTokenCacheManager: kubernetesTokenCacheManager,
KubeClusterAccessService: kubeClusterAccessService,
SignatureService: digitalSignatureService,
SignatureService: signatureService,
SnapshotService: snapshotService,
SSLService: sslService,
DockerClientFactory: dockerClientFactory,
@ -639,6 +613,7 @@ func main() {
for {
server := buildServer(flags)
log.Info().
Str("version", portainer.APIVersion).
Str("build_number", build.BuildNumber).

View File

@ -203,6 +203,7 @@ func (connection *DbConnection) ExportRaw(filename string) error {
func (connection *DbConnection) ConvertToKey(v int) []byte {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, uint64(v))
return b
}

View File

@ -46,8 +46,8 @@ func (connection *DbConnection) UnmarshalObject(data []byte, object interface{})
return errors.Wrap(err, "Failed decrypting object")
}
}
e := json.Unmarshal(data, object)
if e != nil {
if e := json.Unmarshal(data, object); e != nil {
// Special case for the VERSION bucket. Here we're not using json
// So we need to return it as a string
s, ok := object.(*string)
@ -57,6 +57,7 @@ func (connection *DbConnection) UnmarshalObject(data []byte, object interface{})
*s = string(data)
}
return err
}
@ -71,7 +72,7 @@ func encrypt(plaintext []byte, passphrase []byte) (encrypted []byte, err error)
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return encrypted, err
}

View File

@ -78,6 +78,7 @@ func (tx *DbTransaction) GetNextIdentifier(bucketName string) int {
id, err := bucket.NextSequence()
if err != nil {
log.Error().Err(err).Str("bucket", bucketName).Msg("failed to get the next identifier")
return 0
}

View File

@ -111,5 +111,6 @@ func (store *Store) finishMigrateLegacyVersion(versionToWrite *models.Version) e
store.connection.DeleteObject(bucketName, []byte(legacyDBVersionKey))
store.connection.DeleteObject(bucketName, []byte(legacyEditionKey))
store.connection.DeleteObject(bucketName, []byte(legacyInstanceKey))
return err
}

View File

@ -39,20 +39,19 @@ func (m *Migrator) Migrate() error {
latestMigrations := m.LatestMigrations()
if latestMigrations.Version.Equal(schemaVersion) &&
version.MigratorCount != len(latestMigrations.MigrationFuncs) {
err := runMigrations(latestMigrations.MigrationFuncs)
if err != nil {
if err := runMigrations(latestMigrations.MigrationFuncs); err != nil {
return err
}
newMigratorCount = len(latestMigrations.MigrationFuncs)
}
} else {
// regular path when major/minor/patch versions differ
for _, migration := range m.migrations {
if schemaVersion.LessThan(migration.Version) {
log.Info().Msgf("migrating data to %s", migration.Version.String())
err := runMigrations(migration.MigrationFuncs)
if err != nil {
if err := runMigrations(migration.MigrationFuncs); err != nil {
return err
}
}
@ -63,16 +62,14 @@ func (m *Migrator) Migrate() error {
}
}
err = m.Always()
if err != nil {
if err := m.Always(); err != nil {
return migrationError(err, "Always migrations returned error")
}
version.SchemaVersion = portainer.APIVersion
version.MigratorCount = newMigratorCount
err = m.versionService.UpdateVersion(version)
if err != nil {
if err := m.versionService.UpdateVersion(version); err != nil {
return migrationError(err, "StoreDBVersion")
}
@ -99,6 +96,7 @@ func (m *Migrator) NeedsMigration() bool {
// In this particular instance we should log a fatal error
if m.CurrentDBEdition() != portainer.PortainerCE {
log.Fatal().Msg("the Portainer database is set for Portainer Business Edition, please follow the instructions in our documentation to downgrade it: https://documentation.portainer.io/v2.0-be/downgrade/be-to-ce/")
return false
}

View File

@ -7,6 +7,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/chisel/crypto"
"github.com/portainer/portainer/api/dataservices"
"github.com/rs/zerolog/log"
)
@ -37,9 +38,11 @@ func (m *Migrator) convertSeedToPrivateKeyForDB100() error {
log.Info().Msg("ServerInfo object not found")
return nil
}
log.Error().
Err(err).
Msg("Failed to read ServerInfo from DB")
return err
}
@ -49,14 +52,15 @@ func (m *Migrator) convertSeedToPrivateKeyForDB100() error {
log.Error().
Err(err).
Msg("Failed to read ServerInfo from DB")
return err
}
err = m.fileService.StoreChiselPrivateKey(key)
if err != nil {
if err := m.fileService.StoreChiselPrivateKey(key); err != nil {
log.Error().
Err(err).
Msg("Failed to save Chisel private key to disk")
return err
}
} else {
@ -64,14 +68,14 @@ func (m *Migrator) convertSeedToPrivateKeyForDB100() error {
}
serverInfo.PrivateKeySeed = ""
err = m.TunnelServerService.UpdateInfo(serverInfo)
if err != nil {
if err := m.TunnelServerService.UpdateInfo(serverInfo); err != nil {
log.Error().
Err(err).
Msg("Failed to clean private key seed in DB")
} else {
log.Info().Msg("Success to migrate private key seed to private key file")
}
return err
}
@ -84,9 +88,8 @@ func (m *Migrator) updateEdgeStackStatusForDB100() error {
}
for _, edgeStack := range edgeStacks {
for environmentID, environmentStatus := range edgeStack.Status {
// skip if status is already updated
// Skip if status is already updated
if len(environmentStatus.Status) > 0 {
continue
}
@ -146,8 +149,7 @@ func (m *Migrator) updateEdgeStackStatusForDB100() error {
edgeStack.Status[environmentID] = environmentStatus
}
err = m.edgeStackService.UpdateEdgeStack(edgeStack.ID, &edgeStack)
if err != nil {
if err := m.edgeStackService.UpdateEdgeStack(edgeStack.ID, &edgeStack); err != nil {
return err
}
}

View File

@ -32,8 +32,8 @@ func (m *Migrator) updateStacksToDB24() error {
for idx := range stacks {
stack := &stacks[idx]
stack.Status = portainer.StackStatusActive
err := m.stackService.Update(stack.ID, stack)
if err != nil {
if err := m.stackService.Update(stack.ID, stack); err != nil {
return err
}
}

View File

@ -583,7 +583,6 @@
"AuthenticationMethod": 1,
"BlackListedLabels": [],
"Edge": {
"AsyncMode": false,
"CommandInterval": 0,
"PingInterval": 0,
"SnapshotInterval": 0

View File

@ -52,27 +52,24 @@ func NewTestStore(t testing.TB, init, secure bool) (bool, *Store, func(), error)
}
if init {
err = store.Init()
if err != nil {
if err := store.Init(); err != nil {
return newStore, nil, nil, err
}
}
if newStore {
// from MigrateData
// From MigrateData
v := models.Version{
SchemaVersion: portainer.APIVersion,
Edition: int(portainer.PortainerCE),
}
err = store.VersionService.UpdateVersion(&v)
if err != nil {
if err := store.VersionService.UpdateVersion(&v); err != nil {
return newStore, nil, nil, err
}
}
teardown := func() {
err := store.Close()
if err != nil {
if err := store.Close(); err != nil {
log.Fatal().Err(err).Msg("")
}
}

View File

@ -36,7 +36,6 @@ func (c *ContainerService) Recreate(ctx context.Context, endpoint *portainer.End
if err != nil {
return nil, errors.Wrap(err, "create client error")
}
defer cli.Close()
log.Debug().Str("container_id", containerId).Msg("starting to fetch container information")

View File

@ -5,10 +5,10 @@ import (
"fmt"
"net/http"
"github.com/portainer/portainer/api/http/security"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
gorillacsrf "github.com/gorilla/csrf"
"github.com/portainer/portainer/api/http/security"
"github.com/urfave/negroni"
)
@ -16,8 +16,7 @@ func WithProtect(handler http.Handler) (http.Handler, error) {
handler = withSendCSRFToken(handler)
token := make([]byte, 32)
_, err := rand.Read(token)
if err != nil {
if _, err := rand.Read(token); err != nil {
return nil, fmt.Errorf("failed to generate CSRF token: %w", err)
}
@ -32,7 +31,6 @@ func WithProtect(handler http.Handler) (http.Handler, error) {
func withSendCSRFToken(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sw := negroni.NewResponseWriter(w)
sw.Before(func(sw negroni.ResponseWriter) {
@ -44,16 +42,15 @@ func withSendCSRFToken(handler http.Handler) http.Handler {
})
handler.ServeHTTP(sw, r)
})
}
func withSkipCSRF(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
skip, err := security.ShouldSkipCSRFCheck(r)
if err != nil {
httperror.WriteError(w, http.StatusForbidden, err.Error(), err)
return
}

View File

@ -56,8 +56,7 @@ func (payload *authenticatePayload) Validate(r *http.Request) error {
// @router /auth [post]
func (handler *Handler) authenticate(rw http.ResponseWriter, r *http.Request) *httperror.HandlerError {
var payload authenticatePayload
err := request.DecodeAndValidateJSONPayload(r, &payload)
if err != nil {
if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
return httperror.BadRequest("Invalid request payload", err)
}
@ -104,8 +103,7 @@ func isUserInitialAdmin(user *portainer.User) bool {
}
func (handler *Handler) authenticateInternal(w http.ResponseWriter, user *portainer.User, password string) *httperror.HandlerError {
err := handler.CryptoService.CompareHashAndData(user.Password, password)
if err != nil {
if err := handler.CryptoService.CompareHashAndData(user.Password, password); err != nil {
return httperror.NewError(http.StatusUnprocessableEntity, "Invalid credentials", httperrors.ErrUnauthorized)
}
@ -115,8 +113,7 @@ func (handler *Handler) authenticateInternal(w http.ResponseWriter, user *portai
}
func (handler *Handler) authenticateLDAP(w http.ResponseWriter, user *portainer.User, username, password string, ldapSettings *portainer.LDAPSettings) *httperror.HandlerError {
err := handler.LDAPService.AuthenticateUser(username, password, ldapSettings)
if err != nil {
if err := handler.LDAPService.AuthenticateUser(username, password, ldapSettings); err != nil {
if errors.Is(err, httperrors.ErrUnauthorized) {
return httperror.NewError(http.StatusUnprocessableEntity, "Invalid credentials", httperrors.ErrUnauthorized)
}
@ -131,14 +128,12 @@ func (handler *Handler) authenticateLDAP(w http.ResponseWriter, user *portainer.
PortainerAuthorizations: authorization.DefaultPortainerAuthorizations(),
}
err = handler.DataStore.User().Create(user)
if err != nil {
if err := handler.DataStore.User().Create(user); err != nil {
return httperror.InternalServerError("Unable to persist user inside the database", err)
}
}
err = handler.syncUserTeamsWithLDAPGroups(user, ldapSettings)
if err != nil {
if err := handler.syncUserTeamsWithLDAPGroups(user, ldapSettings); err != nil {
log.Warn().Err(err).Msg("unable to automatically sync user teams with ldap")
}
@ -186,7 +181,6 @@ func (handler *Handler) syncUserTeamsWithLDAPGroups(user *portainer.User, settin
for _, team := range teams {
if teamExists(team.Name, userGroups) {
if teamMembershipExists(team.ID, userMemberships) {
continue
}
@ -197,8 +191,7 @@ func (handler *Handler) syncUserTeamsWithLDAPGroups(user *portainer.User, settin
Role: portainer.TeamMember,
}
err := handler.DataStore.TeamMembership().Create(membership)
if err != nil {
if err := handler.DataStore.TeamMembership().Create(membership); err != nil {
return err
}
}

View File

@ -41,5 +41,6 @@ func NewHandler(bouncer security.BouncerService, rateLimiter *security.RateLimit
rateLimiter.LimitAccess(bouncer.PublicAccess(httperror.LoggerHandler(h.authenticate)))).Methods(http.MethodPost)
h.Handle("/auth/logout",
bouncer.PublicAccess(httperror.LoggerHandler(h.logout))).Methods(http.MethodPost)
return h
}

View File

@ -4,7 +4,7 @@ import (
"net/http"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/logoutcontext"
"github.com/portainer/portainer/api/logoutcontext"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/response"
)

View File

@ -4,14 +4,15 @@ import (
"net/http"
"strconv"
"github.com/pkg/errors"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/authorization"
"github.com/portainer/portainer/api/internal/slices"
"github.com/portainer/portainer/api/slicesx"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
@ -70,7 +71,7 @@ func (handler *Handler) customTemplateList(w http.ResponseWriter, r *http.Reques
customTemplates = filterByType(customTemplates, templateTypes)
if edge != nil {
customTemplates = slices.Filter(customTemplates, func(customTemplate portainer.CustomTemplate) bool {
customTemplates = slicesx.Filter(customTemplates, func(customTemplate portainer.CustomTemplate) bool {
return customTemplate.EdgeTemplate == *edge
})
}

View File

@ -6,7 +6,7 @@ import (
"github.com/portainer/portainer/api/docker/client"
"github.com/portainer/portainer/api/http/handler/docker/utils"
"github.com/portainer/portainer/api/internal/set"
"github.com/portainer/portainer/api/set"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"

View File

@ -7,7 +7,7 @@ import (
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/authorization"
"github.com/portainer/portainer/api/internal/slices"
"github.com/portainer/portainer/api/slicesx"
)
// filterByResourceControl filters a list of items based on the user's role and the resource control associated to the item.
@ -16,7 +16,7 @@ func FilterByResourceControl[T any](tx dataservices.DataStoreTx, items []T, rcTy
return items, nil
}
userTeamIDs := slices.Map(securityContext.UserMemberships, func(membership portainer.TeamMembership) portainer.TeamID {
userTeamIDs := slicesx.Map(securityContext.UserMemberships, func(membership portainer.TeamMembership) portainer.TeamID {
return membership.TeamID
})
@ -32,5 +32,6 @@ func FilterByResourceControl[T any](tx dataservices.DataStoreTx, items []T, rcTy
}
}
return filteredItems, nil
}

View File

@ -36,23 +36,25 @@ func (payload *edgeGroupCreatePayload) Validate(r *http.Request) error {
func calculateEndpointsOrTags(tx dataservices.DataStoreTx, edgeGroup *portainer.EdgeGroup, endpoints []portainer.EndpointID, tagIDs []portainer.TagID) error {
if edgeGroup.Dynamic {
edgeGroup.TagIDs = tagIDs
} else {
endpointIDs := []portainer.EndpointID{}
for _, endpointID := range endpoints {
endpoint, err := tx.Endpoint().Endpoint(endpointID)
if err != nil {
return httperror.InternalServerError("Unable to retrieve environment from the database", err)
}
return nil
}
if endpointutils.IsEdgeEndpoint(endpoint) {
endpointIDs = append(endpointIDs, endpoint.ID)
}
endpointIDs := []portainer.EndpointID{}
for _, endpointID := range endpoints {
endpoint, err := tx.Endpoint().Endpoint(endpointID)
if err != nil {
return httperror.InternalServerError("Unable to retrieve environment from the database", err)
}
edgeGroup.Endpoints = endpointIDs
if endpointutils.IsEdgeEndpoint(endpoint) {
endpointIDs = append(endpointIDs, endpoint.ID)
}
}
edgeGroup.Endpoints = endpointIDs
return nil
}
@ -71,13 +73,13 @@ func calculateEndpointsOrTags(tx dataservices.DataStoreTx, edgeGroup *portainer.
// @router /edge_groups [post]
func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
var payload edgeGroupCreatePayload
err := request.DecodeAndValidateJSONPayload(r, &payload)
if err != nil {
if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
return httperror.BadRequest("Invalid request payload", err)
}
var edgeGroup *portainer.EdgeGroup
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
edgeGroups, err := tx.EdgeGroup().ReadAll()
if err != nil {
return httperror.InternalServerError("Unable to retrieve Edge groups from the database", err)
@ -101,8 +103,7 @@ func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request)
return err
}
err = tx.EdgeGroup().Create(edgeGroup)
if err != nil {
if err := tx.EdgeGroup().Create(edgeGroup); err != nil {
return httperror.InternalServerError("Unable to persist the Edge group inside the database", err)
}

View File

@ -9,7 +9,7 @@ import (
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/edge"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/internal/unique"
"github.com/portainer/portainer/api/slicesx"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
@ -113,7 +113,7 @@ func (handler *Handler) edgeGroupUpdate(w http.ResponseWriter, r *http.Request)
}
newRelatedEndpoints := edge.EdgeGroupRelatedEndpoints(edgeGroup, endpoints, endpointGroups)
endpointsToUpdate := unique.Unique(append(newRelatedEndpoints, oldRelatedEndpoints...))
endpointsToUpdate := slicesx.Unique(append(newRelatedEndpoints, oldRelatedEndpoints...))
edgeJobs, err := tx.EdgeJob().ReadAll()
if err != nil {

View File

@ -31,8 +31,7 @@ func setupHandler(t *testing.T) (*Handler, string) {
}
user := &portainer.User{ID: 2, Username: "admin", Role: portainer.AdministratorRole}
err = store.User().Create(user)
if err != nil {
if err := store.User().Create(user); err != nil {
t.Fatal(err)
}
@ -66,8 +65,7 @@ func setupHandler(t *testing.T) (*Handler, string) {
}
settings.EnableEdgeComputeFeatures = true
err = handler.DataStore.Settings().UpdateSettings(settings)
if err != nil {
if err := handler.DataStore.Settings().UpdateSettings(settings); err != nil {
t.Fatal(err)
}
@ -88,8 +86,7 @@ func createEndpointWithId(t *testing.T, store dataservices.DataStore, endpointID
LastCheckInDate: time.Now().Unix(),
}
err := store.Endpoint().Create(&endpoint)
if err != nil {
if err := store.Endpoint().Create(&endpoint); err != nil {
t.Fatal(err)
}
@ -112,8 +109,7 @@ func createEdgeStack(t *testing.T, store dataservices.DataStore, endpointID port
PartialMatch: false,
}
err := store.EdgeGroup().Create(&edgeGroup)
if err != nil {
if err := store.EdgeGroup().Create(&edgeGroup); err != nil {
t.Fatal(err)
}
@ -138,13 +134,11 @@ func createEdgeStack(t *testing.T, store dataservices.DataStore, endpointID port
},
}
err = store.EdgeStack().Create(edgeStack.ID, &edgeStack)
if err != nil {
if err := store.EdgeStack().Create(edgeStack.ID, &edgeStack); err != nil {
t.Fatal(err)
}
err = store.EndpointRelation().Create(&endpointRelation)
if err != nil {
if err := store.EndpointRelation().Create(&endpointRelation); err != nil {
t.Fatal(err)
}

View File

@ -6,7 +6,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/edge"
"github.com/portainer/portainer/api/internal/set"
"github.com/portainer/portainer/api/set"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"

View File

@ -9,6 +9,7 @@ import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/stretchr/testify/require"
"github.com/segmentio/encoding/json"
)
@ -24,8 +25,7 @@ func TestUpdateAndInspect(t *testing.T) {
endpointID := portainer.EndpointID(6)
newEndpoint := createEndpointWithId(t, handler.DataStore, endpointID)
err := handler.DataStore.Endpoint().Create(&newEndpoint)
if err != nil {
if err := handler.DataStore.Endpoint().Create(&newEndpoint); err != nil {
t.Fatal(err)
}
@ -36,8 +36,7 @@ func TestUpdateAndInspect(t *testing.T) {
},
}
err = handler.DataStore.EndpointRelation().Create(&endpointRelation)
if err != nil {
if err := handler.DataStore.EndpointRelation().Create(&endpointRelation); err != nil {
t.Fatal(err)
}
@ -50,8 +49,7 @@ func TestUpdateAndInspect(t *testing.T) {
PartialMatch: false,
}
err = handler.DataStore.EdgeGroup().Create(&newEdgeGroup)
if err != nil {
if err := handler.DataStore.EdgeGroup().Create(&newEdgeGroup); err != nil {
t.Fatal(err)
}
@ -96,8 +94,7 @@ func TestUpdateAndInspect(t *testing.T) {
}
updatedStack := portainer.EdgeStack{}
err = json.NewDecoder(rec.Body).Decode(&updatedStack)
if err != nil {
if err := json.NewDecoder(rec.Body).Decode(&updatedStack); err != nil {
t.Fatal("error decoding response:", err)
}
@ -120,7 +117,6 @@ func TestUpdateWithInvalidEdgeGroups(t *testing.T) {
endpoint := createEndpoint(t, handler.DataStore)
edgeStack := createEdgeStack(t, handler.DataStore, endpoint.ID)
//newEndpoint := createEndpoint(t, handler.DataStore)
newEdgeGroup := portainer.EdgeGroup{
ID: 2,
Name: "EdgeGroup 2",
@ -130,7 +126,8 @@ func TestUpdateWithInvalidEdgeGroups(t *testing.T) {
PartialMatch: false,
}
handler.DataStore.EdgeGroup().Create(&newEdgeGroup)
err := handler.DataStore.EdgeGroup().Create(&newEdgeGroup)
require.NoError(t, err)
cases := []struct {
Name string

View File

@ -18,6 +18,7 @@ import (
"github.com/segmentio/encoding/json"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type endpointTestCase struct {
@ -99,8 +100,7 @@ func mustSetupHandler(t *testing.T) *Handler {
}
settings.TrustOnFirstConnect = true
err = store.Settings().UpdateSettings(settings)
if err != nil {
if err = store.Settings().UpdateSettings(settings); err != nil {
t.Fatalf("could not update settings: %s", err)
}
@ -122,8 +122,7 @@ func createEndpoint(handler *Handler, endpoint portainer.Endpoint, endpointRelat
return nil
}
err = handler.DataStore.Endpoint().Create(&endpoint)
if err != nil {
if err := handler.DataStore.Endpoint().Create(&endpoint); err != nil {
return err
}
@ -134,14 +133,13 @@ func TestMissingEdgeIdentifier(t *testing.T) {
handler := mustSetupHandler(t)
endpointID := portainer.EndpointID(45)
err := createEndpoint(handler, portainer.Endpoint{
if err := createEndpoint(handler, portainer.Endpoint{
ID: endpointID,
Name: "endpoint-id-45",
Type: portainer.EdgeAgentOnDockerEnvironment,
URL: "https://portainer.io:9443",
EdgeID: "edge-id",
}, portainer.EndpointRelation{EndpointID: endpointID})
if err != nil {
}, portainer.EndpointRelation{EndpointID: endpointID}); err != nil {
t.Fatal(err)
}
@ -201,8 +199,7 @@ func TestLastCheckInDateIncreases(t *testing.T) {
EndpointID: endpoint.ID,
}
err := createEndpoint(handler, endpoint, endpointRelation)
if err != nil {
if err := createEndpoint(handler, endpoint, endpointRelation); err != nil {
t.Fatal(err)
}
@ -212,6 +209,7 @@ func TestLastCheckInDateIncreases(t *testing.T) {
if err != nil {
t.Fatal("request error:", err)
}
req.Header.Set(portainer.PortainerAgentEdgeIDHeader, "edge-id")
req.Header.Set(portainer.HTTPResponseAgentPlatform, "1")
@ -246,8 +244,7 @@ func TestEmptyEdgeIdWithAgentPlatformHeader(t *testing.T) {
EndpointID: endpoint.ID,
}
err := createEndpoint(handler, endpoint, endpointRelation)
if err != nil {
if err := createEndpoint(handler, endpoint, endpointRelation); err != nil {
t.Fatal(err)
}
@ -255,6 +252,7 @@ func TestEmptyEdgeIdWithAgentPlatformHeader(t *testing.T) {
if err != nil {
t.Fatal("request error:", err)
}
req.Header.Set(portainer.PortainerAgentEdgeIDHeader, edgeId)
req.Header.Set(portainer.HTTPResponseAgentPlatform, "1")
@ -308,10 +306,11 @@ func TestEdgeStackStatus(t *testing.T) {
edgeStack.ID: true,
},
}
handler.DataStore.EdgeStack().Create(edgeStack.ID, &edgeStack)
err := createEndpoint(handler, endpoint, endpointRelation)
if err != nil {
err := handler.DataStore.EdgeStack().Create(edgeStack.ID, &edgeStack)
require.NoError(t, err)
if err := createEndpoint(handler, endpoint, endpointRelation); err != nil {
t.Fatal(err)
}
@ -319,6 +318,7 @@ func TestEdgeStackStatus(t *testing.T) {
if err != nil {
t.Fatal("request error:", err)
}
req.Header.Set(portainer.PortainerAgentEdgeIDHeader, "edge-id")
req.Header.Set(portainer.HTTPResponseAgentPlatform, "1")
@ -330,8 +330,7 @@ func TestEdgeStackStatus(t *testing.T) {
}
var data endpointEdgeStatusInspectResponse
err = json.NewDecoder(rec.Body).Decode(&data)
if err != nil {
if err := json.NewDecoder(rec.Body).Decode(&data); err != nil {
t.Fatal("error decoding response:", err)
}
@ -357,8 +356,7 @@ func TestEdgeJobsResponse(t *testing.T) {
EndpointID: endpoint.ID,
}
err := createEndpoint(handler, endpoint, endpointRelation)
if err != nil {
if err := createEndpoint(handler, endpoint, endpointRelation); err != nil {
t.Fatal(err)
}
@ -384,6 +382,7 @@ func TestEdgeJobsResponse(t *testing.T) {
if err != nil {
t.Fatal("request error:", err)
}
req.Header.Set(portainer.PortainerAgentEdgeIDHeader, "edge-id")
req.Header.Set(portainer.HTTPResponseAgentPlatform, "1")
@ -395,8 +394,7 @@ func TestEdgeJobsResponse(t *testing.T) {
}
var data endpointEdgeStatusInspectResponse
err = json.NewDecoder(rec.Body).Decode(&data)
if err != nil {
if err := json.NewDecoder(rec.Body).Decode(&data); err != nil {
t.Fatal("error decoding response:", err)
}

View File

@ -8,8 +8,8 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/internal/tag"
"github.com/portainer/portainer/api/pendingactions/handlers"
"github.com/portainer/portainer/api/tag"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"

View File

@ -4,7 +4,7 @@ import (
"net/http"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/set"
"github.com/portainer/portainer/api/set"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/response"
)

View File

@ -73,8 +73,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
payload.GroupID = groupID
var tagIDs []portainer.TagID
err = request.RetrieveMultiPartFormJSONValue(r, "TagIds", &tagIDs, true)
if err != nil {
if err := request.RetrieveMultiPartFormJSONValue(r, "TagIds", &tagIDs, true); err != nil {
return errors.New("invalid TagIds parameter")
}
payload.TagIDs = tagIDs
@ -96,6 +95,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
if err != nil {
return errors.New("invalid CA certificate file. Ensure that the file is uploaded correctly")
}
payload.TLSCACertFile = caCert
}
@ -110,6 +110,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
if err != nil {
return errors.New("invalid key file. Ensure that the file is uploaded correctly")
}
payload.TLSKeyFile = key
}
}
@ -120,6 +121,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
if err != nil {
return errors.New("invalid Azure application ID")
}
payload.AzureApplicationID = azureApplicationID
azureTenantID, err := request.RetrieveMultiPartFormValue(r, "AzureTenantID", false)
@ -139,6 +141,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
if err != nil || strings.EqualFold("", strings.Trim(endpointURL, " ")) {
return errors.New("URL cannot be empty")
}
payload.URL = endpointURL
publicURL, _ := request.RetrieveMultiPartFormValue(r, "PublicURL", true)
@ -156,10 +159,10 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
}
gpus := make([]portainer.Pair, 0)
err = request.RetrieveMultiPartFormJSONValue(r, "Gpus", &gpus, true)
if err != nil {
if err := request.RetrieveMultiPartFormJSONValue(r, "Gpus", &gpus, true); err != nil {
return errors.New("invalid Gpus parameter")
}
payload.Gpus = gpus
edgeCheckinInterval, _ := request.RetrieveNumericMultiPartFormValue(r, "EdgeCheckinInterval", true)
@ -206,8 +209,7 @@ func (payload *endpointCreatePayload) Validate(r *http.Request) error {
// @router /endpoints [post]
func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
payload := &endpointCreatePayload{}
err := payload.Validate(r)
if err != nil {
if err := payload.Validate(r); err != nil {
return httperror.BadRequest("Invalid request payload", err)
}
@ -268,8 +270,7 @@ func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *
)
}
err = handler.DataStore.EndpointRelation().Create(relationObject)
if err != nil {
if err := handler.DataStore.EndpointRelation().Create(relationObject); err != nil {
return httperror.InternalServerError("Unable to persist the relation object inside the database", err)
}
@ -278,6 +279,7 @@ func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *
func (handler *Handler) createEndpoint(tx dataservices.DataStoreTx, payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) {
var err error
switch payload.EndpointCreationType {
case azureEnvironment:
return handler.createAzureEndpoint(tx, payload)
@ -329,8 +331,7 @@ func (handler *Handler) createAzureEndpoint(tx dataservices.DataStoreTx, payload
}
httpClient := client.NewHTTPClient()
_, err := httpClient.ExecuteAzureAuthenticationRequest(&credentials)
if err != nil {
if _, err := httpClient.ExecuteAzureAuthenticationRequest(&credentials); err != nil {
return nil, httperror.InternalServerError("Unable to authenticate against Azure", err)
}
@ -352,8 +353,7 @@ func (handler *Handler) createAzureEndpoint(tx dataservices.DataStoreTx, payload
Kubernetes: portainer.KubernetesDefault(),
}
err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint)
if err != nil {
if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil {
return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err)
}
@ -405,8 +405,7 @@ func (handler *Handler) createEdgeAgentEndpoint(tx dataservices.DataStoreTx, pay
endpoint.EdgeID = edgeID.String()
}
err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint)
if err != nil {
if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil {
return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err)
}
@ -443,8 +442,7 @@ func (handler *Handler) createUnsecuredEndpoint(tx dataservices.DataStoreTx, pay
Kubernetes: portainer.KubernetesDefault(),
}
err := handler.snapshotAndPersistEndpoint(tx, endpoint)
if err != nil {
if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil {
return nil, err
}
@ -478,8 +476,7 @@ func (handler *Handler) createKubernetesEndpoint(tx dataservices.DataStoreTx, pa
Kubernetes: portainer.KubernetesDefault(),
}
err := handler.snapshotAndPersistEndpoint(tx, endpoint)
if err != nil {
if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil {
return nil, err
}
@ -510,13 +507,11 @@ func (handler *Handler) createTLSSecuredEndpoint(tx dataservices.DataStoreTx, pa
endpoint.Agent.Version = agentVersion
err := handler.storeTLSFiles(endpoint, payload)
if err != nil {
if err := handler.storeTLSFiles(endpoint, payload); err != nil {
return nil, err
}
err = handler.snapshotAndPersistEndpoint(tx, endpoint)
if err != nil {
if err := handler.snapshotAndPersistEndpoint(tx, endpoint); err != nil {
return nil, err
}
@ -524,17 +519,16 @@ func (handler *Handler) createTLSSecuredEndpoint(tx dataservices.DataStoreTx, pa
}
func (handler *Handler) snapshotAndPersistEndpoint(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint) *httperror.HandlerError {
err := handler.SnapshotService.SnapshotEndpoint(endpoint)
if err != nil {
if err := handler.SnapshotService.SnapshotEndpoint(endpoint); err != nil {
if (endpoint.Type == portainer.AgentOnDockerEnvironment && strings.Contains(err.Error(), "Invalid request signature")) ||
(endpoint.Type == portainer.AgentOnKubernetesEnvironment && strings.Contains(err.Error(), "unknown")) {
err = errors.New("agent already paired with another Portainer instance")
}
return httperror.InternalServerError("Unable to initiate communications with environment", err)
}
err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint)
if err != nil {
if err := handler.saveEndpointAndUpdateAuthorizations(tx, endpoint); err != nil {
return httperror.InternalServerError("An error occurred while trying to create the environment", err)
}
@ -555,16 +549,14 @@ func (handler *Handler) saveEndpointAndUpdateAuthorizations(tx dataservices.Data
AllowStackManagementForRegularUsers: true,
}
err := tx.Endpoint().Create(endpoint)
if err != nil {
if err := tx.Endpoint().Create(endpoint); err != nil {
return err
}
for _, tagID := range endpoint.TagIDs {
err = tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) {
if err := tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) {
tag.Endpoints[endpoint.ID] = true
})
if err != nil {
}); err != nil {
return err
}
}
@ -580,22 +572,26 @@ func (handler *Handler) storeTLSFiles(endpoint *portainer.Endpoint, payload *end
if err != nil {
return httperror.InternalServerError("Unable to persist TLS CA certificate file on disk", err)
}
endpoint.TLSConfig.TLSCACertPath = caCertPath
}
if !payload.TLSSkipClientVerify {
certPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCert, payload.TLSCertFile)
if err != nil {
return httperror.InternalServerError("Unable to persist TLS certificate file on disk", err)
}
endpoint.TLSConfig.TLSCertPath = certPath
keyPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileKey, payload.TLSKeyFile)
if err != nil {
return httperror.InternalServerError("Unable to persist TLS key file on disk", err)
}
endpoint.TLSConfig.TLSKeyPath = keyPath
if payload.TLSSkipClientVerify {
return nil
}
certPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCert, payload.TLSCertFile)
if err != nil {
return httperror.InternalServerError("Unable to persist TLS certificate file on disk", err)
}
endpoint.TLSConfig.TLSCertPath = certPath
keyPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileKey, payload.TLSKeyFile)
if err != nil {
return httperror.InternalServerError("Unable to persist TLS key file on disk", err)
}
endpoint.TLSConfig.TLSKeyPath = keyPath
return nil
}

View File

@ -30,24 +30,22 @@ func TestEndpointDeleteEdgeGroupsConcurrently(t *testing.T) {
for i := 0; i < endpointsCount; i++ {
endpointID := portainer.EndpointID(i) + 1
err := store.Endpoint().Create(&portainer.Endpoint{
if err := store.Endpoint().Create(&portainer.Endpoint{
ID: endpointID,
Name: "env-" + strconv.Itoa(int(endpointID)),
Type: portainer.EdgeAgentOnDockerEnvironment,
})
if err != nil {
}); err != nil {
t.Fatal("could not create endpoint:", err)
}
endpointIDs = append(endpointIDs, endpointID)
}
err := store.EdgeGroup().Create(&portainer.EdgeGroup{
if err := store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 1,
Name: "edgegroup-1",
Endpoints: endpointIDs,
})
if err != nil {
}); err != nil {
t.Fatal("could not create edge group:", err)
}

View File

@ -102,7 +102,6 @@ func Test_EndpointList_AgentVersion(t *testing.T) {
}
func Test_endpointList_edgeFilter(t *testing.T) {
trustedEdgeAsync := portainer.Endpoint{ID: 1, UserTrusted: true, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
untrustedEdgeAsync := portainer.Endpoint{ID: 2, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
regularUntrustedEdgeStandard := portainer.Endpoint{ID: 3, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: false}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
@ -227,8 +226,7 @@ func doEndpointListRequest(req *http.Request, h *Handler, is *assert.Assertions)
}
resp := []portainer.Endpoint{}
err = json.Unmarshal(body, &resp)
if err != nil {
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}

View File

@ -34,12 +34,10 @@ func (handler *Handler) endpointRegistriesList(w http.ResponseWriter, r *http.Re
}
var registries []portainer.Registry
err = handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error {
if err := handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error {
registries, err = handler.listRegistries(tx, r, portainer.EndpointID(endpointID))
return err
})
if err != nil {
}); err != nil {
var httpErr *httperror.HandlerError
if errors.As(err, &httpErr) {
return httpErr
@ -104,11 +102,9 @@ func (handler *Handler) filterKubernetesEndpointRegistries(r *http.Request, regi
}
if namespaceParam != "" {
authorized, err := handler.isNamespaceAuthorized(endpoint, namespaceParam, user.ID, memberships, isAdmin)
if err != nil {
if authorized, err := handler.isNamespaceAuthorized(endpoint, namespaceParam, user.ID, memberships, isAdmin); err != nil {
return nil, httperror.NotFound("Unable to check for namespace authorization", err)
}
if !authorized {
} else if !authorized {
return nil, httperror.Forbidden("User is not authorized to use namespace", errors.New("user is not authorized to use namespace"))
}

View File

@ -13,7 +13,7 @@ import (
"github.com/portainer/portainer/api/http/handler/edgegroups"
"github.com/portainer/portainer/api/internal/edge"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/internal/unique"
"github.com/portainer/portainer/api/slicesx"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/pkg/errors"
@ -254,6 +254,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
if err != nil {
return nil, errors.WithMessage(err, "Unable to retrieve edge group from the database")
}
if edgeGroup.Dynamic {
endpointIDs, err := edgegroups.GetEndpointsByTags(datastore, edgeGroup.TagIDs, edgeGroup.PartialMatch)
if err != nil {
@ -261,6 +262,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
}
edgeGroup.Endpoints = endpointIDs
}
envIds = append(envIds, edgeGroup.Endpoints...)
}
@ -275,7 +277,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
envIds = envIds[:n]
}
uniqueIds := unique.Unique(envIds)
uniqueIds := slicesx.Unique(envIds)
filteredEndpoints := filteredEndpointsByIds(endpoints, uniqueIds)
return filteredEndpoints, nil

View File

@ -5,8 +5,8 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/internal/slices"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/slicesx"
"github.com/stretchr/testify/assert"
)
@ -129,7 +129,7 @@ func Test_Filter_edgeFilter(t *testing.T) {
func Test_Filter_excludeIDs(t *testing.T) {
ids := []portainer.EndpointID{1, 2, 3, 4, 5, 6, 7, 8, 9}
environments := slices.Map(ids, func(id portainer.EndpointID) portainer.Endpoint {
environments := slicesx.Map(ids, func(id portainer.EndpointID) portainer.Endpoint {
return portainer.Endpoint{ID: id, GroupID: 1, Type: portainer.DockerEnvironment}
})

View File

@ -4,7 +4,8 @@ import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/internal/slices"
"github.com/portainer/portainer/api/slicesx"
"github.com/stretchr/testify/assert"
)
@ -162,7 +163,7 @@ func TestSortEndpointsByField(t *testing.T) {
}
func getEndpointIDs(environments []portainer.Endpoint) []portainer.EndpointID {
return slices.Map(environments, func(environment portainer.Endpoint) portainer.EndpointID {
return slicesx.Map(environments, func(environment portainer.Endpoint) portainer.EndpointID {
return environment.ID
})
}

View File

@ -6,7 +6,7 @@ import (
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/edge"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/internal/set"
"github.com/portainer/portainer/api/set"
)
// updateEdgeRelations updates the edge stacks associated to an edge endpoint

View File

@ -6,7 +6,7 @@ import (
"github.com/pkg/errors"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/set"
"github.com/portainer/portainer/api/set"
)
func updateEnvironmentEdgeGroups(tx dataservices.DataStoreTx, newEdgeGroups []portainer.EdgeGroupID, environmentID portainer.EndpointID) (bool, error) {

View File

@ -10,7 +10,6 @@ import (
)
func Test_updateEdgeGroups(t *testing.T) {
createGroups := func(store *datastore.Store, names []string) ([]portainer.EdgeGroup, error) {
groups := make([]portainer.EdgeGroup, len(names))
for index, name := range names {
@ -21,8 +20,7 @@ func Test_updateEdgeGroups(t *testing.T) {
Endpoints: make([]portainer.EndpointID, 0),
}
err := store.EdgeGroup().Create(group)
if err != nil {
if err := store.EdgeGroup().Create(group); err != nil {
return nil, err
}
@ -42,6 +40,7 @@ func Test_updateEdgeGroups(t *testing.T) {
return
}
}
is.Fail("expected endpoint to be in group")
}
}
@ -52,6 +51,7 @@ func Test_updateEdgeGroups(t *testing.T) {
for j, tag := range groups {
if tag.Name == tagName {
result[i] = groups[j]
break
}
}
@ -88,6 +88,7 @@ func Test_updateEdgeGroups(t *testing.T) {
}
expectedGroups := groupsByName(groups, testCase.groupsToApply)
expectedIDs := make([]portainer.EdgeGroupID, len(expectedGroups))
for i, tag := range expectedGroups {
expectedIDs[i] = tag.ID

View File

@ -4,7 +4,7 @@ import (
"github.com/pkg/errors"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/set"
"github.com/portainer/portainer/api/set"
)
// updateEnvironmentTags updates the tags associated to an environment

View File

@ -10,14 +10,14 @@ import (
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/exec/exectest"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/testhelpers"
helper "github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/jwt"
"github.com/portainer/portainer/api/kubernetes"
"github.com/portainer/portainer/pkg/libhelm/binary/test"
"github.com/portainer/portainer/pkg/libhelm/options"
"github.com/stretchr/testify/assert"
"github.com/portainer/portainer/api/internal/testhelpers"
helper "github.com/portainer/portainer/api/internal/testhelpers"
"github.com/stretchr/testify/assert"
)
func Test_helmDelete(t *testing.T) {

View File

@ -97,13 +97,13 @@ func (handler *Handler) userHasRegistryAccess(r *http.Request) (hasAccess bool,
if err != nil {
return false, false, err
}
endpoint, err := handler.DataStore.Endpoint().Endpoint(portainer.EndpointID(endpointID))
if err != nil {
return false, false, err
}
err = handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint)
if err != nil {
if err := handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint); err != nil {
return false, false, err
}

View File

@ -71,6 +71,7 @@ func (handler *Handler) settingsPublic(w http.ResponseWriter, r *http.Request) *
}
publicSettings := generatePublicSettings(settings)
return response.JSON(w, publicSettings)
}
@ -96,7 +97,7 @@ func generatePublicSettings(appSettings *portainer.Settings) *publicSettingsResp
publicSettings.IsDockerDesktopExtension = appSettings.IsDockerDesktopExtension
//if OAuth authentication is on, compose the related fields from application settings
// If OAuth authentication is on, compose the related fields from application settings
if publicSettings.AuthenticationMethod == portainer.AuthenticationOAuth {
publicSettings.OAuthLogoutURI = appSettings.OAuthSettings.LogoutURI
publicSettings.OAuthLoginURI = fmt.Sprintf("%s?response_type=code&client_id=%s&redirect_uri=%s&scope=%s",
@ -104,16 +105,18 @@ func generatePublicSettings(appSettings *portainer.Settings) *publicSettingsResp
appSettings.OAuthSettings.ClientID,
appSettings.OAuthSettings.RedirectURI,
appSettings.OAuthSettings.Scopes)
//control prompt=login param according to the SSO setting
// Control prompt=login param according to the SSO setting
if !appSettings.OAuthSettings.SSO {
publicSettings.OAuthLoginURI += "&prompt=login"
}
}
//if LDAP authentication is on, compose the related fields from application settings
// If LDAP authentication is on, compose the related fields from application settings
if publicSettings.AuthenticationMethod == portainer.AuthenticationLDAP && appSettings.LDAPSettings.GroupSearchSettings != nil {
if len(appSettings.LDAPSettings.GroupSearchSettings) > 0 {
publicSettings.TeamSync = len(appSettings.LDAPSettings.GroupSearchSettings[0].GroupBaseDN) > 0
}
}
return publicSettings
}

View File

@ -40,14 +40,17 @@ func setup() {
func TestGeneratePublicSettingsWithSSO(t *testing.T) {
setup()
mockAppSettings.OAuthSettings.SSO = true
publicSettings := generatePublicSettings(mockAppSettings)
if publicSettings.AuthenticationMethod != portainer.AuthenticationOAuth {
t.Errorf("wrong AuthenticationMethod, want: %d, got: %d", portainer.AuthenticationOAuth, publicSettings.AuthenticationMethod)
}
if publicSettings.OAuthLoginURI != dummyOAuthLoginURI {
t.Errorf("wrong OAuthLoginURI when SSO is switched on, want: %s, got: %s", dummyOAuthLoginURI, publicSettings.OAuthLoginURI)
}
if publicSettings.OAuthLogoutURI != dummyOAuthLogoutURI {
t.Errorf("wrong OAuthLogoutURI, want: %s, got: %s", dummyOAuthLogoutURI, publicSettings.OAuthLogoutURI)
}
@ -55,15 +58,18 @@ func TestGeneratePublicSettingsWithSSO(t *testing.T) {
func TestGeneratePublicSettingsWithoutSSO(t *testing.T) {
setup()
mockAppSettings.OAuthSettings.SSO = false
publicSettings := generatePublicSettings(mockAppSettings)
if publicSettings.AuthenticationMethod != portainer.AuthenticationOAuth {
t.Errorf("wrong AuthenticationMethod, want: %d, got: %d", portainer.AuthenticationOAuth, publicSettings.AuthenticationMethod)
}
expectedOAuthLoginURI := dummyOAuthLoginURI + "&prompt=login"
if publicSettings.OAuthLoginURI != expectedOAuthLoginURI {
t.Errorf("wrong OAuthLoginURI when SSO is switched off, want: %s, got: %s", expectedOAuthLoginURI, publicSettings.OAuthLoginURI)
}
if publicSettings.OAuthLogoutURI != dummyOAuthLogoutURI {
t.Errorf("wrong OAuthLogoutURI, want: %s, got: %s", dummyOAuthLogoutURI, publicSettings.OAuthLogoutURI)
}

View File

@ -89,8 +89,7 @@ func (handler *Handler) stackDelete(w http.ResponseWriter, r *http.Request) *htt
}
if !isOrphaned {
err = handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint)
if err != nil {
if err := handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint); err != nil {
return httperror.Forbidden("Permission denied to access endpoint", err)
}
@ -119,25 +118,21 @@ func (handler *Handler) stackDelete(w http.ResponseWriter, r *http.Request) *htt
deployments.StopAutoupdate(stack.ID, stack.AutoUpdate.JobID, handler.Scheduler)
}
err = handler.deleteStack(securityContext.UserID, stack, endpoint)
if err != nil {
if err := handler.deleteStack(securityContext.UserID, stack, endpoint); err != nil {
return httperror.InternalServerError(err.Error(), err)
}
err = handler.DataStore.Stack().Delete(portainer.StackID(id))
if err != nil {
if err := handler.DataStore.Stack().Delete(portainer.StackID(id)); err != nil {
return httperror.InternalServerError("Unable to remove the stack from the database", err)
}
if resourceControl != nil {
err = handler.DataStore.ResourceControl().Delete(resourceControl.ID)
if err != nil {
if err := handler.DataStore.ResourceControl().Delete(resourceControl.ID); err != nil {
return httperror.InternalServerError("Unable to remove the associated resource control from the database", err)
}
}
err = handler.FileService.RemoveDirectory(stack.ProjectPath)
if err != nil {
if err := handler.FileService.RemoveDirectory(stack.ProjectPath); err != nil {
log.Warn().Err(err).Msg("Unable to remove stack files from disk")
}
@ -169,8 +164,7 @@ func (handler *Handler) deleteExternalStack(r *http.Request, w http.ResponseWrit
return httperror.InternalServerError("Unable to find the endpoint associated to the stack inside the database", err)
}
err = handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint)
if err != nil {
if err := handler.requestBouncer.AuthorizedEndpointOperation(r, endpoint); err != nil {
return httperror.Forbidden("Permission denied to access endpoint", err)
}
@ -179,8 +173,7 @@ func (handler *Handler) deleteExternalStack(r *http.Request, w http.ResponseWrit
Type: portainer.DockerSwarmStack,
}
err = handler.deleteStack(securityContext.UserID, stack, endpoint)
if err != nil {
if err := handler.deleteStack(securityContext.UserID, stack, endpoint); err != nil {
return httperror.InternalServerError("Unable to delete stack", err)
}
@ -255,6 +248,7 @@ func (handler *Handler) deleteStack(userID portainer.UserID, stack *portainer.St
}
}
}
return errors.WithMessagef(err, "failed to remove kubernetes resources: %q", out)
}
@ -369,18 +363,18 @@ func (handler *Handler) stackDeleteKubernetesByName(w http.ResponseWriter, r *ht
if err != nil {
log.Err(err).Msgf("Unable to delete Kubernetes stack `%d`", stack.ID)
errors = append(errors, err)
continue
}
err = handler.DataStore.Stack().Delete(stack.ID)
if err != nil {
if err := handler.DataStore.Stack().Delete(stack.ID); err != nil {
errors = append(errors, err)
log.Err(err).Msgf("Unable to remove the stack `%d` from the database", stack.ID)
continue
}
err = handler.FileService.RemoveDirectory(stack.ProjectPath)
if err != nil {
if err := handler.FileService.RemoveDirectory(stack.ProjectPath); err != nil {
errors = append(errors, err)
log.Warn().Err(err).Msg("Unable to remove stack files from disk")
}

View File

@ -18,8 +18,7 @@ func TestTagDeleteEdgeGroupsConcurrently(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, false)
user := &portainer.User{ID: 2, Username: "admin", Role: portainer.AdministratorRole}
err := store.User().Create(user)
if err != nil {
if err := store.User().Create(user); err != nil {
t.Fatal("could not create admin user:", err)
}
@ -33,29 +32,28 @@ func TestTagDeleteEdgeGroupsConcurrently(t *testing.T) {
for i := 0; i < tagsCount; i++ {
tagID := portainer.TagID(i) + 1
err = store.Tag().Create(&portainer.Tag{
if err := store.Tag().Create(&portainer.Tag{
ID: tagID,
Name: "tag-" + strconv.Itoa(int(tagID)),
})
if err != nil {
}); err != nil {
t.Fatal("could not create tag:", err)
}
tagIDs = append(tagIDs, tagID)
}
err = store.EdgeGroup().Create(&portainer.EdgeGroup{
if err := store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 1,
Name: "edgegroup-1",
TagIDs: tagIDs,
})
if err != nil {
}); err != nil {
t.Fatal("could not create edge group:", err)
}
// Remove the tags concurrently
var wg sync.WaitGroup
wg.Add(len(tagIDs))
for _, tagID := range tagIDs {

View File

@ -27,6 +27,7 @@ func (payload *userCreatePayload) Validate(r *http.Request) error {
if payload.Role != 1 && payload.Role != 2 {
return errors.New("Invalid role value. Value must be one of: 1 (administrator) or 2 (regular user)")
}
return nil
}
@ -49,8 +50,7 @@ func (payload *userCreatePayload) Validate(r *http.Request) error {
// @router /users [post]
func (handler *Handler) userCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
var payload userCreatePayload
err := request.DecodeAndValidateJSONPayload(r, &payload)
if err != nil {
if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
return httperror.BadRequest("Invalid request payload", err)
}
@ -89,11 +89,11 @@ func (handler *Handler) userCreate(w http.ResponseWriter, r *http.Request) *http
}
}
err = handler.DataStore.User().Create(user)
if err != nil {
if err := handler.DataStore.User().Create(user); err != nil {
return httperror.InternalServerError("Unable to persist user inside the database", err)
}
hideFields(user)
return response.JSON(w, user)
}

View File

@ -26,12 +26,12 @@ func Test_userList(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, true)
// create admin and standard user(s)
// Create admin and standard user(s)
adminUser := &portainer.User{ID: 1, Username: "admin", Role: portainer.AdministratorRole}
err := store.User().Create(adminUser)
is.NoError(err, "error creating admin user")
// setup services
// Setup services
jwtService, err := jwt.NewService("1h", store)
is.NoError(err, "Error initiating jwt service")
apiKeyService := apikey.NewAPIKeyService(store.APIKeyRepository(), store.User())
@ -42,7 +42,7 @@ func Test_userList(t *testing.T) {
h := NewHandler(requestBouncer, rateLimiter, apiKeyService, passwordChecker)
h.DataStore = store
// generate admin user tokens
// Generate admin user tokens
adminJWT, _, _ := jwtService.GenerateToken(&portainer.TokenData{ID: adminUser.ID, Username: adminUser.Username, Role: adminUser.Role})
// Case 1: the user is given the endpoint access directly
@ -54,12 +54,12 @@ func Test_userList(t *testing.T) {
err = store.User().Create(userWithoutEndpointAccess)
is.NoError(err, "error creating user")
// create environment group
// Create environment group
endpointGroup := &portainer.EndpointGroup{ID: 1, Name: "default-endpoint-group"}
err = store.EndpointGroup().Create(endpointGroup)
is.NoError(err, "error creating endpoint group")
// create endpoint and user access policies
// Create endpoint and user access policies
userAccessPolicies := make(portainer.UserAccessPolicies, 0)
userAccessPolicies[userWithEndpointAccess.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userWithEndpointAccess.Role)}
@ -129,7 +129,7 @@ func Test_userList(t *testing.T) {
err = store.User().Create(userUnderGroup)
is.NoError(err, "error creating user")
// create environment group including a user
// Create environment group including a user
userAccessPoliciesUnderGroup := make(portainer.UserAccessPolicies, 0)
userAccessPoliciesUnderGroup[userUnderGroup.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userUnderGroup.Role)}
@ -137,7 +137,7 @@ func Test_userList(t *testing.T) {
err = store.EndpointGroup().Create(endpointGroupWithUser)
is.NoError(err, "error creating endpoint group")
// create endpoint
// Create endpoint
endpointUnderGroupWithUser := &portainer.Endpoint{ID: 2, GroupID: endpointGroupWithUser.ID}
err = store.Endpoint().Create(endpointUnderGroupWithUser)
is.NoError(err, "error creating endpoint")
@ -182,7 +182,7 @@ func Test_userList(t *testing.T) {
err = store.TeamMembership().Create(teamMembership)
is.NoError(err, "error creating team membership")
// create environment group including a team
// Create environment group including a team
teamAccessPoliciesUnderGroup := make(portainer.TeamAccessPolicies, 0)
teamAccessPoliciesUnderGroup[teamUnderGroup.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userUnderTeam.Role)}
@ -190,7 +190,7 @@ func Test_userList(t *testing.T) {
err = store.EndpointGroup().Create(endpointGroupWithTeam)
is.NoError(err, "error creating endpoint group")
// create endpoint
// Create endpoint
endpointUnderGroupWithTeam := &portainer.Endpoint{ID: 3, GroupID: endpointGroupWithTeam.ID}
err = store.Endpoint().Create(endpointUnderGroupWithTeam)
is.NoError(err, "error creating endpoint")
@ -233,12 +233,12 @@ func Test_userList(t *testing.T) {
err = store.TeamMembership().Create(teamMembershipWithEndpointAccess)
is.NoError(err, "error creating team membership")
// create environment group
// Create environment group
endpointGroupWithoutTeam := &portainer.EndpointGroup{ID: 4, Name: "endpoint-group-without-team"}
err = store.EndpointGroup().Create(endpointGroupWithoutTeam)
is.NoError(err, "error creating endpoint group")
// create endpoint and team access policies
// Create endpoint and team access policies
teamAccessPolicies := make(portainer.TeamAccessPolicies, 0)
teamAccessPolicies[teamWithEndpointAccess.ID] = portainer.AccessPolicy{RoleID: portainer.RoleID(userUnderTeamWithEndpointAccess.Role)}

View File

@ -19,12 +19,12 @@ func Test_updateUserRemovesAccessTokens(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, true)
// create standard user
// Create standard user
user := &portainer.User{ID: 2, Username: "standard", Role: portainer.StandardUserRole}
err := store.User().Create(user)
is.NoError(err, "error creating user")
// setup services
// Setup services
jwtService, err := jwt.NewService("1h", store)
is.NoError(err, "Error initiating jwt service")
apiKeyService := apikey.NewAPIKeyService(store.APIKeyRepository(), store.User())

View File

@ -8,12 +8,12 @@ import (
"net/url"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/logoutcontext"
"github.com/portainer/portainer/api/logoutcontext"
"github.com/gorilla/websocket"
"github.com/koding/websocketproxy"
"github.com/portainer/portainer/api/crypto"
"github.com/rs/zerolog/log"
)

View File

@ -9,7 +9,7 @@ import (
"github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/api/http/proxy/factory/agent"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/internal/url"
"github.com/portainer/portainer/api/url"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"

View File

@ -54,6 +54,7 @@ func decorateObject(object map[string]interface{}, resourceControl *portainer.Re
portainerMetadata := object["Portainer"].(map[string]interface{})
portainerMetadata["ResourceControl"] = resourceControl
return object
}
@ -64,8 +65,7 @@ func (transport *Transport) createPrivateResourceControl(
resourceControl := authorization.NewPrivateResourceControl(resourceIdentifier, resourceType, userID)
err := transport.dataStore.ResourceControl().Create(resourceControl)
if err != nil {
if err := transport.dataStore.ResourceControl().Create(resourceControl); err != nil {
log.Error().
Str("resource", resourceIdentifier).
Err(err).
@ -84,6 +84,7 @@ func (transport *Transport) userCanDeleteContainerGroup(request *http.Request, c
resourceIdentifier := request.URL.Path
resourceControl := transport.findResourceControl(resourceIdentifier, context)
return authorization.UserCanAccessResource(context.userID, context.userTeamIDs, resourceControl)
}
@ -136,20 +137,19 @@ func (transport *Transport) filterContainerGroups(containerGroups []interface{},
func (transport *Transport) removeResourceControl(containerGroup map[string]interface{}, context *azureRequestContext) error {
containerGroupID, ok := containerGroup["id"].(string)
if ok {
resourceControl := transport.findResourceControl(containerGroupID, context)
if resourceControl != nil {
err := transport.dataStore.ResourceControl().Delete(resourceControl.ID)
return err
}
} else {
if !ok {
log.Debug().Msg("missing ID in container group")
return nil
}
if resourceControl := transport.findResourceControl(containerGroupID, context); resourceControl != nil {
return transport.dataStore.ResourceControl().Delete(resourceControl.ID)
}
return nil
}
func (transport *Transport) findResourceControl(containerGroupId string, context *azureRequestContext) *portainer.ResourceControl {
resourceControl := authorization.GetResourceControlByResourceIDAndType(containerGroupId, portainer.ContainerGroupResourceControl, context.resourceControls)
return resourceControl
return authorization.GetResourceControlByResourceIDAndType(containerGroupId, portainer.ContainerGroupResourceControl, context.resourceControls)
}

View File

@ -8,7 +8,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/api/http/proxy/factory/docker"
"github.com/portainer/portainer/api/internal/url"
"github.com/portainer/portainer/api/url"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/rs/zerolog/log"

View File

@ -105,8 +105,7 @@ func (transport *Transport) newResourceControlFromPortainerLabels(labelsObject m
resourceControl := authorization.NewRestrictedResourceControl(resourceID, resourceType, userIDs, teamIDs)
err := transport.dataStore.ResourceControl().Create(resourceControl)
if err != nil {
if err := transport.dataStore.ResourceControl().Create(resourceControl); err != nil {
return nil, err
}
@ -119,8 +118,7 @@ func (transport *Transport) newResourceControlFromPortainerLabels(labelsObject m
func (transport *Transport) createPrivateResourceControl(resourceIdentifier string, resourceType portainer.ResourceControlType, userID portainer.UserID) (*portainer.ResourceControl, error) {
resourceControl := authorization.NewPrivateResourceControl(resourceIdentifier, resourceType, userID)
err := transport.dataStore.ResourceControl().Create(resourceControl)
if err != nil {
if err := transport.dataStore.ResourceControl().Create(resourceControl); err != nil {
log.Error().
Str("resource", resourceIdentifier).
Err(err).
@ -170,6 +168,7 @@ func (transport *Transport) applyAccessControlOnResource(parameters *resourceOpe
systemResourceControl := findSystemNetworkResourceControl(responseObject)
if systemResourceControl != nil {
responseObject = decorateObject(responseObject, systemResourceControl)
return utils.RewriteResponse(response, responseObject, http.StatusOK)
}
}
@ -188,6 +187,7 @@ func (transport *Transport) applyAccessControlOnResource(parameters *resourceOpe
if executor.operationContext.isAdmin || (resourceControl != nil && authorization.UserCanAccessResource(executor.operationContext.userID, executor.operationContext.userTeamIDs, resourceControl)) {
responseObject = decorateObject(responseObject, resourceControl)
return utils.RewriteResponse(response, responseObject, http.StatusOK)
}
@ -221,6 +221,7 @@ func (transport *Transport) decorateResourceList(parameters *resourceOperationPa
if systemResourceControl != nil {
resourceObject = decorateObject(resourceObject, systemResourceControl)
decoratedResourceData = append(decoratedResourceData, resourceObject)
continue
}
}
@ -264,6 +265,7 @@ func (transport *Transport) filterResourceList(parameters *resourceOperationPara
if systemResourceControl != nil {
resourceObject = decorateObject(resourceObject, systemResourceControl)
filteredResourceData = append(filteredResourceData, resourceObject)
continue
}
}
@ -277,6 +279,7 @@ func (transport *Transport) filterResourceList(parameters *resourceOperationPara
if context.isAdmin {
filteredResourceData = append(filteredResourceData, resourceObject)
}
continue
}
@ -334,11 +337,13 @@ func (transport *Transport) findResourceControl(resourceIdentifier string, resou
func getStackResourceIDFromLabels(resourceLabelsObject map[string]string, endpointID portainer.EndpointID) string {
if resourceLabelsObject[resourceLabelForDockerSwarmStackName] != "" {
stackName := resourceLabelsObject[resourceLabelForDockerSwarmStackName]
return stackutils.ResourceControlID(endpointID, stackName)
}
if resourceLabelsObject[resourceLabelForDockerComposeStackName] != "" {
stackName := resourceLabelsObject[resourceLabelForDockerComposeStackName]
return stackutils.ResourceControlID(endpointID, stackName)
}
@ -352,5 +357,6 @@ func decorateObject(object map[string]interface{}, resourceControl *portainer.Re
portainerMetadata := object["Portainer"].(map[string]interface{})
portainerMetadata["ResourceControl"] = resourceControl
return object
}

View File

@ -11,9 +11,7 @@ import (
"github.com/portainer/portainer/api/internal/authorization"
)
const (
configObjectIdentifier = "ID"
)
const configObjectIdentifier = "ID"
func getInheritedResourceControlFromConfigLabels(dockerClient *client.Client, endpointID portainer.EndpointID, configID string, resourceControls []portainer.ResourceControl) (*portainer.ResourceControl, error) {
config, _, err := dockerClient.ConfigInspectWithRaw(context.Background(), configID)
@ -78,10 +76,9 @@ func (transport *Transport) configInspectOperation(response *http.Response, exec
// https://docs.docker.com/engine/api/v1.37/#operation/ConfigList
// https://docs.docker.com/engine/api/v1.37/#operation/ConfigInspect
func selectorConfigLabels(responseObject map[string]interface{}) map[string]interface{} {
secretSpec := utils.GetJSONObject(responseObject, "Spec")
if secretSpec != nil {
secretLabelsObject := utils.GetJSONObject(secretSpec, "Labels")
return secretLabelsObject
if secretSpec := utils.GetJSONObject(responseObject, "Spec"); secretSpec != nil {
return utils.GetJSONObject(secretSpec, "Labels")
}
return nil
}

View File

@ -7,9 +7,7 @@ import (
"github.com/portainer/portainer/api/http/proxy/factory/utils"
)
const (
taskServiceObjectIdentifier = "ServiceID"
)
const taskServiceObjectIdentifier = "ServiceID"
// taskListOperation extracts the response as a JSON array, loop through the tasks array
// and filter the containers based on resource controls before rewriting the response.
@ -46,5 +44,6 @@ func selectorTaskLabels(responseObject map[string]interface{}) map[string]interf
return utils.GetJSONObject(containerSpecObject, "Labels")
}
}
return nil
}

View File

@ -7,19 +7,17 @@ import (
"net/http"
"path"
"github.com/docker/docker/client"
"github.com/rs/zerolog/log"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/http/proxy/factory/utils"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/authorization"
"github.com/portainer/portainer/api/internal/snapshot"
"github.com/docker/docker/client"
"github.com/rs/zerolog/log"
)
const (
volumeObjectIdentifier = "ResourceID"
)
const volumeObjectIdentifier = "ResourceID"
func getInheritedResourceControlFromVolumeLabels(dockerClient *client.Client, endpointID portainer.EndpointID, volumeID string, resourceControls []portainer.ResourceControl) (*portainer.ResourceControl, error) {
volume, err := dockerClient.VolumeInspect(context.Background(), volumeID)
@ -57,14 +55,13 @@ func (transport *Transport) volumeListOperation(response *http.Response, executo
Msg("snapshot is not filled into the endpoint.")
}
}
for _, volumeObject := range volumeData {
volume := volumeObject.(map[string]interface{})
err = transport.decorateVolumeResponseWithResourceID(volume)
if err != nil {
if err := transport.decorateVolumeResponseWithResourceID(volume); err != nil {
return fmt.Errorf("failed decorating volume response: %w", err)
}
}
resourceOperationParameters := &resourceOperationParameters{
@ -77,6 +74,7 @@ func (transport *Transport) volumeListOperation(response *http.Response, executo
if err != nil {
return err
}
// Overwrite the original volume list
responseObject["Volumes"] = volumeData
}
@ -94,8 +92,7 @@ func (transport *Transport) volumeInspectOperation(response *http.Response, exec
return err
}
err = transport.decorateVolumeResponseWithResourceID(responseObject)
if err != nil {
if err := transport.decorateVolumeResponseWithResourceID(responseObject); err != nil {
return fmt.Errorf("failed decorating volume response: %w", err)
}
@ -148,8 +145,7 @@ func (transport *Transport) decorateVolumeResourceCreationOperation(request *htt
}
defer cli.Close()
_, err = cli.VolumeInspect(context.Background(), volumeID)
if err == nil {
if _, err = cli.VolumeInspect(context.Background(), volumeID); err == nil {
return &http.Response{
StatusCode: http.StatusConflict,
}, errors.New("a volume with the same name already exists")
@ -164,6 +160,7 @@ func (transport *Transport) decorateVolumeResourceCreationOperation(request *htt
if response.StatusCode == http.StatusCreated {
err = transport.decorateVolumeCreationResponse(response, resourceType, tokenData.ID)
}
return response, err
}
@ -195,7 +192,6 @@ func (transport *Transport) decorateVolumeCreationResponse(response *http.Respon
}
func (transport *Transport) restrictedVolumeOperation(requestPath string, request *http.Request) (*http.Response, error) {
if request.Method == http.MethodGet {
return transport.rewriteOperation(request, transport.volumeInspectOperation)
}
@ -210,6 +206,7 @@ func (transport *Transport) restrictedVolumeOperation(requestPath string, reques
if request.Method == http.MethodDelete {
return transport.executeGenericResourceDeletionOperation(request, resourceID, volumeName, portainer.VolumeResourceControl)
}
return transport.restrictedResourceOperation(request, resourceID, volumeName, portainer.VolumeResourceControl, false)
}
@ -218,6 +215,7 @@ func (transport *Transport) getVolumeResourceID(volumeName string) (string, erro
if err != nil {
return "", fmt.Errorf("failed fetching docker id: %w", err)
}
return fmt.Sprintf("%s_%s", volumeName, dockerID), nil
}

View File

@ -4,7 +4,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/internal/tag"
"github.com/portainer/portainer/api/tag"
)
// EdgeGroupRelatedEndpoints returns a list of environments(endpoints) related to this Edge group

View File

@ -37,12 +37,12 @@ func (service *Service) BuildEdgeStack(
registries []portainer.RegistryID,
useManifestNamespaces bool,
) (*portainer.EdgeStack, error) {
err := validateUniqueName(tx.EdgeStack().EdgeStacks, name)
if err != nil {
if err := validateUniqueName(tx.EdgeStack().EdgeStacks, name); err != nil {
return nil, err
}
stackID := tx.EdgeStack().GetNextIdentifier()
return &portainer.EdgeStack{
ID: portainer.EdgeStackID(stackID),
Name: name,
@ -77,7 +77,6 @@ func (service *Service) PersistEdgeStack(
storeManifest edgetypes.StoreManifestFunc) (*portainer.EdgeStack, error) {
relationConfig, err := edge.FetchEndpointRelationsConfig(tx)
if err != nil {
return nil, fmt.Errorf("unable to find environment relations in database: %w", err)
}
@ -87,6 +86,7 @@ func (service *Service) PersistEdgeStack(
if errors.Is(err, edge.ErrEdgeGroupNotFound) {
return nil, httperrors.NewInvalidPayloadError(err.Error())
}
return nil, fmt.Errorf("unable to persist environment relation in database: %w", err)
}
@ -101,13 +101,11 @@ func (service *Service) PersistEdgeStack(
stack.EntryPoint = composePath
stack.NumDeployments = len(relatedEndpointIds)
err = service.updateEndpointRelations(tx, stack.ID, relatedEndpointIds)
if err != nil {
if err := service.updateEndpointRelations(tx, stack.ID, relatedEndpointIds); err != nil {
return nil, fmt.Errorf("unable to update endpoint relations: %w", err)
}
err = tx.EdgeStack().Create(stack.ID, stack)
if err != nil {
if err := tx.EdgeStack().Create(stack.ID, stack); err != nil {
return nil, err
}
@ -126,8 +124,7 @@ func (service *Service) updateEndpointRelations(tx dataservices.DataStoreTx, edg
relation.EdgeStacks[edgeStackID] = true
err = endpointRelationService.UpdateEndpointRelation(endpointID, relation)
if err != nil {
if err := endpointRelationService.UpdateEndpointRelation(endpointID, relation); err != nil {
return fmt.Errorf("unable to persist endpoint relation in database: %w", err)
}
}
@ -155,14 +152,12 @@ func (service *Service) DeleteEdgeStack(tx dataservices.DataStoreTx, edgeStackID
delete(relation.EdgeStacks, edgeStackID)
err = tx.EndpointRelation().UpdateEndpointRelation(endpointID, relation)
if err != nil {
if err := tx.EndpointRelation().UpdateEndpointRelation(endpointID, relation); err != nil {
return errors.WithMessage(err, "Unable to persist environment relation in database")
}
}
err = tx.EdgeStack().DeleteEdgeStack(edgeStackID)
if err != nil {
if err := tx.EdgeStack().DeleteEdgeStack(edgeStackID); err != nil {
return errors.WithMessage(err, "Unable to remove the edge stack from the database")
}

View File

@ -4,6 +4,7 @@ import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/stretchr/testify/assert"
)

View File

@ -8,6 +8,7 @@ import (
// NodesCount returns the total node number of all environments
func NodesCount(endpoints []portainer.Endpoint) int {
nodes := 0
for _, env := range endpoints {
if !endpointutils.IsEdgeEndpoint(&env) || env.UserTrusted {
nodes += countNodes(&env)
@ -28,11 +29,3 @@ func countNodes(endpoint *portainer.Endpoint) int {
return 1
}
func max(a, b int) int {
if a > b {
return a
}
return b
}

View File

@ -1,16 +0,0 @@
package securecookie
import (
"crypto/rand"
"io"
)
// GenerateRandomKey generates a random key of specified length
// source: https://github.com/gorilla/securecookie/blob/master/securecookie.go#L515
func GenerateRandomKey(length int) []byte {
k := make([]byte, length)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return nil
}
return k
}

View File

@ -1,23 +0,0 @@
package slices
// Map applies the given function to each element of the slice and returns a new slice with the results
func Map[T, U any](s []T, f func(T) U) []U {
result := make([]U, len(s))
for i, v := range s {
result[i] = f(v)
}
return result
}
// Filter returns a new slice containing only the elements of the slice for which the given predicate returns true
func Filter[T any](s []T, predicate func(T) bool) []T {
n := 0
for _, v := range s {
if predicate(v) {
s[n] = v
n++
}
}
return s[:n]
}

View File

@ -1,41 +0,0 @@
package unique
func Unique[T comparable](items []T) []T {
return UniqueBy(items, func(item T) T {
return item
})
}
func UniqueBy[ItemType any, ComparableType comparable](items []ItemType, accessorFunc func(ItemType) ComparableType) []ItemType {
includedItems := make(map[ComparableType]bool)
result := []ItemType{}
for _, item := range items {
if _, isIncluded := includedItems[accessorFunc(item)]; !isIncluded {
includedItems[accessorFunc(item)] = true
result = append(result, item)
}
}
return result
}
/**
type someType struct {
id int
fn func()
}
func Test() {
ids := []int{1, 2, 3, 3}
_ = UniqueBy(ids, func(id int) int { return id })
_ = Unique(ids) // shorthand for UniqueBy Identity/self
as := []someType{{id: 1}, {id: 2}, {id: 3}, {id: 3}}
_ = UniqueBy(as, func(item someType) int { return item.id }) // no error
_ = UniqueBy(as, func(item someType) someType { return item }) // compile error - someType is not comparable
_ = Unique(as) // compile error - shorthand fails for the same reason
}
*/

View File

@ -6,10 +6,10 @@ import (
"time"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/apikey"
"github.com/portainer/portainer/api/dataservices"
"github.com/golang-jwt/jwt/v4"
"github.com/portainer/portainer/api/internal/securecookie"
"github.com/rs/zerolog/log"
)
@ -51,7 +51,7 @@ func NewService(userSessionDuration string, dataStore dataservices.DataStore) (*
return nil, err
}
secret := securecookie.GenerateRandomKey(32)
secret := apikey.GenerateRandomKey(32)
if secret == nil {
return nil, errSecretGeneration
}
@ -69,6 +69,7 @@ func NewService(userSessionDuration string, dataStore dataservices.DataStore) (*
userSessionTimeout,
dataStore,
}
return service, nil
}
@ -80,16 +81,18 @@ func getOrCreateKubeSecret(dataStore dataservices.DataStore) ([]byte, error) {
kubeSecret := settings.OAuthSettings.KubeSecretKey
if kubeSecret == nil {
kubeSecret = securecookie.GenerateRandomKey(32)
kubeSecret = apikey.GenerateRandomKey(32)
if kubeSecret == nil {
return nil, errSecretGeneration
}
settings.OAuthSettings.KubeSecretKey = kubeSecret
err = dataStore.Settings().UpdateSettings(settings)
if err != nil {
if err := dataStore.Settings().UpdateSettings(settings); err != nil {
return nil, err
}
}
return kubeSecret, nil
}

View File

@ -3,8 +3,8 @@ package cli
import (
"context"
"github.com/portainer/portainer/api/concurrent"
models "github.com/portainer/portainer/api/http/models/kubernetes"
"github.com/portainer/portainer/api/internal/concurrent"
"k8s.io/apimachinery/pkg/api/errors"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"

View File

@ -4,6 +4,7 @@ import (
"context"
models "github.com/portainer/portainer/api/http/models/kubernetes"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
labels "k8s.io/apimachinery/pkg/labels"
@ -67,9 +68,7 @@ func (kcl *KubeClient) GetServices(namespace string, lookupApplications bool) ([
return result, nil
}
// CreateService creates a new service in a given namespace in a k8s endpoint.
func (kcl *KubeClient) CreateService(namespace string, info models.K8sServiceInfo) error {
ServiceClient := kcl.cli.CoreV1().Services(namespace)
func (kcl *KubeClient) fillService(info models.K8sServiceInfo) v1.Service {
var service v1.Service
service.Name = info.Name
@ -93,16 +92,21 @@ func (kcl *KubeClient) CreateService(namespace string, info models.K8sServiceInf
// Set ingresses.
for _, i := range info.IngressStatus {
var ing v1.LoadBalancerIngress
ing.IP = i.IP
ing.Hostname = i.Host
service.Status.LoadBalancer.Ingress = append(
service.Status.LoadBalancer.Ingress,
ing,
v1.LoadBalancerIngress{IP: i.IP, Hostname: i.Host},
)
}
_, err := ServiceClient.Create(context.Background(), &service, metav1.CreateOptions{})
return service
}
// CreateService creates a new service in a given namespace in a k8s endpoint.
func (kcl *KubeClient) CreateService(namespace string, info models.K8sServiceInfo) error {
serviceClient := kcl.cli.CoreV1().Services(namespace)
service := kcl.fillService(info)
_, err := serviceClient.Create(context.Background(), &service, metav1.CreateOptions{})
return err
}
@ -120,45 +124,16 @@ func (kcl *KubeClient) DeleteServices(reqs models.K8sServiceDeleteRequests) erro
)
}
}
return err
}
// UpdateService updates service in a given namespace in a k8s endpoint.
func (kcl *KubeClient) UpdateService(namespace string, info models.K8sServiceInfo) error {
ServiceClient := kcl.cli.CoreV1().Services(namespace)
var service v1.Service
serviceClient := kcl.cli.CoreV1().Services(namespace)
service := kcl.fillService(info)
service.Name = info.Name
service.Spec.Type = v1.ServiceType(info.Type)
service.Namespace = info.Namespace
service.Annotations = info.Annotations
service.Labels = info.Labels
service.Spec.AllocateLoadBalancerNodePorts = info.AllocateLoadBalancerNodePorts
service.Spec.Selector = info.Selector
// Set ports.
for _, p := range info.Ports {
var port v1.ServicePort
port.Name = p.Name
port.NodePort = int32(p.NodePort)
port.Port = int32(p.Port)
port.Protocol = v1.Protocol(p.Protocol)
port.TargetPort = intstr.FromString(p.TargetPort)
service.Spec.Ports = append(service.Spec.Ports, port)
}
// Set ingresses.
for _, i := range info.IngressStatus {
var ing v1.LoadBalancerIngress
ing.IP = i.IP
ing.Hostname = i.Host
service.Status.LoadBalancer.Ingress = append(
service.Status.LoadBalancer.Ingress,
ing,
)
}
_, err := ServiceClient.Update(context.Background(), &service, metav1.UpdateOptions{})
_, err := serviceClient.Update(context.Background(), &service, metav1.UpdateOptions{})
return err
}
@ -210,5 +185,4 @@ func makeApplication(meta metav1.Object) []models.K8sApplication {
Name: ownerReference.Name,
},
}
}

View File

@ -16,8 +16,7 @@ func Test_getOAuthToken(t *testing.T) {
t.Run("getOAuthToken fails upon invalid code", func(t *testing.T) {
code := ""
_, err := getOAuthToken(code, config)
if err == nil {
if _, err := getOAuthToken(code, config); err == nil {
t.Errorf("getOAuthToken should fail upon providing invalid code; code=%v", code)
}
})
@ -91,22 +90,19 @@ func Test_getResource(t *testing.T) {
defer srv.Close()
t.Run("should fail upon missing Authorization Bearer header", func(t *testing.T) {
_, err := getResource("", config)
if err == nil {
if _, err := getResource("", config); err == nil {
t.Errorf("getResource should fail if access token is not provided in auth bearer header")
}
})
t.Run("should fail upon providing incorrect Authorization Bearer header", func(t *testing.T) {
_, err := getResource("incorrect-token", config)
if err == nil {
if _, err := getResource("incorrect-token", config); err == nil {
t.Errorf("getResource should fail if incorrect access token provided in auth bearer header")
}
})
t.Run("should succeed upon providing correct Authorization Bearer header", func(t *testing.T) {
_, err := getResource(oauthtest.AccessToken, config)
if err != nil {
if _, err := getResource(oauthtest.AccessToken, config); err != nil {
t.Errorf("getResource should succeed if correct access token provided in auth bearer header")
}
})
@ -120,8 +116,7 @@ func Test_Authenticate(t *testing.T) {
srv, config := oauthtest.RunOAuthServer(code, &portainer.OAuthSettings{})
defer srv.Close()
_, err := authService.Authenticate(code, config)
if err == nil {
if _, err := authService.Authenticate(code, config); err == nil {
t.Error("Authenticate should fail to extract username from resource if incorrect UserIdentifier provided")
}
})

View File

@ -12,6 +12,7 @@ import (
gittypes "github.com/portainer/portainer/api/git/types"
models "github.com/portainer/portainer/api/http/models/kubernetes"
"github.com/portainer/portainer/pkg/featureflags"
"golang.org/x/oauth2"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/version"
@ -322,14 +323,14 @@ type (
Name string `json:"Name"`
Status map[EndpointID]EdgeStackStatus `json:"Status"`
// StatusArray map[EndpointID][]EdgeStackStatus `json:"StatusArray"`
CreationDate int64 `json:"CreationDate"`
EdgeGroups []EdgeGroupID `json:"EdgeGroups"`
ProjectPath string `json:"ProjectPath"`
EntryPoint string `json:"EntryPoint"`
Version int `json:"Version"`
NumDeployments int `json:"NumDeployments"`
ManifestPath string
DeploymentType EdgeStackDeploymentType
CreationDate int64 `json:"CreationDate"`
EdgeGroups []EdgeGroupID `json:"EdgeGroups"`
ProjectPath string `json:"ProjectPath"`
EntryPoint string `json:"EntryPoint"`
Version int `json:"Version"`
NumDeployments int `json:"NumDeployments"`
ManifestPath string `json:"ManifestPath"`
DeploymentType EdgeStackDeploymentType `json:"DeploymentType"`
// Uses the manifest's namespaces instead of the default one
UseManifestNamespaces bool
@ -554,23 +555,22 @@ type (
// Extension represents a deprecated Portainer extension
Extension struct {
// Extension Identifier
ID ExtensionID `json:"Id" example:"1"`
Enabled bool `json:"Enabled"`
Name string `json:"Name,omitempty"`
ShortDescription string `json:"ShortDescription,omitempty"`
Description string `json:"Description,omitempty"`
DescriptionURL string `json:"DescriptionURL,omitempty"`
Price string `json:"Price,omitempty"`
PriceDescription string `json:"PriceDescription,omitempty"`
Deal bool `json:"Deal,omitempty"`
Available bool `json:"Available,omitempty"`
License LicenseInformation `json:"License,omitempty"`
Version string `json:"Version"`
UpdateAvailable bool `json:"UpdateAvailable"`
ShopURL string `json:"ShopURL,omitempty"`
Images []string `json:"Images,omitempty"`
Logo string `json:"Logo,omitempty"`
ID ExtensionID `json:"Id" example:"1"`
Enabled bool `json:"Enabled"`
Name string `json:"Name,omitempty"`
ShortDescription string `json:"ShortDescription,omitempty"`
Description string `json:"Description,omitempty"`
DescriptionURL string `json:"DescriptionURL,omitempty"`
Price string `json:"Price,omitempty"`
PriceDescription string `json:"PriceDescription,omitempty"`
Deal bool `json:"Deal,omitempty"`
Available bool `json:"Available,omitempty"`
License ExtensionLicenseInformation `json:"License,omitempty"`
Version string `json:"Version"`
UpdateAvailable bool `json:"UpdateAvailable"`
ShopURL string `json:"ShopURL,omitempty"`
Images []string `json:"Images,omitempty"`
Logo string `json:"Logo,omitempty"`
}
// ExtensionID represents a extension identifier
@ -737,8 +737,8 @@ type (
Groups []string
}
// LicenseInformation represents information about an extension license
LicenseInformation struct {
// ExtensionLicenseInformation represents information about an extension license
ExtensionLicenseInformation struct {
LicenseKey string `json:"LicenseKey,omitempty"`
Company string `json:"Company,omitempty"`
Expiration string `json:"Expiration,omitempty"`
@ -939,6 +939,18 @@ type (
HideStacksFunctionality bool `json:"hideStacksFunctionality" example:"false"`
}
Edge struct {
// The command list interval for edge agent - used in edge async mode (in seconds)
CommandInterval int `json:"CommandInterval" example:"5"`
// The ping interval for edge agent - used in edge async mode (in seconds)
PingInterval int `json:"PingInterval" example:"5"`
// The snapshot interval for edge agent - used in edge async mode (in seconds)
SnapshotInterval int `json:"SnapshotInterval" example:"5"`
// Deprecated 2.18
AsyncMode bool `json:"AsyncMode,omitempty" example:"false"`
}
// Settings represents the application settings
Settings struct {
// URL to a logo that will be displayed on the login page as well as on top of the sidebar. Will use default Portainer logo when value is empty string
@ -984,17 +996,7 @@ type (
// EdgePortainerURL is the URL that is exposed to edge agents
EdgePortainerURL string `json:"EdgePortainerUrl"`
Edge struct {
// The command list interval for edge agent - used in edge async mode (in seconds)
CommandInterval int `json:"CommandInterval" example:"5"`
// The ping interval for edge agent - used in edge async mode (in seconds)
PingInterval int `json:"PingInterval" example:"5"`
// The snapshot interval for edge agent - used in edge async mode (in seconds)
SnapshotInterval int `json:"SnapshotInterval" example:"5"`
// Deprecated 2.18
AsyncMode bool
}
Edge Edge `json:"Edge"`
// Deprecated fields
DisplayDonationHeader bool `json:"DisplayDonationHeader,omitempty"`

View File

@ -58,8 +58,8 @@ func (s Set[T]) Copy() Set[T] {
}
// Difference returns a new set containing the keys that are in the first set but not in the second set.
func (set Set[T]) Difference(second Set[T]) Set[T] {
difference := set.Copy()
func (s Set[T]) Difference(second Set[T]) Set[T] {
difference := s.Copy()
for key := range second {
difference.Remove(key)

43
api/slicesx/slices.go Normal file
View File

@ -0,0 +1,43 @@
package slicesx
// Map applies the given function to each element of the slice and returns a new slice with the results
func Map[T, U any](s []T, f func(T) U) []U {
result := make([]U, len(s))
for i, v := range s {
result[i] = f(v)
}
return result
}
// Filter returns a new slice containing only the elements of the slice for which the given predicate returns true
func Filter[T any](s []T, predicate func(T) bool) []T {
n := 0
for _, v := range s {
if predicate(v) {
s[n] = v
n++
}
}
return s[:n]
}
func Unique[T comparable](items []T) []T {
return UniqueBy(items, func(item T) T {
return item
})
}
func UniqueBy[ItemType any, ComparableType comparable](items []ItemType, accessorFunc func(ItemType) ComparableType) []ItemType {
includedItems := make(map[ComparableType]bool)
result := []ItemType{}
for _, item := range items {
if _, isIncluded := includedItems[accessorFunc(item)]; !isIncluded {
includedItems[accessorFunc(item)] = true
result = append(result, item)
}
}
return result
}

View File

@ -1,4 +1,4 @@
package slices
package slicesx
import (
"strconv"