Sfoglia il codice sorgente

节点之间的链接启用链接池,加上basic测试

xbase 2 settimane fa
parent
commit
2c322564d5

+ 76 - 0
example/basic/basic.md

@@ -0,0 +1,76 @@
+# Raft Basic Test Scenario
+
+This directory contains a basic test setup for the Raft cluster with 4 nodes.
+
+## Scenario Description
+
+- **Cluster Setup**:
+    - **Initial Cluster**: Node 1 and Node 2 form the initial cluster.
+    - **Dynamic Members**: Node 3 and Node 4 are started as standalone nodes and can be added to the cluster dynamically using `node add`.
+- **Configuration**:
+    - **Log Compaction**: Disabled (`LogCompactionEnabled = false`). The binary log will grow indefinitely.
+    - **Memory**: Only keys and metadata are cached in memory. Values are stored on disk (default engine behavior).
+- **Data Directory**: `./data/node{1..4}`
+
+## Setup and Usage
+
+### 1. Start Nodes
+
+Open 4 separate terminals and run the following commands:
+
+**Terminal 1 (Node 1):**
+```bash
+cd example/basic/node1
+go run main.go
+```
+
+**Terminal 2 (Node 2):**
+```bash
+cd example/basic/node2
+go run main.go
+```
+
+**Terminal 3 (Node 3):**
+```bash
+cd example/basic/node3
+go run main.go
+```
+
+**Terminal 4 (Node 4):**
+```bash
+cd example/basic/node4
+go run main.go
+```
+
+### 2. CLI Commands
+
+Each node provides an interactive CLI with the following commands:
+
+| Command | Description | Example |
+| :--- | :--- | :--- |
+| `set <key> <val>` | Set a key-value pair. Requests are forwarded to Leader. | `set user:1 bob` |
+| `get <key>` | Get a value by key (Linearizable Read). | `get user:1` |
+| `del <key>` | Delete a key. | `del user:1` |
+| `search <query> [limit] [offset]` | Search keys using SQL-like syntax. | `search key like "user:*" 10 0` |
+| `demodata <count> <pattern>` | Generate demo data. Pattern supports `*` replacement. | `demodata 100 user.name.u*` |
+| `stats` | Show current node status, term, and indices. | `stats` |
+| `binlog` | Show the last CommitIndex in the Raft log. | `binlog` |
+| `db` | Show the last CommitIndex applied to the DB. | `db` |
+| `join <nodeID> <addr>` | (Leader Only) Add a new node to the cluster. | `join node3 127.0.0.1:9003` |
+| `leave <nodeID>` | (Leader Only) Remove a node from the cluster. | `leave node3` |
+| `help` | Show this help message. | `help` |
+
+### 3. Test Workflow
+
+1.  **Verify Initial Cluster**: Check `stats` on Node 1 and Node 2. One should be Leader, the other Follower.
+2.  **Generate Data**: On Node 1 (or any node), run `demodata 100 user.*`.
+3.  **Read Data**: Verify data availability on Node 2 using `get user.1` or `search key like "user.*"`.
+4.  **Expand Cluster**:
+    - Determine the Leader (e.g., Node 1).
+    - On Leader, run: `join node3 127.0.0.1:9003`.
+    - On Leader, run: `join node4 127.0.0.1:9004`.
+5.  **Verify Replication**: Check if Node 3 and Node 4 have the data using `search` or `get`.
+6.  **Delete Data**: Run `del user.1`. Verify it's gone on all nodes.
+7.  **Inspect Logs**: Use `binlog` and `db` to see the commit indices matching across the cluster.
+
+

+ 237 - 0
example/basic/common/cli.go

