rpc.go 14 KB


  1. package raft
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net"
  7. "sync"
  8. "time"
  9. )
  10. // Transport defines the interface for RPC communication
  11. type Transport interface {
  12. // Start starts the transport
  13. Start() error
  14. // Stop stops the transport
  15. Stop() error
  16. // RequestVote sends a RequestVote RPC to the target node
  17. RequestVote(ctx context.Context, target string, args *RequestVoteArgs) (*RequestVoteReply, error)
  18. // AppendEntries sends an AppendEntries RPC to the target node
  19. AppendEntries(ctx context.Context, target string, args *AppendEntriesArgs) (*AppendEntriesReply, error)
  20. // InstallSnapshot sends an InstallSnapshot RPC to the target node
  21. InstallSnapshot(ctx context.Context, target string, args *InstallSnapshotArgs) (*InstallSnapshotReply, error)
  22. // ForwardPropose forwards a propose request to the leader
  23. ForwardPropose(ctx context.Context, target string, args *ProposeArgs) (*ProposeReply, error)
  24. // ForwardAddNode forwards an AddNode request to the leader
  25. ForwardAddNode(ctx context.Context, target string, args *AddNodeArgs) (*AddNodeReply, error)
  26. // ForwardRemoveNode forwards a RemoveNode request to the leader
  27. ForwardRemoveNode(ctx context.Context, target string, args *RemoveNodeArgs) (*RemoveNodeReply, error)
  28. // TimeoutNow sends a TimeoutNow RPC for leadership transfer
  29. TimeoutNow(ctx context.Context, target string, args *TimeoutNowArgs) (*TimeoutNowReply, error)
  30. // ReadIndex sends a ReadIndex RPC for linearizable reads
  31. ReadIndex(ctx context.Context, target string, args *ReadIndexArgs) (*ReadIndexReply, error)
  32. // ForwardGet sends a Get RPC for remote KV reads
  33. ForwardGet(ctx context.Context, target string, args *GetArgs) (*GetReply, error)
  34. // SetRPCHandler sets the handler for incoming RPCs
  35. SetRPCHandler(handler RPCHandler)
  36. }
  37. // RPCHandler handles incoming RPCs
  38. type RPCHandler interface {
  39. HandleRequestVote(args *RequestVoteArgs) *RequestVoteReply
  40. HandleAppendEntries(args *AppendEntriesArgs) *AppendEntriesReply
  41. HandleInstallSnapshot(args *InstallSnapshotArgs) *InstallSnapshotReply
  42. HandlePropose(args *ProposeArgs) *ProposeReply
  43. HandleAddNode(args *AddNodeArgs) *AddNodeReply
  44. HandleRemoveNode(args *RemoveNodeArgs) *RemoveNodeReply
  45. HandleTimeoutNow(args *TimeoutNowArgs) *TimeoutNowReply
  46. HandleReadIndex(args *ReadIndexArgs) *ReadIndexReply
  47. HandleGet(args *GetArgs) *GetReply
  48. }
  49. // TCPTransport implements Transport using raw TCP with binary protocol
  50. // This is more efficient than HTTP for high-frequency RPCs
  51. type TCPTransport struct {
  52. mu sync.RWMutex
  53. localAddr string
  54. handler RPCHandler
  55. logger Logger
  56. listener net.Listener
  57. shutdownCh chan struct{}
  58. // Connection pool
  59. connPool map[string]chan net.Conn
  60. poolSize int
  61. }
  62. // NewTCPTransport creates a new TCP transport
  63. func NewTCPTransport(localAddr string, poolSize int, logger Logger) *TCPTransport {
  64. if logger == nil {
  65. logger = &NoopLogger{}
  66. }
  67. if poolSize <= 0 {
  68. poolSize = 5
  69. }
  70. return &TCPTransport{
  71. localAddr: localAddr,
  72. logger: logger,
  73. shutdownCh: make(chan struct{}),
  74. connPool: make(map[string]chan net.Conn),
  75. poolSize: poolSize,
  76. }
  77. }
  78. // SetRPCHandler sets the handler for incoming RPCs
  79. func (t *TCPTransport) SetRPCHandler(handler RPCHandler) {
  80. t.mu.Lock()
  81. defer t.mu.Unlock()
  82. t.handler = handler
  83. }
  84. // Start starts the TCP server
  85. func (t *TCPTransport) Start() error {
  86. var err error
  87. t.listener, err = net.Listen("tcp", t.localAddr)
  88. if err != nil {
  89. return fmt.Errorf("failed to listen on %s: %w", t.localAddr, err)
  90. }
  91. go t.acceptLoop()
  92. t.logger.Info("TCP Transport started on %s", t.localAddr)
  93. return nil
  94. }
  95. // acceptLoop accepts incoming connections
  96. func (t *TCPTransport) acceptLoop() {
  97. for {
  98. select {
  99. case <-t.shutdownCh:
  100. return
  101. default:
  102. }
  103. conn, err := t.listener.Accept()
  104. if err != nil {
  105. select {
  106. case <-t.shutdownCh:
  107. return
  108. default:
  109. t.logger.Error("Accept error: %v", err)
  110. continue
  111. }
  112. }
  113. go t.handleConnection(conn)
  114. }
  115. }
  116. // handleConnection handles an incoming connection
  117. func (t *TCPTransport) handleConnection(conn net.Conn) {
  118. defer conn.Close()
  119. for {
  120. select {
  121. case <-t.shutdownCh:
  122. return
  123. default:
  124. }
  125. // Set read deadline
  126. conn.SetReadDeadline(time.Now().Add(30 * time.Second))
  127. // Read message type (1 byte)
  128. typeBuf := make([]byte, 1)
  129. if _, err := io.ReadFull(conn, typeBuf); err != nil {
  130. if err != io.EOF {
  131. t.logger.Debug("Read type error: %v", err)
  132. }
  133. return
  134. }
  135. // Read message length (4 bytes)
  136. lenBuf := make([]byte, 4)
  137. if _, err := io.ReadFull(conn, lenBuf); err != nil {
  138. t.logger.Debug("Read length error: %v", err)
  139. return
  140. }
  141. length := uint32(lenBuf[0])<<24 | uint32(lenBuf[1])<<16 | uint32(lenBuf[2])<<8 | uint32(lenBuf[3])
  142. if length > 10*1024*1024 { // 10MB limit
  143. t.logger.Error("Message too large: %d", length)
  144. return
  145. }
  146. // Read message body
  147. body := make([]byte, length)
  148. if _, err := io.ReadFull(conn, body); err != nil {
  149. t.logger.Debug("Read body error: %v", err)
  150. return
  151. }
  152. // Handle message
  153. var response []byte
  154. var err error
  155. t.mu.RLock()
  156. handler := t.handler
  157. t.mu.RUnlock()
  158. if handler == nil {
  159. t.logger.Error("No handler registered")
  160. return
  161. }
  162. switch RPCType(typeBuf[0]) {
  163. case RPCRequestVote:
  164. var args RequestVoteArgs
  165. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  166. t.logger.Error("Unmarshal RequestVote error: %v", err)
  167. return
  168. }
  169. reply := handler.HandleRequestVote(&args)
  170. response, err = DefaultCodec.Marshal(reply)
  171. case RPCAppendEntries:
  172. var args AppendEntriesArgs
  173. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  174. t.logger.Error("Unmarshal AppendEntries error: %v", err)
  175. return
  176. }
  177. reply := handler.HandleAppendEntries(&args)
  178. response, err = DefaultCodec.Marshal(reply)
  179. case RPCInstallSnapshot:
  180. var args InstallSnapshotArgs
  181. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  182. t.logger.Error("Unmarshal InstallSnapshot error: %v", err)
  183. return
  184. }
  185. reply := handler.HandleInstallSnapshot(&args)
  186. response, err = DefaultCodec.Marshal(reply)
  187. case RPCPropose:
  188. var args ProposeArgs
  189. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  190. t.logger.Error("Unmarshal Propose error: %v", err)
  191. return
  192. }
  193. reply := handler.HandlePropose(&args)
  194. response, err = DefaultCodec.Marshal(reply)
  195. case RPCAddNode:
  196. var args AddNodeArgs
  197. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  198. t.logger.Error("Unmarshal AddNode error: %v", err)
  199. return
  200. }
  201. reply := handler.HandleAddNode(&args)
  202. response, err = DefaultCodec.Marshal(reply)
  203. case RPCRemoveNode:
  204. var args RemoveNodeArgs
  205. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  206. t.logger.Error("Unmarshal RemoveNode error: %v", err)
  207. return
  208. }
  209. reply := handler.HandleRemoveNode(&args)
  210. response, err = DefaultCodec.Marshal(reply)
  211. case RPCTimeoutNow:
  212. var args TimeoutNowArgs
  213. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  214. t.logger.Error("Unmarshal TimeoutNow error: %v", err)
  215. return
  216. }
  217. reply := handler.HandleTimeoutNow(&args)
  218. response, err = DefaultCodec.Marshal(reply)
  219. case RPCReadIndex:
  220. var args ReadIndexArgs
  221. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  222. t.logger.Error("Unmarshal ReadIndex error: %v", err)
  223. return
  224. }
  225. reply := handler.HandleReadIndex(&args)
  226. response, err = DefaultCodec.Marshal(reply)
  227. case RPCGet:
  228. var args GetArgs
  229. if err := DefaultCodec.Unmarshal(body, &args); err != nil {
  230. t.logger.Error("Unmarshal Get error: %v", err)
  231. return
  232. }
  233. reply := handler.HandleGet(&args)
  234. response, err = DefaultCodec.Marshal(reply)
  235. default:
  236. t.logger.Error("Unknown RPC type: %d", typeBuf[0])
  237. return
  238. }
  239. if err != nil {
  240. t.logger.Error("Marshal response error: %v", err)
  241. return
  242. }
  243. // Write response
  244. conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
  245. respLen := make([]byte, 4)
  246. respLen[0] = byte(len(response) >> 24)
  247. respLen[1] = byte(len(response) >> 16)
  248. respLen[2] = byte(len(response) >> 8)
  249. respLen[3] = byte(len(response))
  250. if _, err := conn.Write(respLen); err != nil {
  251. t.logger.Debug("Write response length error: %v", err)
  252. return
  253. }
  254. if _, err := conn.Write(response); err != nil {
  255. t.logger.Debug("Write response error: %v", err)
  256. return
  257. }
  258. }
  259. }
  260. // Stop stops the TCP server
  261. func (t *TCPTransport) Stop() error {
  262. close(t.shutdownCh)
  263. // Close all pooled connections
  264. t.mu.Lock()
  265. for _, pool := range t.connPool {
  266. close(pool)
  267. for conn := range pool {
  268. conn.Close()
  269. }
  270. }
  271. t.connPool = make(map[string]chan net.Conn)
  272. t.mu.Unlock()
  273. if t.listener != nil {
  274. return t.listener.Close()
  275. }
  276. return nil
  277. }
  278. // getConn gets a connection from the pool or creates a new one
  279. func (t *TCPTransport) getConn(target string) (net.Conn, error) {
  280. t.mu.Lock()
  281. pool, ok := t.connPool[target]
  282. if !ok {
  283. pool = make(chan net.Conn, t.poolSize)
  284. t.connPool[target] = pool
  285. }
  286. t.mu.Unlock()
  287. select {
  288. case conn := <-pool:
  289. return conn, nil
  290. default:
  291. return net.DialTimeout("tcp", target, 5*time.Second)
  292. }
  293. }
  294. // putConn returns a connection to the pool
  295. func (t *TCPTransport) putConn(target string, conn net.Conn) {
  296. t.mu.RLock()
  297. pool, ok := t.connPool[target]
  298. t.mu.RUnlock()
  299. if !ok {
  300. conn.Close()
  301. return
  302. }
  303. select {
  304. case pool <- conn:
  305. default:
  306. conn.Close()
  307. }
  308. }
  309. // sendTCPRPC sends an RPC over TCP
  310. func (t *TCPTransport) sendTCPRPC(ctx context.Context, target string, rpcType RPCType, args interface{}, reply interface{}) error {
  311. conn, err := t.getConn(target)
  312. if err != nil {
  313. return fmt.Errorf("failed to get connection: %w", err)
  314. }
  315. data, err := DefaultCodec.Marshal(args)
  316. if err != nil {
  317. conn.Close()
  318. return fmt.Errorf("failed to marshal request: %w", err)
  319. }
  320. // Set deadline from context
  321. deadline, ok := ctx.Deadline()
  322. if !ok {
  323. deadline = time.Now().Add(5 * time.Second)
  324. }
  325. conn.SetDeadline(deadline)
  326. // Write message: [type(1)][length(4)][body]
  327. header := make([]byte, 5)
  328. header[0] = byte(rpcType)
  329. header[1] = byte(len(data) >> 24)
  330. header[2] = byte(len(data) >> 16)
  331. header[3] = byte(len(data) >> 8)
  332. header[4] = byte(len(data))
  333. if _, err := conn.Write(header); err != nil {
  334. conn.Close()
  335. return fmt.Errorf("failed to write header: %w", err)
  336. }
  337. if _, err := conn.Write(data); err != nil {
  338. conn.Close()
  339. return fmt.Errorf("failed to write body: %w", err)
  340. }
  341. // Read response length
  342. lenBuf := make([]byte, 4)
  343. if _, err := io.ReadFull(conn, lenBuf); err != nil {
  344. conn.Close()
  345. return fmt.Errorf("failed to read response length: %w", err)
  346. }
  347. length := uint32(lenBuf[0])<<24 | uint32(lenBuf[1])<<16 | uint32(lenBuf[2])<<8 | uint32(lenBuf[3])
  348. if length > 10*1024*1024 {
  349. conn.Close()
  350. return fmt.Errorf("response too large: %d", length)
  351. }
  352. // Read response body
  353. respBody := make([]byte, length)
  354. if _, err := io.ReadFull(conn, respBody); err != nil {
  355. conn.Close()
  356. return fmt.Errorf("failed to read response body: %w", err)
  357. }
  358. if err := DefaultCodec.Unmarshal(respBody, reply); err != nil {
  359. conn.Close()
  360. return fmt.Errorf("failed to unmarshal response: %w", err)
  361. }
  362. // Return connection to pool
  363. t.putConn(target, conn)
  364. return nil
  365. }
  366. // RequestVote sends a RequestVote RPC
  367. func (t *TCPTransport) RequestVote(ctx context.Context, target string, args *RequestVoteArgs) (*RequestVoteReply, error) {
  368. var reply RequestVoteReply
  369. err := t.sendTCPRPC(ctx, target, RPCRequestVote, args, &reply)
  370. if err != nil {
  371. return nil, err
  372. }
  373. return &reply, nil
  374. }
  375. // AppendEntries sends an AppendEntries RPC
  376. func (t *TCPTransport) AppendEntries(ctx context.Context, target string, args *AppendEntriesArgs) (*AppendEntriesReply, error) {
  377. var reply AppendEntriesReply
  378. err := t.sendTCPRPC(ctx, target, RPCAppendEntries, args, &reply)
  379. if err != nil {
  380. return nil, err
  381. }
  382. return &reply, nil
  383. }
  384. // InstallSnapshot sends an InstallSnapshot RPC
  385. func (t *TCPTransport) InstallSnapshot(ctx context.Context, target string, args *InstallSnapshotArgs) (*InstallSnapshotReply, error) {
  386. var reply InstallSnapshotReply
  387. err := t.sendTCPRPC(ctx, target, RPCInstallSnapshot, args, &reply)
  388. if err != nil {
  389. return nil, err
  390. }
  391. return &reply, nil
  392. }
  393. // ForwardPropose forwards a propose request to the leader
  394. func (t *TCPTransport) ForwardPropose(ctx context.Context, target string, args *ProposeArgs) (*ProposeReply, error) {
  395. var reply ProposeReply
  396. err := t.sendTCPRPC(ctx, target, RPCPropose, args, &reply)
  397. if err != nil {
  398. return nil, err
  399. }
  400. return &reply, nil
  401. }
  402. // ForwardAddNode forwards an AddNode request to the leader
  403. func (t *TCPTransport) ForwardAddNode(ctx context.Context, target string, args *AddNodeArgs) (*AddNodeReply, error) {
  404. var reply AddNodeReply
  405. err := t.sendTCPRPC(ctx, target, RPCAddNode, args, &reply)
  406. if err != nil {
  407. return nil, err
  408. }
  409. return &reply, nil
  410. }
  411. // ForwardRemoveNode forwards a RemoveNode request to the leader
  412. func (t *TCPTransport) ForwardRemoveNode(ctx context.Context, target string, args *RemoveNodeArgs) (*RemoveNodeReply, error) {
  413. var reply RemoveNodeReply
  414. err := t.sendTCPRPC(ctx, target, RPCRemoveNode, args, &reply)
  415. if err != nil {
  416. return nil, err
  417. }
  418. return &reply, nil
  419. }
  420. // TimeoutNow sends a TimeoutNow RPC for leadership transfer
  421. func (t *TCPTransport) TimeoutNow(ctx context.Context, target string, args *TimeoutNowArgs) (*TimeoutNowReply, error) {
  422. var reply TimeoutNowReply
  423. err := t.sendTCPRPC(ctx, target, RPCTimeoutNow, args, &reply)
  424. if err != nil {
  425. return nil, err
  426. }
  427. return &reply, nil
  428. }
  429. // ReadIndex sends a ReadIndex RPC for linearizable reads
  430. func (t *TCPTransport) ReadIndex(ctx context.Context, target string, args *ReadIndexArgs) (*ReadIndexReply, error) {
  431. var reply ReadIndexReply
  432. err := t.sendTCPRPC(ctx, target, RPCReadIndex, args, &reply)
  433. if err != nil {
  434. return nil, err
  435. }
  436. return &reply, nil
  437. }
  438. // ForwardGet sends a Get RPC for remote KV reads
  439. func (t *TCPTransport) ForwardGet(ctx context.Context, target string, args *GetArgs) (*GetReply, error) {
  440. var reply GetReply
  441. err := t.sendTCPRPC(ctx, target, RPCGet, args, &reply)
  442. if err != nil {
  443. return nil, err
  444. }
  445. return &reply, nil
  446. }