package raft import ( "bufio" "encoding/json" "fmt" "io" "net" "runtime" "strconv" "strings" "time" ) // TCPClientSession holds state for a TCP connection type TCPClientSession struct { conn net.Conn server *KVServer token string username string reader *bufio.Reader writer *bufio.Writer remoteAddr string } func (s *KVServer) StartTCPServer(addr string) error { listener, err := net.Listen("tcp", addr) if err != nil { return err } s.Raft.config.Logger.Info("TCP API server listening on %s", addr) go func() { defer listener.Close() for { conn, err := listener.Accept() if err != nil { if s.stopCh != nil { select { case <-s.stopCh: return default: } } s.Raft.config.Logger.Error("TCP Accept error: %v", err) continue } go s.handleTCPConnection(conn) } }() return nil } func (s *KVServer) handleTCPConnection(conn net.Conn) { defer conn.Close() // Use larger buffers for high throughput reader := bufio.NewReaderSize(conn, 64*1024) writer := bufio.NewWriterSize(conn, 64*1024) session := &TCPClientSession{ conn: conn, server: s, reader: reader, writer: writer, remoteAddr: conn.RemoteAddr().String(), } for { // Set Keep-Alive Deadline // Using a longer deadline avoids frequent syscalls if traffic is continuous conn.SetReadDeadline(time.Now().Add(60 * time.Second)) // Read Request Line line, err := reader.ReadString('\n') if err != nil { return } line = strings.TrimSpace(line) if line == "" { continue } // Parse Headers var contentLength int for { hLine, err := reader.ReadString('\n') if err != nil { return } hLine = strings.TrimSpace(hLine) if hLine == "" { break } // Optimized Content-Length check // "Content-Length: 123" lowerHLine := strings.ToLower(hLine) if strings.HasPrefix(lowerHLine, "content-length:") { valStr := strings.TrimSpace(hLine[15:]) if l, err := strconv.Atoi(valStr); err == nil { contentLength = l } } } // Read Body var body string if contentLength > 0 { // Optimization: Avoid allocation for small bodies if possible, // but Raft commands need to be passed as bytes/string anyway. buf := make([]byte, contentLength) if _, err := io.ReadFull(reader, buf); err != nil { return } body = string(buf) } // Implement Batching: Accumulate requests if they are ASET/SET // Ideally we need an internal buffer or channel here to queue up commands. // For zero-allocation batching in this loop structure, we can try to // eagerly read the next request if available in the buffer. // Note: Raft's Propose is thread-safe. The current serial loop is efficient for // minimizing context switches. Batching helps if we can merge multiple Proposes into one. // Since we haven't modified Raft to support ProposeBatch yet, we can't do true backend batching easily. // However, we can do "frontend batching" by checking buffer availability? // No, frontend batching requires Raft to accept a batch. // Without modifying Raft.Propose to accept []Command, "batching" here is limited. // BUT, we can at least pipeline the execution if we had concurrent workers, // but we replaced that with this serial loop for perf. // So for now, we just execute. // Execute Command Inline // s.Raft.config.Logger.Debug("TCP Request: %s", line) resp := s.executeCommandWithBody(session, line, body) // Write Response if _, err := writer.WriteString(resp + "\n"); err != nil { return } // Flush Optimization: // Only flush if there is no more data in the read buffer. // This automatically batches responses when requests are pipelined. if reader.Buffered() == 0 { if err := writer.Flush(); err != nil { return } } } } // Helper to bridge old logic func (s *KVServer) executeCommandWithBody(session *TCPClientSession, line string, body string) string { parts := strings.Fields(line) if len(parts) == 0 { return "" } cmd := strings.ToUpper(parts[0]) // If body is present, it overrides the value part of the command // We handle specific commands that use body if body != "" { switch cmd { case "SET", "ASET": if len(parts) < 2 { return "ERR usage: SET (value in body)" } key := parts[1] // Value is body if cmd == "SET" { if err := s.SetAuthenticated(key, body, session.token); err != nil { return fmt.Sprintf("ERR %v", err) } return "OK" } else { if err := s.SetAuthenticatedAsync(key, body, session.token); err != nil { return fmt.Sprintf("ERR %v", err) } return "OK" } } } // Fallback to legacy parsing if no body or command doesn't use body return s.executeCommand(session, line) } func (s *KVServer) executeCommand(session *TCPClientSession, line string) string { parts := strings.Fields(line) if len(parts) == 0 { return "" } cmd := strings.ToUpper(parts[0]) var resp string switch cmd { case "LOGIN": if len(parts) < 3 { resp = "ERR usage: LOGIN [otp]" } else { user := parts[1] pass := parts[2] otp := "" if len(parts) > 3 { otp = parts[3] } // Extract IP ip := session.remoteAddr if host, _, err := net.SplitHostPort(ip); err == nil { ip = host } token, err := s.AuthManager.Login(user, pass, otp, ip) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { session.token = token session.username = user resp = fmt.Sprintf("OK %s", token) } } case "AUTH": if len(parts) < 2 { resp = "ERR usage: AUTH " } else { token := parts[1] // Verify token sess, err := s.AuthManager.GetSession(token) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { session.token = token session.username = sess.Username resp = "OK" } } case "LOGOUT": if session.token != "" { s.AuthManager.Logout(session.token) session.token = "" session.username = "" } resp = "OK" case "GET": if len(parts) < 2 { resp = "ERR usage: GET " } else { key := parts[1] val, found, err := s.GetLinearAuthenticated(key, session.token) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else if !found { resp = "ERR not found" } else { resp = fmt.Sprintf("OK %s", val) } } case "SET": if len(parts) < 3 { resp = "ERR usage: SET " } else { key := parts[1] // Value might contain spaces, join the rest val := strings.Join(parts[2:], " ") // Use SetAuthenticated (Sync) by default for safety err := s.SetAuthenticated(key, val, session.token) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = "OK" } } case "ASET": // Async SET for high performance if len(parts) < 3 { resp = "ERR usage: ASET " } else { key := parts[1] val := strings.Join(parts[2:], " ") err := s.SetAuthenticatedAsync(key, val, session.token) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = "OK" } } case "DEL": if len(parts) < 2 { resp = "ERR usage: DEL " } else { key := parts[1] err := s.DelAuthenticated(key, session.token) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = "OK" } } case "MFA-GENERATE": if session.token == "" { resp = "ERR not authenticated" } else { secret, err := GenerateMFASecret() if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { url := GenerateOTPAuthURL(session.username, secret, "RaftKV") resp = fmt.Sprintf("OK %s %s", secret, url) } } case "MFA-ENABLE": if len(parts) < 3 { resp = "ERR usage: MFA-ENABLE " } else if session.token == "" { resp = "ERR not authenticated" } else { secret := parts[1] code := parts[2] if !ValidateTOTP(secret, code) { resp = "ERR invalid code" } else { if err := s.SetUserMFA(session.username, secret, true, session.token); err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = "OK" } } } case "MFA-DISABLE": if session.token == "" { resp = "ERR not authenticated" } else { if err := s.SetUserMFA(session.username, "", false, session.token); err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = "OK" } } case "MFA-STATUS": if session.token == "" { resp = "ERR not authenticated" } else { user, err := s.AuthManager.GetUser(session.username) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { status := "disabled" if user.MFAEnabled { status = "enabled" } resp = fmt.Sprintf("OK %s", status) } } case "WHOAMI": if session.username == "" { resp = "Guest" } else { resp = session.username } case "HELP": helpText := `Available Commands: GET - Get value SET - Set value DEL - Delete value SEARCH [limit] - Search keys (e.g. user.*) COUNT - Count keys INFO - Show system stats WHOAMI - Show current user JOIN - Add node (Root only) LEAVE - Remove node (Root only) USER_LIST - List users (Admin) ROLE_LIST - List roles (Admin) LOGIN/LOGOUT/EXIT` resp = "OK " + helpText case "INFO": // Check permission (Admin only if auth enabled) if s.AuthManager.IsEnabled() { // Allow if admin OR if root (HasFullAccess) // But IsAdmin is basically check for "admin" on "*" // Let's relax it slightly for dashboard if we want read-only dashboard? // For now, strict: Admin access required. if !s.IsAdmin(session.token) { resp = "ERR Permission Denied: Admin access required" break } } // Gather stats stats := s.GetStats() health := s.HealthCheck() dbSize := s.GetDBSize() logSize := s.GetLogSize() var m runtime.MemStats runtime.ReadMemStats(&m) // Construct JSON response info := map[string]interface{}{ "node": map[string]interface{}{ "id": health.NodeID, "state": health.State, "term": health.Term, "leader": health.LeaderID, "healthy": health.IsHealthy, }, "storage": map[string]interface{}{ "db_size": dbSize, "log_size": logSize, "mem_alloc": m.Alloc, "mem_sys": m.Sys, "num_gc": m.NumGC, }, "indices": map[string]interface{}{ "commit_index": stats.CommitIndex, "applied_index": stats.LastApplied, "last_log_index": stats.LastLogIndex, "db_applied": s.DB.GetLastAppliedIndex(), }, "cluster": stats.ClusterNodes, "cluster_size": stats.ClusterSize, } data, err := json.Marshal(info) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = "OK " + string(data) } case "SEARCH": // Usage: SEARCH [limit] [offset] if len(parts) < 2 { resp = "ERR usage: SEARCH [limit] [offset]" } else { pattern := parts[1] limit := 20 offset := 0 if len(parts) >= 3 { if l, err := strconv.Atoi(parts[2]); err == nil { limit = l } } if len(parts) >= 4 { if o, err := strconv.Atoi(parts[3]); err == nil { offset = o } } results, err := s.SearchAuthenticated(pattern, limit, offset, session.token) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { data, _ := json.Marshal(results) resp = "OK " + string(data) } } case "COUNT": // Usage: COUNT if len(parts) < 2 { resp = "ERR usage: COUNT " } else { pattern := parts[1] count, err := s.CountAuthenticated(pattern, session.token) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = fmt.Sprintf("OK %d", count) } } case "JOIN": // Usage: JOIN // Admin only if s.AuthManager.IsEnabled() { if !s.IsAdmin(session.token) { resp = "ERR Permission Denied: Admin access required" break } } if len(parts) < 3 { resp = "ERR usage: JOIN " } else { err := s.Join(parts[1], parts[2]) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = "OK Join request sent" } } case "LEAVE": // Usage: LEAVE // Admin only if s.AuthManager.IsEnabled() { if !s.IsAdmin(session.token) { resp = "ERR Permission Denied: Admin access required" break } } if len(parts) < 2 { resp = "ERR usage: LEAVE " } else { err := s.Leave(parts[1]) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = "OK Leave request sent" } } // --- Admin Commands --- case "USER_LIST": users, err := s.ListUsers(session.token) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { data, err := json.Marshal(users) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { // Ensure it's a single line for TCP protocol simplicity jsonStr := string(data) resp = fmt.Sprintf("OK %s", jsonStr) } } case "ROLE_LIST": roles, err := s.ListRoles(session.token) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { data, err := json.Marshal(roles) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = fmt.Sprintf("OK %s", string(data)) } } case "USER_CREATE": // Usage: USER_CREATE if len(parts) < 3 { resp = "ERR usage: USER_CREATE [roles]" } else { u := parts[1] p := parts[2] var roles []string if len(parts) > 3 { roles = strings.Split(parts[3], ",") } // Use Server method which performs permission check err := s.CreateUser(u, p, roles, session.token) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = "OK" } } case "USER_UPDATE": // Usage: USER_UPDATE if len(parts) < 4 { resp = "ERR usage: USER_UPDATE " } else { username := parts[1] newPass := parts[2] newRolesStr := parts[3] // Get existing user to preserve other fields userPtr, err := s.AuthManager.GetUser(username) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { // Create copy to modify user := *userPtr // Update Password if requested if newPass != "-" { salt := fmt.Sprintf("%d", time.Now().UnixNano()) user.Salt = salt user.PasswordHash = HashPassword(newPass, salt) } // Update Roles if requested if newRolesStr != "-" { if newRolesStr == "" { user.Roles = []string{} } else { user.Roles = strings.Split(newRolesStr, ",") } } // Use Server method which performs permission check err := s.UpdateUser(user, session.token) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = "OK" } } } case "ROLE_CREATE": // Usage: ROLE_CREATE if len(parts) < 2 { resp = "ERR usage: ROLE_CREATE " } else { name := parts[1] // Use Server method which performs permission check err := s.CreateRole(name, session.token) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = "OK" } } case "ROLE_PERMISSION_ADD": // Usage: ROLE_PERMISSION_ADD [min] [max] // Actions: comma separated list of actions (read,write,admin,*) // Min/Max: optional numeric constraints for write operations ("-" for no constraint) if len(parts) < 4 { resp = "ERR usage: ROLE_PERMISSION_ADD [min] [max]" } else { roleName := parts[1] pattern := parts[2] actionsStr := parts[3] actions := strings.Split(actionsStr, ",") var minVal, maxVal *float64 if len(parts) > 4 { if parts[4] != "-" && parts[4] != "null" { if v, err := strconv.ParseFloat(parts[4], 64); err == nil { minVal = &v } else { resp = "ERR invalid min value" break } } } if len(parts) > 5 { if parts[5] != "-" && parts[5] != "null" { if v, err := strconv.ParseFloat(parts[5], 64); err == nil { maxVal = &v } else { resp = "ERR invalid max value" break } } } // Check auth logic inside UpdateRole call? // No, UpdateRole in server.go handles Auth check now. // We just need to construct the object. rolePtr, err := s.AuthManager.GetRole(roleName) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { // Create a copy to modify role := *rolePtr // Deep copy permissions originalPerms := role.Permissions role.Permissions = make([]Permission, len(originalPerms)) copy(role.Permissions, originalPerms) newPerm := Permission{ KeyPattern: pattern, Actions: actions, } if minVal != nil || maxVal != nil { newPerm.Constraint = &Constraint{ Min: minVal, Max: maxVal, } } // Upsert logic found := false for i, p := range role.Permissions { if p.KeyPattern == pattern { role.Permissions[i] = newPerm found = true break } } if !found { role.Permissions = append(role.Permissions, newPerm) } // Use server.UpdateRole which performs Delegation Check err := s.UpdateRole(role, session.token) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = "OK" } } } case "ROLE_PERMISSION_REMOVE": // Usage: ROLE_PERMISSION_REMOVE if len(parts) < 3 { resp = "ERR usage: ROLE_PERMISSION_REMOVE " } else { roleName := parts[1] pattern := parts[2] rolePtr, err := s.AuthManager.GetRole(roleName) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { role := *rolePtr originalPerms := role.Permissions newPerms := make([]Permission, 0, len(originalPerms)) found := false for _, p := range originalPerms { if p.KeyPattern == pattern { found = true continue } newPerms = append(newPerms, p) } if !found { resp = "ERR permission not found" } else { role.Permissions = newPerms // Use server.UpdateRole which performs Delegation Check err := s.UpdateRole(role, session.token) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = "OK" } } } } case "USER_UNLOCK": // Usage: USER_UNLOCK if s.AuthManager.IsEnabled() { if !s.IsAdmin(session.token) { resp = "ERR Permission Denied: Admin access required" break } } if len(parts) < 2 { resp = "ERR usage: USER_UNLOCK " } else { // Manually clear the lock key userToUnlock := parts[1] // We use DelSync to ensure the lock is removed before returning err := s.DelSync("system.lock." + userToUnlock) if err != nil { resp = fmt.Sprintf("ERR %v", err) } else { resp = "OK" } } case "EXIT", "QUIT": resp = "BYE" // Need signal to close connection after write // For simplicity, handle it in handleTCPConnection loop break, // but here we just return the string. // Actually, BYE handling is tricky in async writer. // Let's keep connection open or let client close it. // Or send special signal? // For now, simple return. Client will read BYE and close. default: s.Raft.config.Logger.Warn("Unknown command received: %s (parts: %v)", cmd, parts) resp = fmt.Sprintf("ERR unknown command: %s", cmd) } return resp }