@@ -0,0 +1,237 @@
+package common
+
+import (
+	"bufio"
+	"fmt"
+	"os"
+	"strconv"
+	"strings"
+	"text/tabwriter"
+	"time"
+
+	"igit.com/xbase/raft"
+)
+
+const (
+	ColorReset  = "\033[0m"
+	ColorDim    = "\033[90m" // Dark Gray
+	ColorRed    = "\033[31m"
+	ColorGreen  = "\033[32m"
+	ColorYellow = "\033[33m"
+	ColorBlue   = "\033[34m"
+	ColorCyan   = "\033[36m"
+)
+
+func StartCLI(server *raft.KVServer, nodeID string) {
+	fmt.Printf("Node %s%s%s CLI Started\n", ColorGreen, nodeID, ColorReset)
+	fmt.Println("Type 'help' for commands.")
+
+	// State Monitor Loop
+	go func() {
+		var lastState string
+		var lastTerm uint64
+		stats := server.GetStats()
+		lastState = stats.State
+		lastTerm = stats.Term
+
+		ticker := time.NewTicker(100 * time.Millisecond)
+		defer ticker.Stop()
+
+		for range ticker.C {
+			stats := server.GetStats()
+			if stats.State != lastState || stats.Term != lastTerm {
+				fmt.Printf("\n%s[State Change] %s (Term %d) -> %s (Term %d)%s\n> ",
+					ColorYellow, lastState, lastTerm, stats.State, stats.Term, ColorReset)
+				lastState = stats.State
+				lastTerm = stats.Term
+			}
+		}
+	}()
+
+	scanner := bufio.NewScanner(os.Stdin)
+	fmt.Print("> ")
+	for scanner.Scan() {
+		text := strings.TrimSpace(scanner.Text())
+		if text == "" {
+			fmt.Print("> ")
+			continue
+		}
+		parts := strings.Fields(text)
+		cmd := strings.ToLower(parts[0])
+
+		switch cmd {
+		case "set":
+			if len(parts) != 3 {
+				fmt.Println("Usage: set <key> <value>")
+				break
+			}
+			key, val := parts[1], parts[2]
+			if err := server.Set(key, val); err != nil {
+				fmt.Printf("%sError:%s %v\n", ColorRed, ColorReset, err)
+			} else {
+				fmt.Printf("%sOK%s\n", ColorGreen, ColorReset)
+			}
+
+		case "get":
+			if len(parts) != 2 {
+				fmt.Println("Usage: get <key>")
+				break
+			}
+			key := parts[1]
+			if val, ok, err := server.GetLinear(key); err != nil {
+				fmt.Printf("%sError:%s %v\n", ColorRed, ColorReset, err)
+			} else if !ok {
+				fmt.Printf("%sNot Found%s\n", ColorYellow, ColorReset)
+			} else {
+				fmt.Printf("%s%s%s = %s%s%s\n", ColorCyan, key, ColorReset, ColorYellow, val, ColorReset)
+			}
+
+		case "del", "delete":
+			if len(parts) != 2 {
+				fmt.Println("Usage: del <key>")
+				break
+			}
+			key := parts[1]
+			if err := server.Del(key); err != nil {
+				fmt.Printf("%sError:%s %v\n", ColorRed, ColorReset, err)
+			} else {
+				fmt.Printf("%sDeleted%s\n", ColorGreen, ColorReset)
+			}
+
+		case "demodata":
+			if len(parts) != 3 {
+				fmt.Println("Usage: demodata <count> <pattern> (e.g. demodata 100 user.*)")
+				break
+			}
+			count, err := strconv.Atoi(parts[1])
+			if err != nil {
+				fmt.Printf("Invalid count: %v\n", err)
+				break
+			}
+			pattern := parts[2]
+			fmt.Printf("Generating %d items with pattern '%s'...\n", count, pattern)
+			
+			start := time.Now()
+			success := 0
+			for i := 1; i <= count; i++ {
+				key := strings.Replace(pattern, "*", strconv.Itoa(i), -1)
+				val := fmt.Sprintf("val-%s", key) // Simple value derivation
+				if err := server.Set(key, val); err != nil {
+					fmt.Printf("Failed at %d: %v\n", i, err)
+				} else {
+					success++
+				}
+				if i%100 == 0 {
+					fmt.Printf("Progress: %d/%d\r", i, count)
+				}
+			}
+			duration := time.Since(start)
+			fmt.Printf("\n%sDone!%s inserted %d/%d items in %v (Avg: %v/op)\n", 
+				ColorGreen, ColorReset, success, count, duration, duration/time.Duration(count))
+
+		case "search":
+			// search <pattern> [limit] [offset]
+			// Internally constructs: key like "<pattern>" LIMIT <limit> OFFSET <offset>
+			if len(parts) < 2 {
+				fmt.Println("Usage: search <pattern> [limit] [offset]")
+				break
+			}
+			pattern := parts[1]
+			limit := 20 // Default limit
+			offset := 0
+			
+			if len(parts) >= 3 {
+				if l, err := strconv.Atoi(parts[2]); err == nil {
+					limit = l
+				}
+			}
+			if len(parts) >= 4 {
+				if o, err := strconv.Atoi(parts[3]); err == nil {
+					offset = o
+				}
+			}
+
+			// Construct SQL for DB Engine
+			sql := fmt.Sprintf("key like \"%s\" LIMIT %d OFFSET %d", pattern, limit, offset)
+			results, err := server.DB.Query(sql)
+			if err != nil {
+				fmt.Printf("%sError:%s %v\n", ColorRed, ColorReset, err)
+				break
+			}
+
+			// Print Table
+			w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
+			fmt.Fprintln(w, "Key\tValue\tCommitIndex")
+			fmt.Fprintln(w, "---\t-----\t-----------")
+			for _, r := range results {
+				fmt.Fprintf(w, "%s\t%s\t%d\n", r.Key, r.Value, r.CommitIndex)
+			}
+			w.Flush()
+			fmt.Printf("%sFound %d records (showing max %d)%s\n", ColorDim, len(results), limit, ColorReset)
+
+		case "binlog":
+			// Show Raft Log Stats
+			stats := server.GetStats() // Basic stats
+			fmt.Printf("Raft Log CommitIndex: %s%d%s\n", ColorGreen, stats.CommitIndex, ColorReset)
+			fmt.Printf("Raft Log AppliedIndex: %s%d%s\n", ColorGreen, stats.LastApplied, ColorReset)
+			fmt.Printf("Raft LastLogIndex: %s%d%s\n", ColorGreen, stats.LastLogIndex, ColorReset)
+
+		case "db":
+			// Show DB Stats
+			idx := server.DB.GetLastAppliedIndex()
+			fmt.Printf("DB LastAppliedIndex: %s%d%s\n", ColorGreen, idx, ColorReset)
+
+		case "stats":
+			stats := server.GetStats()
+			health := server.HealthCheck()
+			fmt.Printf("Node: %s\nState: %s\nLeader: %s\nTerm: %d\nHealthy: %v\n", 
+				health.NodeID, health.State, health.LeaderID, health.Term, health.IsHealthy)
+			fmt.Printf("CommitIndex: %d, Applied: %d\n", stats.CommitIndex, stats.LastApplied)
+			fmt.Printf("Cluster Size: %d\n", stats.ClusterSize)
+			fmt.Println("Cluster Nodes:")
+			for id, addr := range stats.ClusterNodes {
+				fmt.Printf("  - %s: %s\n", id, addr)
+			}
+
+		case "join":
+			if len(parts) != 3 {
+				fmt.Println("Usage: join <nodeID> <addr>")
+				break
+			}
+			if err := server.Join(parts[1], parts[2]); err != nil {
+				fmt.Printf("%sError:%s %v\n", ColorRed, ColorReset, err)
+			} else {
+				fmt.Printf("%sJoined node %s%s\n", ColorGreen, parts[1], ColorReset)
+			}
+
+		case "leave":
+			if len(parts) != 2 {
+				fmt.Println("Usage: leave <nodeID>")
+				break
+			}
+			if err := server.Leave(parts[1]); err != nil {
+				fmt.Printf("%sError:%s %v\n", ColorRed, ColorReset, err)
+			} else {
+				fmt.Printf("%sRemoved node %s%s\n", ColorGreen, parts[1], ColorReset)
+			}
+
+		case "help":
+			fmt.Println("Commands:")
+			fmt.Println("  set <key> <val>           Set value")
+			fmt.Println("  get <key>                 Get value")
+			fmt.Println("  del <key>                 Delete key")
+			fmt.Println("  demodata <n> <pat>        Generate n items (e.g. 'demodata 100 user.*')")
+			fmt.Println("  search <pat> [lim] [off]  Search keys (e.g. 'search user.* 10 0')")
+			fmt.Println("  binlog                    Show Raft log indices")
+			fmt.Println("  db                        Show DB applied index")
+			fmt.Println("  stats                     Show node stats")
+			fmt.Println("  join <id> <addr>          Add node to cluster")
+			fmt.Println("  leave <id>                Remove node from cluster")
+
+		default:
+			fmt.Println("Unknown command. Type 'help'.")
+		}
+		fmt.Print("> ")
+	}
+}
+

