| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- package raft
- import (
- "bufio"
- "encoding/json"
- "fmt"
- "net"
- "strings"
- )
- // TCPClientSession holds state for a TCP connection
- type TCPClientSession struct {
- conn net.Conn
- server *KVServer
- token string
- username string
- reader *bufio.Reader
- writer *bufio.Writer
- remoteAddr string
- }
- func (s *KVServer) StartTCPServer(addr string) error {
- listener, err := net.Listen("tcp", addr)
- if err != nil {
- return err
- }
- s.Raft.config.Logger.Info("TCP API server listening on %s", addr)
- go func() {
- defer listener.Close()
- for {
- conn, err := listener.Accept()
- if err != nil {
- if s.stopCh != nil {
- select {
- case <-s.stopCh:
- return
- default:
- }
- }
- s.Raft.config.Logger.Error("TCP Accept error: %v", err)
- continue
- }
- go s.handleTCPConnection(conn)
- }
- }()
- return nil
- }
- func (s *KVServer) handleTCPConnection(conn net.Conn) {
- session := &TCPClientSession{
- conn: conn,
- server: s,
- reader: bufio.NewReader(conn),
- writer: bufio.NewWriter(conn),
- remoteAddr: conn.RemoteAddr().String(),
- }
- defer conn.Close()
- for {
- line, err := session.reader.ReadString('\n')
- if err != nil {
- return // Connection closed
- }
- line = strings.TrimSpace(line)
- if line == "" {
- continue
- }
- parts := strings.Fields(line)
- if len(parts) == 0 {
- continue
- }
- cmd := strings.ToUpper(parts[0])
- var resp string
- switch cmd {
- case "LOGIN":
- if len(parts) < 3 {
- resp = "ERR usage: LOGIN <username> <password> [otp]"
- } else {
- user := parts[1]
- pass := parts[2]
- otp := ""
- if len(parts) > 3 {
- otp = parts[3]
- }
-
- // Extract IP
- ip := session.remoteAddr
- if host, _, err := net.SplitHostPort(ip); err == nil {
- ip = host
- }
- token, err := s.AuthManager.Login(user, pass, otp, ip)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- session.token = token
- session.username = user
- resp = fmt.Sprintf("OK %s", token)
- }
- }
- case "AUTH":
- if len(parts) < 2 {
- resp = "ERR usage: AUTH <token>"
- } else {
- token := parts[1]
- // Verify token
- sess, err := s.AuthManager.GetSession(token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- session.token = token
- session.username = sess.Username
- resp = "OK"
- }
- }
- case "LOGOUT":
- if session.token != "" {
- s.AuthManager.Logout(session.token)
- session.token = ""
- session.username = ""
- }
- resp = "OK"
- case "GET":
- if len(parts) < 2 {
- resp = "ERR usage: GET <key>"
- } else {
- key := parts[1]
- val, found, err := s.GetLinearAuthenticated(key, session.token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else if !found {
- resp = "ERR not found"
- } else {
- resp = fmt.Sprintf("OK %s", val)
- }
- }
- case "SET":
- if len(parts) < 3 {
- resp = "ERR usage: SET <key> <value>"
- } else {
- key := parts[1]
- // Value might contain spaces, join the rest
- val := strings.Join(parts[2:], " ")
- err := s.SetAuthenticated(key, val, session.token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
- case "DEL":
- if len(parts) < 2 {
- resp = "ERR usage: DEL <key>"
- } else {
- key := parts[1]
- err := s.DelAuthenticated(key, session.token)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
-
- case "WHOAMI":
- if session.username == "" {
- resp = "Guest"
- } else {
- resp = session.username
- }
- // --- Admin Commands ---
- case "USER_LIST":
- users := s.AuthManager.ListUsers()
- data, err := json.Marshal(users)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- // Ensure it's a single line for TCP protocol simplicity
- jsonStr := string(data)
- // JSON marshal might include newlines if indentation was used (it's not by default, but safe to check)
- // However, standard json.Marshal does not indent.
- resp = fmt.Sprintf("OK %s", jsonStr)
- }
- case "ROLE_LIST":
- roles := s.AuthManager.ListRoles()
- data, err := json.Marshal(roles)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = fmt.Sprintf("OK %s", string(data))
- }
- case "USER_CREATE":
- // Usage: USER_CREATE <username> <password> <role1,role2>
- if len(parts) < 3 {
- resp = "ERR usage: USER_CREATE <user> <pass> [roles]"
- } else {
- u := parts[1]
- p := parts[2]
- var roles []string
- if len(parts) > 3 {
- roles = strings.Split(parts[3], ",")
- }
- // Use RegisterUser (sync)
- err := s.AuthManager.RegisterUser(u, p, roles)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
- case "ROLE_CREATE":
- // Usage: ROLE_CREATE <name>
- if len(parts) < 2 {
- resp = "ERR usage: ROLE_CREATE <name>"
- } else {
- name := parts[1]
- err := s.AuthManager.CreateRole(name)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
-
- case "USER_UNLOCK":
- // Usage: USER_UNLOCK <username>
- if len(parts) < 2 {
- resp = "ERR usage: USER_UNLOCK <username>"
- } else {
- // Manually clear the lock key
- // Note: accessing server.Set directly bypasses auth check which is fine here
- // as the TCP session itself should be authenticated as admin ideally.
- // For now we trust the connected client has rights or we check session.
- // In real impl, check if session.username is root or has admin perm.
- userToUnlock := parts[1]
- // We use Del to remove the lock key
- err := s.Del("system.lock." + userToUnlock)
- if err != nil {
- resp = fmt.Sprintf("ERR %v", err)
- } else {
- resp = "OK"
- }
- }
- case "EXIT", "QUIT":
- session.writer.WriteString("BYE\n")
- session.writer.Flush()
- return
- default:
- s.Raft.config.Logger.Warn("Unknown command received: %s (parts: %v)", cmd, parts)
- resp = fmt.Sprintf("ERR unknown command: %s", cmd)
- }
- session.writer.WriteString(resp + "\n")
- session.writer.Flush()
- }
- }
|