tcp_server.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. package raft
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "net"
  7. "runtime"
  8. "strconv"
  9. "strings"
  10. )
  11. // TCPClientSession holds state for a TCP connection
  12. type TCPClientSession struct {
  13. conn net.Conn
  14. server *KVServer
  15. token string
  16. username string
  17. reader *bufio.Reader
  18. writer *bufio.Writer
  19. remoteAddr string
  20. }
  21. func (s *KVServer) StartTCPServer(addr string) error {
  22. listener, err := net.Listen("tcp", addr)
  23. if err != nil {
  24. return err
  25. }
  26. s.Raft.config.Logger.Info("TCP API server listening on %s", addr)
  27. go func() {
  28. defer listener.Close()
  29. for {
  30. conn, err := listener.Accept()
  31. if err != nil {
  32. if s.stopCh != nil {
  33. select {
  34. case <-s.stopCh:
  35. return
  36. default:
  37. }
  38. }
  39. s.Raft.config.Logger.Error("TCP Accept error: %v", err)
  40. continue
  41. }
  42. go s.handleTCPConnection(conn)
  43. }
  44. }()
  45. return nil
  46. }
  47. func (s *KVServer) handleTCPConnection(conn net.Conn) {
  48. session := &TCPClientSession{
  49. conn: conn,
  50. server: s,
  51. reader: bufio.NewReader(conn),
  52. writer: bufio.NewWriter(conn),
  53. remoteAddr: conn.RemoteAddr().String(),
  54. }
  55. defer conn.Close()
  56. for {
  57. line, err := session.reader.ReadString('\n')
  58. if err != nil {
  59. return // Connection closed
  60. }
  61. line = strings.TrimSpace(line)
  62. if line == "" {
  63. continue
  64. }
  65. parts := strings.Fields(line)
  66. if len(parts) == 0 {
  67. continue
  68. }
  69. cmd := strings.ToUpper(parts[0])
  70. var resp string
  71. switch cmd {
  72. case "LOGIN":
  73. if len(parts) < 3 {
  74. resp = "ERR usage: LOGIN <username> <password> [otp]"
  75. } else {
  76. user := parts[1]
  77. pass := parts[2]
  78. otp := ""
  79. if len(parts) > 3 {
  80. otp = parts[3]
  81. }
  82. // Extract IP
  83. ip := session.remoteAddr
  84. if host, _, err := net.SplitHostPort(ip); err == nil {
  85. ip = host
  86. }
  87. token, err := s.AuthManager.Login(user, pass, otp, ip)
  88. if err != nil {
  89. resp = fmt.Sprintf("ERR %v", err)
  90. } else {
  91. session.token = token
  92. session.username = user
  93. resp = fmt.Sprintf("OK %s", token)
  94. }
  95. }
  96. case "AUTH":
  97. if len(parts) < 2 {
  98. resp = "ERR usage: AUTH <token>"
  99. } else {
  100. token := parts[1]
  101. // Verify token
  102. sess, err := s.AuthManager.GetSession(token)
  103. if err != nil {
  104. resp = fmt.Sprintf("ERR %v", err)
  105. } else {
  106. session.token = token
  107. session.username = sess.Username
  108. resp = "OK"
  109. }
  110. }
  111. case "LOGOUT":
  112. if session.token != "" {
  113. s.AuthManager.Logout(session.token)
  114. session.token = ""
  115. session.username = ""
  116. }
  117. resp = "OK"
  118. case "GET":
  119. if len(parts) < 2 {
  120. resp = "ERR usage: GET <key>"
  121. } else {
  122. key := parts[1]
  123. val, found, err := s.GetLinearAuthenticated(key, session.token)
  124. if err != nil {
  125. resp = fmt.Sprintf("ERR %v", err)
  126. } else if !found {
  127. resp = "ERR not found"
  128. } else {
  129. resp = fmt.Sprintf("OK %s", val)
  130. }
  131. }
  132. case "SET":
  133. if len(parts) < 3 {
  134. resp = "ERR usage: SET <key> <value>"
  135. } else {
  136. key := parts[1]
  137. // Value might contain spaces, join the rest
  138. val := strings.Join(parts[2:], " ")
  139. err := s.SetAuthenticated(key, val, session.token)
  140. if err != nil {
  141. resp = fmt.Sprintf("ERR %v", err)
  142. } else {
  143. resp = "OK"
  144. }
  145. }
  146. case "DEL":
  147. if len(parts) < 2 {
  148. resp = "ERR usage: DEL <key>"
  149. } else {
  150. key := parts[1]
  151. err := s.DelAuthenticated(key, session.token)
  152. if err != nil {
  153. resp = fmt.Sprintf("ERR %v", err)
  154. } else {
  155. resp = "OK"
  156. }
  157. }
  158. case "WHOAMI":
  159. if session.username == "" {
  160. resp = "Guest"
  161. } else {
  162. resp = session.username
  163. }
  164. case "HELP":
  165. helpText := `Available Commands:
  166. GET <key> - Get value
  167. SET <key> <value> - Set value
  168. DEL <key> - Delete value
  169. SEARCH <pattern> [limit] - Search keys (e.g. user.*)
  170. COUNT <pattern> - Count keys
  171. INFO - Show system stats
  172. WHOAMI - Show current user
  173. JOIN <id> <addr> - Add node (Root only)
  174. LEAVE <id> - Remove node (Root only)
  175. USER_LIST - List users (Admin)
  176. ROLE_LIST - List roles (Admin)
  177. LOGIN/LOGOUT/EXIT`
  178. resp = "OK " + helpText
  179. case "INFO":
  180. // Check permission (Root only if auth enabled)
  181. if s.AuthManager.IsEnabled() {
  182. sess, err := s.AuthManager.GetSession(session.token)
  183. if err != nil {
  184. resp = fmt.Sprintf("ERR %v", err)
  185. break
  186. }
  187. if sess.Username != "root" {
  188. resp = "ERR Permission Denied: Root access required"
  189. break
  190. }
  191. }
  192. // Gather stats
  193. stats := s.GetStats()
  194. health := s.HealthCheck()
  195. dbSize := s.GetDBSize()
  196. logSize := s.GetLogSize()
  197. var m runtime.MemStats
  198. runtime.ReadMemStats(&m)
  199. // Construct JSON response
  200. info := map[string]interface{}{
  201. "node": map[string]interface{}{
  202. "id": health.NodeID,
  203. "state": health.State,
  204. "term": health.Term,
  205. "leader": health.LeaderID,
  206. "healthy": health.IsHealthy,
  207. },
  208. "storage": map[string]interface{}{
  209. "db_size": dbSize,
  210. "log_size": logSize,
  211. "mem_alloc": m.Alloc,
  212. "mem_sys": m.Sys,
  213. "num_gc": m.NumGC,
  214. },
  215. "indices": map[string]interface{}{
  216. "commit_index": stats.CommitIndex,
  217. "applied_index": stats.LastApplied,
  218. "last_log_index": stats.LastLogIndex,
  219. "db_applied": s.DB.GetLastAppliedIndex(),
  220. },
  221. "cluster": stats.ClusterNodes,
  222. "cluster_size": stats.ClusterSize,
  223. }
  224. data, err := json.Marshal(info)
  225. if err != nil {
  226. resp = fmt.Sprintf("ERR %v", err)
  227. } else {
  228. resp = "OK " + string(data)
  229. }
  230. case "SEARCH":
  231. // Usage: SEARCH <pattern> [limit] [offset]
  232. if len(parts) < 2 {
  233. resp = "ERR usage: SEARCH <pattern> [limit] [offset]"
  234. } else {
  235. pattern := parts[1]
  236. limit := 20
  237. offset := 0
  238. if len(parts) >= 3 {
  239. if l, err := strconv.Atoi(parts[2]); err == nil {
  240. limit = l
  241. }
  242. }
  243. if len(parts) >= 4 {
  244. if o, err := strconv.Atoi(parts[3]); err == nil {
  245. offset = o
  246. }
  247. }
  248. results, err := s.SearchAuthenticated(pattern, limit, offset, session.token)
  249. if err != nil {
  250. resp = fmt.Sprintf("ERR %v", err)
  251. } else {
  252. data, _ := json.Marshal(results)
  253. resp = "OK " + string(data)
  254. }
  255. }
  256. case "COUNT":
  257. // Usage: COUNT <pattern>
  258. if len(parts) < 2 {
  259. resp = "ERR usage: COUNT <pattern>"
  260. } else {
  261. pattern := parts[1]
  262. count, err := s.CountAuthenticated(pattern, session.token)
  263. if err != nil {
  264. resp = fmt.Sprintf("ERR %v", err)
  265. } else {
  266. resp = fmt.Sprintf("OK %d", count)
  267. }
  268. }
  269. case "JOIN":
  270. // Usage: JOIN <id> <addr>
  271. // Admin only
  272. if s.AuthManager.IsEnabled() {
  273. sess, err := s.AuthManager.GetSession(session.token)
  274. if err != nil || sess.Username != "root" {
  275. resp = "ERR Permission Denied: Root access required"
  276. break
  277. }
  278. }
  279. if len(parts) < 3 {
  280. resp = "ERR usage: JOIN <id> <addr>"
  281. } else {
  282. err := s.Join(parts[1], parts[2])
  283. if err != nil {
  284. resp = fmt.Sprintf("ERR %v", err)
  285. } else {
  286. resp = "OK Join request sent"
  287. }
  288. }
  289. case "LEAVE":
  290. // Usage: LEAVE <id>
  291. // Admin only
  292. if s.AuthManager.IsEnabled() {
  293. sess, err := s.AuthManager.GetSession(session.token)
  294. if err != nil || sess.Username != "root" {
  295. resp = "ERR Permission Denied: Root access required"
  296. break
  297. }
  298. }
  299. if len(parts) < 2 {
  300. resp = "ERR usage: LEAVE <id>"
  301. } else {
  302. err := s.Leave(parts[1])
  303. if err != nil {
  304. resp = fmt.Sprintf("ERR %v", err)
  305. } else {
  306. resp = "OK Leave request sent"
  307. }
  308. }
  309. // --- Admin Commands ---
  310. case "USER_LIST":
  311. users := s.AuthManager.ListUsers()
  312. data, err := json.Marshal(users)
  313. if err != nil {
  314. resp = fmt.Sprintf("ERR %v", err)
  315. } else {
  316. // Ensure it's a single line for TCP protocol simplicity
  317. jsonStr := string(data)
  318. // JSON marshal might include newlines if indentation was used (it's not by default, but safe to check)
  319. // However, standard json.Marshal does not indent.
  320. resp = fmt.Sprintf("OK %s", jsonStr)
  321. }
  322. case "ROLE_LIST":
  323. roles := s.AuthManager.ListRoles()
  324. data, err := json.Marshal(roles)
  325. if err != nil {
  326. resp = fmt.Sprintf("ERR %v", err)
  327. } else {
  328. resp = fmt.Sprintf("OK %s", string(data))
  329. }
  330. case "USER_CREATE":
  331. // Usage: USER_CREATE <username> <password> <role1,role2>
  332. if len(parts) < 3 {
  333. resp = "ERR usage: USER_CREATE <user> <pass> [roles]"
  334. } else {
  335. u := parts[1]
  336. p := parts[2]
  337. var roles []string
  338. if len(parts) > 3 {
  339. roles = strings.Split(parts[3], ",")
  340. }
  341. // Use RegisterUser (sync)
  342. err := s.AuthManager.RegisterUser(u, p, roles)
  343. if err != nil {
  344. resp = fmt.Sprintf("ERR %v", err)
  345. } else {
  346. resp = "OK"
  347. }
  348. }
  349. case "ROLE_CREATE":
  350. // Usage: ROLE_CREATE <name>
  351. if len(parts) < 2 {
  352. resp = "ERR usage: ROLE_CREATE <name>"
  353. } else {
  354. name := parts[1]
  355. err := s.AuthManager.CreateRole(name)
  356. if err != nil {
  357. resp = fmt.Sprintf("ERR %v", err)
  358. } else {
  359. resp = "OK"
  360. }
  361. }
  362. case "ROLE_PERMISSION_ADD":
  363. // Usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>
  364. // Actions: comma separated list of actions (read,write,admin,*)
  365. if len(parts) < 4 {
  366. resp = "ERR usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>"
  367. } else {
  368. roleName := parts[1]
  369. pattern := parts[2]
  370. actionsStr := parts[3]
  371. actions := strings.Split(actionsStr, ",")
  372. rolePtr, err := s.AuthManager.GetRole(roleName)
  373. if err != nil {
  374. resp = fmt.Sprintf("ERR %v", err)
  375. } else {
  376. // Create a copy to modify
  377. role := *rolePtr
  378. // Deep copy permissions to avoid potential side effects on cached object
  379. originalPerms := role.Permissions
  380. role.Permissions = make([]Permission, len(originalPerms))
  381. copy(role.Permissions, originalPerms)
  382. newPerm := Permission{
  383. KeyPattern: pattern,
  384. Actions: actions,
  385. }
  386. role.Permissions = append(role.Permissions, newPerm)
  387. err := s.AuthManager.UpdateRole(role)
  388. if err != nil {
  389. resp = fmt.Sprintf("ERR %v", err)
  390. } else {
  391. resp = "OK"
  392. }
  393. }
  394. }
  395. case "USER_UNLOCK":
  396. // Usage: USER_UNLOCK <username>
  397. if len(parts) < 2 {
  398. resp = "ERR usage: USER_UNLOCK <username>"
  399. } else {
  400. // Manually clear the lock key
  401. // Note: accessing server.Set directly bypasses auth check which is fine here
  402. // as the TCP session itself should be authenticated as admin ideally.
  403. // For now we trust the connected client has rights or we check session.
  404. // In real impl, check if session.username is root or has admin perm.
  405. userToUnlock := parts[1]
  406. // We use Del to remove the lock key
  407. err := s.Del("system.lock." + userToUnlock)
  408. if err != nil {
  409. resp = fmt.Sprintf("ERR %v", err)
  410. } else {
  411. resp = "OK"
  412. }
  413. }
  414. case "EXIT", "QUIT":
  415. session.writer.WriteString("BYE\n")
  416. session.writer.Flush()
  417. return
  418. default:
  419. s.Raft.config.Logger.Warn("Unknown command received: %s (parts: %v)", cmd, parts)
  420. resp = fmt.Sprintf("ERR unknown command: %s", cmd)
  421. }
  422. session.writer.WriteString(resp + "\n")
  423. session.writer.Flush()
  424. }
  425. }