Compare commits
62 Commits
2.33.0-rc2
...
develop
Author | SHA1 | Date |
---|---|---|
|
24a092836b | |
|
290374f6fc | |
|
2e7acc73d8 | |
|
666d51482e | |
|
eedf37d18a | |
|
16f210966b | |
|
30e70b6327 | |
|
f91a2e3b65 | |
|
fdc405c912 | |
|
2f2e70bb86 | |
|
eef54f4153 | |
|
ad1c015f01 | |
|
326fdcf6ea | |
|
26a0c4e809 | |
|
acb465ae33 | |
|
5418a0bee6 | |
|
a59815264d | |
|
3ac0be4e35 | |
|
feae930293 | |
|
7ebb52ec6d | |
|
8b73ad3b6f | |
|
6fc2a8234d | |
|
e2c2724e36 | |
|
6abfbe8553 | |
|
54f6add45d | |
|
f8ae5368bf | |
|
2ba348551d | |
|
110f88f22d | |
|
c90a15dd0f | |
|
f4335e1e72 | |
|
8d9e1a0ad5 | |
|
48dcfcb08f | |
|
def19be230 | |
|
36154e9d33 | |
|
7cf6bb78d6 | |
|
541f281b29 | |
|
965ef5246b | |
|
9c88057bd1 | |
|
8c52e92705 | |
|
3a727d24ce | |
|
185558a642 | |
|
35aa525bd2 | |
|
2ce8788487 | |
|
ec0e98a64b | |
|
121e9f03a4 | |
|
a0295b1a39 | |
|
30aba86380 | |
|
89f5a20786 | |
|
ef7caa260b | |
|
39d50ef70e | |
|
58a1392480 | |
|
06f6bcc340 | |
|
c9d18b614b | |
|
2035c42c3c | |
|
a760426b87 | |
|
10b129a02e | |
|
129b9d5db9 | |
|
2c08becf6c | |
|
a3bfe7cb0c | |
|
7049a8a2bb | |
|
1197b1dd8d | |
|
7f167ff2fc |
|
@ -6,7 +6,7 @@ body:
|
|||
|
||||
Thanks for suggesting an idea for Portainer!
|
||||
|
||||
Before opening a new idea or feature request, make sure that we do not have any duplicates already open. You can ensure this by [searching this discussion cagetory](https://github.com/orgs/portainer/discussions/categories/ideas). If there is a duplicate, please add a comment to the existing idea instead.
|
||||
Before opening a new idea or feature request, make sure that we do not have any duplicates already open. You can ensure this by [searching this discussion category](https://github.com/orgs/portainer/discussions/categories/ideas). If there is a duplicate, please add a comment to the existing idea instead.
|
||||
|
||||
Also, be sure to check our [knowledge base](https://portal.portainer.io/knowledge) and [documentation](https://docs.portainer.io) as they may point you toward a solution.
|
||||
|
||||
|
|
|
@ -94,6 +94,9 @@ body:
|
|||
description: We only provide support for current versions of Portainer as per the lifecycle policy linked above. If you are on an older version of Portainer we recommend [updating first](https://docs.portainer.io/start/upgrade) in case your bug has already been fixed.
|
||||
multiple: false
|
||||
options:
|
||||
- '2.34.0'
|
||||
- '2.33.1'
|
||||
- '2.33.0'
|
||||
- '2.32.0'
|
||||
- '2.31.3'
|
||||
- '2.31.2'
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
version: "2"
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- forbidigo
|
||||
settings:
|
||||
forbidigo:
|
||||
forbid:
|
||||
- pattern: ^dataservices.DataStore.(EdgeGroup|EdgeJob|EdgeStack|EndpointRelation|Endpoint|GitCredential|Registry|ResourceControl|Role|Settings|Snapshot|Stack|Tag|User)$
|
||||
msg: Use a transaction instead
|
||||
analyze-types: true
|
|
@ -13,6 +13,12 @@ linters:
|
|||
- perfsprint
|
||||
- staticcheck
|
||||
- unused
|
||||
- mirror
|
||||
- durationcheck
|
||||
- errorlint
|
||||
- govet
|
||||
- zerologlint
|
||||
- testifylint
|
||||
settings:
|
||||
staticcheck:
|
||||
checks: ["all", "-ST1003", "-ST1005", "-ST1016", "-SA1019", "-QF1003"]
|
||||
|
@ -32,12 +38,20 @@ linters:
|
|||
desc: use github.com/portainer/portainer/pkg/libcrypto
|
||||
- pkg: github.com/portainer/libhttp
|
||||
desc: use github.com/portainer/portainer/pkg/libhttp
|
||||
- pkg: golang.org/x/crypto
|
||||
desc: golang.org/x/crypto is not allowed because of FIPS mode
|
||||
- pkg: github.com/ProtonMail/go-crypto/openpgp
|
||||
desc: github.com/ProtonMail/go-crypto/openpgp is not allowed because of FIPS mode
|
||||
forbidigo:
|
||||
forbid:
|
||||
- pattern: ^tls\.Config$
|
||||
msg: Use crypto.CreateTLSConfiguration() instead
|
||||
- pattern: ^tls\.Config\.(InsecureSkipVerify|MinVersion|MaxVersion|CipherSuites|CurvePreferences)$
|
||||
msg: Do not set this field directly, use crypto.CreateTLSConfiguration() instead
|
||||
- pattern: ^object\.(Commit|Tag)\.Verify$
|
||||
msg: "Not allowed because of FIPS mode"
|
||||
- pattern: ^(types\.SystemContext\.)?(DockerDaemonInsecureSkipTLSVerify|DockerInsecureSkipTLSVerify|OCIInsecureSkipTLSVerify)$
|
||||
msg: "Not allowed because of FIPS mode"
|
||||
analyze-types: true
|
||||
exclusions:
|
||||
generated: lax
|
||||
|
|
7
Makefile
7
Makefile
|
@ -54,14 +54,12 @@ client-deps: ## Install client dependencies
|
|||
tidy: ## Tidy up the go.mod file
|
||||
@go mod tidy
|
||||
|
||||
|
||||
##@ Cleanup
|
||||
.PHONY: clean
|
||||
clean: ## Remove all build and download artifacts
|
||||
@echo "Clearing the dist directory..."
|
||||
@rm -rf dist/*
|
||||
|
||||
|
||||
##@ Testing
|
||||
.PHONY: test test-client test-server
|
||||
test: test-server test-client ## Run all tests
|
||||
|
@ -105,16 +103,15 @@ lint: lint-client lint-server ## Lint all code
|
|||
lint-client: ## Lint client code
|
||||
yarn lint
|
||||
|
||||
lint-server: ## Lint server code
|
||||
lint-server: tidy ## Lint server code
|
||||
golangci-lint run --timeout=10m -c .golangci.yaml
|
||||
|
||||
golangci-lint run --timeout=10m --new-from-rev=HEAD~ -c .golangci-forward.yaml
|
||||
|
||||
##@ Extension
|
||||
.PHONY: dev-extension
|
||||
dev-extension: build-server build-client ## Run the extension in development mode
|
||||
make local -f build/docker-extension/Makefile
|
||||
|
||||
|
||||
##@ Docs
|
||||
.PHONY: docs-build docs-validate docs-clean docs-validate-clean
|
||||
docs-build: init-dist ## Build docs
|
||||
|
|
|
@ -10,31 +10,31 @@ func Test_generateRandomKey(t *testing.T) {
|
|||
is := assert.New(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
wantLenth int
|
||||
name string
|
||||
wantLength int
|
||||
}{
|
||||
{
|
||||
name: "Generate a random key of length 16",
|
||||
wantLenth: 16,
|
||||
name: "Generate a random key of length 16",
|
||||
wantLength: 16,
|
||||
},
|
||||
{
|
||||
name: "Generate a random key of length 32",
|
||||
wantLenth: 32,
|
||||
name: "Generate a random key of length 32",
|
||||
wantLength: 32,
|
||||
},
|
||||
{
|
||||
name: "Generate a random key of length 64",
|
||||
wantLenth: 64,
|
||||
name: "Generate a random key of length 64",
|
||||
wantLength: 64,
|
||||
},
|
||||
{
|
||||
name: "Generate a random key of length 128",
|
||||
wantLenth: 128,
|
||||
name: "Generate a random key of length 128",
|
||||
wantLength: 128,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GenerateRandomKey(tt.wantLenth)
|
||||
is.Equal(tt.wantLenth, len(got))
|
||||
got := GenerateRandomKey(tt.wantLength)
|
||||
is.Len(got, tt.wantLength)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -10,9 +10,10 @@ import (
|
|||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_SatisfiesAPIKeyServiceInterface(t *testing.T) {
|
||||
|
@ -30,7 +31,7 @@ func Test_GenerateApiKey(t *testing.T) {
|
|||
t.Run("Successfully generates API key", func(t *testing.T) {
|
||||
desc := "test-1"
|
||||
rawKey, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, desc)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
is.NotEmpty(rawKey)
|
||||
is.NotEmpty(apiKey)
|
||||
is.Equal(desc, apiKey.Description)
|
||||
|
@ -38,7 +39,7 @@ func Test_GenerateApiKey(t *testing.T) {
|
|||
|
||||
t.Run("Api key prefix is 7 chars", func(t *testing.T) {
|
||||
rawKey, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-2")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
is.Equal(rawKey[:7], apiKey.Prefix)
|
||||
is.Len(apiKey.Prefix, 7)
|
||||
|
@ -46,7 +47,7 @@ func Test_GenerateApiKey(t *testing.T) {
|
|||
|
||||
t.Run("Api key has 'ptr_' as prefix", func(t *testing.T) {
|
||||
rawKey, _, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-x")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
is.Equal(portainerAPIKeyPrefix, "ptr_")
|
||||
is.True(strings.HasPrefix(rawKey, "ptr_"))
|
||||
|
@ -55,7 +56,7 @@ func Test_GenerateApiKey(t *testing.T) {
|
|||
t.Run("Successfully caches API key", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-3")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
userFromCache, apiKeyFromCache, ok := service.cache.Get(apiKey.Digest)
|
||||
is.True(ok)
|
||||
|
@ -65,7 +66,7 @@ func Test_GenerateApiKey(t *testing.T) {
|
|||
|
||||
t.Run("Decoded raw api-key digest matches generated digest", func(t *testing.T) {
|
||||
rawKey, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-4")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
generatedDigest := sha256.Sum256([]byte(rawKey))
|
||||
|
||||
|
@ -83,10 +84,10 @@ func Test_GetAPIKey(t *testing.T) {
|
|||
t.Run("Successfully returns all API keys", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-1")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
apiKeyGot, err := service.GetAPIKey(apiKey.ID)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
is.Equal(apiKey, apiKeyGot)
|
||||
})
|
||||
|
@ -102,12 +103,12 @@ func Test_GetAPIKeys(t *testing.T) {
|
|||
t.Run("Successfully returns all API keys", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, _, err := service.GenerateApiKey(user, "test-1")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
_, _, err = service.GenerateApiKey(user, "test-2")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys, err := service.GetAPIKeys(user.ID)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
is.Len(keys, 2)
|
||||
})
|
||||
}
|
||||
|
@ -122,10 +123,10 @@ func Test_GetDigestUserAndKey(t *testing.T) {
|
|||
t.Run("Successfully returns user and api key associated to digest", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-1")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
userGot, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
is.Equal(user, userGot)
|
||||
is.Equal(*apiKey, apiKeyGot)
|
||||
})
|
||||
|
@ -133,10 +134,10 @@ func Test_GetDigestUserAndKey(t *testing.T) {
|
|||
t.Run("Successfully caches user and api key associated to digest", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-1")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
userGot, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
is.Equal(user, userGot)
|
||||
is.Equal(*apiKey, apiKeyGot)
|
||||
|
||||
|
@ -158,14 +159,14 @@ func Test_UpdateAPIKey(t *testing.T) {
|
|||
user := portainer.User{ID: 1}
|
||||
store.User().Create(&user)
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-x")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
apiKey.LastUsed = time.Now().UTC().Unix()
|
||||
err = service.UpdateAPIKey(apiKey)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
log.Debug().Str("wanted", fmt.Sprintf("%+v", apiKey)).Str("got", fmt.Sprintf("%+v", apiKeyGot)).Msg("")
|
||||
|
||||
|
@ -174,7 +175,7 @@ func Test_UpdateAPIKey(t *testing.T) {
|
|||
|
||||
t.Run("Successfully updates api-key in cache upon api-key update", func(t *testing.T) {
|
||||
_, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-x2")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, apiKeyFromCache, ok := service.cache.Get(apiKey.Digest)
|
||||
is.True(ok)
|
||||
|
@ -184,7 +185,7 @@ func Test_UpdateAPIKey(t *testing.T) {
|
|||
is.NotEqual(*apiKey, apiKeyFromCache)
|
||||
|
||||
err = service.UpdateAPIKey(apiKey)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, updatedAPIKeyFromCache, ok := service.cache.Get(apiKey.Digest)
|
||||
is.True(ok)
|
||||
|
@ -202,30 +203,30 @@ func Test_DeleteAPIKey(t *testing.T) {
|
|||
t.Run("Successfully updates the api-key", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-1")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
is.Equal(*apiKey, apiKeyGot)
|
||||
|
||||
err = service.DeleteAPIKey(apiKey.ID)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = service.GetDigestUserAndKey(apiKey.Digest)
|
||||
is.Error(err)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Successfully removes api-key from cache upon deletion", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-1")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, apiKeyFromCache, ok := service.cache.Get(apiKey.Digest)
|
||||
is.True(ok)
|
||||
is.Equal(*apiKey, apiKeyFromCache)
|
||||
|
||||
err = service.DeleteAPIKey(apiKey.ID)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, ok = service.cache.Get(apiKey.Digest)
|
||||
is.False(ok)
|
||||
|
@ -243,10 +244,10 @@ func Test_InvalidateUserKeyCache(t *testing.T) {
|
|||
// generate api keys
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey1, err := service.GenerateApiKey(user, "test-1")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, apiKey2, err := service.GenerateApiKey(user, "test-2")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// verify api keys are present in cache
|
||||
_, apiKeyFromCache, ok := service.cache.Get(apiKey1.Digest)
|
||||
|
@ -273,11 +274,11 @@ func Test_InvalidateUserKeyCache(t *testing.T) {
|
|||
// generate keys for 2 users
|
||||
user1 := portainer.User{ID: 1}
|
||||
_, apiKey1, err := service.GenerateApiKey(user1, "test-1")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
user2 := portainer.User{ID: 2}
|
||||
_, apiKey2, err := service.GenerateApiKey(user2, "test-2")
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// verify keys in cache
|
||||
_, apiKeyFromCache, ok := service.cache.Get(apiKey1.Digest)
|
||||
|
|
|
@ -8,15 +8,19 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func listFiles(dir string) []string {
|
||||
items := make([]string, 0)
|
||||
|
||||
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if path == dir {
|
||||
return nil
|
||||
}
|
||||
|
||||
items = append(items, path)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
|
@ -26,13 +30,21 @@ func listFiles(dir string) []string {
|
|||
func Test_shouldCreateArchive(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
content := []byte("content")
|
||||
os.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
|
||||
|
||||
err := os.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
os.MkdirAll(path.Join(tmpdir, "dir"), 0700)
|
||||
os.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600)
|
||||
os.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
gzPath, err := TarGzDir(tmpdir)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, filepath.Join(tmpdir, filepath.Base(tmpdir)+".tar.gz"), gzPath)
|
||||
|
||||
extractionDir := t.TempDir()
|
||||
|
@ -45,7 +57,8 @@ func Test_shouldCreateArchive(t *testing.T) {
|
|||
wasExtracted := func(p string) {
|
||||
fullpath := path.Join(extractionDir, p)
|
||||
assert.Contains(t, extractedFiles, fullpath)
|
||||
copyContent, _ := os.ReadFile(fullpath)
|
||||
copyContent, err := os.ReadFile(fullpath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, copyContent)
|
||||
}
|
||||
|
||||
|
@ -57,13 +70,21 @@ func Test_shouldCreateArchive(t *testing.T) {
|
|||
func Test_shouldCreateArchive2(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
content := []byte("content")
|
||||
os.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
|
||||
os.MkdirAll(path.Join(tmpdir, "dir"), 0700)
|
||||
os.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600)
|
||||
os.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
|
||||
|
||||
err := os.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.MkdirAll(path.Join(tmpdir, "dir"), 0700)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
gzPath, err := TarGzDir(tmpdir)
|
||||
assert.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, filepath.Join(tmpdir, filepath.Base(tmpdir)+".tar.gz"), gzPath)
|
||||
|
||||
extractionDir := t.TempDir()
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUnzipFile(t *testing.T) {
|
||||
|
@ -20,7 +21,7 @@ func TestUnzipFile(t *testing.T) {
|
|||
|
||||
err := UnzipFile("./testdata/sample_archive.zip", dir)
|
||||
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
archiveDir := dir + "/sample_archive"
|
||||
assert.FileExists(t, filepath.Join(archiveDir, "0.txt"))
|
||||
assert.FileExists(t, filepath.Join(archiveDir, "0", "1.txt"))
|
||||
|
|
|
@ -9,8 +9,8 @@ import (
|
|||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"github.com/alecthomas/kingpin/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
)
|
||||
|
||||
// Service implements the CLIService interface
|
||||
|
@ -35,16 +35,9 @@ func CLIFlags() *portainer.CLIFlags {
|
|||
FeatureFlags: kingpin.Flag("feat", "List of feature flags").Strings(),
|
||||
EnableEdgeComputeFeatures: kingpin.Flag("edge-compute", "Enable Edge Compute features").Bool(),
|
||||
NoAnalytics: kingpin.Flag("no-analytics", "Disable Analytics in app (deprecated)").Bool(),
|
||||
TLS: kingpin.Flag("tlsverify", "TLS support").Default(defaultTLS).Bool(),
|
||||
TLSSkipVerify: kingpin.Flag("tlsskipverify", "Disable TLS server verification").Default(defaultTLSSkipVerify).Bool(),
|
||||
TLSCacert: kingpin.Flag("tlscacert", "Path to the CA").Default(defaultTLSCACertPath).String(),
|
||||
TLSCert: kingpin.Flag("tlscert", "Path to the TLS certificate file").Default(defaultTLSCertPath).String(),
|
||||
TLSKey: kingpin.Flag("tlskey", "Path to the TLS key").Default(defaultTLSKeyPath).String(),
|
||||
HTTPDisabled: kingpin.Flag("http-disabled", "Serve portainer only on https").Default(defaultHTTPDisabled).Bool(),
|
||||
HTTPEnabled: kingpin.Flag("http-enabled", "Serve portainer on http").Default(defaultHTTPEnabled).Bool(),
|
||||
SSL: kingpin.Flag("ssl", "Secure Portainer instance using SSL (deprecated)").Default(defaultSSL).Bool(),
|
||||
SSLCert: kingpin.Flag("sslcert", "Path to the SSL certificate used to secure the Portainer instance").String(),
|
||||
SSLKey: kingpin.Flag("sslkey", "Path to the SSL key used to secure the Portainer instance").String(),
|
||||
Rollback: kingpin.Flag("rollback", "Rollback the database to the previous backup").Bool(),
|
||||
SnapshotInterval: kingpin.Flag("snapshot-interval", "Duration between each environment snapshot job").String(),
|
||||
AdminPassword: kingpin.Flag("admin-password", "Set admin password with provided hash").String(),
|
||||
|
@ -70,8 +63,37 @@ func CLIFlags() *portainer.CLIFlags {
|
|||
func (Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
||||
kingpin.Version(version)
|
||||
|
||||
var hasSSLFlag, hasSSLCertFlag, hasSSLKeyFlag bool
|
||||
sslFlag := kingpin.Flag(
|
||||
"ssl",
|
||||
"Secure Portainer instance using SSL (deprecated)",
|
||||
).Default(defaultSSL).IsSetByUser(&hasSSLFlag)
|
||||
ssl := sslFlag.Bool()
|
||||
sslCertFlag := kingpin.Flag(
|
||||
"sslcert",
|
||||
"Path to the SSL certificate used to secure the Portainer instance",
|
||||
).IsSetByUser(&hasSSLCertFlag)
|
||||
sslCert := sslCertFlag.String()
|
||||
sslKeyFlag := kingpin.Flag(
|
||||
"sslkey",
|
||||
"Path to the SSL key used to secure the Portainer instance",
|
||||
).IsSetByUser(&hasSSLKeyFlag)
|
||||
sslKey := sslKeyFlag.String()
|
||||
|
||||
flags := CLIFlags()
|
||||
|
||||
var hasTLSFlag, hasTLSCertFlag, hasTLSKeyFlag bool
|
||||
tlsFlag := kingpin.Flag("tlsverify", "TLS support").Default(defaultTLS).IsSetByUser(&hasTLSFlag)
|
||||
flags.TLS = tlsFlag.Bool()
|
||||
tlsCertFlag := kingpin.Flag(
|
||||
"tlscert",
|
||||
"Path to the TLS certificate file",
|
||||
).Default(defaultTLSCertPath).IsSetByUser(&hasTLSCertFlag)
|
||||
flags.TLSCert = tlsCertFlag.String()
|
||||
tlsKeyFlag := kingpin.Flag("tlskey", "Path to the TLS key").Default(defaultTLSKeyPath).IsSetByUser(&hasTLSKeyFlag)
|
||||
flags.TLSKey = tlsKeyFlag.String()
|
||||
flags.TLSCacert = kingpin.Flag("tlscacert", "Path to the CA").Default(defaultTLSCACertPath).String()
|
||||
|
||||
kingpin.Parse()
|
||||
|
||||
if !filepath.IsAbs(*flags.Assets) {
|
||||
|
@ -83,6 +105,41 @@ func (Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
|||
*flags.Assets = filepath.Join(filepath.Dir(ex), *flags.Assets)
|
||||
}
|
||||
|
||||
// If the user didn't provide a tls flag remove the defaults to match previous behaviour
|
||||
if !hasTLSFlag {
|
||||
if !hasTLSCertFlag {
|
||||
*flags.TLSCert = ""
|
||||
}
|
||||
|
||||
if !hasTLSKeyFlag {
|
||||
*flags.TLSKey = ""
|
||||
}
|
||||
}
|
||||
|
||||
if hasSSLFlag {
|
||||
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslFlag.Model().Name, tlsFlag.Model().Name)
|
||||
|
||||
if !hasTLSFlag {
|
||||
flags.TLS = ssl
|
||||
}
|
||||
}
|
||||
|
||||
if hasSSLCertFlag {
|
||||
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslCertFlag.Model().Name, tlsCertFlag.Model().Name)
|
||||
|
||||
if !hasTLSCertFlag {
|
||||
flags.TLSCert = sslCert
|
||||
}
|
||||
}
|
||||
|
||||
if hasSSLKeyFlag {
|
||||
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslKeyFlag.Model().Name, tlsKeyFlag.Model().Name)
|
||||
|
||||
if !hasTLSKeyFlag {
|
||||
flags.TLSKey = sslKey
|
||||
}
|
||||
}
|
||||
|
||||
return flags, nil
|
||||
}
|
||||
|
||||
|
@ -109,10 +166,6 @@ func displayDeprecationWarnings(flags *portainer.CLIFlags) {
|
|||
if *flags.NoAnalytics {
|
||||
log.Warn().Msg("the --no-analytics flag has been kept to allow migration of instances running a previous version of Portainer with this flag enabled, to version 2.0 where enabling this flag will have no effect")
|
||||
}
|
||||
|
||||
if *flags.SSL {
|
||||
log.Warn().Msg("SSL is enabled by default and there is no need for the --ssl flag, it has been kept to allow migration of instances running a previous version of Portainer with this flag enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func validateEndpointURL(endpointURL string) error {
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
zerolog "github.com/rs/zerolog/log"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -22,3 +25,185 @@ func TestOptionParser(t *testing.T) {
|
|||
require.False(t, *opts.HTTPDisabled)
|
||||
require.True(t, *opts.EnableEdgeComputeFeatures)
|
||||
}
|
||||
|
||||
func TestParseTLSFlags(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedTLSFlag bool
|
||||
expectedTLSCertFlag string
|
||||
expectedTLSKeyFlag string
|
||||
expectedLogMessages []string
|
||||
}{
|
||||
{
|
||||
name: "no flags",
|
||||
expectedTLSFlag: false,
|
||||
expectedTLSCertFlag: "",
|
||||
expectedTLSKeyFlag: "",
|
||||
},
|
||||
{
|
||||
name: "only ssl flag",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "",
|
||||
expectedTLSKeyFlag: "",
|
||||
},
|
||||
{
|
||||
name: "only tls flag",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: defaultTLSCertPath,
|
||||
expectedTLSKeyFlag: defaultTLSKeyPath,
|
||||
},
|
||||
{
|
||||
name: "partial ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
"--sslcert=ssl-cert-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "ssl-cert-flag-value",
|
||||
expectedTLSKeyFlag: "",
|
||||
},
|
||||
{
|
||||
name: "partial tls flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: defaultTLSKeyPath,
|
||||
},
|
||||
{
|
||||
name: "partial tls and ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "ssl-key-flag-value",
|
||||
},
|
||||
{
|
||||
name: "partial tls and ssl flags 2",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "ssl-key-flag-value",
|
||||
},
|
||||
{
|
||||
name: "ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
"--sslcert=ssl-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "ssl-cert-flag-value",
|
||||
expectedTLSKeyFlag: "ssl-key-flag-value",
|
||||
expectedLogMessages: []string{
|
||||
"the \\\"ssl\\\" flag is deprecated. use \\\"tlsverify\\\" instead.",
|
||||
"the \\\"sslcert\\\" flag is deprecated. use \\\"tlscert\\\" instead.",
|
||||
"the \\\"sslkey\\\" flag is deprecated. use \\\"tlskey\\\" instead.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tls flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--tlskey=tls-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "tls-key-flag-value",
|
||||
},
|
||||
{
|
||||
name: "tls and ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--tlskey=tls-key-flag-value",
|
||||
"--ssl",
|
||||
"--sslcert=ssl-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "tls-key-flag-value",
|
||||
expectedLogMessages: []string{
|
||||
"the \\\"ssl\\\" flag is deprecated. use \\\"tlsverify\\\" instead.",
|
||||
"the \\\"sslcert\\\" flag is deprecated. use \\\"tlscert\\\" instead.",
|
||||
"the \\\"sslkey\\\" flag is deprecated. use \\\"tlskey\\\" instead.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var logOutput strings.Builder
|
||||
setupLogOutput(t, &logOutput)
|
||||
|
||||
if tc.args == nil {
|
||||
tc.args = []string{"portainer"}
|
||||
}
|
||||
setOsArgs(t, tc.args)
|
||||
|
||||
s := Service{}
|
||||
flags, err := s.ParseFlags("test-version")
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing flags: %v", err)
|
||||
}
|
||||
|
||||
if flags.TLS == nil {
|
||||
t.Fatal("TLS flag was nil")
|
||||
}
|
||||
|
||||
require.Equal(t, tc.expectedTLSFlag, *flags.TLS, "tlsverify flag didn't match")
|
||||
require.Equal(t, tc.expectedTLSCertFlag, *flags.TLSCert, "tlscert flag didn't match")
|
||||
require.Equal(t, tc.expectedTLSKeyFlag, *flags.TLSKey, "tlskey flag didn't match")
|
||||
|
||||
for _, expectedLogMessage := range tc.expectedLogMessages {
|
||||
require.Contains(t, logOutput.String(), expectedLogMessage, "Log didn't contain expected message")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setOsArgs(t *testing.T, args []string) {
|
||||
t.Helper()
|
||||
previousArgs := os.Args
|
||||
os.Args = args
|
||||
t.Cleanup(func() {
|
||||
os.Args = previousArgs
|
||||
})
|
||||
}
|
||||
|
||||
func setupLogOutput(t *testing.T, w io.Writer) {
|
||||
t.Helper()
|
||||
|
||||
oldLogger := zerolog.Logger
|
||||
zerolog.Logger = zerolog.Output(w)
|
||||
t.Cleanup(func() {
|
||||
zerolog.Logger = oldLogger
|
||||
})
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
"github.com/alecthomas/kingpin/v2"
|
||||
)
|
||||
|
||||
type pairList []portainer.Pair
|
||||
|
|
|
@ -309,13 +309,13 @@ func initKeyPair(fileService portainer.FileService, signatureService portainer.D
|
|||
|
||||
// dbSecretPath build the path to the file that contains the db encryption
|
||||
// secret. Normally in Docker this is built from the static path inside
|
||||
// /run/portainer for example: /run/portainer/<keyFilenameFlag> but for ease of
|
||||
// /run/secrets for example: /run/secrets/<keyFilenameFlag> but for ease of
|
||||
// use outside Docker it also accepts an absolute path
|
||||
func dbSecretPath(keyFilenameFlag string) string {
|
||||
if path.IsAbs(keyFilenameFlag) {
|
||||
return keyFilenameFlag
|
||||
}
|
||||
return path.Join("/run/portainer", keyFilenameFlag)
|
||||
return path.Join("/run/secrets", keyFilenameFlag)
|
||||
}
|
||||
|
||||
func loadEncryptionSecretKey(keyfilename string) []byte {
|
||||
|
@ -408,7 +408,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
|||
|
||||
edgeStacksService := edgestacks.NewService(dataStore)
|
||||
|
||||
sslService, err := initSSLService(*flags.AddrHTTPS, *flags.SSLCert, *flags.SSLKey, fileService, dataStore, shutdownTrigger)
|
||||
sslService, err := initSSLService(*flags.AddrHTTPS, *flags.TLSCert, *flags.TLSKey, fileService, dataStore, shutdownTrigger)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
|
|
@ -43,12 +43,12 @@ func TestDBSecretPath(t *testing.T) {
|
|||
keyFilenameFlag string
|
||||
expected string
|
||||
}{
|
||||
{keyFilenameFlag: "secret.txt", expected: "/run/portainer/secret.txt"},
|
||||
{keyFilenameFlag: "secret.txt", expected: "/run/secrets/secret.txt"},
|
||||
{keyFilenameFlag: "/tmp/secret.txt", expected: "/tmp/secret.txt"},
|
||||
{keyFilenameFlag: "/run/portainer/secret.txt", expected: "/run/portainer/secret.txt"},
|
||||
{keyFilenameFlag: "./secret.txt", expected: "/run/portainer/secret.txt"},
|
||||
{keyFilenameFlag: "/run/secrets/secret.txt", expected: "/run/secrets/secret.txt"},
|
||||
{keyFilenameFlag: "./secret.txt", expected: "/run/secrets/secret.txt"},
|
||||
{keyFilenameFlag: "../secret.txt", expected: "/run/secret.txt"},
|
||||
{keyFilenameFlag: "foo/bar/secret.txt", expected: "/run/portainer/foo/bar/secret.txt"},
|
||||
{keyFilenameFlag: "foo/bar/secret.txt", expected: "/run/secrets/foo/bar/secret.txt"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/pbkdf2"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
|
@ -14,9 +15,9 @@ import (
|
|||
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
// Not allowed in FIPS mode
|
||||
"golang.org/x/crypto/argon2" //nolint:depguard
|
||||
"golang.org/x/crypto/scrypt" //nolint:depguard
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -248,7 +249,10 @@ func aesEncryptGCMFIPS(input io.Reader, output io.Writer, passphrase []byte) err
|
|||
return err
|
||||
}
|
||||
|
||||
key := pbkdf2.Key(passphrase, salt, pbkdf2Iterations, 32, sha256.New)
|
||||
key, err := pbkdf2.Key(sha256.New, string(passphrase), salt, pbkdf2Iterations, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deriving key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
|
@ -315,7 +319,10 @@ func aesDecryptGCMFIPS(input io.Reader, passphrase []byte) (io.Reader, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
key := pbkdf2.Key(passphrase, salt, pbkdf2Iterations, 32, sha256.New)
|
||||
key, err := pbkdf2.Key(sha256.New, string(passphrase), salt, pbkdf2Iterations, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error deriving key: %w", err)
|
||||
}
|
||||
|
||||
// Initialize AES cipher block
|
||||
block, err := aes.NewCipher(key)
|
||||
|
@ -382,3 +389,18 @@ func aesDecryptOFB(input io.Reader, passphrase []byte) (io.Reader, error) {
|
|||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
// HasEncryptedHeader checks if the data has an encrypted header, note that fips
|
||||
// mode changes this behavior and so will only recognize data encrypted by the
|
||||
// same mode (fips enabled or disabled)
|
||||
func HasEncryptedHeader(data []byte) bool {
|
||||
return hasEncryptedHeader(data, fips.FIPSMode())
|
||||
}
|
||||
|
||||
func hasEncryptedHeader(data []byte, fipsMode bool) bool {
|
||||
if fipsMode {
|
||||
return bytes.HasPrefix(data, []byte(aesGcmFIPSHeader))
|
||||
}
|
||||
|
||||
return bytes.HasPrefix(data, []byte(aesGcmHeader))
|
||||
}
|
||||
|
|
|
@ -55,17 +55,19 @@ func Test_encryptAndDecrypt_withTheSamePassword(t *testing.T) {
|
|||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
|
||||
err := encrypt(originFile, encryptedFileWriter, []byte(passphrase))
|
||||
require.Nil(t, err, "Failed to encrypt a file")
|
||||
require.NoError(t, err, "Failed to encrypt a file")
|
||||
encryptedFileWriter.Close()
|
||||
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
require.Nil(t, err, "Couldn't read encrypted file")
|
||||
require.NoError(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
encryptedFileReader, err := os.Open(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
decryptedFileWriter, err := os.Create(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte(passphrase))
|
||||
|
@ -155,11 +157,11 @@ func Test_encryptAndDecrypt_withStrongPassphrase(t *testing.T) {
|
|||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
|
||||
err := encrypt(originFile, encryptedFileWriter, []byte(passphrase))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
require.NoError(t, err, "Failed to encrypt a file")
|
||||
encryptedFileWriter.Close()
|
||||
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
require.NoError(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
|
@ -169,7 +171,7 @@ func Test_encryptAndDecrypt_withStrongPassphrase(t *testing.T) {
|
|||
defer decryptedFileWriter.Close()
|
||||
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte(passphrase))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
require.NoError(t, err, "Failed to decrypt file")
|
||||
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
|
||||
|
@ -205,25 +207,29 @@ func Test_encryptAndDecrypt_withTheSamePasswordSmallFile(t *testing.T) {
|
|||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
|
||||
err := encrypt(originFile, encryptedFileWriter, []byte("passphrase"))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
require.NoError(t, err, "Failed to encrypt a file")
|
||||
encryptedFileWriter.Close()
|
||||
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
require.NoError(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
encryptedFileReader, err := os.Open(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
decryptedFileWriter, err := os.Create(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte("passphrase"))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
require.NoError(t, err, "Failed to decrypt file")
|
||||
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
_, err = io.Copy(decryptedFileWriter, decryptedReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
decryptedContent, _ := os.ReadFile(decryptedFilePath)
|
||||
decryptedContent, err := os.ReadFile(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
|
@ -247,32 +253,40 @@ func Test_encryptAndDecrypt_withEmptyPassword(t *testing.T) {
|
|||
)
|
||||
|
||||
content := randBytes(1024 * 50)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
err := os.WriteFile(originFilePath, content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
originFile, err := os.Open(originFilePath)
|
||||
require.NoError(t, err)
|
||||
defer originFile.Close()
|
||||
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
encryptedFileWriter, err := os.Create(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer encryptedFileWriter.Close()
|
||||
|
||||
err := encrypt(originFile, encryptedFileWriter, []byte(""))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
err = encrypt(originFile, encryptedFileWriter, []byte(""))
|
||||
require.NoError(t, err, "Failed to encrypt a file")
|
||||
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
require.NoError(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
encryptedFileReader, err := os.Open(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
decryptedFileWriter, err := os.Create(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte(""))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
require.NoError(t, err, "Failed to decrypt file")
|
||||
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
_, err = io.Copy(decryptedFileWriter, decryptedReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
decryptedContent, _ := os.ReadFile(decryptedFilePath)
|
||||
decryptedContent, err := os.ReadFile(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
|
@ -305,9 +319,9 @@ func Test_decryptWithDifferentPassphrase_shouldProduceWrongResult(t *testing.T)
|
|||
defer encryptedFileWriter.Close()
|
||||
|
||||
err := encrypt(originFile, encryptedFileWriter, []byte("passphrase"))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
require.NoError(t, err, "Failed to encrypt a file")
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
require.NoError(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
|
@ -317,7 +331,7 @@ func Test_decryptWithDifferentPassphrase_shouldProduceWrongResult(t *testing.T)
|
|||
defer decryptedFileWriter.Close()
|
||||
|
||||
_, err = decrypt(encryptedFileReader, []byte("garbage"))
|
||||
assert.NotNil(t, err, "Should not allow decrypt with wrong passphrase")
|
||||
require.Error(t, err, "Should not allow decrypt with wrong passphrase")
|
||||
}
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
|
@ -350,3 +364,62 @@ func legacyAesEncrypt(input io.Reader, output io.Writer, passphrase []byte) erro
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_hasEncryptedHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
fipsMode bool
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "non-FIPS mode with valid header",
|
||||
data: []byte("AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: false,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non-FIPS mode with FIPS header",
|
||||
data: []byte("FIPS-AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "FIPS mode with valid header",
|
||||
data: []byte("FIPS-AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: true,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "FIPS mode with non-FIPS header",
|
||||
data: []byte("AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: true,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid header",
|
||||
data: []byte("INVALID-HEADER" + "some data"),
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty data",
|
||||
data: []byte{},
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil data",
|
||||
data: nil,
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := hasEncryptedHeader(tt.data, tt.fipsMode)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,12 +11,12 @@ func TestCreateSignature(t *testing.T) {
|
|||
|
||||
privKey, pubKey, err := s.GenerateKeyPair()
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, len(privKey), 0)
|
||||
require.Greater(t, len(pubKey), 0)
|
||||
require.NotEmpty(t, privKey)
|
||||
require.NotEmpty(t, pubKey)
|
||||
|
||||
m := "test message"
|
||||
r, err := s.CreateSignature(m)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, r, m)
|
||||
require.Greater(t, len(r), 0)
|
||||
require.NotEmpty(t, r)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
// Not allowed in FIPS mode
|
||||
"golang.org/x/crypto/bcrypt" //nolint:depguard
|
||||
)
|
||||
|
||||
// Service represents a service for encrypting/hashing data.
|
||||
|
|
|
@ -1,18 +1,17 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/fips140"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"os"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
)
|
||||
|
||||
// CreateTLSConfiguration creates a basic tls.Config with recommended TLS settings
|
||||
func CreateTLSConfiguration(insecureSkipVerify bool) *tls.Config { //nolint:forbidigo
|
||||
// TODO: use fips.FIPSMode() instead
|
||||
return createTLSConfiguration(fips140.Enabled(), insecureSkipVerify)
|
||||
return createTLSConfiguration(fips.FIPSMode(), insecureSkipVerify)
|
||||
}
|
||||
|
||||
func createTLSConfiguration(fipsEnabled bool, insecureSkipVerify bool) *tls.Config { //nolint:forbidigo
|
||||
|
@ -58,8 +57,7 @@ func createTLSConfiguration(fipsEnabled bool, insecureSkipVerify bool) *tls.Conf
|
|||
// CreateTLSConfigurationFromBytes initializes a tls.Config using a CA certificate, a certificate and a key
|
||||
// loaded from memory.
|
||||
func CreateTLSConfigurationFromBytes(useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo
|
||||
// TODO: use fips.FIPSMode() instead
|
||||
return createTLSConfigurationFromBytes(fips140.Enabled(), useTLS, caCert, cert, key, skipClientVerification, skipServerVerification)
|
||||
return createTLSConfigurationFromBytes(fips.FIPSMode(), useTLS, caCert, cert, key, skipClientVerification, skipServerVerification)
|
||||
}
|
||||
|
||||
func createTLSConfigurationFromBytes(fipsEnabled, useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo
|
||||
|
@ -90,8 +88,7 @@ func createTLSConfigurationFromBytes(fipsEnabled, useTLS bool, caCert, cert, key
|
|||
// CreateTLSConfigurationFromDisk initializes a tls.Config using a CA certificate, a certificate and a key
|
||||
// loaded from disk.
|
||||
func CreateTLSConfigurationFromDisk(config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo
|
||||
// TODO: use fips.FIPSMode() instead
|
||||
return createTLSConfigurationFromDisk(fips140.Enabled(), config)
|
||||
return createTLSConfigurationFromDisk(fips.FIPSMode(), config)
|
||||
}
|
||||
|
||||
func createTLSConfigurationFromDisk(fipsEnabled bool, config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo
|
||||
|
|
|
@ -44,7 +44,7 @@ func TestCreateTLSConfigurationFIPS(t *testing.T) {
|
|||
func TestCreateTLSConfigurationFromBytes(t *testing.T) {
|
||||
// No TLS
|
||||
config, err := CreateTLSConfigurationFromBytes(false, nil, nil, nil, false, false)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, config)
|
||||
|
||||
// Skip TLS client/server verifications
|
||||
|
@ -61,7 +61,7 @@ func TestCreateTLSConfigurationFromBytes(t *testing.T) {
|
|||
func TestCreateTLSConfigurationFromDisk(t *testing.T) {
|
||||
// No TLS
|
||||
config, err := CreateTLSConfigurationFromDisk(portainer.TLSConfiguration{})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, config)
|
||||
|
||||
// Skip TLS verifications
|
||||
|
|
|
@ -94,7 +94,7 @@ func Test_MarshalObjectUnencrypted(t *testing.T) {
|
|||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) {
|
||||
data, err := conn.MarshalObject(test.object)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
is.Equal(test.expected, string(data))
|
||||
})
|
||||
}
|
||||
|
@ -135,7 +135,7 @@ func Test_UnMarshalObjectUnencrypted(t *testing.T) {
|
|||
t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) {
|
||||
var object string
|
||||
err := conn.UnmarshalObject(test.object, &object)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
is.Equal(test.expected, object)
|
||||
})
|
||||
}
|
||||
|
@ -172,12 +172,12 @@ func Test_ObjectMarshallingEncrypted(t *testing.T) {
|
|||
t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) {
|
||||
|
||||
data, err := conn.MarshalObject(test.object)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
var object []byte
|
||||
err = conn.UnmarshalObject(data, &object)
|
||||
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
is.Equal(test.object, object)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -28,13 +28,12 @@ func NewService(connection portainer.Connection) (*Service, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// CreateCustomTemplate uses the existing id and saves it.
|
||||
// TODO: where does the ID come from, and is it safe?
|
||||
func (service *Service) Create(customTemplate *portainer.CustomTemplate) error {
|
||||
return service.Connection.CreateObjectWithId(BucketName, int(customTemplate.ID), customTemplate)
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service *Service) GetNextIdentifier() int {
|
||||
return service.Connection.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
func (service *Service) Create(customTemplate *portainer.CustomTemplate) error {
|
||||
return service.Connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).Create(customTemplate)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
package customtemplate_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCustomTemplateCreate(t *testing.T) {
|
||||
_, ds := datastore.MustNewTestStore(t, true, false)
|
||||
require.NotNil(t, ds)
|
||||
|
||||
require.NoError(t, ds.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1}))
|
||||
e, err := ds.CustomTemplate().Read(1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, portainer.CustomTemplateID(1), e.ID)
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package customtemplate
|
||||
|
||||
import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
)
|
||||
|
||||
// Service represents a service for managing custom template data.
|
||||
type ServiceTx struct {
|
||||
dataservices.BaseDataServiceTx[portainer.CustomTemplate, portainer.CustomTemplateID]
|
||||
}
|
||||
|
||||
func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
|
||||
return ServiceTx{
|
||||
BaseDataServiceTx: dataservices.BaseDataServiceTx[portainer.CustomTemplate, portainer.CustomTemplateID]{
|
||||
Bucket: BucketName,
|
||||
Connection: service.Connection,
|
||||
Tx: tx,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (service ServiceTx) GetNextIdentifier() int {
|
||||
return service.Tx.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
// CreateCustomTemplate uses the existing id and saves it.
|
||||
// TODO: where does the ID come from, and is it safe?
|
||||
func (service ServiceTx) Create(customTemplate *portainer.CustomTemplate) error {
|
||||
return service.Tx.CreateObjectWithId(BucketName, int(customTemplate.ID), customTemplate)
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package customtemplate_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCustomTemplateCreateTx(t *testing.T) {
|
||||
_, ds := datastore.MustNewTestStore(t, true, false)
|
||||
require.NotNil(t, ds)
|
||||
|
||||
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
return tx.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1})
|
||||
}))
|
||||
|
||||
var template *portainer.CustomTemplate
|
||||
require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error {
|
||||
var err error
|
||||
template, err = tx.CustomTemplate().Read(1)
|
||||
return err
|
||||
}))
|
||||
|
||||
require.Equal(t, portainer.CustomTemplateID(1), template.ID)
|
||||
}
|
|
@ -91,9 +91,9 @@ func (service *Service) UpdateEndpointRelation(endpointID portainer.EndpointID,
|
|||
})
|
||||
}
|
||||
|
||||
func (service *Service) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error {
|
||||
func (service *Service) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error {
|
||||
return service.connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).AddEndpointRelationsForEdgeStack(endpointIDs, edgeStackID)
|
||||
return service.Tx(tx).AddEndpointRelationsForEdgeStack(endpointIDs, edgeStack)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/database/boltdb"
|
||||
"github.com/portainer/portainer/api/dataservices/edgestack"
|
||||
"github.com/portainer/portainer/api/internal/edge/cache"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -102,3 +103,38 @@ func TestUpdateRelation(t *testing.T) {
|
|||
require.Equal(t, 0, edgeStacks[edgeStackID1].NumDeployments)
|
||||
require.Equal(t, 0, edgeStacks[edgeStackID2].NumDeployments)
|
||||
}
|
||||
|
||||
func TestAddEndpointRelationsForEdgeStack(t *testing.T) {
|
||||
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
service, err := NewService(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
edgeStackService, err := edgestack.NewService(conn, func(t portainer.Transaction, esi portainer.EdgeStackID) {})
|
||||
require.NoError(t, err)
|
||||
|
||||
service.RegisterUpdateStackFunction(edgeStackService.UpdateEdgeStackFuncTx)
|
||||
require.NoError(t, edgeStackService.Create(1, &portainer.EdgeStack{}))
|
||||
require.NoError(t, service.Create(&portainer.EndpointRelation{EndpointID: 1, EdgeStacks: map[portainer.EdgeStackID]bool{}}))
|
||||
require.NoError(t, service.AddEndpointRelationsForEdgeStack([]portainer.EndpointID{1}, &portainer.EdgeStack{ID: 1}))
|
||||
}
|
||||
|
||||
func TestEndpointRelations(t *testing.T) {
|
||||
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
service, err := NewService(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, service.Create(&portainer.EndpointRelation{EndpointID: 1}))
|
||||
rels, err := service.EndpointRelations()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rels, 1)
|
||||
}
|
||||
|
|
|
@ -76,14 +76,14 @@ func (service ServiceTx) UpdateEndpointRelation(endpointID portainer.EndpointID,
|
|||
return nil
|
||||
}
|
||||
|
||||
func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error {
|
||||
func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error {
|
||||
for _, endpointID := range endpointIDs {
|
||||
rel, err := service.EndpointRelation(endpointID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rel.EdgeStacks[edgeStackID] = true
|
||||
rel.EdgeStacks[edgeStack.ID] = true
|
||||
|
||||
identifier := service.service.connection.ConvertToKey(int(endpointID))
|
||||
err = service.tx.UpdateObject(BucketName, identifier, rel)
|
||||
|
@ -97,8 +97,12 @@ func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portaine
|
|||
service.service.endpointRelationsCache = nil
|
||||
service.service.mu.Unlock()
|
||||
|
||||
if err := service.service.updateStackFnTx(service.tx, edgeStackID, func(edgeStack *portainer.EdgeStack) {
|
||||
edgeStack.NumDeployments += len(endpointIDs)
|
||||
if err := service.service.updateStackFnTx(service.tx, edgeStack.ID, func(es *portainer.EdgeStack) {
|
||||
es.NumDeployments += len(endpointIDs)
|
||||
|
||||
// sync changes in `edgeStack` in case it is re-persisted after `AddEndpointRelationsForEdgeStack` call
|
||||
// to avoid overriding with the previous values
|
||||
edgeStack.NumDeployments = es.NumDeployments
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("could not update the number of deployments")
|
||||
}
|
||||
|
|
|
@ -126,7 +126,7 @@ type (
|
|||
EndpointRelation(EndpointID portainer.EndpointID) (*portainer.EndpointRelation, error)
|
||||
Create(endpointRelation *portainer.EndpointRelation) error
|
||||
UpdateEndpointRelation(EndpointID portainer.EndpointID, endpointRelation *portainer.EndpointRelation) error
|
||||
AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error
|
||||
AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error
|
||||
RemoveEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error
|
||||
DeleteEndpointRelation(EndpointID portainer.EndpointID) error
|
||||
BucketName() string
|
||||
|
|
|
@ -1,13 +1,8 @@
|
|||
package pendingactions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const BucketName = "pending_actions"
|
||||
|
@ -16,10 +11,6 @@ type Service struct {
|
|||
dataservices.BaseDataService[portainer.PendingAction, portainer.PendingActionID]
|
||||
}
|
||||
|
||||
type ServiceTx struct {
|
||||
dataservices.BaseDataServiceTx[portainer.PendingAction, portainer.PendingActionID]
|
||||
}
|
||||
|
||||
func NewService(connection portainer.Connection) (*Service, error) {
|
||||
err := connection.SetServiceName(BucketName)
|
||||
if err != nil {
|
||||
|
@ -34,6 +25,11 @@ func NewService(connection portainer.Connection) (*Service, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service *Service) GetNextIdentifier() int {
|
||||
return service.Connection.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
func (s Service) Create(config *portainer.PendingAction) error {
|
||||
return s.Connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return s.Tx(tx).Create(config)
|
||||
|
@ -61,43 +57,3 @@ func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
|
|||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s ServiceTx) Create(config *portainer.PendingAction) error {
|
||||
return s.Tx.CreateObject(BucketName, func(id uint64) (int, any) {
|
||||
config.ID = portainer.PendingActionID(id)
|
||||
config.CreatedAt = time.Now().Unix()
|
||||
|
||||
return int(config.ID), config
|
||||
})
|
||||
}
|
||||
|
||||
func (s ServiceTx) Update(ID portainer.PendingActionID, config *portainer.PendingAction) error {
|
||||
return s.BaseDataServiceTx.Update(ID, config)
|
||||
}
|
||||
|
||||
func (s ServiceTx) DeleteByEndpointID(ID portainer.EndpointID) error {
|
||||
log.Debug().Int("endpointId", int(ID)).Msg("deleting pending actions for endpoint")
|
||||
pendingActions, err := s.ReadAll()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve pending-actions for endpoint (%d): %w", ID, err)
|
||||
}
|
||||
|
||||
for _, pendingAction := range pendingActions {
|
||||
if pendingAction.EndpointID == ID {
|
||||
if err := s.Delete(pendingAction.ID); err != nil {
|
||||
log.Debug().Int("endpointId", int(ID)).Msgf("failed to delete pending action: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service ServiceTx) GetNextIdentifier() int {
|
||||
return service.Tx.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service *Service) GetNextIdentifier() int {
|
||||
return service.Connection.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
package pendingactions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ServiceTx struct {
|
||||
dataservices.BaseDataServiceTx[portainer.PendingAction, portainer.PendingActionID]
|
||||
}
|
||||
|
||||
func (s ServiceTx) Create(config *portainer.PendingAction) error {
|
||||
return s.Tx.CreateObject(BucketName, func(id uint64) (int, any) {
|
||||
config.ID = portainer.PendingActionID(id)
|
||||
config.CreatedAt = time.Now().Unix()
|
||||
|
||||
return int(config.ID), config
|
||||
})
|
||||
}
|
||||
|
||||
func (s ServiceTx) Update(ID portainer.PendingActionID, config *portainer.PendingAction) error {
|
||||
return s.BaseDataServiceTx.Update(ID, config)
|
||||
}
|
||||
|
||||
func (s ServiceTx) DeleteByEndpointID(ID portainer.EndpointID) error {
|
||||
log.Debug().Int("endpointId", int(ID)).Msg("deleting pending actions for endpoint")
|
||||
pendingActions, err := s.ReadAll()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve pending-actions for endpoint (%d): %w", ID, err)
|
||||
}
|
||||
|
||||
for _, pendingAction := range pendingActions {
|
||||
if pendingAction.EndpointID == ID {
|
||||
if err := s.Delete(pendingAction.ID); err != nil {
|
||||
log.Debug().Int("endpointId", int(ID)).Msgf("failed to delete pending action: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service ServiceTx) GetNextIdentifier() int {
|
||||
return service.Tx.GetNextIdentifier(BucketName)
|
||||
}
|
|
@ -4,17 +4,18 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newGuidString(t *testing.T) string {
|
||||
uuid, err := uuid.NewV4()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
return uuid.String()
|
||||
}
|
||||
|
@ -41,7 +42,7 @@ func TestService_StackByWebhookID(t *testing.T) {
|
|||
|
||||
// can find a stack by webhook ID
|
||||
got, err := store.StackService.StackByWebhookID(webhookID)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, stack, *got)
|
||||
|
||||
// returns nil and object not found error if there's no stack associated with the webhook
|
||||
|
@ -94,10 +95,10 @@ func Test_RefreshableStacks(t *testing.T) {
|
|||
|
||||
for _, stack := range []*portainer.Stack{&staticStack, &stackWithWebhook, &refreshableStack} {
|
||||
err := store.Stack().Create(stack)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
stacks, err := store.Stack().RefreshableStacks()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.ElementsMatch(t, []portainer.Stack{refreshableStack}, stacks)
|
||||
}
|
||||
|
|
|
@ -5,7 +5,9 @@ import (
|
|||
|
||||
"github.com/portainer/portainer/api/dataservices/errors"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_teamByName(t *testing.T) {
|
||||
|
@ -13,7 +15,7 @@ func Test_teamByName(t *testing.T) {
|
|||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
|
||||
_, err := store.Team().TeamByName("name")
|
||||
assert.ErrorIs(t, err, errors.ErrObjectNotFound)
|
||||
require.ErrorIs(t, err, errors.ErrObjectNotFound)
|
||||
|
||||
})
|
||||
|
||||
|
@ -29,7 +31,7 @@ func Test_teamByName(t *testing.T) {
|
|||
teamBuilder.createNew("name1")
|
||||
|
||||
_, err := store.Team().TeamByName("name")
|
||||
assert.ErrorIs(t, err, errors.ErrObjectNotFound)
|
||||
require.ErrorIs(t, err, errors.ErrObjectNotFound)
|
||||
})
|
||||
|
||||
t.Run("When there is an object with the same name should return the object", func(t *testing.T) {
|
||||
|
@ -44,7 +46,7 @@ func Test_teamByName(t *testing.T) {
|
|||
expectedTeam := teamBuilder.createNew("name1")
|
||||
|
||||
team, err := store.Team().TeamByName("name1")
|
||||
assert.NoError(t, err, "TeamByName should succeed")
|
||||
require.NoError(t, err, "TeamByName should succeed")
|
||||
assert.Equal(t, expectedTeam, team)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -5,15 +5,14 @@ import (
|
|||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/database/models"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func TestStoreCreation(t *testing.T) {
|
||||
_, store := MustNewTestStore(t, true, true)
|
||||
if store == nil {
|
||||
t.Fatal("Expect to create a store")
|
||||
}
|
||||
require.NotNil(t, store)
|
||||
|
||||
v, err := store.VersionService.Version()
|
||||
if err != nil {
|
||||
|
|
|
@ -6,12 +6,14 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/dchest/uniuri"
|
||||
"github.com/pkg/errors"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/chisel"
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
|
||||
"github.com/dchest/uniuri"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -30,45 +32,21 @@ func TestStoreFull(t *testing.T) {
|
|||
_, store := MustNewTestStore(t, true, true)
|
||||
|
||||
testCases := map[string]func(t *testing.T){
|
||||
"User Accounts": func(t *testing.T) {
|
||||
store.testUserAccounts(t)
|
||||
},
|
||||
"Environments": func(t *testing.T) {
|
||||
store.testEnvironments(t)
|
||||
},
|
||||
"Settings": func(t *testing.T) {
|
||||
store.testSettings(t)
|
||||
},
|
||||
"SSL Settings": func(t *testing.T) {
|
||||
store.testSSLSettings(t)
|
||||
},
|
||||
"Tunnel Server": func(t *testing.T) {
|
||||
store.testTunnelServer(t)
|
||||
},
|
||||
"Custom Templates": func(t *testing.T) {
|
||||
store.testCustomTemplates(t)
|
||||
},
|
||||
"Registries": func(t *testing.T) {
|
||||
store.testRegistries(t)
|
||||
},
|
||||
"Resource Control": func(t *testing.T) {
|
||||
store.testResourceControl(t)
|
||||
},
|
||||
"Schedules": func(t *testing.T) {
|
||||
store.testSchedules(t)
|
||||
},
|
||||
"Tags": func(t *testing.T) {
|
||||
store.testTags(t)
|
||||
},
|
||||
|
||||
// "Test Title": func(t *testing.T) {
|
||||
// },
|
||||
"User Accounts": store.testUserAccounts,
|
||||
"Environments": store.testEnvironments,
|
||||
"Settings": store.testSettings,
|
||||
"SSL Settings": store.testSSLSettings,
|
||||
"Tunnel Server": store.testTunnelServer,
|
||||
"Custom Templates": store.testCustomTemplates,
|
||||
"Registries": store.testRegistries,
|
||||
"Resource Control": store.testResourceControl,
|
||||
"Schedules": store.testSchedules,
|
||||
"Tags": store.testTags,
|
||||
}
|
||||
|
||||
for name, test := range testCases {
|
||||
t.Run(name, test)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (store *Store) testEnvironments(t *testing.T) {
|
||||
|
@ -167,7 +145,7 @@ func (store *Store) CreateEndpoint(t *testing.T, name string, endpointType porta
|
|||
store.Endpoint().Create(expectedEndpoint)
|
||||
|
||||
endpoint, err := store.Endpoint().Endpoint(id)
|
||||
is.NoError(err, "Endpoint() should not return an error")
|
||||
require.NoError(t, err, "Endpoint() should not return an error")
|
||||
is.Equal(expectedEndpoint, endpoint, "endpoint should be the same")
|
||||
|
||||
return endpoint.ID
|
||||
|
@ -194,7 +172,7 @@ func (store *Store) testSSLSettings(t *testing.T) {
|
|||
store.SSLSettings().UpdateSettings(ssl)
|
||||
|
||||
settings, err := store.SSLSettings().Settings()
|
||||
is.NoError(err, "Get sslsettings should succeed")
|
||||
require.NoError(t, err, "Get sslsettings should succeed")
|
||||
is.Equal(ssl, settings, "Stored SSLSettings should be the same as what is read out")
|
||||
}
|
||||
|
||||
|
@ -203,27 +181,27 @@ func (store *Store) testTunnelServer(t *testing.T) {
|
|||
expectPrivateKeySeed := uniuri.NewLen(16)
|
||||
|
||||
err := store.TunnelServer().UpdateInfo(&portainer.TunnelServerInfo{PrivateKeySeed: expectPrivateKeySeed})
|
||||
is.NoError(err, "UpdateInfo should have succeeded")
|
||||
require.NoError(t, err, "UpdateInfo should have succeeded")
|
||||
|
||||
serverInfo, err := store.TunnelServer().Info()
|
||||
is.NoError(err, "Info should have succeeded")
|
||||
require.NoError(t, err, "Info should have succeeded")
|
||||
|
||||
is.Equal(expectPrivateKeySeed, serverInfo.PrivateKeySeed, "hashed passwords should not differ")
|
||||
}
|
||||
|
||||
// add users, read them back and check the details are unchanged
|
||||
func (store *Store) testUserAccounts(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
err := store.createAccount(adminUsername, adminPassword, portainer.AdministratorRole)
|
||||
is.NoError(err, "CreateAccount should succeed")
|
||||
store.checkAccount(adminUsername, adminPassword, portainer.AdministratorRole)
|
||||
is.NoError(err, "Account failure")
|
||||
require.NoError(t, err, "CreateAccount should succeed")
|
||||
|
||||
err = store.checkAccount(adminUsername, adminPassword, portainer.AdministratorRole)
|
||||
require.NoError(t, err, "Account failure")
|
||||
|
||||
err = store.createAccount(standardUsername, standardPassword, portainer.StandardUserRole)
|
||||
is.NoError(err, "CreateAccount should succeed")
|
||||
store.checkAccount(standardUsername, standardPassword, portainer.StandardUserRole)
|
||||
is.NoError(err, "Account failure")
|
||||
require.NoError(t, err, "CreateAccount should succeed")
|
||||
|
||||
err = store.checkAccount(standardUsername, standardPassword, portainer.StandardUserRole)
|
||||
require.NoError(t, err, "Account failure")
|
||||
}
|
||||
|
||||
// create an account with the provided details
|
||||
|
@ -238,12 +216,7 @@ func (store *Store) createAccount(username, password string, role portainer.User
|
|||
return err
|
||||
}
|
||||
|
||||
err = store.User().Create(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return store.User().Create(user)
|
||||
}
|
||||
|
||||
func (store *Store) checkAccount(username, expectPassword string, expectRole portainer.UserRole) error {
|
||||
|
@ -260,12 +233,7 @@ func (store *Store) checkAccount(username, expectPassword string, expectRole por
|
|||
|
||||
// Check the password
|
||||
cs := crypto.Service{}
|
||||
expectPasswordHash, err := cs.Hash(expectPassword)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "hash failed")
|
||||
}
|
||||
|
||||
if user.Password != expectPasswordHash {
|
||||
if cs.CompareHashAndData(user.Password, expectPassword) != nil {
|
||||
return fmt.Errorf("%s user password hash failure", user.Username)
|
||||
}
|
||||
|
||||
|
@ -277,7 +245,7 @@ func (store *Store) testSettings(t *testing.T) {
|
|||
|
||||
// since many settings are default and basically nil, I'm going to update some and read them back
|
||||
expectedSettings, err := store.Settings().Settings()
|
||||
is.NoError(err, "Settings() should not return an error")
|
||||
require.NoError(t, err, "Settings() should not return an error")
|
||||
expectedSettings.TemplatesURL = "http://portainer.io/application-templates"
|
||||
expectedSettings.HelmRepositoryURL = "http://portainer.io/helm-repository"
|
||||
expectedSettings.EdgeAgentCheckinInterval = 60
|
||||
|
@ -291,10 +259,10 @@ func (store *Store) testSettings(t *testing.T) {
|
|||
expectedSettings.SnapshotInterval = "10m"
|
||||
|
||||
err = store.Settings().UpdateSettings(expectedSettings)
|
||||
is.NoError(err, "UpdateSettings() should succeed")
|
||||
require.NoError(t, err, "UpdateSettings() should succeed")
|
||||
|
||||
settings, err := store.Settings().Settings()
|
||||
is.NoError(err, "Settings() should not return an error")
|
||||
require.NoError(t, err, "Settings() should not return an error")
|
||||
is.Equal(expectedSettings, settings, "stored settings should match")
|
||||
}
|
||||
|
||||
|
@ -317,7 +285,7 @@ func (store *Store) testCustomTemplates(t *testing.T) {
|
|||
customTemplate.Create(expectedTemplate)
|
||||
|
||||
actualTemplate, err := customTemplate.Read(expectedTemplate.ID)
|
||||
is.NoError(err, "CustomTemplate should not return an error")
|
||||
require.NoError(t, err, "CustomTemplate should not return an error")
|
||||
is.Equal(expectedTemplate, actualTemplate, "expected and actual template do not match")
|
||||
}
|
||||
|
||||
|
@ -345,17 +313,17 @@ func (store *Store) testRegistries(t *testing.T) {
|
|||
}
|
||||
|
||||
err := regService.Create(reg1)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = regService.Create(reg2)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
|
||||
actualReg1, err := regService.Read(reg1.ID)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
is.Equal(reg1, actualReg1, "registries differ")
|
||||
|
||||
actualReg2, err := regService.Read(reg2.ID)
|
||||
is.NoError(err)
|
||||
require.NoError(t, err)
|
||||
is.Equal(reg2, actualReg2, "registries differ")
|
||||
}
|
||||
|
||||
|
@ -378,10 +346,10 @@ func (store *Store) testSchedules(t *testing.T) {
|
|||
}
|
||||
|
||||
err := schedule.CreateSchedule(s)
|
||||
is.NoError(err, "CreateSchedule should succeed")
|
||||
require.NoError(t, err, "CreateSchedule should succeed")
|
||||
|
||||
actual, err := schedule.Schedule(s.ID)
|
||||
is.NoError(err, "schedule should be found")
|
||||
require.NoError(t, err, "schedule should be found")
|
||||
is.Equal(s, actual, "schedules differ")
|
||||
}
|
||||
|
||||
|
@ -401,16 +369,16 @@ func (store *Store) testTags(t *testing.T) {
|
|||
}
|
||||
|
||||
err := tags.Create(tag1)
|
||||
is.NoError(err, "Tags.Create should succeed")
|
||||
require.NoError(t, err, "Tags.Create should succeed")
|
||||
|
||||
err = tags.Create(tag2)
|
||||
is.NoError(err, "Tags.Create should succeed")
|
||||
require.NoError(t, err, "Tags.Create should succeed")
|
||||
|
||||
actual, err := tags.Read(tag1.ID)
|
||||
is.NoError(err, "tag1 should be found")
|
||||
require.NoError(t, err, "tag1 should be found")
|
||||
is.Equal(tag1, actual, "tags differ")
|
||||
|
||||
actual, err = tags.Read(tag2.ID)
|
||||
is.NoError(err, "tag2 should be found")
|
||||
require.NoError(t, err, "tag2 should be found")
|
||||
is.Equal(tag2, actual, "tags differ")
|
||||
}
|
||||
|
|
|
@ -6,7 +6,9 @@ import (
|
|||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore/migrator"
|
||||
gittypes "github.com/portainer/portainer/api/git/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMigrateStackEntryPoint(t *testing.T) {
|
||||
|
@ -28,25 +30,25 @@ func TestMigrateStackEntryPoint(t *testing.T) {
|
|||
|
||||
for _, s := range stacks {
|
||||
err := stackService.Create(s)
|
||||
assert.NoError(t, err, "failed to create stack")
|
||||
require.NoError(t, err, "failed to create stack")
|
||||
}
|
||||
|
||||
s, err := stackService.Read(1)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, s.GitConfig, "first stack should not have git config")
|
||||
|
||||
s, err = stackService.Read(2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "", s.GitConfig.ConfigFilePath, "not migrated yet migrated")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, s.GitConfig.ConfigFilePath, "not migrated yet migrated")
|
||||
|
||||
err = migrator.MigrateStackEntryPoint(stackService)
|
||||
assert.NoError(t, err, "failed to migrate entry point to Git ConfigFilePath")
|
||||
require.NoError(t, err, "failed to migrate entry point to Git ConfigFilePath")
|
||||
|
||||
s, err = stackService.Read(1)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, s.GitConfig, "first stack should not have git config")
|
||||
|
||||
s, err = stackService.Read(2)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "dir/sub/compose.yml", s.GitConfig.ConfigFilePath, "second stack should have config file path migrated")
|
||||
}
|
||||
|
|
|
@ -11,8 +11,10 @@ func (m *Migrator) migrateEdgeGroupEndpointsToRoars_2_33_0() error {
|
|||
}
|
||||
|
||||
for _, eg := range egs {
|
||||
eg.EndpointIDs = roar.FromSlice(eg.Endpoints)
|
||||
eg.Endpoints = nil
|
||||
if eg.EndpointIDs.Len() == 0 {
|
||||
eg.EndpointIDs = roar.FromSlice(eg.Endpoints)
|
||||
eg.Endpoints = nil
|
||||
}
|
||||
|
||||
if err := m.edgeGroupService.Update(eg.ID, &eg); err != nil {
|
||||
return err
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
package migrator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/database/boltdb"
|
||||
"github.com/portainer/portainer/api/dataservices/edgegroup"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMigrateEdgeGroupEndpointsToRoars_2_33_0Idempotency(t *testing.T) {
|
||||
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
edgeGroupService, err := edgegroup.NewService(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
edgeGroup := &portainer.EdgeGroup{
|
||||
ID: 1,
|
||||
Name: "test-edge-group",
|
||||
Endpoints: []portainer.EndpointID{1, 2, 3},
|
||||
}
|
||||
|
||||
err = conn.CreateObjectWithId(edgegroup.BucketName, int(edgeGroup.ID), edgeGroup)
|
||||
require.NoError(t, err)
|
||||
|
||||
m := NewMigrator(&MigratorParameters{EdgeGroupService: edgeGroupService})
|
||||
|
||||
// Run migration once
|
||||
|
||||
err = m.migrateEdgeGroupEndpointsToRoars_2_33_0()
|
||||
require.NoError(t, err)
|
||||
|
||||
migratedEdgeGroup, err := edgeGroupService.Read(edgeGroup.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Empty(t, migratedEdgeGroup.Endpoints)
|
||||
require.Equal(t, len(edgeGroup.Endpoints), migratedEdgeGroup.EndpointIDs.Len())
|
||||
|
||||
// Run migration again to ensure the results didn't change
|
||||
|
||||
err = m.migrateEdgeGroupEndpointsToRoars_2_33_0()
|
||||
require.NoError(t, err)
|
||||
|
||||
migratedEdgeGroup, err = edgeGroupService.Read(edgeGroup.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Empty(t, migratedEdgeGroup.Endpoints)
|
||||
require.Equal(t, len(edgeGroup.Endpoints), migratedEdgeGroup.EndpointIDs.Len())
|
||||
}
|
|
@ -256,10 +256,9 @@ func (m *Migrator) initMigrations() {
|
|||
|
||||
m.addMigrations("2.32.0", m.addEndpointRelationForEdgeAgents_2_32_0)
|
||||
|
||||
m.addMigrations("2.33.0-rc1", m.migrateEdgeGroupEndpointsToRoars_2_33_0)
|
||||
m.addMigrations("2.33.1", m.migrateEdgeGroupEndpointsToRoars_2_33_0)
|
||||
|
||||
//m.addMigrations("2.33.0", m.migrateEdgeGroupEndpointsToRoars_2_33_0)
|
||||
// when we release 2.33.0 it will also run the rc-1 migration function
|
||||
// WARNING: do not change migrations that have already been released!
|
||||
|
||||
// Add new migrations above...
|
||||
// One function per migration, each versions migration funcs in the same file.
|
||||
|
|
|
@ -2,6 +2,7 @@ package postinit
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
|
@ -83,17 +84,27 @@ func (postInitMigrator *PostInitMigrator) PostInitMigrate() error {
|
|||
|
||||
// try to create a post init migration pending action. If it already exists, do nothing
|
||||
// this function exists for readability, not reusability
|
||||
// TODO: This should be moved into pending actions as part of the pending action migration
|
||||
func (postInitMigrator *PostInitMigrator) createPostInitMigrationPendingAction(environmentID portainer.EndpointID) error {
|
||||
// If there are no pending actions for the given endpoint, create one
|
||||
err := postInitMigrator.dataStore.PendingActions().Create(&portainer.PendingAction{
|
||||
action := portainer.PendingAction{
|
||||
EndpointID: environmentID,
|
||||
Action: actions.PostInitMigrateEnvironment,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Error creating pending action for environment %d", environmentID)
|
||||
}
|
||||
return nil
|
||||
pendingActions, err := postInitMigrator.dataStore.PendingActions().ReadAll()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve pending actions: %w", err)
|
||||
}
|
||||
|
||||
for _, dba := range pendingActions {
|
||||
if dba.EndpointID == action.EndpointID && dba.Action == action.Action {
|
||||
log.Debug().
|
||||
Str("action", action.Action).
|
||||
Int("endpoint_id", int(action.EndpointID)).
|
||||
Msg("pending action already exists for environment, skipping...")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return postInitMigrator.dataStore.PendingActions().Create(&action)
|
||||
}
|
||||
|
||||
// MigrateEnvironment runs migrations on a single environment
|
||||
|
|
|
@ -8,10 +8,12 @@ import (
|
|||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/pendingactions/actions"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/segmentio/encoding/json"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -20,8 +22,9 @@ func TestMigrateGPUs(t *testing.T) {
|
|||
if strings.HasSuffix(r.URL.Path, "/containers/json") {
|
||||
containerSummary := []container.Summary{{ID: "container1"}}
|
||||
|
||||
err := json.NewEncoder(w).Encode(containerSummary)
|
||||
require.NoError(t, err)
|
||||
if err := json.NewEncoder(w).Encode(containerSummary); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -39,8 +42,9 @@ func TestMigrateGPUs(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
err := json.NewEncoder(w).Encode(container)
|
||||
require.NoError(t, err)
|
||||
if err := json.NewEncoder(w).Encode(container); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
|
@ -73,3 +77,98 @@ func TestMigrateGPUs(t *testing.T) {
|
|||
require.False(t, migratedEndpoint.PostInitMigrations.MigrateGPUs)
|
||||
require.True(t, migratedEndpoint.EnableGPUManagement)
|
||||
}
|
||||
|
||||
func TestPostInitMigrate_PendingActionsCreated(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
existingPendingActions []*portainer.PendingAction
|
||||
expectedPendingActions int
|
||||
expectedAction string
|
||||
}{
|
||||
{
|
||||
name: "when existing non-matching action exists, should add migration action",
|
||||
existingPendingActions: []*portainer.PendingAction{
|
||||
{
|
||||
EndpointID: 7,
|
||||
Action: "some-other-action",
|
||||
},
|
||||
},
|
||||
expectedPendingActions: 2,
|
||||
expectedAction: actions.PostInitMigrateEnvironment,
|
||||
},
|
||||
{
|
||||
name: "when matching action exists, should not add duplicate",
|
||||
existingPendingActions: []*portainer.PendingAction{
|
||||
{
|
||||
EndpointID: 7,
|
||||
Action: actions.PostInitMigrateEnvironment,
|
||||
},
|
||||
},
|
||||
expectedPendingActions: 1,
|
||||
expectedAction: actions.PostInitMigrateEnvironment,
|
||||
},
|
||||
{
|
||||
name: "when no actions exist, should add migration action",
|
||||
existingPendingActions: []*portainer.PendingAction{},
|
||||
expectedPendingActions: 1,
|
||||
expectedAction: actions.PostInitMigrateEnvironment,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
|
||||
// Create test endpoint
|
||||
endpoint := &portainer.Endpoint{
|
||||
ID: 7,
|
||||
UserTrusted: true,
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
Edge: portainer.EnvironmentEdgeSettings{
|
||||
AsyncMode: false,
|
||||
},
|
||||
EdgeID: "edgeID",
|
||||
}
|
||||
err := store.Endpoint().Create(endpoint)
|
||||
require.NoError(t, err, "error creating endpoint")
|
||||
|
||||
// Create any existing pending actions
|
||||
for _, action := range tt.existingPendingActions {
|
||||
err = store.PendingActions().Create(action)
|
||||
require.NoError(t, err, "error creating pending action")
|
||||
}
|
||||
|
||||
migrator := NewPostInitMigrator(
|
||||
nil, // kubeFactory not needed for this test
|
||||
nil, // dockerFactory not needed for this test
|
||||
store,
|
||||
"", // assetsPath not needed for this test
|
||||
nil, // kubernetesDeployer not needed for this test
|
||||
)
|
||||
|
||||
err = migrator.PostInitMigrate()
|
||||
require.NoError(t, err, "PostInitMigrate should not return error")
|
||||
|
||||
// Verify the results
|
||||
pendingActions, err := store.PendingActions().ReadAll()
|
||||
require.NoError(t, err, "error reading pending actions")
|
||||
is.Len(pendingActions, tt.expectedPendingActions, "unexpected number of pending actions")
|
||||
|
||||
// If we expect any actions, verify at least one has the expected action type
|
||||
if tt.expectedPendingActions > 0 {
|
||||
hasExpectedAction := false
|
||||
for _, action := range pendingActions {
|
||||
if action.Action == tt.expectedAction {
|
||||
hasExpectedAction = true
|
||||
is.Equal(endpoint.ID, action.EndpointID, "action should reference correct endpoint")
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
is.True(hasExpectedAction, "should have found action of expected type")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,9 @@ func (tx *StoreTx) IsErrObjectNotFound(err error) bool {
|
|||
return tx.store.IsErrObjectNotFound(err)
|
||||
}
|
||||
|
||||
func (tx *StoreTx) CustomTemplate() dataservices.CustomTemplateService { return nil }
|
||||
func (tx *StoreTx) CustomTemplate() dataservices.CustomTemplateService {
|
||||
return tx.store.CustomTemplateService.Tx(tx.tx)
|
||||
}
|
||||
|
||||
func (tx *StoreTx) PendingActions() dataservices.PendingActionsService {
|
||||
return tx.store.PendingActionsService.Tx(tx.tx)
|
||||
|
|
|
@ -83,7 +83,6 @@
|
|||
"MigrateIngresses": true
|
||||
},
|
||||
"PublicURL": "",
|
||||
"QueryDate": 0,
|
||||
"SecuritySettings": {
|
||||
"allowBindMountsForRegularUsers": true,
|
||||
"allowContainerCapabilitiesForRegularUsers": true,
|
||||
|
@ -615,7 +614,7 @@
|
|||
"RequiredPasswordLength": 12
|
||||
},
|
||||
"KubeconfigExpiry": "0",
|
||||
"KubectlShellImage": "portainer/kubectl-shell:2.33.0-rc1",
|
||||
"KubectlShellImage": "portainer/kubectl-shell:2.34.0",
|
||||
"LDAPSettings": {
|
||||
"AnonymousMode": true,
|
||||
"AutoCreateUsers": true,
|
||||
|
@ -944,7 +943,7 @@
|
|||
}
|
||||
],
|
||||
"version": {
|
||||
"VERSION": "{\"SchemaVersion\":\"2.33.0-rc1\",\"MigratorCount\":1,\"Edition\":1,\"InstanceID\":\"463d5c47-0ea5-4aca-85b1-405ceefee254\"}"
|
||||
"VERSION": "{\"SchemaVersion\":\"2.34.0\",\"MigratorCount\":0,\"Edition\":1,\"InstanceID\":\"463d5c47-0ea5-4aca-85b1-405ceefee254\"}"
|
||||
},
|
||||
"webhooks": null
|
||||
}
|
|
@ -4,10 +4,14 @@ import (
|
|||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHttpClient(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
// Valid TLS configuration
|
||||
endpoint := &portainer.Endpoint{}
|
||||
endpoint.TLSConfig = portainer.TLSConfiguration{TLS: true}
|
||||
|
|
|
@ -1,37 +0,0 @@
|
|||
package docker
|
||||
|
||||
import "github.com/docker/docker/api/types"
|
||||
|
||||
type ContainerStats struct {
|
||||
Running int `json:"running"`
|
||||
Stopped int `json:"stopped"`
|
||||
Healthy int `json:"healthy"`
|
||||
Unhealthy int `json:"unhealthy"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
func CalculateContainerStats(containers []types.Container) ContainerStats {
|
||||
var running, stopped, healthy, unhealthy int
|
||||
for _, container := range containers {
|
||||
switch container.State {
|
||||
case "running":
|
||||
running++
|
||||
case "healthy":
|
||||
running++
|
||||
healthy++
|
||||
case "unhealthy":
|
||||
running++
|
||||
unhealthy++
|
||||
case "exited", "stopped":
|
||||
stopped++
|
||||
}
|
||||
}
|
||||
|
||||
return ContainerStats{
|
||||
Running: running,
|
||||
Stopped: stopped,
|
||||
Healthy: healthy,
|
||||
Unhealthy: unhealthy,
|
||||
Total: len(containers),
|
||||
}
|
||||
}
|
|
@ -1,27 +0,0 @@
|
|||
package docker
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCalculateContainerStats(t *testing.T) {
|
||||
containers := []types.Container{
|
||||
{State: "running"},
|
||||
{State: "running"},
|
||||
{State: "exited"},
|
||||
{State: "stopped"},
|
||||
{State: "healthy"},
|
||||
{State: "unhealthy"},
|
||||
}
|
||||
|
||||
stats := CalculateContainerStats(containers)
|
||||
|
||||
assert.Equal(t, 4, stats.Running)
|
||||
assert.Equal(t, 2, stats.Stopped)
|
||||
assert.Equal(t, 1, stats.Healthy)
|
||||
assert.Equal(t, 1, stats.Unhealthy)
|
||||
assert.Equal(t, 6, stats.Total)
|
||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestImageParser(t *testing.T) {
|
||||
|
@ -14,7 +15,7 @@ func TestImageParser(t *testing.T) {
|
|||
image, err := ParseImage(ParseImageOptions{
|
||||
Name: "portainer/portainer-ee",
|
||||
})
|
||||
is.NoError(err, "")
|
||||
require.NoError(t, err)
|
||||
is.Equal("docker.io/portainer/portainer-ee:latest", image.FullName())
|
||||
is.Equal("portainer/portainer-ee", image.Opts.Name)
|
||||
is.Equal("latest", image.Tag)
|
||||
|
@ -30,10 +31,10 @@ func TestImageParser(t *testing.T) {
|
|||
image, err := ParseImage(ParseImageOptions{
|
||||
Name: "gcr.io/k8s-minikube/kicbase@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2",
|
||||
})
|
||||
is.NoError(err, "")
|
||||
require.NoError(t, err)
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.FullName())
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name)
|
||||
is.Equal("", image.Tag)
|
||||
is.Empty(image.Tag)
|
||||
is.Equal("k8s-minikube/kicbase", image.Path)
|
||||
is.Equal("gcr.io", image.Domain)
|
||||
is.Equal("https://gcr.io/k8s-minikube/kicbase", image.HubLink)
|
||||
|
@ -47,7 +48,7 @@ func TestImageParser(t *testing.T) {
|
|||
image, err := ParseImage(ParseImageOptions{
|
||||
Name: "gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2",
|
||||
})
|
||||
is.NoError(err, "")
|
||||
require.NoError(t, err)
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30", image.FullName())
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name)
|
||||
is.Equal("v0.0.30", image.Tag)
|
||||
|
@ -68,8 +69,9 @@ func TestUpdateParsedImage(t *testing.T) {
|
|||
image, err := ParseImage(ParseImageOptions{
|
||||
Name: "gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2",
|
||||
})
|
||||
is.NoError(err, "")
|
||||
_ = image.WithTag("v0.0.31")
|
||||
require.NoError(t, err)
|
||||
err = image.WithTag("v0.0.31")
|
||||
require.NoError(t, err)
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.31", image.FullName())
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name)
|
||||
is.Equal("v0.0.31", image.Tag)
|
||||
|
@ -86,8 +88,9 @@ func TestUpdateParsedImage(t *testing.T) {
|
|||
image, err := ParseImage(ParseImageOptions{
|
||||
Name: "gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2",
|
||||
})
|
||||
is.NoError(err, "")
|
||||
_ = image.WithDigest("sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b3")
|
||||
require.NoError(t, err)
|
||||
err = image.WithDigest("sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b3")
|
||||
require.NoError(t, err)
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30", image.FullName())
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name)
|
||||
is.Equal("v0.0.30", image.Tag)
|
||||
|
@ -104,8 +107,9 @@ func TestUpdateParsedImage(t *testing.T) {
|
|||
image, err := ParseImage(ParseImageOptions{
|
||||
Name: "gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2",
|
||||
})
|
||||
is.NoError(err, "")
|
||||
_ = image.TrimDigest()
|
||||
require.NoError(t, err)
|
||||
err = image.TrimDigest()
|
||||
require.NoError(t, err)
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30", image.FullName())
|
||||
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name)
|
||||
is.Equal("v0.0.30", image.Tag)
|
||||
|
|
|
@ -4,7 +4,9 @@ import (
|
|||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFindBestMatchNeedAuthRegistry(t *testing.T) {
|
||||
|
@ -15,9 +17,9 @@ func TestFindBestMatchNeedAuthRegistry(t *testing.T) {
|
|||
registries := []portainer.Registry{createNewRegistry("docker.io", "USERNAME", false),
|
||||
createNewRegistry("hub-mirror.c.163.com", "", false)}
|
||||
r, err := findBestMatchRegistry(image, registries)
|
||||
is.NoError(err, "")
|
||||
is.NotNil(r, "")
|
||||
is.False(r.Authentication, "")
|
||||
require.NoError(t, err)
|
||||
is.NotNil(r)
|
||||
is.False(r.Authentication)
|
||||
is.Equal("docker.io", r.URL)
|
||||
})
|
||||
|
||||
|
@ -26,9 +28,9 @@ func TestFindBestMatchNeedAuthRegistry(t *testing.T) {
|
|||
registries := []portainer.Registry{createNewRegistry("docker.io", "", false),
|
||||
createNewRegistry("hub-mirror.c.163.com", "USERNAME", false)}
|
||||
r, err := findBestMatchRegistry(image, registries)
|
||||
is.NoError(err, "")
|
||||
is.NotNil(r, "")
|
||||
is.False(r.Authentication, "")
|
||||
require.NoError(t, err)
|
||||
is.NotNil(r)
|
||||
is.False(r.Authentication)
|
||||
is.Equal("docker.io", r.URL)
|
||||
})
|
||||
|
||||
|
@ -37,9 +39,9 @@ func TestFindBestMatchNeedAuthRegistry(t *testing.T) {
|
|||
registries := []portainer.Registry{createNewRegistry("docker.io", "USERNAME", true),
|
||||
createNewRegistry("hub-mirror.c.163.com", "", false)}
|
||||
r, err := findBestMatchRegistry(image, registries)
|
||||
is.NoError(err, "")
|
||||
is.NotNil(r, "")
|
||||
is.True(r.Authentication, "")
|
||||
require.NoError(t, err)
|
||||
is.NotNil(r)
|
||||
is.True(r.Authentication)
|
||||
is.Equal("docker.io", r.URL)
|
||||
})
|
||||
|
||||
|
@ -47,9 +49,9 @@ func TestFindBestMatchNeedAuthRegistry(t *testing.T) {
|
|||
image := "portainer/portainer-ee:latest"
|
||||
registries := []portainer.Registry{createNewRegistry("docker.io", "", true)}
|
||||
r, err := findBestMatchRegistry(image, registries)
|
||||
is.NoError(err, "")
|
||||
is.NotNil(r, "")
|
||||
is.True(r.Authentication, "")
|
||||
require.NoError(t, err)
|
||||
is.NotNil(r)
|
||||
is.True(r.Authentication)
|
||||
is.Equal("docker.io", r.URL)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
package stats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
)
|
||||
|
||||
type ContainerStats struct {
|
||||
Running int `json:"running"`
|
||||
Stopped int `json:"stopped"`
|
||||
Healthy int `json:"healthy"`
|
||||
Unhealthy int `json:"unhealthy"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
type DockerClient interface {
|
||||
ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error)
|
||||
}
|
||||
|
||||
func CalculateContainerStats(ctx context.Context, cli DockerClient, isSwarm bool, containers []container.Summary) (ContainerStats, error) {
|
||||
if isSwarm {
|
||||
return CalculateContainerStatsForSwarm(containers), nil
|
||||
}
|
||||
|
||||
var running, stopped, healthy, unhealthy int
|
||||
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 5)
|
||||
|
||||
var aggErr error
|
||||
var aggMu sync.Mutex
|
||||
|
||||
for i := range containers {
|
||||
id := containers[i].ID
|
||||
semaphore <- struct{}{}
|
||||
wg.Go(func() {
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
containerInspection, err := cli.ContainerInspect(ctx, id)
|
||||
stat := ContainerStats{}
|
||||
if err != nil {
|
||||
aggMu.Lock()
|
||||
aggErr = errors.Join(aggErr, err)
|
||||
aggMu.Unlock()
|
||||
return
|
||||
}
|
||||
stat = getContainerStatus(containerInspection.State)
|
||||
|
||||
mu.Lock()
|
||||
running += stat.Running
|
||||
stopped += stat.Stopped
|
||||
healthy += stat.Healthy
|
||||
unhealthy += stat.Unhealthy
|
||||
mu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return ContainerStats{
|
||||
Running: running,
|
||||
Stopped: stopped,
|
||||
Healthy: healthy,
|
||||
Unhealthy: unhealthy,
|
||||
Total: len(containers),
|
||||
}, aggErr
|
||||
}
|
||||
|
||||
func getContainerStatus(state *container.State) ContainerStats {
|
||||
stat := ContainerStats{}
|
||||
if state == nil {
|
||||
return stat
|
||||
}
|
||||
|
||||
switch state.Status {
|
||||
case container.StateRunning:
|
||||
stat.Running++
|
||||
case container.StateExited, container.StateDead:
|
||||
stat.Stopped++
|
||||
}
|
||||
|
||||
if state.Health != nil {
|
||||
switch state.Health.Status {
|
||||
case container.Healthy:
|
||||
stat.Healthy++
|
||||
case container.Unhealthy:
|
||||
stat.Unhealthy++
|
||||
}
|
||||
}
|
||||
|
||||
return stat
|
||||
}
|
||||
|
||||
// This is a temporary workaround to calculate container stats for Swarm
|
||||
// TODO: Remove this once we have a proper way to calculate container stats for Swarm
|
||||
func CalculateContainerStatsForSwarm(containers []container.Summary) ContainerStats {
|
||||
var running, stopped, healthy, unhealthy int
|
||||
for _, container := range containers {
|
||||
switch container.State {
|
||||
case "running":
|
||||
running++
|
||||
case "exited", "stopped":
|
||||
stopped++
|
||||
}
|
||||
|
||||
if strings.Contains(container.Status, "(healthy)") {
|
||||
healthy++
|
||||
} else if strings.Contains(container.Status, "(unhealthy)") {
|
||||
unhealthy++
|
||||
}
|
||||
}
|
||||
|
||||
return ContainerStats{
|
||||
Running: running,
|
||||
Stopped: stopped,
|
||||
Healthy: healthy,
|
||||
Unhealthy: unhealthy,
|
||||
Total: len(containers),
|
||||
}
|
||||
}
|
|
@ -0,0 +1,253 @@
|
|||
package stats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MockDockerClient implements the DockerClient interface for testing
|
||||
type MockDockerClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockDockerClient) ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error) {
|
||||
args := m.Called(ctx, containerID)
|
||||
return args.Get(0).(container.InspectResponse), args.Error(1)
|
||||
}
|
||||
|
||||
func TestCalculateContainerStats(t *testing.T) {
|
||||
mockClient := new(MockDockerClient)
|
||||
|
||||
// Test containers - using enough containers to test concurrent processing
|
||||
containers := []container.Summary{
|
||||
{ID: "container1"},
|
||||
{ID: "container2"},
|
||||
{ID: "container3"},
|
||||
{ID: "container4"},
|
||||
{ID: "container5"},
|
||||
{ID: "container6"},
|
||||
{ID: "container7"},
|
||||
{ID: "container8"},
|
||||
{ID: "container9"},
|
||||
{ID: "container10"},
|
||||
}
|
||||
|
||||
// Setup mock expectations with different container states to test various scenarios
|
||||
containerStates := []struct {
|
||||
id string
|
||||
status string
|
||||
health *container.Health
|
||||
expected ContainerStats
|
||||
}{
|
||||
{"container1", container.StateRunning, &container.Health{Status: container.Healthy}, ContainerStats{Running: 1, Stopped: 0, Healthy: 1, Unhealthy: 0}},
|
||||
{"container2", container.StateRunning, &container.Health{Status: container.Unhealthy}, ContainerStats{Running: 1, Stopped: 0, Healthy: 0, Unhealthy: 1}},
|
||||
{"container3", container.StateRunning, nil, ContainerStats{Running: 1, Stopped: 0, Healthy: 0, Unhealthy: 0}},
|
||||
{"container4", container.StateExited, nil, ContainerStats{Running: 0, Stopped: 1, Healthy: 0, Unhealthy: 0}},
|
||||
{"container5", container.StateDead, nil, ContainerStats{Running: 0, Stopped: 1, Healthy: 0, Unhealthy: 0}},
|
||||
{"container6", container.StateRunning, &container.Health{Status: container.Healthy}, ContainerStats{Running: 1, Stopped: 0, Healthy: 1, Unhealthy: 0}},
|
||||
{"container7", container.StateRunning, &container.Health{Status: container.Unhealthy}, ContainerStats{Running: 1, Stopped: 0, Healthy: 0, Unhealthy: 1}},
|
||||
{"container8", container.StateExited, nil, ContainerStats{Running: 0, Stopped: 1, Healthy: 0, Unhealthy: 0}},
|
||||
{"container9", container.StateRunning, nil, ContainerStats{Running: 1, Stopped: 0, Healthy: 0, Unhealthy: 0}},
|
||||
{"container10", container.StateDead, nil, ContainerStats{Running: 0, Stopped: 1, Healthy: 0, Unhealthy: 0}},
|
||||
}
|
||||
|
||||
expected := ContainerStats{}
|
||||
// Setup mock expectations for all containers with artificial delays to simulate real Docker calls
|
||||
for _, state := range containerStates {
|
||||
mockClient.On("ContainerInspect", mock.Anything, state.id).Return(container.InspectResponse{
|
||||
ContainerJSONBase: &container.ContainerJSONBase{
|
||||
State: &container.State{
|
||||
Status: state.status,
|
||||
Health: state.health,
|
||||
},
|
||||
},
|
||||
}, nil).After(50 * time.Millisecond) // Simulate 50ms Docker API call
|
||||
|
||||
expected.Running += state.expected.Running
|
||||
expected.Stopped += state.expected.Stopped
|
||||
expected.Healthy += state.expected.Healthy
|
||||
expected.Unhealthy += state.expected.Unhealthy
|
||||
expected.Total++
|
||||
}
|
||||
|
||||
// Call the function and measure time
|
||||
startTime := time.Now()
|
||||
stats, err := CalculateContainerStats(context.Background(), mockClient, false, containers)
|
||||
require.NoError(t, err, "failed to calculate container stats")
|
||||
duration := time.Since(startTime)
|
||||
|
||||
// Assert results
|
||||
assert.Equal(t, expected, stats)
|
||||
assert.Equal(t, expected.Running, stats.Running)
|
||||
assert.Equal(t, expected.Stopped, stats.Stopped)
|
||||
assert.Equal(t, expected.Healthy, stats.Healthy)
|
||||
assert.Equal(t, expected.Unhealthy, stats.Unhealthy)
|
||||
assert.Equal(t, 10, stats.Total)
|
||||
|
||||
// Verify concurrent processing by checking that all mock calls were made
|
||||
mockClient.AssertExpectations(t)
|
||||
|
||||
// Test concurrency: With 5 workers and 10 containers taking 50ms each:
|
||||
// Sequential would take: 10 * 50ms = 500ms
|
||||
sequentialTime := 10 * 50 * time.Millisecond
|
||||
|
||||
// Verify that concurrent processing is actually faster than sequential
|
||||
// Allow some overhead for goroutine scheduling
|
||||
assert.Less(t, duration, sequentialTime, "Concurrent processing should be faster than sequential")
|
||||
// Concurrent should take: ~100-150ms (depending on scheduling)
|
||||
assert.Less(t, duration, 150*time.Millisecond, "Concurrent processing should be significantly faster")
|
||||
assert.Greater(t, duration, 100*time.Millisecond, "Concurrent processing should be longer than 100ms")
|
||||
}
|
||||
|
||||
func TestCalculateContainerStatsAllErrors(t *testing.T) {
|
||||
mockClient := new(MockDockerClient)
|
||||
|
||||
// Test containers
|
||||
containers := []container.Summary{
|
||||
{ID: "container1"},
|
||||
{ID: "container2"},
|
||||
}
|
||||
|
||||
// Setup mock expectations with all calls returning errors
|
||||
mockClient.On("ContainerInspect", mock.Anything, "container1").Return(container.InspectResponse{}, errors.New("network error"))
|
||||
mockClient.On("ContainerInspect", mock.Anything, "container2").Return(container.InspectResponse{}, errors.New("permission denied"))
|
||||
|
||||
// Call the function
|
||||
stats, err := CalculateContainerStats(context.Background(), mockClient, false, containers)
|
||||
|
||||
// Assert that an error was returned
|
||||
require.Error(t, err, "should return error when all containers fail to inspect")
|
||||
assert.Contains(t, err.Error(), "network error", "error should contain one of the original error messages")
|
||||
assert.Contains(t, err.Error(), "permission denied", "error should contain the other original error message")
|
||||
|
||||
// Assert that stats are zero since no containers were successfully processed
|
||||
expectedStats := ContainerStats{
|
||||
Running: 0,
|
||||
Stopped: 0,
|
||||
Healthy: 0,
|
||||
Unhealthy: 0,
|
||||
Total: 2, // total containers processed
|
||||
}
|
||||
assert.Equal(t, expectedStats, stats)
|
||||
|
||||
// Verify all mock calls were made
|
||||
mockClient.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetContainerStatus(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
state *container.State
|
||||
expected ContainerStats
|
||||
}{
|
||||
{
|
||||
name: "running healthy container",
|
||||
state: &container.State{
|
||||
Status: container.StateRunning,
|
||||
Health: &container.Health{
|
||||
Status: container.Healthy,
|
||||
},
|
||||
},
|
||||
expected: ContainerStats{
|
||||
Running: 1,
|
||||
Stopped: 0,
|
||||
Healthy: 1,
|
||||
Unhealthy: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "running unhealthy container",
|
||||
state: &container.State{
|
||||
Status: container.StateRunning,
|
||||
Health: &container.Health{
|
||||
Status: container.Unhealthy,
|
||||
},
|
||||
},
|
||||
expected: ContainerStats{
|
||||
Running: 1,
|
||||
Stopped: 0,
|
||||
Healthy: 0,
|
||||
Unhealthy: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "running container without health check",
|
||||
state: &container.State{
|
||||
Status: container.StateRunning,
|
||||
},
|
||||
expected: ContainerStats{
|
||||
Running: 1,
|
||||
Stopped: 0,
|
||||
Healthy: 0,
|
||||
Unhealthy: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "exited container",
|
||||
state: &container.State{
|
||||
Status: container.StateExited,
|
||||
},
|
||||
expected: ContainerStats{
|
||||
Running: 0,
|
||||
Stopped: 1,
|
||||
Healthy: 0,
|
||||
Unhealthy: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "dead container",
|
||||
state: &container.State{
|
||||
Status: container.StateDead,
|
||||
},
|
||||
expected: ContainerStats{
|
||||
Running: 0,
|
||||
Stopped: 1,
|
||||
Healthy: 0,
|
||||
Unhealthy: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil state",
|
||||
state: nil,
|
||||
expected: ContainerStats{
|
||||
Running: 0,
|
||||
Stopped: 0,
|
||||
Healthy: 0,
|
||||
Unhealthy: 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
stat := getContainerStatus(testCase.state)
|
||||
assert.Equal(t, testCase.expected, stat)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateContainerStatsForSwarm(t *testing.T) {
|
||||
containers := []container.Summary{
|
||||
{State: "running"},
|
||||
{State: "running", Status: "Up 5 minutes (healthy)"},
|
||||
{State: "exited"},
|
||||
{State: "stopped"},
|
||||
{State: "running", Status: "Up 10 minutes"},
|
||||
{State: "running", Status: "Up about an hour (unhealthy)"},
|
||||
}
|
||||
|
||||
stats := CalculateContainerStatsForSwarm(containers)
|
||||
|
||||
assert.Equal(t, 4, stats.Running)
|
||||
assert.Equal(t, 2, stats.Stopped)
|
||||
assert.Equal(t, 1, stats.Healthy)
|
||||
assert.Equal(t, 1, stats.Unhealthy)
|
||||
assert.Equal(t, 6, stats.Total)
|
||||
}
|
|
@ -49,6 +49,11 @@ type (
|
|||
|
||||
// Is relative path supported
|
||||
SupportRelativePath bool
|
||||
// AlwaysCloneGitRepoForRelativePath is a flag indicating if the agent must always clone the git repository for relative path.
|
||||
// This field is only valid when SupportRelativePath is true.
|
||||
// Used only for EE
|
||||
AlwaysCloneGitRepoForRelativePath bool
|
||||
|
||||
// Mount point for relative path
|
||||
FilesystemPath string
|
||||
// Used only for EE
|
||||
|
|
|
@ -8,7 +8,9 @@ import (
|
|||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_createEnvFile(t *testing.T) {
|
||||
|
@ -61,7 +63,7 @@ func Test_createEnvFile(t *testing.T) {
|
|||
|
||||
assert.Equal(t, tt.expected, string(content))
|
||||
} else {
|
||||
assert.Equal(t, "", result)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -79,7 +81,7 @@ func Test_createEnvFile_mergesDefultAndInplaceEnvVars(t *testing.T) {
|
|||
}
|
||||
result, err := createEnvFile(stack)
|
||||
assert.Equal(t, filepath.Join(stack.ProjectPath, "stack.env"), result)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.FileExists(t, path.Join(dir, "stack.env"))
|
||||
f, _ := os.Open(path.Join(dir, "stack.env"))
|
||||
content, _ := io.ReadAll(f)
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockKubectlClient struct {
|
||||
|
@ -68,7 +69,7 @@ func TestExecuteKubectlOperation_Apply_Success(t *testing.T) {
|
|||
manifests := []string{"manifest1.yaml", "manifest2.yaml"}
|
||||
err := testExecuteKubectlOperation(mockClient, "apply", manifests)
|
||||
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
|
@ -86,7 +87,7 @@ func TestExecuteKubectlOperation_Apply_Error(t *testing.T) {
|
|||
manifests := []string{"error.yaml"}
|
||||
err := testExecuteKubectlOperation(mockClient, "apply", manifests)
|
||||
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), expectedErr.Error())
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
@ -104,7 +105,7 @@ func TestExecuteKubectlOperation_Delete_Success(t *testing.T) {
|
|||
manifests := []string{"manifest1.yaml"}
|
||||
err := testExecuteKubectlOperation(mockClient, "delete", manifests)
|
||||
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
|
@ -122,7 +123,7 @@ func TestExecuteKubectlOperation_Delete_Error(t *testing.T) {
|
|||
manifests := []string{"error.yaml"}
|
||||
err := testExecuteKubectlOperation(mockClient, "delete", manifests)
|
||||
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), expectedErr.Error())
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
@ -140,7 +141,7 @@ func TestExecuteKubectlOperation_RolloutRestart_Success(t *testing.T) {
|
|||
resources := []string{"deployment/nginx"}
|
||||
err := testExecuteKubectlOperation(mockClient, "rollout-restart", resources)
|
||||
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
|
@ -158,7 +159,7 @@ func TestExecuteKubectlOperation_RolloutRestart_Error(t *testing.T) {
|
|||
resources := []string{"deployment/error"}
|
||||
err := testExecuteKubectlOperation(mockClient, "rollout-restart", resources)
|
||||
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), expectedErr.Error())
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
@ -168,6 +169,6 @@ func TestExecuteKubectlOperation_UnsupportedOperation(t *testing.T) {
|
|||
|
||||
err := testExecuteKubectlOperation(mockClient, "unsupported", []string{})
|
||||
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported operation")
|
||||
}
|
||||
|
|
|
@ -7,12 +7,13 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_copyFile_returnsError_whenSourceDoesNotExist(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
err := copyFile("does-not-exist", tmpdir)
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_copyFile_shouldMakeAbackup(t *testing.T) {
|
||||
|
@ -21,7 +22,7 @@ func Test_copyFile_shouldMakeAbackup(t *testing.T) {
|
|||
os.WriteFile(path.Join(tmpdir, "origin"), content, 0600)
|
||||
|
||||
err := copyFile(path.Join(tmpdir, "origin"), path.Join(tmpdir, "copy"))
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
copyContent, _ := os.ReadFile(path.Join(tmpdir, "copy"))
|
||||
assert.Equal(t, content, copyContent)
|
||||
|
@ -30,7 +31,7 @@ func Test_copyFile_shouldMakeAbackup(t *testing.T) {
|
|||
func Test_CopyDir_shouldCopyAllFilesAndDirectories(t *testing.T) {
|
||||
destination := t.TempDir()
|
||||
err := CopyDir("./testdata/copy_test", destination, true)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.FileExists(t, filepath.Join(destination, "copy_test", "outer"))
|
||||
assert.FileExists(t, filepath.Join(destination, "copy_test", "dir", ".dotfile"))
|
||||
|
@ -40,7 +41,7 @@ func Test_CopyDir_shouldCopyAllFilesAndDirectories(t *testing.T) {
|
|||
func Test_CopyDir_shouldCopyOnlyDirContents(t *testing.T) {
|
||||
destination := t.TempDir()
|
||||
err := CopyDir("./testdata/copy_test", destination, false)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.FileExists(t, filepath.Join(destination, "outer"))
|
||||
assert.FileExists(t, filepath.Join(destination, "dir", ".dotfile"))
|
||||
|
@ -50,7 +51,7 @@ func Test_CopyDir_shouldCopyOnlyDirContents(t *testing.T) {
|
|||
func Test_CopyPath_shouldSkipWhenNotExist(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
err := CopyPath("does-not-exists", tmpdir)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NoFileExists(t, tmpdir)
|
||||
}
|
||||
|
@ -62,17 +63,17 @@ func Test_CopyPath_shouldCopyFile(t *testing.T) {
|
|||
|
||||
os.MkdirAll(path.Join(tmpdir, "backup"), 0700)
|
||||
err := CopyPath(path.Join(tmpdir, "file"), path.Join(tmpdir, "backup"))
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
copyContent, err := os.ReadFile(path.Join(tmpdir, "backup", "file"))
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, copyContent)
|
||||
}
|
||||
|
||||
func Test_CopyPath_shouldCopyDir(t *testing.T) {
|
||||
destination := t.TempDir()
|
||||
err := CopyPath("./testdata/copy_test", destination)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.FileExists(t, filepath.Join(destination, "copy_test", "outer"))
|
||||
assert.FileExists(t, filepath.Join(destination, "copy_test", "dir", ".dotfile"))
|
||||
|
|
|
@ -848,7 +848,7 @@ func defaultMTLSCertPathUnderFileStore() (string, string, string) {
|
|||
return caCertPath, certPath, keyPath
|
||||
}
|
||||
|
||||
// GetDefaultChiselPrivateKeyPath returns the chisle private key path
|
||||
// GetDefaultChiselPrivateKeyPath returns the chisel private key path
|
||||
func (service *Service) GetDefaultChiselPrivateKeyPath() string {
|
||||
privateKeyPath := defaultChiselPrivateKeyPathUnderFileStore()
|
||||
return service.wrapFileStore(privateKeyPath)
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_fileSystemService_FileExists_whenFileExistsShouldReturnTrue(t *testing.T) {
|
||||
|
@ -30,14 +31,14 @@ func Test_FileExists_whenFileNotExistsShouldReturnFalse(t *testing.T) {
|
|||
|
||||
func testHelperFileExists_fileExists(t *testing.T, checker func(path string) (bool, error)) {
|
||||
file, err := os.CreateTemp("", t.Name())
|
||||
assert.NoError(t, err, "CreateTemp should not fail")
|
||||
require.NoError(t, err, "CreateTemp should not fail")
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.RemoveAll(file.Name())
|
||||
})
|
||||
|
||||
exists, err := checker(file.Name())
|
||||
assert.NoError(t, err, "FileExists should not fail")
|
||||
require.NoError(t, err, "FileExists should not fail")
|
||||
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
@ -46,10 +47,10 @@ func testHelperFileExists_fileNotExists(t *testing.T, checker func(path string)
|
|||
filePath := path.Join(t.TempDir(), fmt.Sprintf("%s%d", t.Name(), rand.Int()))
|
||||
|
||||
err := os.RemoveAll(filePath)
|
||||
assert.NoError(t, err, "RemoveAll should not fail")
|
||||
require.NoError(t, err, "RemoveAll should not fail")
|
||||
|
||||
exists, err := checker(filePath)
|
||||
assert.NoError(t, err, "FileExists should not fail")
|
||||
require.NoError(t, err, "FileExists should not fail")
|
||||
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var content = []byte("content")
|
||||
|
@ -13,25 +14,25 @@ var content = []byte("content")
|
|||
func Test_movePath_shouldFailIfSourceDirDoesNotExist(t *testing.T) {
|
||||
sourceDir := "missing"
|
||||
destinationDir := t.TempDir()
|
||||
file1 := addFile(destinationDir, "dir", "file")
|
||||
file2 := addFile(destinationDir, "file")
|
||||
file1 := addFile(t, destinationDir, "dir", "file")
|
||||
file2 := addFile(t, destinationDir, "file")
|
||||
|
||||
err := MoveDirectory(sourceDir, destinationDir, false)
|
||||
assert.Error(t, err, "move directory should fail when source path is missing")
|
||||
require.Error(t, err, "move directory should fail when source path is missing")
|
||||
assert.FileExists(t, file1, "destination dir contents should remain")
|
||||
assert.FileExists(t, file2, "destination dir contents should remain")
|
||||
}
|
||||
|
||||
func Test_movePath_shouldFailIfDestinationDirExists(t *testing.T) {
|
||||
sourceDir := t.TempDir()
|
||||
file1 := addFile(sourceDir, "dir", "file")
|
||||
file2 := addFile(sourceDir, "file")
|
||||
file1 := addFile(t, sourceDir, "dir", "file")
|
||||
file2 := addFile(t, sourceDir, "file")
|
||||
destinationDir := t.TempDir()
|
||||
file3 := addFile(destinationDir, "dir", "file")
|
||||
file4 := addFile(destinationDir, "file")
|
||||
file3 := addFile(t, destinationDir, "dir", "file")
|
||||
file4 := addFile(t, destinationDir, "file")
|
||||
|
||||
err := MoveDirectory(sourceDir, destinationDir, false)
|
||||
assert.Error(t, err, "move directory should fail when destination directory already exists")
|
||||
require.Error(t, err, "move directory should fail when destination directory already exists")
|
||||
assert.FileExists(t, file1, "source dir contents should remain")
|
||||
assert.FileExists(t, file2, "source dir contents should remain")
|
||||
assert.FileExists(t, file3, "destination dir contents should remain")
|
||||
|
@ -40,14 +41,14 @@ func Test_movePath_shouldFailIfDestinationDirExists(t *testing.T) {
|
|||
|
||||
func Test_movePath_succesIfOverwriteSetWhenDestinationDirExists(t *testing.T) {
|
||||
sourceDir := t.TempDir()
|
||||
file1 := addFile(sourceDir, "dir", "file")
|
||||
file2 := addFile(sourceDir, "file")
|
||||
file1 := addFile(t, sourceDir, "dir", "file")
|
||||
file2 := addFile(t, sourceDir, "file")
|
||||
destinationDir := t.TempDir()
|
||||
file3 := addFile(destinationDir, "dir", "file")
|
||||
file4 := addFile(destinationDir, "file")
|
||||
file3 := addFile(t, destinationDir, "dir", "file")
|
||||
file4 := addFile(t, destinationDir, "file")
|
||||
|
||||
err := MoveDirectory(sourceDir, destinationDir, true)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.NoFileExists(t, file1, "source dir contents should be moved")
|
||||
assert.NoFileExists(t, file2, "source dir contents should be moved")
|
||||
assert.FileExists(t, file3, "destination dir contents should remain")
|
||||
|
@ -58,32 +59,34 @@ func Test_movePath_successWhenSourceExistsAndDestinationIsMissing(t *testing.T)
|
|||
tmp := t.TempDir()
|
||||
sourceDir := path.Join(tmp, "source")
|
||||
os.Mkdir(sourceDir, 0766)
|
||||
file1 := addFile(sourceDir, "dir", "file")
|
||||
file2 := addFile(sourceDir, "file")
|
||||
file1 := addFile(t, sourceDir, "dir", "file")
|
||||
file2 := addFile(t, sourceDir, "file")
|
||||
destinationDir := path.Join(tmp, "destination")
|
||||
|
||||
err := MoveDirectory(sourceDir, destinationDir, false)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.NoFileExists(t, file1, "source dir contents should be moved")
|
||||
assert.NoFileExists(t, file2, "source dir contents should be moved")
|
||||
assertFileContent(t, path.Join(destinationDir, "file"))
|
||||
assertFileContent(t, path.Join(destinationDir, "dir", "file"))
|
||||
}
|
||||
|
||||
func addFile(fileParts ...string) (filepath string) {
|
||||
func addFile(t *testing.T, fileParts ...string) (filepath string) {
|
||||
if len(fileParts) > 2 {
|
||||
dir := path.Join(fileParts[:len(fileParts)-1]...)
|
||||
os.MkdirAll(dir, 0766)
|
||||
err := os.MkdirAll(dir, 0766)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
p := path.Join(fileParts...)
|
||||
os.WriteFile(p, content, 0766)
|
||||
err := os.WriteFile(p, content, 0766)
|
||||
require.NoError(t, err)
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func assertFileContent(t *testing.T, filePath string) {
|
||||
actualContent, err := os.ReadFile(filePath)
|
||||
assert.NoErrorf(t, err, "failed to read file %s", filePath)
|
||||
require.NoError(t, err, "failed to read file %s", filePath)
|
||||
assert.Equal(t, content, actualContent, "file %s content doesn't match", filePath)
|
||||
}
|
||||
|
|
|
@ -5,14 +5,14 @@ import (
|
|||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func createService(t *testing.T) *Service {
|
||||
dataStorePath := path.Join(t.TempDir(), t.Name())
|
||||
|
||||
service, err := NewService(dataStorePath, "")
|
||||
assert.NoError(t, err, "NewService should not fail")
|
||||
require.NoError(t, err, "NewService should not fail")
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.RemoveAll(dataStorePath)
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_WriteFile_CanStoreContentInANewFile(t *testing.T) {
|
||||
|
@ -14,7 +15,7 @@ func Test_WriteFile_CanStoreContentInANewFile(t *testing.T) {
|
|||
|
||||
content := []byte("content")
|
||||
err := WriteToFile(tmpFilePath, content)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
fileContent, _ := os.ReadFile(tmpFilePath)
|
||||
assert.Equal(t, content, fileContent)
|
||||
|
@ -25,11 +26,11 @@ func Test_WriteFile_CanOverwriteExistingFile(t *testing.T) {
|
|||
tmpFilePath := path.Join(tmpDir, "dummy")
|
||||
|
||||
err := WriteToFile(tmpFilePath, []byte("content"))
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
content := []byte("new content")
|
||||
err = WriteToFile(tmpFilePath, content)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
fileContent, _ := os.ReadFile(tmpFilePath)
|
||||
assert.Equal(t, content, fileContent)
|
||||
|
@ -41,7 +42,7 @@ func Test_WriteFile_CanWriteANestedPath(t *testing.T) {
|
|||
|
||||
content := []byte("content")
|
||||
err := WriteToFile(tmpFilePath, content)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
fileContent, _ := os.ReadFile(tmpFilePath)
|
||||
assert.Equal(t, content, fileContent)
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const privateAzureRepoURL = "https://portainer.visualstudio.com/gitops-test/_git/gitops-test"
|
||||
|
@ -67,7 +68,7 @@ func TestService_ClonePublicRepository_Azure(t *testing.T) {
|
|||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.FileExists(t, filepath.Join(dst, "README.md"))
|
||||
})
|
||||
}
|
||||
|
@ -90,7 +91,7 @@ func TestService_ClonePrivateRepository_Azure(t *testing.T) {
|
|||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.FileExists(t, filepath.Join(dst, "README.md"))
|
||||
}
|
||||
|
||||
|
@ -108,7 +109,7 @@ func TestService_LatestCommitID_Azure(t *testing.T) {
|
|||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, id, "cannot guarantee commit id, but it should be not empty")
|
||||
}
|
||||
|
||||
|
@ -127,7 +128,7 @@ func TestService_ListRefs_Azure(t *testing.T) {
|
|||
false,
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(refs), 1)
|
||||
}
|
||||
|
||||
|
@ -289,14 +290,14 @@ func TestService_ListFiles_Azure(t *testing.T) {
|
|||
false,
|
||||
)
|
||||
if tt.expect.shouldFail {
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
if tt.expect.err != nil {
|
||||
assert.Equal(t, tt.expect.err, err)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
if tt.expect.matchedCount > 0 {
|
||||
assert.Greater(t, len(paths), 0)
|
||||
assert.NotEmpty(t, paths)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
@ -8,7 +8,10 @@ import (
|
|||
"testing"
|
||||
|
||||
gittypes "github.com/portainer/portainer/api/git/types"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_buildDownloadUrl(t *testing.T) {
|
||||
|
@ -18,15 +21,18 @@ func Test_buildDownloadUrl(t *testing.T) {
|
|||
project: "project",
|
||||
repository: "repository",
|
||||
}, "refs/heads/main")
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedUrl, _ := url.Parse("https://dev.azure.com/organisation/project/_apis/git/repositories/repository/items?scopePath=/&download=true&versionDescriptor.version=main&$format=zip&recursionLevel=full&api-version=6.0&versionDescriptor.versionType=branch")
|
||||
actualUrl, _ := url.Parse(u)
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, expectedUrl.Host, actualUrl.Host)
|
||||
assert.Equal(t, expectedUrl.Scheme, actualUrl.Scheme)
|
||||
assert.Equal(t, expectedUrl.Path, actualUrl.Path)
|
||||
assert.Equal(t, expectedUrl.Query(), actualUrl.Query())
|
||||
}
|
||||
expectedUrl, err := url.Parse("https://dev.azure.com/organisation/project/_apis/git/repositories/repository/items?scopePath=/&download=true&versionDescriptor.version=main&$format=zip&recursionLevel=full&api-version=6.0&versionDescriptor.versionType=branch")
|
||||
require.NoError(t, err)
|
||||
|
||||
actualUrl, err := url.Parse(u)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, expectedUrl.Host, actualUrl.Host)
|
||||
assert.Equal(t, expectedUrl.Scheme, actualUrl.Scheme)
|
||||
assert.Equal(t, expectedUrl.Path, actualUrl.Path)
|
||||
assert.Equal(t, expectedUrl.Query(), actualUrl.Query())
|
||||
}
|
||||
|
||||
func Test_buildRootItemUrl(t *testing.T) {
|
||||
|
@ -39,7 +45,7 @@ func Test_buildRootItemUrl(t *testing.T) {
|
|||
|
||||
expectedUrl, _ := url.Parse("https://dev.azure.com/organisation/project/_apis/git/repositories/repository/items?scopePath=/&api-version=6.0&versionDescriptor.version=main&versionDescriptor.versionType=branch")
|
||||
actualUrl, _ := url.Parse(u)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedUrl.Host, actualUrl.Host)
|
||||
assert.Equal(t, expectedUrl.Scheme, actualUrl.Scheme)
|
||||
assert.Equal(t, expectedUrl.Path, actualUrl.Path)
|
||||
|
@ -56,7 +62,7 @@ func Test_buildRefsUrl(t *testing.T) {
|
|||
|
||||
expectedUrl, _ := url.Parse("https://dev.azure.com/organisation/project/_apis/git/repositories/repository/refs?api-version=6.0")
|
||||
actualUrl, _ := url.Parse(u)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedUrl.Host, actualUrl.Host)
|
||||
assert.Equal(t, expectedUrl.Scheme, actualUrl.Scheme)
|
||||
assert.Equal(t, expectedUrl.Path, actualUrl.Path)
|
||||
|
@ -73,7 +79,7 @@ func Test_buildTreeUrl(t *testing.T) {
|
|||
|
||||
expectedUrl, _ := url.Parse("https://dev.azure.com/organisation/project/_apis/git/repositories/repository/trees/sha1?api-version=6.0&recursive=true")
|
||||
actualUrl, _ := url.Parse(u)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedUrl.Host, actualUrl.Host)
|
||||
assert.Equal(t, expectedUrl.Scheme, actualUrl.Scheme)
|
||||
assert.Equal(t, expectedUrl.Path, actualUrl.Path)
|
||||
|
@ -234,6 +240,8 @@ func Test_isAzureUrl(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_azureDownloader_downloadZipFromAzureDevOps(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
type args struct {
|
||||
options baseOption
|
||||
}
|
||||
|
@ -301,13 +309,15 @@ func Test_azureDownloader_downloadZipFromAzureDevOps(t *testing.T) {
|
|||
},
|
||||
}
|
||||
_, err := a.downloadZipFromAzureDevOps(context.Background(), option)
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tt.want, zipRequestAuth)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_azureDownloader_latestCommitID(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := `{
|
||||
"count": 1,
|
||||
|
@ -497,12 +507,12 @@ func Test_listRefs_azure(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
refs, err := client.listRefs(context.TODO(), tt.args)
|
||||
if tt.expect.err == nil {
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
if tt.expect.refsCount > 0 {
|
||||
assert.Greater(t, len(refs), 0)
|
||||
assert.NotEmpty(t, refs)
|
||||
}
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tt.expect.err, err)
|
||||
}
|
||||
})
|
||||
|
@ -608,14 +618,14 @@ func Test_listFiles_azure(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
paths, err := client.listFiles(context.TODO(), tt.args)
|
||||
if tt.expect.shouldFail {
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
if tt.expect.err != nil {
|
||||
assert.Equal(t, tt.expect.err, err)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
if tt.expect.matchedCount > 0 {
|
||||
assert.Greater(t, len(paths), 0)
|
||||
assert.NotEmpty(t, paths)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
@ -9,7 +9,9 @@ import (
|
|||
"time"
|
||||
|
||||
gittypes "github.com/portainer/portainer/api/git/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -35,7 +37,7 @@ func TestService_ClonePrivateRepository_GitHub(t *testing.T) {
|
|||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.FileExists(t, filepath.Join(dst, "README.md"))
|
||||
}
|
||||
|
||||
|
@ -55,7 +57,7 @@ func TestService_LatestCommitID_GitHub(t *testing.T) {
|
|||
gittypes.GitCredentialAuthType_Basic,
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, id, "cannot guarantee commit id, but it should be not empty")
|
||||
}
|
||||
|
||||
|
@ -68,7 +70,7 @@ func TestService_ListRefs_GitHub(t *testing.T) {
|
|||
|
||||
repositoryUrl := privateGitRepoURL
|
||||
refs, err := service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(refs), 1)
|
||||
}
|
||||
|
||||
|
@ -231,14 +233,14 @@ func TestService_ListFiles_GitHub(t *testing.T) {
|
|||
false,
|
||||
)
|
||||
if tt.expect.shouldFail {
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
if tt.expect.err != nil {
|
||||
assert.Equal(t, tt.expect.err, err)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
if tt.expect.matchedCount > 0 {
|
||||
assert.Greater(t, len(paths), 0)
|
||||
assert.NotEmpty(t, paths)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -361,12 +363,12 @@ func TestService_HardRefresh_ListRefs_GitHub(t *testing.T) {
|
|||
|
||||
repositoryUrl := privateGitRepoURL
|
||||
refs, err := service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(refs), 1)
|
||||
assert.Equal(t, 1, service.repoRefCache.Len())
|
||||
|
||||
_, err = service.ListRefs(repositoryUrl, username, "fake-token", gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, 1, service.repoRefCache.Len())
|
||||
}
|
||||
|
||||
|
@ -379,7 +381,7 @@ func TestService_HardRefresh_ListRefs_And_RemoveAllCaches_GitHub(t *testing.T) {
|
|||
|
||||
repositoryUrl := privateGitRepoURL
|
||||
refs, err := service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(refs), 1)
|
||||
assert.Equal(t, 1, service.repoRefCache.Len())
|
||||
|
||||
|
@ -394,7 +396,7 @@ func TestService_HardRefresh_ListRefs_And_RemoveAllCaches_GitHub(t *testing.T) {
|
|||
[]string{},
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(files), 1)
|
||||
assert.Equal(t, 1, service.repoFileCache.Len())
|
||||
|
||||
|
@ -409,16 +411,16 @@ func TestService_HardRefresh_ListRefs_And_RemoveAllCaches_GitHub(t *testing.T) {
|
|||
[]string{},
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(files), 1)
|
||||
assert.Equal(t, 2, service.repoFileCache.Len())
|
||||
|
||||
_, err = service.ListRefs(repositoryUrl, username, "fake-token", gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, 1, service.repoRefCache.Len())
|
||||
|
||||
_, err = service.ListRefs(repositoryUrl, username, "fake-token", gittypes.GitCredentialAuthType_Basic, true, false)
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, 1, service.repoRefCache.Len())
|
||||
// The relevant file caches should be removed too
|
||||
assert.Equal(t, 0, service.repoFileCache.Len())
|
||||
|
@ -442,7 +444,7 @@ func TestService_HardRefresh_ListFiles_GitHub(t *testing.T) {
|
|||
[]string{},
|
||||
false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(files), 1)
|
||||
assert.Equal(t, 1, service.repoFileCache.Len())
|
||||
|
||||
|
@ -457,7 +459,7 @@ func TestService_HardRefresh_ListFiles_GitHub(t *testing.T) {
|
|||
[]string{},
|
||||
false,
|
||||
)
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, 0, service.repoFileCache.Len())
|
||||
}
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/go-git/go-git/v5/plumbing/object"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setup(t *testing.T) string {
|
||||
|
@ -39,7 +40,7 @@ func Test_ClonePublicRepository_Shallow(t *testing.T) {
|
|||
dir := t.TempDir()
|
||||
t.Logf("Cloning into %s", dir)
|
||||
err := service.CloneRepository(dir, repositoryURL, referenceName, "", "", gittypes.GitCredentialAuthType_Basic, false)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, getCommitHistoryLength(t, dir), "cloned repo has incorrect depth")
|
||||
}
|
||||
|
||||
|
@ -51,7 +52,7 @@ func Test_ClonePublicRepository_NoGitDirectory(t *testing.T) {
|
|||
dir := t.TempDir()
|
||||
t.Logf("Cloning into %s", dir)
|
||||
err := service.CloneRepository(dir, repositoryURL, referenceName, "", "", gittypes.GitCredentialAuthType_Basic, false)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.NoDirExists(t, filepath.Join(dir, ".git"))
|
||||
}
|
||||
|
||||
|
@ -74,7 +75,7 @@ func Test_cloneRepository(t *testing.T) {
|
|||
depth: 10,
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 4, getCommitHistoryLength(t, dir), "cloned repo has incorrect depth")
|
||||
}
|
||||
|
||||
|
@ -86,7 +87,7 @@ func Test_latestCommitID(t *testing.T) {
|
|||
|
||||
id, err := service.LatestCommitID(repositoryURL, referenceName, "", "", gittypes.GitCredentialAuthType_Basic, false)
|
||||
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "68dcaa7bd452494043c64252ab90db0f98ecf8d2", id)
|
||||
}
|
||||
|
||||
|
@ -97,7 +98,7 @@ func Test_ListRefs(t *testing.T) {
|
|||
|
||||
fs, err := service.ListRefs(repositoryURL, "", "", gittypes.GitCredentialAuthType_Basic, false, false)
|
||||
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"refs/heads/main"}, fs)
|
||||
}
|
||||
|
||||
|
@ -119,7 +120,7 @@ func Test_ListFiles(t *testing.T) {
|
|||
false,
|
||||
)
|
||||
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"docker-compose.yml"}, fs)
|
||||
}
|
||||
|
||||
|
@ -214,12 +215,12 @@ func Test_listRefsPrivateRepository(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
refs, err := client.listRefs(context.TODO(), tt.args)
|
||||
if tt.expect.err == nil {
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
if tt.expect.refsCount > 0 {
|
||||
assert.Greater(t, len(refs), 0)
|
||||
assert.NotEmpty(t, refs)
|
||||
}
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tt.expect.err, err)
|
||||
}
|
||||
})
|
||||
|
@ -325,14 +326,14 @@ func Test_listFilesPrivateRepository(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
paths, err := client.listFiles(context.TODO(), tt.args)
|
||||
if tt.expect.shouldFail {
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
if tt.expect.err != nil {
|
||||
assert.Equal(t, tt.expect.err, err)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
if tt.expect.matchedCount > 0 {
|
||||
assert.Greater(t, len(paths), 0)
|
||||
assert.NotEmpty(t, paths)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
@ -21,10 +21,14 @@ func ValidateAutoUpdateSettings(autoUpdate *portainer.AutoUpdateSettings) error
|
|||
return httperrors.NewInvalidPayloadError("invalid Webhook format")
|
||||
}
|
||||
|
||||
if autoUpdate.Interval != "" {
|
||||
if _, err := time.ParseDuration(autoUpdate.Interval); err != nil {
|
||||
return httperrors.NewInvalidPayloadError("invalid Interval format")
|
||||
}
|
||||
if autoUpdate.Interval == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if d, err := time.ParseDuration(autoUpdate.Interval); err != nil {
|
||||
return httperrors.NewInvalidPayloadError("invalid Interval format")
|
||||
} else if d < time.Minute {
|
||||
return httperrors.NewInvalidPayloadError("interval must be at least 1 minute")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -23,6 +23,16 @@ func Test_ValidateAutoUpdate(t *testing.T) {
|
|||
value: &portainer.AutoUpdateSettings{Interval: "1dd2hh3mm"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "short interval value",
|
||||
value: &portainer.AutoUpdateSettings{Interval: "1s"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid webhook without interval",
|
||||
value: &portainer.AutoUpdateSettings{Webhook: "8dce8c2f-9ca1-482b-ad20-271e86536ada"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid auto update",
|
||||
value: &portainer.AutoUpdateSettings{
|
||||
|
|
|
@ -4,10 +4,14 @@ import (
|
|||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewService(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
service := NewService(true)
|
||||
require.NotNil(t, service)
|
||||
require.True(t, service.httpsClient.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify) //nolint:forbidigo
|
||||
|
|
|
@ -6,11 +6,14 @@ import (
|
|||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecutePingOperationFailure(t *testing.T) {
|
||||
fips.InitFIPS(false)
|
||||
|
||||
host := "http://localhost:1"
|
||||
config := portainer.TLSConfiguration{
|
||||
TLS: true,
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
)
|
||||
|
||||
func TxResponse(err error, validResponse func() *httperror.HandlerError) *httperror.HandlerError {
|
||||
if err != nil {
|
||||
var handlerError *httperror.HandlerError
|
||||
if errors.As(err, &handlerError) {
|
||||
return handlerError
|
||||
}
|
||||
|
||||
return httperror.InternalServerError("Unexpected error", err)
|
||||
}
|
||||
|
||||
return validResponse()
|
||||
}
|
|
@ -26,11 +26,10 @@ func (handler *Handler) logout(w http.ResponseWriter, r *http.Request) *httperro
|
|||
handler.KubernetesTokenCacheManager.RemoveUserFromCache(tokenData.ID)
|
||||
handler.KubernetesClientFactory.ClearUserClientCache(strconv.Itoa(int(tokenData.ID)))
|
||||
logoutcontext.Cancel(tokenData.Token)
|
||||
handler.bouncer.RevokeJWT(tokenData.Token)
|
||||
}
|
||||
|
||||
security.RemoveAuthCookie(w)
|
||||
|
||||
handler.bouncer.RevokeJWT(tokenData.Token)
|
||||
|
||||
return response.Empty(w)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/http/proxy/factory/kubernetes"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
"github.com/portainer/portainer/api/kubernetes/cli"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockBouncer struct {
|
||||
security.BouncerService
|
||||
}
|
||||
|
||||
func NewMockBouncer() *mockBouncer {
|
||||
return &mockBouncer{BouncerService: testhelpers.NewTestRequestBouncer()}
|
||||
}
|
||||
|
||||
func (*mockBouncer) CookieAuthLookup(r *http.Request) (*portainer.TokenData, error) {
|
||||
return &portainer.TokenData{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Token: "valid-token",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestLogout(t *testing.T) {
|
||||
h := NewHandler(NewMockBouncer(), nil, nil, nil)
|
||||
h.KubernetesTokenCacheManager = kubernetes.NewTokenCacheManager()
|
||||
k, err := cli.NewClientFactory(nil, nil, nil, "", "", "")
|
||||
require.NoError(t, err)
|
||||
h.KubernetesClientFactory = k
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/auth/logout", nil)
|
||||
|
||||
h.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusNoContent, rr.Code)
|
||||
}
|
||||
|
||||
func TestLogoutNoPanic(t *testing.T) {
|
||||
h := NewHandler(testhelpers.NewTestRequestBouncer(), nil, nil, nil)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/auth/logout", nil)
|
||||
|
||||
h.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusNoContent, rr.Code)
|
||||
}
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_restoreArchive_usingCombinationOfPasswords(t *testing.T) {
|
||||
|
@ -64,13 +65,12 @@ func Test_restoreArchive_usingCombinationOfPasswords(t *testing.T) {
|
|||
adminMonitor,
|
||||
)
|
||||
|
||||
//backup
|
||||
// backup
|
||||
archive := backup(t, h, test.backupPassword)
|
||||
|
||||
//restore
|
||||
// restore
|
||||
w := httptest.NewRecorder()
|
||||
r, err := prepareMultipartRequest(test.restorePassword, archive)
|
||||
assert.Nil(t, err, "Shouldn't fail to write multipart form")
|
||||
r := prepareMultipartRequest(t, test.restorePassword, archive)
|
||||
|
||||
restoreErr := h.restore(w, r)
|
||||
assert.Equal(t, test.fails, restoreErr != nil, "Didn't meet expectation of failing restore handler")
|
||||
|
@ -96,13 +96,12 @@ func Test_restoreArchive_shouldFailIfSystemWasAlreadyInitialized(t *testing.T) {
|
|||
adminMonitor,
|
||||
)
|
||||
|
||||
//backup
|
||||
// backup
|
||||
archive := backup(t, h, "password")
|
||||
|
||||
//restore
|
||||
// restore
|
||||
w := httptest.NewRecorder()
|
||||
r, err := prepareMultipartRequest("password", archive)
|
||||
assert.Nil(t, err, "Shouldn't fail to write multipart form")
|
||||
r := prepareMultipartRequest(t, "password", archive)
|
||||
|
||||
restoreErr := h.restore(w, r)
|
||||
assert.NotNil(t, restoreErr, "Should fail, because system it already initialized")
|
||||
|
@ -117,31 +116,31 @@ func backup(t *testing.T, h *Handler, password string) []byte {
|
|||
assert.Nil(t, backupErr, "Backup should not fail")
|
||||
|
||||
response := w.Result()
|
||||
archive, _ := io.ReadAll(response.Body)
|
||||
archive, err := io.ReadAll(response.Body)
|
||||
require.NoError(t, err)
|
||||
response.Body.Close()
|
||||
|
||||
return archive
|
||||
}
|
||||
|
||||
func prepareMultipartRequest(password string, file []byte) (*http.Request, error) {
|
||||
func prepareMultipartRequest(t *testing.T, password string, file []byte) *http.Request {
|
||||
var body bytes.Buffer
|
||||
|
||||
w := multipart.NewWriter(&body)
|
||||
err := w.WriteField("password", password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
fw, err := w.CreateFormFile("file", "filename")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
io.Copy(fw, bytes.NewReader(file))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = io.Copy(fw, bytes.NewReader(file))
|
||||
require.NoError(t, err)
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "http://localhost/", &body)
|
||||
r.Header.Set("Content-Type", w.FormDataContentType())
|
||||
|
||||
w.Close()
|
||||
err = w.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
return r, nil
|
||||
return r
|
||||
}
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -25,6 +24,8 @@ import (
|
|||
|
||||
"github.com/segmentio/encoding/json"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -112,15 +113,14 @@ func createTestFile(targetPath string) error {
|
|||
}
|
||||
|
||||
func prepareTestFolder(projectPath, filename string) error {
|
||||
err := os.MkdirAll(projectPath, fs.ModePerm)
|
||||
if err != nil {
|
||||
if err := os.MkdirAll(projectPath, fs.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return createTestFile(filepath.Join(projectPath, filename))
|
||||
}
|
||||
|
||||
func singleAPIRequest(h *Handler, jwt string, is *assert.Assertions, expect string) {
|
||||
func singleAPIRequest(h *Handler, jwt string, expect string) error {
|
||||
type response struct {
|
||||
FileContent string
|
||||
}
|
||||
|
@ -131,15 +131,25 @@ func singleAPIRequest(h *Handler, jwt string, is *assert.Assertions, expect stri
|
|||
rr := httptest.NewRecorder()
|
||||
h.ServeHTTP(rr, req)
|
||||
|
||||
is.Equal(http.StatusOK, rr.Code)
|
||||
if rr.Code != http.StatusOK {
|
||||
return errors.New("unexpected status code: " + http.StatusText(rr.Code))
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(rr.Body)
|
||||
is.NoError(err, "ReadAll should not return error")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var resp response
|
||||
err = json.Unmarshal(body, &resp)
|
||||
is.NoError(err, "response should be list json")
|
||||
is.Equal(resp.FileContent, expect)
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.FileContent != expect {
|
||||
return errors.New("unexpected file content: " + resp.FileContent + ", expected: " + expect)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_customTemplateGitFetch(t *testing.T) {
|
||||
|
@ -150,28 +160,29 @@ func Test_customTemplateGitFetch(t *testing.T) {
|
|||
// create user(s)
|
||||
user1 := &portainer.User{ID: 1, Username: "user-1", Role: portainer.StandardUserRole, PortainerAuthorizations: authorization.DefaultPortainerAuthorizations()}
|
||||
err := store.User().Create(user1)
|
||||
is.NoError(err, "error creating user 1")
|
||||
require.NoError(t, err, "error creating user 1")
|
||||
|
||||
user2 := &portainer.User{ID: 2, Username: "user-2", Role: portainer.StandardUserRole, PortainerAuthorizations: authorization.DefaultPortainerAuthorizations()}
|
||||
err = store.User().Create(user2)
|
||||
is.NoError(err, "error creating user 2")
|
||||
require.NoError(t, err, "error creating user 2")
|
||||
|
||||
dir, err := os.Getwd()
|
||||
is.NoError(err, "error to get working directory")
|
||||
require.NoError(t, err, "error to get working directory")
|
||||
|
||||
template1 := &portainer.CustomTemplate{ID: 1, Title: "custom-template-1", ProjectPath: filepath.Join(dir, "fixtures/custom_template_1"), GitConfig: &gittypes.RepoConfig{ConfigFilePath: "test-config-path.txt"}}
|
||||
err = store.CustomTemplateService.Create(template1)
|
||||
is.NoError(err, "error creating custom template 1")
|
||||
require.NoError(t, err, "error creating custom template 1")
|
||||
|
||||
// prepare testing folder
|
||||
err = prepareTestFolder(template1.ProjectPath, template1.GitConfig.ConfigFilePath)
|
||||
is.NoError(err, "error creating testing folder")
|
||||
require.NoError(t, err, "error creating testing folder")
|
||||
|
||||
defer os.RemoveAll(filepath.Join(dir, "fixtures"))
|
||||
|
||||
// setup services
|
||||
jwtService, err := jwt.NewService("1h", store)
|
||||
is.NoError(err, "Error initiating jwt service")
|
||||
require.NoError(t, err, "Error initiating jwt service")
|
||||
|
||||
requestBouncer := security.NewRequestBouncer(store, jwtService, nil)
|
||||
|
||||
gitService := &TestGitService{
|
||||
|
@ -182,52 +193,55 @@ func Test_customTemplateGitFetch(t *testing.T) {
|
|||
h := NewHandler(requestBouncer, store, fileService, gitService)
|
||||
|
||||
// generate two standard users' tokens
|
||||
jwt1, _, _ := jwtService.GenerateToken(&portainer.TokenData{ID: user1.ID, Username: user1.Username, Role: user1.Role})
|
||||
jwt2, _, _ := jwtService.GenerateToken(&portainer.TokenData{ID: user2.ID, Username: user2.Username, Role: user2.Role})
|
||||
jwt1, _, err := jwtService.GenerateToken(&portainer.TokenData{ID: user1.ID, Username: user1.Username, Role: user1.Role})
|
||||
require.NoError(t, err)
|
||||
|
||||
jwt2, _, err := jwtService.GenerateToken(&portainer.TokenData{ID: user2.ID, Username: user2.Username, Role: user2.Role})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("can return the expected file content by a single call from one user", func(t *testing.T) {
|
||||
singleAPIRequest(h, jwt1, is, "abcdefg")
|
||||
err := singleAPIRequest(h, jwt1, "abcdefg")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("can return the expected file content by multiple calls from one user", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(5)
|
||||
var g errgroup.Group
|
||||
|
||||
for range 5 {
|
||||
go func() {
|
||||
singleAPIRequest(h, jwt1, is, "abcdefg")
|
||||
wg.Done()
|
||||
}()
|
||||
g.Go(func() error {
|
||||
return singleAPIRequest(h, jwt1, "abcdefg")
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
err := g.Wait()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("can return the expected file content by multiple calls from different users", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(10)
|
||||
var g errgroup.Group
|
||||
|
||||
for i := range 10 {
|
||||
go func(j int) {
|
||||
if j%2 == 0 {
|
||||
singleAPIRequest(h, jwt1, is, "abcdefg")
|
||||
} else {
|
||||
singleAPIRequest(h, jwt2, is, "abcdefg")
|
||||
g.Go(func() error {
|
||||
if i%2 == 0 {
|
||||
return singleAPIRequest(h, jwt1, "abcdefg")
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}(i)
|
||||
return singleAPIRequest(h, jwt2, "abcdefg")
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
err := g.Wait()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("can return the expected file content after a new commit is made", func(t *testing.T) {
|
||||
singleAPIRequest(h, jwt1, is, "abcdefg")
|
||||
err := singleAPIRequest(h, jwt1, "abcdefg")
|
||||
require.NoError(t, err)
|
||||
|
||||
testFileContent = "gfedcba"
|
||||
|
||||
singleAPIRequest(h, jwt2, is, "gfedcba")
|
||||
err = singleAPIRequest(h, jwt2, "gfedcba")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("restore git repository if it is failed to download the new git repository", func(t *testing.T) {
|
||||
|
@ -246,11 +260,11 @@ func Test_customTemplateGitFetch(t *testing.T) {
|
|||
|
||||
var errResp httperror.HandlerError
|
||||
err = json.NewDecoder(rr.Body).Decode(&errResp)
|
||||
assert.NoError(t, err, "failed to parse error body")
|
||||
require.NoError(t, err, "failed to parse error body")
|
||||
|
||||
assert.FileExists(t, gitService.targetFilePath, "previous git repository is not restored")
|
||||
fileContent, err := os.ReadFile(gitService.targetFilePath)
|
||||
assert.NoError(t, err, "failed to read target file")
|
||||
require.NoError(t, err, "failed to read target file")
|
||||
assert.Equal(t, "gfedcba", string(fileContent))
|
||||
})
|
||||
}
|
||||
|
|
|
@ -5,8 +5,11 @@ import (
|
|||
"strconv"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
httperrors "github.com/portainer/portainer/api/http/errors"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/authorization"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
|
@ -32,31 +35,45 @@ func (handler *Handler) customTemplateInspect(w http.ResponseWriter, r *http.Req
|
|||
return httperror.BadRequest("Invalid Custom template identifier route variable", err)
|
||||
}
|
||||
|
||||
customTemplate, err := handler.DataStore.CustomTemplate().Read(portainer.CustomTemplateID(customTemplateID))
|
||||
if handler.DataStore.IsErrObjectNotFound(err) {
|
||||
return httperror.NotFound("Unable to find a custom template with the specified identifier inside the database", err)
|
||||
} else if err != nil {
|
||||
return httperror.InternalServerError("Unable to find a custom template with the specified identifier inside the database", err)
|
||||
}
|
||||
var customTemplate *portainer.CustomTemplate
|
||||
err = handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error {
|
||||
customTemplate, err = tx.CustomTemplate().Read(portainer.CustomTemplateID(customTemplateID))
|
||||
if handler.DataStore.IsErrObjectNotFound(err) {
|
||||
return httperror.NotFound("Unable to find a custom template with the specified identifier inside the database", err)
|
||||
} else if err != nil {
|
||||
return httperror.InternalServerError("Unable to find a custom template with the specified identifier inside the database", err)
|
||||
}
|
||||
|
||||
securityContext, err := security.RetrieveRestrictedRequestContext(r)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to retrieve user info from request context", err)
|
||||
}
|
||||
resourceControl, err := tx.ResourceControl().ResourceControlByResourceIDAndType(strconv.Itoa(customTemplateID), portainer.CustomTemplateResourceControl)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to retrieve a resource control associated to the custom template", err)
|
||||
}
|
||||
|
||||
resourceControl, err := handler.DataStore.ResourceControl().ResourceControlByResourceIDAndType(strconv.Itoa(customTemplateID), portainer.CustomTemplateResourceControl)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to retrieve a resource control associated to the custom template", err)
|
||||
}
|
||||
securityContext, err := security.RetrieveRestrictedRequestContext(r)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to retrieve user info from request context", err)
|
||||
}
|
||||
|
||||
canEdit := userCanEditTemplate(customTemplate, securityContext)
|
||||
hasAccess := false
|
||||
|
||||
if resourceControl != nil {
|
||||
customTemplate.ResourceControl = resourceControl
|
||||
|
||||
teamIDs := slicesx.Map(securityContext.UserMemberships, func(m portainer.TeamMembership) portainer.TeamID {
|
||||
return m.TeamID
|
||||
})
|
||||
|
||||
hasAccess = authorization.UserCanAccessResource(securityContext.UserID, teamIDs, resourceControl)
|
||||
|
||||
}
|
||||
|
||||
if canEdit || hasAccess {
|
||||
return nil
|
||||
}
|
||||
|
||||
access := userCanEditTemplate(customTemplate, securityContext)
|
||||
if !access {
|
||||
return httperror.Forbidden("Access denied to resource", httperrors.ErrResourceAccessDenied)
|
||||
}
|
||||
})
|
||||
|
||||
if resourceControl != nil {
|
||||
customTemplate.ResourceControl = resourceControl
|
||||
}
|
||||
|
||||
return response.JSON(w, customTemplate)
|
||||
return response.TxResponse(w, customTemplate, err)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
package customtemplates
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/segmentio/encoding/json"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInspectHandler(t *testing.T) {
|
||||
_, ds := datastore.MustNewTestStore(t, true, false)
|
||||
require.NotNil(t, ds)
|
||||
|
||||
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
require.NoError(t, tx.User().Create(&portainer.User{ID: 1, Username: "admin", Role: portainer.AdministratorRole}))
|
||||
require.NoError(t, tx.User().Create(&portainer.User{ID: 2, Username: "std2", Role: portainer.StandardUserRole}))
|
||||
require.NoError(t, tx.User().Create(&portainer.User{ID: 3, Username: "std3", Role: portainer.StandardUserRole}))
|
||||
require.NoError(t, tx.User().Create(&portainer.User{ID: 4, Username: "std4", Role: portainer.StandardUserRole}))
|
||||
require.NoError(t, tx.Endpoint().Create(&portainer.Endpoint{ID: 1,
|
||||
UserAccessPolicies: portainer.UserAccessPolicies{
|
||||
2: portainer.AccessPolicy{RoleID: 0},
|
||||
3: portainer.AccessPolicy{RoleID: 0},
|
||||
}}))
|
||||
require.NoError(t, tx.Team().Create(&portainer.Team{ID: 1}))
|
||||
require.NoError(t, tx.TeamMembership().Create(&portainer.TeamMembership{ID: 1, UserID: 3, TeamID: 1, Role: portainer.TeamMember}))
|
||||
|
||||
require.NoError(t, tx.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1}))
|
||||
require.NoError(t, tx.CustomTemplate().Create(&portainer.CustomTemplate{ID: 2}))
|
||||
require.NoError(t, tx.ResourceControl().Create(&portainer.ResourceControl{ID: 1, ResourceID: "2", Type: portainer.CustomTemplateResourceControl,
|
||||
UserAccesses: []portainer.UserResourceAccess{{UserID: 2}},
|
||||
TeamAccesses: []portainer.TeamResourceAccess{{TeamID: 1}},
|
||||
}))
|
||||
return nil
|
||||
}))
|
||||
|
||||
handler := NewHandler(testhelpers.NewTestRequestBouncer(), ds, &TestFileService{}, nil)
|
||||
|
||||
test := func(templateID string, restrictedContext *security.RestrictedRequestContext) (*httptest.ResponseRecorder, *httperror.HandlerError) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/custom_templates/"+templateID, nil)
|
||||
r = mux.SetURLVars(r, map[string]string{"id": templateID})
|
||||
ctx := security.StoreRestrictedRequestContext(r, restrictedContext)
|
||||
r = r.WithContext(ctx)
|
||||
rr := httptest.NewRecorder()
|
||||
return rr, handler.customTemplateInspect(rr, r)
|
||||
}
|
||||
|
||||
t.Run("unknown id should get not found error", func(t *testing.T) {
|
||||
_, r := test("0", &security.RestrictedRequestContext{UserID: 1})
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusNotFound, r.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("admin should access adminonly template", func(t *testing.T) {
|
||||
rr, r := test("1", &security.RestrictedRequestContext{UserID: 1, IsAdmin: true})
|
||||
require.Nil(t, r)
|
||||
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
|
||||
var template portainer.CustomTemplate
|
||||
require.NoError(t, json.NewDecoder(rr.Body).Decode(&template))
|
||||
require.Equal(t, portainer.CustomTemplateID(1), template.ID)
|
||||
})
|
||||
|
||||
t.Run("std should not access adminonly template", func(t *testing.T) {
|
||||
_, r := test("1", &security.RestrictedRequestContext{UserID: 2})
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusForbidden, r.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("std should access template via direct user access", func(t *testing.T) {
|
||||
rr, r := test("2", &security.RestrictedRequestContext{UserID: 2})
|
||||
require.Nil(t, r)
|
||||
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
|
||||
var template portainer.CustomTemplate
|
||||
require.NoError(t, json.NewDecoder(rr.Body).Decode(&template))
|
||||
require.Equal(t, portainer.CustomTemplateID(2), template.ID)
|
||||
})
|
||||
|
||||
t.Run("std should access template via team access", func(t *testing.T) {
|
||||
rr, r := test("2", &security.RestrictedRequestContext{UserID: 3, UserMemberships: []portainer.TeamMembership{{ID: 1, UserID: 3, TeamID: 1}}})
|
||||
require.Nil(t, r)
|
||||
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
|
||||
var template portainer.CustomTemplate
|
||||
require.NoError(t, json.NewDecoder(rr.Body).Decode(&template))
|
||||
require.Equal(t, portainer.CustomTemplateID(2), template.ID)
|
||||
})
|
||||
|
||||
t.Run("std should not access template without access", func(t *testing.T) {
|
||||
_, r := test("2", &security.RestrictedRequestContext{UserID: 4})
|
||||
require.NotNil(t, r)
|
||||
require.Equal(t, http.StatusForbidden, r.StatusCode)
|
||||
})
|
||||
}
|
|
@ -11,8 +11,7 @@ import (
|
|||
"github.com/docker/docker/api/types/volume"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/docker"
|
||||
"github.com/portainer/portainer/api/http/errors"
|
||||
"github.com/portainer/portainer/api/docker/stats"
|
||||
"github.com/portainer/portainer/api/http/handler/docker/utils"
|
||||
"github.com/portainer/portainer/api/http/middlewares"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
|
@ -26,12 +25,12 @@ type imagesCounters struct {
|
|||
}
|
||||
|
||||
type dashboardResponse struct {
|
||||
Containers docker.ContainerStats `json:"containers"`
|
||||
Services int `json:"services"`
|
||||
Images imagesCounters `json:"images"`
|
||||
Volumes int `json:"volumes"`
|
||||
Networks int `json:"networks"`
|
||||
Stacks int `json:"stacks"`
|
||||
Containers stats.ContainerStats `json:"containers"`
|
||||
Services int `json:"services"`
|
||||
Images imagesCounters `json:"images"`
|
||||
Volumes int `json:"volumes"`
|
||||
Networks int `json:"networks"`
|
||||
Stacks int `json:"stacks"`
|
||||
}
|
||||
|
||||
// @id dockerDashboard
|
||||
|
@ -144,13 +143,18 @@ func (h *Handler) dashboard(w http.ResponseWriter, r *http.Request) *httperror.H
|
|||
stackCount = len(stacks)
|
||||
}
|
||||
|
||||
containersStats, err := stats.CalculateContainerStats(r.Context(), cli, info.Swarm.ControlAvailable, containers)
|
||||
if err != nil {
|
||||
return httperror.InternalServerError("Unable to retrieve Docker containers stats", err)
|
||||
}
|
||||
|
||||
resp = dashboardResponse{
|
||||
Images: imagesCounters{
|
||||
Total: len(images),
|
||||
Size: totalSize,
|
||||
},
|
||||
Services: len(services),
|
||||
Containers: docker.CalculateContainerStats(containers),
|
||||
Containers: containersStats,
|
||||
Networks: len(networks),
|
||||
Volumes: len(volumes),
|
||||
Stacks: stackCount,
|
||||
|
@ -159,7 +163,5 @@ func (h *Handler) dashboard(w http.ResponseWriter, r *http.Request) *httperror.H
|
|||
return nil
|
||||
})
|
||||
|
||||
return errors.TxResponse(err, func() *httperror.HandlerError {
|
||||
return response.JSON(w, resp)
|
||||
})
|
||||
return response.TxResponse(w, resp, err)
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHandler_getDockerStacks(t *testing.T) {
|
||||
|
@ -69,7 +70,7 @@ func TestHandler_getDockerStacks(t *testing.T) {
|
|||
stacksList, err := GetDockerStacks(datastore, &security.RestrictedRequestContext{
|
||||
IsAdmin: true,
|
||||
}, environment.ID, containers, services)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, stacksList, 3)
|
||||
|
||||
expectedStacks := []StackViewModel{
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/portainer/portainer/api/roar"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
)
|
||||
|
||||
type edgeGroupCreatePayload struct {
|
||||
|
@ -111,5 +112,5 @@ func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request)
|
|||
return nil
|
||||
})
|
||||
|
||||
return txResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err)
|
||||
return response.TxResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err)
|
||||
}
|
||||
|
|
|
@ -32,16 +32,8 @@ func (handler *Handler) edgeGroupDelete(w http.ResponseWriter, r *http.Request)
|
|||
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
return deleteEdgeGroup(tx, portainer.EdgeGroupID(edgeGroupID))
|
||||
})
|
||||
if err != nil {
|
||||
var httpErr *httperror.HandlerError
|
||||
if errors.As(err, &httpErr) {
|
||||
return httpErr
|
||||
}
|
||||
|
||||
return httperror.InternalServerError("Unexpected error", err)
|
||||
}
|
||||
|
||||
return response.Empty(w)
|
||||
return response.TxEmptyResponse(w, err)
|
||||
}
|
||||
|
||||
func deleteEdgeGroup(tx dataservices.DataStoreTx, ID portainer.EdgeGroupID) error {
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/portainer/portainer/api/roar"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
)
|
||||
|
||||
// @id EdgeGroupInspect
|
||||
|
@ -36,7 +37,7 @@ func (handler *Handler) edgeGroupInspect(w http.ResponseWriter, r *http.Request)
|
|||
|
||||
edgeGroup.Endpoints = edgeGroup.EndpointIDs.ToSlice()
|
||||
|
||||
return txResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err)
|
||||
return response.TxResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err)
|
||||
}
|
||||
|
||||
func getEdgeGroup(tx dataservices.DataStoreTx, ID portainer.EdgeGroupID) (*portainer.EdgeGroup, error) {
|
||||
|
|
|
@ -105,7 +105,7 @@ func TestEmptyEdgeGroupInspectHandler(t *testing.T) {
|
|||
|
||||
// Make sure the frontend does not get a null value but a [] instead
|
||||
require.NotNil(t, responseGroup.Endpoints)
|
||||
require.Len(t, responseGroup.Endpoints, 0)
|
||||
require.Empty(t, responseGroup.Endpoints)
|
||||
}
|
||||
|
||||
func TestDynamicEdgeGroupInspectHandler(t *testing.T) {
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/roar"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
)
|
||||
|
||||
type shadowedEdgeGroup struct {
|
||||
|
@ -44,7 +45,7 @@ func (handler *Handler) edgeGroupList(w http.ResponseWriter, r *http.Request) *h
|
|||
return err
|
||||
})
|
||||
|
||||
return txResponse(w, decoratedEdgeGroups, err)
|
||||
return response.TxResponse(w, decoratedEdgeGroups, err)
|
||||
}
|
||||
|
||||
func getEdgeGroupList(tx dataservices.DataStoreTx) ([]decoratedEdgeGroup, error) {
|
||||
|
|
|
@ -47,7 +47,7 @@ func Test_getEndpointTypes(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
ans, err := getEndpointTypes(datastore, roar.FromSlice(test.endpointIds))
|
||||
assert.NoError(t, err, "getEndpointTypes shouldn't fail")
|
||||
require.NoError(t, err, "getEndpointTypes shouldn't fail")
|
||||
|
||||
assert.ElementsMatch(t, test.expected, ans, "getEndpointTypes expected to return %b for %v, but returned %b", test.expected, test.endpointIds, ans)
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ func Test_getEndpointTypes_failWhenEndpointDontExist(t *testing.T) {
|
|||
datastore := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{}))
|
||||
|
||||
_, err := getEndpointTypes(datastore, roar.FromSlice([]portainer.EndpointID{1}))
|
||||
assert.Error(t, err, "getEndpointTypes should fail")
|
||||
require.Error(t, err, "getEndpointTypes should fail")
|
||||
}
|
||||
|
||||
func TestEdgeGroupListHandler(t *testing.T) {
|
||||
|
@ -112,5 +112,5 @@ func TestEdgeGroupListHandler(t *testing.T) {
|
|||
|
||||
require.Len(t, responseGroups, 1)
|
||||
require.ElementsMatch(t, []portainer.EndpointID{1, 2, 3}, responseGroups[0].Endpoints)
|
||||
require.Len(t, responseGroups[0].TrustedEndpoints, 0)
|
||||
require.Empty(t, responseGroups[0].TrustedEndpoints)
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/portainer/portainer/api/slicesx"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
)
|
||||
|
||||
type edgeGroupUpdatePayload struct {
|
||||
|
@ -158,7 +159,7 @@ func (handler *Handler) edgeGroupUpdate(w http.ResponseWriter, r *http.Request)
|
|||
return nil
|
||||
})
|
||||
|
||||
return txResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err)
|
||||
return response.TxResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err)
|
||||
}
|
||||
|
||||
func (handler *Handler) updateEndpointStacks(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint, edgeGroups []portainer.EdgeGroup, edgeStacks []portainer.EdgeStack) error {
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
package edgegroups
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
@ -38,16 +36,3 @@ func NewHandler(bouncer security.BouncerService) *Handler {
|
|||
|
||||
return h
|
||||
}
|
||||
|
||||
func txResponse(w http.ResponseWriter, r any, err error) *httperror.HandlerError {
|
||||
if err != nil {
|
||||
var handlerError *httperror.HandlerError
|
||||
if errors.As(err, &handlerError) {
|
||||
return handlerError
|
||||
}
|
||||
|
||||
return httperror.InternalServerError("Unexpected error", err)
|
||||
}
|
||||
|
||||
return response.JSON(w, r)
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
"github.com/portainer/portainer/pkg/validate"
|
||||
)
|
||||
|
||||
|
@ -85,19 +86,18 @@ func (payload *edgeJobCreateFromFileContentPayload) Validate(r *http.Request) er
|
|||
// @router /edge_jobs/create/string [post]
|
||||
func (handler *Handler) createEdgeJobFromFileContent(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
|
||||
var payload edgeJobCreateFromFileContentPayload
|
||||
err := request.DecodeAndValidateJSONPayload(r, &payload)
|
||||
if err != nil {
|
||||
if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
|
||||
return httperror.BadRequest("Invalid request payload", err)
|
||||
}
|
||||
|
||||
var edgeJob *portainer.EdgeJob
|
||||
var err error
|
||||
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
edgeJob, err = handler.createEdgeJob(tx, &payload.edgeJobBasePayload, []byte(payload.FileContent))
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return txResponse(w, edgeJob, err)
|
||||
return response.TxResponse(w, edgeJob, err)
|
||||
}
|
||||
|
||||
func (handler *Handler) createEdgeJob(tx dataservices.DataStoreTx, payload *edgeJobBasePayload, fileContent []byte) (*portainer.EdgeJob, error) {
|
||||
|
@ -191,19 +191,18 @@ func (payload *edgeJobCreateFromFilePayload) Validate(r *http.Request) error {
|
|||
// @router /edge_jobs/create/file [post]
|
||||
func (handler *Handler) createEdgeJobFromFile(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
|
||||
payload := &edgeJobCreateFromFilePayload{}
|
||||
err := payload.Validate(r)
|
||||
if err != nil {
|
||||
if err := payload.Validate(r); err != nil {
|
||||
return httperror.BadRequest("Invalid request payload", err)
|
||||
}
|
||||
|
||||
var edgeJob *portainer.EdgeJob
|
||||
var err error
|
||||
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
edgeJob, err = handler.createEdgeJob(tx, &payload.edgeJobBasePayload, payload.File)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return txResponse(w, edgeJob, err)
|
||||
return response.TxResponse(w, edgeJob, err)
|
||||
}
|
||||
|
||||
func (handler *Handler) createEdgeJobObjectFromPayload(tx dataservices.DataStoreTx, payload *edgeJobBasePayload) *portainer.EdgeJob {
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
package edgejobs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockFileService struct {
|
||||
mock.Mock
|
||||
portainer.FileService
|
||||
}
|
||||
|
||||
func (m *mockFileService) StoreEdgeJobFileFromBytes(id string, file []byte) (string, error) {
|
||||
args := m.Called(id, file)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockFileService) GetEdgeJobFolder(id string) string {
|
||||
args := m.Called(id)
|
||||
|
||||
return args.String(0)
|
||||
}
|
||||
|
||||
func (m *mockFileService) RemoveDirectory(path string) error {
|
||||
args := m.Called(path)
|
||||
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func initStore(t *testing.T) *datastore.Store {
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
require.NotNil(t, store)
|
||||
|
||||
require.NoError(t, store.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
require.NoError(t, tx.Endpoint().Create(&portainer.Endpoint{
|
||||
ID: 1,
|
||||
Name: "endpoint-1",
|
||||
EdgeID: "edge-id-1",
|
||||
GroupID: 1,
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
UserTrusted: true,
|
||||
}))
|
||||
|
||||
require.NoError(t, tx.Endpoint().Create(&portainer.Endpoint{
|
||||
ID: 2,
|
||||
Name: "endpoint-2",
|
||||
EdgeID: "edge-id-2",
|
||||
GroupID: 1,
|
||||
Type: portainer.EdgeAgentOnDockerEnvironment,
|
||||
UserTrusted: false,
|
||||
}))
|
||||
return nil
|
||||
}))
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func Test_edgeJobCreate_StringMethod_Success(t *testing.T) {
|
||||
store := initStore(t)
|
||||
|
||||
fileService := &mockFileService{}
|
||||
fileService.On("StoreEdgeJobFileFromBytes", mock.Anything, mock.Anything).Return("testfile.txt", nil)
|
||||
|
||||
handler := &Handler{
|
||||
DataStore: store,
|
||||
FileService: fileService,
|
||||
}
|
||||
|
||||
payload := edgeJobCreateFromFileContentPayload{
|
||||
edgeJobBasePayload: edgeJobBasePayload{
|
||||
Name: "testjob",
|
||||
CronExpression: "* * * * *",
|
||||
Endpoints: []portainer.EndpointID{1, 2},
|
||||
},
|
||||
FileContent: "echo hello",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPost, "/edge_jobs/create/string", bytes.NewReader(body))
|
||||
req = mux.SetURLVars(req, map[string]string{"method": "string"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call handler
|
||||
errh := handler.edgeJobCreate(w, req)
|
||||
require.Nil(t, errh)
|
||||
require.Equal(t, http.StatusOK, w.Result().StatusCode)
|
||||
|
||||
// Get edge job ID from response
|
||||
var resp struct {
|
||||
ID int `json:"Id"`
|
||||
}
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
|
||||
edgeJob, err := store.EdgeJob().Read(portainer.EdgeJobID(resp.ID))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, edgeJob.Endpoints, 2)
|
||||
require.Contains(t, edgeJob.Endpoints, portainer.EndpointID(1))
|
||||
}
|
||||
|
||||
func Test_edgeJobCreate_FileMethod_Success(t *testing.T) {
|
||||
store := initStore(t)
|
||||
|
||||
fileService := &mockFileService{}
|
||||
fileService.On("StoreEdgeJobFileFromBytes", mock.Anything, mock.Anything).Return("testfile.txt", nil)
|
||||
|
||||
handler := &Handler{
|
||||
DataStore: store,
|
||||
FileService: fileService,
|
||||
}
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
require.NoError(t, writer.WriteField("Name", "testjob"))
|
||||
require.NoError(t, writer.WriteField("CronExpression", "* * * * *"))
|
||||
require.NoError(t, writer.WriteField("Endpoints", "[1,2]"))
|
||||
|
||||
fileWriter, err := writer.CreateFormFile("file", "test.txt")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = io.Copy(fileWriter, strings.NewReader("echo hello"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, writer.Close())
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/edge_jobs/create/file", &body)
|
||||
req = mux.SetURLVars(req, map[string]string{"method": "file"})
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handlerErr := handler.edgeJobCreate(w, req)
|
||||
require.Nil(t, handlerErr)
|
||||
require.Equal(t, http.StatusOK, w.Result().StatusCode)
|
||||
|
||||
var resp struct {
|
||||
ID int `json:"Id"`
|
||||
}
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
|
||||
edgeJob, err := store.EdgeJob().Read(portainer.EdgeJobID(resp.ID))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, edgeJob.Endpoints, 2)
|
||||
require.Contains(t, edgeJob.Endpoints, portainer.EndpointID(1))
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
package edgejobs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"maps"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
@ -35,18 +34,11 @@ func (handler *Handler) edgeJobDelete(w http.ResponseWriter, r *http.Request) *h
|
|||
return httperror.BadRequest("Invalid Edge job identifier route variable", err)
|
||||
}
|
||||
|
||||
if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
return handler.deleteEdgeJob(tx, portainer.EdgeJobID(edgeJobID))
|
||||
}); err != nil {
|
||||
var handlerError *httperror.HandlerError
|
||||
if errors.As(err, &handlerError) {
|
||||
return handlerError
|
||||
}
|
||||
})
|
||||
|
||||
return httperror.InternalServerError("Unexpected error", err)
|
||||
}
|
||||
|
||||
return response.Empty(w)
|
||||
return response.TxEmptyResponse(w, err)
|
||||
}
|
||||
|
||||
func (handler *Handler) deleteEdgeJob(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID) error {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package edgejobs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
@ -54,7 +53,7 @@ func (handler *Handler) edgeJobTasksClear(w http.ResponseWriter, r *http.Request
|
|||
}
|
||||
}
|
||||
|
||||
if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
updateEdgeJobFn := func(edgeJob *portainer.EdgeJob, endpointID portainer.EndpointID, endpointsFromGroups []portainer.EndpointID) error {
|
||||
mutationFn(edgeJob, endpointID, endpointsFromGroups)
|
||||
|
||||
|
@ -62,16 +61,9 @@ func (handler *Handler) edgeJobTasksClear(w http.ResponseWriter, r *http.Request
|
|||
}
|
||||
|
||||
return handler.clearEdgeJobTaskLogs(tx, portainer.EdgeJobID(edgeJobID), portainer.EndpointID(taskID), updateEdgeJobFn)
|
||||
}); err != nil {
|
||||
var handlerError *httperror.HandlerError
|
||||
if errors.As(err, &handlerError) {
|
||||
return handlerError
|
||||
}
|
||||
})
|
||||
|
||||
return httperror.InternalServerError("Unexpected error", err)
|
||||
}
|
||||
|
||||
return response.Empty(w)
|
||||
return response.TxEmptyResponse(w, err)
|
||||
}
|
||||
|
||||
func (handler *Handler) clearEdgeJobTaskLogs(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID, endpointID portainer.EndpointID, updateEdgeJob func(*portainer.EdgeJob, portainer.EndpointID, []portainer.EndpointID) error) error {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package edgejobs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"slices"
|
||||
|
||||
|
@ -39,7 +38,7 @@ func (handler *Handler) edgeJobTasksCollect(w http.ResponseWriter, r *http.Reque
|
|||
return httperror.BadRequest("Invalid Task identifier route variable", err)
|
||||
}
|
||||
|
||||
if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
edgeJob, err := tx.EdgeJob().Read(portainer.EdgeJobID(edgeJobID))
|
||||
if tx.IsErrObjectNotFound(err) {
|
||||
return httperror.NotFound("Unable to find an Edge job with the specified identifier inside the database", err)
|
||||
|
@ -81,14 +80,7 @@ func (handler *Handler) edgeJobTasksCollect(w http.ResponseWriter, r *http.Reque
|
|||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
var handlerError *httperror.HandlerError
|
||||
if errors.As(err, &handlerError) {
|
||||
return handlerError
|
||||
}
|
||||
})
|
||||
|
||||
return httperror.InternalServerError("Unexpected error", err)
|
||||
}
|
||||
|
||||
return response.Empty(w)
|
||||
return response.TxEmptyResponse(w, err)
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/portainer/portainer/api/internal/edge"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
)
|
||||
|
||||
type taskContainer struct {
|
||||
|
@ -49,31 +50,33 @@ func (handler *Handler) edgeJobTasksList(w http.ResponseWriter, r *http.Request)
|
|||
return err
|
||||
})
|
||||
|
||||
results := filters.SearchOrderAndPaginate(tasks, params, filters.Config[*taskContainer]{
|
||||
SearchAccessors: []filters.SearchAccessor[*taskContainer]{
|
||||
func(tc *taskContainer) (string, error) {
|
||||
switch tc.LogsStatus {
|
||||
case portainer.EdgeJobLogsStatusPending:
|
||||
return "pending", nil
|
||||
case 0, portainer.EdgeJobLogsStatusIdle:
|
||||
return "idle", nil
|
||||
case portainer.EdgeJobLogsStatusCollected:
|
||||
return "collected", nil
|
||||
}
|
||||
return "", errors.New("unknown state")
|
||||
return response.TxFuncResponse(err, func() *httperror.HandlerError {
|
||||
results := filters.SearchOrderAndPaginate(tasks, params, filters.Config[*taskContainer]{
|
||||
SearchAccessors: []filters.SearchAccessor[*taskContainer]{
|
||||
func(tc *taskContainer) (string, error) {
|
||||
switch tc.LogsStatus {
|
||||
case portainer.EdgeJobLogsStatusPending:
|
||||
return "pending", nil
|
||||
case 0, portainer.EdgeJobLogsStatusIdle:
|
||||
return "idle", nil
|
||||
case portainer.EdgeJobLogsStatusCollected:
|
||||
return "collected", nil
|
||||
}
|
||||
return "", errors.New("unknown state")
|
||||
},
|
||||
func(tc *taskContainer) (string, error) {
|
||||
return tc.EndpointName, nil
|
||||
},
|
||||
},
|
||||
func(tc *taskContainer) (string, error) {
|
||||
return tc.EndpointName, nil
|
||||
SortBindings: []filters.SortBinding[*taskContainer]{
|
||||
{Key: "EndpointName", Fn: func(a, b *taskContainer) int { return strings.Compare(a.EndpointName, b.EndpointName) }},
|
||||
},
|
||||
},
|
||||
SortBindings: []filters.SortBinding[*taskContainer]{
|
||||
{Key: "EndpointName", Fn: func(a, b *taskContainer) int { return strings.Compare(a.EndpointName, b.EndpointName) }},
|
||||
},
|
||||
})
|
||||
|
||||
filters.ApplyFilterResultsHeaders(&w, results)
|
||||
|
||||
return response.JSON(w, results.Items)
|
||||
})
|
||||
|
||||
filters.ApplyFilterResultsHeaders(&w, results)
|
||||
|
||||
return txResponse(w, results.Items, err)
|
||||
}
|
||||
|
||||
func listEdgeJobTasks(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID) ([]*taskContainer, error) {
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
"github.com/portainer/portainer/api/roar"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -68,14 +69,16 @@ func Test_EdgeJobTasksListHandler(t *testing.T) {
|
|||
|
||||
tcStr := rr.Header().Get("x-total-count")
|
||||
assert.NotEmpty(t, tcStr)
|
||||
|
||||
totalCount, err := strconv.Atoi(tcStr)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedCount, totalCount)
|
||||
|
||||
taStr := rr.Header().Get("x-total-available")
|
||||
assert.NotEmpty(t, taStr)
|
||||
|
||||
totalAvailable, err := strconv.Atoi(taStr)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, envCount, totalAvailable)
|
||||
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/request"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
"github.com/portainer/portainer/pkg/validate"
|
||||
)
|
||||
|
||||
|
@ -66,7 +67,7 @@ func (handler *Handler) edgeJobUpdate(w http.ResponseWriter, r *http.Request) *h
|
|||
return err
|
||||
})
|
||||
|
||||
return txResponse(w, edgeJob, err)
|
||||
return response.TxResponse(w, edgeJob, err)
|
||||
}
|
||||
|
||||
func (handler *Handler) updateEdgeJob(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID, payload edgeJobUpdatePayload) (*portainer.EdgeJob, error) {
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
package edgejobs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/http/security"
|
||||
httperror "github.com/portainer/portainer/pkg/libhttp/error"
|
||||
"github.com/portainer/portainer/pkg/libhttp/response"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
@ -60,16 +58,3 @@ func convertEndpointsToMetaObject(endpoints []portainer.EndpointID) map[portaine
|
|||
|
||||
return endpointsMap
|
||||
}
|
||||
|
||||
func txResponse(w http.ResponseWriter, r any, err error) *httperror.HandlerError {
|
||||
if err != nil {
|
||||
var handlerError *httperror.HandlerError
|
||||
if errors.As(err, &handlerError) {
|
||||
return handlerError
|
||||
}
|
||||
|
||||
return httperror.InternalServerError("Unexpected error", err)
|
||||
}
|
||||
|
||||
return response.JSON(w, r)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package edgestacks
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
|
@ -30,18 +29,11 @@ func (handler *Handler) edgeStackDelete(w http.ResponseWriter, r *http.Request)
|
|||
return httperror.BadRequest("Invalid edge stack identifier route variable", err)
|
||||
}
|
||||
|
||||
if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
return handler.deleteEdgeStack(tx, portainer.EdgeStackID(edgeStackID))
|
||||
}); err != nil {
|
||||
var httpErr *httperror.HandlerError
|
||||
if errors.As(err, &httpErr) {
|
||||
return httpErr
|
||||
}
|
||||
})
|
||||
|
||||
return httperror.InternalServerError("Unexpected error", err)
|
||||
}
|
||||
|
||||
return response.Empty(w)
|
||||
return response.TxEmptyResponse(w, err)
|
||||
}
|
||||
|
||||
func (handler *Handler) deleteEdgeStack(tx dataservices.DataStoreTx, edgeStackID portainer.EdgeStackID) error {
|
||||
|
|
|
@ -96,12 +96,7 @@ func (handler *Handler) edgeStackStatusUpdate(w http.ResponseWriter, r *http.Req
|
|||
|
||||
return nil
|
||||
}); err != nil {
|
||||
var httpErr *httperror.HandlerError
|
||||
if errors.As(err, &httpErr) {
|
||||
return httpErr
|
||||
}
|
||||
|
||||
return httperror.InternalServerError("Unexpected error", err)
|
||||
return response.TxErrorResponse(err)
|
||||
}
|
||||
|
||||
if ok, _ := strconv.ParseBool(r.Header.Get("X-Portainer-No-Body")); ok {
|
||||
|
|
|
@ -66,12 +66,7 @@ func (handler *Handler) edgeStackUpdate(w http.ResponseWriter, r *http.Request)
|
|||
stack, err = handler.updateEdgeStack(tx, portainer.EdgeStackID(stackID), payload)
|
||||
return err
|
||||
}); err != nil {
|
||||
var httpErr *httperror.HandlerError
|
||||
if errors.As(err, &httpErr) {
|
||||
return httpErr
|
||||
}
|
||||
|
||||
return httperror.InternalServerError("Unexpected error", err)
|
||||
return response.TxErrorResponse(err)
|
||||
}
|
||||
|
||||
if err := fillEdgeStackStatus(handler.DataStore, stack); err != nil {
|
||||
|
@ -99,7 +94,7 @@ func (handler *Handler) updateEdgeStack(tx dataservices.DataStoreTx, stackID por
|
|||
|
||||
groupsIds := stack.EdgeGroups
|
||||
if payload.EdgeGroups != nil {
|
||||
newRelated, _, err := handler.handleChangeEdgeGroups(tx, stack.ID, payload.EdgeGroups, relatedEndpointIds, relationConfig)
|
||||
newRelated, _, err := handler.handleChangeEdgeGroups(tx, stack, payload.EdgeGroups, relatedEndpointIds, relationConfig)
|
||||
if err != nil {
|
||||
return nil, httperror.InternalServerError("Unable to handle edge groups change", err)
|
||||
}
|
||||
|
@ -136,7 +131,7 @@ func (handler *Handler) updateEdgeStack(tx dataservices.DataStoreTx, stackID por
|
|||
return stack, nil
|
||||
}
|
||||
|
||||
func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edgeStackID portainer.EdgeStackID, newEdgeGroupsIDs []portainer.EdgeGroupID, oldRelatedEnvironmentIDs []portainer.EndpointID, relationConfig *edge.EndpointRelationsConfig) ([]portainer.EndpointID, set.Set[portainer.EndpointID], error) {
|
||||
func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edgeStack *portainer.EdgeStack, newEdgeGroupsIDs []portainer.EdgeGroupID, oldRelatedEnvironmentIDs []portainer.EndpointID, relationConfig *edge.EndpointRelationsConfig) ([]portainer.EndpointID, set.Set[portainer.EndpointID], error) {
|
||||
newRelatedEnvironmentIDs, err := edge.EdgeStackRelatedEndpoints(newEdgeGroupsIDs, relationConfig.Endpoints, relationConfig.EndpointGroups, relationConfig.EdgeGroups)
|
||||
if err != nil {
|
||||
return nil, nil, errors.WithMessage(err, "Unable to retrieve edge stack related environments from database")
|
||||
|
@ -149,13 +144,13 @@ func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edge
|
|||
relatedEnvironmentsToRemove := oldRelatedEnvironmentsSet.Difference(newRelatedEnvironmentsSet)
|
||||
|
||||
if len(relatedEnvironmentsToRemove) > 0 {
|
||||
if err := tx.EndpointRelation().RemoveEndpointRelationsForEdgeStack(relatedEnvironmentsToRemove.Keys(), edgeStackID); err != nil {
|
||||
if err := tx.EndpointRelation().RemoveEndpointRelationsForEdgeStack(relatedEnvironmentsToRemove.Keys(), edgeStack.ID); err != nil {
|
||||
return nil, nil, errors.WithMessage(err, "Unable to remove edge stack relations from the database")
|
||||
}
|
||||
}
|
||||
|
||||
if len(relatedEnvironmentsToAdd) > 0 {
|
||||
if err := tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEnvironmentsToAdd.Keys(), edgeStackID); err != nil {
|
||||
if err := tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEnvironmentsToAdd.Keys(), edgeStack); err != nil {
|
||||
return nil, nil, errors.WithMessage(err, "Unable to add edge stack relations to the database")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,9 @@ import (
|
|||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/internal/testhelpers"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_hasKubeEndpoint(t *testing.T) {
|
||||
|
@ -40,7 +42,7 @@ func Test_hasKubeEndpoint(t *testing.T) {
|
|||
for _, test := range tests {
|
||||
|
||||
ans, err := hasKubeEndpoint(datastore.Endpoint(), test.endpointIds)
|
||||
assert.NoError(t, err, "hasKubeEndpoint shouldn't fail")
|
||||
require.NoError(t, err, "hasKubeEndpoint shouldn't fail")
|
||||
|
||||
assert.Equal(t, test.expected, ans, "hasKubeEndpoint expected to return %b for %v, but returned %b", test.expected, test.endpointIds, ans)
|
||||
}
|
||||
|
@ -50,7 +52,7 @@ func Test_hasKubeEndpoint_failWhenEndpointDontExist(t *testing.T) {
|
|||
datastore := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{}))
|
||||
|
||||
_, err := hasKubeEndpoint(datastore.Endpoint(), []portainer.EndpointID{1})
|
||||
assert.Error(t, err, "hasKubeEndpoint should fail")
|
||||
require.Error(t, err, "hasKubeEndpoint should fail")
|
||||
}
|
||||
|
||||
func Test_hasDockerEndpoint(t *testing.T) {
|
||||
|
@ -85,7 +87,7 @@ func Test_hasDockerEndpoint(t *testing.T) {
|
|||
for _, test := range tests {
|
||||
|
||||
ans, err := hasDockerEndpoint(datastore.Endpoint(), test.endpointIds)
|
||||
assert.NoError(t, err, "hasDockerEndpoint shouldn't fail")
|
||||
require.NoError(t, err, "hasDockerEndpoint shouldn't fail")
|
||||
|
||||
assert.Equal(t, test.expected, ans, "hasDockerEndpoint expected to return %b for %v, but returned %b", test.expected, test.endpointIds, ans)
|
||||
}
|
||||
|
@ -95,5 +97,5 @@ func Test_hasDockerEndpoint_failWhenEndpointDontExist(t *testing.T) {
|
|||
datastore := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{}))
|
||||
|
||||
_, err := hasDockerEndpoint(datastore.Endpoint(), []portainer.EndpointID{1})
|
||||
assert.Error(t, err, "hasDockerEndpoint should fail")
|
||||
require.Error(t, err, "hasDockerEndpoint should fail")
|
||||
}
|
||||
|
|
|
@ -49,26 +49,18 @@ func (payload *endpointGroupCreatePayload) Validate(r *http.Request) error {
|
|||
// @router /endpoint_groups [post]
|
||||
func (handler *Handler) endpointGroupCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
|
||||
var payload endpointGroupCreatePayload
|
||||
err := request.DecodeAndValidateJSONPayload(r, &payload)
|
||||
if err != nil {
|
||||
if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
|
||||
return httperror.BadRequest("Invalid request payload", err)
|
||||
}
|
||||
|
||||
var endpointGroup *portainer.EndpointGroup
|
||||
var err error
|
||||
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
endpointGroup, err = handler.createEndpointGroup(tx, payload)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
var httpErr *httperror.HandlerError
|
||||
if errors.As(err, &httpErr) {
|
||||
return httpErr
|
||||
}
|
||||
|
||||
return httperror.InternalServerError("Unexpected error", err)
|
||||
}
|
||||
|
||||
return response.JSON(w, endpointGroup)
|
||||
return response.TxResponse(w, endpointGroup, err)
|
||||
}
|
||||
|
||||
func (handler *Handler) createEndpointGroup(tx dataservices.DataStoreTx, payload endpointGroupCreatePayload) (*portainer.EndpointGroup, error) {
|
||||
|
|
|
@ -37,16 +37,8 @@ func (handler *Handler) endpointGroupDelete(w http.ResponseWriter, r *http.Reque
|
|||
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
return handler.deleteEndpointGroup(tx, portainer.EndpointGroupID(endpointGroupID))
|
||||
})
|
||||
if err != nil {
|
||||
var httpErr *httperror.HandlerError
|
||||
if errors.As(err, &httpErr) {
|
||||
return httpErr
|
||||
}
|
||||
|
||||
return httperror.InternalServerError("Unexpected error", err)
|
||||
}
|
||||
|
||||
return response.Empty(w)
|
||||
return response.TxEmptyResponse(w, err)
|
||||
}
|
||||
|
||||
func (handler *Handler) deleteEndpointGroup(tx dataservices.DataStoreTx, endpointGroupID portainer.EndpointGroupID) error {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package endpointgroups
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
@ -39,16 +38,8 @@ func (handler *Handler) endpointGroupAddEndpoint(w http.ResponseWriter, r *http.
|
|||
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
return handler.addEndpoint(tx, portainer.EndpointGroupID(endpointGroupID), portainer.EndpointID(endpointID))
|
||||
})
|
||||
if err != nil {
|
||||
var httpErr *httperror.HandlerError
|
||||
if errors.As(err, &httpErr) {
|
||||
return httpErr
|
||||
}
|
||||
|
||||
return httperror.InternalServerError("Unexpected error", err)
|
||||
}
|
||||
|
||||
return response.Empty(w)
|
||||
return response.TxEmptyResponse(w, err)
|
||||
}
|
||||
|
||||
func (handler *Handler) addEndpoint(tx dataservices.DataStoreTx, endpointGroupID portainer.EndpointGroupID, endpointID portainer.EndpointID) error {
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue