Compare commits

...

63 Commits

Author SHA1 Message Date
Ali 657cd04af2 fix(cve): fix frontend CVEs [r8s-563] (#1239) 2025-09-22 10:15:29 +12:00
Oscar Zhou 24a092836b fix(activitylog): remove export limit and fix search function [BE-12270] (#1235) 2025-09-19 14:52:33 +12:00
andres-portainer 290374f6fc fix(kubernetes/cli): unexport a field BE-12259 (#1228) 2025-09-18 14:39:38 -03:00
andres-portainer 2e7acc73d8 fix(kubernetes/cli): fix a data-race BE-12259 (#1218) 2025-09-18 09:19:29 -03:00
Oscar Zhou 666d51482e fix(container): apply less accurate solution to calculate container status for swarm environment [BE-12256] (#1225) 2025-09-18 16:29:35 +12:00
Oscar Zhou eedf37d18a feat(edge): add option to allow always clone git repository [BE-12240] (#1215) 2025-09-17 18:25:42 +12:00
Viktor Pettersson 16f210966b fix(version): change API version support from LTS to STS (#1223) 2025-09-17 17:18:03 +12:00
andres-portainer 30e70b6327 chore(version): bump to v2.34.0 (#1216) 2025-09-15 22:13:51 -03:00
andres-portainer f91a2e3b65 fix(csp): update the Content-Security-Policy header BE-12228 (#1201) 2025-09-15 10:47:50 -03:00
Ali fdc405c912 feat(docker-networks): allow ipv6 for ipvlan networks [portainer-pr12608] (#1196)
Co-authored-by: ar0311 <arogers0311@gmail.com>
2025-09-15 11:49:06 +12:00
Phil Calder 2f2e70bb86 Fix typo (#1186) 2025-09-13 14:31:52 +12:00
andres-portainer eef54f4153 chore(golangci-lint): add forward-looking static checking rules BE-12183 (#1200) 2025-09-12 16:54:30 -03:00
LP B ad1c015f01 fix(api/custom-templates): UAC-allowed users cannot fetch custom template details (#1113) 2025-09-11 16:08:52 +02:00
LP B 326fdcf6ea refactor(api): remove duplicates of TxResponse + HandlerError detection (#1117) 2025-09-11 11:33:30 +02:00
Malcolm Lockyer 26a0c4e809 fix(encryption): set correct default secret key path [r8s-555] (#1182)
Co-authored-by: Gorbasch <57012534+mbegerau@users.noreply.github.com>
2025-09-11 16:32:43 +12:00
Ali acb465ae33 fix(node): revert table css selector, add new specific selector [r8s-331] (#1170) 2025-09-11 10:53:35 +12:00
andres-portainer 5418a0bee6 fix(mingit): remove mingit BE-12245 (#1177) 2025-09-10 15:01:12 -03:00
andres-portainer a59815264d fix(csp): add google.com to the CSP header BE-12228 (#1175) 2025-09-10 15:00:25 -03:00
Viktor Pettersson 3ac0be4e35 chore(gomod): add `go mod tidy` checks in the CI BE-12233 (#1151) 2025-09-10 08:28:58 +12:00
Ali feae930293 fix(node): allow switching tabs [r8s-546] (#1161) 2025-09-10 08:17:40 +12:00
LP B 7ebb52ec6d fix(api/container): standard users cannot connect or disconnect containers to networks (#1118) 2025-09-09 22:07:19 +02:00
Ali 8b73ad3b6f chore(kubernetes): node view react migration [r8s-331] (#746) 2025-09-08 22:51:32 +12:00
Ali 6fc2a8234d fix(registry): allow trusted tls custom registries [r8s-489] (#1116) 2025-09-08 09:28:40 +12:00
Ali e2c2724e36 fix(helm): update helm repo validation to match helm cli [r8s-531] (#1141) 2025-09-08 08:58:04 +12:00
Malcolm Lockyer 6abfbe8553 fix(fips): encrypt the chisel private key file for fips [be-12132] (#1143) 2025-09-05 13:17:30 +12:00
andres-portainer 54f6add45d fix(compose): fix a data race in a test BE-12231 (#1148) 2025-09-04 17:31:57 -03:00
andres-portainer f8ae5368bf fix(git): add a minimum interval validation BE-12220 (#1144) 2025-09-04 15:11:12 -03:00
andres-portainer 2ba348551d fix(scheduler): fix a data race in the job scheduler BE-12229 (#1146) 2025-09-04 15:09:52 -03:00
andres-portainer 110f88f22d chore(endpointutils): remove unnecessary field BE-10415 (#1136) 2025-09-04 11:22:46 -03:00
James Player c90a15dd0f refactor(app/repository): migrate edit repository view to React [R8S-332] (#768) 2025-09-04 16:27:39 +12:00
andres-portainer f4335e1e72 fix(registries): clear sensitive fields in the update handler BE-12215 (#1128) 2025-09-02 15:44:09 -03:00
andres-portainer 8d9e1a0ad5 fix(csp): add object-src to the CSP header BE-12217 (#1126) 2025-09-02 11:39:46 -03:00
andres-portainer 48dcfcb08f fix(forbidigo): add more rules to avoid skipping TLS verifications BE-11973 (#1123) 2025-09-01 16:57:22 -03:00
andres-portainer def19be230 fix(depguard): mitigate improper usage of openpgp BE-11977 (#1122) 2025-09-01 14:44:45 -03:00
andres-portainer 36154e9d33 fix(depguard): add a rule against golang.org/x/crypto BE-11978 (#1119) 2025-09-01 10:54:24 -03:00
Oscar Zhou 7cf6bb78d6 fix(container): inaccurate healthy container count [BE-2290] (#1114) 2025-09-01 17:01:13 +12:00
Cara Ryan 541f281b29 fix(kubernetes): Namespace resource limits and requests display consistent value (#1055) 2025-09-01 10:25:53 +12:00
Viktor Pettersson 965ef5246b feat(autopatch): implement OCI registry patch finder BE-12111 (#1044) 2025-08-27 19:04:41 +12:00
James Carppe 9c88057bd1 Updates for release 2.33.1 (#1109) 2025-08-27 16:56:01 +12:00
andres-portainer 8c52e92705 chore(bbolt): upgrade bbolt to v1.4.3 BE-12193 (#1103) 2025-08-25 15:51:56 -03:00
Devon Steenberg 3a727d24ce fix(sslflags): Deprecate ssl flags [BE-12168] (#1075) 2025-08-25 14:35:55 +12:00
Malcolm Lockyer 185558a642 fix(standard): manual endpoint refresh fails to save new status [be-12188] (#1092) 2025-08-25 13:49:17 +12:00
Ali 35aa525bd2 fix(environments): create k8s specific edge agent before connecting [r8s-438] (#1088)
Merging because this change is unrelated to the failing kubernetes/tests/helm-oci.spec.ts tests
2025-08-25 09:32:10 +12:00
Oscar Zhou 2ce8788487 fix(autoupdate): update tooltips in edge stack gitops update [BE-12177] (#1084) 2025-08-23 10:56:04 +12:00
andres-portainer ec0e98a64b chore(linters): enable testifylint BE-12183 (#1091) 2025-08-22 15:31:10 -03:00
Steven Kang 121e9f03a4 fix: GHSA-2464-8j7c-4cjm - develop [R8S-495] (#1087) 2025-08-22 14:03:13 +12:00
andres-portainer a0295b1a39 chore(go): upgrade Go to v1.25.0 BE-12181 (#1071) 2025-08-20 12:55:06 -03:00
andres-portainer 30aba86380 chore(benchmarks): use b.Loop() BE-12182 (#1072) 2025-08-20 12:54:26 -03:00
James Carppe 89f5a20786 Updates for release 2.33.0 (#1067) 2025-08-20 15:35:58 +12:00
James Player ef7caa260b fix(UI): add experimental features back in [r8s-483] (#1061) 2025-08-19 16:55:24 +12:00
Steven Kang 39d50ef70e fix: cve-2025-55198 and cve-2025-55199 - develop [R8S-482] (#1057) 2025-08-19 16:22:52 +12:00
James Player 58a1392480 fix(helm): support http and custom tls helm registries, give help when misconfigured - develop [r8s-472] (#1050)
Co-authored-by: testA113 <aliharriss1995@gmail.com>
2025-08-19 13:32:32 +12:00
James Player 06f6bcc340 fix(ui): Fixed react-select TooManyResultsSelector filter and improved scrolling (#1024) 2025-08-19 09:35:00 +12:00
LP B c9d18b614b fix(api/edge-stacks): avoid overriding updates with old values (#1047) 2025-08-16 03:52:13 +02:00
andres-portainer 2035c42c3c fix(migrator): rewrite a migration so it is idempotent BE-12053 (#1042) 2025-08-15 09:26:10 -03:00
Malcolm Lockyer a760426b87 fix(fips): use standard lib pbkdf2 [be-12164] (#1038) 2025-08-15 11:44:35 +12:00
andres-portainer 10b129a02e fix(crypto): replace fips140 calls with fips calls BE-11979 (#1033) 2025-08-14 19:36:15 -03:00
Cara Ryan 129b9d5db9 fix(pending-actions): Small improvements to pending actions (R8S-350) (#949) 2025-08-15 10:07:51 +12:00
andres-portainer 2c08becf6c feat(openai): remove OpenAI BE-12018 (#873) 2025-08-14 10:42:21 -03:00
Ali a3bfe7cb0c fix(logs): improve log rendering performance [r8s-437] (#993) 2025-08-14 13:55:37 +12:00
andres-portainer 7049a8a2bb fix(linters): add many linters BE-12112 (#1009) 2025-08-13 19:42:24 -03:00
LP B 1197b1dd8d feat(api): Permissions-Policy header deny all (#1021) 2025-08-13 22:07:55 +02:00
andres-portainer 7f167ff2fc fix(auth): remove a nil pointer dereference BE-12149 (#1014) 2025-08-13 13:20:56 -03:00
357 changed files with 9431 additions and 3151 deletions

View File

@ -6,7 +6,7 @@ body:
Thanks for suggesting an idea for Portainer! Thanks for suggesting an idea for Portainer!
Before opening a new idea or feature request, make sure that we do not have any duplicates already open. You can ensure this by [searching this discussion cagetory](https://github.com/orgs/portainer/discussions/categories/ideas). If there is a duplicate, please add a comment to the existing idea instead. Before opening a new idea or feature request, make sure that we do not have any duplicates already open. You can ensure this by [searching this discussion category](https://github.com/orgs/portainer/discussions/categories/ideas). If there is a duplicate, please add a comment to the existing idea instead.
Also, be sure to check our [knowledge base](https://portal.portainer.io/knowledge) and [documentation](https://docs.portainer.io) as they may point you toward a solution. Also, be sure to check our [knowledge base](https://portal.portainer.io/knowledge) and [documentation](https://docs.portainer.io) as they may point you toward a solution.

View File

@ -94,6 +94,9 @@ body:
description: We only provide support for current versions of Portainer as per the lifecycle policy linked above. If you are on an older version of Portainer we recommend [updating first](https://docs.portainer.io/start/upgrade) in case your bug has already been fixed. description: We only provide support for current versions of Portainer as per the lifecycle policy linked above. If you are on an older version of Portainer we recommend [updating first](https://docs.portainer.io/start/upgrade) in case your bug has already been fixed.
multiple: false multiple: false
options: options:
- '2.34.0'
- '2.33.1'
- '2.33.0'
- '2.32.0' - '2.32.0'
- '2.31.3' - '2.31.3'
- '2.31.2' - '2.31.2'

11
.golangci-forward.yaml Normal file
View File

@ -0,0 +1,11 @@
version: "2"
linters:
default: none
enable:
- forbidigo
settings:
forbidigo:
forbid:
- pattern: ^dataservices.DataStore.(EdgeGroup|EdgeJob|EdgeStack|EndpointRelation|Endpoint|GitCredential|Registry|ResourceControl|Role|Settings|Snapshot|Stack|Tag|User)$
msg: Use a transaction instead
analyze-types: true

View File

@ -13,6 +13,12 @@ linters:
- perfsprint - perfsprint
- staticcheck - staticcheck
- unused - unused
- mirror
- durationcheck
- errorlint
- govet
- zerologlint
- testifylint
settings: settings:
staticcheck: staticcheck:
checks: ["all", "-ST1003", "-ST1005", "-ST1016", "-SA1019", "-QF1003"] checks: ["all", "-ST1003", "-ST1005", "-ST1016", "-SA1019", "-QF1003"]
@ -32,12 +38,20 @@ linters:
desc: use github.com/portainer/portainer/pkg/libcrypto desc: use github.com/portainer/portainer/pkg/libcrypto
- pkg: github.com/portainer/libhttp - pkg: github.com/portainer/libhttp
desc: use github.com/portainer/portainer/pkg/libhttp desc: use github.com/portainer/portainer/pkg/libhttp
- pkg: golang.org/x/crypto
desc: golang.org/x/crypto is not allowed because of FIPS mode
- pkg: github.com/ProtonMail/go-crypto/openpgp
desc: github.com/ProtonMail/go-crypto/openpgp is not allowed because of FIPS mode
forbidigo: forbidigo:
forbid: forbid:
- pattern: ^tls\.Config$ - pattern: ^tls\.Config$
msg: Use crypto.CreateTLSConfiguration() instead msg: Use crypto.CreateTLSConfiguration() instead
- pattern: ^tls\.Config\.(InsecureSkipVerify|MinVersion|MaxVersion|CipherSuites|CurvePreferences)$ - pattern: ^tls\.Config\.(InsecureSkipVerify|MinVersion|MaxVersion|CipherSuites|CurvePreferences)$
msg: Do not set this field directly, use crypto.CreateTLSConfiguration() instead msg: Do not set this field directly, use crypto.CreateTLSConfiguration() instead
- pattern: ^object\.(Commit|Tag)\.Verify$
msg: "Not allowed because of FIPS mode"
- pattern: ^(types\.SystemContext\.)?(DockerDaemonInsecureSkipTLSVerify|DockerInsecureSkipTLSVerify|OCIInsecureSkipTLSVerify)$
msg: "Not allowed because of FIPS mode"
analyze-types: true analyze-types: true
exclusions: exclusions:
generated: lax generated: lax

View File

@ -54,14 +54,12 @@ client-deps: ## Install client dependencies
tidy: ## Tidy up the go.mod file tidy: ## Tidy up the go.mod file
@go mod tidy @go mod tidy
##@ Cleanup ##@ Cleanup
.PHONY: clean .PHONY: clean
clean: ## Remove all build and download artifacts clean: ## Remove all build and download artifacts
@echo "Clearing the dist directory..." @echo "Clearing the dist directory..."
@rm -rf dist/* @rm -rf dist/*
##@ Testing ##@ Testing
.PHONY: test test-client test-server .PHONY: test test-client test-server
test: test-server test-client ## Run all tests test: test-server test-client ## Run all tests
@ -105,16 +103,15 @@ lint: lint-client lint-server ## Lint all code
lint-client: ## Lint client code lint-client: ## Lint client code
yarn lint yarn lint
lint-server: ## Lint server code lint-server: tidy ## Lint server code
golangci-lint run --timeout=10m -c .golangci.yaml golangci-lint run --timeout=10m -c .golangci.yaml
golangci-lint run --timeout=10m --new-from-rev=HEAD~ -c .golangci-forward.yaml
##@ Extension ##@ Extension
.PHONY: dev-extension .PHONY: dev-extension
dev-extension: build-server build-client ## Run the extension in development mode dev-extension: build-server build-client ## Run the extension in development mode
make local -f build/docker-extension/Makefile make local -f build/docker-extension/Makefile
##@ Docs ##@ Docs
.PHONY: docs-build docs-validate docs-clean docs-validate-clean .PHONY: docs-build docs-validate docs-clean docs-validate-clean
docs-build: init-dist ## Build docs docs-build: init-dist ## Build docs

View File

@ -11,30 +11,30 @@ func Test_generateRandomKey(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
wantLenth int wantLength int
}{ }{
{ {
name: "Generate a random key of length 16", name: "Generate a random key of length 16",
wantLenth: 16, wantLength: 16,
}, },
{ {
name: "Generate a random key of length 32", name: "Generate a random key of length 32",
wantLenth: 32, wantLength: 32,
}, },
{ {
name: "Generate a random key of length 64", name: "Generate a random key of length 64",
wantLenth: 64, wantLength: 64,
}, },
{ {
name: "Generate a random key of length 128", name: "Generate a random key of length 128",
wantLenth: 128, wantLength: 128,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := GenerateRandomKey(tt.wantLenth) got := GenerateRandomKey(tt.wantLength)
is.Equal(tt.wantLenth, len(got)) is.Len(got, tt.wantLength)
}) })
} }

View File

@ -10,9 +10,10 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore" "github.com/portainer/portainer/api/datastore"
"github.com/stretchr/testify/assert"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_SatisfiesAPIKeyServiceInterface(t *testing.T) { func Test_SatisfiesAPIKeyServiceInterface(t *testing.T) {
@ -30,7 +31,7 @@ func Test_GenerateApiKey(t *testing.T) {
t.Run("Successfully generates API key", func(t *testing.T) { t.Run("Successfully generates API key", func(t *testing.T) {
desc := "test-1" desc := "test-1"
rawKey, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, desc) rawKey, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, desc)
is.NoError(err) require.NoError(t, err)
is.NotEmpty(rawKey) is.NotEmpty(rawKey)
is.NotEmpty(apiKey) is.NotEmpty(apiKey)
is.Equal(desc, apiKey.Description) is.Equal(desc, apiKey.Description)
@ -38,7 +39,7 @@ func Test_GenerateApiKey(t *testing.T) {
t.Run("Api key prefix is 7 chars", func(t *testing.T) { t.Run("Api key prefix is 7 chars", func(t *testing.T) {
rawKey, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-2") rawKey, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-2")
is.NoError(err) require.NoError(t, err)
is.Equal(rawKey[:7], apiKey.Prefix) is.Equal(rawKey[:7], apiKey.Prefix)
is.Len(apiKey.Prefix, 7) is.Len(apiKey.Prefix, 7)
@ -46,7 +47,7 @@ func Test_GenerateApiKey(t *testing.T) {
t.Run("Api key has 'ptr_' as prefix", func(t *testing.T) { t.Run("Api key has 'ptr_' as prefix", func(t *testing.T) {
rawKey, _, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-x") rawKey, _, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-x")
is.NoError(err) require.NoError(t, err)
is.Equal(portainerAPIKeyPrefix, "ptr_") is.Equal(portainerAPIKeyPrefix, "ptr_")
is.True(strings.HasPrefix(rawKey, "ptr_")) is.True(strings.HasPrefix(rawKey, "ptr_"))
@ -55,7 +56,7 @@ func Test_GenerateApiKey(t *testing.T) {
t.Run("Successfully caches API key", func(t *testing.T) { t.Run("Successfully caches API key", func(t *testing.T) {
user := portainer.User{ID: 1} user := portainer.User{ID: 1}
_, apiKey, err := service.GenerateApiKey(user, "test-3") _, apiKey, err := service.GenerateApiKey(user, "test-3")
is.NoError(err) require.NoError(t, err)
userFromCache, apiKeyFromCache, ok := service.cache.Get(apiKey.Digest) userFromCache, apiKeyFromCache, ok := service.cache.Get(apiKey.Digest)
is.True(ok) is.True(ok)
@ -65,7 +66,7 @@ func Test_GenerateApiKey(t *testing.T) {
t.Run("Decoded raw api-key digest matches generated digest", func(t *testing.T) { t.Run("Decoded raw api-key digest matches generated digest", func(t *testing.T) {
rawKey, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-4") rawKey, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-4")
is.NoError(err) require.NoError(t, err)
generatedDigest := sha256.Sum256([]byte(rawKey)) generatedDigest := sha256.Sum256([]byte(rawKey))
@ -83,10 +84,10 @@ func Test_GetAPIKey(t *testing.T) {
t.Run("Successfully returns all API keys", func(t *testing.T) { t.Run("Successfully returns all API keys", func(t *testing.T) {
user := portainer.User{ID: 1} user := portainer.User{ID: 1}
_, apiKey, err := service.GenerateApiKey(user, "test-1") _, apiKey, err := service.GenerateApiKey(user, "test-1")
is.NoError(err) require.NoError(t, err)
apiKeyGot, err := service.GetAPIKey(apiKey.ID) apiKeyGot, err := service.GetAPIKey(apiKey.ID)
is.NoError(err) require.NoError(t, err)
is.Equal(apiKey, apiKeyGot) is.Equal(apiKey, apiKeyGot)
}) })
@ -102,12 +103,12 @@ func Test_GetAPIKeys(t *testing.T) {
t.Run("Successfully returns all API keys", func(t *testing.T) { t.Run("Successfully returns all API keys", func(t *testing.T) {
user := portainer.User{ID: 1} user := portainer.User{ID: 1}
_, _, err := service.GenerateApiKey(user, "test-1") _, _, err := service.GenerateApiKey(user, "test-1")
is.NoError(err) require.NoError(t, err)
_, _, err = service.GenerateApiKey(user, "test-2") _, _, err = service.GenerateApiKey(user, "test-2")
is.NoError(err) require.NoError(t, err)
keys, err := service.GetAPIKeys(user.ID) keys, err := service.GetAPIKeys(user.ID)
is.NoError(err) require.NoError(t, err)
is.Len(keys, 2) is.Len(keys, 2)
}) })
} }
@ -122,10 +123,10 @@ func Test_GetDigestUserAndKey(t *testing.T) {
t.Run("Successfully returns user and api key associated to digest", func(t *testing.T) { t.Run("Successfully returns user and api key associated to digest", func(t *testing.T) {
user := portainer.User{ID: 1} user := portainer.User{ID: 1}
_, apiKey, err := service.GenerateApiKey(user, "test-1") _, apiKey, err := service.GenerateApiKey(user, "test-1")
is.NoError(err) require.NoError(t, err)
userGot, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest) userGot, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
is.NoError(err) require.NoError(t, err)
is.Equal(user, userGot) is.Equal(user, userGot)
is.Equal(*apiKey, apiKeyGot) is.Equal(*apiKey, apiKeyGot)
}) })
@ -133,10 +134,10 @@ func Test_GetDigestUserAndKey(t *testing.T) {
t.Run("Successfully caches user and api key associated to digest", func(t *testing.T) { t.Run("Successfully caches user and api key associated to digest", func(t *testing.T) {
user := portainer.User{ID: 1} user := portainer.User{ID: 1}
_, apiKey, err := service.GenerateApiKey(user, "test-1") _, apiKey, err := service.GenerateApiKey(user, "test-1")
is.NoError(err) require.NoError(t, err)
userGot, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest) userGot, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
is.NoError(err) require.NoError(t, err)
is.Equal(user, userGot) is.Equal(user, userGot)
is.Equal(*apiKey, apiKeyGot) is.Equal(*apiKey, apiKeyGot)
@ -158,14 +159,14 @@ func Test_UpdateAPIKey(t *testing.T) {
user := portainer.User{ID: 1} user := portainer.User{ID: 1}
store.User().Create(&user) store.User().Create(&user)
_, apiKey, err := service.GenerateApiKey(user, "test-x") _, apiKey, err := service.GenerateApiKey(user, "test-x")
is.NoError(err) require.NoError(t, err)
apiKey.LastUsed = time.Now().UTC().Unix() apiKey.LastUsed = time.Now().UTC().Unix()
err = service.UpdateAPIKey(apiKey) err = service.UpdateAPIKey(apiKey)
is.NoError(err) require.NoError(t, err)
_, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest) _, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
is.NoError(err) require.NoError(t, err)
log.Debug().Str("wanted", fmt.Sprintf("%+v", apiKey)).Str("got", fmt.Sprintf("%+v", apiKeyGot)).Msg("") log.Debug().Str("wanted", fmt.Sprintf("%+v", apiKey)).Str("got", fmt.Sprintf("%+v", apiKeyGot)).Msg("")
@ -174,7 +175,7 @@ func Test_UpdateAPIKey(t *testing.T) {
t.Run("Successfully updates api-key in cache upon api-key update", func(t *testing.T) { t.Run("Successfully updates api-key in cache upon api-key update", func(t *testing.T) {
_, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-x2") _, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-x2")
is.NoError(err) require.NoError(t, err)
_, apiKeyFromCache, ok := service.cache.Get(apiKey.Digest) _, apiKeyFromCache, ok := service.cache.Get(apiKey.Digest)
is.True(ok) is.True(ok)
@ -184,7 +185,7 @@ func Test_UpdateAPIKey(t *testing.T) {
is.NotEqual(*apiKey, apiKeyFromCache) is.NotEqual(*apiKey, apiKeyFromCache)
err = service.UpdateAPIKey(apiKey) err = service.UpdateAPIKey(apiKey)
is.NoError(err) require.NoError(t, err)
_, updatedAPIKeyFromCache, ok := service.cache.Get(apiKey.Digest) _, updatedAPIKeyFromCache, ok := service.cache.Get(apiKey.Digest)
is.True(ok) is.True(ok)
@ -202,30 +203,30 @@ func Test_DeleteAPIKey(t *testing.T) {
t.Run("Successfully updates the api-key", func(t *testing.T) { t.Run("Successfully updates the api-key", func(t *testing.T) {
user := portainer.User{ID: 1} user := portainer.User{ID: 1}
_, apiKey, err := service.GenerateApiKey(user, "test-1") _, apiKey, err := service.GenerateApiKey(user, "test-1")
is.NoError(err) require.NoError(t, err)
_, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest) _, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
is.NoError(err) require.NoError(t, err)
is.Equal(*apiKey, apiKeyGot) is.Equal(*apiKey, apiKeyGot)
err = service.DeleteAPIKey(apiKey.ID) err = service.DeleteAPIKey(apiKey.ID)
is.NoError(err) require.NoError(t, err)
_, _, err = service.GetDigestUserAndKey(apiKey.Digest) _, _, err = service.GetDigestUserAndKey(apiKey.Digest)
is.Error(err) require.Error(t, err)
}) })
t.Run("Successfully removes api-key from cache upon deletion", func(t *testing.T) { t.Run("Successfully removes api-key from cache upon deletion", func(t *testing.T) {
user := portainer.User{ID: 1} user := portainer.User{ID: 1}
_, apiKey, err := service.GenerateApiKey(user, "test-1") _, apiKey, err := service.GenerateApiKey(user, "test-1")
is.NoError(err) require.NoError(t, err)
_, apiKeyFromCache, ok := service.cache.Get(apiKey.Digest) _, apiKeyFromCache, ok := service.cache.Get(apiKey.Digest)
is.True(ok) is.True(ok)
is.Equal(*apiKey, apiKeyFromCache) is.Equal(*apiKey, apiKeyFromCache)
err = service.DeleteAPIKey(apiKey.ID) err = service.DeleteAPIKey(apiKey.ID)
is.NoError(err) require.NoError(t, err)
_, _, ok = service.cache.Get(apiKey.Digest) _, _, ok = service.cache.Get(apiKey.Digest)
is.False(ok) is.False(ok)
@ -243,10 +244,10 @@ func Test_InvalidateUserKeyCache(t *testing.T) {
// generate api keys // generate api keys
user := portainer.User{ID: 1} user := portainer.User{ID: 1}
_, apiKey1, err := service.GenerateApiKey(user, "test-1") _, apiKey1, err := service.GenerateApiKey(user, "test-1")
is.NoError(err) require.NoError(t, err)
_, apiKey2, err := service.GenerateApiKey(user, "test-2") _, apiKey2, err := service.GenerateApiKey(user, "test-2")
is.NoError(err) require.NoError(t, err)
// verify api keys are present in cache // verify api keys are present in cache
_, apiKeyFromCache, ok := service.cache.Get(apiKey1.Digest) _, apiKeyFromCache, ok := service.cache.Get(apiKey1.Digest)
@ -273,11 +274,11 @@ func Test_InvalidateUserKeyCache(t *testing.T) {
// generate keys for 2 users // generate keys for 2 users
user1 := portainer.User{ID: 1} user1 := portainer.User{ID: 1}
_, apiKey1, err := service.GenerateApiKey(user1, "test-1") _, apiKey1, err := service.GenerateApiKey(user1, "test-1")
is.NoError(err) require.NoError(t, err)
user2 := portainer.User{ID: 2} user2 := portainer.User{ID: 2}
_, apiKey2, err := service.GenerateApiKey(user2, "test-2") _, apiKey2, err := service.GenerateApiKey(user2, "test-2")
is.NoError(err) require.NoError(t, err)
// verify keys in cache // verify keys in cache
_, apiKeyFromCache, ok := service.cache.Get(apiKey1.Digest) _, apiKeyFromCache, ok := service.cache.Get(apiKey1.Digest)

View File

@ -8,15 +8,19 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func listFiles(dir string) []string { func listFiles(dir string) []string {
items := make([]string, 0) items := make([]string, 0)
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if path == dir { if path == dir {
return nil return nil
} }
items = append(items, path) items = append(items, path)
return nil return nil
}) })
@ -26,13 +30,21 @@ func listFiles(dir string) []string {
func Test_shouldCreateArchive(t *testing.T) { func Test_shouldCreateArchive(t *testing.T) {
tmpdir := t.TempDir() tmpdir := t.TempDir()
content := []byte("content") content := []byte("content")
os.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
err := os.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
require.NoError(t, err)
os.MkdirAll(path.Join(tmpdir, "dir"), 0700) os.MkdirAll(path.Join(tmpdir, "dir"), 0700)
os.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600) require.NoError(t, err)
os.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
err = os.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600)
require.NoError(t, err)
err = os.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
require.NoError(t, err)
gzPath, err := TarGzDir(tmpdir) gzPath, err := TarGzDir(tmpdir)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, filepath.Join(tmpdir, filepath.Base(tmpdir)+".tar.gz"), gzPath) assert.Equal(t, filepath.Join(tmpdir, filepath.Base(tmpdir)+".tar.gz"), gzPath)
extractionDir := t.TempDir() extractionDir := t.TempDir()
@ -45,7 +57,8 @@ func Test_shouldCreateArchive(t *testing.T) {
wasExtracted := func(p string) { wasExtracted := func(p string) {
fullpath := path.Join(extractionDir, p) fullpath := path.Join(extractionDir, p)
assert.Contains(t, extractedFiles, fullpath) assert.Contains(t, extractedFiles, fullpath)
copyContent, _ := os.ReadFile(fullpath) copyContent, err := os.ReadFile(fullpath)
require.NoError(t, err)
assert.Equal(t, content, copyContent) assert.Equal(t, content, copyContent)
} }
@ -57,13 +70,21 @@ func Test_shouldCreateArchive(t *testing.T) {
func Test_shouldCreateArchive2(t *testing.T) { func Test_shouldCreateArchive2(t *testing.T) {
tmpdir := t.TempDir() tmpdir := t.TempDir()
content := []byte("content") content := []byte("content")
os.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
os.MkdirAll(path.Join(tmpdir, "dir"), 0700) err := os.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
os.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600) require.NoError(t, err)
os.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
err = os.MkdirAll(path.Join(tmpdir, "dir"), 0700)
require.NoError(t, err)
err = os.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600)
require.NoError(t, err)
err = os.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
require.NoError(t, err)
gzPath, err := TarGzDir(tmpdir) gzPath, err := TarGzDir(tmpdir)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, filepath.Join(tmpdir, filepath.Base(tmpdir)+".tar.gz"), gzPath) assert.Equal(t, filepath.Join(tmpdir, filepath.Base(tmpdir)+".tar.gz"), gzPath)
extractionDir := t.TempDir() extractionDir := t.TempDir()

View File

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestUnzipFile(t *testing.T) { func TestUnzipFile(t *testing.T) {
@ -20,7 +21,7 @@ func TestUnzipFile(t *testing.T) {
err := UnzipFile("./testdata/sample_archive.zip", dir) err := UnzipFile("./testdata/sample_archive.zip", dir)
assert.NoError(t, err) require.NoError(t, err)
archiveDir := dir + "/sample_archive" archiveDir := dir + "/sample_archive"
assert.FileExists(t, filepath.Join(archiveDir, "0.txt")) assert.FileExists(t, filepath.Join(archiveDir, "0.txt"))
assert.FileExists(t, filepath.Join(archiveDir, "0", "1.txt")) assert.FileExists(t, filepath.Join(archiveDir, "0", "1.txt"))

View File

@ -9,8 +9,8 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/alecthomas/kingpin/v2"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gopkg.in/alecthomas/kingpin.v2"
) )
// Service implements the CLIService interface // Service implements the CLIService interface
@ -35,16 +35,9 @@ func CLIFlags() *portainer.CLIFlags {
FeatureFlags: kingpin.Flag("feat", "List of feature flags").Strings(), FeatureFlags: kingpin.Flag("feat", "List of feature flags").Strings(),
EnableEdgeComputeFeatures: kingpin.Flag("edge-compute", "Enable Edge Compute features").Bool(), EnableEdgeComputeFeatures: kingpin.Flag("edge-compute", "Enable Edge Compute features").Bool(),
NoAnalytics: kingpin.Flag("no-analytics", "Disable Analytics in app (deprecated)").Bool(), NoAnalytics: kingpin.Flag("no-analytics", "Disable Analytics in app (deprecated)").Bool(),
TLS: kingpin.Flag("tlsverify", "TLS support").Default(defaultTLS).Bool(),
TLSSkipVerify: kingpin.Flag("tlsskipverify", "Disable TLS server verification").Default(defaultTLSSkipVerify).Bool(), TLSSkipVerify: kingpin.Flag("tlsskipverify", "Disable TLS server verification").Default(defaultTLSSkipVerify).Bool(),
TLSCacert: kingpin.Flag("tlscacert", "Path to the CA").Default(defaultTLSCACertPath).String(),
TLSCert: kingpin.Flag("tlscert", "Path to the TLS certificate file").Default(defaultTLSCertPath).String(),
TLSKey: kingpin.Flag("tlskey", "Path to the TLS key").Default(defaultTLSKeyPath).String(),
HTTPDisabled: kingpin.Flag("http-disabled", "Serve portainer only on https").Default(defaultHTTPDisabled).Bool(), HTTPDisabled: kingpin.Flag("http-disabled", "Serve portainer only on https").Default(defaultHTTPDisabled).Bool(),
HTTPEnabled: kingpin.Flag("http-enabled", "Serve portainer on http").Default(defaultHTTPEnabled).Bool(), HTTPEnabled: kingpin.Flag("http-enabled", "Serve portainer on http").Default(defaultHTTPEnabled).Bool(),
SSL: kingpin.Flag("ssl", "Secure Portainer instance using SSL (deprecated)").Default(defaultSSL).Bool(),
SSLCert: kingpin.Flag("sslcert", "Path to the SSL certificate used to secure the Portainer instance").String(),
SSLKey: kingpin.Flag("sslkey", "Path to the SSL key used to secure the Portainer instance").String(),
Rollback: kingpin.Flag("rollback", "Rollback the database to the previous backup").Bool(), Rollback: kingpin.Flag("rollback", "Rollback the database to the previous backup").Bool(),
SnapshotInterval: kingpin.Flag("snapshot-interval", "Duration between each environment snapshot job").String(), SnapshotInterval: kingpin.Flag("snapshot-interval", "Duration between each environment snapshot job").String(),
AdminPassword: kingpin.Flag("admin-password", "Set admin password with provided hash").String(), AdminPassword: kingpin.Flag("admin-password", "Set admin password with provided hash").String(),
@ -70,8 +63,37 @@ func CLIFlags() *portainer.CLIFlags {
func (Service) ParseFlags(version string) (*portainer.CLIFlags, error) { func (Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
kingpin.Version(version) kingpin.Version(version)
var hasSSLFlag, hasSSLCertFlag, hasSSLKeyFlag bool
sslFlag := kingpin.Flag(
"ssl",
"Secure Portainer instance using SSL (deprecated)",
).Default(defaultSSL).IsSetByUser(&hasSSLFlag)
ssl := sslFlag.Bool()
sslCertFlag := kingpin.Flag(
"sslcert",
"Path to the SSL certificate used to secure the Portainer instance",
).IsSetByUser(&hasSSLCertFlag)
sslCert := sslCertFlag.String()
sslKeyFlag := kingpin.Flag(
"sslkey",
"Path to the SSL key used to secure the Portainer instance",
).IsSetByUser(&hasSSLKeyFlag)
sslKey := sslKeyFlag.String()
flags := CLIFlags() flags := CLIFlags()
var hasTLSFlag, hasTLSCertFlag, hasTLSKeyFlag bool
tlsFlag := kingpin.Flag("tlsverify", "TLS support").Default(defaultTLS).IsSetByUser(&hasTLSFlag)
flags.TLS = tlsFlag.Bool()
tlsCertFlag := kingpin.Flag(
"tlscert",
"Path to the TLS certificate file",
).Default(defaultTLSCertPath).IsSetByUser(&hasTLSCertFlag)
flags.TLSCert = tlsCertFlag.String()
tlsKeyFlag := kingpin.Flag("tlskey", "Path to the TLS key").Default(defaultTLSKeyPath).IsSetByUser(&hasTLSKeyFlag)
flags.TLSKey = tlsKeyFlag.String()
flags.TLSCacert = kingpin.Flag("tlscacert", "Path to the CA").Default(defaultTLSCACertPath).String()
kingpin.Parse() kingpin.Parse()
if !filepath.IsAbs(*flags.Assets) { if !filepath.IsAbs(*flags.Assets) {
@ -83,6 +105,41 @@ func (Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
*flags.Assets = filepath.Join(filepath.Dir(ex), *flags.Assets) *flags.Assets = filepath.Join(filepath.Dir(ex), *flags.Assets)
} }
// If the user didn't provide a tls flag remove the defaults to match previous behaviour
if !hasTLSFlag {
if !hasTLSCertFlag {
*flags.TLSCert = ""
}
if !hasTLSKeyFlag {
*flags.TLSKey = ""
}
}
if hasSSLFlag {
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslFlag.Model().Name, tlsFlag.Model().Name)
if !hasTLSFlag {
flags.TLS = ssl
}
}
if hasSSLCertFlag {
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslCertFlag.Model().Name, tlsCertFlag.Model().Name)
if !hasTLSCertFlag {
flags.TLSCert = sslCert
}
}
if hasSSLKeyFlag {
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslKeyFlag.Model().Name, tlsKeyFlag.Model().Name)
if !hasTLSKeyFlag {
flags.TLSKey = sslKey
}
}
return flags, nil return flags, nil
} }
@ -109,10 +166,6 @@ func displayDeprecationWarnings(flags *portainer.CLIFlags) {
if *flags.NoAnalytics { if *flags.NoAnalytics {
log.Warn().Msg("the --no-analytics flag has been kept to allow migration of instances running a previous version of Portainer with this flag enabled, to version 2.0 where enabling this flag will have no effect") log.Warn().Msg("the --no-analytics flag has been kept to allow migration of instances running a previous version of Portainer with this flag enabled, to version 2.0 where enabling this flag will have no effect")
} }
if *flags.SSL {
log.Warn().Msg("SSL is enabled by default and there is no need for the --ssl flag, it has been kept to allow migration of instances running a previous version of Portainer with this flag enabled")
}
} }
func validateEndpointURL(endpointURL string) error { func validateEndpointURL(endpointURL string) error {

View File

@ -1,9 +1,12 @@
package cli package cli
import ( import (
"io"
"os" "os"
"strings"
"testing" "testing"
zerolog "github.com/rs/zerolog/log"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -22,3 +25,185 @@ func TestOptionParser(t *testing.T) {
require.False(t, *opts.HTTPDisabled) require.False(t, *opts.HTTPDisabled)
require.True(t, *opts.EnableEdgeComputeFeatures) require.True(t, *opts.EnableEdgeComputeFeatures)
} }
func TestParseTLSFlags(t *testing.T) {
testCases := []struct {
name string
args []string
expectedTLSFlag bool
expectedTLSCertFlag string
expectedTLSKeyFlag string
expectedLogMessages []string
}{
{
name: "no flags",
expectedTLSFlag: false,
expectedTLSCertFlag: "",
expectedTLSKeyFlag: "",
},
{
name: "only ssl flag",
args: []string{
"portainer",
"--ssl",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "",
expectedTLSKeyFlag: "",
},
{
name: "only tls flag",
args: []string{
"portainer",
"--tlsverify",
},
expectedTLSFlag: true,
expectedTLSCertFlag: defaultTLSCertPath,
expectedTLSKeyFlag: defaultTLSKeyPath,
},
{
name: "partial ssl flags",
args: []string{
"portainer",
"--ssl",
"--sslcert=ssl-cert-flag-value",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "ssl-cert-flag-value",
expectedTLSKeyFlag: "",
},
{
name: "partial tls flags",
args: []string{
"portainer",
"--tlsverify",
"--tlscert=tls-cert-flag-value",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "tls-cert-flag-value",
expectedTLSKeyFlag: defaultTLSKeyPath,
},
{
name: "partial tls and ssl flags",
args: []string{
"portainer",
"--tlsverify",
"--tlscert=tls-cert-flag-value",
"--sslkey=ssl-key-flag-value",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "tls-cert-flag-value",
expectedTLSKeyFlag: "ssl-key-flag-value",
},
{
name: "partial tls and ssl flags 2",
args: []string{
"portainer",
"--ssl",
"--tlscert=tls-cert-flag-value",
"--sslkey=ssl-key-flag-value",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "tls-cert-flag-value",
expectedTLSKeyFlag: "ssl-key-flag-value",
},
{
name: "ssl flags",
args: []string{
"portainer",
"--ssl",
"--sslcert=ssl-cert-flag-value",
"--sslkey=ssl-key-flag-value",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "ssl-cert-flag-value",
expectedTLSKeyFlag: "ssl-key-flag-value",
expectedLogMessages: []string{
"the \\\"ssl\\\" flag is deprecated. use \\\"tlsverify\\\" instead.",
"the \\\"sslcert\\\" flag is deprecated. use \\\"tlscert\\\" instead.",
"the \\\"sslkey\\\" flag is deprecated. use \\\"tlskey\\\" instead.",
},
},
{
name: "tls flags",
args: []string{
"portainer",
"--tlsverify",
"--tlscert=tls-cert-flag-value",
"--tlskey=tls-key-flag-value",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "tls-cert-flag-value",
expectedTLSKeyFlag: "tls-key-flag-value",
},
{
name: "tls and ssl flags",
args: []string{
"portainer",
"--tlsverify",
"--tlscert=tls-cert-flag-value",
"--tlskey=tls-key-flag-value",
"--ssl",
"--sslcert=ssl-cert-flag-value",
"--sslkey=ssl-key-flag-value",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "tls-cert-flag-value",
expectedTLSKeyFlag: "tls-key-flag-value",
expectedLogMessages: []string{
"the \\\"ssl\\\" flag is deprecated. use \\\"tlsverify\\\" instead.",
"the \\\"sslcert\\\" flag is deprecated. use \\\"tlscert\\\" instead.",
"the \\\"sslkey\\\" flag is deprecated. use \\\"tlskey\\\" instead.",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var logOutput strings.Builder
setupLogOutput(t, &logOutput)
if tc.args == nil {
tc.args = []string{"portainer"}
}
setOsArgs(t, tc.args)
s := Service{}
flags, err := s.ParseFlags("test-version")
if err != nil {
t.Fatalf("error parsing flags: %v", err)
}
if flags.TLS == nil {
t.Fatal("TLS flag was nil")
}
require.Equal(t, tc.expectedTLSFlag, *flags.TLS, "tlsverify flag didn't match")
require.Equal(t, tc.expectedTLSCertFlag, *flags.TLSCert, "tlscert flag didn't match")
require.Equal(t, tc.expectedTLSKeyFlag, *flags.TLSKey, "tlskey flag didn't match")
for _, expectedLogMessage := range tc.expectedLogMessages {
require.Contains(t, logOutput.String(), expectedLogMessage, "Log didn't contain expected message")
}
})
}
}
func setOsArgs(t *testing.T, args []string) {
t.Helper()
previousArgs := os.Args
os.Args = args
t.Cleanup(func() {
os.Args = previousArgs
})
}
func setupLogOutput(t *testing.T, w io.Writer) {
t.Helper()
oldLogger := zerolog.Logger
zerolog.Logger = zerolog.Output(w)
t.Cleanup(func() {
zerolog.Logger = oldLogger
})
}

View File

@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"gopkg.in/alecthomas/kingpin.v2" "github.com/alecthomas/kingpin/v2"
) )
type pairList []portainer.Pair type pairList []portainer.Pair

View File

@ -309,13 +309,13 @@ func initKeyPair(fileService portainer.FileService, signatureService portainer.D
// dbSecretPath build the path to the file that contains the db encryption // dbSecretPath build the path to the file that contains the db encryption
// secret. Normally in Docker this is built from the static path inside // secret. Normally in Docker this is built from the static path inside
// /run/portainer for example: /run/portainer/<keyFilenameFlag> but for ease of // /run/secrets for example: /run/secrets/<keyFilenameFlag> but for ease of
// use outside Docker it also accepts an absolute path // use outside Docker it also accepts an absolute path
func dbSecretPath(keyFilenameFlag string) string { func dbSecretPath(keyFilenameFlag string) string {
if path.IsAbs(keyFilenameFlag) { if path.IsAbs(keyFilenameFlag) {
return keyFilenameFlag return keyFilenameFlag
} }
return path.Join("/run/portainer", keyFilenameFlag) return path.Join("/run/secrets", keyFilenameFlag)
} }
func loadEncryptionSecretKey(keyfilename string) []byte { func loadEncryptionSecretKey(keyfilename string) []byte {
@ -408,7 +408,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
edgeStacksService := edgestacks.NewService(dataStore) edgeStacksService := edgestacks.NewService(dataStore)
sslService, err := initSSLService(*flags.AddrHTTPS, *flags.SSLCert, *flags.SSLKey, fileService, dataStore, shutdownTrigger) sslService, err := initSSLService(*flags.AddrHTTPS, *flags.TLSCert, *flags.TLSKey, fileService, dataStore, shutdownTrigger)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("") log.Fatal().Err(err).Msg("")
} }

View File

@ -43,12 +43,12 @@ func TestDBSecretPath(t *testing.T) {
keyFilenameFlag string keyFilenameFlag string
expected string expected string
}{ }{
{keyFilenameFlag: "secret.txt", expected: "/run/portainer/secret.txt"}, {keyFilenameFlag: "secret.txt", expected: "/run/secrets/secret.txt"},
{keyFilenameFlag: "/tmp/secret.txt", expected: "/tmp/secret.txt"}, {keyFilenameFlag: "/tmp/secret.txt", expected: "/tmp/secret.txt"},
{keyFilenameFlag: "/run/portainer/secret.txt", expected: "/run/portainer/secret.txt"}, {keyFilenameFlag: "/run/secrets/secret.txt", expected: "/run/secrets/secret.txt"},
{keyFilenameFlag: "./secret.txt", expected: "/run/portainer/secret.txt"}, {keyFilenameFlag: "./secret.txt", expected: "/run/secrets/secret.txt"},
{keyFilenameFlag: "../secret.txt", expected: "/run/secret.txt"}, {keyFilenameFlag: "../secret.txt", expected: "/run/secret.txt"},
{keyFilenameFlag: "foo/bar/secret.txt", expected: "/run/portainer/foo/bar/secret.txt"}, {keyFilenameFlag: "foo/bar/secret.txt", expected: "/run/secrets/foo/bar/secret.txt"},
} }
for _, test := range tests { for _, test := range tests {

View File

@ -5,6 +5,7 @@ import (
"bytes" "bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/pbkdf2"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"errors" "errors"
@ -14,9 +15,9 @@ import (
"github.com/portainer/portainer/pkg/fips" "github.com/portainer/portainer/pkg/fips"
"golang.org/x/crypto/argon2" // Not allowed in FIPS mode
"golang.org/x/crypto/pbkdf2" "golang.org/x/crypto/argon2" //nolint:depguard
"golang.org/x/crypto/scrypt" "golang.org/x/crypto/scrypt" //nolint:depguard
) )
const ( const (
@ -248,7 +249,10 @@ func aesEncryptGCMFIPS(input io.Reader, output io.Writer, passphrase []byte) err
return err return err
} }
key := pbkdf2.Key(passphrase, salt, pbkdf2Iterations, 32, sha256.New) key, err := pbkdf2.Key(sha256.New, string(passphrase), salt, pbkdf2Iterations, 32)
if err != nil {
return fmt.Errorf("error deriving key: %w", err)
}
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
if err != nil { if err != nil {
@ -315,7 +319,10 @@ func aesDecryptGCMFIPS(input io.Reader, passphrase []byte) (io.Reader, error) {
return nil, err return nil, err
} }
key := pbkdf2.Key(passphrase, salt, pbkdf2Iterations, 32, sha256.New) key, err := pbkdf2.Key(sha256.New, string(passphrase), salt, pbkdf2Iterations, 32)
if err != nil {
return nil, fmt.Errorf("error deriving key: %w", err)
}
// Initialize AES cipher block // Initialize AES cipher block
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
@ -382,3 +389,18 @@ func aesDecryptOFB(input io.Reader, passphrase []byte) (io.Reader, error) {
return reader, nil return reader, nil
} }
// HasEncryptedHeader checks if the data has an encrypted header, note that fips
// mode changes this behavior and so will only recognize data encrypted by the
// same mode (fips enabled or disabled)
func HasEncryptedHeader(data []byte) bool {
return hasEncryptedHeader(data, fips.FIPSMode())
}
func hasEncryptedHeader(data []byte, fipsMode bool) bool {
if fipsMode {
return bytes.HasPrefix(data, []byte(aesGcmFIPSHeader))
}
return bytes.HasPrefix(data, []byte(aesGcmHeader))
}

View File

@ -55,17 +55,19 @@ func Test_encryptAndDecrypt_withTheSamePassword(t *testing.T) {
encryptedFileWriter, _ := os.Create(encryptedFilePath) encryptedFileWriter, _ := os.Create(encryptedFilePath)
err := encrypt(originFile, encryptedFileWriter, []byte(passphrase)) err := encrypt(originFile, encryptedFileWriter, []byte(passphrase))
require.Nil(t, err, "Failed to encrypt a file") require.NoError(t, err, "Failed to encrypt a file")
encryptedFileWriter.Close() encryptedFileWriter.Close()
encryptedContent, err := os.ReadFile(encryptedFilePath) encryptedContent, err := os.ReadFile(encryptedFilePath)
require.Nil(t, err, "Couldn't read encrypted file") require.NoError(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted") assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
encryptedFileReader, _ := os.Open(encryptedFilePath) encryptedFileReader, err := os.Open(encryptedFilePath)
require.NoError(t, err)
defer encryptedFileReader.Close() defer encryptedFileReader.Close()
decryptedFileWriter, _ := os.Create(decryptedFilePath) decryptedFileWriter, err := os.Create(decryptedFilePath)
require.NoError(t, err)
defer decryptedFileWriter.Close() defer decryptedFileWriter.Close()
decryptedReader, err := decrypt(encryptedFileReader, []byte(passphrase)) decryptedReader, err := decrypt(encryptedFileReader, []byte(passphrase))
@ -155,11 +157,11 @@ func Test_encryptAndDecrypt_withStrongPassphrase(t *testing.T) {
encryptedFileWriter, _ := os.Create(encryptedFilePath) encryptedFileWriter, _ := os.Create(encryptedFilePath)
err := encrypt(originFile, encryptedFileWriter, []byte(passphrase)) err := encrypt(originFile, encryptedFileWriter, []byte(passphrase))
assert.Nil(t, err, "Failed to encrypt a file") require.NoError(t, err, "Failed to encrypt a file")
encryptedFileWriter.Close() encryptedFileWriter.Close()
encryptedContent, err := os.ReadFile(encryptedFilePath) encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file") require.NoError(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted") assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
encryptedFileReader, _ := os.Open(encryptedFilePath) encryptedFileReader, _ := os.Open(encryptedFilePath)
@ -169,7 +171,7 @@ func Test_encryptAndDecrypt_withStrongPassphrase(t *testing.T) {
defer decryptedFileWriter.Close() defer decryptedFileWriter.Close()
decryptedReader, err := decrypt(encryptedFileReader, []byte(passphrase)) decryptedReader, err := decrypt(encryptedFileReader, []byte(passphrase))
assert.Nil(t, err, "Failed to decrypt file") require.NoError(t, err, "Failed to decrypt file")
io.Copy(decryptedFileWriter, decryptedReader) io.Copy(decryptedFileWriter, decryptedReader)
@ -205,25 +207,29 @@ func Test_encryptAndDecrypt_withTheSamePasswordSmallFile(t *testing.T) {
encryptedFileWriter, _ := os.Create(encryptedFilePath) encryptedFileWriter, _ := os.Create(encryptedFilePath)
err := encrypt(originFile, encryptedFileWriter, []byte("passphrase")) err := encrypt(originFile, encryptedFileWriter, []byte("passphrase"))
assert.Nil(t, err, "Failed to encrypt a file") require.NoError(t, err, "Failed to encrypt a file")
encryptedFileWriter.Close() encryptedFileWriter.Close()
encryptedContent, err := os.ReadFile(encryptedFilePath) encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file") require.NoError(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted") assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
encryptedFileReader, _ := os.Open(encryptedFilePath) encryptedFileReader, err := os.Open(encryptedFilePath)
require.NoError(t, err)
defer encryptedFileReader.Close() defer encryptedFileReader.Close()
decryptedFileWriter, _ := os.Create(decryptedFilePath) decryptedFileWriter, err := os.Create(decryptedFilePath)
require.NoError(t, err)
defer decryptedFileWriter.Close() defer decryptedFileWriter.Close()
decryptedReader, err := decrypt(encryptedFileReader, []byte("passphrase")) decryptedReader, err := decrypt(encryptedFileReader, []byte("passphrase"))
assert.Nil(t, err, "Failed to decrypt file") require.NoError(t, err, "Failed to decrypt file")
io.Copy(decryptedFileWriter, decryptedReader) _, err = io.Copy(decryptedFileWriter, decryptedReader)
require.NoError(t, err)
decryptedContent, _ := os.ReadFile(decryptedFilePath) decryptedContent, err := os.ReadFile(decryptedFilePath)
require.NoError(t, err)
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match") assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
} }
@ -247,32 +253,40 @@ func Test_encryptAndDecrypt_withEmptyPassword(t *testing.T) {
) )
content := randBytes(1024 * 50) content := randBytes(1024 * 50)
os.WriteFile(originFilePath, content, 0600) err := os.WriteFile(originFilePath, content, 0600)
require.NoError(t, err)
originFile, _ := os.Open(originFilePath) originFile, err := os.Open(originFilePath)
require.NoError(t, err)
defer originFile.Close() defer originFile.Close()
encryptedFileWriter, _ := os.Create(encryptedFilePath) encryptedFileWriter, err := os.Create(encryptedFilePath)
require.NoError(t, err)
defer encryptedFileWriter.Close() defer encryptedFileWriter.Close()
err := encrypt(originFile, encryptedFileWriter, []byte("")) err = encrypt(originFile, encryptedFileWriter, []byte(""))
assert.Nil(t, err, "Failed to encrypt a file") require.NoError(t, err, "Failed to encrypt a file")
encryptedContent, err := os.ReadFile(encryptedFilePath) encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file") require.NoError(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted") assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
encryptedFileReader, _ := os.Open(encryptedFilePath) encryptedFileReader, err := os.Open(encryptedFilePath)
require.NoError(t, err)
defer encryptedFileReader.Close() defer encryptedFileReader.Close()
decryptedFileWriter, _ := os.Create(decryptedFilePath) decryptedFileWriter, err := os.Create(decryptedFilePath)
require.NoError(t, err)
defer decryptedFileWriter.Close() defer decryptedFileWriter.Close()
decryptedReader, err := decrypt(encryptedFileReader, []byte("")) decryptedReader, err := decrypt(encryptedFileReader, []byte(""))
assert.Nil(t, err, "Failed to decrypt file") require.NoError(t, err, "Failed to decrypt file")
io.Copy(decryptedFileWriter, decryptedReader) _, err = io.Copy(decryptedFileWriter, decryptedReader)
require.NoError(t, err)
decryptedContent, _ := os.ReadFile(decryptedFilePath) decryptedContent, err := os.ReadFile(decryptedFilePath)
require.NoError(t, err)
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match") assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
} }
@ -305,9 +319,9 @@ func Test_decryptWithDifferentPassphrase_shouldProduceWrongResult(t *testing.T)
defer encryptedFileWriter.Close() defer encryptedFileWriter.Close()
err := encrypt(originFile, encryptedFileWriter, []byte("passphrase")) err := encrypt(originFile, encryptedFileWriter, []byte("passphrase"))
assert.Nil(t, err, "Failed to encrypt a file") require.NoError(t, err, "Failed to encrypt a file")
encryptedContent, err := os.ReadFile(encryptedFilePath) encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file") require.NoError(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted") assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
encryptedFileReader, _ := os.Open(encryptedFilePath) encryptedFileReader, _ := os.Open(encryptedFilePath)
@ -317,7 +331,7 @@ func Test_decryptWithDifferentPassphrase_shouldProduceWrongResult(t *testing.T)
defer decryptedFileWriter.Close() defer decryptedFileWriter.Close()
_, err = decrypt(encryptedFileReader, []byte("garbage")) _, err = decrypt(encryptedFileReader, []byte("garbage"))
assert.NotNil(t, err, "Should not allow decrypt with wrong passphrase") require.Error(t, err, "Should not allow decrypt with wrong passphrase")
} }
t.Run("fips", func(t *testing.T) { t.Run("fips", func(t *testing.T) {
@ -350,3 +364,62 @@ func legacyAesEncrypt(input io.Reader, output io.Writer, passphrase []byte) erro
return nil return nil
} }
func Test_hasEncryptedHeader(t *testing.T) {
tests := []struct {
name string
data []byte
fipsMode bool
want bool
}{
{
name: "non-FIPS mode with valid header",
data: []byte("AES256-GCM" + "some encrypted data"),
fipsMode: false,
want: true,
},
{
name: "non-FIPS mode with FIPS header",
data: []byte("FIPS-AES256-GCM" + "some encrypted data"),
fipsMode: false,
want: false,
},
{
name: "FIPS mode with valid header",
data: []byte("FIPS-AES256-GCM" + "some encrypted data"),
fipsMode: true,
want: true,
},
{
name: "FIPS mode with non-FIPS header",
data: []byte("AES256-GCM" + "some encrypted data"),
fipsMode: true,
want: false,
},
{
name: "invalid header",
data: []byte("INVALID-HEADER" + "some data"),
fipsMode: false,
want: false,
},
{
name: "empty data",
data: []byte{},
fipsMode: false,
want: false,
},
{
name: "nil data",
data: nil,
fipsMode: false,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := hasEncryptedHeader(tt.data, tt.fipsMode)
assert.Equal(t, tt.want, got)
})
}
}

View File

@ -11,12 +11,12 @@ func TestCreateSignature(t *testing.T) {
privKey, pubKey, err := s.GenerateKeyPair() privKey, pubKey, err := s.GenerateKeyPair()
require.NoError(t, err) require.NoError(t, err)
require.Greater(t, len(privKey), 0) require.NotEmpty(t, privKey)
require.Greater(t, len(pubKey), 0) require.NotEmpty(t, pubKey)
m := "test message" m := "test message"
r, err := s.CreateSignature(m) r, err := s.CreateSignature(m)
require.NoError(t, err) require.NoError(t, err)
require.NotEqual(t, r, m) require.NotEqual(t, r, m)
require.Greater(t, len(r), 0) require.NotEmpty(t, r)
} }

View File

@ -1,7 +1,8 @@
package crypto package crypto
import ( import (
"golang.org/x/crypto/bcrypt" // Not allowed in FIPS mode
"golang.org/x/crypto/bcrypt" //nolint:depguard
) )
// Service represents a service for encrypting/hashing data. // Service represents a service for encrypting/hashing data.

View File

@ -1,18 +1,17 @@
package crypto package crypto
import ( import (
"crypto/fips140"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"os" "os"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/pkg/fips"
) )
// CreateTLSConfiguration creates a basic tls.Config with recommended TLS settings // CreateTLSConfiguration creates a basic tls.Config with recommended TLS settings
func CreateTLSConfiguration(insecureSkipVerify bool) *tls.Config { //nolint:forbidigo func CreateTLSConfiguration(insecureSkipVerify bool) *tls.Config { //nolint:forbidigo
// TODO: use fips.FIPSMode() instead return createTLSConfiguration(fips.FIPSMode(), insecureSkipVerify)
return createTLSConfiguration(fips140.Enabled(), insecureSkipVerify)
} }
func createTLSConfiguration(fipsEnabled bool, insecureSkipVerify bool) *tls.Config { //nolint:forbidigo func createTLSConfiguration(fipsEnabled bool, insecureSkipVerify bool) *tls.Config { //nolint:forbidigo
@ -58,8 +57,7 @@ func createTLSConfiguration(fipsEnabled bool, insecureSkipVerify bool) *tls.Conf
// CreateTLSConfigurationFromBytes initializes a tls.Config using a CA certificate, a certificate and a key // CreateTLSConfigurationFromBytes initializes a tls.Config using a CA certificate, a certificate and a key
// loaded from memory. // loaded from memory.
func CreateTLSConfigurationFromBytes(useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo func CreateTLSConfigurationFromBytes(useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo
// TODO: use fips.FIPSMode() instead return createTLSConfigurationFromBytes(fips.FIPSMode(), useTLS, caCert, cert, key, skipClientVerification, skipServerVerification)
return createTLSConfigurationFromBytes(fips140.Enabled(), useTLS, caCert, cert, key, skipClientVerification, skipServerVerification)
} }
func createTLSConfigurationFromBytes(fipsEnabled, useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo func createTLSConfigurationFromBytes(fipsEnabled, useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo
@ -90,8 +88,7 @@ func createTLSConfigurationFromBytes(fipsEnabled, useTLS bool, caCert, cert, key
// CreateTLSConfigurationFromDisk initializes a tls.Config using a CA certificate, a certificate and a key // CreateTLSConfigurationFromDisk initializes a tls.Config using a CA certificate, a certificate and a key
// loaded from disk. // loaded from disk.
func CreateTLSConfigurationFromDisk(config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo func CreateTLSConfigurationFromDisk(config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo
// TODO: use fips.FIPSMode() instead return createTLSConfigurationFromDisk(fips.FIPSMode(), config)
return createTLSConfigurationFromDisk(fips140.Enabled(), config)
} }
func createTLSConfigurationFromDisk(fipsEnabled bool, config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo func createTLSConfigurationFromDisk(fipsEnabled bool, config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo

View File

@ -44,7 +44,7 @@ func TestCreateTLSConfigurationFIPS(t *testing.T) {
func TestCreateTLSConfigurationFromBytes(t *testing.T) { func TestCreateTLSConfigurationFromBytes(t *testing.T) {
// No TLS // No TLS
config, err := CreateTLSConfigurationFromBytes(false, nil, nil, nil, false, false) config, err := CreateTLSConfigurationFromBytes(false, nil, nil, nil, false, false)
require.Nil(t, err) require.NoError(t, err)
require.Nil(t, config) require.Nil(t, config)
// Skip TLS client/server verifications // Skip TLS client/server verifications
@ -61,7 +61,7 @@ func TestCreateTLSConfigurationFromBytes(t *testing.T) {
func TestCreateTLSConfigurationFromDisk(t *testing.T) { func TestCreateTLSConfigurationFromDisk(t *testing.T) {
// No TLS // No TLS
config, err := CreateTLSConfigurationFromDisk(portainer.TLSConfiguration{}) config, err := CreateTLSConfigurationFromDisk(portainer.TLSConfiguration{})
require.Nil(t, err) require.NoError(t, err)
require.Nil(t, config) require.Nil(t, config)
// Skip TLS verifications // Skip TLS verifications

View File

@ -94,7 +94,7 @@ func Test_MarshalObjectUnencrypted(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) { t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) {
data, err := conn.MarshalObject(test.object) data, err := conn.MarshalObject(test.object)
is.NoError(err) require.NoError(t, err)
is.Equal(test.expected, string(data)) is.Equal(test.expected, string(data))
}) })
} }
@ -135,7 +135,7 @@ func Test_UnMarshalObjectUnencrypted(t *testing.T) {
t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) { t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) {
var object string var object string
err := conn.UnmarshalObject(test.object, &object) err := conn.UnmarshalObject(test.object, &object)
is.NoError(err) require.NoError(t, err)
is.Equal(test.expected, object) is.Equal(test.expected, object)
}) })
} }
@ -172,12 +172,12 @@ func Test_ObjectMarshallingEncrypted(t *testing.T) {
t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) { t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) {
data, err := conn.MarshalObject(test.object) data, err := conn.MarshalObject(test.object)
is.NoError(err) require.NoError(t, err)
var object []byte var object []byte
err = conn.UnmarshalObject(data, &object) err = conn.UnmarshalObject(data, &object)
is.NoError(err) require.NoError(t, err)
is.Equal(test.object, object) is.Equal(test.object, object)
}) })
} }

View File

@ -28,13 +28,12 @@ func NewService(connection portainer.Connection) (*Service, error) {
}, nil }, nil
} }
// CreateCustomTemplate uses the existing id and saves it.
// TODO: where does the ID come from, and is it safe?
func (service *Service) Create(customTemplate *portainer.CustomTemplate) error {
return service.Connection.CreateObjectWithId(BucketName, int(customTemplate.ID), customTemplate)
}
// GetNextIdentifier returns the next identifier for a custom template.
func (service *Service) GetNextIdentifier() int { func (service *Service) GetNextIdentifier() int {
return service.Connection.GetNextIdentifier(BucketName) return service.Connection.GetNextIdentifier(BucketName)
} }
func (service *Service) Create(customTemplate *portainer.CustomTemplate) error {
return service.Connection.UpdateTx(func(tx portainer.Transaction) error {
return service.Tx(tx).Create(customTemplate)
})
}

View File

@ -0,0 +1,19 @@
package customtemplate_test
import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore"
"github.com/stretchr/testify/require"
)
func TestCustomTemplateCreate(t *testing.T) {
_, ds := datastore.MustNewTestStore(t, true, false)
require.NotNil(t, ds)
require.NoError(t, ds.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1}))
e, err := ds.CustomTemplate().Read(1)
require.NoError(t, err)
require.Equal(t, portainer.CustomTemplateID(1), e.ID)
}

View File

@ -0,0 +1,31 @@
package customtemplate
import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
)
// Service represents a service for managing custom template data.
type ServiceTx struct {
dataservices.BaseDataServiceTx[portainer.CustomTemplate, portainer.CustomTemplateID]
}
func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
return ServiceTx{
BaseDataServiceTx: dataservices.BaseDataServiceTx[portainer.CustomTemplate, portainer.CustomTemplateID]{
Bucket: BucketName,
Connection: service.Connection,
Tx: tx,
},
}
}
func (service ServiceTx) GetNextIdentifier() int {
return service.Tx.GetNextIdentifier(BucketName)
}
// CreateCustomTemplate uses the existing id and saves it.
// TODO: where does the ID come from, and is it safe?
func (service ServiceTx) Create(customTemplate *portainer.CustomTemplate) error {
return service.Tx.CreateObjectWithId(BucketName, int(customTemplate.ID), customTemplate)
}

View File

@ -0,0 +1,28 @@
package customtemplate_test
import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/datastore"
"github.com/stretchr/testify/require"
)
func TestCustomTemplateCreateTx(t *testing.T) {
_, ds := datastore.MustNewTestStore(t, true, false)
require.NotNil(t, ds)
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
return tx.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1})
}))
var template *portainer.CustomTemplate
require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error {
var err error
template, err = tx.CustomTemplate().Read(1)
return err
}))
require.Equal(t, portainer.CustomTemplateID(1), template.ID)
}

View File

@ -91,9 +91,9 @@ func (service *Service) UpdateEndpointRelation(endpointID portainer.EndpointID,
}) })
} }
func (service *Service) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error { func (service *Service) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error {
return service.connection.UpdateTx(func(tx portainer.Transaction) error { return service.connection.UpdateTx(func(tx portainer.Transaction) error {
return service.Tx(tx).AddEndpointRelationsForEdgeStack(endpointIDs, edgeStackID) return service.Tx(tx).AddEndpointRelationsForEdgeStack(endpointIDs, edgeStack)
}) })
} }

View File

@ -5,6 +5,7 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/database/boltdb" "github.com/portainer/portainer/api/database/boltdb"
"github.com/portainer/portainer/api/dataservices/edgestack"
"github.com/portainer/portainer/api/internal/edge/cache" "github.com/portainer/portainer/api/internal/edge/cache"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -102,3 +103,38 @@ func TestUpdateRelation(t *testing.T) {
require.Equal(t, 0, edgeStacks[edgeStackID1].NumDeployments) require.Equal(t, 0, edgeStacks[edgeStackID1].NumDeployments)
require.Equal(t, 0, edgeStacks[edgeStackID2].NumDeployments) require.Equal(t, 0, edgeStacks[edgeStackID2].NumDeployments)
} }
func TestAddEndpointRelationsForEdgeStack(t *testing.T) {
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
err := conn.Open()
require.NoError(t, err)
defer conn.Close()
service, err := NewService(conn)
require.NoError(t, err)
edgeStackService, err := edgestack.NewService(conn, func(t portainer.Transaction, esi portainer.EdgeStackID) {})
require.NoError(t, err)
service.RegisterUpdateStackFunction(edgeStackService.UpdateEdgeStackFuncTx)
require.NoError(t, edgeStackService.Create(1, &portainer.EdgeStack{}))
require.NoError(t, service.Create(&portainer.EndpointRelation{EndpointID: 1, EdgeStacks: map[portainer.EdgeStackID]bool{}}))
require.NoError(t, service.AddEndpointRelationsForEdgeStack([]portainer.EndpointID{1}, &portainer.EdgeStack{ID: 1}))
}
func TestEndpointRelations(t *testing.T) {
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
err := conn.Open()
require.NoError(t, err)
defer conn.Close()
service, err := NewService(conn)
require.NoError(t, err)
require.NoError(t, service.Create(&portainer.EndpointRelation{EndpointID: 1}))
rels, err := service.EndpointRelations()
require.NoError(t, err)
require.Len(t, rels, 1)
}

View File

@ -76,14 +76,14 @@ func (service ServiceTx) UpdateEndpointRelation(endpointID portainer.EndpointID,
return nil return nil
} }
func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error { func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error {
for _, endpointID := range endpointIDs { for _, endpointID := range endpointIDs {
rel, err := service.EndpointRelation(endpointID) rel, err := service.EndpointRelation(endpointID)
if err != nil { if err != nil {
return err return err
} }
rel.EdgeStacks[edgeStackID] = true rel.EdgeStacks[edgeStack.ID] = true
identifier := service.service.connection.ConvertToKey(int(endpointID)) identifier := service.service.connection.ConvertToKey(int(endpointID))
err = service.tx.UpdateObject(BucketName, identifier, rel) err = service.tx.UpdateObject(BucketName, identifier, rel)
@ -97,8 +97,12 @@ func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portaine
service.service.endpointRelationsCache = nil service.service.endpointRelationsCache = nil
service.service.mu.Unlock() service.service.mu.Unlock()
if err := service.service.updateStackFnTx(service.tx, edgeStackID, func(edgeStack *portainer.EdgeStack) { if err := service.service.updateStackFnTx(service.tx, edgeStack.ID, func(es *portainer.EdgeStack) {
edgeStack.NumDeployments += len(endpointIDs) es.NumDeployments += len(endpointIDs)
// sync changes in `edgeStack` in case it is re-persisted after `AddEndpointRelationsForEdgeStack` call
// to avoid overriding with the previous values
edgeStack.NumDeployments = es.NumDeployments
}); err != nil { }); err != nil {
log.Error().Err(err).Msg("could not update the number of deployments") log.Error().Err(err).Msg("could not update the number of deployments")
} }

View File

@ -126,7 +126,7 @@ type (
EndpointRelation(EndpointID portainer.EndpointID) (*portainer.EndpointRelation, error) EndpointRelation(EndpointID portainer.EndpointID) (*portainer.EndpointRelation, error)
Create(endpointRelation *portainer.EndpointRelation) error Create(endpointRelation *portainer.EndpointRelation) error
UpdateEndpointRelation(EndpointID portainer.EndpointID, endpointRelation *portainer.EndpointRelation) error UpdateEndpointRelation(EndpointID portainer.EndpointID, endpointRelation *portainer.EndpointRelation) error
AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error
RemoveEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error RemoveEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error
DeleteEndpointRelation(EndpointID portainer.EndpointID) error DeleteEndpointRelation(EndpointID portainer.EndpointID) error
BucketName() string BucketName() string

View File

@ -1,13 +1,8 @@
package pendingactions package pendingactions
import ( import (
"fmt"
"time"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/rs/zerolog/log"
) )
const BucketName = "pending_actions" const BucketName = "pending_actions"
@ -16,10 +11,6 @@ type Service struct {
dataservices.BaseDataService[portainer.PendingAction, portainer.PendingActionID] dataservices.BaseDataService[portainer.PendingAction, portainer.PendingActionID]
} }
type ServiceTx struct {
dataservices.BaseDataServiceTx[portainer.PendingAction, portainer.PendingActionID]
}
func NewService(connection portainer.Connection) (*Service, error) { func NewService(connection portainer.Connection) (*Service, error) {
err := connection.SetServiceName(BucketName) err := connection.SetServiceName(BucketName)
if err != nil { if err != nil {
@ -34,6 +25,11 @@ func NewService(connection portainer.Connection) (*Service, error) {
}, nil }, nil
} }
// GetNextIdentifier returns the next identifier for a custom template.
func (service *Service) GetNextIdentifier() int {
return service.Connection.GetNextIdentifier(BucketName)
}
func (s Service) Create(config *portainer.PendingAction) error { func (s Service) Create(config *portainer.PendingAction) error {
return s.Connection.UpdateTx(func(tx portainer.Transaction) error { return s.Connection.UpdateTx(func(tx portainer.Transaction) error {
return s.Tx(tx).Create(config) return s.Tx(tx).Create(config)
@ -61,43 +57,3 @@ func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
}, },
} }
} }
func (s ServiceTx) Create(config *portainer.PendingAction) error {
return s.Tx.CreateObject(BucketName, func(id uint64) (int, any) {
config.ID = portainer.PendingActionID(id)
config.CreatedAt = time.Now().Unix()
return int(config.ID), config
})
}
func (s ServiceTx) Update(ID portainer.PendingActionID, config *portainer.PendingAction) error {
return s.BaseDataServiceTx.Update(ID, config)
}
func (s ServiceTx) DeleteByEndpointID(ID portainer.EndpointID) error {
log.Debug().Int("endpointId", int(ID)).Msg("deleting pending actions for endpoint")
pendingActions, err := s.ReadAll()
if err != nil {
return fmt.Errorf("failed to retrieve pending-actions for endpoint (%d): %w", ID, err)
}
for _, pendingAction := range pendingActions {
if pendingAction.EndpointID == ID {
if err := s.Delete(pendingAction.ID); err != nil {
log.Debug().Int("endpointId", int(ID)).Msgf("failed to delete pending action: %v", err)
}
}
}
return nil
}
// GetNextIdentifier returns the next identifier for a custom template.
func (service ServiceTx) GetNextIdentifier() int {
return service.Tx.GetNextIdentifier(BucketName)
}
// GetNextIdentifier returns the next identifier for a custom template.
func (service *Service) GetNextIdentifier() int {
return service.Connection.GetNextIdentifier(BucketName)
}

View File

@ -0,0 +1,49 @@
package pendingactions
import (
"fmt"
"time"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/rs/zerolog/log"
)
type ServiceTx struct {
dataservices.BaseDataServiceTx[portainer.PendingAction, portainer.PendingActionID]
}
func (s ServiceTx) Create(config *portainer.PendingAction) error {
return s.Tx.CreateObject(BucketName, func(id uint64) (int, any) {
config.ID = portainer.PendingActionID(id)
config.CreatedAt = time.Now().Unix()
return int(config.ID), config
})
}
func (s ServiceTx) Update(ID portainer.PendingActionID, config *portainer.PendingAction) error {
return s.BaseDataServiceTx.Update(ID, config)
}
func (s ServiceTx) DeleteByEndpointID(ID portainer.EndpointID) error {
log.Debug().Int("endpointId", int(ID)).Msg("deleting pending actions for endpoint")
pendingActions, err := s.ReadAll()
if err != nil {
return fmt.Errorf("failed to retrieve pending-actions for endpoint (%d): %w", ID, err)
}
for _, pendingAction := range pendingActions {
if pendingAction.EndpointID == ID {
if err := s.Delete(pendingAction.ID); err != nil {
log.Debug().Int("endpointId", int(ID)).Msgf("failed to delete pending action: %v", err)
}
}
}
return nil
}
// GetNextIdentifier returns the next identifier for a custom template.
func (service ServiceTx) GetNextIdentifier() int {
return service.Tx.GetNextIdentifier(BucketName)
}

View File

@ -4,17 +4,18 @@ import (
"testing" "testing"
"time" "time"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore" "github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/filesystem"
"github.com/gofrs/uuid" "github.com/gofrs/uuid"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/filesystem"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func newGuidString(t *testing.T) string { func newGuidString(t *testing.T) string {
uuid, err := uuid.NewV4() uuid, err := uuid.NewV4()
assert.NoError(t, err) require.NoError(t, err)
return uuid.String() return uuid.String()
} }
@ -41,7 +42,7 @@ func TestService_StackByWebhookID(t *testing.T) {
// can find a stack by webhook ID // can find a stack by webhook ID
got, err := store.StackService.StackByWebhookID(webhookID) got, err := store.StackService.StackByWebhookID(webhookID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, stack, *got) assert.Equal(t, stack, *got)
// returns nil and object not found error if there's no stack associated with the webhook // returns nil and object not found error if there's no stack associated with the webhook
@ -94,10 +95,10 @@ func Test_RefreshableStacks(t *testing.T) {
for _, stack := range []*portainer.Stack{&staticStack, &stackWithWebhook, &refreshableStack} { for _, stack := range []*portainer.Stack{&staticStack, &stackWithWebhook, &refreshableStack} {
err := store.Stack().Create(stack) err := store.Stack().Create(stack)
assert.NoError(t, err) require.NoError(t, err)
} }
stacks, err := store.Stack().RefreshableStacks() stacks, err := store.Stack().RefreshableStacks()
assert.NoError(t, err) require.NoError(t, err)
assert.ElementsMatch(t, []portainer.Stack{refreshableStack}, stacks) assert.ElementsMatch(t, []portainer.Stack{refreshableStack}, stacks)
} }

View File

@ -5,7 +5,9 @@ import (
"github.com/portainer/portainer/api/dataservices/errors" "github.com/portainer/portainer/api/dataservices/errors"
"github.com/portainer/portainer/api/datastore" "github.com/portainer/portainer/api/datastore"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_teamByName(t *testing.T) { func Test_teamByName(t *testing.T) {
@ -13,7 +15,7 @@ func Test_teamByName(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, true) _, store := datastore.MustNewTestStore(t, true, true)
_, err := store.Team().TeamByName("name") _, err := store.Team().TeamByName("name")
assert.ErrorIs(t, err, errors.ErrObjectNotFound) require.ErrorIs(t, err, errors.ErrObjectNotFound)
}) })
@ -29,7 +31,7 @@ func Test_teamByName(t *testing.T) {
teamBuilder.createNew("name1") teamBuilder.createNew("name1")
_, err := store.Team().TeamByName("name") _, err := store.Team().TeamByName("name")
assert.ErrorIs(t, err, errors.ErrObjectNotFound) require.ErrorIs(t, err, errors.ErrObjectNotFound)
}) })
t.Run("When there is an object with the same name should return the object", func(t *testing.T) { t.Run("When there is an object with the same name should return the object", func(t *testing.T) {
@ -44,7 +46,7 @@ func Test_teamByName(t *testing.T) {
expectedTeam := teamBuilder.createNew("name1") expectedTeam := teamBuilder.createNew("name1")
team, err := store.Team().TeamByName("name1") team, err := store.Team().TeamByName("name1")
assert.NoError(t, err, "TeamByName should succeed") require.NoError(t, err, "TeamByName should succeed")
assert.Equal(t, expectedTeam, team) assert.Equal(t, expectedTeam, team)
}) })
} }

View File

@ -5,15 +5,14 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/database/models" "github.com/portainer/portainer/api/database/models"
"github.com/stretchr/testify/require"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func TestStoreCreation(t *testing.T) { func TestStoreCreation(t *testing.T) {
_, store := MustNewTestStore(t, true, true) _, store := MustNewTestStore(t, true, true)
if store == nil { require.NotNil(t, store)
t.Fatal("Expect to create a store")
}
v, err := store.VersionService.Version() v, err := store.VersionService.Version()
if err != nil { if err != nil {

View File

@ -6,12 +6,14 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/dchest/uniuri"
"github.com/pkg/errors"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/chisel" "github.com/portainer/portainer/api/chisel"
"github.com/portainer/portainer/api/crypto" "github.com/portainer/portainer/api/crypto"
"github.com/dchest/uniuri"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
const ( const (
@ -30,45 +32,21 @@ func TestStoreFull(t *testing.T) {
_, store := MustNewTestStore(t, true, true) _, store := MustNewTestStore(t, true, true)
testCases := map[string]func(t *testing.T){ testCases := map[string]func(t *testing.T){
"User Accounts": func(t *testing.T) { "User Accounts": store.testUserAccounts,
store.testUserAccounts(t) "Environments": store.testEnvironments,
}, "Settings": store.testSettings,
"Environments": func(t *testing.T) { "SSL Settings": store.testSSLSettings,
store.testEnvironments(t) "Tunnel Server": store.testTunnelServer,
}, "Custom Templates": store.testCustomTemplates,
"Settings": func(t *testing.T) { "Registries": store.testRegistries,
store.testSettings(t) "Resource Control": store.testResourceControl,
}, "Schedules": store.testSchedules,
"SSL Settings": func(t *testing.T) { "Tags": store.testTags,
store.testSSLSettings(t)
},
"Tunnel Server": func(t *testing.T) {
store.testTunnelServer(t)
},
"Custom Templates": func(t *testing.T) {
store.testCustomTemplates(t)
},
"Registries": func(t *testing.T) {
store.testRegistries(t)
},
"Resource Control": func(t *testing.T) {
store.testResourceControl(t)
},
"Schedules": func(t *testing.T) {
store.testSchedules(t)
},
"Tags": func(t *testing.T) {
store.testTags(t)
},
// "Test Title": func(t *testing.T) {
// },
} }
for name, test := range testCases { for name, test := range testCases {
t.Run(name, test) t.Run(name, test)
} }
} }
func (store *Store) testEnvironments(t *testing.T) { func (store *Store) testEnvironments(t *testing.T) {
@ -167,7 +145,7 @@ func (store *Store) CreateEndpoint(t *testing.T, name string, endpointType porta
store.Endpoint().Create(expectedEndpoint) store.Endpoint().Create(expectedEndpoint)
endpoint, err := store.Endpoint().Endpoint(id) endpoint, err := store.Endpoint().Endpoint(id)
is.NoError(err, "Endpoint() should not return an error") require.NoError(t, err, "Endpoint() should not return an error")
is.Equal(expectedEndpoint, endpoint, "endpoint should be the same") is.Equal(expectedEndpoint, endpoint, "endpoint should be the same")
return endpoint.ID return endpoint.ID
@ -194,7 +172,7 @@ func (store *Store) testSSLSettings(t *testing.T) {
store.SSLSettings().UpdateSettings(ssl) store.SSLSettings().UpdateSettings(ssl)
settings, err := store.SSLSettings().Settings() settings, err := store.SSLSettings().Settings()
is.NoError(err, "Get sslsettings should succeed") require.NoError(t, err, "Get sslsettings should succeed")
is.Equal(ssl, settings, "Stored SSLSettings should be the same as what is read out") is.Equal(ssl, settings, "Stored SSLSettings should be the same as what is read out")
} }
@ -203,27 +181,27 @@ func (store *Store) testTunnelServer(t *testing.T) {
expectPrivateKeySeed := uniuri.NewLen(16) expectPrivateKeySeed := uniuri.NewLen(16)
err := store.TunnelServer().UpdateInfo(&portainer.TunnelServerInfo{PrivateKeySeed: expectPrivateKeySeed}) err := store.TunnelServer().UpdateInfo(&portainer.TunnelServerInfo{PrivateKeySeed: expectPrivateKeySeed})
is.NoError(err, "UpdateInfo should have succeeded") require.NoError(t, err, "UpdateInfo should have succeeded")
serverInfo, err := store.TunnelServer().Info() serverInfo, err := store.TunnelServer().Info()
is.NoError(err, "Info should have succeeded") require.NoError(t, err, "Info should have succeeded")
is.Equal(expectPrivateKeySeed, serverInfo.PrivateKeySeed, "hashed passwords should not differ") is.Equal(expectPrivateKeySeed, serverInfo.PrivateKeySeed, "hashed passwords should not differ")
} }
// add users, read them back and check the details are unchanged // add users, read them back and check the details are unchanged
func (store *Store) testUserAccounts(t *testing.T) { func (store *Store) testUserAccounts(t *testing.T) {
is := assert.New(t)
err := store.createAccount(adminUsername, adminPassword, portainer.AdministratorRole) err := store.createAccount(adminUsername, adminPassword, portainer.AdministratorRole)
is.NoError(err, "CreateAccount should succeed") require.NoError(t, err, "CreateAccount should succeed")
store.checkAccount(adminUsername, adminPassword, portainer.AdministratorRole)
is.NoError(err, "Account failure") err = store.checkAccount(adminUsername, adminPassword, portainer.AdministratorRole)
require.NoError(t, err, "Account failure")
err = store.createAccount(standardUsername, standardPassword, portainer.StandardUserRole) err = store.createAccount(standardUsername, standardPassword, portainer.StandardUserRole)
is.NoError(err, "CreateAccount should succeed") require.NoError(t, err, "CreateAccount should succeed")
store.checkAccount(standardUsername, standardPassword, portainer.StandardUserRole)
is.NoError(err, "Account failure") err = store.checkAccount(standardUsername, standardPassword, portainer.StandardUserRole)
require.NoError(t, err, "Account failure")
} }
// create an account with the provided details // create an account with the provided details
@ -238,12 +216,7 @@ func (store *Store) createAccount(username, password string, role portainer.User
return err return err
} }
err = store.User().Create(user) return store.User().Create(user)
if err != nil {
return err
}
return nil
} }
func (store *Store) checkAccount(username, expectPassword string, expectRole portainer.UserRole) error { func (store *Store) checkAccount(username, expectPassword string, expectRole portainer.UserRole) error {
@ -260,12 +233,7 @@ func (store *Store) checkAccount(username, expectPassword string, expectRole por
// Check the password // Check the password
cs := crypto.Service{} cs := crypto.Service{}
expectPasswordHash, err := cs.Hash(expectPassword) if cs.CompareHashAndData(user.Password, expectPassword) != nil {
if err != nil {
return errors.Wrap(err, "hash failed")
}
if user.Password != expectPasswordHash {
return fmt.Errorf("%s user password hash failure", user.Username) return fmt.Errorf("%s user password hash failure", user.Username)
} }
@ -277,7 +245,7 @@ func (store *Store) testSettings(t *testing.T) {
// since many settings are default and basically nil, I'm going to update some and read them back // since many settings are default and basically nil, I'm going to update some and read them back
expectedSettings, err := store.Settings().Settings() expectedSettings, err := store.Settings().Settings()
is.NoError(err, "Settings() should not return an error") require.NoError(t, err, "Settings() should not return an error")
expectedSettings.TemplatesURL = "http://portainer.io/application-templates" expectedSettings.TemplatesURL = "http://portainer.io/application-templates"
expectedSettings.HelmRepositoryURL = "http://portainer.io/helm-repository" expectedSettings.HelmRepositoryURL = "http://portainer.io/helm-repository"
expectedSettings.EdgeAgentCheckinInterval = 60 expectedSettings.EdgeAgentCheckinInterval = 60
@ -291,10 +259,10 @@ func (store *Store) testSettings(t *testing.T) {
expectedSettings.SnapshotInterval = "10m" expectedSettings.SnapshotInterval = "10m"
err = store.Settings().UpdateSettings(expectedSettings) err = store.Settings().UpdateSettings(expectedSettings)
is.NoError(err, "UpdateSettings() should succeed") require.NoError(t, err, "UpdateSettings() should succeed")
settings, err := store.Settings().Settings() settings, err := store.Settings().Settings()
is.NoError(err, "Settings() should not return an error") require.NoError(t, err, "Settings() should not return an error")
is.Equal(expectedSettings, settings, "stored settings should match") is.Equal(expectedSettings, settings, "stored settings should match")
} }
@ -317,7 +285,7 @@ func (store *Store) testCustomTemplates(t *testing.T) {
customTemplate.Create(expectedTemplate) customTemplate.Create(expectedTemplate)
actualTemplate, err := customTemplate.Read(expectedTemplate.ID) actualTemplate, err := customTemplate.Read(expectedTemplate.ID)
is.NoError(err, "CustomTemplate should not return an error") require.NoError(t, err, "CustomTemplate should not return an error")
is.Equal(expectedTemplate, actualTemplate, "expected and actual template do not match") is.Equal(expectedTemplate, actualTemplate, "expected and actual template do not match")
} }
@ -345,17 +313,17 @@ func (store *Store) testRegistries(t *testing.T) {
} }
err := regService.Create(reg1) err := regService.Create(reg1)
is.NoError(err) require.NoError(t, err)
err = regService.Create(reg2) err = regService.Create(reg2)
is.NoError(err) require.NoError(t, err)
actualReg1, err := regService.Read(reg1.ID) actualReg1, err := regService.Read(reg1.ID)
is.NoError(err) require.NoError(t, err)
is.Equal(reg1, actualReg1, "registries differ") is.Equal(reg1, actualReg1, "registries differ")
actualReg2, err := regService.Read(reg2.ID) actualReg2, err := regService.Read(reg2.ID)
is.NoError(err) require.NoError(t, err)
is.Equal(reg2, actualReg2, "registries differ") is.Equal(reg2, actualReg2, "registries differ")
} }
@ -378,10 +346,10 @@ func (store *Store) testSchedules(t *testing.T) {
} }
err := schedule.CreateSchedule(s) err := schedule.CreateSchedule(s)
is.NoError(err, "CreateSchedule should succeed") require.NoError(t, err, "CreateSchedule should succeed")
actual, err := schedule.Schedule(s.ID) actual, err := schedule.Schedule(s.ID)
is.NoError(err, "schedule should be found") require.NoError(t, err, "schedule should be found")
is.Equal(s, actual, "schedules differ") is.Equal(s, actual, "schedules differ")
} }
@ -401,16 +369,16 @@ func (store *Store) testTags(t *testing.T) {
} }
err := tags.Create(tag1) err := tags.Create(tag1)
is.NoError(err, "Tags.Create should succeed") require.NoError(t, err, "Tags.Create should succeed")
err = tags.Create(tag2) err = tags.Create(tag2)
is.NoError(err, "Tags.Create should succeed") require.NoError(t, err, "Tags.Create should succeed")
actual, err := tags.Read(tag1.ID) actual, err := tags.Read(tag1.ID)
is.NoError(err, "tag1 should be found") require.NoError(t, err, "tag1 should be found")
is.Equal(tag1, actual, "tags differ") is.Equal(tag1, actual, "tags differ")
actual, err = tags.Read(tag2.ID) actual, err = tags.Read(tag2.ID)
is.NoError(err, "tag2 should be found") require.NoError(t, err, "tag2 should be found")
is.Equal(tag2, actual, "tags differ") is.Equal(tag2, actual, "tags differ")
} }

View File

@ -6,7 +6,9 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore/migrator" "github.com/portainer/portainer/api/datastore/migrator"
gittypes "github.com/portainer/portainer/api/git/types" gittypes "github.com/portainer/portainer/api/git/types"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestMigrateStackEntryPoint(t *testing.T) { func TestMigrateStackEntryPoint(t *testing.T) {
@ -28,25 +30,25 @@ func TestMigrateStackEntryPoint(t *testing.T) {
for _, s := range stacks { for _, s := range stacks {
err := stackService.Create(s) err := stackService.Create(s)
assert.NoError(t, err, "failed to create stack") require.NoError(t, err, "failed to create stack")
} }
s, err := stackService.Read(1) s, err := stackService.Read(1)
assert.NoError(t, err) require.NoError(t, err)
assert.Nil(t, s.GitConfig, "first stack should not have git config") assert.Nil(t, s.GitConfig, "first stack should not have git config")
s, err = stackService.Read(2) s, err = stackService.Read(2)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "", s.GitConfig.ConfigFilePath, "not migrated yet migrated") assert.Empty(t, s.GitConfig.ConfigFilePath, "not migrated yet migrated")
err = migrator.MigrateStackEntryPoint(stackService) err = migrator.MigrateStackEntryPoint(stackService)
assert.NoError(t, err, "failed to migrate entry point to Git ConfigFilePath") require.NoError(t, err, "failed to migrate entry point to Git ConfigFilePath")
s, err = stackService.Read(1) s, err = stackService.Read(1)
assert.NoError(t, err) require.NoError(t, err)
assert.Nil(t, s.GitConfig, "first stack should not have git config") assert.Nil(t, s.GitConfig, "first stack should not have git config")
s, err = stackService.Read(2) s, err = stackService.Read(2)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "dir/sub/compose.yml", s.GitConfig.ConfigFilePath, "second stack should have config file path migrated") assert.Equal(t, "dir/sub/compose.yml", s.GitConfig.ConfigFilePath, "second stack should have config file path migrated")
} }

View File

@ -11,8 +11,10 @@ func (m *Migrator) migrateEdgeGroupEndpointsToRoars_2_33_0() error {
} }
for _, eg := range egs { for _, eg := range egs {
if eg.EndpointIDs.Len() == 0 {
eg.EndpointIDs = roar.FromSlice(eg.Endpoints) eg.EndpointIDs = roar.FromSlice(eg.Endpoints)
eg.Endpoints = nil eg.Endpoints = nil
}
if err := m.edgeGroupService.Update(eg.ID, &eg); err != nil { if err := m.edgeGroupService.Update(eg.ID, &eg); err != nil {
return err return err

View File

@ -0,0 +1,55 @@
package migrator
import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/database/boltdb"
"github.com/portainer/portainer/api/dataservices/edgegroup"
"github.com/stretchr/testify/require"
)
func TestMigrateEdgeGroupEndpointsToRoars_2_33_0Idempotency(t *testing.T) {
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
err := conn.Open()
require.NoError(t, err)
defer conn.Close()
edgeGroupService, err := edgegroup.NewService(conn)
require.NoError(t, err)
edgeGroup := &portainer.EdgeGroup{
ID: 1,
Name: "test-edge-group",
Endpoints: []portainer.EndpointID{1, 2, 3},
}
err = conn.CreateObjectWithId(edgegroup.BucketName, int(edgeGroup.ID), edgeGroup)
require.NoError(t, err)
m := NewMigrator(&MigratorParameters{EdgeGroupService: edgeGroupService})
// Run migration once
err = m.migrateEdgeGroupEndpointsToRoars_2_33_0()
require.NoError(t, err)
migratedEdgeGroup, err := edgeGroupService.Read(edgeGroup.ID)
require.NoError(t, err)
require.Empty(t, migratedEdgeGroup.Endpoints)
require.Equal(t, len(edgeGroup.Endpoints), migratedEdgeGroup.EndpointIDs.Len())
// Run migration again to ensure the results didn't change
err = m.migrateEdgeGroupEndpointsToRoars_2_33_0()
require.NoError(t, err)
migratedEdgeGroup, err = edgeGroupService.Read(edgeGroup.ID)
require.NoError(t, err)
require.Empty(t, migratedEdgeGroup.Endpoints)
require.Equal(t, len(edgeGroup.Endpoints), migratedEdgeGroup.EndpointIDs.Len())
}

View File

@ -256,10 +256,9 @@ func (m *Migrator) initMigrations() {
m.addMigrations("2.32.0", m.addEndpointRelationForEdgeAgents_2_32_0) m.addMigrations("2.32.0", m.addEndpointRelationForEdgeAgents_2_32_0)
m.addMigrations("2.33.0-rc1", m.migrateEdgeGroupEndpointsToRoars_2_33_0) m.addMigrations("2.33.1", m.migrateEdgeGroupEndpointsToRoars_2_33_0)
//m.addMigrations("2.33.0", m.migrateEdgeGroupEndpointsToRoars_2_33_0) // WARNING: do not change migrations that have already been released!
// when we release 2.33.0 it will also run the rc-1 migration function
// Add new migrations above... // Add new migrations above...
// One function per migration, each versions migration funcs in the same file. // One function per migration, each versions migration funcs in the same file.

View File

@ -2,6 +2,7 @@ package postinit
import ( import (
"context" "context"
"fmt"
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client" "github.com/docker/docker/client"
@ -83,18 +84,28 @@ func (postInitMigrator *PostInitMigrator) PostInitMigrate() error {
// try to create a post init migration pending action. If it already exists, do nothing // try to create a post init migration pending action. If it already exists, do nothing
// this function exists for readability, not reusability // this function exists for readability, not reusability
// TODO: This should be moved into pending actions as part of the pending action migration
func (postInitMigrator *PostInitMigrator) createPostInitMigrationPendingAction(environmentID portainer.EndpointID) error { func (postInitMigrator *PostInitMigrator) createPostInitMigrationPendingAction(environmentID portainer.EndpointID) error {
// If there are no pending actions for the given endpoint, create one action := portainer.PendingAction{
err := postInitMigrator.dataStore.PendingActions().Create(&portainer.PendingAction{
EndpointID: environmentID, EndpointID: environmentID,
Action: actions.PostInitMigrateEnvironment, Action: actions.PostInitMigrateEnvironment,
})
if err != nil {
log.Error().Err(err).Msgf("Error creating pending action for environment %d", environmentID)
} }
pendingActions, err := postInitMigrator.dataStore.PendingActions().ReadAll()
if err != nil {
return fmt.Errorf("failed to retrieve pending actions: %w", err)
}
for _, dba := range pendingActions {
if dba.EndpointID == action.EndpointID && dba.Action == action.Action {
log.Debug().
Str("action", action.Action).
Int("endpoint_id", int(action.EndpointID)).
Msg("pending action already exists for environment, skipping...")
return nil return nil
} }
}
return postInitMigrator.dataStore.PendingActions().Create(&action)
}
// MigrateEnvironment runs migrations on a single environment // MigrateEnvironment runs migrations on a single environment
func (migrator *PostInitMigrator) MigrateEnvironment(environment *portainer.Endpoint) error { func (migrator *PostInitMigrator) MigrateEnvironment(environment *portainer.Endpoint) error {

View File

@ -8,10 +8,12 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore" "github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/pendingactions/actions"
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client" "github.com/docker/docker/client"
"github.com/segmentio/encoding/json" "github.com/segmentio/encoding/json"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -20,8 +22,9 @@ func TestMigrateGPUs(t *testing.T) {
if strings.HasSuffix(r.URL.Path, "/containers/json") { if strings.HasSuffix(r.URL.Path, "/containers/json") {
containerSummary := []container.Summary{{ID: "container1"}} containerSummary := []container.Summary{{ID: "container1"}}
err := json.NewEncoder(w).Encode(containerSummary) if err := json.NewEncoder(w).Encode(containerSummary); err != nil {
require.NoError(t, err) w.WriteHeader(http.StatusInternalServerError)
}
return return
} }
@ -39,8 +42,9 @@ func TestMigrateGPUs(t *testing.T) {
}, },
} }
err := json.NewEncoder(w).Encode(container) if err := json.NewEncoder(w).Encode(container); err != nil {
require.NoError(t, err) w.WriteHeader(http.StatusInternalServerError)
}
})) }))
defer srv.Close() defer srv.Close()
@ -73,3 +77,98 @@ func TestMigrateGPUs(t *testing.T) {
require.False(t, migratedEndpoint.PostInitMigrations.MigrateGPUs) require.False(t, migratedEndpoint.PostInitMigrations.MigrateGPUs)
require.True(t, migratedEndpoint.EnableGPUManagement) require.True(t, migratedEndpoint.EnableGPUManagement)
} }
func TestPostInitMigrate_PendingActionsCreated(t *testing.T) {
tests := []struct {
name string
existingPendingActions []*portainer.PendingAction
expectedPendingActions int
expectedAction string
}{
{
name: "when existing non-matching action exists, should add migration action",
existingPendingActions: []*portainer.PendingAction{
{
EndpointID: 7,
Action: "some-other-action",
},
},
expectedPendingActions: 2,
expectedAction: actions.PostInitMigrateEnvironment,
},
{
name: "when matching action exists, should not add duplicate",
existingPendingActions: []*portainer.PendingAction{
{
EndpointID: 7,
Action: actions.PostInitMigrateEnvironment,
},
},
expectedPendingActions: 1,
expectedAction: actions.PostInitMigrateEnvironment,
},
{
name: "when no actions exist, should add migration action",
existingPendingActions: []*portainer.PendingAction{},
expectedPendingActions: 1,
expectedAction: actions.PostInitMigrateEnvironment,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
is := assert.New(t)
_, store := datastore.MustNewTestStore(t, true, true)
// Create test endpoint
endpoint := &portainer.Endpoint{
ID: 7,
UserTrusted: true,
Type: portainer.EdgeAgentOnDockerEnvironment,
Edge: portainer.EnvironmentEdgeSettings{
AsyncMode: false,
},
EdgeID: "edgeID",
}
err := store.Endpoint().Create(endpoint)
require.NoError(t, err, "error creating endpoint")
// Create any existing pending actions
for _, action := range tt.existingPendingActions {
err = store.PendingActions().Create(action)
require.NoError(t, err, "error creating pending action")
}
migrator := NewPostInitMigrator(
nil, // kubeFactory not needed for this test
nil, // dockerFactory not needed for this test
store,
"", // assetsPath not needed for this test
nil, // kubernetesDeployer not needed for this test
)
err = migrator.PostInitMigrate()
require.NoError(t, err, "PostInitMigrate should not return error")
// Verify the results
pendingActions, err := store.PendingActions().ReadAll()
require.NoError(t, err, "error reading pending actions")
is.Len(pendingActions, tt.expectedPendingActions, "unexpected number of pending actions")
// If we expect any actions, verify at least one has the expected action type
if tt.expectedPendingActions > 0 {
hasExpectedAction := false
for _, action := range pendingActions {
if action.Action == tt.expectedAction {
hasExpectedAction = true
is.Equal(endpoint.ID, action.EndpointID, "action should reference correct endpoint")
break
}
}
is.True(hasExpectedAction, "should have found action of expected type")
}
})
}
}

View File

@ -14,7 +14,9 @@ func (tx *StoreTx) IsErrObjectNotFound(err error) bool {
return tx.store.IsErrObjectNotFound(err) return tx.store.IsErrObjectNotFound(err)
} }
func (tx *StoreTx) CustomTemplate() dataservices.CustomTemplateService { return nil } func (tx *StoreTx) CustomTemplate() dataservices.CustomTemplateService {
return tx.store.CustomTemplateService.Tx(tx.tx)
}
func (tx *StoreTx) PendingActions() dataservices.PendingActionsService { func (tx *StoreTx) PendingActions() dataservices.PendingActionsService {
return tx.store.PendingActionsService.Tx(tx.tx) return tx.store.PendingActionsService.Tx(tx.tx)

View File

@ -83,7 +83,6 @@
"MigrateIngresses": true "MigrateIngresses": true
}, },
"PublicURL": "", "PublicURL": "",
"QueryDate": 0,
"SecuritySettings": { "SecuritySettings": {
"allowBindMountsForRegularUsers": true, "allowBindMountsForRegularUsers": true,
"allowContainerCapabilitiesForRegularUsers": true, "allowContainerCapabilitiesForRegularUsers": true,
@ -615,7 +614,7 @@
"RequiredPasswordLength": 12 "RequiredPasswordLength": 12
}, },
"KubeconfigExpiry": "0", "KubeconfigExpiry": "0",
"KubectlShellImage": "portainer/kubectl-shell:2.33.0-rc1", "KubectlShellImage": "portainer/kubectl-shell:2.34.0",
"LDAPSettings": { "LDAPSettings": {
"AnonymousMode": true, "AnonymousMode": true,
"AutoCreateUsers": true, "AutoCreateUsers": true,
@ -944,7 +943,7 @@
} }
], ],
"version": { "version": {
"VERSION": "{\"SchemaVersion\":\"2.33.0-rc1\",\"MigratorCount\":1,\"Edition\":1,\"InstanceID\":\"463d5c47-0ea5-4aca-85b1-405ceefee254\"}" "VERSION": "{\"SchemaVersion\":\"2.34.0\",\"MigratorCount\":0,\"Edition\":1,\"InstanceID\":\"463d5c47-0ea5-4aca-85b1-405ceefee254\"}"
}, },
"webhooks": null "webhooks": null
} }

View File

@ -4,10 +4,14 @@ import (
"testing" "testing"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/pkg/fips"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestHttpClient(t *testing.T) { func TestHttpClient(t *testing.T) {
fips.InitFIPS(false)
// Valid TLS configuration // Valid TLS configuration
endpoint := &portainer.Endpoint{} endpoint := &portainer.Endpoint{}
endpoint.TLSConfig = portainer.TLSConfiguration{TLS: true} endpoint.TLSConfig = portainer.TLSConfiguration{TLS: true}

View File

@ -1,37 +0,0 @@
package docker
import "github.com/docker/docker/api/types"
type ContainerStats struct {
Running int `json:"running"`
Stopped int `json:"stopped"`
Healthy int `json:"healthy"`
Unhealthy int `json:"unhealthy"`
Total int `json:"total"`
}
func CalculateContainerStats(containers []types.Container) ContainerStats {
var running, stopped, healthy, unhealthy int
for _, container := range containers {
switch container.State {
case "running":
running++
case "healthy":
running++
healthy++
case "unhealthy":
running++
unhealthy++
case "exited", "stopped":
stopped++
}
}
return ContainerStats{
Running: running,
Stopped: stopped,
Healthy: healthy,
Unhealthy: unhealthy,
Total: len(containers),
}
}

View File

@ -1,27 +0,0 @@
package docker
import (
"testing"
"github.com/docker/docker/api/types"
"github.com/stretchr/testify/assert"
)
func TestCalculateContainerStats(t *testing.T) {
containers := []types.Container{
{State: "running"},
{State: "running"},
{State: "exited"},
{State: "stopped"},
{State: "healthy"},
{State: "unhealthy"},
}
stats := CalculateContainerStats(containers)
assert.Equal(t, 4, stats.Running)
assert.Equal(t, 2, stats.Stopped)
assert.Equal(t, 1, stats.Healthy)
assert.Equal(t, 1, stats.Unhealthy)
assert.Equal(t, 6, stats.Total)
}

View File

@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestImageParser(t *testing.T) { func TestImageParser(t *testing.T) {
@ -14,7 +15,7 @@ func TestImageParser(t *testing.T) {
image, err := ParseImage(ParseImageOptions{ image, err := ParseImage(ParseImageOptions{
Name: "portainer/portainer-ee", Name: "portainer/portainer-ee",
}) })
is.NoError(err, "") require.NoError(t, err)
is.Equal("docker.io/portainer/portainer-ee:latest", image.FullName()) is.Equal("docker.io/portainer/portainer-ee:latest", image.FullName())
is.Equal("portainer/portainer-ee", image.Opts.Name) is.Equal("portainer/portainer-ee", image.Opts.Name)
is.Equal("latest", image.Tag) is.Equal("latest", image.Tag)
@ -30,10 +31,10 @@ func TestImageParser(t *testing.T) {
image, err := ParseImage(ParseImageOptions{ image, err := ParseImage(ParseImageOptions{
Name: "gcr.io/k8s-minikube/kicbase@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", Name: "gcr.io/k8s-minikube/kicbase@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2",
}) })
is.NoError(err, "") require.NoError(t, err)
is.Equal("gcr.io/k8s-minikube/kicbase@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.FullName()) is.Equal("gcr.io/k8s-minikube/kicbase@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.FullName())
is.Equal("gcr.io/k8s-minikube/kicbase@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name) is.Equal("gcr.io/k8s-minikube/kicbase@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name)
is.Equal("", image.Tag) is.Empty(image.Tag)
is.Equal("k8s-minikube/kicbase", image.Path) is.Equal("k8s-minikube/kicbase", image.Path)
is.Equal("gcr.io", image.Domain) is.Equal("gcr.io", image.Domain)
is.Equal("https://gcr.io/k8s-minikube/kicbase", image.HubLink) is.Equal("https://gcr.io/k8s-minikube/kicbase", image.HubLink)
@ -47,7 +48,7 @@ func TestImageParser(t *testing.T) {
image, err := ParseImage(ParseImageOptions{ image, err := ParseImage(ParseImageOptions{
Name: "gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", Name: "gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2",
}) })
is.NoError(err, "") require.NoError(t, err)
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30", image.FullName()) is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30", image.FullName())
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name) is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name)
is.Equal("v0.0.30", image.Tag) is.Equal("v0.0.30", image.Tag)
@ -68,8 +69,9 @@ func TestUpdateParsedImage(t *testing.T) {
image, err := ParseImage(ParseImageOptions{ image, err := ParseImage(ParseImageOptions{
Name: "gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", Name: "gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2",
}) })
is.NoError(err, "") require.NoError(t, err)
_ = image.WithTag("v0.0.31") err = image.WithTag("v0.0.31")
require.NoError(t, err)
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.31", image.FullName()) is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.31", image.FullName())
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name) is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name)
is.Equal("v0.0.31", image.Tag) is.Equal("v0.0.31", image.Tag)
@ -86,8 +88,9 @@ func TestUpdateParsedImage(t *testing.T) {
image, err := ParseImage(ParseImageOptions{ image, err := ParseImage(ParseImageOptions{
Name: "gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", Name: "gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2",
}) })
is.NoError(err, "") require.NoError(t, err)
_ = image.WithDigest("sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b3") err = image.WithDigest("sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b3")
require.NoError(t, err)
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30", image.FullName()) is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30", image.FullName())
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name) is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name)
is.Equal("v0.0.30", image.Tag) is.Equal("v0.0.30", image.Tag)
@ -104,8 +107,9 @@ func TestUpdateParsedImage(t *testing.T) {
image, err := ParseImage(ParseImageOptions{ image, err := ParseImage(ParseImageOptions{
Name: "gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", Name: "gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2",
}) })
is.NoError(err, "") require.NoError(t, err)
_ = image.TrimDigest() err = image.TrimDigest()
require.NoError(t, err)
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30", image.FullName()) is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30", image.FullName())
is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name) is.Equal("gcr.io/k8s-minikube/kicbase:v0.0.30@sha256:02c921df998f95e849058af14de7045efc3954d90320967418a0d1f182bbc0b2", image.Opts.Name)
is.Equal("v0.0.30", image.Tag) is.Equal("v0.0.30", image.Tag)

View File

@ -4,7 +4,9 @@ import (
"testing" "testing"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestFindBestMatchNeedAuthRegistry(t *testing.T) { func TestFindBestMatchNeedAuthRegistry(t *testing.T) {
@ -15,9 +17,9 @@ func TestFindBestMatchNeedAuthRegistry(t *testing.T) {
registries := []portainer.Registry{createNewRegistry("docker.io", "USERNAME", false), registries := []portainer.Registry{createNewRegistry("docker.io", "USERNAME", false),
createNewRegistry("hub-mirror.c.163.com", "", false)} createNewRegistry("hub-mirror.c.163.com", "", false)}
r, err := findBestMatchRegistry(image, registries) r, err := findBestMatchRegistry(image, registries)
is.NoError(err, "") require.NoError(t, err)
is.NotNil(r, "") is.NotNil(r)
is.False(r.Authentication, "") is.False(r.Authentication)
is.Equal("docker.io", r.URL) is.Equal("docker.io", r.URL)
}) })
@ -26,9 +28,9 @@ func TestFindBestMatchNeedAuthRegistry(t *testing.T) {
registries := []portainer.Registry{createNewRegistry("docker.io", "", false), registries := []portainer.Registry{createNewRegistry("docker.io", "", false),
createNewRegistry("hub-mirror.c.163.com", "USERNAME", false)} createNewRegistry("hub-mirror.c.163.com", "USERNAME", false)}
r, err := findBestMatchRegistry(image, registries) r, err := findBestMatchRegistry(image, registries)
is.NoError(err, "") require.NoError(t, err)
is.NotNil(r, "") is.NotNil(r)
is.False(r.Authentication, "") is.False(r.Authentication)
is.Equal("docker.io", r.URL) is.Equal("docker.io", r.URL)
}) })
@ -37,9 +39,9 @@ func TestFindBestMatchNeedAuthRegistry(t *testing.T) {
registries := []portainer.Registry{createNewRegistry("docker.io", "USERNAME", true), registries := []portainer.Registry{createNewRegistry("docker.io", "USERNAME", true),
createNewRegistry("hub-mirror.c.163.com", "", false)} createNewRegistry("hub-mirror.c.163.com", "", false)}
r, err := findBestMatchRegistry(image, registries) r, err := findBestMatchRegistry(image, registries)
is.NoError(err, "") require.NoError(t, err)
is.NotNil(r, "") is.NotNil(r)
is.True(r.Authentication, "") is.True(r.Authentication)
is.Equal("docker.io", r.URL) is.Equal("docker.io", r.URL)
}) })
@ -47,9 +49,9 @@ func TestFindBestMatchNeedAuthRegistry(t *testing.T) {
image := "portainer/portainer-ee:latest" image := "portainer/portainer-ee:latest"
registries := []portainer.Registry{createNewRegistry("docker.io", "", true)} registries := []portainer.Registry{createNewRegistry("docker.io", "", true)}
r, err := findBestMatchRegistry(image, registries) r, err := findBestMatchRegistry(image, registries)
is.NoError(err, "") require.NoError(t, err)
is.NotNil(r, "") is.NotNil(r)
is.True(r.Authentication, "") is.True(r.Authentication)
is.Equal("docker.io", r.URL) is.Equal("docker.io", r.URL)
}) })
} }

View File

@ -0,0 +1,125 @@
package stats
import (
"context"
"errors"
"strings"
"sync"
"github.com/docker/docker/api/types/container"
)
type ContainerStats struct {
Running int `json:"running"`
Stopped int `json:"stopped"`
Healthy int `json:"healthy"`
Unhealthy int `json:"unhealthy"`
Total int `json:"total"`
}
type DockerClient interface {
ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error)
}
func CalculateContainerStats(ctx context.Context, cli DockerClient, isSwarm bool, containers []container.Summary) (ContainerStats, error) {
if isSwarm {
return CalculateContainerStatsForSwarm(containers), nil
}
var running, stopped, healthy, unhealthy int
var mu sync.Mutex
var wg sync.WaitGroup
semaphore := make(chan struct{}, 5)
var aggErr error
var aggMu sync.Mutex
for i := range containers {
id := containers[i].ID
semaphore <- struct{}{}
wg.Go(func() {
defer func() { <-semaphore }()
containerInspection, err := cli.ContainerInspect(ctx, id)
stat := ContainerStats{}
if err != nil {
aggMu.Lock()
aggErr = errors.Join(aggErr, err)
aggMu.Unlock()
return
}
stat = getContainerStatus(containerInspection.State)
mu.Lock()
running += stat.Running
stopped += stat.Stopped
healthy += stat.Healthy
unhealthy += stat.Unhealthy
mu.Unlock()
})
}
wg.Wait()
return ContainerStats{
Running: running,
Stopped: stopped,
Healthy: healthy,
Unhealthy: unhealthy,
Total: len(containers),
}, aggErr
}
func getContainerStatus(state *container.State) ContainerStats {
stat := ContainerStats{}
if state == nil {
return stat
}
switch state.Status {
case container.StateRunning:
stat.Running++
case container.StateExited, container.StateDead:
stat.Stopped++
}
if state.Health != nil {
switch state.Health.Status {
case container.Healthy:
stat.Healthy++
case container.Unhealthy:
stat.Unhealthy++
}
}
return stat
}
// This is a temporary workaround to calculate container stats for Swarm
// TODO: Remove this once we have a proper way to calculate container stats for Swarm
func CalculateContainerStatsForSwarm(containers []container.Summary) ContainerStats {
var running, stopped, healthy, unhealthy int
for _, container := range containers {
switch container.State {
case "running":
running++
case "exited", "stopped":
stopped++
}
if strings.Contains(container.Status, "(healthy)") {
healthy++
} else if strings.Contains(container.Status, "(unhealthy)") {
unhealthy++
}
}
return ContainerStats{
Running: running,
Stopped: stopped,
Healthy: healthy,
Unhealthy: unhealthy,
Total: len(containers),
}
}

View File

@ -0,0 +1,253 @@
package stats
import (
"context"
"errors"
"testing"
"time"
"github.com/docker/docker/api/types/container"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// MockDockerClient implements the DockerClient interface for testing
type MockDockerClient struct {
mock.Mock
}
func (m *MockDockerClient) ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error) {
args := m.Called(ctx, containerID)
return args.Get(0).(container.InspectResponse), args.Error(1)
}
func TestCalculateContainerStats(t *testing.T) {
mockClient := new(MockDockerClient)
// Test containers - using enough containers to test concurrent processing
containers := []container.Summary{
{ID: "container1"},
{ID: "container2"},
{ID: "container3"},
{ID: "container4"},
{ID: "container5"},
{ID: "container6"},
{ID: "container7"},
{ID: "container8"},
{ID: "container9"},
{ID: "container10"},
}
// Setup mock expectations with different container states to test various scenarios
containerStates := []struct {
id string
status string
health *container.Health
expected ContainerStats
}{
{"container1", container.StateRunning, &container.Health{Status: container.Healthy}, ContainerStats{Running: 1, Stopped: 0, Healthy: 1, Unhealthy: 0}},
{"container2", container.StateRunning, &container.Health{Status: container.Unhealthy}, ContainerStats{Running: 1, Stopped: 0, Healthy: 0, Unhealthy: 1}},
{"container3", container.StateRunning, nil, ContainerStats{Running: 1, Stopped: 0, Healthy: 0, Unhealthy: 0}},
{"container4", container.StateExited, nil, ContainerStats{Running: 0, Stopped: 1, Healthy: 0, Unhealthy: 0}},
{"container5", container.StateDead, nil, ContainerStats{Running: 0, Stopped: 1, Healthy: 0, Unhealthy: 0}},
{"container6", container.StateRunning, &container.Health{Status: container.Healthy}, ContainerStats{Running: 1, Stopped: 0, Healthy: 1, Unhealthy: 0}},
{"container7", container.StateRunning, &container.Health{Status: container.Unhealthy}, ContainerStats{Running: 1, Stopped: 0, Healthy: 0, Unhealthy: 1}},
{"container8", container.StateExited, nil, ContainerStats{Running: 0, Stopped: 1, Healthy: 0, Unhealthy: 0}},
{"container9", container.StateRunning, nil, ContainerStats{Running: 1, Stopped: 0, Healthy: 0, Unhealthy: 0}},
{"container10", container.StateDead, nil, ContainerStats{Running: 0, Stopped: 1, Healthy: 0, Unhealthy: 0}},
}
expected := ContainerStats{}
// Setup mock expectations for all containers with artificial delays to simulate real Docker calls
for _, state := range containerStates {
mockClient.On("ContainerInspect", mock.Anything, state.id).Return(container.InspectResponse{
ContainerJSONBase: &container.ContainerJSONBase{
State: &container.State{
Status: state.status,
Health: state.health,
},
},
}, nil).After(50 * time.Millisecond) // Simulate 50ms Docker API call
expected.Running += state.expected.Running
expected.Stopped += state.expected.Stopped
expected.Healthy += state.expected.Healthy
expected.Unhealthy += state.expected.Unhealthy
expected.Total++
}
// Call the function and measure time
startTime := time.Now()
stats, err := CalculateContainerStats(context.Background(), mockClient, false, containers)
require.NoError(t, err, "failed to calculate container stats")
duration := time.Since(startTime)
// Assert results
assert.Equal(t, expected, stats)
assert.Equal(t, expected.Running, stats.Running)
assert.Equal(t, expected.Stopped, stats.Stopped)
assert.Equal(t, expected.Healthy, stats.Healthy)
assert.Equal(t, expected.Unhealthy, stats.Unhealthy)
assert.Equal(t, 10, stats.Total)
// Verify concurrent processing by checking that all mock calls were made
mockClient.AssertExpectations(t)
// Test concurrency: With 5 workers and 10 containers taking 50ms each:
// Sequential would take: 10 * 50ms = 500ms
sequentialTime := 10 * 50 * time.Millisecond
// Verify that concurrent processing is actually faster than sequential
// Allow some overhead for goroutine scheduling
assert.Less(t, duration, sequentialTime, "Concurrent processing should be faster than sequential")
// Concurrent should take: ~100-150ms (depending on scheduling)
assert.Less(t, duration, 150*time.Millisecond, "Concurrent processing should be significantly faster")
assert.Greater(t, duration, 100*time.Millisecond, "Concurrent processing should be longer than 100ms")
}
func TestCalculateContainerStatsAllErrors(t *testing.T) {
mockClient := new(MockDockerClient)
// Test containers
containers := []container.Summary{
{ID: "container1"},
{ID: "container2"},
}
// Setup mock expectations with all calls returning errors
mockClient.On("ContainerInspect", mock.Anything, "container1").Return(container.InspectResponse{}, errors.New("network error"))
mockClient.On("ContainerInspect", mock.Anything, "container2").Return(container.InspectResponse{}, errors.New("permission denied"))
// Call the function
stats, err := CalculateContainerStats(context.Background(), mockClient, false, containers)
// Assert that an error was returned
require.Error(t, err, "should return error when all containers fail to inspect")
assert.Contains(t, err.Error(), "network error", "error should contain one of the original error messages")
assert.Contains(t, err.Error(), "permission denied", "error should contain the other original error message")
// Assert that stats are zero since no containers were successfully processed
expectedStats := ContainerStats{
Running: 0,
Stopped: 0,
Healthy: 0,
Unhealthy: 0,
Total: 2, // total containers processed
}
assert.Equal(t, expectedStats, stats)
// Verify all mock calls were made
mockClient.AssertExpectations(t)
}
func TestGetContainerStatus(t *testing.T) {
testCases := []struct {
name string
state *container.State
expected ContainerStats
}{
{
name: "running healthy container",
state: &container.State{
Status: container.StateRunning,
Health: &container.Health{
Status: container.Healthy,
},
},
expected: ContainerStats{
Running: 1,
Stopped: 0,
Healthy: 1,
Unhealthy: 0,
},
},
{
name: "running unhealthy container",
state: &container.State{
Status: container.StateRunning,
Health: &container.Health{
Status: container.Unhealthy,
},
},
expected: ContainerStats{
Running: 1,
Stopped: 0,
Healthy: 0,
Unhealthy: 1,
},
},
{
name: "running container without health check",
state: &container.State{
Status: container.StateRunning,
},
expected: ContainerStats{
Running: 1,
Stopped: 0,
Healthy: 0,
Unhealthy: 0,
},
},
{
name: "exited container",
state: &container.State{
Status: container.StateExited,
},
expected: ContainerStats{
Running: 0,
Stopped: 1,
Healthy: 0,
Unhealthy: 0,
},
},
{
name: "dead container",
state: &container.State{
Status: container.StateDead,
},
expected: ContainerStats{
Running: 0,
Stopped: 1,
Healthy: 0,
Unhealthy: 0,
},
},
{
name: "nil state",
state: nil,
expected: ContainerStats{
Running: 0,
Stopped: 0,
Healthy: 0,
Unhealthy: 0,
},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
stat := getContainerStatus(testCase.state)
assert.Equal(t, testCase.expected, stat)
})
}
}
func TestCalculateContainerStatsForSwarm(t *testing.T) {
containers := []container.Summary{
{State: "running"},
{State: "running", Status: "Up 5 minutes (healthy)"},
{State: "exited"},
{State: "stopped"},
{State: "running", Status: "Up 10 minutes"},
{State: "running", Status: "Up about an hour (unhealthy)"},
}
stats := CalculateContainerStatsForSwarm(containers)
assert.Equal(t, 4, stats.Running)
assert.Equal(t, 2, stats.Stopped)
assert.Equal(t, 1, stats.Healthy)
assert.Equal(t, 1, stats.Unhealthy)
assert.Equal(t, 6, stats.Total)
}

View File

@ -49,6 +49,11 @@ type (
// Is relative path supported // Is relative path supported
SupportRelativePath bool SupportRelativePath bool
// AlwaysCloneGitRepoForRelativePath is a flag indicating if the agent must always clone the git repository for relative path.
// This field is only valid when SupportRelativePath is true.
// Used only for EE
AlwaysCloneGitRepoForRelativePath bool
// Mount point for relative path // Mount point for relative path
FilesystemPath string FilesystemPath string
// Used only for EE // Used only for EE

View File

@ -8,7 +8,9 @@ import (
"testing" "testing"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_createEnvFile(t *testing.T) { func Test_createEnvFile(t *testing.T) {
@ -61,7 +63,7 @@ func Test_createEnvFile(t *testing.T) {
assert.Equal(t, tt.expected, string(content)) assert.Equal(t, tt.expected, string(content))
} else { } else {
assert.Equal(t, "", result) assert.Empty(t, result)
} }
}) })
} }
@ -79,7 +81,7 @@ func Test_createEnvFile_mergesDefultAndInplaceEnvVars(t *testing.T) {
} }
result, err := createEnvFile(stack) result, err := createEnvFile(stack)
assert.Equal(t, filepath.Join(stack.ProjectPath, "stack.env"), result) assert.Equal(t, filepath.Join(stack.ProjectPath, "stack.env"), result)
assert.NoError(t, err) require.NoError(t, err)
assert.FileExists(t, path.Join(dir, "stack.env")) assert.FileExists(t, path.Join(dir, "stack.env"))
f, _ := os.Open(path.Join(dir, "stack.env")) f, _ := os.Open(path.Join(dir, "stack.env"))
content, _ := io.ReadAll(f) content, _ := io.ReadAll(f)

View File

@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
type mockKubectlClient struct { type mockKubectlClient struct {
@ -68,7 +69,7 @@ func TestExecuteKubectlOperation_Apply_Success(t *testing.T) {
manifests := []string{"manifest1.yaml", "manifest2.yaml"} manifests := []string{"manifest1.yaml", "manifest2.yaml"}
err := testExecuteKubectlOperation(mockClient, "apply", manifests) err := testExecuteKubectlOperation(mockClient, "apply", manifests)
assert.NoError(t, err) require.NoError(t, err)
assert.True(t, called) assert.True(t, called)
} }
@ -86,7 +87,7 @@ func TestExecuteKubectlOperation_Apply_Error(t *testing.T) {
manifests := []string{"error.yaml"} manifests := []string{"error.yaml"}
err := testExecuteKubectlOperation(mockClient, "apply", manifests) err := testExecuteKubectlOperation(mockClient, "apply", manifests)
assert.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), expectedErr.Error()) assert.Contains(t, err.Error(), expectedErr.Error())
assert.True(t, called) assert.True(t, called)
} }
@ -104,7 +105,7 @@ func TestExecuteKubectlOperation_Delete_Success(t *testing.T) {
manifests := []string{"manifest1.yaml"} manifests := []string{"manifest1.yaml"}
err := testExecuteKubectlOperation(mockClient, "delete", manifests) err := testExecuteKubectlOperation(mockClient, "delete", manifests)
assert.NoError(t, err) require.NoError(t, err)
assert.True(t, called) assert.True(t, called)
} }
@ -122,7 +123,7 @@ func TestExecuteKubectlOperation_Delete_Error(t *testing.T) {
manifests := []string{"error.yaml"} manifests := []string{"error.yaml"}
err := testExecuteKubectlOperation(mockClient, "delete", manifests) err := testExecuteKubectlOperation(mockClient, "delete", manifests)
assert.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), expectedErr.Error()) assert.Contains(t, err.Error(), expectedErr.Error())
assert.True(t, called) assert.True(t, called)
} }
@ -140,7 +141,7 @@ func TestExecuteKubectlOperation_RolloutRestart_Success(t *testing.T) {
resources := []string{"deployment/nginx"} resources := []string{"deployment/nginx"}
err := testExecuteKubectlOperation(mockClient, "rollout-restart", resources) err := testExecuteKubectlOperation(mockClient, "rollout-restart", resources)
assert.NoError(t, err) require.NoError(t, err)
assert.True(t, called) assert.True(t, called)
} }
@ -158,7 +159,7 @@ func TestExecuteKubectlOperation_RolloutRestart_Error(t *testing.T) {
resources := []string{"deployment/error"} resources := []string{"deployment/error"}
err := testExecuteKubectlOperation(mockClient, "rollout-restart", resources) err := testExecuteKubectlOperation(mockClient, "rollout-restart", resources)
assert.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), expectedErr.Error()) assert.Contains(t, err.Error(), expectedErr.Error())
assert.True(t, called) assert.True(t, called)
} }
@ -168,6 +169,6 @@ func TestExecuteKubectlOperation_UnsupportedOperation(t *testing.T) {
err := testExecuteKubectlOperation(mockClient, "unsupported", []string{}) err := testExecuteKubectlOperation(mockClient, "unsupported", []string{})
assert.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported operation") assert.Contains(t, err.Error(), "unsupported operation")
} }

View File

@ -7,12 +7,13 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_copyFile_returnsError_whenSourceDoesNotExist(t *testing.T) { func Test_copyFile_returnsError_whenSourceDoesNotExist(t *testing.T) {
tmpdir := t.TempDir() tmpdir := t.TempDir()
err := copyFile("does-not-exist", tmpdir) err := copyFile("does-not-exist", tmpdir)
assert.Error(t, err) require.Error(t, err)
} }
func Test_copyFile_shouldMakeAbackup(t *testing.T) { func Test_copyFile_shouldMakeAbackup(t *testing.T) {
@ -21,7 +22,7 @@ func Test_copyFile_shouldMakeAbackup(t *testing.T) {
os.WriteFile(path.Join(tmpdir, "origin"), content, 0600) os.WriteFile(path.Join(tmpdir, "origin"), content, 0600)
err := copyFile(path.Join(tmpdir, "origin"), path.Join(tmpdir, "copy")) err := copyFile(path.Join(tmpdir, "origin"), path.Join(tmpdir, "copy"))
assert.NoError(t, err) require.NoError(t, err)
copyContent, _ := os.ReadFile(path.Join(tmpdir, "copy")) copyContent, _ := os.ReadFile(path.Join(tmpdir, "copy"))
assert.Equal(t, content, copyContent) assert.Equal(t, content, copyContent)
@ -30,7 +31,7 @@ func Test_copyFile_shouldMakeAbackup(t *testing.T) {
func Test_CopyDir_shouldCopyAllFilesAndDirectories(t *testing.T) { func Test_CopyDir_shouldCopyAllFilesAndDirectories(t *testing.T) {
destination := t.TempDir() destination := t.TempDir()
err := CopyDir("./testdata/copy_test", destination, true) err := CopyDir("./testdata/copy_test", destination, true)
assert.NoError(t, err) require.NoError(t, err)
assert.FileExists(t, filepath.Join(destination, "copy_test", "outer")) assert.FileExists(t, filepath.Join(destination, "copy_test", "outer"))
assert.FileExists(t, filepath.Join(destination, "copy_test", "dir", ".dotfile")) assert.FileExists(t, filepath.Join(destination, "copy_test", "dir", ".dotfile"))
@ -40,7 +41,7 @@ func Test_CopyDir_shouldCopyAllFilesAndDirectories(t *testing.T) {
func Test_CopyDir_shouldCopyOnlyDirContents(t *testing.T) { func Test_CopyDir_shouldCopyOnlyDirContents(t *testing.T) {
destination := t.TempDir() destination := t.TempDir()
err := CopyDir("./testdata/copy_test", destination, false) err := CopyDir("./testdata/copy_test", destination, false)
assert.NoError(t, err) require.NoError(t, err)
assert.FileExists(t, filepath.Join(destination, "outer")) assert.FileExists(t, filepath.Join(destination, "outer"))
assert.FileExists(t, filepath.Join(destination, "dir", ".dotfile")) assert.FileExists(t, filepath.Join(destination, "dir", ".dotfile"))
@ -50,7 +51,7 @@ func Test_CopyDir_shouldCopyOnlyDirContents(t *testing.T) {
func Test_CopyPath_shouldSkipWhenNotExist(t *testing.T) { func Test_CopyPath_shouldSkipWhenNotExist(t *testing.T) {
tmpdir := t.TempDir() tmpdir := t.TempDir()
err := CopyPath("does-not-exists", tmpdir) err := CopyPath("does-not-exists", tmpdir)
assert.NoError(t, err) require.NoError(t, err)
assert.NoFileExists(t, tmpdir) assert.NoFileExists(t, tmpdir)
} }
@ -62,17 +63,17 @@ func Test_CopyPath_shouldCopyFile(t *testing.T) {
os.MkdirAll(path.Join(tmpdir, "backup"), 0700) os.MkdirAll(path.Join(tmpdir, "backup"), 0700)
err := CopyPath(path.Join(tmpdir, "file"), path.Join(tmpdir, "backup")) err := CopyPath(path.Join(tmpdir, "file"), path.Join(tmpdir, "backup"))
assert.NoError(t, err) require.NoError(t, err)
copyContent, err := os.ReadFile(path.Join(tmpdir, "backup", "file")) copyContent, err := os.ReadFile(path.Join(tmpdir, "backup", "file"))
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, content, copyContent) assert.Equal(t, content, copyContent)
} }
func Test_CopyPath_shouldCopyDir(t *testing.T) { func Test_CopyPath_shouldCopyDir(t *testing.T) {
destination := t.TempDir() destination := t.TempDir()
err := CopyPath("./testdata/copy_test", destination) err := CopyPath("./testdata/copy_test", destination)
assert.NoError(t, err) require.NoError(t, err)
assert.FileExists(t, filepath.Join(destination, "copy_test", "outer")) assert.FileExists(t, filepath.Join(destination, "copy_test", "outer"))
assert.FileExists(t, filepath.Join(destination, "copy_test", "dir", ".dotfile")) assert.FileExists(t, filepath.Join(destination, "copy_test", "dir", ".dotfile"))

View File

@ -848,7 +848,7 @@ func defaultMTLSCertPathUnderFileStore() (string, string, string) {
return caCertPath, certPath, keyPath return caCertPath, certPath, keyPath
} }
// GetDefaultChiselPrivateKeyPath returns the chisle private key path // GetDefaultChiselPrivateKeyPath returns the chisel private key path
func (service *Service) GetDefaultChiselPrivateKeyPath() string { func (service *Service) GetDefaultChiselPrivateKeyPath() string {
privateKeyPath := defaultChiselPrivateKeyPathUnderFileStore() privateKeyPath := defaultChiselPrivateKeyPathUnderFileStore()
return service.wrapFileStore(privateKeyPath) return service.wrapFileStore(privateKeyPath)

View File

@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_fileSystemService_FileExists_whenFileExistsShouldReturnTrue(t *testing.T) { func Test_fileSystemService_FileExists_whenFileExistsShouldReturnTrue(t *testing.T) {
@ -30,14 +31,14 @@ func Test_FileExists_whenFileNotExistsShouldReturnFalse(t *testing.T) {
func testHelperFileExists_fileExists(t *testing.T, checker func(path string) (bool, error)) { func testHelperFileExists_fileExists(t *testing.T, checker func(path string) (bool, error)) {
file, err := os.CreateTemp("", t.Name()) file, err := os.CreateTemp("", t.Name())
assert.NoError(t, err, "CreateTemp should not fail") require.NoError(t, err, "CreateTemp should not fail")
t.Cleanup(func() { t.Cleanup(func() {
os.RemoveAll(file.Name()) os.RemoveAll(file.Name())
}) })
exists, err := checker(file.Name()) exists, err := checker(file.Name())
assert.NoError(t, err, "FileExists should not fail") require.NoError(t, err, "FileExists should not fail")
assert.True(t, exists) assert.True(t, exists)
} }
@ -46,10 +47,10 @@ func testHelperFileExists_fileNotExists(t *testing.T, checker func(path string)
filePath := path.Join(t.TempDir(), fmt.Sprintf("%s%d", t.Name(), rand.Int())) filePath := path.Join(t.TempDir(), fmt.Sprintf("%s%d", t.Name(), rand.Int()))
err := os.RemoveAll(filePath) err := os.RemoveAll(filePath)
assert.NoError(t, err, "RemoveAll should not fail") require.NoError(t, err, "RemoveAll should not fail")
exists, err := checker(filePath) exists, err := checker(filePath)
assert.NoError(t, err, "FileExists should not fail") require.NoError(t, err, "FileExists should not fail")
assert.False(t, exists) assert.False(t, exists)
} }

View File

@ -6,6 +6,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
var content = []byte("content") var content = []byte("content")
@ -13,25 +14,25 @@ var content = []byte("content")
func Test_movePath_shouldFailIfSourceDirDoesNotExist(t *testing.T) { func Test_movePath_shouldFailIfSourceDirDoesNotExist(t *testing.T) {
sourceDir := "missing" sourceDir := "missing"
destinationDir := t.TempDir() destinationDir := t.TempDir()
file1 := addFile(destinationDir, "dir", "file") file1 := addFile(t, destinationDir, "dir", "file")
file2 := addFile(destinationDir, "file") file2 := addFile(t, destinationDir, "file")
err := MoveDirectory(sourceDir, destinationDir, false) err := MoveDirectory(sourceDir, destinationDir, false)
assert.Error(t, err, "move directory should fail when source path is missing") require.Error(t, err, "move directory should fail when source path is missing")
assert.FileExists(t, file1, "destination dir contents should remain") assert.FileExists(t, file1, "destination dir contents should remain")
assert.FileExists(t, file2, "destination dir contents should remain") assert.FileExists(t, file2, "destination dir contents should remain")
} }
func Test_movePath_shouldFailIfDestinationDirExists(t *testing.T) { func Test_movePath_shouldFailIfDestinationDirExists(t *testing.T) {
sourceDir := t.TempDir() sourceDir := t.TempDir()
file1 := addFile(sourceDir, "dir", "file") file1 := addFile(t, sourceDir, "dir", "file")
file2 := addFile(sourceDir, "file") file2 := addFile(t, sourceDir, "file")
destinationDir := t.TempDir() destinationDir := t.TempDir()
file3 := addFile(destinationDir, "dir", "file") file3 := addFile(t, destinationDir, "dir", "file")
file4 := addFile(destinationDir, "file") file4 := addFile(t, destinationDir, "file")
err := MoveDirectory(sourceDir, destinationDir, false) err := MoveDirectory(sourceDir, destinationDir, false)
assert.Error(t, err, "move directory should fail when destination directory already exists") require.Error(t, err, "move directory should fail when destination directory already exists")
assert.FileExists(t, file1, "source dir contents should remain") assert.FileExists(t, file1, "source dir contents should remain")
assert.FileExists(t, file2, "source dir contents should remain") assert.FileExists(t, file2, "source dir contents should remain")
assert.FileExists(t, file3, "destination dir contents should remain") assert.FileExists(t, file3, "destination dir contents should remain")
@ -40,14 +41,14 @@ func Test_movePath_shouldFailIfDestinationDirExists(t *testing.T) {
func Test_movePath_succesIfOverwriteSetWhenDestinationDirExists(t *testing.T) { func Test_movePath_succesIfOverwriteSetWhenDestinationDirExists(t *testing.T) {
sourceDir := t.TempDir() sourceDir := t.TempDir()
file1 := addFile(sourceDir, "dir", "file") file1 := addFile(t, sourceDir, "dir", "file")
file2 := addFile(sourceDir, "file") file2 := addFile(t, sourceDir, "file")
destinationDir := t.TempDir() destinationDir := t.TempDir()
file3 := addFile(destinationDir, "dir", "file") file3 := addFile(t, destinationDir, "dir", "file")
file4 := addFile(destinationDir, "file") file4 := addFile(t, destinationDir, "file")
err := MoveDirectory(sourceDir, destinationDir, true) err := MoveDirectory(sourceDir, destinationDir, true)
assert.NoError(t, err) require.NoError(t, err)
assert.NoFileExists(t, file1, "source dir contents should be moved") assert.NoFileExists(t, file1, "source dir contents should be moved")
assert.NoFileExists(t, file2, "source dir contents should be moved") assert.NoFileExists(t, file2, "source dir contents should be moved")
assert.FileExists(t, file3, "destination dir contents should remain") assert.FileExists(t, file3, "destination dir contents should remain")
@ -58,32 +59,34 @@ func Test_movePath_successWhenSourceExistsAndDestinationIsMissing(t *testing.T)
tmp := t.TempDir() tmp := t.TempDir()
sourceDir := path.Join(tmp, "source") sourceDir := path.Join(tmp, "source")
os.Mkdir(sourceDir, 0766) os.Mkdir(sourceDir, 0766)
file1 := addFile(sourceDir, "dir", "file") file1 := addFile(t, sourceDir, "dir", "file")
file2 := addFile(sourceDir, "file") file2 := addFile(t, sourceDir, "file")
destinationDir := path.Join(tmp, "destination") destinationDir := path.Join(tmp, "destination")
err := MoveDirectory(sourceDir, destinationDir, false) err := MoveDirectory(sourceDir, destinationDir, false)
assert.NoError(t, err) require.NoError(t, err)
assert.NoFileExists(t, file1, "source dir contents should be moved") assert.NoFileExists(t, file1, "source dir contents should be moved")
assert.NoFileExists(t, file2, "source dir contents should be moved") assert.NoFileExists(t, file2, "source dir contents should be moved")
assertFileContent(t, path.Join(destinationDir, "file")) assertFileContent(t, path.Join(destinationDir, "file"))
assertFileContent(t, path.Join(destinationDir, "dir", "file")) assertFileContent(t, path.Join(destinationDir, "dir", "file"))
} }
func addFile(fileParts ...string) (filepath string) { func addFile(t *testing.T, fileParts ...string) (filepath string) {
if len(fileParts) > 2 { if len(fileParts) > 2 {
dir := path.Join(fileParts[:len(fileParts)-1]...) dir := path.Join(fileParts[:len(fileParts)-1]...)
os.MkdirAll(dir, 0766) err := os.MkdirAll(dir, 0766)
require.NoError(t, err)
} }
p := path.Join(fileParts...) p := path.Join(fileParts...)
os.WriteFile(p, content, 0766) err := os.WriteFile(p, content, 0766)
require.NoError(t, err)
return p return p
} }
func assertFileContent(t *testing.T, filePath string) { func assertFileContent(t *testing.T, filePath string) {
actualContent, err := os.ReadFile(filePath) actualContent, err := os.ReadFile(filePath)
assert.NoErrorf(t, err, "failed to read file %s", filePath) require.NoError(t, err, "failed to read file %s", filePath)
assert.Equal(t, content, actualContent, "file %s content doesn't match", filePath) assert.Equal(t, content, actualContent, "file %s content doesn't match", filePath)
} }

View File

@ -5,14 +5,14 @@ import (
"path" "path"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func createService(t *testing.T) *Service { func createService(t *testing.T) *Service {
dataStorePath := path.Join(t.TempDir(), t.Name()) dataStorePath := path.Join(t.TempDir(), t.Name())
service, err := NewService(dataStorePath, "") service, err := NewService(dataStorePath, "")
assert.NoError(t, err, "NewService should not fail") require.NoError(t, err, "NewService should not fail")
t.Cleanup(func() { t.Cleanup(func() {
os.RemoveAll(dataStorePath) os.RemoveAll(dataStorePath)

View File

@ -6,6 +6,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_WriteFile_CanStoreContentInANewFile(t *testing.T) { func Test_WriteFile_CanStoreContentInANewFile(t *testing.T) {
@ -14,7 +15,7 @@ func Test_WriteFile_CanStoreContentInANewFile(t *testing.T) {
content := []byte("content") content := []byte("content")
err := WriteToFile(tmpFilePath, content) err := WriteToFile(tmpFilePath, content)
assert.NoError(t, err) require.NoError(t, err)
fileContent, _ := os.ReadFile(tmpFilePath) fileContent, _ := os.ReadFile(tmpFilePath)
assert.Equal(t, content, fileContent) assert.Equal(t, content, fileContent)
@ -25,11 +26,11 @@ func Test_WriteFile_CanOverwriteExistingFile(t *testing.T) {
tmpFilePath := path.Join(tmpDir, "dummy") tmpFilePath := path.Join(tmpDir, "dummy")
err := WriteToFile(tmpFilePath, []byte("content")) err := WriteToFile(tmpFilePath, []byte("content"))
assert.NoError(t, err) require.NoError(t, err)
content := []byte("new content") content := []byte("new content")
err = WriteToFile(tmpFilePath, content) err = WriteToFile(tmpFilePath, content)
assert.NoError(t, err) require.NoError(t, err)
fileContent, _ := os.ReadFile(tmpFilePath) fileContent, _ := os.ReadFile(tmpFilePath)
assert.Equal(t, content, fileContent) assert.Equal(t, content, fileContent)
@ -41,7 +42,7 @@ func Test_WriteFile_CanWriteANestedPath(t *testing.T) {
content := []byte("content") content := []byte("content")
err := WriteToFile(tmpFilePath, content) err := WriteToFile(tmpFilePath, content)
assert.NoError(t, err) require.NoError(t, err)
fileContent, _ := os.ReadFile(tmpFilePath) fileContent, _ := os.ReadFile(tmpFilePath)
assert.Equal(t, content, fileContent) assert.Equal(t, content, fileContent)

View File

@ -12,6 +12,7 @@ import (
_ "github.com/joho/godotenv/autoload" _ "github.com/joho/godotenv/autoload"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
const privateAzureRepoURL = "https://portainer.visualstudio.com/gitops-test/_git/gitops-test" const privateAzureRepoURL = "https://portainer.visualstudio.com/gitops-test/_git/gitops-test"
@ -67,7 +68,7 @@ func TestService_ClonePublicRepository_Azure(t *testing.T) {
gittypes.GitCredentialAuthType_Basic, gittypes.GitCredentialAuthType_Basic,
false, false,
) )
assert.NoError(t, err) require.NoError(t, err)
assert.FileExists(t, filepath.Join(dst, "README.md")) assert.FileExists(t, filepath.Join(dst, "README.md"))
}) })
} }
@ -90,7 +91,7 @@ func TestService_ClonePrivateRepository_Azure(t *testing.T) {
gittypes.GitCredentialAuthType_Basic, gittypes.GitCredentialAuthType_Basic,
false, false,
) )
assert.NoError(t, err) require.NoError(t, err)
assert.FileExists(t, filepath.Join(dst, "README.md")) assert.FileExists(t, filepath.Join(dst, "README.md"))
} }
@ -108,7 +109,7 @@ func TestService_LatestCommitID_Azure(t *testing.T) {
gittypes.GitCredentialAuthType_Basic, gittypes.GitCredentialAuthType_Basic,
false, false,
) )
assert.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, id, "cannot guarantee commit id, but it should be not empty") assert.NotEmpty(t, id, "cannot guarantee commit id, but it should be not empty")
} }
@ -127,7 +128,7 @@ func TestService_ListRefs_Azure(t *testing.T) {
false, false,
false, false,
) )
assert.NoError(t, err) require.NoError(t, err)
assert.GreaterOrEqual(t, len(refs), 1) assert.GreaterOrEqual(t, len(refs), 1)
} }
@ -289,14 +290,14 @@ func TestService_ListFiles_Azure(t *testing.T) {
false, false,
) )
if tt.expect.shouldFail { if tt.expect.shouldFail {
assert.Error(t, err) require.Error(t, err)
if tt.expect.err != nil { if tt.expect.err != nil {
assert.Equal(t, tt.expect.err, err) assert.Equal(t, tt.expect.err, err)
} }
} else { } else {
assert.NoError(t, err) require.NoError(t, err)
if tt.expect.matchedCount > 0 { if tt.expect.matchedCount > 0 {
assert.Greater(t, len(paths), 0) assert.NotEmpty(t, paths)
} }
} }
}) })

View File

@ -8,7 +8,10 @@ import (
"testing" "testing"
gittypes "github.com/portainer/portainer/api/git/types" gittypes "github.com/portainer/portainer/api/git/types"
"github.com/portainer/portainer/pkg/fips"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_buildDownloadUrl(t *testing.T) { func Test_buildDownloadUrl(t *testing.T) {
@ -18,16 +21,19 @@ func Test_buildDownloadUrl(t *testing.T) {
project: "project", project: "project",
repository: "repository", repository: "repository",
}, "refs/heads/main") }, "refs/heads/main")
require.NoError(t, err)
expectedUrl, err := url.Parse("https://dev.azure.com/organisation/project/_apis/git/repositories/repository/items?scopePath=/&download=true&versionDescriptor.version=main&$format=zip&recursionLevel=full&api-version=6.0&versionDescriptor.versionType=branch")
require.NoError(t, err)
actualUrl, err := url.Parse(u)
require.NoError(t, err)
expectedUrl, _ := url.Parse("https://dev.azure.com/organisation/project/_apis/git/repositories/repository/items?scopePath=/&download=true&versionDescriptor.version=main&$format=zip&recursionLevel=full&api-version=6.0&versionDescriptor.versionType=branch")
actualUrl, _ := url.Parse(u)
if assert.NoError(t, err) {
assert.Equal(t, expectedUrl.Host, actualUrl.Host) assert.Equal(t, expectedUrl.Host, actualUrl.Host)
assert.Equal(t, expectedUrl.Scheme, actualUrl.Scheme) assert.Equal(t, expectedUrl.Scheme, actualUrl.Scheme)
assert.Equal(t, expectedUrl.Path, actualUrl.Path) assert.Equal(t, expectedUrl.Path, actualUrl.Path)
assert.Equal(t, expectedUrl.Query(), actualUrl.Query()) assert.Equal(t, expectedUrl.Query(), actualUrl.Query())
} }
}
func Test_buildRootItemUrl(t *testing.T) { func Test_buildRootItemUrl(t *testing.T) {
a := NewAzureClient() a := NewAzureClient()
@ -39,7 +45,7 @@ func Test_buildRootItemUrl(t *testing.T) {
expectedUrl, _ := url.Parse("https://dev.azure.com/organisation/project/_apis/git/repositories/repository/items?scopePath=/&api-version=6.0&versionDescriptor.version=main&versionDescriptor.versionType=branch") expectedUrl, _ := url.Parse("https://dev.azure.com/organisation/project/_apis/git/repositories/repository/items?scopePath=/&api-version=6.0&versionDescriptor.version=main&versionDescriptor.versionType=branch")
actualUrl, _ := url.Parse(u) actualUrl, _ := url.Parse(u)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, expectedUrl.Host, actualUrl.Host) assert.Equal(t, expectedUrl.Host, actualUrl.Host)
assert.Equal(t, expectedUrl.Scheme, actualUrl.Scheme) assert.Equal(t, expectedUrl.Scheme, actualUrl.Scheme)
assert.Equal(t, expectedUrl.Path, actualUrl.Path) assert.Equal(t, expectedUrl.Path, actualUrl.Path)
@ -56,7 +62,7 @@ func Test_buildRefsUrl(t *testing.T) {
expectedUrl, _ := url.Parse("https://dev.azure.com/organisation/project/_apis/git/repositories/repository/refs?api-version=6.0") expectedUrl, _ := url.Parse("https://dev.azure.com/organisation/project/_apis/git/repositories/repository/refs?api-version=6.0")
actualUrl, _ := url.Parse(u) actualUrl, _ := url.Parse(u)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, expectedUrl.Host, actualUrl.Host) assert.Equal(t, expectedUrl.Host, actualUrl.Host)
assert.Equal(t, expectedUrl.Scheme, actualUrl.Scheme) assert.Equal(t, expectedUrl.Scheme, actualUrl.Scheme)
assert.Equal(t, expectedUrl.Path, actualUrl.Path) assert.Equal(t, expectedUrl.Path, actualUrl.Path)
@ -73,7 +79,7 @@ func Test_buildTreeUrl(t *testing.T) {
expectedUrl, _ := url.Parse("https://dev.azure.com/organisation/project/_apis/git/repositories/repository/trees/sha1?api-version=6.0&recursive=true") expectedUrl, _ := url.Parse("https://dev.azure.com/organisation/project/_apis/git/repositories/repository/trees/sha1?api-version=6.0&recursive=true")
actualUrl, _ := url.Parse(u) actualUrl, _ := url.Parse(u)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, expectedUrl.Host, actualUrl.Host) assert.Equal(t, expectedUrl.Host, actualUrl.Host)
assert.Equal(t, expectedUrl.Scheme, actualUrl.Scheme) assert.Equal(t, expectedUrl.Scheme, actualUrl.Scheme)
assert.Equal(t, expectedUrl.Path, actualUrl.Path) assert.Equal(t, expectedUrl.Path, actualUrl.Path)
@ -234,6 +240,8 @@ func Test_isAzureUrl(t *testing.T) {
} }
func Test_azureDownloader_downloadZipFromAzureDevOps(t *testing.T) { func Test_azureDownloader_downloadZipFromAzureDevOps(t *testing.T) {
fips.InitFIPS(false)
type args struct { type args struct {
options baseOption options baseOption
} }
@ -301,13 +309,15 @@ func Test_azureDownloader_downloadZipFromAzureDevOps(t *testing.T) {
}, },
} }
_, err := a.downloadZipFromAzureDevOps(context.Background(), option) _, err := a.downloadZipFromAzureDevOps(context.Background(), option)
assert.Error(t, err) require.Error(t, err)
assert.Equal(t, tt.want, zipRequestAuth) assert.Equal(t, tt.want, zipRequestAuth)
}) })
} }
} }
func Test_azureDownloader_latestCommitID(t *testing.T) { func Test_azureDownloader_latestCommitID(t *testing.T) {
fips.InitFIPS(false)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := `{ response := `{
"count": 1, "count": 1,
@ -497,12 +507,12 @@ func Test_listRefs_azure(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
refs, err := client.listRefs(context.TODO(), tt.args) refs, err := client.listRefs(context.TODO(), tt.args)
if tt.expect.err == nil { if tt.expect.err == nil {
assert.NoError(t, err) require.NoError(t, err)
if tt.expect.refsCount > 0 { if tt.expect.refsCount > 0 {
assert.Greater(t, len(refs), 0) assert.NotEmpty(t, refs)
} }
} else { } else {
assert.Error(t, err) require.Error(t, err)
assert.Equal(t, tt.expect.err, err) assert.Equal(t, tt.expect.err, err)
} }
}) })
@ -608,14 +618,14 @@ func Test_listFiles_azure(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
paths, err := client.listFiles(context.TODO(), tt.args) paths, err := client.listFiles(context.TODO(), tt.args)
if tt.expect.shouldFail { if tt.expect.shouldFail {
assert.Error(t, err) require.Error(t, err)
if tt.expect.err != nil { if tt.expect.err != nil {
assert.Equal(t, tt.expect.err, err) assert.Equal(t, tt.expect.err, err)
} }
} else { } else {
assert.NoError(t, err) require.NoError(t, err)
if tt.expect.matchedCount > 0 { if tt.expect.matchedCount > 0 {
assert.Greater(t, len(paths), 0) assert.NotEmpty(t, paths)
} }
} }
}) })

View File

@ -9,7 +9,9 @@ import (
"time" "time"
gittypes "github.com/portainer/portainer/api/git/types" gittypes "github.com/portainer/portainer/api/git/types"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
const ( const (
@ -35,7 +37,7 @@ func TestService_ClonePrivateRepository_GitHub(t *testing.T) {
gittypes.GitCredentialAuthType_Basic, gittypes.GitCredentialAuthType_Basic,
false, false,
) )
assert.NoError(t, err) require.NoError(t, err)
assert.FileExists(t, filepath.Join(dst, "README.md")) assert.FileExists(t, filepath.Join(dst, "README.md"))
} }
@ -55,7 +57,7 @@ func TestService_LatestCommitID_GitHub(t *testing.T) {
gittypes.GitCredentialAuthType_Basic, gittypes.GitCredentialAuthType_Basic,
false, false,
) )
assert.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, id, "cannot guarantee commit id, but it should be not empty") assert.NotEmpty(t, id, "cannot guarantee commit id, but it should be not empty")
} }
@ -68,7 +70,7 @@ func TestService_ListRefs_GitHub(t *testing.T) {
repositoryUrl := privateGitRepoURL repositoryUrl := privateGitRepoURL
refs, err := service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false) refs, err := service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
assert.NoError(t, err) require.NoError(t, err)
assert.GreaterOrEqual(t, len(refs), 1) assert.GreaterOrEqual(t, len(refs), 1)
} }
@ -231,14 +233,14 @@ func TestService_ListFiles_GitHub(t *testing.T) {
false, false,
) )
if tt.expect.shouldFail { if tt.expect.shouldFail {
assert.Error(t, err) require.Error(t, err)
if tt.expect.err != nil { if tt.expect.err != nil {
assert.Equal(t, tt.expect.err, err) assert.Equal(t, tt.expect.err, err)
} }
} else { } else {
assert.NoError(t, err) require.NoError(t, err)
if tt.expect.matchedCount > 0 { if tt.expect.matchedCount > 0 {
assert.Greater(t, len(paths), 0) assert.NotEmpty(t, paths)
} }
} }
}) })
@ -361,12 +363,12 @@ func TestService_HardRefresh_ListRefs_GitHub(t *testing.T) {
repositoryUrl := privateGitRepoURL repositoryUrl := privateGitRepoURL
refs, err := service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false) refs, err := service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
assert.NoError(t, err) require.NoError(t, err)
assert.GreaterOrEqual(t, len(refs), 1) assert.GreaterOrEqual(t, len(refs), 1)
assert.Equal(t, 1, service.repoRefCache.Len()) assert.Equal(t, 1, service.repoRefCache.Len())
_, err = service.ListRefs(repositoryUrl, username, "fake-token", gittypes.GitCredentialAuthType_Basic, false, false) _, err = service.ListRefs(repositoryUrl, username, "fake-token", gittypes.GitCredentialAuthType_Basic, false, false)
assert.Error(t, err) require.Error(t, err)
assert.Equal(t, 1, service.repoRefCache.Len()) assert.Equal(t, 1, service.repoRefCache.Len())
} }
@ -379,7 +381,7 @@ func TestService_HardRefresh_ListRefs_And_RemoveAllCaches_GitHub(t *testing.T) {
repositoryUrl := privateGitRepoURL repositoryUrl := privateGitRepoURL
refs, err := service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false) refs, err := service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
assert.NoError(t, err) require.NoError(t, err)
assert.GreaterOrEqual(t, len(refs), 1) assert.GreaterOrEqual(t, len(refs), 1)
assert.Equal(t, 1, service.repoRefCache.Len()) assert.Equal(t, 1, service.repoRefCache.Len())
@ -394,7 +396,7 @@ func TestService_HardRefresh_ListRefs_And_RemoveAllCaches_GitHub(t *testing.T) {
[]string{}, []string{},
false, false,
) )
assert.NoError(t, err) require.NoError(t, err)
assert.GreaterOrEqual(t, len(files), 1) assert.GreaterOrEqual(t, len(files), 1)
assert.Equal(t, 1, service.repoFileCache.Len()) assert.Equal(t, 1, service.repoFileCache.Len())
@ -409,16 +411,16 @@ func TestService_HardRefresh_ListRefs_And_RemoveAllCaches_GitHub(t *testing.T) {
[]string{}, []string{},
false, false,
) )
assert.NoError(t, err) require.NoError(t, err)
assert.GreaterOrEqual(t, len(files), 1) assert.GreaterOrEqual(t, len(files), 1)
assert.Equal(t, 2, service.repoFileCache.Len()) assert.Equal(t, 2, service.repoFileCache.Len())
_, err = service.ListRefs(repositoryUrl, username, "fake-token", gittypes.GitCredentialAuthType_Basic, false, false) _, err = service.ListRefs(repositoryUrl, username, "fake-token", gittypes.GitCredentialAuthType_Basic, false, false)
assert.Error(t, err) require.Error(t, err)
assert.Equal(t, 1, service.repoRefCache.Len()) assert.Equal(t, 1, service.repoRefCache.Len())
_, err = service.ListRefs(repositoryUrl, username, "fake-token", gittypes.GitCredentialAuthType_Basic, true, false) _, err = service.ListRefs(repositoryUrl, username, "fake-token", gittypes.GitCredentialAuthType_Basic, true, false)
assert.Error(t, err) require.Error(t, err)
assert.Equal(t, 1, service.repoRefCache.Len()) assert.Equal(t, 1, service.repoRefCache.Len())
// The relevant file caches should be removed too // The relevant file caches should be removed too
assert.Equal(t, 0, service.repoFileCache.Len()) assert.Equal(t, 0, service.repoFileCache.Len())
@ -442,7 +444,7 @@ func TestService_HardRefresh_ListFiles_GitHub(t *testing.T) {
[]string{}, []string{},
false, false,
) )
assert.NoError(t, err) require.NoError(t, err)
assert.GreaterOrEqual(t, len(files), 1) assert.GreaterOrEqual(t, len(files), 1)
assert.Equal(t, 1, service.repoFileCache.Len()) assert.Equal(t, 1, service.repoFileCache.Len())
@ -457,7 +459,7 @@ func TestService_HardRefresh_ListFiles_GitHub(t *testing.T) {
[]string{}, []string{},
false, false,
) )
assert.Error(t, err) require.Error(t, err)
assert.Equal(t, 0, service.repoFileCache.Len()) assert.Equal(t, 0, service.repoFileCache.Len())
} }

View File

@ -13,6 +13,7 @@ import (
"github.com/go-git/go-git/v5/plumbing/object" "github.com/go-git/go-git/v5/plumbing/object"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func setup(t *testing.T) string { func setup(t *testing.T) string {
@ -39,7 +40,7 @@ func Test_ClonePublicRepository_Shallow(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
t.Logf("Cloning into %s", dir) t.Logf("Cloning into %s", dir)
err := service.CloneRepository(dir, repositoryURL, referenceName, "", "", gittypes.GitCredentialAuthType_Basic, false) err := service.CloneRepository(dir, repositoryURL, referenceName, "", "", gittypes.GitCredentialAuthType_Basic, false)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, getCommitHistoryLength(t, dir), "cloned repo has incorrect depth") assert.Equal(t, 1, getCommitHistoryLength(t, dir), "cloned repo has incorrect depth")
} }
@ -51,7 +52,7 @@ func Test_ClonePublicRepository_NoGitDirectory(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
t.Logf("Cloning into %s", dir) t.Logf("Cloning into %s", dir)
err := service.CloneRepository(dir, repositoryURL, referenceName, "", "", gittypes.GitCredentialAuthType_Basic, false) err := service.CloneRepository(dir, repositoryURL, referenceName, "", "", gittypes.GitCredentialAuthType_Basic, false)
assert.NoError(t, err) require.NoError(t, err)
assert.NoDirExists(t, filepath.Join(dir, ".git")) assert.NoDirExists(t, filepath.Join(dir, ".git"))
} }
@ -74,7 +75,7 @@ func Test_cloneRepository(t *testing.T) {
depth: 10, depth: 10,
}) })
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 4, getCommitHistoryLength(t, dir), "cloned repo has incorrect depth") assert.Equal(t, 4, getCommitHistoryLength(t, dir), "cloned repo has incorrect depth")
} }
@ -86,7 +87,7 @@ func Test_latestCommitID(t *testing.T) {
id, err := service.LatestCommitID(repositoryURL, referenceName, "", "", gittypes.GitCredentialAuthType_Basic, false) id, err := service.LatestCommitID(repositoryURL, referenceName, "", "", gittypes.GitCredentialAuthType_Basic, false)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "68dcaa7bd452494043c64252ab90db0f98ecf8d2", id) assert.Equal(t, "68dcaa7bd452494043c64252ab90db0f98ecf8d2", id)
} }
@ -97,7 +98,7 @@ func Test_ListRefs(t *testing.T) {
fs, err := service.ListRefs(repositoryURL, "", "", gittypes.GitCredentialAuthType_Basic, false, false) fs, err := service.ListRefs(repositoryURL, "", "", gittypes.GitCredentialAuthType_Basic, false, false)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"refs/heads/main"}, fs) assert.Equal(t, []string{"refs/heads/main"}, fs)
} }
@ -119,7 +120,7 @@ func Test_ListFiles(t *testing.T) {
false, false,
) )
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"docker-compose.yml"}, fs) assert.Equal(t, []string{"docker-compose.yml"}, fs)
} }
@ -214,12 +215,12 @@ func Test_listRefsPrivateRepository(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
refs, err := client.listRefs(context.TODO(), tt.args) refs, err := client.listRefs(context.TODO(), tt.args)
if tt.expect.err == nil { if tt.expect.err == nil {
assert.NoError(t, err) require.NoError(t, err)
if tt.expect.refsCount > 0 { if tt.expect.refsCount > 0 {
assert.Greater(t, len(refs), 0) assert.NotEmpty(t, refs)
} }
} else { } else {
assert.Error(t, err) require.Error(t, err)
assert.Equal(t, tt.expect.err, err) assert.Equal(t, tt.expect.err, err)
} }
}) })
@ -325,14 +326,14 @@ func Test_listFilesPrivateRepository(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
paths, err := client.listFiles(context.TODO(), tt.args) paths, err := client.listFiles(context.TODO(), tt.args)
if tt.expect.shouldFail { if tt.expect.shouldFail {
assert.Error(t, err) require.Error(t, err)
if tt.expect.err != nil { if tt.expect.err != nil {
assert.Equal(t, tt.expect.err, err) assert.Equal(t, tt.expect.err, err)
} }
} else { } else {
assert.NoError(t, err) require.NoError(t, err)
if tt.expect.matchedCount > 0 { if tt.expect.matchedCount > 0 {
assert.Greater(t, len(paths), 0) assert.NotEmpty(t, paths)
} }
} }
}) })

View File

@ -21,10 +21,14 @@ func ValidateAutoUpdateSettings(autoUpdate *portainer.AutoUpdateSettings) error
return httperrors.NewInvalidPayloadError("invalid Webhook format") return httperrors.NewInvalidPayloadError("invalid Webhook format")
} }
if autoUpdate.Interval != "" { if autoUpdate.Interval == "" {
if _, err := time.ParseDuration(autoUpdate.Interval); err != nil { return nil
return httperrors.NewInvalidPayloadError("invalid Interval format")
} }
if d, err := time.ParseDuration(autoUpdate.Interval); err != nil {
return httperrors.NewInvalidPayloadError("invalid Interval format")
} else if d < time.Minute {
return httperrors.NewInvalidPayloadError("interval must be at least 1 minute")
} }
return nil return nil

View File

@ -23,6 +23,16 @@ func Test_ValidateAutoUpdate(t *testing.T) {
value: &portainer.AutoUpdateSettings{Interval: "1dd2hh3mm"}, value: &portainer.AutoUpdateSettings{Interval: "1dd2hh3mm"},
wantErr: true, wantErr: true,
}, },
{
name: "short interval value",
value: &portainer.AutoUpdateSettings{Interval: "1s"},
wantErr: true,
},
{
name: "valid webhook without interval",
value: &portainer.AutoUpdateSettings{Webhook: "8dce8c2f-9ca1-482b-ad20-271e86536ada"},
wantErr: false,
},
{ {
name: "valid auto update", name: "valid auto update",
value: &portainer.AutoUpdateSettings{ value: &portainer.AutoUpdateSettings{

View File

@ -4,10 +4,14 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/portainer/portainer/pkg/fips"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestNewService(t *testing.T) { func TestNewService(t *testing.T) {
fips.InitFIPS(false)
service := NewService(true) service := NewService(true)
require.NotNil(t, service) require.NotNil(t, service)
require.True(t, service.httpsClient.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify) //nolint:forbidigo require.True(t, service.httpsClient.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify) //nolint:forbidigo

View File

@ -6,11 +6,14 @@ import (
"testing" "testing"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/pkg/fips"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestExecutePingOperationFailure(t *testing.T) { func TestExecutePingOperationFailure(t *testing.T) {
fips.InitFIPS(false)
host := "http://localhost:1" host := "http://localhost:1"
config := portainer.TLSConfiguration{ config := portainer.TLSConfiguration{
TLS: true, TLS: true,

View File

@ -1,20 +0,0 @@
package errors
import (
"errors"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
)
func TxResponse(err error, validResponse func() *httperror.HandlerError) *httperror.HandlerError {
if err != nil {
var handlerError *httperror.HandlerError
if errors.As(err, &handlerError) {
return handlerError
}
return httperror.InternalServerError("Unexpected error", err)
}
return validResponse()
}

View File

@ -26,11 +26,10 @@ func (handler *Handler) logout(w http.ResponseWriter, r *http.Request) *httperro
handler.KubernetesTokenCacheManager.RemoveUserFromCache(tokenData.ID) handler.KubernetesTokenCacheManager.RemoveUserFromCache(tokenData.ID)
handler.KubernetesClientFactory.ClearUserClientCache(strconv.Itoa(int(tokenData.ID))) handler.KubernetesClientFactory.ClearUserClientCache(strconv.Itoa(int(tokenData.ID)))
logoutcontext.Cancel(tokenData.Token) logoutcontext.Cancel(tokenData.Token)
handler.bouncer.RevokeJWT(tokenData.Token)
} }
security.RemoveAuthCookie(w) security.RemoveAuthCookie(w)
handler.bouncer.RevokeJWT(tokenData.Token)
return response.Empty(w) return response.Empty(w)
} }

View File

@ -0,0 +1,55 @@
package auth
import (
"net/http"
"net/http/httptest"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/http/proxy/factory/kubernetes"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/kubernetes/cli"
"github.com/stretchr/testify/require"
)
type mockBouncer struct {
security.BouncerService
}
func NewMockBouncer() *mockBouncer {
return &mockBouncer{BouncerService: testhelpers.NewTestRequestBouncer()}
}
func (*mockBouncer) CookieAuthLookup(r *http.Request) (*portainer.TokenData, error) {
return &portainer.TokenData{
ID: 1,
Username: "testuser",
Token: "valid-token",
}, nil
}
func TestLogout(t *testing.T) {
h := NewHandler(NewMockBouncer(), nil, nil, nil)
h.KubernetesTokenCacheManager = kubernetes.NewTokenCacheManager()
k, err := cli.NewClientFactory(nil, nil, nil, "", "", "")
require.NoError(t, err)
h.KubernetesClientFactory = k
rr := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/auth/logout", nil)
h.ServeHTTP(rr, req)
require.Equal(t, http.StatusNoContent, rr.Code)
}
func TestLogoutNoPanic(t *testing.T) {
h := NewHandler(testhelpers.NewTestRequestBouncer(), nil, nil, nil)
rr := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/auth/logout", nil)
h.ServeHTTP(rr, req)
require.Equal(t, http.StatusNoContent, rr.Code)
}

View File

@ -18,6 +18,7 @@ import (
"github.com/portainer/portainer/api/internal/testhelpers" "github.com/portainer/portainer/api/internal/testhelpers"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_restoreArchive_usingCombinationOfPasswords(t *testing.T) { func Test_restoreArchive_usingCombinationOfPasswords(t *testing.T) {
@ -69,8 +70,7 @@ func Test_restoreArchive_usingCombinationOfPasswords(t *testing.T) {
// restore // restore
w := httptest.NewRecorder() w := httptest.NewRecorder()
r, err := prepareMultipartRequest(test.restorePassword, archive) r := prepareMultipartRequest(t, test.restorePassword, archive)
assert.Nil(t, err, "Shouldn't fail to write multipart form")
restoreErr := h.restore(w, r) restoreErr := h.restore(w, r)
assert.Equal(t, test.fails, restoreErr != nil, "Didn't meet expectation of failing restore handler") assert.Equal(t, test.fails, restoreErr != nil, "Didn't meet expectation of failing restore handler")
@ -101,8 +101,7 @@ func Test_restoreArchive_shouldFailIfSystemWasAlreadyInitialized(t *testing.T) {
// restore // restore
w := httptest.NewRecorder() w := httptest.NewRecorder()
r, err := prepareMultipartRequest("password", archive) r := prepareMultipartRequest(t, "password", archive)
assert.Nil(t, err, "Shouldn't fail to write multipart form")
restoreErr := h.restore(w, r) restoreErr := h.restore(w, r)
assert.NotNil(t, restoreErr, "Should fail, because system it already initialized") assert.NotNil(t, restoreErr, "Should fail, because system it already initialized")
@ -117,31 +116,31 @@ func backup(t *testing.T, h *Handler, password string) []byte {
assert.Nil(t, backupErr, "Backup should not fail") assert.Nil(t, backupErr, "Backup should not fail")
response := w.Result() response := w.Result()
archive, _ := io.ReadAll(response.Body) archive, err := io.ReadAll(response.Body)
require.NoError(t, err)
response.Body.Close() response.Body.Close()
return archive return archive
} }
func prepareMultipartRequest(password string, file []byte) (*http.Request, error) { func prepareMultipartRequest(t *testing.T, password string, file []byte) *http.Request {
var body bytes.Buffer var body bytes.Buffer
w := multipart.NewWriter(&body) w := multipart.NewWriter(&body)
err := w.WriteField("password", password) err := w.WriteField("password", password)
if err != nil { require.NoError(t, err)
return nil, err
}
fw, err := w.CreateFormFile("file", "filename") fw, err := w.CreateFormFile("file", "filename")
if err != nil { require.NoError(t, err)
return nil, err
} _, err = io.Copy(fw, bytes.NewReader(file))
io.Copy(fw, bytes.NewReader(file)) require.NoError(t, err)
r := httptest.NewRequest(http.MethodPost, "http://localhost/", &body) r := httptest.NewRequest(http.MethodPost, "http://localhost/", &body)
r.Header.Set("Content-Type", w.FormDataContentType()) r.Header.Set("Content-Type", w.FormDataContentType())
w.Close() err = w.Close()
require.NoError(t, err)
return r, nil return r
} }

View File

@ -9,7 +9,6 @@ import (
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"sync"
"testing" "testing"
"time" "time"
@ -25,6 +24,8 @@ import (
"github.com/segmentio/encoding/json" "github.com/segmentio/encoding/json"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
) )
func init() { func init() {
@ -112,15 +113,14 @@ func createTestFile(targetPath string) error {
} }
func prepareTestFolder(projectPath, filename string) error { func prepareTestFolder(projectPath, filename string) error {
err := os.MkdirAll(projectPath, fs.ModePerm) if err := os.MkdirAll(projectPath, fs.ModePerm); err != nil {
if err != nil {
return err return err
} }
return createTestFile(filepath.Join(projectPath, filename)) return createTestFile(filepath.Join(projectPath, filename))
} }
func singleAPIRequest(h *Handler, jwt string, is *assert.Assertions, expect string) { func singleAPIRequest(h *Handler, jwt string, expect string) error {
type response struct { type response struct {
FileContent string FileContent string
} }
@ -131,15 +131,25 @@ func singleAPIRequest(h *Handler, jwt string, is *assert.Assertions, expect stri
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
h.ServeHTTP(rr, req) h.ServeHTTP(rr, req)
is.Equal(http.StatusOK, rr.Code) if rr.Code != http.StatusOK {
return errors.New("unexpected status code: " + http.StatusText(rr.Code))
}
body, err := io.ReadAll(rr.Body) body, err := io.ReadAll(rr.Body)
is.NoError(err, "ReadAll should not return error") if err != nil {
return err
}
var resp response var resp response
err = json.Unmarshal(body, &resp) if err := json.Unmarshal(body, &resp); err != nil {
is.NoError(err, "response should be list json") return err
is.Equal(resp.FileContent, expect) }
if resp.FileContent != expect {
return errors.New("unexpected file content: " + resp.FileContent + ", expected: " + expect)
}
return nil
} }
func Test_customTemplateGitFetch(t *testing.T) { func Test_customTemplateGitFetch(t *testing.T) {
@ -150,28 +160,29 @@ func Test_customTemplateGitFetch(t *testing.T) {
// create user(s) // create user(s)
user1 := &portainer.User{ID: 1, Username: "user-1", Role: portainer.StandardUserRole, PortainerAuthorizations: authorization.DefaultPortainerAuthorizations()} user1 := &portainer.User{ID: 1, Username: "user-1", Role: portainer.StandardUserRole, PortainerAuthorizations: authorization.DefaultPortainerAuthorizations()}
err := store.User().Create(user1) err := store.User().Create(user1)
is.NoError(err, "error creating user 1") require.NoError(t, err, "error creating user 1")
user2 := &portainer.User{ID: 2, Username: "user-2", Role: portainer.StandardUserRole, PortainerAuthorizations: authorization.DefaultPortainerAuthorizations()} user2 := &portainer.User{ID: 2, Username: "user-2", Role: portainer.StandardUserRole, PortainerAuthorizations: authorization.DefaultPortainerAuthorizations()}
err = store.User().Create(user2) err = store.User().Create(user2)
is.NoError(err, "error creating user 2") require.NoError(t, err, "error creating user 2")
dir, err := os.Getwd() dir, err := os.Getwd()
is.NoError(err, "error to get working directory") require.NoError(t, err, "error to get working directory")
template1 := &portainer.CustomTemplate{ID: 1, Title: "custom-template-1", ProjectPath: filepath.Join(dir, "fixtures/custom_template_1"), GitConfig: &gittypes.RepoConfig{ConfigFilePath: "test-config-path.txt"}} template1 := &portainer.CustomTemplate{ID: 1, Title: "custom-template-1", ProjectPath: filepath.Join(dir, "fixtures/custom_template_1"), GitConfig: &gittypes.RepoConfig{ConfigFilePath: "test-config-path.txt"}}
err = store.CustomTemplateService.Create(template1) err = store.CustomTemplateService.Create(template1)
is.NoError(err, "error creating custom template 1") require.NoError(t, err, "error creating custom template 1")
// prepare testing folder // prepare testing folder
err = prepareTestFolder(template1.ProjectPath, template1.GitConfig.ConfigFilePath) err = prepareTestFolder(template1.ProjectPath, template1.GitConfig.ConfigFilePath)
is.NoError(err, "error creating testing folder") require.NoError(t, err, "error creating testing folder")
defer os.RemoveAll(filepath.Join(dir, "fixtures")) defer os.RemoveAll(filepath.Join(dir, "fixtures"))
// setup services // setup services
jwtService, err := jwt.NewService("1h", store) jwtService, err := jwt.NewService("1h", store)
is.NoError(err, "Error initiating jwt service") require.NoError(t, err, "Error initiating jwt service")
requestBouncer := security.NewRequestBouncer(store, jwtService, nil) requestBouncer := security.NewRequestBouncer(store, jwtService, nil)
gitService := &TestGitService{ gitService := &TestGitService{
@ -182,52 +193,55 @@ func Test_customTemplateGitFetch(t *testing.T) {
h := NewHandler(requestBouncer, store, fileService, gitService) h := NewHandler(requestBouncer, store, fileService, gitService)
// generate two standard users' tokens // generate two standard users' tokens
jwt1, _, _ := jwtService.GenerateToken(&portainer.TokenData{ID: user1.ID, Username: user1.Username, Role: user1.Role}) jwt1, _, err := jwtService.GenerateToken(&portainer.TokenData{ID: user1.ID, Username: user1.Username, Role: user1.Role})
jwt2, _, _ := jwtService.GenerateToken(&portainer.TokenData{ID: user2.ID, Username: user2.Username, Role: user2.Role}) require.NoError(t, err)
jwt2, _, err := jwtService.GenerateToken(&portainer.TokenData{ID: user2.ID, Username: user2.Username, Role: user2.Role})
require.NoError(t, err)
t.Run("can return the expected file content by a single call from one user", func(t *testing.T) { t.Run("can return the expected file content by a single call from one user", func(t *testing.T) {
singleAPIRequest(h, jwt1, is, "abcdefg") err := singleAPIRequest(h, jwt1, "abcdefg")
require.NoError(t, err)
}) })
t.Run("can return the expected file content by multiple calls from one user", func(t *testing.T) { t.Run("can return the expected file content by multiple calls from one user", func(t *testing.T) {
var wg sync.WaitGroup var g errgroup.Group
wg.Add(5)
for range 5 { for range 5 {
go func() { g.Go(func() error {
singleAPIRequest(h, jwt1, is, "abcdefg") return singleAPIRequest(h, jwt1, "abcdefg")
wg.Done() })
}()
} }
wg.Wait() err := g.Wait()
require.NoError(t, err)
}) })
t.Run("can return the expected file content by multiple calls from different users", func(t *testing.T) { t.Run("can return the expected file content by multiple calls from different users", func(t *testing.T) {
var wg sync.WaitGroup var g errgroup.Group
wg.Add(10)
for i := range 10 { for i := range 10 {
go func(j int) { g.Go(func() error {
if j%2 == 0 { if i%2 == 0 {
singleAPIRequest(h, jwt1, is, "abcdefg") return singleAPIRequest(h, jwt1, "abcdefg")
} else {
singleAPIRequest(h, jwt2, is, "abcdefg")
} }
wg.Done() return singleAPIRequest(h, jwt2, "abcdefg")
}(i) })
} }
wg.Wait() err := g.Wait()
require.NoError(t, err)
}) })
t.Run("can return the expected file content after a new commit is made", func(t *testing.T) { t.Run("can return the expected file content after a new commit is made", func(t *testing.T) {
singleAPIRequest(h, jwt1, is, "abcdefg") err := singleAPIRequest(h, jwt1, "abcdefg")
require.NoError(t, err)
testFileContent = "gfedcba" testFileContent = "gfedcba"
singleAPIRequest(h, jwt2, is, "gfedcba") err = singleAPIRequest(h, jwt2, "gfedcba")
require.NoError(t, err)
}) })
t.Run("restore git repository if it is failed to download the new git repository", func(t *testing.T) { t.Run("restore git repository if it is failed to download the new git repository", func(t *testing.T) {
@ -246,11 +260,11 @@ func Test_customTemplateGitFetch(t *testing.T) {
var errResp httperror.HandlerError var errResp httperror.HandlerError
err = json.NewDecoder(rr.Body).Decode(&errResp) err = json.NewDecoder(rr.Body).Decode(&errResp)
assert.NoError(t, err, "failed to parse error body") require.NoError(t, err, "failed to parse error body")
assert.FileExists(t, gitService.targetFilePath, "previous git repository is not restored") assert.FileExists(t, gitService.targetFilePath, "previous git repository is not restored")
fileContent, err := os.ReadFile(gitService.targetFilePath) fileContent, err := os.ReadFile(gitService.targetFilePath)
assert.NoError(t, err, "failed to read target file") require.NoError(t, err, "failed to read target file")
assert.Equal(t, "gfedcba", string(fileContent)) assert.Equal(t, "gfedcba", string(fileContent))
}) })
} }

View File

@ -5,8 +5,11 @@ import (
"strconv" "strconv"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
httperrors "github.com/portainer/portainer/api/http/errors" httperrors "github.com/portainer/portainer/api/http/errors"
"github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/authorization"
"github.com/portainer/portainer/api/slicesx"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/libhttp/response"
@ -32,31 +35,45 @@ func (handler *Handler) customTemplateInspect(w http.ResponseWriter, r *http.Req
return httperror.BadRequest("Invalid Custom template identifier route variable", err) return httperror.BadRequest("Invalid Custom template identifier route variable", err)
} }
customTemplate, err := handler.DataStore.CustomTemplate().Read(portainer.CustomTemplateID(customTemplateID)) var customTemplate *portainer.CustomTemplate
err = handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error {
customTemplate, err = tx.CustomTemplate().Read(portainer.CustomTemplateID(customTemplateID))
if handler.DataStore.IsErrObjectNotFound(err) { if handler.DataStore.IsErrObjectNotFound(err) {
return httperror.NotFound("Unable to find a custom template with the specified identifier inside the database", err) return httperror.NotFound("Unable to find a custom template with the specified identifier inside the database", err)
} else if err != nil { } else if err != nil {
return httperror.InternalServerError("Unable to find a custom template with the specified identifier inside the database", err) return httperror.InternalServerError("Unable to find a custom template with the specified identifier inside the database", err)
} }
resourceControl, err := tx.ResourceControl().ResourceControlByResourceIDAndType(strconv.Itoa(customTemplateID), portainer.CustomTemplateResourceControl)
if err != nil {
return httperror.InternalServerError("Unable to retrieve a resource control associated to the custom template", err)
}
securityContext, err := security.RetrieveRestrictedRequestContext(r) securityContext, err := security.RetrieveRestrictedRequestContext(r)
if err != nil { if err != nil {
return httperror.InternalServerError("Unable to retrieve user info from request context", err) return httperror.InternalServerError("Unable to retrieve user info from request context", err)
} }
resourceControl, err := handler.DataStore.ResourceControl().ResourceControlByResourceIDAndType(strconv.Itoa(customTemplateID), portainer.CustomTemplateResourceControl) canEdit := userCanEditTemplate(customTemplate, securityContext)
if err != nil { hasAccess := false
return httperror.InternalServerError("Unable to retrieve a resource control associated to the custom template", err)
}
access := userCanEditTemplate(customTemplate, securityContext)
if !access {
return httperror.Forbidden("Access denied to resource", httperrors.ErrResourceAccessDenied)
}
if resourceControl != nil { if resourceControl != nil {
customTemplate.ResourceControl = resourceControl customTemplate.ResourceControl = resourceControl
teamIDs := slicesx.Map(securityContext.UserMemberships, func(m portainer.TeamMembership) portainer.TeamID {
return m.TeamID
})
hasAccess = authorization.UserCanAccessResource(securityContext.UserID, teamIDs, resourceControl)
} }
return response.JSON(w, customTemplate) if canEdit || hasAccess {
return nil
}
return httperror.Forbidden("Access denied to resource", httperrors.ErrResourceAccessDenied)
})
return response.TxResponse(w, customTemplate, err)
} }

View File

@ -0,0 +1,100 @@
package customtemplates
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gorilla/mux"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/testhelpers"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/segmentio/encoding/json"
"github.com/stretchr/testify/require"
)
func TestInspectHandler(t *testing.T) {
_, ds := datastore.MustNewTestStore(t, true, false)
require.NotNil(t, ds)
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
require.NoError(t, tx.User().Create(&portainer.User{ID: 1, Username: "admin", Role: portainer.AdministratorRole}))
require.NoError(t, tx.User().Create(&portainer.User{ID: 2, Username: "std2", Role: portainer.StandardUserRole}))
require.NoError(t, tx.User().Create(&portainer.User{ID: 3, Username: "std3", Role: portainer.StandardUserRole}))
require.NoError(t, tx.User().Create(&portainer.User{ID: 4, Username: "std4", Role: portainer.StandardUserRole}))
require.NoError(t, tx.Endpoint().Create(&portainer.Endpoint{ID: 1,
UserAccessPolicies: portainer.UserAccessPolicies{
2: portainer.AccessPolicy{RoleID: 0},
3: portainer.AccessPolicy{RoleID: 0},
}}))
require.NoError(t, tx.Team().Create(&portainer.Team{ID: 1}))
require.NoError(t, tx.TeamMembership().Create(&portainer.TeamMembership{ID: 1, UserID: 3, TeamID: 1, Role: portainer.TeamMember}))
require.NoError(t, tx.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1}))
require.NoError(t, tx.CustomTemplate().Create(&portainer.CustomTemplate{ID: 2}))
require.NoError(t, tx.ResourceControl().Create(&portainer.ResourceControl{ID: 1, ResourceID: "2", Type: portainer.CustomTemplateResourceControl,
UserAccesses: []portainer.UserResourceAccess{{UserID: 2}},
TeamAccesses: []portainer.TeamResourceAccess{{TeamID: 1}},
}))
return nil
}))
handler := NewHandler(testhelpers.NewTestRequestBouncer(), ds, &TestFileService{}, nil)
test := func(templateID string, restrictedContext *security.RestrictedRequestContext) (*httptest.ResponseRecorder, *httperror.HandlerError) {
r := httptest.NewRequest(http.MethodGet, "/custom_templates/"+templateID, nil)
r = mux.SetURLVars(r, map[string]string{"id": templateID})
ctx := security.StoreRestrictedRequestContext(r, restrictedContext)
r = r.WithContext(ctx)
rr := httptest.NewRecorder()
return rr, handler.customTemplateInspect(rr, r)
}
t.Run("unknown id should get not found error", func(t *testing.T) {
_, r := test("0", &security.RestrictedRequestContext{UserID: 1})
require.NotNil(t, r)
require.Equal(t, http.StatusNotFound, r.StatusCode)
})
t.Run("admin should access adminonly template", func(t *testing.T) {
rr, r := test("1", &security.RestrictedRequestContext{UserID: 1, IsAdmin: true})
require.Nil(t, r)
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
var template portainer.CustomTemplate
require.NoError(t, json.NewDecoder(rr.Body).Decode(&template))
require.Equal(t, portainer.CustomTemplateID(1), template.ID)
})
t.Run("std should not access adminonly template", func(t *testing.T) {
_, r := test("1", &security.RestrictedRequestContext{UserID: 2})
require.NotNil(t, r)
require.Equal(t, http.StatusForbidden, r.StatusCode)
})
t.Run("std should access template via direct user access", func(t *testing.T) {
rr, r := test("2", &security.RestrictedRequestContext{UserID: 2})
require.Nil(t, r)
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
var template portainer.CustomTemplate
require.NoError(t, json.NewDecoder(rr.Body).Decode(&template))
require.Equal(t, portainer.CustomTemplateID(2), template.ID)
})
t.Run("std should access template via team access", func(t *testing.T) {
rr, r := test("2", &security.RestrictedRequestContext{UserID: 3, UserMemberships: []portainer.TeamMembership{{ID: 1, UserID: 3, TeamID: 1}}})
require.Nil(t, r)
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
var template portainer.CustomTemplate
require.NoError(t, json.NewDecoder(rr.Body).Decode(&template))
require.Equal(t, portainer.CustomTemplateID(2), template.ID)
})
t.Run("std should not access template without access", func(t *testing.T) {
_, r := test("2", &security.RestrictedRequestContext{UserID: 4})
require.NotNil(t, r)
require.Equal(t, http.StatusForbidden, r.StatusCode)
})
}

View File

@ -11,8 +11,7 @@ import (
"github.com/docker/docker/api/types/volume" "github.com/docker/docker/api/types/volume"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/docker" "github.com/portainer/portainer/api/docker/stats"
"github.com/portainer/portainer/api/http/errors"
"github.com/portainer/portainer/api/http/handler/docker/utils" "github.com/portainer/portainer/api/http/handler/docker/utils"
"github.com/portainer/portainer/api/http/middlewares" "github.com/portainer/portainer/api/http/middlewares"
"github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/http/security"
@ -26,7 +25,7 @@ type imagesCounters struct {
} }
type dashboardResponse struct { type dashboardResponse struct {
Containers docker.ContainerStats `json:"containers"` Containers stats.ContainerStats `json:"containers"`
Services int `json:"services"` Services int `json:"services"`
Images imagesCounters `json:"images"` Images imagesCounters `json:"images"`
Volumes int `json:"volumes"` Volumes int `json:"volumes"`
@ -144,13 +143,18 @@ func (h *Handler) dashboard(w http.ResponseWriter, r *http.Request) *httperror.H
stackCount = len(stacks) stackCount = len(stacks)
} }
containersStats, err := stats.CalculateContainerStats(r.Context(), cli, info.Swarm.ControlAvailable, containers)
if err != nil {
return httperror.InternalServerError("Unable to retrieve Docker containers stats", err)
}
resp = dashboardResponse{ resp = dashboardResponse{
Images: imagesCounters{ Images: imagesCounters{
Total: len(images), Total: len(images),
Size: totalSize, Size: totalSize,
}, },
Services: len(services), Services: len(services),
Containers: docker.CalculateContainerStats(containers), Containers: containersStats,
Networks: len(networks), Networks: len(networks),
Volumes: len(volumes), Volumes: len(volumes),
Stacks: stackCount, Stacks: stackCount,
@ -159,7 +163,5 @@ func (h *Handler) dashboard(w http.ResponseWriter, r *http.Request) *httperror.H
return nil return nil
}) })
return errors.TxResponse(err, func() *httperror.HandlerError { return response.TxResponse(w, resp, err)
return response.JSON(w, resp)
})
} }

View File

@ -11,6 +11,7 @@ import (
"github.com/portainer/portainer/api/internal/testhelpers" "github.com/portainer/portainer/api/internal/testhelpers"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestHandler_getDockerStacks(t *testing.T) { func TestHandler_getDockerStacks(t *testing.T) {
@ -69,7 +70,7 @@ func TestHandler_getDockerStacks(t *testing.T) {
stacksList, err := GetDockerStacks(datastore, &security.RestrictedRequestContext{ stacksList, err := GetDockerStacks(datastore, &security.RestrictedRequestContext{
IsAdmin: true, IsAdmin: true,
}, environment.ID, containers, services) }, environment.ID, containers, services)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, stacksList, 3) assert.Len(t, stacksList, 3)
expectedStacks := []StackViewModel{ expectedStacks := []StackViewModel{

View File

@ -10,6 +10,7 @@ import (
"github.com/portainer/portainer/api/roar" "github.com/portainer/portainer/api/roar"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
) )
type edgeGroupCreatePayload struct { type edgeGroupCreatePayload struct {
@ -111,5 +112,5 @@ func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request)
return nil return nil
}) })
return txResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err) return response.TxResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err)
} }

View File

@ -32,16 +32,8 @@ func (handler *Handler) edgeGroupDelete(w http.ResponseWriter, r *http.Request)
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
return deleteEdgeGroup(tx, portainer.EdgeGroupID(edgeGroupID)) return deleteEdgeGroup(tx, portainer.EdgeGroupID(edgeGroupID))
}) })
if err != nil {
var httpErr *httperror.HandlerError
if errors.As(err, &httpErr) {
return httpErr
}
return httperror.InternalServerError("Unexpected error", err) return response.TxEmptyResponse(w, err)
}
return response.Empty(w)
} }
func deleteEdgeGroup(tx dataservices.DataStoreTx, ID portainer.EdgeGroupID) error { func deleteEdgeGroup(tx dataservices.DataStoreTx, ID portainer.EdgeGroupID) error {

View File

@ -8,6 +8,7 @@ import (
"github.com/portainer/portainer/api/roar" "github.com/portainer/portainer/api/roar"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
) )
// @id EdgeGroupInspect // @id EdgeGroupInspect
@ -36,7 +37,7 @@ func (handler *Handler) edgeGroupInspect(w http.ResponseWriter, r *http.Request)
edgeGroup.Endpoints = edgeGroup.EndpointIDs.ToSlice() edgeGroup.Endpoints = edgeGroup.EndpointIDs.ToSlice()
return txResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err) return response.TxResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err)
} }
func getEdgeGroup(tx dataservices.DataStoreTx, ID portainer.EdgeGroupID) (*portainer.EdgeGroup, error) { func getEdgeGroup(tx dataservices.DataStoreTx, ID portainer.EdgeGroupID) (*portainer.EdgeGroup, error) {

View File

@ -105,7 +105,7 @@ func TestEmptyEdgeGroupInspectHandler(t *testing.T) {
// Make sure the frontend does not get a null value but a [] instead // Make sure the frontend does not get a null value but a [] instead
require.NotNil(t, responseGroup.Endpoints) require.NotNil(t, responseGroup.Endpoints)
require.Len(t, responseGroup.Endpoints, 0) require.Empty(t, responseGroup.Endpoints)
} }
func TestDynamicEdgeGroupInspectHandler(t *testing.T) { func TestDynamicEdgeGroupInspectHandler(t *testing.T) {

View File

@ -9,6 +9,7 @@ import (
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/roar" "github.com/portainer/portainer/api/roar"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/response"
) )
type shadowedEdgeGroup struct { type shadowedEdgeGroup struct {
@ -44,7 +45,7 @@ func (handler *Handler) edgeGroupList(w http.ResponseWriter, r *http.Request) *h
return err return err
}) })
return txResponse(w, decoratedEdgeGroups, err) return response.TxResponse(w, decoratedEdgeGroups, err)
} }
func getEdgeGroupList(tx dataservices.DataStoreTx) ([]decoratedEdgeGroup, error) { func getEdgeGroupList(tx dataservices.DataStoreTx) ([]decoratedEdgeGroup, error) {

View File

@ -47,7 +47,7 @@ func Test_getEndpointTypes(t *testing.T) {
for _, test := range tests { for _, test := range tests {
ans, err := getEndpointTypes(datastore, roar.FromSlice(test.endpointIds)) ans, err := getEndpointTypes(datastore, roar.FromSlice(test.endpointIds))
assert.NoError(t, err, "getEndpointTypes shouldn't fail") require.NoError(t, err, "getEndpointTypes shouldn't fail")
assert.ElementsMatch(t, test.expected, ans, "getEndpointTypes expected to return %b for %v, but returned %b", test.expected, test.endpointIds, ans) assert.ElementsMatch(t, test.expected, ans, "getEndpointTypes expected to return %b for %v, but returned %b", test.expected, test.endpointIds, ans)
} }
@ -57,7 +57,7 @@ func Test_getEndpointTypes_failWhenEndpointDontExist(t *testing.T) {
datastore := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{})) datastore := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{}))
_, err := getEndpointTypes(datastore, roar.FromSlice([]portainer.EndpointID{1})) _, err := getEndpointTypes(datastore, roar.FromSlice([]portainer.EndpointID{1}))
assert.Error(t, err, "getEndpointTypes should fail") require.Error(t, err, "getEndpointTypes should fail")
} }
func TestEdgeGroupListHandler(t *testing.T) { func TestEdgeGroupListHandler(t *testing.T) {
@ -112,5 +112,5 @@ func TestEdgeGroupListHandler(t *testing.T) {
require.Len(t, responseGroups, 1) require.Len(t, responseGroups, 1)
require.ElementsMatch(t, []portainer.EndpointID{1, 2, 3}, responseGroups[0].Endpoints) require.ElementsMatch(t, []portainer.EndpointID{1, 2, 3}, responseGroups[0].Endpoints)
require.Len(t, responseGroups[0].TrustedEndpoints, 0) require.Empty(t, responseGroups[0].TrustedEndpoints)
} }

View File

@ -13,6 +13,7 @@ import (
"github.com/portainer/portainer/api/slicesx" "github.com/portainer/portainer/api/slicesx"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
) )
type edgeGroupUpdatePayload struct { type edgeGroupUpdatePayload struct {
@ -158,7 +159,7 @@ func (handler *Handler) edgeGroupUpdate(w http.ResponseWriter, r *http.Request)
return nil return nil
}) })
return txResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err) return response.TxResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err)
} }
func (handler *Handler) updateEndpointStacks(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint, edgeGroups []portainer.EdgeGroup, edgeStacks []portainer.EdgeStack) error { func (handler *Handler) updateEndpointStacks(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint, edgeGroups []portainer.EdgeGroup, edgeStacks []portainer.EdgeStack) error {

View File

@ -1,14 +1,12 @@
package edgegroups package edgegroups
import ( import (
"errors"
"net/http" "net/http"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/http/security"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/response"
"github.com/gorilla/mux" "github.com/gorilla/mux"
) )
@ -38,16 +36,3 @@ func NewHandler(bouncer security.BouncerService) *Handler {
return h return h
} }
func txResponse(w http.ResponseWriter, r any, err error) *httperror.HandlerError {
if err != nil {
var handlerError *httperror.HandlerError
if errors.As(err, &handlerError) {
return handlerError
}
return httperror.InternalServerError("Unexpected error", err)
}
return response.JSON(w, r)
}

View File

@ -15,6 +15,7 @@ import (
"github.com/portainer/portainer/api/internal/endpointutils" "github.com/portainer/portainer/api/internal/endpointutils"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/validate" "github.com/portainer/portainer/pkg/validate"
) )
@ -85,19 +86,18 @@ func (payload *edgeJobCreateFromFileContentPayload) Validate(r *http.Request) er
// @router /edge_jobs/create/string [post] // @router /edge_jobs/create/string [post]
func (handler *Handler) createEdgeJobFromFileContent(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { func (handler *Handler) createEdgeJobFromFileContent(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
var payload edgeJobCreateFromFileContentPayload var payload edgeJobCreateFromFileContentPayload
err := request.DecodeAndValidateJSONPayload(r, &payload) if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
if err != nil {
return httperror.BadRequest("Invalid request payload", err) return httperror.BadRequest("Invalid request payload", err)
} }
var edgeJob *portainer.EdgeJob var edgeJob *portainer.EdgeJob
var err error
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
edgeJob, err = handler.createEdgeJob(tx, &payload.edgeJobBasePayload, []byte(payload.FileContent)) edgeJob, err = handler.createEdgeJob(tx, &payload.edgeJobBasePayload, []byte(payload.FileContent))
return err return err
}) })
return txResponse(w, edgeJob, err) return response.TxResponse(w, edgeJob, err)
} }
func (handler *Handler) createEdgeJob(tx dataservices.DataStoreTx, payload *edgeJobBasePayload, fileContent []byte) (*portainer.EdgeJob, error) { func (handler *Handler) createEdgeJob(tx dataservices.DataStoreTx, payload *edgeJobBasePayload, fileContent []byte) (*portainer.EdgeJob, error) {
@ -191,19 +191,18 @@ func (payload *edgeJobCreateFromFilePayload) Validate(r *http.Request) error {
// @router /edge_jobs/create/file [post] // @router /edge_jobs/create/file [post]
func (handler *Handler) createEdgeJobFromFile(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { func (handler *Handler) createEdgeJobFromFile(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
payload := &edgeJobCreateFromFilePayload{} payload := &edgeJobCreateFromFilePayload{}
err := payload.Validate(r) if err := payload.Validate(r); err != nil {
if err != nil {
return httperror.BadRequest("Invalid request payload", err) return httperror.BadRequest("Invalid request payload", err)
} }
var edgeJob *portainer.EdgeJob var edgeJob *portainer.EdgeJob
var err error
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
edgeJob, err = handler.createEdgeJob(tx, &payload.edgeJobBasePayload, payload.File) edgeJob, err = handler.createEdgeJob(tx, &payload.edgeJobBasePayload, payload.File)
return err return err
}) })
return txResponse(w, edgeJob, err) return response.TxResponse(w, edgeJob, err)
} }
func (handler *Handler) createEdgeJobObjectFromPayload(tx dataservices.DataStoreTx, payload *edgeJobBasePayload) *portainer.EdgeJob { func (handler *Handler) createEdgeJobObjectFromPayload(tx dataservices.DataStoreTx, payload *edgeJobBasePayload) *portainer.EdgeJob {

View File

@ -0,0 +1,158 @@
package edgejobs
import (
"bytes"
"encoding/json"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gorilla/mux"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/datastore"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type mockFileService struct {
mock.Mock
portainer.FileService
}
func (m *mockFileService) StoreEdgeJobFileFromBytes(id string, file []byte) (string, error) {
args := m.Called(id, file)
return args.String(0), args.Error(1)
}
func (m *mockFileService) GetEdgeJobFolder(id string) string {
args := m.Called(id)
return args.String(0)
}
func (m *mockFileService) RemoveDirectory(path string) error {
args := m.Called(path)
return args.Error(0)
}
func initStore(t *testing.T) *datastore.Store {
_, store := datastore.MustNewTestStore(t, true, true)
require.NotNil(t, store)
require.NoError(t, store.UpdateTx(func(tx dataservices.DataStoreTx) error {
require.NoError(t, tx.Endpoint().Create(&portainer.Endpoint{
ID: 1,
Name: "endpoint-1",
EdgeID: "edge-id-1",
GroupID: 1,
Type: portainer.EdgeAgentOnDockerEnvironment,
UserTrusted: true,
}))
require.NoError(t, tx.Endpoint().Create(&portainer.Endpoint{
ID: 2,
Name: "endpoint-2",
EdgeID: "edge-id-2",
GroupID: 1,
Type: portainer.EdgeAgentOnDockerEnvironment,
UserTrusted: false,
}))
return nil
}))
return store
}
func Test_edgeJobCreate_StringMethod_Success(t *testing.T) {
store := initStore(t)
fileService := &mockFileService{}
fileService.On("StoreEdgeJobFileFromBytes", mock.Anything, mock.Anything).Return("testfile.txt", nil)
handler := &Handler{
DataStore: store,
FileService: fileService,
}
payload := edgeJobCreateFromFileContentPayload{
edgeJobBasePayload: edgeJobBasePayload{
Name: "testjob",
CronExpression: "* * * * *",
Endpoints: []portainer.EndpointID{1, 2},
},
FileContent: "echo hello",
}
body, _ := json.Marshal(payload)
req := httptest.NewRequest(http.MethodPost, "/edge_jobs/create/string", bytes.NewReader(body))
req = mux.SetURLVars(req, map[string]string{"method": "string"})
w := httptest.NewRecorder()
// Call handler
errh := handler.edgeJobCreate(w, req)
require.Nil(t, errh)
require.Equal(t, http.StatusOK, w.Result().StatusCode)
// Get edge job ID from response
var resp struct {
ID int `json:"Id"`
}
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
edgeJob, err := store.EdgeJob().Read(portainer.EdgeJobID(resp.ID))
require.NoError(t, err)
require.Len(t, edgeJob.Endpoints, 2)
require.Contains(t, edgeJob.Endpoints, portainer.EndpointID(1))
}
func Test_edgeJobCreate_FileMethod_Success(t *testing.T) {
store := initStore(t)
fileService := &mockFileService{}
fileService.On("StoreEdgeJobFileFromBytes", mock.Anything, mock.Anything).Return("testfile.txt", nil)
handler := &Handler{
DataStore: store,
FileService: fileService,
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
require.NoError(t, writer.WriteField("Name", "testjob"))
require.NoError(t, writer.WriteField("CronExpression", "* * * * *"))
require.NoError(t, writer.WriteField("Endpoints", "[1,2]"))
fileWriter, err := writer.CreateFormFile("file", "test.txt")
require.NoError(t, err)
_, err = io.Copy(fileWriter, strings.NewReader("echo hello"))
require.NoError(t, err)
require.NoError(t, writer.Close())
req := httptest.NewRequest(http.MethodPost, "/edge_jobs/create/file", &body)
req = mux.SetURLVars(req, map[string]string{"method": "file"})
req.Header.Set("Content-Type", writer.FormDataContentType())
w := httptest.NewRecorder()
handlerErr := handler.edgeJobCreate(w, req)
require.Nil(t, handlerErr)
require.Equal(t, http.StatusOK, w.Result().StatusCode)
var resp struct {
ID int `json:"Id"`
}
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
edgeJob, err := store.EdgeJob().Read(portainer.EdgeJobID(resp.ID))
require.NoError(t, err)
require.Len(t, edgeJob.Endpoints, 2)
require.Contains(t, edgeJob.Endpoints, portainer.EndpointID(1))
}

View File

@ -1,7 +1,6 @@
package edgejobs package edgejobs
import ( import (
"errors"
"maps" "maps"
"net/http" "net/http"
"strconv" "strconv"
@ -35,18 +34,11 @@ func (handler *Handler) edgeJobDelete(w http.ResponseWriter, r *http.Request) *h
return httperror.BadRequest("Invalid Edge job identifier route variable", err) return httperror.BadRequest("Invalid Edge job identifier route variable", err)
} }
if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
return handler.deleteEdgeJob(tx, portainer.EdgeJobID(edgeJobID)) return handler.deleteEdgeJob(tx, portainer.EdgeJobID(edgeJobID))
}); err != nil { })
var handlerError *httperror.HandlerError
if errors.As(err, &handlerError) {
return handlerError
}
return httperror.InternalServerError("Unexpected error", err) return response.TxEmptyResponse(w, err)
}
return response.Empty(w)
} }
func (handler *Handler) deleteEdgeJob(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID) error { func (handler *Handler) deleteEdgeJob(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID) error {

View File

@ -1,7 +1,6 @@
package edgejobs package edgejobs
import ( import (
"errors"
"net/http" "net/http"
"slices" "slices"
"strconv" "strconv"
@ -54,7 +53,7 @@ func (handler *Handler) edgeJobTasksClear(w http.ResponseWriter, r *http.Request
} }
} }
if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
updateEdgeJobFn := func(edgeJob *portainer.EdgeJob, endpointID portainer.EndpointID, endpointsFromGroups []portainer.EndpointID) error { updateEdgeJobFn := func(edgeJob *portainer.EdgeJob, endpointID portainer.EndpointID, endpointsFromGroups []portainer.EndpointID) error {
mutationFn(edgeJob, endpointID, endpointsFromGroups) mutationFn(edgeJob, endpointID, endpointsFromGroups)
@ -62,16 +61,9 @@ func (handler *Handler) edgeJobTasksClear(w http.ResponseWriter, r *http.Request
} }
return handler.clearEdgeJobTaskLogs(tx, portainer.EdgeJobID(edgeJobID), portainer.EndpointID(taskID), updateEdgeJobFn) return handler.clearEdgeJobTaskLogs(tx, portainer.EdgeJobID(edgeJobID), portainer.EndpointID(taskID), updateEdgeJobFn)
}); err != nil { })
var handlerError *httperror.HandlerError
if errors.As(err, &handlerError) {
return handlerError
}
return httperror.InternalServerError("Unexpected error", err) return response.TxEmptyResponse(w, err)
}
return response.Empty(w)
} }
func (handler *Handler) clearEdgeJobTaskLogs(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID, endpointID portainer.EndpointID, updateEdgeJob func(*portainer.EdgeJob, portainer.EndpointID, []portainer.EndpointID) error) error { func (handler *Handler) clearEdgeJobTaskLogs(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID, endpointID portainer.EndpointID, updateEdgeJob func(*portainer.EdgeJob, portainer.EndpointID, []portainer.EndpointID) error) error {

View File

@ -1,7 +1,6 @@
package edgejobs package edgejobs
import ( import (
"errors"
"net/http" "net/http"
"slices" "slices"
@ -39,7 +38,7 @@ func (handler *Handler) edgeJobTasksCollect(w http.ResponseWriter, r *http.Reque
return httperror.BadRequest("Invalid Task identifier route variable", err) return httperror.BadRequest("Invalid Task identifier route variable", err)
} }
if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
edgeJob, err := tx.EdgeJob().Read(portainer.EdgeJobID(edgeJobID)) edgeJob, err := tx.EdgeJob().Read(portainer.EdgeJobID(edgeJobID))
if tx.IsErrObjectNotFound(err) { if tx.IsErrObjectNotFound(err) {
return httperror.NotFound("Unable to find an Edge job with the specified identifier inside the database", err) return httperror.NotFound("Unable to find an Edge job with the specified identifier inside the database", err)
@ -81,14 +80,7 @@ func (handler *Handler) edgeJobTasksCollect(w http.ResponseWriter, r *http.Reque
} }
return nil return nil
}); err != nil { })
var handlerError *httperror.HandlerError
if errors.As(err, &handlerError) {
return handlerError
}
return httperror.InternalServerError("Unexpected error", err) return response.TxEmptyResponse(w, err)
}
return response.Empty(w)
} }

View File

@ -13,6 +13,7 @@ import (
"github.com/portainer/portainer/api/internal/edge" "github.com/portainer/portainer/api/internal/edge"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
) )
type taskContainer struct { type taskContainer struct {
@ -49,6 +50,7 @@ func (handler *Handler) edgeJobTasksList(w http.ResponseWriter, r *http.Request)
return err return err
}) })
return response.TxFuncResponse(err, func() *httperror.HandlerError {
results := filters.SearchOrderAndPaginate(tasks, params, filters.Config[*taskContainer]{ results := filters.SearchOrderAndPaginate(tasks, params, filters.Config[*taskContainer]{
SearchAccessors: []filters.SearchAccessor[*taskContainer]{ SearchAccessors: []filters.SearchAccessor[*taskContainer]{
func(tc *taskContainer) (string, error) { func(tc *taskContainer) (string, error) {
@ -73,7 +75,8 @@ func (handler *Handler) edgeJobTasksList(w http.ResponseWriter, r *http.Request)
filters.ApplyFilterResultsHeaders(&w, results) filters.ApplyFilterResultsHeaders(&w, results)
return txResponse(w, results.Items, err) return response.JSON(w, results.Items)
})
} }
func listEdgeJobTasks(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID) ([]*taskContainer, error) { func listEdgeJobTasks(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID) ([]*taskContainer, error) {

View File

@ -11,6 +11,7 @@ import (
"github.com/portainer/portainer/api/datastore" "github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/internal/testhelpers" "github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/roar" "github.com/portainer/portainer/api/roar"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -68,14 +69,16 @@ func Test_EdgeJobTasksListHandler(t *testing.T) {
tcStr := rr.Header().Get("x-total-count") tcStr := rr.Header().Get("x-total-count")
assert.NotEmpty(t, tcStr) assert.NotEmpty(t, tcStr)
totalCount, err := strconv.Atoi(tcStr) totalCount, err := strconv.Atoi(tcStr)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, expectedCount, totalCount) assert.Equal(t, expectedCount, totalCount)
taStr := rr.Header().Get("x-total-available") taStr := rr.Header().Get("x-total-available")
assert.NotEmpty(t, taStr) assert.NotEmpty(t, taStr)
totalAvailable, err := strconv.Atoi(taStr) totalAvailable, err := strconv.Atoi(taStr)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, envCount, totalAvailable) assert.Equal(t, envCount, totalAvailable)
} }

View File

@ -14,6 +14,7 @@ import (
"github.com/portainer/portainer/api/internal/endpointutils" "github.com/portainer/portainer/api/internal/endpointutils"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response"
"github.com/portainer/portainer/pkg/validate" "github.com/portainer/portainer/pkg/validate"
) )
@ -66,7 +67,7 @@ func (handler *Handler) edgeJobUpdate(w http.ResponseWriter, r *http.Request) *h
return err return err
}) })
return txResponse(w, edgeJob, err) return response.TxResponse(w, edgeJob, err)
} }
func (handler *Handler) updateEdgeJob(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID, payload edgeJobUpdatePayload) (*portainer.EdgeJob, error) { func (handler *Handler) updateEdgeJob(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID, payload edgeJobUpdatePayload) (*portainer.EdgeJob, error) {

View File

@ -1,14 +1,12 @@
package edgejobs package edgejobs
import ( import (
"errors"
"net/http" "net/http"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/http/security"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/response"
"github.com/gorilla/mux" "github.com/gorilla/mux"
) )
@ -60,16 +58,3 @@ func convertEndpointsToMetaObject(endpoints []portainer.EndpointID) map[portaine
return endpointsMap return endpointsMap
} }
func txResponse(w http.ResponseWriter, r any, err error) *httperror.HandlerError {
if err != nil {
var handlerError *httperror.HandlerError
if errors.As(err, &handlerError) {
return handlerError
}
return httperror.InternalServerError("Unexpected error", err)
}
return response.JSON(w, r)
}

View File

@ -1,7 +1,6 @@
package edgestacks package edgestacks
import ( import (
"errors"
"net/http" "net/http"
"strconv" "strconv"
@ -30,18 +29,11 @@ func (handler *Handler) edgeStackDelete(w http.ResponseWriter, r *http.Request)
return httperror.BadRequest("Invalid edge stack identifier route variable", err) return httperror.BadRequest("Invalid edge stack identifier route variable", err)
} }
if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
return handler.deleteEdgeStack(tx, portainer.EdgeStackID(edgeStackID)) return handler.deleteEdgeStack(tx, portainer.EdgeStackID(edgeStackID))
}); err != nil { })
var httpErr *httperror.HandlerError
if errors.As(err, &httpErr) {
return httpErr
}
return httperror.InternalServerError("Unexpected error", err) return response.TxEmptyResponse(w, err)
}
return response.Empty(w)
} }
func (handler *Handler) deleteEdgeStack(tx dataservices.DataStoreTx, edgeStackID portainer.EdgeStackID) error { func (handler *Handler) deleteEdgeStack(tx dataservices.DataStoreTx, edgeStackID portainer.EdgeStackID) error {

View File

@ -96,12 +96,7 @@ func (handler *Handler) edgeStackStatusUpdate(w http.ResponseWriter, r *http.Req
return nil return nil
}); err != nil { }); err != nil {
var httpErr *httperror.HandlerError return response.TxErrorResponse(err)
if errors.As(err, &httpErr) {
return httpErr
}
return httperror.InternalServerError("Unexpected error", err)
} }
if ok, _ := strconv.ParseBool(r.Header.Get("X-Portainer-No-Body")); ok { if ok, _ := strconv.ParseBool(r.Header.Get("X-Portainer-No-Body")); ok {

View File

@ -66,12 +66,7 @@ func (handler *Handler) edgeStackUpdate(w http.ResponseWriter, r *http.Request)
stack, err = handler.updateEdgeStack(tx, portainer.EdgeStackID(stackID), payload) stack, err = handler.updateEdgeStack(tx, portainer.EdgeStackID(stackID), payload)
return err return err
}); err != nil { }); err != nil {
var httpErr *httperror.HandlerError return response.TxErrorResponse(err)
if errors.As(err, &httpErr) {
return httpErr
}
return httperror.InternalServerError("Unexpected error", err)
} }
if err := fillEdgeStackStatus(handler.DataStore, stack); err != nil { if err := fillEdgeStackStatus(handler.DataStore, stack); err != nil {
@ -99,7 +94,7 @@ func (handler *Handler) updateEdgeStack(tx dataservices.DataStoreTx, stackID por
groupsIds := stack.EdgeGroups groupsIds := stack.EdgeGroups
if payload.EdgeGroups != nil { if payload.EdgeGroups != nil {
newRelated, _, err := handler.handleChangeEdgeGroups(tx, stack.ID, payload.EdgeGroups, relatedEndpointIds, relationConfig) newRelated, _, err := handler.handleChangeEdgeGroups(tx, stack, payload.EdgeGroups, relatedEndpointIds, relationConfig)
if err != nil { if err != nil {
return nil, httperror.InternalServerError("Unable to handle edge groups change", err) return nil, httperror.InternalServerError("Unable to handle edge groups change", err)
} }
@ -136,7 +131,7 @@ func (handler *Handler) updateEdgeStack(tx dataservices.DataStoreTx, stackID por
return stack, nil return stack, nil
} }
func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edgeStackID portainer.EdgeStackID, newEdgeGroupsIDs []portainer.EdgeGroupID, oldRelatedEnvironmentIDs []portainer.EndpointID, relationConfig *edge.EndpointRelationsConfig) ([]portainer.EndpointID, set.Set[portainer.EndpointID], error) { func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edgeStack *portainer.EdgeStack, newEdgeGroupsIDs []portainer.EdgeGroupID, oldRelatedEnvironmentIDs []portainer.EndpointID, relationConfig *edge.EndpointRelationsConfig) ([]portainer.EndpointID, set.Set[portainer.EndpointID], error) {
newRelatedEnvironmentIDs, err := edge.EdgeStackRelatedEndpoints(newEdgeGroupsIDs, relationConfig.Endpoints, relationConfig.EndpointGroups, relationConfig.EdgeGroups) newRelatedEnvironmentIDs, err := edge.EdgeStackRelatedEndpoints(newEdgeGroupsIDs, relationConfig.Endpoints, relationConfig.EndpointGroups, relationConfig.EdgeGroups)
if err != nil { if err != nil {
return nil, nil, errors.WithMessage(err, "Unable to retrieve edge stack related environments from database") return nil, nil, errors.WithMessage(err, "Unable to retrieve edge stack related environments from database")
@ -149,13 +144,13 @@ func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edge
relatedEnvironmentsToRemove := oldRelatedEnvironmentsSet.Difference(newRelatedEnvironmentsSet) relatedEnvironmentsToRemove := oldRelatedEnvironmentsSet.Difference(newRelatedEnvironmentsSet)
if len(relatedEnvironmentsToRemove) > 0 { if len(relatedEnvironmentsToRemove) > 0 {
if err := tx.EndpointRelation().RemoveEndpointRelationsForEdgeStack(relatedEnvironmentsToRemove.Keys(), edgeStackID); err != nil { if err := tx.EndpointRelation().RemoveEndpointRelationsForEdgeStack(relatedEnvironmentsToRemove.Keys(), edgeStack.ID); err != nil {
return nil, nil, errors.WithMessage(err, "Unable to remove edge stack relations from the database") return nil, nil, errors.WithMessage(err, "Unable to remove edge stack relations from the database")
} }
} }
if len(relatedEnvironmentsToAdd) > 0 { if len(relatedEnvironmentsToAdd) > 0 {
if err := tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEnvironmentsToAdd.Keys(), edgeStackID); err != nil { if err := tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEnvironmentsToAdd.Keys(), edgeStack); err != nil {
return nil, nil, errors.WithMessage(err, "Unable to add edge stack relations to the database") return nil, nil, errors.WithMessage(err, "Unable to add edge stack relations to the database")
} }
} }

View File

@ -5,7 +5,9 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/internal/testhelpers" "github.com/portainer/portainer/api/internal/testhelpers"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_hasKubeEndpoint(t *testing.T) { func Test_hasKubeEndpoint(t *testing.T) {
@ -40,7 +42,7 @@ func Test_hasKubeEndpoint(t *testing.T) {
for _, test := range tests { for _, test := range tests {
ans, err := hasKubeEndpoint(datastore.Endpoint(), test.endpointIds) ans, err := hasKubeEndpoint(datastore.Endpoint(), test.endpointIds)
assert.NoError(t, err, "hasKubeEndpoint shouldn't fail") require.NoError(t, err, "hasKubeEndpoint shouldn't fail")
assert.Equal(t, test.expected, ans, "hasKubeEndpoint expected to return %b for %v, but returned %b", test.expected, test.endpointIds, ans) assert.Equal(t, test.expected, ans, "hasKubeEndpoint expected to return %b for %v, but returned %b", test.expected, test.endpointIds, ans)
} }
@ -50,7 +52,7 @@ func Test_hasKubeEndpoint_failWhenEndpointDontExist(t *testing.T) {
datastore := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{})) datastore := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{}))
_, err := hasKubeEndpoint(datastore.Endpoint(), []portainer.EndpointID{1}) _, err := hasKubeEndpoint(datastore.Endpoint(), []portainer.EndpointID{1})
assert.Error(t, err, "hasKubeEndpoint should fail") require.Error(t, err, "hasKubeEndpoint should fail")
} }
func Test_hasDockerEndpoint(t *testing.T) { func Test_hasDockerEndpoint(t *testing.T) {
@ -85,7 +87,7 @@ func Test_hasDockerEndpoint(t *testing.T) {
for _, test := range tests { for _, test := range tests {
ans, err := hasDockerEndpoint(datastore.Endpoint(), test.endpointIds) ans, err := hasDockerEndpoint(datastore.Endpoint(), test.endpointIds)
assert.NoError(t, err, "hasDockerEndpoint shouldn't fail") require.NoError(t, err, "hasDockerEndpoint shouldn't fail")
assert.Equal(t, test.expected, ans, "hasDockerEndpoint expected to return %b for %v, but returned %b", test.expected, test.endpointIds, ans) assert.Equal(t, test.expected, ans, "hasDockerEndpoint expected to return %b for %v, but returned %b", test.expected, test.endpointIds, ans)
} }
@ -95,5 +97,5 @@ func Test_hasDockerEndpoint_failWhenEndpointDontExist(t *testing.T) {
datastore := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{})) datastore := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{}))
_, err := hasDockerEndpoint(datastore.Endpoint(), []portainer.EndpointID{1}) _, err := hasDockerEndpoint(datastore.Endpoint(), []portainer.EndpointID{1})
assert.Error(t, err, "hasDockerEndpoint should fail") require.Error(t, err, "hasDockerEndpoint should fail")
} }

View File

@ -49,26 +49,18 @@ func (payload *endpointGroupCreatePayload) Validate(r *http.Request) error {
// @router /endpoint_groups [post] // @router /endpoint_groups [post]
func (handler *Handler) endpointGroupCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { func (handler *Handler) endpointGroupCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
var payload endpointGroupCreatePayload var payload endpointGroupCreatePayload
err := request.DecodeAndValidateJSONPayload(r, &payload) if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil {
if err != nil {
return httperror.BadRequest("Invalid request payload", err) return httperror.BadRequest("Invalid request payload", err)
} }
var endpointGroup *portainer.EndpointGroup var endpointGroup *portainer.EndpointGroup
var err error
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
endpointGroup, err = handler.createEndpointGroup(tx, payload) endpointGroup, err = handler.createEndpointGroup(tx, payload)
return err return err
}) })
if err != nil {
var httpErr *httperror.HandlerError
if errors.As(err, &httpErr) {
return httpErr
}
return httperror.InternalServerError("Unexpected error", err) return response.TxResponse(w, endpointGroup, err)
}
return response.JSON(w, endpointGroup)
} }
func (handler *Handler) createEndpointGroup(tx dataservices.DataStoreTx, payload endpointGroupCreatePayload) (*portainer.EndpointGroup, error) { func (handler *Handler) createEndpointGroup(tx dataservices.DataStoreTx, payload endpointGroupCreatePayload) (*portainer.EndpointGroup, error) {

View File

@ -37,16 +37,8 @@ func (handler *Handler) endpointGroupDelete(w http.ResponseWriter, r *http.Reque
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
return handler.deleteEndpointGroup(tx, portainer.EndpointGroupID(endpointGroupID)) return handler.deleteEndpointGroup(tx, portainer.EndpointGroupID(endpointGroupID))
}) })
if err != nil {
var httpErr *httperror.HandlerError
if errors.As(err, &httpErr) {
return httpErr
}
return httperror.InternalServerError("Unexpected error", err) return response.TxEmptyResponse(w, err)
}
return response.Empty(w)
} }
func (handler *Handler) deleteEndpointGroup(tx dataservices.DataStoreTx, endpointGroupID portainer.EndpointGroupID) error { func (handler *Handler) deleteEndpointGroup(tx dataservices.DataStoreTx, endpointGroupID portainer.EndpointGroupID) error {

View File

@ -1,7 +1,6 @@
package endpointgroups package endpointgroups
import ( import (
"errors"
"net/http" "net/http"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
@ -39,16 +38,8 @@ func (handler *Handler) endpointGroupAddEndpoint(w http.ResponseWriter, r *http.
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
return handler.addEndpoint(tx, portainer.EndpointGroupID(endpointGroupID), portainer.EndpointID(endpointID)) return handler.addEndpoint(tx, portainer.EndpointGroupID(endpointGroupID), portainer.EndpointID(endpointID))
}) })
if err != nil {
var httpErr *httperror.HandlerError
if errors.As(err, &httpErr) {
return httpErr
}
return httperror.InternalServerError("Unexpected error", err) return response.TxEmptyResponse(w, err)
}
return response.Empty(w)
} }
func (handler *Handler) addEndpoint(tx dataservices.DataStoreTx, endpointGroupID portainer.EndpointGroupID, endpointID portainer.EndpointID) error { func (handler *Handler) addEndpoint(tx dataservices.DataStoreTx, endpointGroupID portainer.EndpointGroupID, endpointID portainer.EndpointID) error {

Some files were not shown because too many files have changed in this diff Show More