| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467 |
- package raft
- import (
- "bufio"
- "encoding/json"
- "fmt"
- "net"
- "runtime"
- "strconv"
- "strings"
- )
- // TCPClientSession holds state for a TCP connection
- type TCPClientSession struct {
- conn net.Conn
- server *KVServer
- token string
- username string
- reader *bufio.Reader
- writer *bufio.Writer
- remoteAddr string
- }
- func (s *KVServer) StartTCPServer(addr string) error {
- listener, err := net.Listen("tcp", addr)
- if err != nil {
- return err
- }
- s.Raft.config.Logger.Info("TCP API server listening on %s", addr)
- go func() {
- defer listener.Close()
- for {
- conn, err := listener.Accept()
- if err != nil {
- if s.stopCh != nil {
- select {
- case <-s.stopCh:
- return
- default:
- }
- }
- s.Raft.config.Logger.Error("TCP Accept error: %v", err)
- continue
- }
- go s.handleTCPConnection(conn)
- }
- }()
- return nil
- }
- func (s *KVServer) handleTCPConnection(conn net.Conn) {
- session := &TCPClientSession{
- conn: conn,
- server: s,
- reader: bufio.NewReader(conn),
- writer: bufio.NewWriter(conn),
- remoteAddr: conn.RemoteAddr().String(),
- }
- defer conn.Close()
- for {
- line, err := session.reader.ReadString('\n')
- if err != nil {
- return // Connection closed
- }
- line = strings.TrimSpace(line)
- if line == "" {
- continue
- }
- parts := strings.Fields(line)
- if len(parts) == 0 {
- continue
- }
- cmd := strings.ToUpper(parts[0])
- var resp string
- switch cmd {
- case "LOGIN":
- if len(parts) < 3 {
- resp = "ERR usage: LOGIN <username> <password> [otp]"
- } else {
- user := parts[1]
- pass := parts[2]
- otp := ""
- if len(parts) > 3 {
- otp = parts[3]
- }
-
- // Extract IP
- ip := session.remoteAddr
- if host, _, err := net.SplitHostPort(ip); err == nil {
- ip = host
- }
- token, err := s.AuthManager.Login(user, pass, otp, ip)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- session.token = token
- session.username = user
- resp = fmt.Sprintf("OK %s", token)
- }
- }
- case "AUTH":
- if len(parts) < 2 {
- resp = "ERR usage: AUTH <token>"
- } else {
- token := parts[1]
- // Verify token
- sess, err := s.AuthManager.GetSession(token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- session.token = token
- session.username = sess.Username
- resp = "OK"
- }
- }
- case "LOGOUT":
- if session.token != "" {
- s.AuthManager.Logout(session.token)
- session.token = ""
- session.username = ""
- }
- resp = "OK"
- case "GET":
- if len(parts) < 2 {
- resp = "ERR usage: GET <key>"
- } else {
- key := parts[1]
- val, found, err := s.GetLinearAuthenticated(key, session.token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else if !found {
- resp = "ERR not found"
- } else {
- resp = fmt.Sprintf("OK %s", val)
- }
- }
- case "SET":
- if len(parts) < 3 {
- resp = "ERR usage: SET <key> <value>"
- } else {
- key := parts[1]
- // Value might contain spaces, join the rest
- val := strings.Join(parts[2:], " ")
- err := s.SetAuthenticated(key, val, session.token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
- case "DEL":
- if len(parts) < 2 {
- resp = "ERR usage: DEL <key>"
- } else {
- key := parts[1]
- err := s.DelAuthenticated(key, session.token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
-
- case "WHOAMI":
- if session.username == "" {
- resp = "Guest"
- } else {
- resp = session.username
- }
- case "HELP":
- helpText := `Available Commands:
- GET <key> - Get value
- SET <key> <value> - Set value
- DEL <key> - Delete value
- SEARCH <pattern> [limit] - Search keys (e.g. user.*)
- COUNT <pattern> - Count keys
- INFO - Show system stats
- WHOAMI - Show current user
- JOIN <id> <addr> - Add node (Root only)
- LEAVE <id> - Remove node (Root only)
- USER_LIST - List users (Admin)
- ROLE_LIST - List roles (Admin)
- LOGIN/LOGOUT/EXIT`
- resp = "OK " + helpText
- case "INFO":
- // Check permission (Root only if auth enabled)
- if s.AuthManager.IsEnabled() {
- sess, err := s.AuthManager.GetSession(session.token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- break
- }
- if sess.Username != "root" {
- resp = "ERR Permission Denied: Root access required"
- break
- }
- }
- // Gather stats
- stats := s.GetStats()
- health := s.HealthCheck()
- dbSize := s.GetDBSize()
- logSize := s.GetLogSize()
- var m runtime.MemStats
- runtime.ReadMemStats(&m)
- // Construct JSON response
- info := map[string]interface{}{
- "node": map[string]interface{}{
- "id": health.NodeID,
- "state": health.State,
- "term": health.Term,
- "leader": health.LeaderID,
- "healthy": health.IsHealthy,
- },
- "storage": map[string]interface{}{
- "db_size": dbSize,
- "log_size": logSize,
- "mem_alloc": m.Alloc,
- "mem_sys": m.Sys,
- "num_gc": m.NumGC,
- },
- "indices": map[string]interface{}{
- "commit_index": stats.CommitIndex,
- "applied_index": stats.LastApplied,
- "last_log_index": stats.LastLogIndex,
- "db_applied": s.DB.GetLastAppliedIndex(),
- },
- "cluster": stats.ClusterNodes,
- "cluster_size": stats.ClusterSize,
- }
- data, err := json.Marshal(info)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK " + string(data)
- }
- case "SEARCH":
- // Usage: SEARCH <pattern> [limit] [offset]
- if len(parts) < 2 {
- resp = "ERR usage: SEARCH <pattern> [limit] [offset]"
- } else {
- pattern := parts[1]
- limit := 20
- offset := 0
- if len(parts) >= 3 {
- if l, err := strconv.Atoi(parts[2]); err == nil {
- limit = l
- }
- }
- if len(parts) >= 4 {
- if o, err := strconv.Atoi(parts[3]); err == nil {
- offset = o
- }
- }
- results, err := s.SearchAuthenticated(pattern, limit, offset, session.token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- data, _ := json.Marshal(results)
- resp = "OK " + string(data)
- }
- }
- case "COUNT":
- // Usage: COUNT <pattern>
- if len(parts) < 2 {
- resp = "ERR usage: COUNT <pattern>"
- } else {
- pattern := parts[1]
- count, err := s.CountAuthenticated(pattern, session.token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = fmt.Sprintf("OK %d", count)
- }
- }
- case "JOIN":
- // Usage: JOIN <id> <addr>
- // Admin only
- if s.AuthManager.IsEnabled() {
- sess, err := s.AuthManager.GetSession(session.token)
- if err != nil || sess.Username != "root" {
- resp = "ERR Permission Denied: Root access required"
- break
- }
- }
- if len(parts) < 3 {
- resp = "ERR usage: JOIN <id> <addr>"
- } else {
- err := s.Join(parts[1], parts[2])
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK Join request sent"
- }
- }
- case "LEAVE":
- // Usage: LEAVE <id>
- // Admin only
- if s.AuthManager.IsEnabled() {
- sess, err := s.AuthManager.GetSession(session.token)
- if err != nil || sess.Username != "root" {
- resp = "ERR Permission Denied: Root access required"
- break
- }
- }
- if len(parts) < 2 {
- resp = "ERR usage: LEAVE <id>"
- } else {
- err := s.Leave(parts[1])
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK Leave request sent"
- }
- }
- // --- Admin Commands ---
- case "USER_LIST":
- users := s.AuthManager.ListUsers()
- data, err := json.Marshal(users)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- // Ensure it's a single line for TCP protocol simplicity
- jsonStr := string(data)
- // JSON marshal might include newlines if indentation was used (it's not by default, but safe to check)
- // However, standard json.Marshal does not indent.
- resp = fmt.Sprintf("OK %s", jsonStr)
- }
- case "ROLE_LIST":
- roles := s.AuthManager.ListRoles()
- data, err := json.Marshal(roles)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = fmt.Sprintf("OK %s", string(data))
- }
- case "USER_CREATE":
- // Usage: USER_CREATE <username> <password> <role1,role2>
- if len(parts) < 3 {
- resp = "ERR usage: USER_CREATE <user> <pass> [roles]"
- } else {
- u := parts[1]
- p := parts[2]
- var roles []string
- if len(parts) > 3 {
- roles = strings.Split(parts[3], ",")
- }
- // Use RegisterUser (sync)
- err := s.AuthManager.RegisterUser(u, p, roles)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
- case "ROLE_CREATE":
- // Usage: ROLE_CREATE <name>
- if len(parts) < 2 {
- resp = "ERR usage: ROLE_CREATE <name>"
- } else {
- name := parts[1]
- err := s.AuthManager.CreateRole(name)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
- case "ROLE_PERMISSION_ADD":
- // Usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>
- // Actions: comma separated list of actions (read,write,admin,*)
- if len(parts) < 4 {
- resp = "ERR usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>"
- } else {
- roleName := parts[1]
- pattern := parts[2]
- actionsStr := parts[3]
- actions := strings.Split(actionsStr, ",")
- rolePtr, err := s.AuthManager.GetRole(roleName)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- // Create a copy to modify
- role := *rolePtr
-
- // Deep copy permissions to avoid potential side effects on cached object
- originalPerms := role.Permissions
- role.Permissions = make([]Permission, len(originalPerms))
- copy(role.Permissions, originalPerms)
-
- newPerm := Permission{
- KeyPattern: pattern,
- Actions: actions,
- }
- role.Permissions = append(role.Permissions, newPerm)
-
- err := s.AuthManager.UpdateRole(role)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
- }
-
- case "USER_UNLOCK":
- // Usage: USER_UNLOCK <username>
- if len(parts) < 2 {
- resp = "ERR usage: USER_UNLOCK <username>"
- } else {
- // Manually clear the lock key
- // Note: accessing server.Set directly bypasses auth check which is fine here
- // as the TCP session itself should be authenticated as admin ideally.
- // For now we trust the connected client has rights or we check session.
- // In real impl, check if session.username is root or has admin perm.
- userToUnlock := parts[1]
- // We use Del to remove the lock key
- err := s.Del("system.lock." + userToUnlock)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
- case "EXIT", "QUIT":
- session.writer.WriteString("BYE\n")
- session.writer.Flush()
- return
- default:
- s.Raft.config.Logger.Warn("Unknown command received: %s (parts: %v)", cmd, parts)
- resp = fmt.Sprintf("ERR unknown command: %s", cmd)
- }
- session.writer.WriteString(resp + "\n")
- session.writer.Flush()
- }
- }
|