+ 53 - 0
example/basic/node1/main.go

@@ -0,0 +1,53 @@
+package main
+
+import (
+	"log"
+	"os"
+
+	"igit.com/xbase/raft"
+	"igit.com/xbase/raft/example/basic/common"
+)
+
+func main() {
+	// Configuration
+	nodeID := "node1"
+	addr := "127.0.0.1:9001"
+	dataDir := "../../data/node1"
+
+	// Initial Cluster configuration (Node 1 + Node 2)
+	clusterNodes := map[string]string{
+		"node1": "127.0.0.1:9001",
+		"node2": "127.0.0.1:9002",
+	}
+
+	config := raft.DefaultConfig()
+	config.NodeID = nodeID
+	config.ListenAddr = addr
+	config.DataDir = dataDir
+	config.ClusterNodes = clusterNodes
+	config.LogCompactionEnabled = false // Requirement: No bin log compression
+	
+	// Console Logger
+	config.Logger = raft.NewConsoleLogger(nodeID, 1) // Info level
+
+	// Ensure data directory exists
+	if err := os.MkdirAll(dataDir, 0755); err != nil {
+		log.Fatalf("Failed to create data directory: %v", err)
+	}
+
+	// Create KV Server
+	server, err := raft.NewKVServer(config)
+	if err != nil {
+		log.Fatalf("Failed to create server: %v", err)
+	}
+
+	// Start server
+	if err := server.Start(); err != nil {
+		log.Fatalf("Failed to start server: %v", err)
+	}
+	defer server.Stop()
+
+	// Start CLI
+	common.StartCLI(server, nodeID)
+}
+

