| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649 |
- package raft
- import (
- "bufio"
- "encoding/json"
- "fmt"
- "io"
- "net"
- "runtime"
- "strconv"
- "strings"
- "time"
- )
- // TCPClientSession holds state for a TCP connection
- type TCPClientSession struct {
- conn net.Conn
- server *KVServer
- token string
- username string
- reader *bufio.Reader
- writer *bufio.Writer
- remoteAddr string
- }
- func (s *KVServer) StartTCPServer(addr string) error {
- listener, err := net.Listen("tcp", addr)
- if err != nil {
- return err
- }
- s.Raft.config.Logger.Info("TCP API server listening on %s", addr)
- go func() {
- defer listener.Close()
- for {
- conn, err := listener.Accept()
- if err != nil {
- if s.stopCh != nil {
- select {
- case <-s.stopCh:
- return
- default:
- }
- }
- s.Raft.config.Logger.Error("TCP Accept error: %v", err)
- continue
- }
- go s.handleTCPConnection(conn)
- }
- }()
- return nil
- }
- func (s *KVServer) handleTCPConnection(conn net.Conn) {
- defer conn.Close()
- // Use larger buffers for high throughput
- reader := bufio.NewReaderSize(conn, 64*1024)
- writer := bufio.NewWriterSize(conn, 64*1024)
- session := &TCPClientSession{
- conn: conn,
- server: s,
- reader: reader,
- writer: writer,
- remoteAddr: conn.RemoteAddr().String(),
- }
- for {
- // Set Keep-Alive Deadline
- // Using a longer deadline avoids frequent syscalls if traffic is continuous
- conn.SetReadDeadline(time.Now().Add(60 * time.Second))
- // Read Request Line
- line, err := reader.ReadString('\n')
- if err != nil {
- return
- }
- line = strings.TrimSpace(line)
- if line == "" {
- continue
- }
- // Parse Headers
- var contentLength int
- for {
- hLine, err := reader.ReadString('\n')
- if err != nil {
- return
- }
- hLine = strings.TrimSpace(hLine)
- if hLine == "" {
- break
- }
- // Optimized Content-Length check
- // "Content-Length: 123"
- lowerHLine := strings.ToLower(hLine)
- if strings.HasPrefix(lowerHLine, "content-length:") {
- valStr := strings.TrimSpace(hLine[15:])
- if l, err := strconv.Atoi(valStr); err == nil {
- contentLength = l
- }
- }
- }
- // Read Body
- var body string
- if contentLength > 0 {
- // Optimization: Avoid allocation for small bodies if possible,
- // but Raft commands need to be passed as bytes/string anyway.
- buf := make([]byte, contentLength)
- if _, err := io.ReadFull(reader, buf); err != nil {
- return
- }
- body = string(buf)
- }
- // Implement Batching: Accumulate requests if they are ASET/SET
- // Ideally we need an internal buffer or channel here to queue up commands.
- // For zero-allocation batching in this loop structure, we can try to
- // eagerly read the next request if available in the buffer.
-
- // Note: Raft's Propose is thread-safe. The current serial loop is efficient for
- // minimizing context switches. Batching helps if we can merge multiple Proposes into one.
- // Since we haven't modified Raft to support ProposeBatch yet, we can't do true backend batching easily.
- // However, we can do "frontend batching" by checking buffer availability?
- // No, frontend batching requires Raft to accept a batch.
-
- // Without modifying Raft.Propose to accept []Command, "batching" here is limited.
- // BUT, we can at least pipeline the execution if we had concurrent workers,
- // but we replaced that with this serial loop for perf.
- // So for now, we just execute.
-
- // Execute Command Inline
- resp := s.executeCommandWithBody(session, line, body)
- // Write Response
- if _, err := writer.WriteString(resp + "\n"); err != nil {
- return
- }
- // Flush Optimization:
- // Only flush if there is no more data in the read buffer.
- // This automatically batches responses when requests are pipelined.
- if reader.Buffered() == 0 {
- if err := writer.Flush(); err != nil {
- return
- }
- }
- }
- }
- // Helper to bridge old logic
- func (s *KVServer) executeCommandWithBody(session *TCPClientSession, line string, body string) string {
- parts := strings.Fields(line)
- if len(parts) == 0 {
- return ""
- }
-
- cmd := strings.ToUpper(parts[0])
-
- // If body is present, it overrides the value part of the command
- // We handle specific commands that use body
- if body != "" {
- switch cmd {
- case "SET", "ASET":
- if len(parts) < 2 {
- return "ERR usage: SET <key> (value in body)"
- }
- key := parts[1]
- // Value is body
- if cmd == "SET" {
- if err := s.SetAuthenticated(key, body, session.token); err != nil {
- return fmt.Sprintf("ERR %v", err)
- }
- return "OK"
- } else {
- if err := s.SetAuthenticatedAsync(key, body, session.token); err != nil {
- return fmt.Sprintf("ERR %v", err)
- }
- return "OK"
- }
- }
- }
-
- // Fallback to legacy parsing if no body or command doesn't use body
- return s.executeCommand(session, line)
- }
- func (s *KVServer) executeCommand(session *TCPClientSession, line string) string {
- parts := strings.Fields(line)
- if len(parts) == 0 {
- return ""
- }
- cmd := strings.ToUpper(parts[0])
- var resp string
- switch cmd {
- case "LOGIN":
- if len(parts) < 3 {
- resp = "ERR usage: LOGIN <username> <password> [otp]"
- } else {
- user := parts[1]
- pass := parts[2]
- otp := ""
- if len(parts) > 3 {
- otp = parts[3]
- }
-
- // Extract IP
- ip := session.remoteAddr
- if host, _, err := net.SplitHostPort(ip); err == nil {
- ip = host
- }
- token, err := s.AuthManager.Login(user, pass, otp, ip)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- session.token = token
- session.username = user
- resp = fmt.Sprintf("OK %s", token)
- }
- }
- case "AUTH":
- if len(parts) < 2 {
- resp = "ERR usage: AUTH <token>"
- } else {
- token := parts[1]
- // Verify token
- sess, err := s.AuthManager.GetSession(token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- session.token = token
- session.username = sess.Username
- resp = "OK"
- }
- }
- case "LOGOUT":
- if session.token != "" {
- s.AuthManager.Logout(session.token)
- session.token = ""
- session.username = ""
- }
- resp = "OK"
- case "GET":
- if len(parts) < 2 {
- resp = "ERR usage: GET <key>"
- } else {
- key := parts[1]
- val, found, err := s.GetLinearAuthenticated(key, session.token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else if !found {
- resp = "ERR not found"
- } else {
- resp = fmt.Sprintf("OK %s", val)
- }
- }
- case "SET":
- if len(parts) < 3 {
- resp = "ERR usage: SET <key> <value>"
- } else {
- key := parts[1]
- // Value might contain spaces, join the rest
- val := strings.Join(parts[2:], " ")
- // Use SetAuthenticated (Sync) by default for safety
- err := s.SetAuthenticated(key, val, session.token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
- case "ASET":
- // Async SET for high performance
- if len(parts) < 3 {
- resp = "ERR usage: ASET <key> <value>"
- } else {
- key := parts[1]
- val := strings.Join(parts[2:], " ")
- err := s.SetAuthenticatedAsync(key, val, session.token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
- case "DEL":
- if len(parts) < 2 {
- resp = "ERR usage: DEL <key>"
- } else {
- key := parts[1]
- err := s.DelAuthenticated(key, session.token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
-
- case "WHOAMI":
- if session.username == "" {
- resp = "Guest"
- } else {
- resp = session.username
- }
- case "HELP":
- helpText := `Available Commands:
- GET <key> - Get value
- SET <key> <value> - Set value
- DEL <key> - Delete value
- SEARCH <pattern> [limit] - Search keys (e.g. user.*)
- COUNT <pattern> - Count keys
- INFO - Show system stats
- WHOAMI - Show current user
- JOIN <id> <addr> - Add node (Root only)
- LEAVE <id> - Remove node (Root only)
- USER_LIST - List users (Admin)
- ROLE_LIST - List roles (Admin)
- LOGIN/LOGOUT/EXIT`
- resp = "OK " + helpText
- case "INFO":
- // Check permission (Admin only if auth enabled)
- if s.AuthManager.IsEnabled() {
- if !s.IsAdmin(session.token) {
- resp = "ERR Permission Denied: Admin access required"
- break
- }
- }
- // Gather stats
- stats := s.GetStats()
- health := s.HealthCheck()
- dbSize := s.GetDBSize()
- logSize := s.GetLogSize()
- var m runtime.MemStats
- runtime.ReadMemStats(&m)
- // Construct JSON response
- info := map[string]interface{}{
- "node": map[string]interface{}{
- "id": health.NodeID,
- "state": health.State,
- "term": health.Term,
- "leader": health.LeaderID,
- "healthy": health.IsHealthy,
- },
- "storage": map[string]interface{}{
- "db_size": dbSize,
- "log_size": logSize,
- "mem_alloc": m.Alloc,
- "mem_sys": m.Sys,
- "num_gc": m.NumGC,
- },
- "indices": map[string]interface{}{
- "commit_index": stats.CommitIndex,
- "applied_index": stats.LastApplied,
- "last_log_index": stats.LastLogIndex,
- "db_applied": s.DB.GetLastAppliedIndex(),
- },
- "cluster": stats.ClusterNodes,
- "cluster_size": stats.ClusterSize,
- }
- data, err := json.Marshal(info)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK " + string(data)
- }
- case "SEARCH":
- // Usage: SEARCH <pattern> [limit] [offset]
- if len(parts) < 2 {
- resp = "ERR usage: SEARCH <pattern> [limit] [offset]"
- } else {
- pattern := parts[1]
- limit := 20
- offset := 0
- if len(parts) >= 3 {
- if l, err := strconv.Atoi(parts[2]); err == nil {
- limit = l
- }
- }
- if len(parts) >= 4 {
- if o, err := strconv.Atoi(parts[3]); err == nil {
- offset = o
- }
- }
- results, err := s.SearchAuthenticated(pattern, limit, offset, session.token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- data, _ := json.Marshal(results)
- resp = "OK " + string(data)
- }
- }
- case "COUNT":
- // Usage: COUNT <pattern>
- if len(parts) < 2 {
- resp = "ERR usage: COUNT <pattern>"
- } else {
- pattern := parts[1]
- count, err := s.CountAuthenticated(pattern, session.token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = fmt.Sprintf("OK %d", count)
- }
- }
- case "JOIN":
- // Usage: JOIN <id> <addr>
- // Admin only
- if s.AuthManager.IsEnabled() {
- if !s.IsAdmin(session.token) {
- resp = "ERR Permission Denied: Admin access required"
- break
- }
- }
- if len(parts) < 3 {
- resp = "ERR usage: JOIN <id> <addr>"
- } else {
- err := s.Join(parts[1], parts[2])
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK Join request sent"
- }
- }
- case "LEAVE":
- // Usage: LEAVE <id>
- // Admin only
- if s.AuthManager.IsEnabled() {
- if !s.IsAdmin(session.token) {
- resp = "ERR Permission Denied: Admin access required"
- break
- }
- }
- if len(parts) < 2 {
- resp = "ERR usage: LEAVE <id>"
- } else {
- err := s.Leave(parts[1])
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK Leave request sent"
- }
- }
- // --- Admin Commands ---
- case "USER_LIST":
- users := s.AuthManager.ListUsers()
- data, err := json.Marshal(users)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- // Ensure it's a single line for TCP protocol simplicity
- jsonStr := string(data)
- resp = fmt.Sprintf("OK %s", jsonStr)
- }
- case "ROLE_LIST":
- roles := s.AuthManager.ListRoles()
- data, err := json.Marshal(roles)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = fmt.Sprintf("OK %s", string(data))
- }
- case "USER_CREATE":
- // Usage: USER_CREATE <username> <password> <role1,role2>
- if len(parts) < 3 {
- resp = "ERR usage: USER_CREATE <user> <pass> [roles]"
- } else {
- u := parts[1]
- p := parts[2]
- var roles []string
- if len(parts) > 3 {
- roles = strings.Split(parts[3], ",")
- }
- // Use RegisterUser (sync)
- err := s.AuthManager.RegisterUser(u, p, roles)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
- case "ROLE_CREATE":
- // Usage: ROLE_CREATE <name>
- if len(parts) < 2 {
- resp = "ERR usage: ROLE_CREATE <name>"
- } else {
- name := parts[1]
- err := s.AuthManager.CreateRole(name)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
- case "ROLE_PERMISSION_ADD":
- // Usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>
- // Actions: comma separated list of actions (read,write,admin,*)
- if len(parts) < 4 {
- resp = "ERR usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>"
- } else {
- roleName := parts[1]
- pattern := parts[2]
- actionsStr := parts[3]
- actions := strings.Split(actionsStr, ",")
- rolePtr, err := s.AuthManager.GetRole(roleName)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- // Create a copy to modify
- role := *rolePtr
-
- // Deep copy permissions to avoid potential side effects on cached object
- originalPerms := role.Permissions
- role.Permissions = make([]Permission, len(originalPerms))
- copy(role.Permissions, originalPerms)
-
- newPerm := Permission{
- KeyPattern: pattern,
- Actions: actions,
- }
- // Upsert logic: Update if exists, Append if new
- found := false
- for i, p := range role.Permissions {
- if p.KeyPattern == pattern {
- role.Permissions[i] = newPerm
- found = true
- break
- }
- }
- if !found {
- role.Permissions = append(role.Permissions, newPerm)
- }
-
- err := s.AuthManager.UpdateRole(role)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
- }
- case "ROLE_PERMISSION_REMOVE":
- // Usage: ROLE_PERMISSION_REMOVE <role> <pattern>
- if len(parts) < 3 {
- resp = "ERR usage: ROLE_PERMISSION_REMOVE <role> <pattern>"
- } else {
- roleName := parts[1]
- pattern := parts[2]
- rolePtr, err := s.AuthManager.GetRole(roleName)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- role := *rolePtr
- originalPerms := role.Permissions
- newPerms := make([]Permission, 0, len(originalPerms))
-
- found := false
- for _, p := range originalPerms {
- if p.KeyPattern == pattern {
- found = true
- continue
- }
- newPerms = append(newPerms, p)
- }
-
- if !found {
- resp = "ERR permission not found"
- } else {
- role.Permissions = newPerms
- err := s.AuthManager.UpdateRole(role)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
- }
- }
-
- case "USER_UNLOCK":
- // Usage: USER_UNLOCK <username>
- if s.AuthManager.IsEnabled() {
- if !s.IsAdmin(session.token) {
- resp = "ERR Permission Denied: Admin access required"
- break
- }
- }
- if len(parts) < 2 {
- resp = "ERR usage: USER_UNLOCK <username>"
- } else {
- // Manually clear the lock key
- userToUnlock := parts[1]
- // We use Del to remove the lock key
- err := s.Del("system.lock." + userToUnlock)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
- case "EXIT", "QUIT":
- resp = "BYE"
- // Need signal to close connection after write
- // For simplicity, handle it in handleTCPConnection loop break,
- // but here we just return the string.
- // Actually, BYE handling is tricky in async writer.
- // Let's keep connection open or let client close it.
- // Or send special signal?
- // For now, simple return. Client will read BYE and close.
- default:
- s.Raft.config.Logger.Warn("Unknown command received: %s (parts: %v)", cmd, parts)
- resp = fmt.Sprintf("ERR unknown command: %s", cmd)
- }
- return resp
- }
|