From a17e336993dce9f0bede53f311c51c5ac31e9b03 Mon Sep 17 00:00:00 2001 From: Erik Wilson Date: Wed, 24 Jul 2019 00:22:31 -0700 Subject: [PATCH] Use go tcpproxy --- pkg/agent/loadbalancer/config.go | 38 ++++ pkg/agent/loadbalancer/loadbalancer.go | 142 +++++++++++++++ pkg/agent/loadbalancer/loadbalancer_test.go | 183 ++++++++++++++++++++ pkg/agent/loadbalancer/servers.go | 57 ++++++ pkg/agent/loadbalancer/utility.go | 49 ++++++ pkg/agent/run.go | 17 +- pkg/agent/tunnel/tunnel.go | 8 +- pkg/cli/cmds/agent.go | 1 + pkg/cli/server/server.go | 1 + 9 files changed, 491 insertions(+), 5 deletions(-) create mode 100644 pkg/agent/loadbalancer/config.go create mode 100644 pkg/agent/loadbalancer/loadbalancer.go create mode 100644 pkg/agent/loadbalancer/loadbalancer_test.go create mode 100644 pkg/agent/loadbalancer/servers.go create mode 100644 pkg/agent/loadbalancer/utility.go diff --git a/pkg/agent/loadbalancer/config.go b/pkg/agent/loadbalancer/config.go new file mode 100644 index 0000000000..9c3ce95098 --- /dev/null +++ b/pkg/agent/loadbalancer/config.go @@ -0,0 +1,38 @@ +package loadbalancer + +import ( + "encoding/json" + "io/ioutil" + + "github.com/rancher/k3s/pkg/agent/util" +) + +func (lb *LoadBalancer) writeConfig() error { + configOut, err := json.MarshalIndent(lb, "", " ") + if err != nil { + return err + } + if err := util.WriteFile(lb.configFile, string(configOut)); err != nil { + return err + } + return nil +} + +func (lb *LoadBalancer) updateConfig() error { + writeConfig := true + if configBytes, err := ioutil.ReadFile(lb.configFile); err == nil { + config := &LoadBalancer{} + if err := json.Unmarshal(configBytes, config); err == nil { + if config.ServerURL == lb.ServerURL { + writeConfig = false + lb.setServers(config.ServerAddresses) + } + } + } + if writeConfig { + if err := lb.writeConfig(); err != nil { + return err + } + } + return nil +} diff --git a/pkg/agent/loadbalancer/loadbalancer.go b/pkg/agent/loadbalancer/loadbalancer.go new file mode 100644 index 0000000000..b15b6154aa --- /dev/null +++ b/pkg/agent/loadbalancer/loadbalancer.go @@ -0,0 +1,142 @@ +package loadbalancer + +import ( + "context" + "errors" + "net" + "path/filepath" + "sync" + + "github.com/google/tcpproxy" + "github.com/rancher/k3s/pkg/cli/cmds" + "github.com/sirupsen/logrus" +) + +type LoadBalancer struct { + mutex sync.Mutex + dialer *net.Dialer + proxy *tcpproxy.Proxy + + configFile string + localAddress string + localServerURL string + originalServerAddress string + ServerURL string + ServerAddresses []string + randomServers []string + currentServerAddress string + nextServerIndex int +} + +const ( + serviceName = "k3s-agent-load-balancer" +) + +func Setup(ctx context.Context, cfg cmds.Agent) (_lb *LoadBalancer, _err error) { + if cfg.DisableLoadBalancer { + return nil, nil + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + defer func() { + if _err != nil { + logrus.Warnf("Error starting load balancer: %s", _err) + if listener != nil { + listener.Close() + } + } + }() + if err != nil { + return nil, err + } + localAddress := listener.Addr().String() + + originalServerAddress, localServerURL, err := parseURL(cfg.ServerURL, localAddress) + if err != nil { + return nil, err + } + + lb := &LoadBalancer{ + dialer: &net.Dialer{}, + configFile: filepath.Join(cfg.DataDir, "etc", serviceName+".json"), + localAddress: localAddress, + localServerURL: localServerURL, + originalServerAddress: originalServerAddress, + ServerURL: cfg.ServerURL, + } + + lb.setServers([]string{lb.originalServerAddress}) + + lb.proxy = &tcpproxy.Proxy{ + ListenFunc: func(string, string) (net.Listener, error) { + return listener, nil + }, + } + lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{ + Addr: serviceName, + DialContext: lb.dialContext, + }) + + if err := lb.updateConfig(); err != nil { + return nil, err + } + if err := lb.proxy.Start(); err != nil { + return nil, err + } + logrus.Infof("Running load balancer %s -> %v", lb.localAddress, lb.randomServers) + + return lb, nil +} + +func (lb *LoadBalancer) Update(serverAddresses []string) { + if lb == nil { + return + } + if !lb.setServers(serverAddresses) { + return + } + logrus.Infof("Updating load balancer server addresses -> %v", lb.randomServers) + + if err := lb.writeConfig(); err != nil { + logrus.Warnf("Error updating load balancer config: %s", err) + } +} + +func (lb *LoadBalancer) LoadBalancerServerURL() string { + if lb == nil { + return "" + } + return lb.localServerURL +} + +func (lb *LoadBalancer) dialContext(ctx context.Context, network, address string) (net.Conn, error) { + startIndex := lb.nextServerIndex + for { + targetServer := lb.currentServerAddress + + conn, err := lb.dialer.DialContext(ctx, network, targetServer) + if err == nil { + return conn, nil + } + logrus.Warnf("Dial error from load balancer: %s", err) + + newServer, err := lb.nextServer(targetServer) + if err != nil { + return nil, err + } + if targetServer != newServer { + logrus.Warnf("Dial context in load balancer failed over to %s", newServer) + } + if ctx.Err() != nil { + return nil, ctx.Err() + } + + maxIndex := len(lb.randomServers) + if startIndex > maxIndex { + startIndex = maxIndex + } + if lb.nextServerIndex == startIndex { + return nil, errors.New("all servers failed") + } + } +} diff --git a/pkg/agent/loadbalancer/loadbalancer_test.go b/pkg/agent/loadbalancer/loadbalancer_test.go new file mode 100644 index 0000000000..41b5cbaa2c --- /dev/null +++ b/pkg/agent/loadbalancer/loadbalancer_test.go @@ -0,0 +1,183 @@ +package loadbalancer + +import ( + "bufio" + "context" + "errors" + "fmt" + "io/ioutil" + "net" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/rancher/k3s/pkg/cli/cmds" +) + +type server struct { + listener net.Listener + conns []net.Conn + prefix string +} + +func createServer(prefix string) (*server, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, err + } + s := &server{ + prefix: prefix, + listener: listener, + } + go s.serve() + return s, nil +} + +func (s *server) serve() { + for { + conn, err := s.listener.Accept() + if err != nil { + return + } + s.conns = append(s.conns, conn) + go s.echo(conn) + } +} + +func (s *server) close() { + s.listener.Close() + for _, conn := range s.conns { + conn.Close() + } +} + +func (s *server) echo(conn net.Conn) { + for { + result, err := bufio.NewReader(conn).ReadString('\n') + if err != nil { + return + } + conn.Write([]byte(s.prefix + ":" + result)) + } +} + +func ping(conn net.Conn) (string, error) { + fmt.Fprintf(conn, "ping\n") + result, err := bufio.NewReader(conn).ReadString('\n') + if err != nil { + return "", err + } + return strings.TrimSpace(result), nil +} + +func assertEqual(t *testing.T, a interface{}, b interface{}) { + if a != b { + t.Fatalf("[ %v != %v ]", a, b) + } +} + +func assertNotEqual(t *testing.T, a interface{}, b interface{}) { + if a == b { + t.Fatalf("[ %v == %v ]", a, b) + } +} + +func TestFailOver(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "lb-test") + if err != nil { + assertEqual(t, err, nil) + } + defer os.RemoveAll(tmpDir) + + ogServe, err := createServer("og") + if err != nil { + assertEqual(t, err, nil) + } + + lbServe, err := createServer("lb") + if err != nil { + assertEqual(t, err, nil) + } + + cfg := cmds.Agent{ + ServerURL: fmt.Sprintf("http://%s/", ogServe.listener.Addr().String()), + DataDir: tmpDir, + } + + lb, err := Setup(context.Background(), cfg) + if err != nil { + assertEqual(t, err, nil) + } + + parsedURL, err := url.Parse(lb.LoadBalancerServerURL()) + if err != nil { + assertEqual(t, err, nil) + } + localAddress := parsedURL.Host + + lb.Update([]string{lbServe.listener.Addr().String()}) + + conn1, err := net.Dial("tcp", localAddress) + if err != nil { + assertEqual(t, err, nil) + } + result1, err := ping(conn1) + if err != nil { + assertEqual(t, err, nil) + } + assertEqual(t, result1, "lb:ping") + + lbServe.close() + + _, err = ping(conn1) + assertNotEqual(t, err, nil) + + conn2, err := net.Dial("tcp", localAddress) + if err != nil { + assertEqual(t, err, nil) + } + result2, err := ping(conn2) + if err != nil { + assertEqual(t, err, nil) + } + assertEqual(t, result2, "og:ping") +} + +func TestFailFast(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "lb-test") + if err != nil { + assertEqual(t, err, nil) + } + defer os.RemoveAll(tmpDir) + + cfg := cmds.Agent{ + ServerURL: "http://127.0.0.1:-1/", + DataDir: tmpDir, + } + + lb, err := Setup(context.Background(), cfg) + if err != nil { + assertEqual(t, err, nil) + } + + conn, err := net.Dial("tcp", lb.localAddress) + if err != nil { + assertEqual(t, err, nil) + } + + done := make(chan error) + go func() { + _, err = ping(conn) + done <- err + }() + timeout := time.After(10 * time.Millisecond) + + select { + case err := <-done: + assertNotEqual(t, err, nil) + case <-timeout: + t.Fatal(errors.New("time out")) + } +} diff --git a/pkg/agent/loadbalancer/servers.go b/pkg/agent/loadbalancer/servers.go new file mode 100644 index 0000000000..cb0393e733 --- /dev/null +++ b/pkg/agent/loadbalancer/servers.go @@ -0,0 +1,57 @@ +package loadbalancer + +import ( + "errors" + "math/rand" + "reflect" +) + +func (lb *LoadBalancer) setServers(serverAddresses []string) bool { + serverAddresses, hasOriginalServer := sortServers(serverAddresses, lb.originalServerAddress) + if len(serverAddresses) == 0 { + return false + } + + lb.mutex.Lock() + defer lb.mutex.Unlock() + + if reflect.DeepEqual(serverAddresses, lb.ServerAddresses) { + return false + } + + lb.ServerAddresses = serverAddresses + lb.randomServers = append([]string{}, lb.ServerAddresses...) + rand.Shuffle(len(lb.randomServers), func(i, j int) { + lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i] + }) + if !hasOriginalServer { + lb.randomServers = append(lb.randomServers, lb.originalServerAddress) + } + lb.currentServerAddress = lb.randomServers[0] + lb.nextServerIndex = 1 + + return true +} + +func (lb *LoadBalancer) nextServer(failedServer string) (string, error) { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + if len(lb.randomServers) == 0 { + return "", errors.New("No servers in load balancer proxy list") + } + if len(lb.randomServers) == 1 { + return lb.currentServerAddress, nil + } + if failedServer != lb.currentServerAddress { + return lb.currentServerAddress, nil + } + if lb.nextServerIndex >= len(lb.randomServers) { + lb.nextServerIndex = 0 + } + + lb.currentServerAddress = lb.randomServers[lb.nextServerIndex] + lb.nextServerIndex++ + + return lb.currentServerAddress, nil +} diff --git a/pkg/agent/loadbalancer/utility.go b/pkg/agent/loadbalancer/utility.go new file mode 100644 index 0000000000..a462da2e23 --- /dev/null +++ b/pkg/agent/loadbalancer/utility.go @@ -0,0 +1,49 @@ +package loadbalancer + +import ( + "errors" + "net/url" + "sort" + "strings" +) + +func parseURL(serverURL, newHost string) (string, string, error) { + parsedURL, err := url.Parse(serverURL) + if err != nil { + return "", "", err + } + if parsedURL.Host == "" { + return "", "", errors.New("Initial server URL host is not defined for load balancer") + } + address := parsedURL.Host + if parsedURL.Port() == "" { + if strings.ToLower(parsedURL.Scheme) == "http" { + address += ":80" + } + if strings.ToLower(parsedURL.Scheme) == "https" { + address += ":443" + } + } + parsedURL.Host = newHost + return address, parsedURL.String(), nil +} + +func sortServers(input []string, search string) ([]string, bool) { + result := []string{} + found := false + skip := map[string]bool{"": true} + + for _, entry := range input { + if skip[entry] { + continue + } + if search == entry { + found = true + } + skip[entry] = true + result = append(result, entry) + } + + sort.Strings(result) + return result, found +} diff --git a/pkg/agent/run.go b/pkg/agent/run.go index 9a4521149f..d6f6c6dc29 100644 --- a/pkg/agent/run.go +++ b/pkg/agent/run.go @@ -12,6 +12,7 @@ import ( "github.com/rancher/k3s/pkg/agent/config" "github.com/rancher/k3s/pkg/agent/containerd" "github.com/rancher/k3s/pkg/agent/flannel" + "github.com/rancher/k3s/pkg/agent/loadbalancer" "github.com/rancher/k3s/pkg/agent/syssetup" "github.com/rancher/k3s/pkg/agent/tunnel" "github.com/rancher/k3s/pkg/cli/cmds" @@ -21,7 +22,7 @@ import ( "github.com/sirupsen/logrus" ) -func run(ctx context.Context, cfg cmds.Agent) error { +func run(ctx context.Context, cfg cmds.Agent, lb *loadbalancer.LoadBalancer) error { nodeConfig := config.Get(ctx, cfg) if err := config.HostnameCheck(cfg); err != nil { @@ -47,7 +48,7 @@ func run(ctx context.Context, cfg cmds.Agent) error { return err } - if err := tunnel.Setup(ctx, nodeConfig); err != nil { + if err := tunnel.Setup(ctx, nodeConfig, lb.Update); err != nil { return err } @@ -77,11 +78,20 @@ func Run(ctx context.Context, cfg cmds.Agent) error { } cfg.DataDir = filepath.Join(cfg.DataDir, "agent") + os.MkdirAll(cfg.DataDir, 0700) if cfg.ClusterSecret != "" { cfg.Token = "K10node:" + cfg.ClusterSecret } + lb, err := loadbalancer.Setup(ctx, cfg) + if err != nil { + return err + } + if lb != nil { + cfg.ServerURL = lb.LoadBalancerServerURL() + } + for { tmpFile, err := clientaccess.AgentAccessInfoToTempKubeConfig("", cfg.ServerURL, cfg.Token) if err != nil { @@ -97,8 +107,7 @@ func Run(ctx context.Context, cfg cmds.Agent) error { break } - os.MkdirAll(cfg.DataDir, 0700) - return run(ctx, cfg) + return run(ctx, cfg, lb) } func validate() error { diff --git a/pkg/agent/tunnel/tunnel.go b/pkg/agent/tunnel/tunnel.go index be39e21f19..c139129557 100644 --- a/pkg/agent/tunnel/tunnel.go +++ b/pkg/agent/tunnel/tunnel.go @@ -53,7 +53,7 @@ func getAddresses(endpoint *v1.Endpoints) []string { return serverAddresses } -func Setup(ctx context.Context, config *config.Node) error { +func Setup(ctx context.Context, config *config.Node, onChange func([]string)) error { restConfig, err := clientcmd.BuildConfigFromFlags("", config.AgentConfig.KubeConfigNode) if err != nil { return err @@ -74,6 +74,9 @@ func Setup(ctx context.Context, config *config.Node) error { endpoint, _ := client.CoreV1().Endpoints("default").Get("kubernetes", metav1.GetOptions{}) if endpoint != nil { addresses = getAddresses(endpoint) + if onChange != nil { + onChange(addresses) + } } disconnect := map[string]context.CancelFunc{} @@ -120,6 +123,9 @@ func Setup(ctx context.Context, config *config.Node) error { } addresses = newAddresses logrus.Infof("Tunnel endpoint watch event: %v", addresses) + if onChange != nil { + onChange(addresses) + } validEndpoint := map[string]bool{} diff --git a/pkg/cli/cmds/agent.go b/pkg/cli/cmds/agent.go index a82f830774..e1364dda12 100644 --- a/pkg/cli/cmds/agent.go +++ b/pkg/cli/cmds/agent.go @@ -11,6 +11,7 @@ type Agent struct { Token string TokenFile string ServerURL string + DisableLoadBalancer bool ResolvConf string DataDir string NodeIP string diff --git a/pkg/cli/server/server.go b/pkg/cli/server/server.go index ac8c05f9d9..fe4d39b1ca 100644 --- a/pkg/cli/server/server.go +++ b/pkg/cli/server/server.go @@ -214,6 +214,7 @@ func run(app *cli.Context, cfg *cmds.Server) error { agentConfig.ServerURL = url agentConfig.Token = token agentConfig.Labels = append(agentConfig.Labels, "node-role.kubernetes.io/master=true") + agentConfig.DisableLoadBalancer = true return agent.Run(ctx, agentConfig) }