diff --git a/cmd/kubelet/app/server_bootstrap_test.go b/cmd/kubelet/app/server_bootstrap_test.go index 4217550125..5b686e2e55 100644 --- a/cmd/kubelet/app/server_bootstrap_test.go +++ b/cmd/kubelet/app/server_bootstrap_test.go @@ -93,7 +93,7 @@ func Test_buildClientCertificateManager(t *testing.T) { // get an expired CSR (simulating historical output) server.backdate = 2 * time.Hour - server.expectUserAgent = "FirstClient" + server.SetExpectUserAgent("FirstClient") ok, err := r.RotateCerts() if !ok || err != nil { t.Fatalf("unexpected rotation err: %t %v", ok, err) @@ -109,7 +109,7 @@ func Test_buildClientCertificateManager(t *testing.T) { // if m.Current() == nil, then we try again and get a valid // client server.backdate = 0 - server.expectUserAgent = "FirstClient" + server.SetExpectUserAgent("FirstClient") if ok, err := r.RotateCerts(); !ok || err != nil { t.Fatalf("unexpected rotation err: %t %v", ok, err) } @@ -122,7 +122,7 @@ func Test_buildClientCertificateManager(t *testing.T) { } // if m.Current() != nil, then we should use the second client - server.expectUserAgent = "SecondClient" + server.SetExpectUserAgent("SecondClient") if ok, err := r.RotateCerts(); !ok || err != nil { t.Fatalf("unexpected rotation err: %t %v", ok, err) } @@ -243,12 +243,24 @@ type csrSimulator struct { serverCA *x509.Certificate backdate time.Duration + userAgentLock sync.Mutex expectUserAgent string lock sync.Mutex csr *certapi.CertificateSigningRequest } +func (s *csrSimulator) SetExpectUserAgent(a string) { + s.userAgentLock.Lock() + defer s.userAgentLock.Unlock() + s.expectUserAgent = a +} +func (s *csrSimulator) ExpectUserAgent() string { + s.userAgentLock.Lock() + defer s.userAgentLock.Unlock() + return s.expectUserAgent +} + func (s *csrSimulator) ServeHTTP(w http.ResponseWriter, req *http.Request) { s.lock.Lock() defer s.lock.Unlock() @@ -258,11 +270,12 @@ func (s *csrSimulator) ServeHTTP(w http.ResponseWriter, req *http.Request) { q := req.URL.Query() q.Del("timeout") q.Del("timeoutSeconds") + q.Del("allowWatchBookmarks") req.URL.RawQuery = q.Encode() t.Logf("Request %q %q %q", req.Method, req.URL, req.UserAgent()) - if len(s.expectUserAgent) > 0 && req.UserAgent() != s.expectUserAgent { + if a := s.ExpectUserAgent(); len(a) > 0 && req.UserAgent() != a { t.Errorf("Unexpected user agent: %s", req.UserAgent()) }