|
|
@@ -4,10 +4,12 @@ import (
|
|
|
"bufio"
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
+ "io"
|
|
|
"net"
|
|
|
"runtime"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
+ "time"
|
|
|
)
|
|
|
|
|
|
// TCPClientSession holds state for a TCP connection
|
|
|
@@ -50,137 +52,258 @@ func (s *KVServer) StartTCPServer(addr string) error {
|
|
|
}
|
|
|
|
|
|
func (s *KVServer) handleTCPConnection(conn net.Conn) {
|
|
|
+ defer conn.Close()
|
|
|
+
|
|
|
+ // Use larger buffers for high throughput
|
|
|
+ reader := bufio.NewReaderSize(conn, 64*1024)
|
|
|
+ writer := bufio.NewWriterSize(conn, 64*1024)
|
|
|
+
|
|
|
session := &TCPClientSession{
|
|
|
conn: conn,
|
|
|
server: s,
|
|
|
- reader: bufio.NewReader(conn),
|
|
|
- writer: bufio.NewWriter(conn),
|
|
|
+ reader: reader,
|
|
|
+ writer: writer,
|
|
|
remoteAddr: conn.RemoteAddr().String(),
|
|
|
}
|
|
|
- defer conn.Close()
|
|
|
|
|
|
for {
|
|
|
- line, err := session.reader.ReadString('\n')
|
|
|
+ // Set Keep-Alive Deadline
|
|
|
+ // Using a longer deadline avoids frequent syscalls if traffic is continuous
|
|
|
+ conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
|
+
|
|
|
+ // Read Request Line
|
|
|
+ line, err := reader.ReadString('\n')
|
|
|
if err != nil {
|
|
|
- return // Connection closed
|
|
|
+ return
|
|
|
}
|
|
|
+
|
|
|
line = strings.TrimSpace(line)
|
|
|
if line == "" {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- parts := strings.Fields(line)
|
|
|
- if len(parts) == 0 {
|
|
|
- continue
|
|
|
+ // Parse Headers
|
|
|
+ var contentLength int
|
|
|
+ for {
|
|
|
+ hLine, err := reader.ReadString('\n')
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ hLine = strings.TrimSpace(hLine)
|
|
|
+ if hLine == "" {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ // Optimized Content-Length check
|
|
|
+ // "Content-Length: 123"
|
|
|
+ lowerHLine := strings.ToLower(hLine)
|
|
|
+ if strings.HasPrefix(lowerHLine, "content-length:") {
|
|
|
+ valStr := strings.TrimSpace(hLine[15:])
|
|
|
+ if l, err := strconv.Atoi(valStr); err == nil {
|
|
|
+ contentLength = l
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Read Body
|
|
|
+ var body string
|
|
|
+ if contentLength > 0 {
|
|
|
+ // Optimization: Avoid allocation for small bodies if possible,
|
|
|
+ // but Raft commands need to be passed as bytes/string anyway.
|
|
|
+ buf := make([]byte, contentLength)
|
|
|
+ if _, err := io.ReadFull(reader, buf); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ body = string(buf)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Execute Command Inline
|
|
|
+ // This avoids the overhead of launching a goroutine per request
|
|
|
+ // and the complexity/contention of a response channel.
|
|
|
+ // For ASET (Async), this returns almost immediately.
|
|
|
+ resp := s.executeCommandWithBody(session, line, body)
|
|
|
+
|
|
|
+ // Write Response
|
|
|
+ if _, err := writer.WriteString(resp + "\n"); err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
- cmd := strings.ToUpper(parts[0])
|
|
|
- var resp string
|
|
|
+ // Flush Optimization:
|
|
|
+ // Only flush if there is no more data in the read buffer.
|
|
|
+ // This automatically batches responses when requests are pipelined.
|
|
|
+ if reader.Buffered() == 0 {
|
|
|
+ if err := writer.Flush(); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
+// Helper to bridge old logic
|
|
|
+func (s *KVServer) executeCommandWithBody(session *TCPClientSession, line string, body string) string {
|
|
|
+ parts := strings.Fields(line)
|
|
|
+ if len(parts) == 0 {
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+
|
|
|
+ cmd := strings.ToUpper(parts[0])
|
|
|
+
|
|
|
+ // If body is present, it overrides the value part of the command
|
|
|
+ // We handle specific commands that use body
|
|
|
+ if body != "" {
|
|
|
switch cmd {
|
|
|
- case "LOGIN":
|
|
|
- if len(parts) < 3 {
|
|
|
- resp = "ERR usage: LOGIN <username> <password> [otp]"
|
|
|
- } else {
|
|
|
- user := parts[1]
|
|
|
- pass := parts[2]
|
|
|
- otp := ""
|
|
|
- if len(parts) > 3 {
|
|
|
- otp = parts[3]
|
|
|
+ case "SET", "ASET":
|
|
|
+ if len(parts) < 2 {
|
|
|
+ return "ERR usage: SET <key> (value in body)"
|
|
|
+ }
|
|
|
+ key := parts[1]
|
|
|
+ // Value is body
|
|
|
+ if cmd == "SET" {
|
|
|
+ if err := s.SetAuthenticated(key, body, session.token); err != nil {
|
|
|
+ return fmt.Sprintf("ERR %v", err)
|
|
|
}
|
|
|
-
|
|
|
- // Extract IP
|
|
|
- ip := session.remoteAddr
|
|
|
- if host, _, err := net.SplitHostPort(ip); err == nil {
|
|
|
- ip = host
|
|
|
+ return "OK"
|
|
|
+ } else {
|
|
|
+ if err := s.SetAuthenticatedAsync(key, body, session.token); err != nil {
|
|
|
+ return fmt.Sprintf("ERR %v", err)
|
|
|
}
|
|
|
+ return "OK"
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Fallback to legacy parsing if no body or command doesn't use body
|
|
|
+ return s.executeCommand(session, line)
|
|
|
+}
|
|
|
|
|
|
- token, err := s.AuthManager.Login(user, pass, otp, ip)
|
|
|
- if err != nil {
|
|
|
- resp = fmt.Sprintf("ERR %v", err)
|
|
|
- } else {
|
|
|
- session.token = token
|
|
|
- session.username = user
|
|
|
- resp = fmt.Sprintf("OK %s", token)
|
|
|
- }
|
|
|
+func (s *KVServer) executeCommand(session *TCPClientSession, line string) string {
|
|
|
+ parts := strings.Fields(line)
|
|
|
+ if len(parts) == 0 {
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+
|
|
|
+ cmd := strings.ToUpper(parts[0])
|
|
|
+ var resp string
|
|
|
+
|
|
|
+ switch cmd {
|
|
|
+ case "LOGIN":
|
|
|
+ if len(parts) < 3 {
|
|
|
+ resp = "ERR usage: LOGIN <username> <password> [otp]"
|
|
|
+ } else {
|
|
|
+ user := parts[1]
|
|
|
+ pass := parts[2]
|
|
|
+ otp := ""
|
|
|
+ if len(parts) > 3 {
|
|
|
+ otp = parts[3]
|
|
|
+ }
|
|
|
+
|
|
|
+ // Extract IP
|
|
|
+ ip := session.remoteAddr
|
|
|
+ if host, _, err := net.SplitHostPort(ip); err == nil {
|
|
|
+ ip = host
|
|
|
}
|
|
|
|
|
|
- case "AUTH":
|
|
|
- if len(parts) < 2 {
|
|
|
- resp = "ERR usage: AUTH <token>"
|
|
|
+ token, err := s.AuthManager.Login(user, pass, otp, ip)
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
} else {
|
|
|
- token := parts[1]
|
|
|
- // Verify token
|
|
|
- sess, err := s.AuthManager.GetSession(token)
|
|
|
- if err != nil {
|
|
|
- resp = fmt.Sprintf("ERR %v", err)
|
|
|
- } else {
|
|
|
- session.token = token
|
|
|
- session.username = sess.Username
|
|
|
- resp = "OK"
|
|
|
- }
|
|
|
+ session.token = token
|
|
|
+ session.username = user
|
|
|
+ resp = fmt.Sprintf("OK %s", token)
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- case "LOGOUT":
|
|
|
- if session.token != "" {
|
|
|
- s.AuthManager.Logout(session.token)
|
|
|
- session.token = ""
|
|
|
- session.username = ""
|
|
|
+ case "AUTH":
|
|
|
+ if len(parts) < 2 {
|
|
|
+ resp = "ERR usage: AUTH <token>"
|
|
|
+ } else {
|
|
|
+ token := parts[1]
|
|
|
+ // Verify token
|
|
|
+ sess, err := s.AuthManager.GetSession(token)
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
+ } else {
|
|
|
+ session.token = token
|
|
|
+ session.username = sess.Username
|
|
|
+ resp = "OK"
|
|
|
}
|
|
|
- resp = "OK"
|
|
|
+ }
|
|
|
|
|
|
- case "GET":
|
|
|
- if len(parts) < 2 {
|
|
|
- resp = "ERR usage: GET <key>"
|
|
|
+ case "LOGOUT":
|
|
|
+ if session.token != "" {
|
|
|
+ s.AuthManager.Logout(session.token)
|
|
|
+ session.token = ""
|
|
|
+ session.username = ""
|
|
|
+ }
|
|
|
+ resp = "OK"
|
|
|
+
|
|
|
+ case "GET":
|
|
|
+ if len(parts) < 2 {
|
|
|
+ resp = "ERR usage: GET <key>"
|
|
|
+ } else {
|
|
|
+ key := parts[1]
|
|
|
+ val, found, err := s.GetLinearAuthenticated(key, session.token)
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
+ } else if !found {
|
|
|
+ resp = "ERR not found"
|
|
|
} else {
|
|
|
- key := parts[1]
|
|
|
- val, found, err := s.GetLinearAuthenticated(key, session.token)
|
|
|
- if err != nil {
|
|
|
- resp = fmt.Sprintf("ERR %v", err)
|
|
|
- } else if !found {
|
|
|
- resp = "ERR not found"
|
|
|
- } else {
|
|
|
- resp = fmt.Sprintf("OK %s", val)
|
|
|
- }
|
|
|
+ resp = fmt.Sprintf("OK %s", val)
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- case "SET":
|
|
|
- if len(parts) < 3 {
|
|
|
- resp = "ERR usage: SET <key> <value>"
|
|
|
+ case "SET":
|
|
|
+ if len(parts) < 3 {
|
|
|
+ resp = "ERR usage: SET <key> <value>"
|
|
|
+ } else {
|
|
|
+ key := parts[1]
|
|
|
+ // Value might contain spaces, join the rest
|
|
|
+ val := strings.Join(parts[2:], " ")
|
|
|
+ // Use SetAuthenticated (Sync) by default for safety
|
|
|
+ err := s.SetAuthenticated(key, val, session.token)
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
} else {
|
|
|
- key := parts[1]
|
|
|
- // Value might contain spaces, join the rest
|
|
|
- val := strings.Join(parts[2:], " ")
|
|
|
- err := s.SetAuthenticated(key, val, session.token)
|
|
|
- if err != nil {
|
|
|
- resp = fmt.Sprintf("ERR %v", err)
|
|
|
- } else {
|
|
|
- resp = "OK"
|
|
|
- }
|
|
|
+ resp = "OK"
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- case "DEL":
|
|
|
- if len(parts) < 2 {
|
|
|
- resp = "ERR usage: DEL <key>"
|
|
|
+ case "ASET":
|
|
|
+ // Async SET for high performance
|
|
|
+ if len(parts) < 3 {
|
|
|
+ resp = "ERR usage: ASET <key> <value>"
|
|
|
+ } else {
|
|
|
+ key := parts[1]
|
|
|
+ val := strings.Join(parts[2:], " ")
|
|
|
+ err := s.SetAuthenticatedAsync(key, val, session.token)
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
} else {
|
|
|
- key := parts[1]
|
|
|
- err := s.DelAuthenticated(key, session.token)
|
|
|
- if err != nil {
|
|
|
- resp = fmt.Sprintf("ERR %v", err)
|
|
|
- } else {
|
|
|
- resp = "OK"
|
|
|
- }
|
|
|
+ resp = "OK"
|
|
|
}
|
|
|
-
|
|
|
- case "WHOAMI":
|
|
|
- if session.username == "" {
|
|
|
- resp = "Guest"
|
|
|
+ }
|
|
|
+
|
|
|
+ case "DEL":
|
|
|
+ if len(parts) < 2 {
|
|
|
+ resp = "ERR usage: DEL <key>"
|
|
|
+ } else {
|
|
|
+ key := parts[1]
|
|
|
+ err := s.DelAuthenticated(key, session.token)
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
} else {
|
|
|
- resp = session.username
|
|
|
+ resp = "OK"
|
|
|
}
|
|
|
+ }
|
|
|
+
|
|
|
+ case "WHOAMI":
|
|
|
+ if session.username == "" {
|
|
|
+ resp = "Guest"
|
|
|
+ } else {
|
|
|
+ resp = session.username
|
|
|
+ }
|
|
|
|
|
|
- case "HELP":
|
|
|
- helpText := `Available Commands:
|
|
|
+ case "HELP":
|
|
|
+ helpText := `Available Commands:
|
|
|
GET <key> - Get value
|
|
|
SET <key> <value> - Set value
|
|
|
DEL <key> - Delete value
|
|
|
@@ -193,234 +316,282 @@ func (s *KVServer) handleTCPConnection(conn net.Conn) {
|
|
|
USER_LIST - List users (Admin)
|
|
|
ROLE_LIST - List roles (Admin)
|
|
|
LOGIN/LOGOUT/EXIT`
|
|
|
- resp = "OK " + helpText
|
|
|
+ resp = "OK " + helpText
|
|
|
|
|
|
- case "INFO":
|
|
|
- // Check permission (Root only if auth enabled)
|
|
|
- if s.AuthManager.IsEnabled() {
|
|
|
- sess, err := s.AuthManager.GetSession(session.token)
|
|
|
- if err != nil {
|
|
|
- resp = fmt.Sprintf("ERR %v", err)
|
|
|
- break
|
|
|
+ case "INFO":
|
|
|
+ // Check permission (Root only if auth enabled)
|
|
|
+ if s.AuthManager.IsEnabled() {
|
|
|
+ sess, err := s.AuthManager.GetSession(session.token)
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
+ break
|
|
|
+ }
|
|
|
+ if sess.Username != "root" {
|
|
|
+ resp = "ERR Permission Denied: Root access required"
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Gather stats
|
|
|
+ stats := s.GetStats()
|
|
|
+ health := s.HealthCheck()
|
|
|
+ dbSize := s.GetDBSize()
|
|
|
+ logSize := s.GetLogSize()
|
|
|
+ var m runtime.MemStats
|
|
|
+ runtime.ReadMemStats(&m)
|
|
|
+
|
|
|
+ // Construct JSON response
|
|
|
+ info := map[string]interface{}{
|
|
|
+ "node": map[string]interface{}{
|
|
|
+ "id": health.NodeID,
|
|
|
+ "state": health.State,
|
|
|
+ "term": health.Term,
|
|
|
+ "leader": health.LeaderID,
|
|
|
+ "healthy": health.IsHealthy,
|
|
|
+ },
|
|
|
+ "storage": map[string]interface{}{
|
|
|
+ "db_size": dbSize,
|
|
|
+ "log_size": logSize,
|
|
|
+ "mem_alloc": m.Alloc,
|
|
|
+ "mem_sys": m.Sys,
|
|
|
+ "num_gc": m.NumGC,
|
|
|
+ },
|
|
|
+ "indices": map[string]interface{}{
|
|
|
+ "commit_index": stats.CommitIndex,
|
|
|
+ "applied_index": stats.LastApplied,
|
|
|
+ "last_log_index": stats.LastLogIndex,
|
|
|
+ "db_applied": s.DB.GetLastAppliedIndex(),
|
|
|
+ },
|
|
|
+ "cluster": stats.ClusterNodes,
|
|
|
+ "cluster_size": stats.ClusterSize,
|
|
|
+ }
|
|
|
+
|
|
|
+ data, err := json.Marshal(info)
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
+ } else {
|
|
|
+ resp = "OK " + string(data)
|
|
|
+ }
|
|
|
+
|
|
|
+ case "SEARCH":
|
|
|
+ // Usage: SEARCH <pattern> [limit] [offset]
|
|
|
+ if len(parts) < 2 {
|
|
|
+ resp = "ERR usage: SEARCH <pattern> [limit] [offset]"
|
|
|
+ } else {
|
|
|
+ pattern := parts[1]
|
|
|
+ limit := 20
|
|
|
+ offset := 0
|
|
|
+ if len(parts) >= 3 {
|
|
|
+ if l, err := strconv.Atoi(parts[2]); err == nil {
|
|
|
+ limit = l
|
|
|
}
|
|
|
- if sess.Username != "root" {
|
|
|
- resp = "ERR Permission Denied: Root access required"
|
|
|
- break
|
|
|
+ }
|
|
|
+ if len(parts) >= 4 {
|
|
|
+ if o, err := strconv.Atoi(parts[3]); err == nil {
|
|
|
+ offset = o
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // Gather stats
|
|
|
- stats := s.GetStats()
|
|
|
- health := s.HealthCheck()
|
|
|
- dbSize := s.GetDBSize()
|
|
|
- logSize := s.GetLogSize()
|
|
|
- var m runtime.MemStats
|
|
|
- runtime.ReadMemStats(&m)
|
|
|
-
|
|
|
- // Construct JSON response
|
|
|
- info := map[string]interface{}{
|
|
|
- "node": map[string]interface{}{
|
|
|
- "id": health.NodeID,
|
|
|
- "state": health.State,
|
|
|
- "term": health.Term,
|
|
|
- "leader": health.LeaderID,
|
|
|
- "healthy": health.IsHealthy,
|
|
|
- },
|
|
|
- "storage": map[string]interface{}{
|
|
|
- "db_size": dbSize,
|
|
|
- "log_size": logSize,
|
|
|
- "mem_alloc": m.Alloc,
|
|
|
- "mem_sys": m.Sys,
|
|
|
- "num_gc": m.NumGC,
|
|
|
- },
|
|
|
- "indices": map[string]interface{}{
|
|
|
- "commit_index": stats.CommitIndex,
|
|
|
- "applied_index": stats.LastApplied,
|
|
|
- "last_log_index": stats.LastLogIndex,
|
|
|
- "db_applied": s.DB.GetLastAppliedIndex(),
|
|
|
- },
|
|
|
- "cluster": stats.ClusterNodes,
|
|
|
- "cluster_size": stats.ClusterSize,
|
|
|
- }
|
|
|
-
|
|
|
- data, err := json.Marshal(info)
|
|
|
+ results, err := s.SearchAuthenticated(pattern, limit, offset, session.token)
|
|
|
if err != nil {
|
|
|
resp = fmt.Sprintf("ERR %v", err)
|
|
|
} else {
|
|
|
+ data, _ := json.Marshal(results)
|
|
|
resp = "OK " + string(data)
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- case "SEARCH":
|
|
|
- // Usage: SEARCH <pattern> [limit] [offset]
|
|
|
- if len(parts) < 2 {
|
|
|
- resp = "ERR usage: SEARCH <pattern> [limit] [offset]"
|
|
|
- } else {
|
|
|
- pattern := parts[1]
|
|
|
- limit := 20
|
|
|
- offset := 0
|
|
|
- if len(parts) >= 3 {
|
|
|
- if l, err := strconv.Atoi(parts[2]); err == nil {
|
|
|
- limit = l
|
|
|
- }
|
|
|
- }
|
|
|
- if len(parts) >= 4 {
|
|
|
- if o, err := strconv.Atoi(parts[3]); err == nil {
|
|
|
- offset = o
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- results, err := s.SearchAuthenticated(pattern, limit, offset, session.token)
|
|
|
- if err != nil {
|
|
|
- resp = fmt.Sprintf("ERR %v", err)
|
|
|
- } else {
|
|
|
- data, _ := json.Marshal(results)
|
|
|
- resp = "OK " + string(data)
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- case "COUNT":
|
|
|
- // Usage: COUNT <pattern>
|
|
|
- if len(parts) < 2 {
|
|
|
- resp = "ERR usage: COUNT <pattern>"
|
|
|
+ case "COUNT":
|
|
|
+ // Usage: COUNT <pattern>
|
|
|
+ if len(parts) < 2 {
|
|
|
+ resp = "ERR usage: COUNT <pattern>"
|
|
|
+ } else {
|
|
|
+ pattern := parts[1]
|
|
|
+ count, err := s.CountAuthenticated(pattern, session.token)
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
} else {
|
|
|
- pattern := parts[1]
|
|
|
- count, err := s.CountAuthenticated(pattern, session.token)
|
|
|
- if err != nil {
|
|
|
- resp = fmt.Sprintf("ERR %v", err)
|
|
|
- } else {
|
|
|
- resp = fmt.Sprintf("OK %d", count)
|
|
|
- }
|
|
|
+ resp = fmt.Sprintf("OK %d", count)
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- case "JOIN":
|
|
|
- // Usage: JOIN <id> <addr>
|
|
|
- // Admin only
|
|
|
- if s.AuthManager.IsEnabled() {
|
|
|
- sess, err := s.AuthManager.GetSession(session.token)
|
|
|
- if err != nil || sess.Username != "root" {
|
|
|
- resp = "ERR Permission Denied: Root access required"
|
|
|
- break
|
|
|
- }
|
|
|
+ case "JOIN":
|
|
|
+ // Usage: JOIN <id> <addr>
|
|
|
+ // Admin only
|
|
|
+ if s.AuthManager.IsEnabled() {
|
|
|
+ sess, err := s.AuthManager.GetSession(session.token)
|
|
|
+ if err != nil || sess.Username != "root" {
|
|
|
+ resp = "ERR Permission Denied: Root access required"
|
|
|
+ break
|
|
|
}
|
|
|
- if len(parts) < 3 {
|
|
|
- resp = "ERR usage: JOIN <id> <addr>"
|
|
|
+ }
|
|
|
+ if len(parts) < 3 {
|
|
|
+ resp = "ERR usage: JOIN <id> <addr>"
|
|
|
+ } else {
|
|
|
+ err := s.Join(parts[1], parts[2])
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
} else {
|
|
|
- err := s.Join(parts[1], parts[2])
|
|
|
- if err != nil {
|
|
|
- resp = fmt.Sprintf("ERR %v", err)
|
|
|
- } else {
|
|
|
- resp = "OK Join request sent"
|
|
|
- }
|
|
|
+ resp = "OK Join request sent"
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- case "LEAVE":
|
|
|
- // Usage: LEAVE <id>
|
|
|
- // Admin only
|
|
|
- if s.AuthManager.IsEnabled() {
|
|
|
- sess, err := s.AuthManager.GetSession(session.token)
|
|
|
- if err != nil || sess.Username != "root" {
|
|
|
- resp = "ERR Permission Denied: Root access required"
|
|
|
- break
|
|
|
- }
|
|
|
+ case "LEAVE":
|
|
|
+ // Usage: LEAVE <id>
|
|
|
+ // Admin only
|
|
|
+ if s.AuthManager.IsEnabled() {
|
|
|
+ sess, err := s.AuthManager.GetSession(session.token)
|
|
|
+ if err != nil || sess.Username != "root" {
|
|
|
+ resp = "ERR Permission Denied: Root access required"
|
|
|
+ break
|
|
|
}
|
|
|
- if len(parts) < 2 {
|
|
|
- resp = "ERR usage: LEAVE <id>"
|
|
|
+ }
|
|
|
+ if len(parts) < 2 {
|
|
|
+ resp = "ERR usage: LEAVE <id>"
|
|
|
+ } else {
|
|
|
+ err := s.Leave(parts[1])
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
} else {
|
|
|
- err := s.Leave(parts[1])
|
|
|
- if err != nil {
|
|
|
- resp = fmt.Sprintf("ERR %v", err)
|
|
|
- } else {
|
|
|
- resp = "OK Leave request sent"
|
|
|
- }
|
|
|
+ resp = "OK Leave request sent"
|
|
|
}
|
|
|
+ }
|
|
|
+
|
|
|
+ // --- Admin Commands ---
|
|
|
+
|
|
|
+ case "USER_LIST":
|
|
|
+ users := s.AuthManager.ListUsers()
|
|
|
+ data, err := json.Marshal(users)
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
+ } else {
|
|
|
+ // Ensure it's a single line for TCP protocol simplicity
|
|
|
+ jsonStr := string(data)
|
|
|
+ resp = fmt.Sprintf("OK %s", jsonStr)
|
|
|
+ }
|
|
|
|
|
|
- // --- Admin Commands ---
|
|
|
+ case "ROLE_LIST":
|
|
|
+ roles := s.AuthManager.ListRoles()
|
|
|
+ data, err := json.Marshal(roles)
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
+ } else {
|
|
|
+ resp = fmt.Sprintf("OK %s", string(data))
|
|
|
+ }
|
|
|
|
|
|
- case "USER_LIST":
|
|
|
- users := s.AuthManager.ListUsers()
|
|
|
- data, err := json.Marshal(users)
|
|
|
+ case "USER_CREATE":
|
|
|
+ // Usage: USER_CREATE <username> <password> <role1,role2>
|
|
|
+ if len(parts) < 3 {
|
|
|
+ resp = "ERR usage: USER_CREATE <user> <pass> [roles]"
|
|
|
+ } else {
|
|
|
+ u := parts[1]
|
|
|
+ p := parts[2]
|
|
|
+ var roles []string
|
|
|
+ if len(parts) > 3 {
|
|
|
+ roles = strings.Split(parts[3], ",")
|
|
|
+ }
|
|
|
+ // Use RegisterUser (sync)
|
|
|
+ err := s.AuthManager.RegisterUser(u, p, roles)
|
|
|
if err != nil {
|
|
|
resp = fmt.Sprintf("ERR %v", err)
|
|
|
} else {
|
|
|
- // Ensure it's a single line for TCP protocol simplicity
|
|
|
- jsonStr := string(data)
|
|
|
- // JSON marshal might include newlines if indentation was used (it's not by default, but safe to check)
|
|
|
- // However, standard json.Marshal does not indent.
|
|
|
- resp = fmt.Sprintf("OK %s", jsonStr)
|
|
|
+ resp = "OK"
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- case "ROLE_LIST":
|
|
|
- roles := s.AuthManager.ListRoles()
|
|
|
- data, err := json.Marshal(roles)
|
|
|
+ case "ROLE_CREATE":
|
|
|
+ // Usage: ROLE_CREATE <name>
|
|
|
+ if len(parts) < 2 {
|
|
|
+ resp = "ERR usage: ROLE_CREATE <name>"
|
|
|
+ } else {
|
|
|
+ name := parts[1]
|
|
|
+ err := s.AuthManager.CreateRole(name)
|
|
|
if err != nil {
|
|
|
resp = fmt.Sprintf("ERR %v", err)
|
|
|
} else {
|
|
|
- resp = fmt.Sprintf("OK %s", string(data))
|
|
|
+ resp = "OK"
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- case "USER_CREATE":
|
|
|
- // Usage: USER_CREATE <username> <password> <role1,role2>
|
|
|
- if len(parts) < 3 {
|
|
|
- resp = "ERR usage: USER_CREATE <user> <pass> [roles]"
|
|
|
+ case "ROLE_PERMISSION_ADD":
|
|
|
+ // Usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>
|
|
|
+ // Actions: comma separated list of actions (read,write,admin,*)
|
|
|
+ if len(parts) < 4 {
|
|
|
+ resp = "ERR usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>"
|
|
|
+ } else {
|
|
|
+ roleName := parts[1]
|
|
|
+ pattern := parts[2]
|
|
|
+ actionsStr := parts[3]
|
|
|
+ actions := strings.Split(actionsStr, ",")
|
|
|
+
|
|
|
+ rolePtr, err := s.AuthManager.GetRole(roleName)
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
} else {
|
|
|
- u := parts[1]
|
|
|
- p := parts[2]
|
|
|
- var roles []string
|
|
|
- if len(parts) > 3 {
|
|
|
- roles = strings.Split(parts[3], ",")
|
|
|
- }
|
|
|
- // Use RegisterUser (sync)
|
|
|
- err := s.AuthManager.RegisterUser(u, p, roles)
|
|
|
- if err != nil {
|
|
|
- resp = fmt.Sprintf("ERR %v", err)
|
|
|
- } else {
|
|
|
- resp = "OK"
|
|
|
+ // Create a copy to modify
|
|
|
+ role := *rolePtr
|
|
|
+
|
|
|
+ // Deep copy permissions to avoid potential side effects on cached object
|
|
|
+ originalPerms := role.Permissions
|
|
|
+ role.Permissions = make([]Permission, len(originalPerms))
|
|
|
+ copy(role.Permissions, originalPerms)
|
|
|
+
|
|
|
+ newPerm := Permission{
|
|
|
+ KeyPattern: pattern,
|
|
|
+ Actions: actions,
|
|
|
}
|
|
|
- }
|
|
|
|
|
|
- case "ROLE_CREATE":
|
|
|
- // Usage: ROLE_CREATE <name>
|
|
|
- if len(parts) < 2 {
|
|
|
- resp = "ERR usage: ROLE_CREATE <name>"
|
|
|
- } else {
|
|
|
- name := parts[1]
|
|
|
- err := s.AuthManager.CreateRole(name)
|
|
|
+ // Upsert logic: Update if exists, Append if new
|
|
|
+ found := false
|
|
|
+ for i, p := range role.Permissions {
|
|
|
+ if p.KeyPattern == pattern {
|
|
|
+ role.Permissions[i] = newPerm
|
|
|
+ found = true
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if !found {
|
|
|
+ role.Permissions = append(role.Permissions, newPerm)
|
|
|
+ }
|
|
|
+
|
|
|
+ err := s.AuthManager.UpdateRole(role)
|
|
|
if err != nil {
|
|
|
resp = fmt.Sprintf("ERR %v", err)
|
|
|
} else {
|
|
|
resp = "OK"
|
|
|
}
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- case "ROLE_PERMISSION_ADD":
|
|
|
- // Usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>
|
|
|
- // Actions: comma separated list of actions (read,write,admin,*)
|
|
|
- if len(parts) < 4 {
|
|
|
- resp = "ERR usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>"
|
|
|
- } else {
|
|
|
- roleName := parts[1]
|
|
|
- pattern := parts[2]
|
|
|
- actionsStr := parts[3]
|
|
|
- actions := strings.Split(actionsStr, ",")
|
|
|
+ case "ROLE_PERMISSION_REMOVE":
|
|
|
+ // Usage: ROLE_PERMISSION_REMOVE <role> <pattern>
|
|
|
+ if len(parts) < 3 {
|
|
|
+ resp = "ERR usage: ROLE_PERMISSION_REMOVE <role> <pattern>"
|
|
|
+ } else {
|
|
|
+ roleName := parts[1]
|
|
|
+ pattern := parts[2]
|
|
|
|
|
|
- rolePtr, err := s.AuthManager.GetRole(roleName)
|
|
|
- if err != nil {
|
|
|
- resp = fmt.Sprintf("ERR %v", err)
|
|
|
- } else {
|
|
|
- // Create a copy to modify
|
|
|
- role := *rolePtr
|
|
|
-
|
|
|
- // Deep copy permissions to avoid potential side effects on cached object
|
|
|
- originalPerms := role.Permissions
|
|
|
- role.Permissions = make([]Permission, len(originalPerms))
|
|
|
- copy(role.Permissions, originalPerms)
|
|
|
-
|
|
|
- newPerm := Permission{
|
|
|
- KeyPattern: pattern,
|
|
|
- Actions: actions,
|
|
|
+ rolePtr, err := s.AuthManager.GetRole(roleName)
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
+ } else {
|
|
|
+ role := *rolePtr
|
|
|
+ originalPerms := role.Permissions
|
|
|
+ newPerms := make([]Permission, 0, len(originalPerms))
|
|
|
+
|
|
|
+ found := false
|
|
|
+ for _, p := range originalPerms {
|
|
|
+ if p.KeyPattern == pattern {
|
|
|
+ found = true
|
|
|
+ continue
|
|
|
}
|
|
|
- role.Permissions = append(role.Permissions, newPerm)
|
|
|
-
|
|
|
+ newPerms = append(newPerms, p)
|
|
|
+ }
|
|
|
+
|
|
|
+ if !found {
|
|
|
+ resp = "ERR permission not found"
|
|
|
+ } else {
|
|
|
+ role.Permissions = newPerms
|
|
|
err := s.AuthManager.UpdateRole(role)
|
|
|
if err != nil {
|
|
|
resp = fmt.Sprintf("ERR %v", err)
|
|
|
@@ -429,39 +600,41 @@ func (s *KVServer) handleTCPConnection(conn net.Conn) {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
- case "USER_UNLOCK":
|
|
|
- // Usage: USER_UNLOCK <username>
|
|
|
- if len(parts) < 2 {
|
|
|
- resp = "ERR usage: USER_UNLOCK <username>"
|
|
|
+ }
|
|
|
+
|
|
|
+ case "USER_UNLOCK":
|
|
|
+ // Usage: USER_UNLOCK <username>
|
|
|
+ if len(parts) < 2 {
|
|
|
+ resp = "ERR usage: USER_UNLOCK <username>"
|
|
|
+ } else {
|
|
|
+ // Manually clear the lock key
|
|
|
+ // Note: accessing server.Set directly bypasses auth check which is fine here
|
|
|
+ // as the TCP session itself should be authenticated as admin ideally.
|
|
|
+ // For now we trust the connected client has rights or we check session.
|
|
|
+ // In real impl, check if session.username is root or has admin perm.
|
|
|
+ userToUnlock := parts[1]
|
|
|
+ // We use Del to remove the lock key
|
|
|
+ err := s.Del("system.lock." + userToUnlock)
|
|
|
+ if err != nil {
|
|
|
+ resp = fmt.Sprintf("ERR %v", err)
|
|
|
} else {
|
|
|
- // Manually clear the lock key
|
|
|
- // Note: accessing server.Set directly bypasses auth check which is fine here
|
|
|
- // as the TCP session itself should be authenticated as admin ideally.
|
|
|
- // For now we trust the connected client has rights or we check session.
|
|
|
- // In real impl, check if session.username is root or has admin perm.
|
|
|
- userToUnlock := parts[1]
|
|
|
- // We use Del to remove the lock key
|
|
|
- err := s.Del("system.lock." + userToUnlock)
|
|
|
- if err != nil {
|
|
|
- resp = fmt.Sprintf("ERR %v", err)
|
|
|
- } else {
|
|
|
- resp = "OK"
|
|
|
- }
|
|
|
+ resp = "OK"
|
|
|
}
|
|
|
-
|
|
|
- case "EXIT", "QUIT":
|
|
|
- session.writer.WriteString("BYE\n")
|
|
|
- session.writer.Flush()
|
|
|
- return
|
|
|
-
|
|
|
- default:
|
|
|
- s.Raft.config.Logger.Warn("Unknown command received: %s (parts: %v)", cmd, parts)
|
|
|
- resp = fmt.Sprintf("ERR unknown command: %s", cmd)
|
|
|
}
|
|
|
|
|
|
- session.writer.WriteString(resp + "\n")
|
|
|
- session.writer.Flush()
|
|
|
+ case "EXIT", "QUIT":
|
|
|
+ resp = "BYE"
|
|
|
+ // Need signal to close connection after write
|
|
|
+ // For simplicity, handle it in handleTCPConnection loop break,
|
|
|
+ // but here we just return the string.
|
|
|
+ // Actually, BYE handling is tricky in async writer.
|
|
|
+ // Let's keep connection open or let client close it.
|
|
|
+ // Or send special signal?
|
|
|
+ // For now, simple return. Client will read BYE and close.
|
|
|
+
|
|
|
+ default:
|
|
|
+ s.Raft.config.Logger.Warn("Unknown command received: %s (parts: %v)", cmd, parts)
|
|
|
+ resp = fmt.Sprintf("ERR unknown command: %s", cmd)
|
|
|
}
|
|
|
+ return resp
|
|
|
}
|
|
|
-
|