influxdb/meta/rpc.go

461 lines
12 KiB
Go

package meta
import (
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"time"
"github.com/gogo/protobuf/proto"
"github.com/hashicorp/raft"
"github.com/influxdb/influxdb/meta/internal"
)
// Max size of a message before we treat the size as invalid
const (
MaxMessageSize = 1024 * 1024 * 1024
leaderDialTimeout = 10 * time.Second
)
// rpc handles request/response style messaging between cluster nodes
type rpc struct {
logger *log.Logger
tracingEnabled bool
store interface {
cachedData() *Data
IsLeader() bool
Leader() string
Peers() ([]string, error)
AddPeer(host string) error
CreateNode(host string) (*NodeInfo, error)
NodeByHost(host string) (*NodeInfo, error)
WaitForDataChanged() error
}
}
type JoinResult struct {
RaftEnabled bool
RaftNodes []string
NodeID uint64
}
type Reply interface {
GetHeader() *internal.ResponseHeader
}
// proxyLeader proxies the connection to the current raft leader
func (r *rpc) proxyLeader(conn *net.TCPConn) {
if r.store.Leader() == "" {
r.sendError(conn, "no leader")
return
}
leaderConn, err := net.DialTimeout("tcp", r.store.Leader(), leaderDialTimeout)
if err != nil {
r.sendError(conn, fmt.Sprintf("dial leader: %v", err))
return
}
defer leaderConn.Close()
leaderConn.Write([]byte{MuxRPCHeader})
if err := proxy(leaderConn.(*net.TCPConn), conn); err != nil {
r.sendError(conn, fmt.Sprintf("leader proxy error: %v", err))
}
}
// handleRPCConn reads a command from the connection and executes it.
func (r *rpc) handleRPCConn(conn net.Conn) {
defer conn.Close()
// RPC connections should execute on the leader. If we are not the leader,
// proxy the connection to the leader so that clients an connect to any node
// in the cluster.
r.traceCluster("rpc connection from: %v", conn.RemoteAddr())
if !r.store.IsLeader() {
r.proxyLeader(conn.(*net.TCPConn))
return
}
// Read and execute request.
typ, resp, err := func() (internal.RPCType, proto.Message, error) {
// Read request size.
var sz uint64
if err := binary.Read(conn, binary.BigEndian, &sz); err != nil {
return internal.RPCType_Error, nil, fmt.Errorf("read size: %s", err)
}
if sz == 0 {
return 0, nil, fmt.Errorf("invalid message size: %d", sz)
}
if sz >= MaxMessageSize {
return 0, nil, fmt.Errorf("max message size of %d exceeded: %d", MaxMessageSize, sz)
}
// Read request.
buf := make([]byte, sz)
if _, err := io.ReadFull(conn, buf); err != nil {
return internal.RPCType_Error, nil, fmt.Errorf("read request: %s", err)
}
// Determine the RPC type
rpcType := internal.RPCType(btou64(buf[0:8]))
buf = buf[8:]
r.traceCluster("recv %v request on: %v", rpcType, conn.RemoteAddr())
switch rpcType {
case internal.RPCType_FetchData:
var req internal.FetchDataRequest
if err := proto.Unmarshal(buf, &req); err != nil {
return internal.RPCType_Error, nil, fmt.Errorf("fetch request unmarshal: %v", err)
}
resp, err := r.handleFetchData(&req)
return rpcType, resp, err
case internal.RPCType_Join:
var req internal.JoinRequest
if err := proto.Unmarshal(buf, &req); err != nil {
return internal.RPCType_Error, nil, fmt.Errorf("join request unmarshal: %v", err)
}
resp, err := r.handleJoinRequest(&req)
return rpcType, resp, err
default:
return internal.RPCType_Error, nil, fmt.Errorf("unknown rpc type:%v", rpcType)
}
}()
// Handle unexpected RPC errors
if err != nil {
resp = &internal.ErrorResponse{
Header: &internal.ResponseHeader{
OK: proto.Bool(false),
},
}
typ = internal.RPCType_Error
}
// Set the status header and error message
if reply, ok := resp.(Reply); ok {
reply.GetHeader().OK = proto.Bool(err == nil)
if err != nil {
reply.GetHeader().Error = proto.String(err.Error())
}
}
r.sendResponse(conn, typ, resp)
}
func (r *rpc) sendResponse(conn net.Conn, typ internal.RPCType, resp proto.Message) {
// Marshal the response back to a protobuf
buf, err := proto.Marshal(resp)
if err != nil {
r.logger.Printf("unable to marshal response: %v", err)
return
}
// Encode response back to connection.
if _, err := conn.Write(r.pack(typ, buf)); err != nil {
r.logger.Printf("unable to write rpc response: %s", err)
}
}
func (r *rpc) sendError(conn net.Conn, msg string) {
r.traceCluster(msg)
resp := &internal.ErrorResponse{
Header: &internal.ResponseHeader{
OK: proto.Bool(false),
Error: proto.String(msg),
},
}
r.sendResponse(conn, internal.RPCType_Error, resp)
}
// handleFetchData handles a request for the current nodes meta data
func (r *rpc) handleFetchData(req *internal.FetchDataRequest) (*internal.FetchDataResponse, error) {
var (
b []byte
data *Data
err error
)
for {
data = r.store.cachedData()
if data.Index != req.GetIndex() {
b, err = data.MarshalBinary()
if err != nil {
return nil, err
}
break
}
if !req.GetBlocking() {
break
}
if err := r.store.WaitForDataChanged(); err != nil {
return nil, err
}
}
return &internal.FetchDataResponse{
Header: &internal.ResponseHeader{
OK: proto.Bool(true),
},
Index: proto.Uint64(data.Index),
Term: proto.Uint64(data.Term),
Data: b}, nil
}
// handleJoinRequest handles a request to join the cluster
func (r *rpc) handleJoinRequest(req *internal.JoinRequest) (*internal.JoinResponse, error) {
r.traceCluster("join request from: %v", *req.Addr)
node, err := func() (*NodeInfo, error) {
// attempt to create the node
node, err := r.store.CreateNode(*req.Addr)
// if it exists, return the existing node
if err == ErrNodeExists {
node, err = r.store.NodeByHost(*req.Addr)
if err != nil {
return node, err
}
r.logger.Printf("existing node re-joined: id=%v addr=%v", node.ID, node.Host)
} else if err != nil {
return nil, fmt.Errorf("create node: %v", err)
}
peers, err := r.store.Peers()
if err != nil {
return nil, fmt.Errorf("list peers: %v", err)
}
// If we have less than 3 nodes, add them as raft peers if they are not
// already a peer
if len(peers) < MaxRaftNodes && !raft.PeerContained(peers, *req.Addr) {
r.logger.Printf("adding new raft peer: nodeId=%v addr=%v", node.ID, *req.Addr)
if err = r.store.AddPeer(*req.Addr); err != nil {
return node, fmt.Errorf("add peer: %v", err)
}
}
return node, err
}()
nodeID := uint64(0)
if node != nil {
nodeID = node.ID
}
if err != nil {
return nil, err
}
// get the current raft peers
peers, err := r.store.Peers()
if err != nil {
return nil, fmt.Errorf("list peers: %v", err)
}
return &internal.JoinResponse{
Header: &internal.ResponseHeader{
OK: proto.Bool(true),
},
EnableRaft: proto.Bool(raft.PeerContained(peers, *req.Addr)),
RaftNodes: peers,
NodeID: proto.Uint64(nodeID),
}, err
}
// pack returns a TLV style byte slice encoding the size of the payload, the RPC type
// and the RPC data
func (r *rpc) pack(typ internal.RPCType, b []byte) []byte {
buf := u64tob(uint64(len(b)) + 8)
buf = append(buf, u64tob(uint64(typ))...)
buf = append(buf, b...)
return buf
}
// fetchMetaData returns the latest copy of the meta store data from the current
// leader.
func (r *rpc) fetchMetaData(blocking bool) (*Data, error) {
assert(r.store != nil, "store is nil")
// Retrieve the current known leader.
leader := r.store.Leader()
if leader == "" {
return nil, errors.New("no leader")
}
var index, term uint64
data := r.store.cachedData()
if data != nil {
index = data.Index
term = data.Index
}
resp, err := r.call(leader, &internal.FetchDataRequest{
Index: proto.Uint64(index),
Term: proto.Uint64(term),
Blocking: proto.Bool(blocking),
})
if err != nil {
return nil, err
}
switch t := resp.(type) {
case *internal.FetchDataResponse:
// If data is nil, then the term and index we sent matches the leader
if t.GetData() == nil {
return nil, nil
}
ms := &Data{}
if err := ms.UnmarshalBinary(t.GetData()); err != nil {
return nil, fmt.Errorf("rpc unmarshal metadata: %v", err)
}
return ms, nil
case *internal.ErrorResponse:
return nil, fmt.Errorf("rpc failed: %s", t.GetHeader().GetError())
default:
return nil, fmt.Errorf("rpc failed: unknown response type: %v", t.String())
}
}
// join attempts to join a cluster at remoteAddr using localAddr as the current
// node's cluster address
func (r *rpc) join(localAddr, remoteAddr string) (*JoinResult, error) {
req := &internal.JoinRequest{
Addr: proto.String(localAddr),
}
resp, err := r.call(remoteAddr, req)
if err != nil {
return nil, err
}
switch t := resp.(type) {
case *internal.JoinResponse:
return &JoinResult{
RaftEnabled: t.GetEnableRaft(),
RaftNodes: t.GetRaftNodes(),
NodeID: t.GetNodeID(),
}, nil
case *internal.ErrorResponse:
return nil, fmt.Errorf("rpc failed: %s", t.GetHeader().GetError())
default:
return nil, fmt.Errorf("rpc failed: unknown response type: %v", t.String())
}
}
// call sends an encoded request to the remote leader and returns
// an encoded response value.
func (r *rpc) call(dest string, req proto.Message) (proto.Message, error) {
// Determine type of request
var rpcType internal.RPCType
switch t := req.(type) {
case *internal.JoinRequest:
rpcType = internal.RPCType_Join
case *internal.FetchDataRequest:
rpcType = internal.RPCType_FetchData
default:
return nil, fmt.Errorf("unknown rpc request type: %v", t)
}
// Create a connection to the leader.
conn, err := net.DialTimeout("tcp", dest, leaderDialTimeout)
if err != nil {
return nil, fmt.Errorf("rpc dial: %v", err)
}
defer conn.Close()
// Write a marker byte for rpc messages.
_, err = conn.Write([]byte{MuxRPCHeader})
if err != nil {
return nil, err
}
b, err := proto.Marshal(req)
if err != nil {
return nil, fmt.Errorf("rpc marshal: %v", err)
}
// Write request size & bytes.
if _, err := conn.Write(r.pack(rpcType, b)); err != nil {
return nil, fmt.Errorf("write %v rpc: %s", rpcType, err)
}
data, err := ioutil.ReadAll(conn)
if err != nil {
return nil, fmt.Errorf("read %v rpc: %v", rpcType, err)
}
// Should always have a size and type
if exp := 16; len(data) < exp {
r.traceCluster("recv: %v", string(data))
return nil, fmt.Errorf("rpc %v failed: short read: got %v, exp %v", rpcType, len(data), exp)
}
sz := btou64(data[0:8])
if len(data[8:]) != int(sz) {
r.traceCluster("recv: %v", string(data))
return nil, fmt.Errorf("rpc %v failed: short read: got %v, exp %v", rpcType, len(data[8:]), sz)
}
// See what response type we got back, could get a general error response
rpcType = internal.RPCType(btou64(data[8:16]))
data = data[16:]
var resp proto.Message
switch rpcType {
case internal.RPCType_Join:
resp = &internal.JoinResponse{}
case internal.RPCType_FetchData:
resp = &internal.FetchDataResponse{}
case internal.RPCType_Error:
resp = &internal.ErrorResponse{}
default:
return nil, fmt.Errorf("unknown rpc response type: %v", rpcType)
}
if err := proto.Unmarshal(data, resp); err != nil {
return nil, fmt.Errorf("rpc unmarshal: %v", err)
}
if reply, ok := resp.(Reply); ok {
if !reply.GetHeader().GetOK() {
return nil, fmt.Errorf("rpc %v failed: %s", rpcType, reply.GetHeader().GetError())
}
}
return resp, nil
}
func (r *rpc) traceCluster(msg string, args ...interface{}) {
if r.tracingEnabled {
r.logger.Printf("rpc: "+msg, args...)
}
}
func u64tob(v uint64) []byte {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, v)
return b
}
func btou64(b []byte) uint64 {
return binary.BigEndian.Uint64(b)
}
func contains(s []string, e string) bool {
for _, a := range s {
if a == e {
return true
}
}
return false
}