+ 48 - 0
example/basic/node2/main.go

@@ -0,0 +1,48 @@
+package main
+
+import (
+	"log"
+	"os"
+
+	"igit.com/xbase/raft"
+	"igit.com/xbase/raft/example/basic/common"
+)
+
+func main() {
+	// Configuration
+	nodeID := "node2"
+	addr := "127.0.0.1:9002"
+	dataDir := "../../data/node2"
+
+	// Initial Cluster configuration (Node 1 + Node 2)
+	clusterNodes := map[string]string{
+		"node1": "127.0.0.1:9001",
+		"node2": "127.0.0.1:9002",
+	}
+
+	config := raft.DefaultConfig()
+	config.NodeID = nodeID
+	config.ListenAddr = addr
+	config.DataDir = dataDir
+	config.ClusterNodes = clusterNodes
+	config.LogCompactionEnabled = false
+	
+	config.Logger = raft.NewConsoleLogger(nodeID, 1)
+
+	if err := os.MkdirAll(dataDir, 0755); err != nil {
+		log.Fatalf("Failed to create data directory: %v", err)
+	}
+
+	server, err := raft.NewKVServer(config)
+	if err != nil {
+		log.Fatalf("Failed to create server: %v", err)
+	}
+
+	if err := server.Start(); err != nil {
+		log.Fatalf("Failed to start server: %v", err)
+	}
+	defer server.Stop()
+
+	common.StartCLI(server, nodeID)
+}
+

+ 47 - 0
example/basic/node3/main.go

@@ -0,0 +1,47 @@
+package main
+
+import (
+	"log"
+	"os"
+
+	"igit.com/xbase/raft"
+	"igit.com/xbase/raft/example/basic/common"
+)
+
+func main() {
+	// Configuration
+	nodeID := "node3"
+	addr := "127.0.0.1:9003"
+	dataDir := "../../data/node3"
+
+	// Standalone configuration (will be joined to cluster later)
+	clusterNodes := map[string]string{
+		"node3": "127.0.0.1:9003",
+	}
+
+	config := raft.DefaultConfig()
+	config.NodeID = nodeID
+	config.ListenAddr = addr
+	config.DataDir = dataDir
+	config.ClusterNodes = clusterNodes
+	config.LogCompactionEnabled = false
+	
+	config.Logger = raft.NewConsoleLogger(nodeID, 1)
+
+	if err := os.MkdirAll(dataDir, 0755); err != nil {
+		log.Fatalf("Failed to create data directory: %v", err)
+	}
+
+	server, err := raft.NewKVServer(config)
+	if err != nil {
+		log.Fatalf("Failed to create server: %v", err)
+	}
+
+	if err := server.Start(); err != nil {
+		log.Fatalf("Failed to start server: %v", err)
+	}
+	defer server.Stop()
+
+	common.StartCLI(server, nodeID)
+}
+

+ 47 - 0
example/basic/node4/main.go

@@ -0,0 +1,47 @@
+package main
+
+import (
+	"log"
+	"os"
+
+	"igit.com/xbase/raft"
+	"igit.com/xbase/raft/example/basic/common"
+)
+
+func main() {
+	// Configuration
+	nodeID := "node4"
+	addr := "127.0.0.1:9004"
+	dataDir := "../../data/node4"
+
+	// Standalone configuration
+	clusterNodes := map[string]string{
+		"node4": "127.0.0.1:9004",
+	}
+
+	config := raft.DefaultConfig()
+	config.NodeID = nodeID
+	config.ListenAddr = addr
+	config.DataDir = dataDir
+	config.ClusterNodes = clusterNodes
+	config.LogCompactionEnabled = false
+	
+	config.Logger = raft.NewConsoleLogger(nodeID, 1)
+
+	if err := os.MkdirAll(dataDir, 0755); err != nil {
+		log.Fatalf("Failed to create data directory: %v", err)
+	}
+
+	server, err := raft.NewKVServer(config)
+	if err != nil {
+		log.Fatalf("Failed to create server: %v", err)
+	}
+
+	if err := server.Start(); err != nil {
+		log.Fatalf("Failed to start server: %v", err)
+	}
+	defer server.Stop()
+
+	common.StartCLI(server, nodeID)
+}
+

+ 110 - 92
rpc.go

@@ -72,9 +72,8 @@ type TCPTransport struct {
 	listener   net.Listener
 	shutdownCh chan struct{}
 
-	// Connection pool
-	connPool map[string]chan net.Conn
-	poolSize int
+	// Single persistent connection per target
+	conns map[string]net.Conn
 }
 
 // NewTCPTransport creates a new TCP transport
@@ -82,16 +81,12 @@ func NewTCPTransport(localAddr string, poolSize int, logger Logger) *TCPTranspor
 	if logger == nil {
 		logger = &NoopLogger{}
 	}
-	if poolSize <= 0 {
-		poolSize = 5
-	}
 
 	return &TCPTransport{
 		localAddr:  localAddr,
 		logger:     logger,
 		shutdownCh: make(chan struct{}),
-		connPool:   make(map[string]chan net.Conn),
-		poolSize:   poolSize,
+		conns:      make(map[string]net.Conn),
 	}
 }
 
