package raft import ( "context" "fmt" "io" "net" "sync" "time" ) // Transport defines the interface for RPC communication type Transport interface { // Start starts the transport Start() error // Stop stops the transport Stop() error // RequestVote sends a RequestVote RPC to the target node RequestVote(ctx context.Context, target string, args *RequestVoteArgs) (*RequestVoteReply, error) // AppendEntries sends an AppendEntries RPC to the target node AppendEntries(ctx context.Context, target string, args *AppendEntriesArgs) (*AppendEntriesReply, error) // InstallSnapshot sends an InstallSnapshot RPC to the target node InstallSnapshot(ctx context.Context, target string, args *InstallSnapshotArgs) (*InstallSnapshotReply, error) // ForwardPropose forwards a propose request to the leader ForwardPropose(ctx context.Context, target string, args *ProposeArgs) (*ProposeReply, error) // ForwardAddNode forwards an AddNode request to the leader ForwardAddNode(ctx context.Context, target string, args *AddNodeArgs) (*AddNodeReply, error) // ForwardRemoveNode forwards a RemoveNode request to the leader ForwardRemoveNode(ctx context.Context, target string, args *RemoveNodeArgs) (*RemoveNodeReply, error) // TimeoutNow sends a TimeoutNow RPC for leadership transfer TimeoutNow(ctx context.Context, target string, args *TimeoutNowArgs) (*TimeoutNowReply, error) // ReadIndex sends a ReadIndex RPC for linearizable reads ReadIndex(ctx context.Context, target string, args *ReadIndexArgs) (*ReadIndexReply, error) // ForwardGet sends a Get RPC for remote KV reads ForwardGet(ctx context.Context, target string, args *GetArgs) (*GetReply, error) // SetRPCHandler sets the handler for incoming RPCs SetRPCHandler(handler RPCHandler) } // RPCHandler handles incoming RPCs type RPCHandler interface { HandleRequestVote(args *RequestVoteArgs) *RequestVoteReply HandleAppendEntries(args *AppendEntriesArgs) *AppendEntriesReply HandleInstallSnapshot(args *InstallSnapshotArgs) *InstallSnapshotReply HandlePropose(args *ProposeArgs) *ProposeReply HandleAddNode(args *AddNodeArgs) *AddNodeReply HandleRemoveNode(args *RemoveNodeArgs) *RemoveNodeReply HandleTimeoutNow(args *TimeoutNowArgs) *TimeoutNowReply HandleReadIndex(args *ReadIndexArgs) *ReadIndexReply HandleGet(args *GetArgs) *GetReply } // TCPTransport implements Transport using raw TCP with binary protocol // This is more efficient than HTTP for high-frequency RPCs type TCPTransport struct { mu sync.RWMutex localAddr string handler RPCHandler logger Logger listener net.Listener shutdownCh chan struct{} // Single persistent connection per target conns map[string]net.Conn } // NewTCPTransport creates a new TCP transport func NewTCPTransport(localAddr string, poolSize int, logger Logger) *TCPTransport { if logger == nil { logger = &NoopLogger{} } return &TCPTransport{ localAddr: localAddr, logger: logger, shutdownCh: make(chan struct{}), conns: make(map[string]net.Conn), } } // SetRPCHandler sets the handler for incoming RPCs func (t *TCPTransport) SetRPCHandler(handler RPCHandler) { t.mu.Lock() defer t.mu.Unlock() t.handler = handler } // Start starts the TCP server func (t *TCPTransport) Start() error { var err error t.listener, err = net.Listen("tcp", t.localAddr) if err != nil { return fmt.Errorf("failed to listen on %s: %w", t.localAddr, err) } go t.acceptLoop() t.logger.Info("TCP Transport started on %s", t.localAddr) return nil } // acceptLoop accepts incoming connections func (t *TCPTransport) acceptLoop() { for { select { case <-t.shutdownCh: return default: } conn, err := t.listener.Accept() if err != nil { select { case <-t.shutdownCh: return default: t.logger.Error("Accept error: %v", err) continue } } go t.handleConnection(conn) } } // handleConnection handles an incoming connection func (t *TCPTransport) handleConnection(conn net.Conn) { defer conn.Close() for { select { case <-t.shutdownCh: return default: } // Set read deadline conn.SetReadDeadline(time.Now().Add(30 * time.Second)) // Read message type (1 byte) typeBuf := make([]byte, 1) if _, err := io.ReadFull(conn, typeBuf); err != nil { if err != io.EOF { t.logger.Debug("Read type error: %v", err) } return } // Read message length (4 bytes) lenBuf := make([]byte, 4) if _, err := io.ReadFull(conn, lenBuf); err != nil { t.logger.Debug("Read length error: %v", err) return } length := uint32(lenBuf[0])<<24 | uint32(lenBuf[1])<<16 | uint32(lenBuf[2])<<8 | uint32(lenBuf[3]) if length > 10*1024*1024 { // 10MB limit t.logger.Error("Message too large: %d", length) return } // Read message body body := make([]byte, length) if _, err := io.ReadFull(conn, body); err != nil { t.logger.Debug("Read body error: %v", err) return } // Handle message var response []byte var err error t.mu.RLock() handler := t.handler t.mu.RUnlock() if handler == nil { t.logger.Error("No handler registered") return } switch RPCType(typeBuf[0]) { case RPCRequestVote: var args RequestVoteArgs if err := DefaultCodec.Unmarshal(body, &args); err != nil { t.logger.Error("Unmarshal RequestVote error: %v", err) return } reply := handler.HandleRequestVote(&args) response, err = DefaultCodec.Marshal(reply) case RPCAppendEntries: var args AppendEntriesArgs if err := DefaultCodec.Unmarshal(body, &args); err != nil { t.logger.Error("Unmarshal AppendEntries error: %v", err) return } reply := handler.HandleAppendEntries(&args) response, err = DefaultCodec.Marshal(reply) case RPCInstallSnapshot: var args InstallSnapshotArgs if err := DefaultCodec.Unmarshal(body, &args); err != nil { t.logger.Error("Unmarshal InstallSnapshot error: %v", err) return } reply := handler.HandleInstallSnapshot(&args) response, err = DefaultCodec.Marshal(reply) case RPCPropose: var args ProposeArgs if err := DefaultCodec.Unmarshal(body, &args); err != nil { t.logger.Error("Unmarshal Propose error: %v", err) return } reply := handler.HandlePropose(&args) response, err = DefaultCodec.Marshal(reply) case RPCAddNode: var args AddNodeArgs if err := DefaultCodec.Unmarshal(body, &args); err != nil { t.logger.Error("Unmarshal AddNode error: %v", err) return } reply := handler.HandleAddNode(&args) response, err = DefaultCodec.Marshal(reply) case RPCRemoveNode: var args RemoveNodeArgs if err := DefaultCodec.Unmarshal(body, &args); err != nil { t.logger.Error("Unmarshal RemoveNode error: %v", err) return } reply := handler.HandleRemoveNode(&args) response, err = DefaultCodec.Marshal(reply) case RPCTimeoutNow: var args TimeoutNowArgs if err := DefaultCodec.Unmarshal(body, &args); err != nil { t.logger.Error("Unmarshal TimeoutNow error: %v", err) return } reply := handler.HandleTimeoutNow(&args) response, err = DefaultCodec.Marshal(reply) case RPCReadIndex: var args ReadIndexArgs if err := DefaultCodec.Unmarshal(body, &args); err != nil { t.logger.Error("Unmarshal ReadIndex error: %v", err) return } reply := handler.HandleReadIndex(&args) response, err = DefaultCodec.Marshal(reply) case RPCGet: var args GetArgs if err := DefaultCodec.Unmarshal(body, &args); err != nil { t.logger.Error("Unmarshal Get error: %v", err) return } reply := handler.HandleGet(&args) response, err = DefaultCodec.Marshal(reply) default: t.logger.Error("Unknown RPC type: %d", typeBuf[0]) return } if err != nil { t.logger.Error("Marshal response error: %v", err) return } // Write response conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) respLen := make([]byte, 4) respLen[0] = byte(len(response) >> 24) respLen[1] = byte(len(response) >> 16) respLen[2] = byte(len(response) >> 8) respLen[3] = byte(len(response)) if _, err := conn.Write(respLen); err != nil { t.logger.Debug("Write response length error: %v", err) return } if _, err := conn.Write(response); err != nil { t.logger.Debug("Write response error: %v", err) return } } } // Stop stops the TCP server func (t *TCPTransport) Stop() error { close(t.shutdownCh) // Close all connections t.mu.Lock() for _, conn := range t.conns { conn.Close() } t.conns = make(map[string]net.Conn) t.mu.Unlock() if t.listener != nil { return t.listener.Close() } return nil } // getConn gets the persistent connection or creates a new one func (t *TCPTransport) getConn(target string) (net.Conn, error) { t.mu.Lock() defer t.mu.Unlock() // Check existing connection if conn, ok := t.conns[target]; ok { return conn, nil } // Dial new connection conn, err := net.DialTimeout("tcp", target, 5*time.Second) if err != nil { return nil, err } t.conns[target] = conn return conn, nil } // closeConn closes and removes a connection from the map func (t *TCPTransport) closeConn(target string, conn net.Conn) { t.mu.Lock() defer t.mu.Unlock() // Only delete if it's the current connection if current, ok := t.conns[target]; ok && current == conn { delete(t.conns, target) conn.Close() } } // sendTCPRPC sends an RPC over TCP func (t *TCPTransport) sendTCPRPC(ctx context.Context, target string, rpcType RPCType, args interface{}, reply interface{}) error { // Simple retry mechanism for stale connections maxRetries := 2 var lastErr error for i := 0; i < maxRetries; i++ { conn, err := t.getConn(target) if err != nil { return fmt.Errorf("failed to get connection: %w", err) } data, err := DefaultCodec.Marshal(args) if err != nil { // Don't close conn here as we haven't touched it yet in this iteration return fmt.Errorf("failed to marshal request: %w", err) } // Set deadline from context deadline, ok := ctx.Deadline() if !ok { deadline = time.Now().Add(5 * time.Second) } conn.SetDeadline(deadline) // Write message: [type(1)][length(4)][body] header := make([]byte, 5) header[0] = byte(rpcType) header[1] = byte(len(data) >> 24) header[2] = byte(len(data) >> 16) header[3] = byte(len(data) >> 8) header[4] = byte(len(data)) if _, err := conn.Write(header); err != nil { t.closeConn(target, conn) lastErr = fmt.Errorf("failed to write header: %w", err) // If this was a reused connection, retry with a new one if i < maxRetries-1 { continue } return lastErr } if _, err := conn.Write(data); err != nil { t.closeConn(target, conn) lastErr = fmt.Errorf("failed to write body: %w", err) // If write failed, it might be a broken pipe, but we already wrote header. // Retrying here is risky if the server processed the header but not body, // but for idempotent RPCs it might be okay. For safety, we only retry if it looks like a connection issue on write. // However, since we're in binary protocol, partial writes break the stream anyway. if i < maxRetries-1 { continue } return lastErr } // Read response length lenBuf := make([]byte, 4) if _, err := io.ReadFull(conn, lenBuf); err != nil { t.closeConn(target, conn) lastErr = fmt.Errorf("failed to read response length: %w", err) if i < maxRetries-1 && (err == io.EOF || isConnectionReset(err)) { continue } return lastErr } length := uint32(lenBuf[0])<<24 | uint32(lenBuf[1])<<16 | uint32(lenBuf[2])<<8 | uint32(lenBuf[3]) if length > 10*1024*1024 { t.closeConn(target, conn) return fmt.Errorf("response too large: %d", length) } // Read response body respBody := make([]byte, length) if _, err := io.ReadFull(conn, respBody); err != nil { t.closeConn(target, conn) return fmt.Errorf("failed to read response body: %w", err) } if err := DefaultCodec.Unmarshal(respBody, reply); err != nil { t.closeConn(target, conn) return fmt.Errorf("failed to unmarshal response: %w", err) } // Keep connection open return nil } return lastErr } func isConnectionReset(err error) bool { if opErr, ok := err.(*net.OpError); ok { return opErr.Err.Error() == "connection reset by peer" || opErr.Err.Error() == "broken pipe" } return false } // RequestVote sends a RequestVote RPC func (t *TCPTransport) RequestVote(ctx context.Context, target string, args *RequestVoteArgs) (*RequestVoteReply, error) { var reply RequestVoteReply err := t.sendTCPRPC(ctx, target, RPCRequestVote, args, &reply) if err != nil { return nil, err } return &reply, nil } // AppendEntries sends an AppendEntries RPC func (t *TCPTransport) AppendEntries(ctx context.Context, target string, args *AppendEntriesArgs) (*AppendEntriesReply, error) { var reply AppendEntriesReply err := t.sendTCPRPC(ctx, target, RPCAppendEntries, args, &reply) if err != nil { return nil, err } return &reply, nil } // InstallSnapshot sends an InstallSnapshot RPC func (t *TCPTransport) InstallSnapshot(ctx context.Context, target string, args *InstallSnapshotArgs) (*InstallSnapshotReply, error) { var reply InstallSnapshotReply err := t.sendTCPRPC(ctx, target, RPCInstallSnapshot, args, &reply) if err != nil { return nil, err } return &reply, nil } // ForwardPropose forwards a propose request to the leader func (t *TCPTransport) ForwardPropose(ctx context.Context, target string, args *ProposeArgs) (*ProposeReply, error) { var reply ProposeReply err := t.sendTCPRPC(ctx, target, RPCPropose, args, &reply) if err != nil { return nil, err } return &reply, nil } // ForwardAddNode forwards an AddNode request to the leader func (t *TCPTransport) ForwardAddNode(ctx context.Context, target string, args *AddNodeArgs) (*AddNodeReply, error) { var reply AddNodeReply err := t.sendTCPRPC(ctx, target, RPCAddNode, args, &reply) if err != nil { return nil, err } return &reply, nil } // ForwardRemoveNode forwards a RemoveNode request to the leader func (t *TCPTransport) ForwardRemoveNode(ctx context.Context, target string, args *RemoveNodeArgs) (*RemoveNodeReply, error) { var reply RemoveNodeReply err := t.sendTCPRPC(ctx, target, RPCRemoveNode, args, &reply) if err != nil { return nil, err } return &reply, nil } // TimeoutNow sends a TimeoutNow RPC for leadership transfer func (t *TCPTransport) TimeoutNow(ctx context.Context, target string, args *TimeoutNowArgs) (*TimeoutNowReply, error) { var reply TimeoutNowReply err := t.sendTCPRPC(ctx, target, RPCTimeoutNow, args, &reply) if err != nil { return nil, err } return &reply, nil } // ReadIndex sends a ReadIndex RPC for linearizable reads func (t *TCPTransport) ReadIndex(ctx context.Context, target string, args *ReadIndexArgs) (*ReadIndexReply, error) { var reply ReadIndexReply err := t.sendTCPRPC(ctx, target, RPCReadIndex, args, &reply) if err != nil { return nil, err } return &reply, nil } // ForwardGet sends a Get RPC for remote KV reads func (t *TCPTransport) ForwardGet(ctx context.Context, target string, args *GetArgs) (*GetReply, error) { var reply GetReply err := t.sendTCPRPC(ctx, target, RPCGet, args, &reply) if err != nil { return nil, err } return &reply, nil }