|
@@ -72,9 +72,8 @@ type TCPTransport struct {
|
|
|
listener net.Listener
|
|
listener net.Listener
|
|
|
shutdownCh chan struct{}
|
|
shutdownCh chan struct{}
|
|
|
|
|
|
|
|
- // Connection pool
|
|
|
|
|
- connPool map[string]chan net.Conn
|
|
|
|
|
- poolSize int
|
|
|
|
|
|
|
+ // Single persistent connection per target
|
|
|
|
|
+ conns map[string]net.Conn
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// NewTCPTransport creates a new TCP transport
|
|
// NewTCPTransport creates a new TCP transport
|
|
@@ -82,16 +81,12 @@ func NewTCPTransport(localAddr string, poolSize int, logger Logger) *TCPTranspor
|
|
|
if logger == nil {
|
|
if logger == nil {
|
|
|
logger = &NoopLogger{}
|
|
logger = &NoopLogger{}
|
|
|
}
|
|
}
|
|
|
- if poolSize <= 0 {
|
|
|
|
|
- poolSize = 5
|
|
|
|
|
- }
|
|
|
|
|
|
|
|
|
|
return &TCPTransport{
|
|
return &TCPTransport{
|
|
|
localAddr: localAddr,
|
|
localAddr: localAddr,
|
|
|
logger: logger,
|
|
logger: logger,
|
|
|
shutdownCh: make(chan struct{}),
|
|
shutdownCh: make(chan struct{}),
|
|
|
- connPool: make(map[string]chan net.Conn),
|
|
|
|
|
- poolSize: poolSize,
|
|
|
|
|
|
|
+ conns: make(map[string]net.Conn),
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -312,15 +307,12 @@ func (t *TCPTransport) handleConnection(conn net.Conn) {
|
|
|
func (t *TCPTransport) Stop() error {
|
|
func (t *TCPTransport) Stop() error {
|
|
|
close(t.shutdownCh)
|
|
close(t.shutdownCh)
|
|
|
|
|
|
|
|
- // Close all pooled connections
|
|
|
|
|
|
|
+ // Close all connections
|
|
|
t.mu.Lock()
|
|
t.mu.Lock()
|
|
|
- for _, pool := range t.connPool {
|
|
|
|
|
- close(pool)
|
|
|
|
|
- for conn := range pool {
|
|
|
|
|
- conn.Close()
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ for _, conn := range t.conns {
|
|
|
|
|
+ conn.Close()
|
|
|
}
|
|
}
|
|
|
- t.connPool = make(map[string]chan net.Conn)
|
|
|
|
|
|
|
+ t.conns = make(map[string]net.Conn)
|
|
|
t.mu.Unlock()
|
|
t.mu.Unlock()
|
|
|
|
|
|
|
|
if t.listener != nil {
|
|
if t.listener != nil {
|
|
@@ -329,107 +321,133 @@ func (t *TCPTransport) Stop() error {
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// getConn gets a connection from the pool or creates a new one
|
|
|
|
|
|
|
+// getConn gets the persistent connection or creates a new one
|
|
|
func (t *TCPTransport) getConn(target string) (net.Conn, error) {
|
|
func (t *TCPTransport) getConn(target string) (net.Conn, error) {
|
|
|
t.mu.Lock()
|
|
t.mu.Lock()
|
|
|
- pool, ok := t.connPool[target]
|
|
|
|
|
- if !ok {
|
|
|
|
|
- pool = make(chan net.Conn, t.poolSize)
|
|
|
|
|
- t.connPool[target] = pool
|
|
|
|
|
- }
|
|
|
|
|
- t.mu.Unlock()
|
|
|
|
|
|
|
+ defer t.mu.Unlock()
|
|
|
|
|
|
|
|
- select {
|
|
|
|
|
- case conn := <-pool:
|
|
|
|
|
|
|
+ // Check existing connection
|
|
|
|
|
+ if conn, ok := t.conns[target]; ok {
|
|
|
return conn, nil
|
|
return conn, nil
|
|
|
- default:
|
|
|
|
|
- return net.DialTimeout("tcp", target, 5*time.Second)
|
|
|
|
|
}
|
|
}
|
|
|
-}
|
|
|
|
|
|
|
|
|
|
-// putConn returns a connection to the pool
|
|
|
|
|
-func (t *TCPTransport) putConn(target string, conn net.Conn) {
|
|
|
|
|
- t.mu.RLock()
|
|
|
|
|
- pool, ok := t.connPool[target]
|
|
|
|
|
- t.mu.RUnlock()
|
|
|
|
|
-
|
|
|
|
|
- if !ok {
|
|
|
|
|
- conn.Close()
|
|
|
|
|
- return
|
|
|
|
|
|
|
+ // Dial new connection
|
|
|
|
|
+ conn, err := net.DialTimeout("tcp", target, 5*time.Second)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- select {
|
|
|
|
|
- case pool <- conn:
|
|
|
|
|
- default:
|
|
|
|
|
|
|
+ t.conns[target] = conn
|
|
|
|
|
+ return conn, nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// closeConn closes and removes a connection from the map
|
|
|
|
|
+func (t *TCPTransport) closeConn(target string, conn net.Conn) {
|
|
|
|
|
+ t.mu.Lock()
|
|
|
|
|
+ defer t.mu.Unlock()
|
|
|
|
|
+
|
|
|
|
|
+ // Only delete if it's the current connection
|
|
|
|
|
+ if current, ok := t.conns[target]; ok && current == conn {
|
|
|
|
|
+ delete(t.conns, target)
|
|
|
conn.Close()
|
|
conn.Close()
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// sendTCPRPC sends an RPC over TCP
|
|
// sendTCPRPC sends an RPC over TCP
|
|
|
func (t *TCPTransport) sendTCPRPC(ctx context.Context, target string, rpcType RPCType, args interface{}, reply interface{}) error {
|
|
func (t *TCPTransport) sendTCPRPC(ctx context.Context, target string, rpcType RPCType, args interface{}, reply interface{}) error {
|
|
|
- conn, err := t.getConn(target)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return fmt.Errorf("failed to get connection: %w", err)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ // Simple retry mechanism for stale connections
|
|
|
|
|
+ maxRetries := 2
|
|
|
|
|
+ var lastErr error
|
|
|
|
|
|
|
|
- data, err := DefaultCodec.Marshal(args)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- conn.Close()
|
|
|
|
|
- return fmt.Errorf("failed to marshal request: %w", err)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ for i := 0; i < maxRetries; i++ {
|
|
|
|
|
+ conn, err := t.getConn(target)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return fmt.Errorf("failed to get connection: %w", err)
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // Set deadline from context
|
|
|
|
|
- deadline, ok := ctx.Deadline()
|
|
|
|
|
- if !ok {
|
|
|
|
|
- deadline = time.Now().Add(5 * time.Second)
|
|
|
|
|
- }
|
|
|
|
|
- conn.SetDeadline(deadline)
|
|
|
|
|
|
|
+ data, err := DefaultCodec.Marshal(args)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ // Don't close conn here as we haven't touched it yet in this iteration
|
|
|
|
|
+ return fmt.Errorf("failed to marshal request: %w", err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Set deadline from context
|
|
|
|
|
+ deadline, ok := ctx.Deadline()
|
|
|
|
|
+ if !ok {
|
|
|
|
|
+ deadline = time.Now().Add(5 * time.Second)
|
|
|
|
|
+ }
|
|
|
|
|
+ conn.SetDeadline(deadline)
|
|
|
|
|
+
|
|
|
|
|
+ // Write message: [type(1)][length(4)][body]
|
|
|
|
|
+ header := make([]byte, 5)
|
|
|
|
|
+ header[0] = byte(rpcType)
|
|
|
|
|
+ header[1] = byte(len(data) >> 24)
|
|
|
|
|
+ header[2] = byte(len(data) >> 16)
|
|
|
|
|
+ header[3] = byte(len(data) >> 8)
|
|
|
|
|
+ header[4] = byte(len(data))
|
|
|
|
|
+
|
|
|
|
|
+ if _, err := conn.Write(header); err != nil {
|
|
|
|
|
+ t.closeConn(target, conn)
|
|
|
|
|
+ lastErr = fmt.Errorf("failed to write header: %w", err)
|
|
|
|
|
+ // If this was a reused connection, retry with a new one
|
|
|
|
|
+ if i < maxRetries-1 {
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+ return lastErr
|
|
|
|
|
+ }
|
|
|
|
|
+ if _, err := conn.Write(data); err != nil {
|
|
|
|
|
+ t.closeConn(target, conn)
|
|
|
|
|
+ lastErr = fmt.Errorf("failed to write body: %w", err)
|
|
|
|
|
+ // If write failed, it might be a broken pipe, but we already wrote header.
|
|
|
|
|
+ // Retrying here is risky if the server processed the header but not body,
|
|
|
|
|
+ // but for idempotent RPCs it might be okay. For safety, we only retry if it looks like a connection issue on write.
|
|
|
|
|
+ // However, since we're in binary protocol, partial writes break the stream anyway.
|
|
|
|
|
+ if i < maxRetries-1 {
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+ return lastErr
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // Write message: [type(1)][length(4)][body]
|
|
|
|
|
- header := make([]byte, 5)
|
|
|
|
|
- header[0] = byte(rpcType)
|
|
|
|
|
- header[1] = byte(len(data) >> 24)
|
|
|
|
|
- header[2] = byte(len(data) >> 16)
|
|
|
|
|
- header[3] = byte(len(data) >> 8)
|
|
|
|
|
- header[4] = byte(len(data))
|
|
|
|
|
|
|
+ // Read response length
|
|
|
|
|
+ lenBuf := make([]byte, 4)
|
|
|
|
|
+ if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
|
|
|
|
+ t.closeConn(target, conn)
|
|
|
|
|
+ lastErr = fmt.Errorf("failed to read response length: %w", err)
|
|
|
|
|
+ if i < maxRetries-1 && (err == io.EOF || isConnectionReset(err)) {
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+ return lastErr
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- if _, err := conn.Write(header); err != nil {
|
|
|
|
|
- conn.Close()
|
|
|
|
|
- return fmt.Errorf("failed to write header: %w", err)
|
|
|
|
|
- }
|
|
|
|
|
- if _, err := conn.Write(data); err != nil {
|
|
|
|
|
- conn.Close()
|
|
|
|
|
- return fmt.Errorf("failed to write body: %w", err)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ length := uint32(lenBuf[0])<<24 | uint32(lenBuf[1])<<16 | uint32(lenBuf[2])<<8 | uint32(lenBuf[3])
|
|
|
|
|
+ if length > 10*1024*1024 {
|
|
|
|
|
+ t.closeConn(target, conn)
|
|
|
|
|
+ return fmt.Errorf("response too large: %d", length)
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // Read response length
|
|
|
|
|
- lenBuf := make([]byte, 4)
|
|
|
|
|
- if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
|
|
|
|
- conn.Close()
|
|
|
|
|
- return fmt.Errorf("failed to read response length: %w", err)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ // Read response body
|
|
|
|
|
+ respBody := make([]byte, length)
|
|
|
|
|
+ if _, err := io.ReadFull(conn, respBody); err != nil {
|
|
|
|
|
+ t.closeConn(target, conn)
|
|
|
|
|
+ return fmt.Errorf("failed to read response body: %w", err)
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- length := uint32(lenBuf[0])<<24 | uint32(lenBuf[1])<<16 | uint32(lenBuf[2])<<8 | uint32(lenBuf[3])
|
|
|
|
|
- if length > 10*1024*1024 {
|
|
|
|
|
- conn.Close()
|
|
|
|
|
- return fmt.Errorf("response too large: %d", length)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ if err := DefaultCodec.Unmarshal(respBody, reply); err != nil {
|
|
|
|
|
+ t.closeConn(target, conn)
|
|
|
|
|
+ return fmt.Errorf("failed to unmarshal response: %w", err)
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // Read response body
|
|
|
|
|
- respBody := make([]byte, length)
|
|
|
|
|
- if _, err := io.ReadFull(conn, respBody); err != nil {
|
|
|
|
|
- conn.Close()
|
|
|
|
|
- return fmt.Errorf("failed to read response body: %w", err)
|
|
|
|
|
|
|
+ // Keep connection open
|
|
|
|
|
+ return nil
|
|
|
}
|
|
}
|
|
|
|
|
+ return lastErr
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
- if err := DefaultCodec.Unmarshal(respBody, reply); err != nil {
|
|
|
|
|
- conn.Close()
|
|
|
|
|
- return fmt.Errorf("failed to unmarshal response: %w", err)
|
|
|
|
|
|
|
+func isConnectionReset(err error) bool {
|
|
|
|
|
+ if opErr, ok := err.(*net.OpError); ok {
|
|
|
|
|
+ return opErr.Err.Error() == "connection reset by peer" || opErr.Err.Error() == "broken pipe"
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- // Return connection to pool
|
|
|
|
|
- t.putConn(target, conn)
|
|
|
|
|
- return nil
|
|
|
|
|
|
|
+ return false
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// RequestVote sends a RequestVote RPC
|
|
// RequestVote sends a RequestVote RPC
|