Merge pull request #7876 from tstromberg/ssh-flaky-flaky
ssh_mock: Use sync.WaitGroup to defend against deferred mutationspull/7894/head
commit
c6e0f7ffd7
|
@ -36,6 +36,7 @@ func TestNewSSHClient(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("NewSSHServer: %v", err)
|
||||
}
|
||||
|
||||
port, err := s.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Error starting ssh server: %v", err)
|
||||
|
@ -50,27 +51,30 @@ func TestNewSSHClient(t *testing.T) {
|
|||
},
|
||||
T: t,
|
||||
}
|
||||
|
||||
c, err := NewSSHClient(d)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
t.Fatalf("Error creating client: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
sess, err := c.NewSession()
|
||||
if err != nil {
|
||||
t.Fatal("Error creating new session for ssh client")
|
||||
t.Fatalf("Error creating new session: %v", err)
|
||||
}
|
||||
defer sess.Close()
|
||||
|
||||
cmd := "foo"
|
||||
if err := sess.Run(cmd); err != nil {
|
||||
t.Fatalf("Error running %q: %v", cmd, err)
|
||||
t.Errorf("Error running %q: %v", cmd, err)
|
||||
}
|
||||
|
||||
if !s.Connected {
|
||||
t.Fatalf("Server not connected")
|
||||
t.Errorf("mock ssh server is not connected")
|
||||
}
|
||||
|
||||
if _, ok := s.Commands[cmd]; !ok {
|
||||
t.Fatalf("Expected command: %s", cmd)
|
||||
t.Errorf("Expected %q to be run, but it never was!", cmd)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
|
@ -79,8 +80,11 @@ type execRequest struct {
|
|||
|
||||
// Serve loop, listen for connections and store the commands.
|
||||
func (s *SSHServer) serve() {
|
||||
s.t.Logf("Serving ...")
|
||||
loop := 0
|
||||
for {
|
||||
s.t.Logf("Accepting...")
|
||||
loop++
|
||||
s.t.Logf("[loop %d] Accepting for %v...", loop, s)
|
||||
c, err := s.listener.Accept()
|
||||
if s.quit {
|
||||
return
|
||||
|
@ -95,13 +99,19 @@ func (s *SSHServer) serve() {
|
|||
|
||||
// handle an incoming ssh connection
|
||||
func (s *SSHServer) handleIncomingConnection(c net.Conn) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
_, chans, reqs, err := ssh.NewServerConn(c, s.Config)
|
||||
if err != nil {
|
||||
s.t.Logf("newserverconn error: %v", err)
|
||||
return
|
||||
}
|
||||
// The incoming Request channel must be serviced.
|
||||
go ssh.DiscardRequests(reqs)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
ssh.DiscardRequests(reqs)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// Service the incoming Channel channel.
|
||||
for newChannel := range chans {
|
||||
|
@ -115,12 +125,15 @@ func (s *SSHServer) handleIncomingConnection(c net.Conn) {
|
|||
}
|
||||
s.Connected = true
|
||||
for req := range requests {
|
||||
s.handleRequest(channel, req)
|
||||
s.handleRequest(channel, req, &wg)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (s *SSHServer) handleRequest(channel ssh.Channel, req *ssh.Request) {
|
||||
func (s *SSHServer) handleRequest(channel ssh.Channel, req *ssh.Request, wg *sync.WaitGroup) {
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
// Explicitly copy buffer contents to avoid data race
|
||||
b := s.Transfers.Bytes()
|
||||
|
@ -128,6 +141,7 @@ func (s *SSHServer) handleRequest(channel ssh.Channel, req *ssh.Request) {
|
|||
s.t.Errorf("copy failed: %v", err)
|
||||
}
|
||||
channel.Close()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
switch req.Type {
|
||||
|
@ -144,12 +158,15 @@ func (s *SSHServer) handleRequest(channel ssh.Channel, req *ssh.Request) {
|
|||
}
|
||||
s.Commands[cmd.Command] = 1
|
||||
|
||||
s.t.Logf("returning output for %s ...", cmd.Command)
|
||||
// Write specified command output as mocked ssh output
|
||||
if val, err := s.GetCommandToOutput(cmd.Command); err == nil {
|
||||
if _, err := channel.Write([]byte(val)); err != nil {
|
||||
s.t.Errorf("Write failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.t.Logf("setting exit-status for %s ...", cmd.Command)
|
||||
if _, err := channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0}); err != nil {
|
||||
s.t.Errorf("SendRequest failed: %v", err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue