rpc.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. package raft
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net"
  7. "sync"
  8. "time"
  9. )
  10. // Transport defines the interface for RPC communication
  11. type Transport interface {
  12. // Start starts the transport
  13. Start() error
  14. // Stop stops the transport
  15. Stop() error
  16. // RequestVote sends a RequestVote RPC to the target node
  17. RequestVote(ctx context.Context, target string, args *RequestVoteArgs) (*RequestVoteReply, error)
  18. // AppendEntries sends an AppendEntries RPC to the target node
  19. AppendEntries(ctx context.Context, target string, args *AppendEntriesArgs) (*AppendEntriesReply, error)
  20. // InstallSnapshot sends an InstallSnapshot RPC to the target node
  21. InstallSnapshot(ctx context.Context, target string, args *InstallSnapshotArgs) (*InstallSnapshotReply, error)
  22. // ForwardPropose forwards a propose request to the leader
  23. ForwardPropose(ctx context.Context, target string, args *ProposeArgs) (*ProposeReply, error)
  24. // ForwardAddNode forwards an AddNode request to the leader
  25. ForwardAddNode(ctx context.Context, target string, args *AddNodeArgs) (*AddNodeReply, error)
  26. // ForwardRemoveNode forwards a RemoveNode request to the leader
  27. ForwardRemoveNode(ctx context.Context, target string, args *RemoveNodeArgs) (*RemoveNodeReply, error)
  28. // TimeoutNow sends a TimeoutNow RPC for leadership transfer
  29. TimeoutNow(ctx context.Context, target string, args *TimeoutNowArgs) (*TimeoutNowReply, error)
  30. // ReadIndex sends a ReadIndex RPC for linearizable reads
  31. ReadIndex(ctx context.Context, target string, args *ReadIndexArgs) (*ReadIndexReply, error)
  32. // ForwardGet sends a Get RPC for remote KV reads
  33. ForwardGet(ctx context.Context, target string, args *GetArgs) (*GetReply, error)
  34. // SetRPCHandler sets the handler for incoming RPCs
  35. SetRPCHandler(handler RPCHandler)
  36. }
  37. // RPCHandler handles incoming RPCs
  38. type RPCHandler interface {
  39. HandleRequestVote(args *RequestVoteArgs) *RequestVoteReply
  40. HandleAppendEntries(args *AppendEntriesArgs) *AppendEntriesReply
  41. HandleInstallSnapshot(args *InstallSnapshotArgs) *InstallSnapshotReply
  42. HandlePropose(args *ProposeArgs) *ProposeReply
  43. HandleAddNode(args *AddNodeArgs) *AddNodeReply
  44. HandleRemoveNode(args *RemoveNodeArgs) *RemoveNodeReply
  45. HandleTimeoutNow(args *TimeoutNowArgs) *TimeoutNowReply
  46. HandleReadIndex(args *ReadIndexArgs) *ReadIndexReply
  47. HandleGet(args *GetArgs) *GetReply
  48. }
  49. // TCPTransport implements Transport using raw TCP with binary protocol
  50. // This is more efficient than HTTP for high-frequency RPCs
  51. type TCPTransport struct {
  52. mu sync.RWMutex
  53. localAddr string
  54. handler RPCHandler
  55. logger Logger
  56. listener net.Listener
  57. shutdownCh chan struct{}
  58. // Single persistent connection per target
  59. conns map[string]net.Conn
  60. }
  61. // NewTCPTransport creates a new TCP transport
  62. func NewTCPTransport(localAddr string, poolSize int, logger Logger) *TCPTransport {
  63. if logger == nil {
  64. logger = &NoopLogger{}
  65. }
  66. return &TCPTransport{
  67. localAddr: localAddr,
  68. logger: logger,
  69. shutdownCh: make(chan struct{}),
  70. conns: make(map[string]net.Conn),
  71. }
  72. }
  73. // SetRPCHandler sets the handler for incoming RPCs
  74. func (t *TCPTransport) SetRPCHandler(handler RPCHandler) {
  75. t.mu.Lock()
  76. defer t.mu.Unlock()
  77. t.handler = handler
  78. }
  79. // Start starts the TCP server
  80. func (t *TCPTransport) Start() error {
  81. var err error
  82. t.listener, err = net.Listen("tcp", t.localAddr)
  83. if err != nil {
  84. return fmt.Errorf("failed to listen on %s: %w", t.localAddr, err)
  85. }
  86. go t.acceptLoop()
  87. t.logger.Info("TCP Transport started on %s", t.localAddr)
  88. return nil
  89. }
  90. // acceptLoop accepts incoming connections
  91. func (t *TCPTransport) acceptLoop() {
  92. for {
  93. select {
  94. case <-t.shutdownCh:
  95. return
  96. default:
  97. }
  98. conn, err := t.listener.Accept()
  99. if err != nil {
  100. select {
  101. case <-t.shutdownCh:
  102. return
  103. default:
  104. t.logger.Error("Accept error: %v", err)
  105. continue
  106. }
  107. }
  108. go t.handleConnection(conn)
  109. }
  110. }
  111. // handleConnection handles an incoming connection
  112. func (t *TCPTransport) handleConnection(conn net.Conn) {
  113. defer conn.Close()
  114. for {
  115. select {
  116. case <-t.shutdownCh:
  117. return
  118. default:
  119. }
  120. // Set read deadline
  121. conn.SetReadDeadline(time.Now().Add(30 * time.Second))
  122. // Read message type (1 byte)
  123. typeBuf := make([]byte, 1)
  124. if _, err := io.ReadFull(conn, typeBuf); err != nil {
  125. if err != io.EOF {
  126. t.logger.Debug("Read type error: %v", err)
  127. }
  128. return
  129. }
  130. // Read message length (4 bytes)
  131. lenBuf := make([]byte, 4)
  132. if _, err := io.ReadFull(conn, lenBuf); err != nil {
  133. t.logger.Debug("Read length error: %v", err)
  134. return
  135. }
  136. length := uint32(lenBuf[0])<<24 | uint32(lenBuf[1])<<16 | uint32(lenBuf[2])<<8 | uint32(lenBuf[3])
  137. if length > 10*1024*1024 { // 10MB limit
  138. t.logger.Error("Message too large: %d", length)
  139. return
  140. }
  141. // Read message body
  142. body := make([]byte, length)
  143. if _, err := io.ReadFull(conn, body); err != nil {
  144. t.logger.Debug("Read body error: %v", err)
  145. return
  146. }
  147. // Handle message
  148. var response []byte
  149. var err error
  150. t.mu.RLock()
  151. handler := t.handler
  152. t.mu.RUnlock()
  153. if handler == nil {
  154. t.logger.Error("No handler registered")
  155. return
  156. }
  157. switch RPCType(typeBuf[0]) {
  158. case RPCRequestVote:
  159. var args RequestVoteArgs
  160. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  161. t.logger.Error("Unmarshal RequestVote error: %v", err)
  162. return
  163. }
  164. reply := handler.HandleRequestVote(&args)
  165. response, err = DefaultCodec.Marshal(reply)
  166. case RPCAppendEntries:
  167. var args AppendEntriesArgs
  168. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  169. t.logger.Error("Unmarshal AppendEntries error: %v", err)
  170. return
  171. }
  172. reply := handler.HandleAppendEntries(&args)
  173. response, err = DefaultCodec.Marshal(reply)
  174. case RPCInstallSnapshot:
  175. var args InstallSnapshotArgs
  176. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  177. t.logger.Error("Unmarshal InstallSnapshot error: %v", err)
  178. return
  179. }
  180. reply := handler.HandleInstallSnapshot(&args)
  181. response, err = DefaultCodec.Marshal(reply)
  182. case RPCPropose:
  183. var args ProposeArgs
  184. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  185. t.logger.Error("Unmarshal Propose error: %v", err)
  186. return
  187. }
  188. reply := handler.HandlePropose(&args)
  189. response, err = DefaultCodec.Marshal(reply)
  190. case RPCAddNode:
  191. var args AddNodeArgs
  192. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  193. t.logger.Error("Unmarshal AddNode error: %v", err)
  194. return
  195. }
  196. reply := handler.HandleAddNode(&args)
  197. response, err = DefaultCodec.Marshal(reply)
  198. case RPCRemoveNode:
  199. var args RemoveNodeArgs
  200. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  201. t.logger.Error("Unmarshal RemoveNode error: %v", err)
  202. return
  203. }
  204. reply := handler.HandleRemoveNode(&args)
  205. response, err = DefaultCodec.Marshal(reply)
  206. case RPCTimeoutNow:
  207. var args TimeoutNowArgs
  208. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  209. t.logger.Error("Unmarshal TimeoutNow error: %v", err)
  210. return
  211. }
  212. reply := handler.HandleTimeoutNow(&args)
  213. response, err = DefaultCodec.Marshal(reply)
  214. case RPCReadIndex:
  215. var args ReadIndexArgs
  216. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  217. t.logger.Error("Unmarshal ReadIndex error: %v", err)
  218. return
  219. }
  220. reply := handler.HandleReadIndex(&args)
  221. response, err = DefaultCodec.Marshal(reply)
  222. case RPCGet:
  223. var args GetArgs
  224. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  225. t.logger.Error("Unmarshal Get error: %v", err)
  226. return
  227. }
  228. reply := handler.HandleGet(&args)
  229. response, err = DefaultCodec.Marshal(reply)
  230. default:
  231. t.logger.Error("Unknown RPC type: %d", typeBuf[0])
  232. return
  233. }
  234. if err != nil {
  235. t.logger.Error("Marshal response error: %v", err)
  236. return
  237. }
  238. // Write response
  239. conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
  240. respLen := make([]byte, 4)
  241. respLen[0] = byte(len(response) >> 24)
  242. respLen[1] = byte(len(response) >> 16)
  243. respLen[2] = byte(len(response) >> 8)
  244. respLen[3] = byte(len(response))
  245. if _, err := conn.Write(respLen); err != nil {
  246. t.logger.Debug("Write response length error: %v", err)
  247. return
  248. }
  249. if _, err := conn.Write(response); err != nil {
  250. t.logger.Debug("Write response error: %v", err)
  251. return
  252. }
  253. }
  254. }
  255. // Stop stops the TCP server
  256. func (t *TCPTransport) Stop() error {
  257. close(t.shutdownCh)
  258. // Close all connections
  259. t.mu.Lock()
  260. for _, conn := range t.conns {
  261. conn.Close()
  262. }
  263. t.conns = make(map[string]net.Conn)
  264. t.mu.Unlock()
  265. if t.listener != nil {
  266. return t.listener.Close()
  267. }
  268. return nil
  269. }
  270. // getConn gets the persistent connection or creates a new one
  271. func (t *TCPTransport) getConn(target string) (net.Conn, error) {
  272. t.mu.Lock()
  273. defer t.mu.Unlock()
  274. // Check existing connection
  275. if conn, ok := t.conns[target]; ok {
  276. return conn, nil
  277. }
  278. // Dial new connection
  279. conn, err := net.DialTimeout("tcp", target, 5*time.Second)
  280. if err != nil {
  281. return nil, err
  282. }
  283. t.conns[target] = conn
  284. return conn, nil
  285. }
  286. // closeConn closes and removes a connection from the map
  287. func (t *TCPTransport) closeConn(target string, conn net.Conn) {
  288. t.mu.Lock()
  289. defer t.mu.Unlock()
  290. // Only delete if it's the current connection
  291. if current, ok := t.conns[target]; ok && current == conn {
  292. delete(t.conns, target)
  293. conn.Close()
  294. }
  295. }
  296. // sendTCPRPC sends an RPC over TCP
  297. func (t *TCPTransport) sendTCPRPC(ctx context.Context, target string, rpcType RPCType, args interface{}, reply interface{}) error {
  298. // Simple retry mechanism for stale connections
  299. maxRetries := 2
  300. var lastErr error
  301. for i := 0; i < maxRetries; i++ {
  302. conn, err := t.getConn(target)
  303. if err != nil {
  304. return fmt.Errorf("failed to get connection: %w", err)
  305. }
  306. data, err := DefaultCodec.Marshal(args)
  307. if err != nil {
  308. // Don't close conn here as we haven't touched it yet in this iteration
  309. return fmt.Errorf("failed to marshal request: %w", err)
  310. }
  311. // Set deadline from context
  312. deadline, ok := ctx.Deadline()
  313. if !ok {
  314. deadline = time.Now().Add(5 * time.Second)
  315. }
  316. conn.SetDeadline(deadline)
  317. // Write message: [type(1)][length(4)][body]
  318. header := make([]byte, 5)
  319. header[0] = byte(rpcType)
  320. header[1] = byte(len(data) >> 24)
  321. header[2] = byte(len(data) >> 16)
  322. header[3] = byte(len(data) >> 8)
  323. header[4] = byte(len(data))
  324. if _, err := conn.Write(header); err != nil {
  325. t.closeConn(target, conn)
  326. lastErr = fmt.Errorf("failed to write header: %w", err)
  327. // If this was a reused connection, retry with a new one
  328. if i < maxRetries-1 {
  329. continue
  330. }
  331. return lastErr
  332. }
  333. if _, err := conn.Write(data); err != nil {
  334. t.closeConn(target, conn)
  335. lastErr = fmt.Errorf("failed to write body: %w", err)
  336. // If write failed, it might be a broken pipe, but we already wrote header.
  337. // Retrying here is risky if the server processed the header but not body,
  338. // but for idempotent RPCs it might be okay. For safety, we only retry if it looks like a connection issue on write.
  339. // However, since we're in binary protocol, partial writes break the stream anyway.
  340. if i < maxRetries-1 {
  341. continue
  342. }
  343. return lastErr
  344. }
  345. // Read response length
  346. lenBuf := make([]byte, 4)
  347. if _, err := io.ReadFull(conn, lenBuf); err != nil {
  348. t.closeConn(target, conn)
  349. lastErr = fmt.Errorf("failed to read response length: %w", err)
  350. if i < maxRetries-1 && (err == io.EOF || isConnectionReset(err)) {
  351. continue
  352. }
  353. return lastErr
  354. }
  355. length := uint32(lenBuf[0])<<24 | uint32(lenBuf[1])<<16 | uint32(lenBuf[2])<<8 | uint32(lenBuf[3])
  356. if length > 10*1024*1024 {
  357. t.closeConn(target, conn)
  358. return fmt.Errorf("response too large: %d", length)
  359. }
  360. // Read response body
  361. respBody := make([]byte, length)
  362. if _, err := io.ReadFull(conn, respBody); err != nil {
  363. t.closeConn(target, conn)
  364. return fmt.Errorf("failed to read response body: %w", err)
  365. }
  366. if err := DefaultCodec.Unmarshal(respBody, reply); err != nil {
  367. t.closeConn(target, conn)
  368. return fmt.Errorf("failed to unmarshal response: %w", err)
  369. }
  370. // Keep connection open
  371. return nil
  372. }
  373. return lastErr
  374. }
  375. func isConnectionReset(err error) bool {
  376. if opErr, ok := err.(*net.OpError); ok {
  377. return opErr.Err.Error() == "connection reset by peer" || opErr.Err.Error() == "broken pipe"
  378. }
  379. return false
  380. }
  381. // RequestVote sends a RequestVote RPC
  382. func (t *TCPTransport) RequestVote(ctx context.Context, target string, args *RequestVoteArgs) (*RequestVoteReply, error) {
  383. var reply RequestVoteReply
  384. err := t.sendTCPRPC(ctx, target, RPCRequestVote, args, &reply)
  385. if err != nil {
  386. return nil, err
  387. }
  388. return &reply, nil
  389. }
  390. // AppendEntries sends an AppendEntries RPC
  391. func (t *TCPTransport) AppendEntries(ctx context.Context, target string, args *AppendEntriesArgs) (*AppendEntriesReply, error) {
  392. var reply AppendEntriesReply
  393. err := t.sendTCPRPC(ctx, target, RPCAppendEntries, args, &reply)
  394. if err != nil {
  395. return nil, err
  396. }
  397. return &reply, nil
  398. }
  399. // InstallSnapshot sends an InstallSnapshot RPC
  400. func (t *TCPTransport) InstallSnapshot(ctx context.Context, target string, args *InstallSnapshotArgs) (*InstallSnapshotReply, error) {
  401. var reply InstallSnapshotReply
  402. err := t.sendTCPRPC(ctx, target, RPCInstallSnapshot, args, &reply)
  403. if err != nil {
  404. return nil, err
  405. }
  406. return &reply, nil
  407. }
  408. // ForwardPropose forwards a propose request to the leader
  409. func (t *TCPTransport) ForwardPropose(ctx context.Context, target string, args *ProposeArgs) (*ProposeReply, error) {
  410. var reply ProposeReply
  411. err := t.sendTCPRPC(ctx, target, RPCPropose, args, &reply)
  412. if err != nil {
  413. return nil, err
  414. }
  415. return &reply, nil
  416. }
  417. // ForwardAddNode forwards an AddNode request to the leader
  418. func (t *TCPTransport) ForwardAddNode(ctx context.Context, target string, args *AddNodeArgs) (*AddNodeReply, error) {
  419. var reply AddNodeReply
  420. err := t.sendTCPRPC(ctx, target, RPCAddNode, args, &reply)
  421. if err != nil {
  422. return nil, err
  423. }
  424. return &reply, nil
  425. }
  426. // ForwardRemoveNode forwards a RemoveNode request to the leader
  427. func (t *TCPTransport) ForwardRemoveNode(ctx context.Context, target string, args *RemoveNodeArgs) (*RemoveNodeReply, error) {
  428. var reply RemoveNodeReply
  429. err := t.sendTCPRPC(ctx, target, RPCRemoveNode, args, &reply)
  430. if err != nil {
  431. return nil, err
  432. }
  433. return &reply, nil
  434. }
  435. // TimeoutNow sends a TimeoutNow RPC for leadership transfer
  436. func (t *TCPTransport) TimeoutNow(ctx context.Context, target string, args *TimeoutNowArgs) (*TimeoutNowReply, error) {
  437. var reply TimeoutNowReply
  438. err := t.sendTCPRPC(ctx, target, RPCTimeoutNow, args, &reply)
  439. if err != nil {
  440. return nil, err
  441. }
  442. return &reply, nil
  443. }
  444. // ReadIndex sends a ReadIndex RPC for linearizable reads
  445. func (t *TCPTransport) ReadIndex(ctx context.Context, target string, args *ReadIndexArgs) (*ReadIndexReply, error) {
  446. var reply ReadIndexReply
  447. err := t.sendTCPRPC(ctx, target, RPCReadIndex, args, &reply)
  448. if err != nil {
  449. return nil, err
  450. }
  451. return &reply, nil
  452. }
  453. // ForwardGet sends a Get RPC for remote KV reads
  454. func (t *TCPTransport) ForwardGet(ctx context.Context, target string, args *GetArgs) (*GetReply, error) {
  455. var reply GetReply
  456. err := t.sendTCPRPC(ctx, target, RPCGet, args, &reply)
  457. if err != nil {
  458. return nil, err
  459. }
  460. return &reply, nil
  461. }