@@ -312,15 +307,12 @@ func (t *TCPTransport) handleConnection(conn net.Conn) {
 func (t *TCPTransport) Stop() error {
 	close(t.shutdownCh)
 
-	// Close all pooled connections
+	// Close all connections
 	t.mu.Lock()
-	for _, pool := range t.connPool {
-		close(pool)
-		for conn := range pool {
-			conn.Close()
-		}
+	for _, conn := range t.conns {
+		conn.Close()
 	}
-	t.connPool = make(map[string]chan net.Conn)
+	t.conns = make(map[string]net.Conn)
 	t.mu.Unlock()
 
 	if t.listener != nil {
@@ -329,107 +321,133 @@ func (t *TCPTransport) Stop() error {
 	return nil
 }
 
-// getConn gets a connection from the pool or creates a new one
+// getConn gets the persistent connection or creates a new one
 func (t *TCPTransport) getConn(target string) (net.Conn, error) {
 	t.mu.Lock()
-	pool, ok := t.connPool[target]
-	if !ok {
-		pool = make(chan net.Conn, t.poolSize)
-		t.connPool[target] = pool
-	}
-	t.mu.Unlock()
+	defer t.mu.Unlock()
 
-	select {
-	case conn := <-pool:
+	// Check existing connection
+	if conn, ok := t.conns[target]; ok {
 		return conn, nil
-	default:
-		return net.DialTimeout("tcp", target, 5*time.Second)
 	}
-}
 
-// putConn returns a connection to the pool
-func (t *TCPTransport) putConn(target string, conn net.Conn) {
-	t.mu.RLock()
-	pool, ok := t.connPool[target]
-	t.mu.RUnlock()
-
-	if !ok {
-		conn.Close()
-		return
+	// Dial new connection
+	conn, err := net.DialTimeout("tcp", target, 5*time.Second)
+	if err != nil {
+		return nil, err
 	}
 
-	select {
-	case pool <- conn:
-	default:
+	t.conns[target] = conn
+	return conn, nil
+}
+
+// closeConn closes and removes a connection from the map
+func (t *TCPTransport) closeConn(target string, conn net.Conn) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
+	// Only delete if it's the current connection
+	if current, ok := t.conns[target]; ok && current == conn {
+		delete(t.conns, target)
 		conn.Close()
 	}
 }
 
 // sendTCPRPC sends an RPC over TCP
 func (t *TCPTransport) sendTCPRPC(ctx context.Context, target string, rpcType RPCType, args interface{}, reply interface{}) error {
-	conn, err := t.getConn(target)
-	if err != nil {
-		return fmt.Errorf("failed to get connection: %w", err)
-	}
+	// Simple retry mechanism for stale connections
+	maxRetries := 2
+	var lastErr error
 
-	data, err := DefaultCodec.Marshal(args)
-	if err != nil {
-		conn.Close()
-		return fmt.Errorf("failed to marshal request: %w", err)
-	}
+	for i := 0; i < maxRetries; i++ {
+		conn, err := t.getConn(target)
+		if err != nil {
+			return fmt.Errorf("failed to get connection: %w", err)
+		}
 
-	// Set deadline from context
-	deadline, ok := ctx.Deadline()
-	if !ok {
-		deadline = time.Now().Add(5 * time.Second)
-	}
-	conn.SetDeadline(deadline)
+		data, err := DefaultCodec.Marshal(args)
+		if err != nil {
+			// Don't close conn here as we haven't touched it yet in this iteration
+			return fmt.Errorf("failed to marshal request: %w", err)
+		}
+
+		// Set deadline from context
+		deadline, ok := ctx.Deadline()
+		if !ok {
+			deadline = time.Now().Add(5 * time.Second)
+		}
+		conn.SetDeadline(deadline)
+
+		// Write message: [type(1)][length(4)][body]
+		header := make([]byte, 5)
+		header[0] = byte(rpcType)
+		header[1] = byte(len(data) >> 24)
+		header[2] = byte(len(data) >> 16)
+		header[3] = byte(len(data) >> 8)
+		header[4] = byte(len(data))
+
+		if _, err := conn.Write(header); err != nil {
+			t.closeConn(target, conn)
+			lastErr = fmt.Errorf("failed to write header: %w", err)
+			// If this was a reused connection, retry with a new one
+			if i < maxRetries-1 {
+				continue
+			}
+			return lastErr
+		}
+		if _, err := conn.Write(data); err != nil {
+			t.closeConn(target, conn)
+			lastErr = fmt.Errorf("failed to write body: %w", err)
+			// If write failed, it might be a broken pipe, but we already wrote header.
+			// Retrying here is risky if the server processed the header but not body,
+			// but for idempotent RPCs it might be okay. For safety, we only retry if it looks like a connection issue on write.
+			// However, since we're in binary protocol, partial writes break the stream anyway.
+			if i < maxRetries-1 {
+				continue
+			}
+			return lastErr
+		}
 
-	// Write message: [type(1)][length(4)][body]
-	header := make([]byte, 5)
-	header[0] = byte(rpcType)
-	header[1] = byte(len(data) >> 24)
-	header[2] = byte(len(data) >> 16)
-	header[3] = byte(len(data) >> 8)
-	header[4] = byte(len(data))
+		// Read response length
+		lenBuf := make([]byte, 4)
+		if _, err := io.ReadFull(conn, lenBuf); err != nil {
+			t.closeConn(target, conn)
+			lastErr = fmt.Errorf("failed to read response length: %w", err)
+			if i < maxRetries-1 && (err == io.EOF || isConnectionReset(err)) {
+				continue
+			}
+			return lastErr
+		}
 
-	if _, err := conn.Write(header); err != nil {
-		conn.Close()
-		return fmt.Errorf("failed to write header: %w", err)
-	}
-	if _, err := conn.Write(data); err != nil {
-		conn.Close()
-		return fmt.Errorf("failed to write body: %w", err)
-	}
+		length := uint32(lenBuf[0])<<24 | uint32(lenBuf[1])<<16 | uint32(lenBuf[2])<<8 | uint32(lenBuf[3])
+		if length > 10*1024*1024 {
+			t.closeConn(target, conn)
+			return fmt.Errorf("response too large: %d", length)
+		}
 
-	// Read response length
-	lenBuf := make([]byte, 4)
-	if _, err := io.ReadFull(conn, lenBuf); err != nil {
-		conn.Close()
-		return fmt.Errorf("failed to read response length: %w", err)
-	}
+		// Read response body
+		respBody := make([]byte, length)
+		if _, err := io.ReadFull(conn, respBody); err != nil {
+			t.closeConn(target, conn)
+			return fmt.Errorf("failed to read response body: %w", err)
+		}
 
-	length := uint32(lenBuf[0])<<24 | uint32(lenBuf[1])<<16 | uint32(lenBuf[2])<<8 | uint32(lenBuf[3])
-	if length > 10*1024*1024 {
-		conn.Close()
-		return fmt.Errorf("response too large: %d", length)
-	}
+		if err := DefaultCodec.Unmarshal(respBody, reply); err != nil {
+			t.closeConn(target, conn)
+			return fmt.Errorf("failed to unmarshal response: %w", err)
+		}
 
-	// Read response body
-	respBody := make([]byte, length)
-	if _, err := io.ReadFull(conn, respBody); err != nil {
-		conn.Close()
-		return fmt.Errorf("failed to read response body: %w", err)
+		// Keep connection open
+		return nil
 	}
+	return lastErr
+}
 
-	if err := DefaultCodec.Unmarshal(respBody, reply); err != nil {
-		conn.Close()
-		return fmt.Errorf("failed to unmarshal response: %w", err)
+func isConnectionReset(err error) bool {
+	if opErr, ok := err.(*net.OpError); ok {
+		return opErr.Err.Error() == "connection reset by peer" || opErr.Err.Error() == "broken pipe"
 	}
-
-	// Return connection to pool
-	t.putConn(target, conn)
-	return nil
+	return false
 }
 
 // RequestVote sends a RequestVote RPC