||
- package raft
- import (
- "context"
- "fmt"
- "io"
- "net"
- "sync"
- "time"
- )
- // Transport defines the interface for RPC communication
- type Transport interface {
- // Start starts the transport
- Start() error
- // Stop stops the transport
- Stop() error
- // RequestVote sends a RequestVote RPC to the target node
- RequestVote(ctx context.Context, target string, args *RequestVoteArgs) (*RequestVoteReply, error)
- // AppendEntries sends an AppendEntries RPC to the target node
- AppendEntries(ctx context.Context, target string, args *AppendEntriesArgs) (*AppendEntriesReply, error)
- // InstallSnapshot sends an InstallSnapshot RPC to the target node
- InstallSnapshot(ctx context.Context, target string, args *InstallSnapshotArgs) (*InstallSnapshotReply, error)
- // ForwardPropose forwards a propose request to the leader
- ForwardPropose(ctx context.Context, target string, args *ProposeArgs) (*ProposeReply, error)
- // ForwardAddNode forwards an AddNode request to the leader
- ForwardAddNode(ctx context.Context, target string, args *AddNodeArgs) (*AddNodeReply, error)
- // ForwardRemoveNode forwards a RemoveNode request to the leader
- ForwardRemoveNode(ctx context.Context, target string, args *RemoveNodeArgs) (*RemoveNodeReply, error)
- // TimeoutNow sends a TimeoutNow RPC for leadership transfer
- TimeoutNow(ctx context.Context, target string, args *TimeoutNowArgs) (*TimeoutNowReply, error)
- // ReadIndex sends a ReadIndex RPC for linearizable reads
- ReadIndex(ctx context.Context, target string, args *ReadIndexArgs) (*ReadIndexReply, error)
- // ForwardGet sends a Get RPC for remote KV reads
- ForwardGet(ctx context.Context, target string, args *GetArgs) (*GetReply, error)
- // SetRPCHandler sets the handler for incoming RPCs
- SetRPCHandler(handler RPCHandler)
- }
- // RPCHandler handles incoming RPCs
- type RPCHandler interface {
- HandleRequestVote(args *RequestVoteArgs) *RequestVoteReply
- HandleAppendEntries(args *AppendEntriesArgs) *AppendEntriesReply
- HandleInstallSnapshot(args *InstallSnapshotArgs) *InstallSnapshotReply
- HandlePropose(args *ProposeArgs) *ProposeReply
- HandleAddNode(args *AddNodeArgs) *AddNodeReply
- HandleRemoveNode(args *RemoveNodeArgs) *RemoveNodeReply
- HandleTimeoutNow(args *TimeoutNowArgs) *TimeoutNowReply
- HandleReadIndex(args *ReadIndexArgs) *ReadIndexReply
- HandleGet(args *GetArgs) *GetReply
- }
- // TCPTransport implements Transport using raw TCP with binary protocol
- // This is more efficient than HTTP for high-frequency RPCs
- type TCPTransport struct {
- mu sync.RWMutex
- localAddr string
- handler RPCHandler
- logger Logger
- listener net.Listener
- shutdownCh chan struct{}
- // Single persistent connection per target
- conns map[string]net.Conn
- }
- // NewTCPTransport creates a new TCP transport
- func NewTCPTransport(localAddr string, poolSize int, logger Logger) *TCPTransport {
- if logger == nil {
- logger = &NoopLogger{}
- }
- return &TCPTransport{
- localAddr: localAddr,
- logger: logger,
- shutdownCh: make(chan struct{}),
- conns: make(map[string]net.Conn),
- }
- }
- // SetRPCHandler sets the handler for incoming RPCs
- func (t *TCPTransport) SetRPCHandler(handler RPCHandler) {
- t.mu.Lock()
- defer t.mu.Unlock()
- t.handler = handler
- }
- // Start starts the TCP server
- func (t *TCPTransport) Start() error {
- var err error
- t.listener, err = net.Listen("tcp", t.localAddr)
- if err != nil {
- return fmt.Errorf("failed to listen on %s: %w", t.localAddr, err)
- }
- go t.acceptLoop()
- t.logger.Info("TCP Transport started on %s", t.localAddr)
- return nil
- }
- // acceptLoop accepts incoming connections
- func (t *TCPTransport) acceptLoop() {
- for {
- select {
- case <-t.shutdownCh:
- return
- default:
- }
- conn, err := t.listener.Accept()
- if err != nil {
- select {
- case <-t.shutdownCh:
- return
- default:
- t.logger.Error("Accept error: %v", err)
- continue
- }
- }
- go t.handleConnection(conn)
- }
- }
- // handleConnection handles an incoming connection
- func (t *TCPTransport) handleConnection(conn net.Conn) {
- defer conn.Close()
- for {
- select {
- case <-t.shutdownCh:
- return
- default:
- }
- // Set read deadline
- conn.SetReadDeadline(time.Now().Add(30 * time.Second))
- // Read message type (1 byte)
- typeBuf := make([]byte, 1)
- if _, err := io.ReadFull(conn, typeBuf); err != nil {
- if err != io.EOF {
- t.logger.Debug("Read type error: %v", err)
- }
- return
- }
- // Read message length (4 bytes)
- lenBuf := make([]byte, 4)
- if _, err := io.ReadFull(conn, lenBuf); err != nil {
- t.logger.Debug("Read length error: %v", err)
- return
- }
- length := uint32(lenBuf[0])<<24 | uint32(lenBuf[1])<<16 | uint32(lenBuf[2])<<8 | uint32(lenBuf[3])
- if length > 10*1024*1024 { // 10MB limit
- t.logger.Error("Message too large: %d", length)
- return
- }
- // Read message body
- body := make([]byte, length)
- if _, err := io.ReadFull(conn, body); err != nil {
- t.logger.Debug("Read body error: %v", err)
- return
- }
- // Handle message
- var response []byte
- var err error
- t.mu.RLock()
- handler := t.handler
- t.mu.RUnlock()
- if handler == nil {
- t.logger.Error("No handler registered")
- return
- }
- switch RPCType(typeBuf[0]) {
- case RPCRequestVote:
- var args RequestVoteArgs
- if err := DefaultCodec.Unmarshal(body, &args); err != nil {
- t.logger.Error("Unmarshal RequestVote error: %v", err)
- return
- }
- reply := handler.HandleRequestVote(&args)
- response, err = DefaultCodec.Marshal(reply)
- case RPCAppendEntries:
- var args AppendEntriesArgs
- if err := DefaultCodec.Unmarshal(body, &args); err != nil {
- t.logger.Error("Unmarshal AppendEntries error: %v", err)
- return
- }
- reply := handler.HandleAppendEntries(&args)
- response, err = DefaultCodec.Marshal(reply)
- case RPCInstallSnapshot:
- var args InstallSnapshotArgs
- if err := DefaultCodec.Unmarshal(body, &args); err != nil {
- t.logger.Error("Unmarshal InstallSnapshot error: %v", err)
- return
- }
- reply := handler.HandleInstallSnapshot(&args)
- response, err = DefaultCodec.Marshal(reply)
- case RPCPropose:
- var args ProposeArgs
- if err := DefaultCodec.Unmarshal(body, &args); err != nil {
- t.logger.Error("Unmarshal Propose error: %v", err)
- return
- }
- reply := handler.HandlePropose(&args)
- response, err = DefaultCodec.Marshal(reply)
- case RPCAddNode:
- var args AddNodeArgs
- if err := DefaultCodec.Unmarshal(body, &args); err != nil {
- t.logger.Error("Unmarshal AddNode error: %v", err)
- return
- }
- reply := handler.HandleAddNode(&args)
- response, err = DefaultCodec.Marshal(reply)
- case RPCRemoveNode:
- var args RemoveNodeArgs
- if err := DefaultCodec.Unmarshal(body, &args); err != nil {
- t.logger.Error("Unmarshal RemoveNode error: %v", err)
- return
- }
- reply := handler.HandleRemoveNode(&args)
- response, err = DefaultCodec.Marshal(reply)
- case RPCTimeoutNow:
- var args TimeoutNowArgs
- if err := DefaultCodec.Unmarshal(body, &args); err != nil {
- t.logger.Error("Unmarshal TimeoutNow error: %v", err)
- return
- }
- reply := handler.HandleTimeoutNow(&args)
- response, err = DefaultCodec.Marshal(reply)
- case RPCReadIndex:
- var args ReadIndexArgs
- if err := DefaultCodec.Unmarshal(body, &args); err != nil {
- t.logger.Error("Unmarshal ReadIndex error: %v", err)
- return
- }
- reply := handler.HandleReadIndex(&args)
- response, err = DefaultCodec.Marshal(reply)
- case RPCGet:
- var args GetArgs
- if err := DefaultCodec.Unmarshal(body, &args); err != nil {
- t.logger.Error("Unmarshal Get error: %v", err)
- return
- }
- reply := handler.HandleGet(&args)
- response, err = DefaultCodec.Marshal(reply)
- default:
- t.logger.Error("Unknown RPC type: %d", typeBuf[0])
- return
- }
- if err != nil {
- t.logger.Error("Marshal response error: %v", err)
- return
- }
- // Write response
- conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
- respLen := make([]byte, 4)
- respLen[0] = byte(len(response) >> 24)
- respLen[1] = byte(len(response) >> 16)
- respLen[2] = byte(len(response) >> 8)
- respLen[3] = byte(len(response))
- if _, err := conn.Write(respLen); err != nil {
- t.logger.Debug("Write response length error: %v", err)
- return
- }
- if _, err := conn.Write(response); err != nil {
- t.logger.Debug("Write response error: %v", err)
- return
- }
- }
- }
- // Stop stops the TCP server
- func (t *TCPTransport) Stop() error {
- close(t.shutdownCh)
- // Close all connections
- t.mu.Lock()
- for _, conn := range t.conns {
- conn.Close()
- }
- t.conns = make(map[string]net.Conn)
- t.mu.Unlock()
- if t.listener != nil {
- return t.listener.Close()
- }
- return nil
- }
- // getConn gets the persistent connection or creates a new one
- func (t *TCPTransport) getConn(target string) (net.Conn, error) {
- t.mu.Lock()
- defer t.mu.Unlock()
- // Check existing connection
- if conn, ok := t.conns[target]; ok {
- return conn, nil
- }
- // Dial new connection
- conn, err := net.DialTimeout("tcp", target, 5*time.Second)
- if err != nil {
- return nil, err
- }
- t.conns[target] = conn
- return conn, nil
- }
- // closeConn closes and removes a connection from the map
- func (t *TCPTransport) closeConn(target string, conn net.Conn) {
- t.mu.Lock()
- defer t.mu.Unlock()
- // Only delete if it's the current connection
- if current, ok := t.conns[target]; ok && current == conn {
- delete(t.conns, target)
- conn.Close()
- }
- }
- // sendTCPRPC sends an RPC over TCP
- func (t *TCPTransport) sendTCPRPC(ctx context.Context, target string, rpcType RPCType, args interface{}, reply interface{}) error {
- // Simple retry mechanism for stale connections
- maxRetries := 2
- var lastErr error
- for i := 0; i < maxRetries; i++ {
- conn, err := t.getConn(target)
- if err != nil {
- return fmt.Errorf("failed to get connection: %w", err)
- }
- data, err := DefaultCodec.Marshal(args)
- if err != nil {
- // Don't close conn here as we haven't touched it yet in this iteration
- return fmt.Errorf("failed to marshal request: %w", err)
- }
- // Set deadline from context
- deadline, ok := ctx.Deadline()
- if !ok {
- deadline = time.Now().Add(5 * time.Second)
- }
- conn.SetDeadline(deadline)
- // Write message: [type(1)][length(4)][body]
- header := make([]byte, 5)
- header[0] = byte(rpcType)
- header[1] = byte(len(data) >> 24)
- header[2] = byte(len(data) >> 16)
- header[3] = byte(len(data) >> 8)
- header[4] = byte(len(data))
- if _, err := conn.Write(header); err != nil {
- t.closeConn(target, conn)
- lastErr = fmt.Errorf("failed to write header: %w", err)
- // If this was a reused connection, retry with a new one
- if i < maxRetries-1 {
- continue
- }
- return lastErr
- }
- if _, err := conn.Write(data); err != nil {
- t.closeConn(target, conn)
- lastErr = fmt.Errorf("failed to write body: %w", err)
- // If write failed, it might be a broken pipe, but we already wrote header.
- // Retrying here is risky if the server processed the header but not body,
- // but for idempotent RPCs it might be okay. For safety, we only retry if it looks like a connection issue on write.
- // However, since we're in binary protocol, partial writes break the stream anyway.
- if i < maxRetries-1 {
- continue
- }
- return lastErr
- }
- // Read response length
- lenBuf := make([]byte, 4)
- if _, err := io.ReadFull(conn, lenBuf); err != nil {
- t.closeConn(target, conn)
- lastErr = fmt.Errorf("failed to read response length: %w", err)
- if i < maxRetries-1 && (err == io.EOF || isConnectionReset(err)) {
- continue
- }
- return lastErr
- }
- length := uint32(lenBuf[0])<<24 | uint32(lenBuf[1])<<16 | uint32(lenBuf[2])<<8 | uint32(lenBuf[3])
- if length > 10*1024*1024 {
- t.closeConn(target, conn)
- return fmt.Errorf("response too large: %d", length)
- }
- // Read response body
- respBody := make([]byte, length)
- if _, err := io.ReadFull(conn, respBody); err != nil {
- t.closeConn(target, conn)
- return fmt.Errorf("failed to read response body: %w", err)
- }
- if err := DefaultCodec.Unmarshal(respBody, reply); err != nil {
- t.closeConn(target, conn)
- return fmt.Errorf("failed to unmarshal response: %w", err)
- }
- // Keep connection open
- return nil
- }
- return lastErr
- }
- func isConnectionReset(err error) bool {
- if opErr, ok := err.(*net.OpError); ok {
- return opErr.Err.Error() == "connection reset by peer" || opErr.Err.Error() == "broken pipe"
- }
- return false
- }
- // RequestVote sends a RequestVote RPC
- func (t *TCPTransport) RequestVote(ctx context.Context, target string, args *RequestVoteArgs) (*RequestVoteReply, error) {
- var reply RequestVoteReply
- err := t.sendTCPRPC(ctx, target, RPCRequestVote, args, &reply)
- if err != nil {
- return nil, err
- }
- return &reply, nil
- }
- // AppendEntries sends an AppendEntries RPC
- func (t *TCPTransport) AppendEntries(ctx context.Context, target string, args *AppendEntriesArgs) (*AppendEntriesReply, error) {
- var reply AppendEntriesReply
- err := t.sendTCPRPC(ctx, target, RPCAppendEntries, args, &reply)
- if err != nil {
- return nil, err
- }
- return &reply, nil
- }
- // InstallSnapshot sends an InstallSnapshot RPC
- func (t *TCPTransport) InstallSnapshot(ctx context.Context, target string, args *InstallSnapshotArgs) (*InstallSnapshotReply, error) {
- var reply InstallSnapshotReply
- err := t.sendTCPRPC(ctx, target, RPCInstallSnapshot, args, &reply)
- if err != nil {
- return nil, err
- }
- return &reply, nil
- }
- // ForwardPropose forwards a propose request to the leader
- func (t *TCPTransport) ForwardPropose(ctx context.Context, target string, args *ProposeArgs) (*ProposeReply, error) {
- var reply ProposeReply
- err := t.sendTCPRPC(ctx, target, RPCPropose, args, &reply)
- if err != nil {
- return nil, err
- }
- return &reply, nil
- }
- // ForwardAddNode forwards an AddNode request to the leader
- func (t *TCPTransport) ForwardAddNode(ctx context.Context, target string, args *AddNodeArgs) (*AddNodeReply, error) {
- var reply AddNodeReply
- err := t.sendTCPRPC(ctx, target, RPCAddNode, args, &reply)
- if err != nil {
- return nil, err
- }
- return &reply, nil
- }
- // ForwardRemoveNode forwards a RemoveNode request to the leader
- func (t *TCPTransport) ForwardRemoveNode(ctx context.Context, target string, args *RemoveNodeArgs) (*RemoveNodeReply, error) {
- var reply RemoveNodeReply
- err := t.sendTCPRPC(ctx, target, RPCRemoveNode, args, &reply)
- if err != nil {
- return nil, err
- }
- return &reply, nil
- }
- // TimeoutNow sends a TimeoutNow RPC for leadership transfer
- func (t *TCPTransport) TimeoutNow(ctx context.Context, target string, args *TimeoutNowArgs) (*TimeoutNowReply, error) {
- var reply TimeoutNowReply
- err := t.sendTCPRPC(ctx, target, RPCTimeoutNow, args, &reply)
- if err != nil {
- return nil, err
- }
- return &reply, nil
- }
- // ReadIndex sends a ReadIndex RPC for linearizable reads
- func (t *TCPTransport) ReadIndex(ctx context.Context, target string, args *ReadIndexArgs) (*ReadIndexReply, error) {
- var reply ReadIndexReply
- err := t.sendTCPRPC(ctx, target, RPCReadIndex, args, &reply)
- if err != nil {
- return nil, err
- }
- return &reply, nil
- }
- // ForwardGet sends a Get RPC for remote KV reads
- func (t *TCPTransport) ForwardGet(ctx context.Context, target string, args *GetArgs) (*GetReply, error) {
- var reply GetReply
- err := t.sendTCPRPC(ctx, target, RPCGet, args, &reply)
- if err != nil {
- return nil, err
- }
- return &reply, nil
- }
|