From 3f5eec4c4e3e5c518d8c35c6435382572aff321f Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Wed, 4 Mar 2026 06:20:07 +0000 Subject: [PATCH] Drop use of github.com/gorilla/mux mux is replaced with a simple wrapper around http.ServeMux with middleware chain support Unfortunately github.com/rootless-containers/rootlesskit/pkg/parent still uses it so we can't drop the indirect dep yet. Signed-off-by: Brad Davidson --- go.mod | 2 +- pkg/agent/https/https.go | 4 +- pkg/cli/agent/agent.go | 2 +- pkg/cli/server/server.go | 2 +- pkg/cluster/managed.go | 6 +- pkg/etcd/etcd.go | 12 +-- pkg/etcd/s3/s3_test.go | 155 +++++++++++++++----------------- pkg/etcd/snapshot_handler.go | 3 +- pkg/metrics/metrics.go | 2 +- pkg/nodepassword/validate.go | 7 +- pkg/profile/profile.go | 4 +- pkg/server/auth/auth.go | 2 +- pkg/server/handlers/handlers.go | 12 +-- pkg/server/handlers/router.go | 21 +++-- pkg/spegel/spegel.go | 8 +- pkg/util/mux/mux.go | 122 +++++++++++++++++++++++++ 16 files changed, 237 insertions(+), 127 deletions(-) create mode 100644 pkg/util/mux/mux.go diff --git a/go.mod b/go.mod index d62f6704c5e..53c41ed138f 100644 --- a/go.mod +++ b/go.mod @@ -101,7 +101,6 @@ require ( github.com/google/cadvisor v0.53.0 github.com/google/go-containerregistry v0.20.2 github.com/google/uuid v1.6.0 - github.com/gorilla/mux v1.8.1 github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 github.com/inetaf/tcpproxy v0.0.0-20240214030015-3ce58045626c github.com/ipfs/go-ds-leveldb v0.5.0 @@ -299,6 +298,7 @@ require ( github.com/google/go-tpm v0.9.6 // indirect github.com/google/gopacket v1.1.19 // indirect github.com/google/pprof v0.0.0-20250820193118-f64d9cf942d6 // indirect + github.com/gorilla/mux v1.8.1 // indirect github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 // indirect github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.1 // indirect github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.0 // indirect diff --git a/pkg/agent/https/https.go b/pkg/agent/https/https.go index a688cca00dd..703cf6b8bfd 100644 --- a/pkg/agent/https/https.go +++ b/pkg/agent/https/https.go @@ -5,10 +5,10 @@ import ( "strconv" "sync" - "github.com/gorilla/mux" "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/server/auth" "github.com/k3s-io/k3s/pkg/util" + "github.com/k3s-io/k3s/pkg/util/mux" "k8s.io/apiserver/pkg/server" "k8s.io/apiserver/pkg/server/options" ) @@ -25,7 +25,7 @@ var err error // Subsequent calls will return the same router. func Start(ctx context.Context, nodeConfig *config.Node, runtime *config.ControlRuntime) (*mux.Router, error) { once.Do(func() { - router = mux.NewRouter().SkipClean(true) + router = mux.NewRouter() config := &server.Config{} if runtime == nil { diff --git a/pkg/cli/agent/agent.go b/pkg/cli/agent/agent.go index 345ec364088..52680edd517 100644 --- a/pkg/cli/agent/agent.go +++ b/pkg/cli/agent/agent.go @@ -7,7 +7,6 @@ import ( "path/filepath" "sync" - "github.com/gorilla/mux" "github.com/k3s-io/k3s/pkg/agent" "github.com/k3s-io/k3s/pkg/agent/https" "github.com/k3s-io/k3s/pkg/cli/cmds" @@ -20,6 +19,7 @@ import ( "github.com/k3s-io/k3s/pkg/spegel" "github.com/k3s-io/k3s/pkg/util" "github.com/k3s-io/k3s/pkg/util/errors" + "github.com/k3s-io/k3s/pkg/util/mux" "github.com/k3s-io/k3s/pkg/util/permissions" "github.com/k3s-io/k3s/pkg/version" "github.com/k3s-io/k3s/pkg/vpn" diff --git a/pkg/cli/server/server.go b/pkg/cli/server/server.go index acfede36963..d4c6e19890d 100644 --- a/pkg/cli/server/server.go +++ b/pkg/cli/server/server.go @@ -11,7 +11,6 @@ import ( "time" systemd "github.com/coreos/go-systemd/v22/daemon" - "github.com/gorilla/mux" "github.com/k3s-io/k3s/pkg/agent" "github.com/k3s-io/k3s/pkg/agent/https" "github.com/k3s-io/k3s/pkg/agent/loadbalancer" @@ -30,6 +29,7 @@ import ( "github.com/k3s-io/k3s/pkg/spegel" "github.com/k3s-io/k3s/pkg/util" "github.com/k3s-io/k3s/pkg/util/errors" + "github.com/k3s-io/k3s/pkg/util/mux" "github.com/k3s-io/k3s/pkg/util/permissions" "github.com/k3s-io/k3s/pkg/version" "github.com/k3s-io/k3s/pkg/vpn" diff --git a/pkg/cluster/managed.go b/pkg/cluster/managed.go index d380eeeac74..6840788815e 100644 --- a/pkg/cluster/managed.go +++ b/pkg/cluster/managed.go @@ -13,11 +13,11 @@ import ( "sync" "time" - "github.com/gorilla/mux" "github.com/k3s-io/k3s/pkg/cluster/managed" "github.com/k3s-io/k3s/pkg/etcd" "github.com/k3s-io/k3s/pkg/nodepassword" "github.com/k3s-io/k3s/pkg/util" + "github.com/k3s-io/k3s/pkg/util/mux" "github.com/k3s-io/k3s/pkg/version" "github.com/sirupsen/logrus" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -144,10 +144,10 @@ func (c *Cluster) deleteNodePasswdSecret(ctx context.Context) { // handlerNoEtcd wraps a handler with an error message indicating that etcd is not deployed. func handlerNoEtcd(handler http.Handler) http.Handler { - r := mux.NewRouter().SkipClean(true) + r := mux.NewRouter() // Wildcard route for anything after /db/ - r.HandleFunc("/db/{_:.*}", func(resp http.ResponseWriter, r *http.Request) { + r.HandleFunc("/db/", func(resp http.ResponseWriter, r *http.Request) { util.SendError(errors.New("etcd datastore disabled"), resp, r, http.StatusBadRequest) }) diff --git a/pkg/etcd/etcd.go b/pkg/etcd/etcd.go index dbd2c54916f..d1997caa5c1 100644 --- a/pkg/etcd/etcd.go +++ b/pkg/etcd/etcd.go @@ -19,7 +19,6 @@ import ( "time" "github.com/google/uuid" - "github.com/gorilla/mux" "github.com/k3s-io/k3s/pkg/clientaccess" "github.com/k3s-io/k3s/pkg/cluster/managed" "github.com/k3s-io/k3s/pkg/daemons/config" @@ -31,6 +30,7 @@ import ( "github.com/k3s-io/k3s/pkg/signals" "github.com/k3s-io/k3s/pkg/util" "github.com/k3s-io/k3s/pkg/util/errors" + "github.com/k3s-io/k3s/pkg/util/mux" "github.com/k3s-io/k3s/pkg/version" kine "github.com/k3s-io/kine/pkg/app" "github.com/k3s-io/kine/pkg/client" @@ -744,16 +744,16 @@ func (e *ETCD) setName(force bool) error { // handler wraps the handler with routes for database info func (e *ETCD) handler(next http.Handler) http.Handler { - r := mux.NewRouter().SkipClean(true) + r := mux.NewRouter() r.NotFoundHandler = next - ir := r.Path("/db/info").Subrouter() + ir := r.SubRouter("/db/info") ir.Use(auth.IsLocalOrHasRole(e.config, version.Program+":server")) - ir.Handle("", e.infoHandler()) + ir.Handle("/", e.infoHandler()) - sr := r.Path("/db/snapshot").Subrouter() + sr := r.SubRouter("/db/snapshot") sr.Use(auth.HasRole(e.config, version.Program+":server")) - sr.Handle("", e.snapshotHandler()) + sr.Handle("/", e.snapshotHandler()) return r } diff --git a/pkg/etcd/s3/s3_test.go b/pkg/etcd/s3/s3_test.go index 131757ced05..0a0361dda76 100644 --- a/pkg/etcd/s3/s3_test.go +++ b/pkg/etcd/s3/s3_test.go @@ -14,10 +14,10 @@ import ( "text/template" "time" - "github.com/gorilla/mux" "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/etcd/snapshot" "github.com/k3s-io/k3s/pkg/util" + "github.com/k3s-io/k3s/pkg/util/mux" "github.com/k3s-io/k3s/tests/mock" "github.com/rancher/dynamiclistener/cert" "github.com/rancher/wrangler/v3/pkg/generated/controllers/core" @@ -1554,65 +1554,72 @@ func s3Router(t *testing.T) http.Handler { // badbucket returns 404 for all requests // authbucket returns 200 for HeadBucket, 403 for all others // others return 200 for objects with name prefix snapshot, 404 for all others - router := mux.NewRouter().SkipClean(true) + router := mux.NewRouter() // HeadBucket - router.Path("/{bucket}/").Methods(http.MethodHead).HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - if vars["bucket"] == "badbucket" { - rw.WriteHeader(http.StatusNotFound) - } - }) // ListObjectsV2 - router.Path("/{bucket}/").Methods(http.MethodGet).HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - switch vars["bucket"] { - case "badbucket": - rw.WriteHeader(http.StatusNotFound) - case "authbucket": - rw.WriteHeader(http.StatusForbidden) - default: - prefix := r.URL.Query().Get("prefix") - filtered := []object{} - for _, object := range objects { - if strings.HasPrefix(object.Key, prefix) { - filtered = append(filtered, object) + router.HandleFunc("GET /{bucket}/{$}", func(rw http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodHead: + if r.PathValue("bucket") == "badbucket" { + rw.WriteHeader(http.StatusNotFound) + } + case http.MethodGet: + switch r.PathValue("bucket") { + case "badbucket": + rw.WriteHeader(http.StatusNotFound) + case "authbucket": + rw.WriteHeader(http.StatusForbidden) + default: + prefix := r.URL.Query().Get("prefix") + filtered := []object{} + for _, object := range objects { + if strings.HasPrefix(object.Key, prefix) { + filtered = append(filtered, object) + } + } + if err := listResponse.Execute(rw, bucket{Name: r.PathValue("bucket"), Prefix: prefix, Objects: filtered}); err != nil { + t.Errorf("Failed to generate ListObjectsV2 response, error = %v", err) + rw.WriteHeader(http.StatusInternalServerError) } } - if err := listResponse.Execute(rw, bucket{Name: vars["bucket"], Prefix: prefix, Objects: filtered}); err != nil { - t.Errorf("Failed to generate ListObjectsV2 response, error = %v", err) - rw.WriteHeader(http.StatusInternalServerError) + } + }) + // HeadObject + // GetObject + router.HandleFunc("GET /{bucket}/{object...}", func(rw http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodHead: + switch r.PathValue("bucket") { + case "badbucket": + rw.WriteHeader(http.StatusNotFound) + case "authbucket": + rw.WriteHeader(http.StatusForbidden) + default: + if strings.Contains(r.PathValue("object"), "bad") { + rw.WriteHeader(http.StatusNotFound) + } else { + rw.Header().Add("last-modified", time.Now().In(gmt).Format(time.RFC1123)) + } + } + case http.MethodGet: + switch r.PathValue("bucket") { + case "badbucket": + rw.WriteHeader(http.StatusNotFound) + case "authbucket": + rw.WriteHeader(http.StatusForbidden) + default: + if strings.Contains(r.PathValue("object"), "bad") { + rw.WriteHeader(http.StatusNotFound) + } else { + rw.Header().Add("last-modified", time.Now().In(gmt).Format(time.RFC1123)) + rw.Write([]byte("test snapshot file\n")) + } } } }) - // HeadObject - snapshot - router.Path("/{bucket}/{prefix:.*}snapshot-{snapshot}").Methods(http.MethodHead).HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - switch vars["bucket"] { - case "badbucket": - rw.WriteHeader(http.StatusNotFound) - case "authbucket": - rw.WriteHeader(http.StatusForbidden) - default: - rw.Header().Add("last-modified", time.Now().In(gmt).Format(time.RFC1123)) - } - }) - // GetObject - snapshot - router.Path("/{bucket}/{prefix:.*}snapshot-{snapshot}").Methods(http.MethodGet).HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - switch vars["bucket"] { - case "badbucket": - rw.WriteHeader(http.StatusNotFound) - case "authbucket": - rw.WriteHeader(http.StatusForbidden) - default: - rw.Header().Add("last-modified", time.Now().In(gmt).Format(time.RFC1123)) - rw.Write([]byte("test snapshot file\n")) - } - }) - // PutObject/DeleteObject - snapshot - router.Path("/{bucket}/{prefix:.*}snapshot-{snapshot}").Methods(http.MethodPut, http.MethodDelete).HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - switch vars["bucket"] { + // PutObject + router.HandleFunc("PUT /{bucket}/{object...}", func(rw http.ResponseWriter, r *http.Request) { + switch r.PathValue("bucket") { case "badbucket": rw.WriteHeader(http.StatusNotFound) case "authbucket": @@ -1623,45 +1630,27 @@ func s3Router(t *testing.T) http.Handler { } } }) - // HeadObject - snapshot metadata - router.Path("/{bucket}/{prefix:.*}.metadata/snapshot-{snapshot}").Methods(http.MethodHead).HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - switch vars["bucket"] { - case "badbucket": - rw.WriteHeader(http.StatusNotFound) - case "authbucket": - rw.WriteHeader(http.StatusForbidden) - default: - rw.Header().Add("last-modified", time.Now().In(gmt).Format(time.RFC1123)) - } - }) - // GetObject - snapshot metadata - router.Path("/{bucket}/{prefix:.*}.metadata/snapshot-{snapshot}").Methods(http.MethodGet).HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - switch vars["bucket"] { - case "badbucket": - rw.WriteHeader(http.StatusNotFound) - case "authbucket": - rw.WriteHeader(http.StatusForbidden) - default: - rw.Header().Add("last-modified", time.Now().In(gmt).Format(time.RFC1123)) - rw.Write([]byte("test snapshot metadata\n")) - } - }) - // PutObject/DeleteObject - snapshot metadata - router.Path("/{bucket}/{prefix:.*}.metadata/snapshot-{snapshot}").Methods(http.MethodPut, http.MethodDelete).HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - switch vars["bucket"] { + // DeleteObject + router.HandleFunc("DELETE /{bucket}/{object...}", func(rw http.ResponseWriter, r *http.Request) { + switch r.PathValue("bucket") { case "badbucket": rw.WriteHeader(http.StatusNotFound) case "authbucket": rw.WriteHeader(http.StatusForbidden) default: if r.Method == http.MethodDelete { - rw.WriteHeader(http.StatusNoContent) + if strings.Contains(r.PathValue("object"), "bad") { + rw.WriteHeader(http.StatusNotFound) + } else { + rw.WriteHeader(http.StatusNoContent) + } } } }) + router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + logrus.Errorf("Failed to match %q", r.URL) + rw.WriteHeader(http.StatusInternalServerError) + }) return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { scheme := "http" if r.TLS != nil { diff --git a/pkg/etcd/snapshot_handler.go b/pkg/etcd/snapshot_handler.go index 6e1ce856776..8439f47ce12 100644 --- a/pkg/etcd/snapshot_handler.go +++ b/pkg/etcd/snapshot_handler.go @@ -12,6 +12,7 @@ import ( "github.com/k3s-io/k3s/pkg/util" "github.com/k3s-io/k3s/pkg/util/errors" "github.com/sirupsen/logrus" + apierrors "k8s.io/apimachinery/pkg/api/errors" ) type SnapshotOperation string @@ -178,7 +179,7 @@ func (e *ETCD) withRequest(sr *SnapshotRequest) *ETCD { // getSnapshotRequest unmarshalls the snapshot operation request from a client. func getSnapshotRequest(req *http.Request) (*SnapshotRequest, error) { if req.Method != http.MethodPost { - return nil, http.ErrNotSupported + return nil, apierrors.NewMethodNotSupported(k3s.Resource("snapshot"), req.Method) } sr := &SnapshotRequest{} b, err := io.ReadAll(req.Body) diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index c555312e0b1..c4aa8fcdc8b 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -4,11 +4,11 @@ import ( "context" "errors" - "github.com/gorilla/mux" "github.com/k3s-io/k3s/pkg/agent/https" "github.com/k3s-io/k3s/pkg/agent/loadbalancer" "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/etcd/snapshotmetrics" + "github.com/k3s-io/k3s/pkg/util/mux" "github.com/prometheus/client_golang/prometheus/promhttp" lassometrics "github.com/rancher/lasso/pkg/metrics" rdmetrics "github.com/rancher/remotedialer/metrics" diff --git a/pkg/nodepassword/validate.go b/pkg/nodepassword/validate.go index 9df977053cf..d9736371275 100644 --- a/pkg/nodepassword/validate.go +++ b/pkg/nodepassword/validate.go @@ -10,10 +10,10 @@ import ( "sync" "time" - "github.com/gorilla/mux" "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/util" "github.com/k3s-io/k3s/pkg/util/errors" + "github.com/k3s-io/k3s/pkg/version" "github.com/sirupsen/logrus" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" @@ -105,13 +105,12 @@ func getNodeInfo(req *http.Request) (*nodeInfo, error) { return nil, errors.New("auth user not set") } - program := mux.Vars(req)["program"] - nodeName := req.Header.Get(program + "-Node-Name") + nodeName := req.Header.Get(version.Program + "-Node-Name") if nodeName == "" { return nil, errors.New("node name not set") } - nodePassword := req.Header.Get(program + "-Node-Password") + nodePassword := req.Header.Get(version.Program + "-Node-Password") if nodePassword == "" { return nil, errors.New("node password not set") } diff --git a/pkg/profile/profile.go b/pkg/profile/profile.go index 39c3929a580..729453e2f6c 100644 --- a/pkg/profile/profile.go +++ b/pkg/profile/profile.go @@ -5,9 +5,9 @@ import ( "errors" "net/http/pprof" - "github.com/gorilla/mux" "github.com/k3s-io/k3s/pkg/agent/https" "github.com/k3s-io/k3s/pkg/daemons/config" + "github.com/k3s-io/k3s/pkg/util/mux" ) // DefaultProfiler the default instance of a performance profiling server @@ -33,6 +33,6 @@ func (c *Config) Start(ctx context.Context, nodeConfig *config.Node) error { mRouter.HandleFunc("/debug/pprof/profile", pprof.Profile) mRouter.HandleFunc("/debug/pprof/symbol", pprof.Symbol) mRouter.HandleFunc("/debug/pprof/trace", pprof.Trace) - mRouter.PathPrefix("/debug/pprof/").HandlerFunc(pprof.Index) + mRouter.HandleFunc("/debug/pprof/", pprof.Index) return nil } diff --git a/pkg/server/auth/auth.go b/pkg/server/auth/auth.go index c125908c26b..179b5387d0e 100644 --- a/pkg/server/auth/auth.go +++ b/pkg/server/auth/auth.go @@ -5,9 +5,9 @@ import ( "net" "net/http" - "github.com/gorilla/mux" "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/util" + "github.com/k3s-io/k3s/pkg/util/mux" "github.com/k3s-io/k3s/pkg/version" "github.com/sirupsen/logrus" "k8s.io/apiserver/pkg/apis/apiserver" diff --git a/pkg/server/handlers/handlers.go b/pkg/server/handlers/handlers.go index 79584650265..e1326e79592 100644 --- a/pkg/server/handlers/handlers.go +++ b/pkg/server/handlers/handlers.go @@ -13,18 +13,20 @@ import ( "strings" "time" - "github.com/gorilla/mux" "github.com/k3s-io/k3s/pkg/cli/cmds" "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/etcd" "github.com/k3s-io/k3s/pkg/nodepassword" "github.com/k3s-io/k3s/pkg/util" "github.com/k3s-io/k3s/pkg/util/errors" + "github.com/k3s-io/k3s/pkg/version" certutil "github.com/rancher/dynamiclistener/cert" "github.com/sirupsen/logrus" discoveryv1 "k8s.io/api/discovery/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/json" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apiserver/pkg/authentication/user" @@ -66,8 +68,7 @@ func ServingKubeletCert(control *config.Control, auth nodepassword.NodeAuthValid } ips := []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")} - program := mux.Vars(req)["program"] - if nodeIP := req.Header.Get(program + "-Node-IP"); nodeIP != "" { + if nodeIP := req.Header.Get(version.Program + "-Node-IP"); nodeIP != "" { for _, v := range strings.Split(nodeIP, ",") { ip := net.ParseIP(v) if ip == nil { @@ -115,9 +116,8 @@ func ClientKubeProxyCert(control *config.Control) http.Handler { func ClientControllerCert(control *config.Control) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - program := mux.Vars(req)["program"] signAndSend(resp, req, control.Runtime.ClientCA, control.Runtime.ClientCAKey, control.Runtime.ClientK3sControllerKey, certutil.Config{ - CommonName: "system:" + program + "-controller", + CommonName: "system:" + version.Program + "-controller", Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, }) }) @@ -314,7 +314,7 @@ func getCACertAndKey(caCertFile, caKeyFile string) ([]*x509.Certificate, crypto. // If the request is not a POST, or cannot be parsed as a request, an error is returned. func getCSR(req *http.Request) (*x509.CertificateRequest, error) { if req.Method != http.MethodPost { - return nil, mux.ErrMethodMismatch + return nil, apierrors.NewMethodNotSupported(schema.GroupResource{}, req.Method) } csrBytes, err := io.ReadAll(req.Body) if err != nil { diff --git a/pkg/server/handlers/router.go b/pkg/server/handlers/router.go index 45abb7598dc..4e2df6bf493 100644 --- a/pkg/server/handlers/router.go +++ b/pkg/server/handlers/router.go @@ -5,11 +5,11 @@ import ( "net/http" "path/filepath" - "github.com/gorilla/mux" "github.com/k3s-io/k3s/pkg/cli/cmds" "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/nodepassword" "github.com/k3s-io/k3s/pkg/server/auth" + "github.com/k3s-io/k3s/pkg/util/mux" "github.com/k3s-io/k3s/pkg/version" "k8s.io/apiserver/pkg/authentication/user" bootstrapapi "k8s.io/cluster-bootstrap/token/api" @@ -30,26 +30,26 @@ var ( func NewHandler(ctx context.Context, control *config.Control, cfg *cmds.Server) http.Handler { nodeAuth := nodepassword.GetNodeAuthValidator(ctx, control) - prefix := "/v1-{program}" - authed := mux.NewRouter().SkipClean(true) + prefix := "/v1-" + version.Program + authed := mux.NewRouter() authed.NotFoundHandler = APIServer(control, cfg) authed.Use(auth.HasRole(control, version.Program+":agent", user.NodesGroup, bootstrapapi.BootstrapDefaultGroup), auth.RequestInfo(), auth.MaxInFlight(maxNonMutatingAgentRequests, maxMutatingAgentRequests)) authed.Handle(prefix+"/serving-kubelet.crt", ServingKubeletCert(control, nodeAuth)) authed.Handle(prefix+"/client-kubelet.crt", ClientKubeletCert(control, nodeAuth)) authed.Handle(prefix+"/client-kube-proxy.crt", ClientKubeProxyCert(control)) - authed.Handle(prefix+"/client-{program}-controller.crt", ClientControllerCert(control)) + authed.Handle(prefix+"/client-"+version.Program+"-controller.crt", ClientControllerCert(control)) authed.Handle(prefix+"/client-ca.crt", File(control.Runtime.ClientCA)) authed.Handle(prefix+"/server-ca.crt", File(control.Runtime.ServerCA)) authed.Handle(prefix+"/apiservers", APIServers(control)) authed.Handle(prefix+"/config", Config(control, cfg)) authed.Handle(prefix+"/readyz", Readyz(control)) - nodeAuthed := mux.NewRouter().SkipClean(true) + nodeAuthed := mux.NewRouter() nodeAuthed.NotFoundHandler = authed nodeAuthed.Use(auth.HasRole(control, user.NodesGroup)) nodeAuthed.Handle(prefix+"/connect", control.Runtime.Tunnel) - serverAuthed := mux.NewRouter().SkipClean(true) + serverAuthed := mux.NewRouter() serverAuthed.NotFoundHandler = nodeAuthed serverAuthed.Use(auth.HasRole(control, version.Program+":server")) serverAuthed.Handle(prefix+"/encrypt/status", EncryptionStatus(control)) @@ -58,15 +58,14 @@ func NewHandler(ctx context.Context, control *config.Control, cfg *cmds.Server) serverAuthed.Handle(prefix+"/server-bootstrap", Bootstrap(control)) serverAuthed.Handle(prefix+"/token", TokenRequest(ctx, control)) - systemAuthed := mux.NewRouter().SkipClean(true) + systemAuthed := mux.NewRouter() systemAuthed.NotFoundHandler = serverAuthed - systemAuthed.MethodNotAllowedHandler = serverAuthed systemAuthed.Use(auth.HasRole(control, user.SystemPrivilegedGroup)) - systemAuthed.Methods(http.MethodConnect).Handler(control.Runtime.Tunnel) + systemAuthed.Handle("CONNECT /", control.Runtime.Tunnel) - router := mux.NewRouter().SkipClean(true) + router := mux.NewRouter() router.NotFoundHandler = systemAuthed - router.PathPrefix(staticURL).Handler(Static(staticURL, filepath.Join(control.DataDir, "static"))) + router.Handle(staticURL, Static(staticURL, filepath.Join(control.DataDir, "static"))) router.Handle("/cacerts", CACerts(control)) router.Handle("/ping", Ping()) diff --git a/pkg/spegel/spegel.go b/pkg/spegel/spegel.go index fef25a96873..f64b15e1704 100644 --- a/pkg/spegel/spegel.go +++ b/pkg/spegel/spegel.go @@ -18,7 +18,6 @@ import ( "github.com/containerd/containerd/v2/core/remotes/docker" "github.com/go-logr/logr" - "github.com/gorilla/mux" leveldb "github.com/ipfs/go-ds-leveldb" ipfslog "github.com/ipfs/go-log/v2" "github.com/k3s-io/k3s/pkg/agent/https" @@ -26,6 +25,7 @@ import ( "github.com/k3s-io/k3s/pkg/server/auth" "github.com/k3s-io/k3s/pkg/util/errors" "github.com/k3s-io/k3s/pkg/util/logger" + "github.com/k3s-io/k3s/pkg/util/mux" "github.com/k3s-io/k3s/pkg/version" "github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p/core/crypto" @@ -294,10 +294,10 @@ func (c *Config) Start(ctx context.Context, nodeConfig *config.Node, criReadyCha if err != nil { return err } - mRouter.PathPrefix("/v2").Handler(regSvr.Handler) - sRouter := mRouter.PathPrefix("/v1-{program}/p2p").Subrouter() + mRouter.Handle("/v2/", regSvr.Handler) + sRouter := mRouter.SubRouter("/v1-" + version.Program + "/p2p") sRouter.Use(auth.MaxInFlight(maxNonMutatingPeerInfoRequests, maxMutatingPeerInfoRequests)) - sRouter.Handle("", c.peerInfo()) + sRouter.Handle("/", c.peerInfo()) // Wait up to 5 seconds for the p2p network to find peers. if err := wait.PollUntilContextTimeout(ctx, time.Second, resolveTimeout, true, func(ctx context.Context) (bool, error) { diff --git a/pkg/util/mux/mux.go b/pkg/util/mux/mux.go new file mode 100644 index 00000000000..f66d080d47c --- /dev/null +++ b/pkg/util/mux/mux.go @@ -0,0 +1,122 @@ +package mux + +import ( + "net/http" +) + +// MiddlewareFunc is the function signature of middlewares. The middleware is expected +// to return a handler that does work and then either writes a response, or calls the +// provided handler for additional processing. +type MiddlewareFunc func(http.Handler) http.Handler + +// Handler wraps http.Handler and allows selectively running middleware on a matched route +type Handler interface { + http.Handler + Matched() bool +} + +// muxHandler is a wrapper around http.Handler, +// used to differentiate between http.ServeMux internal +// handlers and handlers from this package. +type muxHandler struct { + handler http.Handler +} + +// ServeHTTP calls the wrapped function of the same name. +func (mh *muxHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + mh.handler.ServeHTTP(rw, req) +} + +func (mh *muxHandler) Matched() bool { + return true +} + +// rootHandler runs the router's NotFoundHandler if one is set, +// or returns a fixed NotFound error. Middlewares are run +// if the NotFound handler was registered as a match for the root path. +type rootHandler struct { + r *Router + rootMatch bool +} + +func (rh *rootHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + if rh.r.NotFoundHandler != nil { + rh.r.NotFoundHandler.ServeHTTP(rw, req) + } else { + rw.WriteHeader(http.StatusNotFound) + } +} + +func (rh *rootHandler) Matched() bool { + return rh.rootMatch && rh.r.NotFoundHandler != nil +} + +// Router wraps http.ServeMux, adding functionality to call +// middlewares on matched requests. +type Router struct { + NotFoundHandler http.Handler + + sm *http.ServeMux + rootHandler *rootHandler + middlewares []MiddlewareFunc +} + +// NewRouter creates a new Router +func NewRouter() *Router { + r := &Router{sm: http.NewServeMux()} + r.rootHandler = &rootHandler{r: r} + r.sm.Handle("/", r.rootHandler) + return r +} + +// Use registers one or more middlewares. Middlewares are only run +// when a route is matched. +func (r *Router) Use(mwfs ...MiddlewareFunc) { + r.middlewares = append(r.middlewares, mwfs...) +} + +// SubRouter registers a route pattern, and returns a new router. Middlewares +// will only run if the subrouter pattern was matched. +// Note that the base pattern for the subrouter is NOT automatically +// prefixed to paths registered beneath it. +func (r *Router) SubRouter(pattern string) *Router { + sr := NewRouter() + r.Handle(pattern, sr) + return sr +} + +// Handle registers a route pattern to call a handler +func (r *Router) Handle(pattern string, handler http.Handler) { + handler = &muxHandler{handler: handler} + if pattern == "/" && r.NotFoundHandler == nil { + r.rootHandler.rootMatch = true + r.NotFoundHandler = handler + } else { + r.sm.Handle(pattern, handler) + } +} + +// HandleFunc registers a route pattern to call a handler function +func (r *Router) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) { + r.Handle(pattern, http.HandlerFunc(handler)) +} + +// ServeHTTP handles the request, running middlewares if a registered pattern +// has been matched. +func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + next := http.Handler(r.sm) + // fix up path for CONNECT requests so that they do not get redirected + if req.Method == http.MethodConnect && req.URL.Path == "" { + req.URL.Path = "/" + } + // only run middlewares if this is a mux handler; other handlers are + // http.ServeMux internal handlers that indicate no pattern was matched, and + // we should not run middleware. + handler, _ := r.sm.Handler(req) + if h, ok := handler.(Handler); ok && h.Matched() { + for i := len(r.middlewares) - 1; i >= 0; i-- { + next = r.middlewares[i](next) + } + } + next.ServeHTTP(rw, req) +}