461 lines
12 KiB
Go
461 lines
12 KiB
Go
package meta
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"log"
|
|
"net"
|
|
"time"
|
|
|
|
"github.com/gogo/protobuf/proto"
|
|
"github.com/hashicorp/raft"
|
|
"github.com/influxdb/influxdb/meta/internal"
|
|
)
|
|
|
|
// Max size of a message before we treat the size as invalid
|
|
const (
|
|
MaxMessageSize = 1024 * 1024 * 1024
|
|
leaderDialTimeout = 10 * time.Second
|
|
)
|
|
|
|
// rpc handles request/response style messaging between cluster nodes
|
|
type rpc struct {
|
|
logger *log.Logger
|
|
tracingEnabled bool
|
|
|
|
store interface {
|
|
cachedData() *Data
|
|
IsLeader() bool
|
|
Leader() string
|
|
Peers() ([]string, error)
|
|
AddPeer(host string) error
|
|
CreateNode(host string) (*NodeInfo, error)
|
|
NodeByHost(host string) (*NodeInfo, error)
|
|
WaitForDataChanged() error
|
|
}
|
|
}
|
|
|
|
type JoinResult struct {
|
|
RaftEnabled bool
|
|
RaftNodes []string
|
|
NodeID uint64
|
|
}
|
|
|
|
type Reply interface {
|
|
GetHeader() *internal.ResponseHeader
|
|
}
|
|
|
|
// proxyLeader proxies the connection to the current raft leader
|
|
func (r *rpc) proxyLeader(conn *net.TCPConn) {
|
|
if r.store.Leader() == "" {
|
|
r.sendError(conn, "no leader")
|
|
return
|
|
}
|
|
|
|
leaderConn, err := net.DialTimeout("tcp", r.store.Leader(), leaderDialTimeout)
|
|
if err != nil {
|
|
r.sendError(conn, fmt.Sprintf("dial leader: %v", err))
|
|
return
|
|
}
|
|
defer leaderConn.Close()
|
|
|
|
leaderConn.Write([]byte{MuxRPCHeader})
|
|
if err := proxy(leaderConn.(*net.TCPConn), conn); err != nil {
|
|
r.sendError(conn, fmt.Sprintf("leader proxy error: %v", err))
|
|
}
|
|
}
|
|
|
|
// handleRPCConn reads a command from the connection and executes it.
|
|
func (r *rpc) handleRPCConn(conn net.Conn) {
|
|
defer conn.Close()
|
|
// RPC connections should execute on the leader. If we are not the leader,
|
|
// proxy the connection to the leader so that clients an connect to any node
|
|
// in the cluster.
|
|
r.traceCluster("rpc connection from: %v", conn.RemoteAddr())
|
|
|
|
if !r.store.IsLeader() {
|
|
r.proxyLeader(conn.(*net.TCPConn))
|
|
return
|
|
}
|
|
|
|
// Read and execute request.
|
|
typ, resp, err := func() (internal.RPCType, proto.Message, error) {
|
|
// Read request size.
|
|
var sz uint64
|
|
if err := binary.Read(conn, binary.BigEndian, &sz); err != nil {
|
|
return internal.RPCType_Error, nil, fmt.Errorf("read size: %s", err)
|
|
}
|
|
|
|
if sz == 0 {
|
|
return 0, nil, fmt.Errorf("invalid message size: %d", sz)
|
|
}
|
|
|
|
if sz >= MaxMessageSize {
|
|
return 0, nil, fmt.Errorf("max message size of %d exceeded: %d", MaxMessageSize, sz)
|
|
}
|
|
|
|
// Read request.
|
|
buf := make([]byte, sz)
|
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
|
return internal.RPCType_Error, nil, fmt.Errorf("read request: %s", err)
|
|
}
|
|
|
|
// Determine the RPC type
|
|
rpcType := internal.RPCType(btou64(buf[0:8]))
|
|
buf = buf[8:]
|
|
|
|
r.traceCluster("recv %v request on: %v", rpcType, conn.RemoteAddr())
|
|
switch rpcType {
|
|
case internal.RPCType_FetchData:
|
|
var req internal.FetchDataRequest
|
|
if err := proto.Unmarshal(buf, &req); err != nil {
|
|
return internal.RPCType_Error, nil, fmt.Errorf("fetch request unmarshal: %v", err)
|
|
}
|
|
resp, err := r.handleFetchData(&req)
|
|
return rpcType, resp, err
|
|
case internal.RPCType_Join:
|
|
var req internal.JoinRequest
|
|
if err := proto.Unmarshal(buf, &req); err != nil {
|
|
return internal.RPCType_Error, nil, fmt.Errorf("join request unmarshal: %v", err)
|
|
}
|
|
resp, err := r.handleJoinRequest(&req)
|
|
return rpcType, resp, err
|
|
default:
|
|
return internal.RPCType_Error, nil, fmt.Errorf("unknown rpc type:%v", rpcType)
|
|
}
|
|
}()
|
|
|
|
// Handle unexpected RPC errors
|
|
if err != nil {
|
|
resp = &internal.ErrorResponse{
|
|
Header: &internal.ResponseHeader{
|
|
OK: proto.Bool(false),
|
|
},
|
|
}
|
|
typ = internal.RPCType_Error
|
|
}
|
|
|
|
// Set the status header and error message
|
|
if reply, ok := resp.(Reply); ok {
|
|
reply.GetHeader().OK = proto.Bool(err == nil)
|
|
if err != nil {
|
|
reply.GetHeader().Error = proto.String(err.Error())
|
|
}
|
|
}
|
|
|
|
r.sendResponse(conn, typ, resp)
|
|
}
|
|
|
|
func (r *rpc) sendResponse(conn net.Conn, typ internal.RPCType, resp proto.Message) {
|
|
// Marshal the response back to a protobuf
|
|
buf, err := proto.Marshal(resp)
|
|
if err != nil {
|
|
r.logger.Printf("unable to marshal response: %v", err)
|
|
return
|
|
}
|
|
|
|
// Encode response back to connection.
|
|
if _, err := conn.Write(r.pack(typ, buf)); err != nil {
|
|
r.logger.Printf("unable to write rpc response: %s", err)
|
|
}
|
|
}
|
|
|
|
func (r *rpc) sendError(conn net.Conn, msg string) {
|
|
r.traceCluster(msg)
|
|
resp := &internal.ErrorResponse{
|
|
Header: &internal.ResponseHeader{
|
|
OK: proto.Bool(false),
|
|
Error: proto.String(msg),
|
|
},
|
|
}
|
|
|
|
r.sendResponse(conn, internal.RPCType_Error, resp)
|
|
}
|
|
|
|
// handleFetchData handles a request for the current nodes meta data
|
|
func (r *rpc) handleFetchData(req *internal.FetchDataRequest) (*internal.FetchDataResponse, error) {
|
|
var (
|
|
b []byte
|
|
data *Data
|
|
err error
|
|
)
|
|
|
|
for {
|
|
data = r.store.cachedData()
|
|
if data.Index != req.GetIndex() {
|
|
b, err = data.MarshalBinary()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
break
|
|
}
|
|
|
|
if !req.GetBlocking() {
|
|
break
|
|
}
|
|
|
|
if err := r.store.WaitForDataChanged(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return &internal.FetchDataResponse{
|
|
Header: &internal.ResponseHeader{
|
|
OK: proto.Bool(true),
|
|
},
|
|
Index: proto.Uint64(data.Index),
|
|
Term: proto.Uint64(data.Term),
|
|
Data: b}, nil
|
|
}
|
|
|
|
// handleJoinRequest handles a request to join the cluster
|
|
func (r *rpc) handleJoinRequest(req *internal.JoinRequest) (*internal.JoinResponse, error) {
|
|
r.traceCluster("join request from: %v", *req.Addr)
|
|
|
|
node, err := func() (*NodeInfo, error) {
|
|
|
|
// attempt to create the node
|
|
node, err := r.store.CreateNode(*req.Addr)
|
|
// if it exists, return the existing node
|
|
if err == ErrNodeExists {
|
|
node, err = r.store.NodeByHost(*req.Addr)
|
|
if err != nil {
|
|
return node, err
|
|
}
|
|
r.logger.Printf("existing node re-joined: id=%v addr=%v", node.ID, node.Host)
|
|
} else if err != nil {
|
|
return nil, fmt.Errorf("create node: %v", err)
|
|
}
|
|
|
|
peers, err := r.store.Peers()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list peers: %v", err)
|
|
}
|
|
|
|
// If we have less than 3 nodes, add them as raft peers if they are not
|
|
// already a peer
|
|
if len(peers) < MaxRaftNodes && !raft.PeerContained(peers, *req.Addr) {
|
|
r.logger.Printf("adding new raft peer: nodeId=%v addr=%v", node.ID, *req.Addr)
|
|
if err = r.store.AddPeer(*req.Addr); err != nil {
|
|
return node, fmt.Errorf("add peer: %v", err)
|
|
}
|
|
}
|
|
return node, err
|
|
}()
|
|
|
|
nodeID := uint64(0)
|
|
if node != nil {
|
|
nodeID = node.ID
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// get the current raft peers
|
|
peers, err := r.store.Peers()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list peers: %v", err)
|
|
}
|
|
|
|
return &internal.JoinResponse{
|
|
Header: &internal.ResponseHeader{
|
|
OK: proto.Bool(true),
|
|
},
|
|
EnableRaft: proto.Bool(raft.PeerContained(peers, *req.Addr)),
|
|
RaftNodes: peers,
|
|
NodeID: proto.Uint64(nodeID),
|
|
}, err
|
|
|
|
}
|
|
|
|
// pack returns a TLV style byte slice encoding the size of the payload, the RPC type
|
|
// and the RPC data
|
|
func (r *rpc) pack(typ internal.RPCType, b []byte) []byte {
|
|
buf := u64tob(uint64(len(b)) + 8)
|
|
buf = append(buf, u64tob(uint64(typ))...)
|
|
buf = append(buf, b...)
|
|
return buf
|
|
}
|
|
|
|
// fetchMetaData returns the latest copy of the meta store data from the current
|
|
// leader.
|
|
func (r *rpc) fetchMetaData(blocking bool) (*Data, error) {
|
|
assert(r.store != nil, "store is nil")
|
|
|
|
// Retrieve the current known leader.
|
|
leader := r.store.Leader()
|
|
if leader == "" {
|
|
return nil, errors.New("no leader")
|
|
}
|
|
|
|
var index, term uint64
|
|
data := r.store.cachedData()
|
|
if data != nil {
|
|
index = data.Index
|
|
term = data.Index
|
|
}
|
|
resp, err := r.call(leader, &internal.FetchDataRequest{
|
|
Index: proto.Uint64(index),
|
|
Term: proto.Uint64(term),
|
|
Blocking: proto.Bool(blocking),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch t := resp.(type) {
|
|
case *internal.FetchDataResponse:
|
|
// If data is nil, then the term and index we sent matches the leader
|
|
if t.GetData() == nil {
|
|
return nil, nil
|
|
}
|
|
ms := &Data{}
|
|
if err := ms.UnmarshalBinary(t.GetData()); err != nil {
|
|
return nil, fmt.Errorf("rpc unmarshal metadata: %v", err)
|
|
}
|
|
return ms, nil
|
|
case *internal.ErrorResponse:
|
|
return nil, fmt.Errorf("rpc failed: %s", t.GetHeader().GetError())
|
|
default:
|
|
return nil, fmt.Errorf("rpc failed: unknown response type: %v", t.String())
|
|
}
|
|
}
|
|
|
|
// join attempts to join a cluster at remoteAddr using localAddr as the current
|
|
// node's cluster address
|
|
func (r *rpc) join(localAddr, remoteAddr string) (*JoinResult, error) {
|
|
req := &internal.JoinRequest{
|
|
Addr: proto.String(localAddr),
|
|
}
|
|
|
|
resp, err := r.call(remoteAddr, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch t := resp.(type) {
|
|
case *internal.JoinResponse:
|
|
return &JoinResult{
|
|
RaftEnabled: t.GetEnableRaft(),
|
|
RaftNodes: t.GetRaftNodes(),
|
|
NodeID: t.GetNodeID(),
|
|
}, nil
|
|
case *internal.ErrorResponse:
|
|
return nil, fmt.Errorf("rpc failed: %s", t.GetHeader().GetError())
|
|
default:
|
|
return nil, fmt.Errorf("rpc failed: unknown response type: %v", t.String())
|
|
}
|
|
}
|
|
|
|
// call sends an encoded request to the remote leader and returns
|
|
// an encoded response value.
|
|
func (r *rpc) call(dest string, req proto.Message) (proto.Message, error) {
|
|
// Determine type of request
|
|
var rpcType internal.RPCType
|
|
switch t := req.(type) {
|
|
case *internal.JoinRequest:
|
|
rpcType = internal.RPCType_Join
|
|
case *internal.FetchDataRequest:
|
|
rpcType = internal.RPCType_FetchData
|
|
default:
|
|
return nil, fmt.Errorf("unknown rpc request type: %v", t)
|
|
}
|
|
|
|
// Create a connection to the leader.
|
|
conn, err := net.DialTimeout("tcp", dest, leaderDialTimeout)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("rpc dial: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Write a marker byte for rpc messages.
|
|
_, err = conn.Write([]byte{MuxRPCHeader})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
b, err := proto.Marshal(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("rpc marshal: %v", err)
|
|
}
|
|
|
|
// Write request size & bytes.
|
|
if _, err := conn.Write(r.pack(rpcType, b)); err != nil {
|
|
return nil, fmt.Errorf("write %v rpc: %s", rpcType, err)
|
|
}
|
|
|
|
data, err := ioutil.ReadAll(conn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read %v rpc: %v", rpcType, err)
|
|
}
|
|
|
|
// Should always have a size and type
|
|
if exp := 16; len(data) < exp {
|
|
r.traceCluster("recv: %v", string(data))
|
|
return nil, fmt.Errorf("rpc %v failed: short read: got %v, exp %v", rpcType, len(data), exp)
|
|
}
|
|
|
|
sz := btou64(data[0:8])
|
|
if len(data[8:]) != int(sz) {
|
|
r.traceCluster("recv: %v", string(data))
|
|
return nil, fmt.Errorf("rpc %v failed: short read: got %v, exp %v", rpcType, len(data[8:]), sz)
|
|
}
|
|
|
|
// See what response type we got back, could get a general error response
|
|
rpcType = internal.RPCType(btou64(data[8:16]))
|
|
data = data[16:]
|
|
|
|
var resp proto.Message
|
|
switch rpcType {
|
|
case internal.RPCType_Join:
|
|
resp = &internal.JoinResponse{}
|
|
case internal.RPCType_FetchData:
|
|
resp = &internal.FetchDataResponse{}
|
|
case internal.RPCType_Error:
|
|
resp = &internal.ErrorResponse{}
|
|
default:
|
|
return nil, fmt.Errorf("unknown rpc response type: %v", rpcType)
|
|
}
|
|
|
|
if err := proto.Unmarshal(data, resp); err != nil {
|
|
return nil, fmt.Errorf("rpc unmarshal: %v", err)
|
|
}
|
|
|
|
if reply, ok := resp.(Reply); ok {
|
|
if !reply.GetHeader().GetOK() {
|
|
return nil, fmt.Errorf("rpc %v failed: %s", rpcType, reply.GetHeader().GetError())
|
|
}
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (r *rpc) traceCluster(msg string, args ...interface{}) {
|
|
if r.tracingEnabled {
|
|
r.logger.Printf("rpc: "+msg, args...)
|
|
}
|
|
}
|
|
|
|
func u64tob(v uint64) []byte {
|
|
b := make([]byte, 8)
|
|
binary.BigEndian.PutUint64(b, v)
|
|
return b
|
|
}
|
|
|
|
func btou64(b []byte) uint64 {
|
|
return binary.BigEndian.Uint64(b)
|
|
}
|
|
|
|
func contains(s []string, e string) bool {
|
|
for _, a := range s {
|
|
if a == e {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|