Merge pull request #7876 from tstromberg/ssh-flaky-flaky

ssh_mock: Use sync.WaitGroup to defend against deferred mutations
pull/7894/head
Medya Ghazizadeh 2020-04-24 13:19:54 -07:00 committed by GitHub
commit c6e0f7ffd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 9 deletions

View File

@ -36,6 +36,7 @@ func TestNewSSHClient(t *testing.T) {
if err != nil {
t.Fatalf("NewSSHServer: %v", err)
}
port, err := s.Start()
if err != nil {
t.Fatalf("Error starting ssh server: %v", err)
@ -50,27 +51,30 @@ func TestNewSSHClient(t *testing.T) {
},
T: t,
}
c, err := NewSSHClient(d)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
t.Fatalf("Error creating client: %v", err)
}
defer c.Close()
sess, err := c.NewSession()
if err != nil {
t.Fatal("Error creating new session for ssh client")
t.Fatalf("Error creating new session: %v", err)
}
defer sess.Close()
cmd := "foo"
if err := sess.Run(cmd); err != nil {
t.Fatalf("Error running %q: %v", cmd, err)
t.Errorf("Error running %q: %v", cmd, err)
}
if !s.Connected {
t.Fatalf("Server not connected")
t.Errorf("mock ssh server is not connected")
}
if _, ok := s.Commands[cmd]; !ok {
t.Fatalf("Expected command: %s", cmd)
t.Errorf("Expected %q to be run, but it never was!", cmd)
}
}

View File

@ -24,6 +24,7 @@ import (
"io"
"net"
"strconv"
"sync"
"sync/atomic"
"testing"
@ -79,8 +80,11 @@ type execRequest struct {
// Serve loop, listen for connections and store the commands.
func (s *SSHServer) serve() {
s.t.Logf("Serving ...")
loop := 0
for {
s.t.Logf("Accepting...")
loop++
s.t.Logf("[loop %d] Accepting for %v...", loop, s)
c, err := s.listener.Accept()
if s.quit {
return
@ -95,13 +99,19 @@ func (s *SSHServer) serve() {
// handle an incoming ssh connection
func (s *SSHServer) handleIncomingConnection(c net.Conn) {
var wg sync.WaitGroup
_, chans, reqs, err := ssh.NewServerConn(c, s.Config)
if err != nil {
s.t.Logf("newserverconn error: %v", err)
return
}
// The incoming Request channel must be serviced.
go ssh.DiscardRequests(reqs)
wg.Add(1)
go func() {
ssh.DiscardRequests(reqs)
wg.Done()
}()
// Service the incoming Channel channel.
for newChannel := range chans {
@ -115,12 +125,15 @@ func (s *SSHServer) handleIncomingConnection(c net.Conn) {
}
s.Connected = true
for req := range requests {
s.handleRequest(channel, req)
s.handleRequest(channel, req, &wg)
}
}
wg.Wait()
}
func (s *SSHServer) handleRequest(channel ssh.Channel, req *ssh.Request) {
func (s *SSHServer) handleRequest(channel ssh.Channel, req *ssh.Request, wg *sync.WaitGroup) {
wg.Add(1)
go func() {
// Explicitly copy buffer contents to avoid data race
b := s.Transfers.Bytes()
@ -128,6 +141,7 @@ func (s *SSHServer) handleRequest(channel ssh.Channel, req *ssh.Request) {
s.t.Errorf("copy failed: %v", err)
}
channel.Close()
wg.Done()
}()
switch req.Type {
@ -144,12 +158,15 @@ func (s *SSHServer) handleRequest(channel ssh.Channel, req *ssh.Request) {
}
s.Commands[cmd.Command] = 1
s.t.Logf("returning output for %s ...", cmd.Command)
// Write specified command output as mocked ssh output
if val, err := s.GetCommandToOutput(cmd.Command); err == nil {
if _, err := channel.Write([]byte(val)); err != nil {
s.t.Errorf("Write failed: %v", err)
}
}
s.t.Logf("setting exit-status for %s ...", cmd.Command)
if _, err := channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0}); err != nil {
s.t.Errorf("SendRequest failed: %v", err)
}