tcp_server.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. package raft
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net"
  8. "runtime"
  9. "strconv"
  10. "strings"
  11. "time"
  12. )
  13. // TCPClientSession holds state for a TCP connection
  14. type TCPClientSession struct {
  15. conn net.Conn
  16. server *KVServer
  17. token string
  18. username string
  19. reader *bufio.Reader
  20. writer *bufio.Writer
  21. remoteAddr string
  22. }
  23. func (s *KVServer) StartTCPServer(addr string) error {
  24. listener, err := net.Listen("tcp", addr)
  25. if err != nil {
  26. return err
  27. }
  28. s.Raft.config.Logger.Info("TCP API server listening on %s", addr)
  29. go func() {
  30. defer listener.Close()
  31. for {
  32. conn, err := listener.Accept()
  33. if err != nil {
  34. if s.stopCh != nil {
  35. select {
  36. case <-s.stopCh:
  37. return
  38. default:
  39. }
  40. }
  41. s.Raft.config.Logger.Error("TCP Accept error: %v", err)
  42. continue
  43. }
  44. go s.handleTCPConnection(conn)
  45. }
  46. }()
  47. return nil
  48. }
  49. func (s *KVServer) handleTCPConnection(conn net.Conn) {
  50. defer conn.Close()
  51. // Use larger buffers for high throughput
  52. reader := bufio.NewReaderSize(conn, 64*1024)
  53. writer := bufio.NewWriterSize(conn, 64*1024)
  54. session := &TCPClientSession{
  55. conn: conn,
  56. server: s,
  57. reader: reader,
  58. writer: writer,
  59. remoteAddr: conn.RemoteAddr().String(),
  60. }
  61. for {
  62. // Set Keep-Alive Deadline
  63. // Using a longer deadline avoids frequent syscalls if traffic is continuous
  64. conn.SetReadDeadline(time.Now().Add(60 * time.Second))
  65. // Read Request Line
  66. line, err := reader.ReadString('\n')
  67. if err != nil {
  68. return
  69. }
  70. line = strings.TrimSpace(line)
  71. if line == "" {
  72. continue
  73. }
  74. // Parse Headers
  75. var contentLength int
  76. for {
  77. hLine, err := reader.ReadString('\n')
  78. if err != nil {
  79. return
  80. }
  81. hLine = strings.TrimSpace(hLine)
  82. if hLine == "" {
  83. break
  84. }
  85. // Optimized Content-Length check
  86. // "Content-Length: 123"
  87. lowerHLine := strings.ToLower(hLine)
  88. if strings.HasPrefix(lowerHLine, "content-length:") {
  89. valStr := strings.TrimSpace(hLine[15:])
  90. if l, err := strconv.Atoi(valStr); err == nil {
  91. contentLength = l
  92. }
  93. }
  94. }
  95. // Read Body
  96. var body string
  97. if contentLength > 0 {
  98. // Optimization: Avoid allocation for small bodies if possible,
  99. // but Raft commands need to be passed as bytes/string anyway.
  100. buf := make([]byte, contentLength)
  101. if _, err := io.ReadFull(reader, buf); err != nil {
  102. return
  103. }
  104. body = string(buf)
  105. }
  106. // Execute Command Inline
  107. // This avoids the overhead of launching a goroutine per request
  108. // and the complexity/contention of a response channel.
  109. // For ASET (Async), this returns almost immediately.
  110. resp := s.executeCommandWithBody(session, line, body)
  111. // Write Response
  112. if _, err := writer.WriteString(resp + "\n"); err != nil {
  113. return
  114. }
  115. // Flush Optimization:
  116. // Only flush if there is no more data in the read buffer.
  117. // This automatically batches responses when requests are pipelined.
  118. if reader.Buffered() == 0 {
  119. if err := writer.Flush(); err != nil {
  120. return
  121. }
  122. }
  123. }
  124. }
  125. // Helper to bridge old logic
  126. func (s *KVServer) executeCommandWithBody(session *TCPClientSession, line string, body string) string {
  127. parts := strings.Fields(line)
  128. if len(parts) == 0 {
  129. return ""
  130. }
  131. cmd := strings.ToUpper(parts[0])
  132. // If body is present, it overrides the value part of the command
  133. // We handle specific commands that use body
  134. if body != "" {
  135. switch cmd {
  136. case "SET", "ASET":
  137. if len(parts) < 2 {
  138. return "ERR usage: SET <key> (value in body)"
  139. }
  140. key := parts[1]
  141. // Value is body
  142. if cmd == "SET" {
  143. if err := s.SetAuthenticated(key, body, session.token); err != nil {
  144. return fmt.Sprintf("ERR %v", err)
  145. }
  146. return "OK"
  147. } else {
  148. if err := s.SetAuthenticatedAsync(key, body, session.token); err != nil {
  149. return fmt.Sprintf("ERR %v", err)
  150. }
  151. return "OK"
  152. }
  153. }
  154. }
  155. // Fallback to legacy parsing if no body or command doesn't use body
  156. return s.executeCommand(session, line)
  157. }
  158. func (s *KVServer) executeCommand(session *TCPClientSession, line string) string {
  159. parts := strings.Fields(line)
  160. if len(parts) == 0 {
  161. return ""
  162. }
  163. cmd := strings.ToUpper(parts[0])
  164. var resp string
  165. switch cmd {
  166. case "LOGIN":
  167. if len(parts) < 3 {
  168. resp = "ERR usage: LOGIN <username> <password> [otp]"
  169. } else {
  170. user := parts[1]
  171. pass := parts[2]
  172. otp := ""
  173. if len(parts) > 3 {
  174. otp = parts[3]
  175. }
  176. // Extract IP
  177. ip := session.remoteAddr
  178. if host, _, err := net.SplitHostPort(ip); err == nil {
  179. ip = host
  180. }
  181. token, err := s.AuthManager.Login(user, pass, otp, ip)
  182. if err != nil {
  183. resp = fmt.Sprintf("ERR %v", err)
  184. } else {
  185. session.token = token
  186. session.username = user
  187. resp = fmt.Sprintf("OK %s", token)
  188. }
  189. }
  190. case "AUTH":
  191. if len(parts) < 2 {
  192. resp = "ERR usage: AUTH <token>"
  193. } else {
  194. token := parts[1]
  195. // Verify token
  196. sess, err := s.AuthManager.GetSession(token)
  197. if err != nil {
  198. resp = fmt.Sprintf("ERR %v", err)
  199. } else {
  200. session.token = token
  201. session.username = sess.Username
  202. resp = "OK"
  203. }
  204. }
  205. case "LOGOUT":
  206. if session.token != "" {
  207. s.AuthManager.Logout(session.token)
  208. session.token = ""
  209. session.username = ""
  210. }
  211. resp = "OK"
  212. case "GET":
  213. if len(parts) < 2 {
  214. resp = "ERR usage: GET <key>"
  215. } else {
  216. key := parts[1]
  217. val, found, err := s.GetLinearAuthenticated(key, session.token)
  218. if err != nil {
  219. resp = fmt.Sprintf("ERR %v", err)
  220. } else if !found {
  221. resp = "ERR not found"
  222. } else {
  223. resp = fmt.Sprintf("OK %s", val)
  224. }
  225. }
  226. case "SET":
  227. if len(parts) < 3 {
  228. resp = "ERR usage: SET <key> <value>"
  229. } else {
  230. key := parts[1]
  231. // Value might contain spaces, join the rest
  232. val := strings.Join(parts[2:], " ")
  233. // Use SetAuthenticated (Sync) by default for safety
  234. err := s.SetAuthenticated(key, val, session.token)
  235. if err != nil {
  236. resp = fmt.Sprintf("ERR %v", err)
  237. } else {
  238. resp = "OK"
  239. }
  240. }
  241. case "ASET":
  242. // Async SET for high performance
  243. if len(parts) < 3 {
  244. resp = "ERR usage: ASET <key> <value>"
  245. } else {
  246. key := parts[1]
  247. val := strings.Join(parts[2:], " ")
  248. err := s.SetAuthenticatedAsync(key, val, session.token)
  249. if err != nil {
  250. resp = fmt.Sprintf("ERR %v", err)
  251. } else {
  252. resp = "OK"
  253. }
  254. }
  255. case "DEL":
  256. if len(parts) < 2 {
  257. resp = "ERR usage: DEL <key>"
  258. } else {
  259. key := parts[1]
  260. err := s.DelAuthenticated(key, session.token)
  261. if err != nil {
  262. resp = fmt.Sprintf("ERR %v", err)
  263. } else {
  264. resp = "OK"
  265. }
  266. }
  267. case "WHOAMI":
  268. if session.username == "" {
  269. resp = "Guest"
  270. } else {
  271. resp = session.username
  272. }
  273. case "HELP":
  274. helpText := `Available Commands:
  275. GET <key> - Get value
  276. SET <key> <value> - Set value
  277. DEL <key> - Delete value
  278. SEARCH <pattern> [limit] - Search keys (e.g. user.*)
  279. COUNT <pattern> - Count keys
  280. INFO - Show system stats
  281. WHOAMI - Show current user
  282. JOIN <id> <addr> - Add node (Root only)
  283. LEAVE <id> - Remove node (Root only)
  284. USER_LIST - List users (Admin)
  285. ROLE_LIST - List roles (Admin)
  286. LOGIN/LOGOUT/EXIT`
  287. resp = "OK " + helpText
  288. case "INFO":
  289. // Check permission (Root only if auth enabled)
  290. if s.AuthManager.IsEnabled() {
  291. sess, err := s.AuthManager.GetSession(session.token)
  292. if err != nil {
  293. resp = fmt.Sprintf("ERR %v", err)
  294. break
  295. }
  296. if sess.Username != "root" {
  297. resp = "ERR Permission Denied: Root access required"
  298. break
  299. }
  300. }
  301. // Gather stats
  302. stats := s.GetStats()
  303. health := s.HealthCheck()
  304. dbSize := s.GetDBSize()
  305. logSize := s.GetLogSize()
  306. var m runtime.MemStats
  307. runtime.ReadMemStats(&m)
  308. // Construct JSON response
  309. info := map[string]interface{}{
  310. "node": map[string]interface{}{
  311. "id": health.NodeID,
  312. "state": health.State,
  313. "term": health.Term,
  314. "leader": health.LeaderID,
  315. "healthy": health.IsHealthy,
  316. },
  317. "storage": map[string]interface{}{
  318. "db_size": dbSize,
  319. "log_size": logSize,
  320. "mem_alloc": m.Alloc,
  321. "mem_sys": m.Sys,
  322. "num_gc": m.NumGC,
  323. },
  324. "indices": map[string]interface{}{
  325. "commit_index": stats.CommitIndex,
  326. "applied_index": stats.LastApplied,
  327. "last_log_index": stats.LastLogIndex,
  328. "db_applied": s.DB.GetLastAppliedIndex(),
  329. },
  330. "cluster": stats.ClusterNodes,
  331. "cluster_size": stats.ClusterSize,
  332. }
  333. data, err := json.Marshal(info)
  334. if err != nil {
  335. resp = fmt.Sprintf("ERR %v", err)
  336. } else {
  337. resp = "OK " + string(data)
  338. }
  339. case "SEARCH":
  340. // Usage: SEARCH <pattern> [limit] [offset]
  341. if len(parts) < 2 {
  342. resp = "ERR usage: SEARCH <pattern> [limit] [offset]"
  343. } else {
  344. pattern := parts[1]
  345. limit := 20
  346. offset := 0
  347. if len(parts) >= 3 {
  348. if l, err := strconv.Atoi(parts[2]); err == nil {
  349. limit = l
  350. }
  351. }
  352. if len(parts) >= 4 {
  353. if o, err := strconv.Atoi(parts[3]); err == nil {
  354. offset = o
  355. }
  356. }
  357. results, err := s.SearchAuthenticated(pattern, limit, offset, session.token)
  358. if err != nil {
  359. resp = fmt.Sprintf("ERR %v", err)
  360. } else {
  361. data, _ := json.Marshal(results)
  362. resp = "OK " + string(data)
  363. }
  364. }
  365. case "COUNT":
  366. // Usage: COUNT <pattern>
  367. if len(parts) < 2 {
  368. resp = "ERR usage: COUNT <pattern>"
  369. } else {
  370. pattern := parts[1]
  371. count, err := s.CountAuthenticated(pattern, session.token)
  372. if err != nil {
  373. resp = fmt.Sprintf("ERR %v", err)
  374. } else {
  375. resp = fmt.Sprintf("OK %d", count)
  376. }
  377. }
  378. case "JOIN":
  379. // Usage: JOIN <id> <addr>
  380. // Admin only
  381. if s.AuthManager.IsEnabled() {
  382. sess, err := s.AuthManager.GetSession(session.token)
  383. if err != nil || sess.Username != "root" {
  384. resp = "ERR Permission Denied: Root access required"
  385. break
  386. }
  387. }
  388. if len(parts) < 3 {
  389. resp = "ERR usage: JOIN <id> <addr>"
  390. } else {
  391. err := s.Join(parts[1], parts[2])
  392. if err != nil {
  393. resp = fmt.Sprintf("ERR %v", err)
  394. } else {
  395. resp = "OK Join request sent"
  396. }
  397. }
  398. case "LEAVE":
  399. // Usage: LEAVE <id>
  400. // Admin only
  401. if s.AuthManager.IsEnabled() {
  402. sess, err := s.AuthManager.GetSession(session.token)
  403. if err != nil || sess.Username != "root" {
  404. resp = "ERR Permission Denied: Root access required"
  405. break
  406. }
  407. }
  408. if len(parts) < 2 {
  409. resp = "ERR usage: LEAVE <id>"
  410. } else {
  411. err := s.Leave(parts[1])
  412. if err != nil {
  413. resp = fmt.Sprintf("ERR %v", err)
  414. } else {
  415. resp = "OK Leave request sent"
  416. }
  417. }
  418. // --- Admin Commands ---
  419. case "USER_LIST":
  420. users := s.AuthManager.ListUsers()
  421. data, err := json.Marshal(users)
  422. if err != nil {
  423. resp = fmt.Sprintf("ERR %v", err)
  424. } else {
  425. // Ensure it's a single line for TCP protocol simplicity
  426. jsonStr := string(data)
  427. resp = fmt.Sprintf("OK %s", jsonStr)
  428. }
  429. case "ROLE_LIST":
  430. roles := s.AuthManager.ListRoles()
  431. data, err := json.Marshal(roles)
  432. if err != nil {
  433. resp = fmt.Sprintf("ERR %v", err)
  434. } else {
  435. resp = fmt.Sprintf("OK %s", string(data))
  436. }
  437. case "USER_CREATE":
  438. // Usage: USER_CREATE <username> <password> <role1,role2>
  439. if len(parts) < 3 {
  440. resp = "ERR usage: USER_CREATE <user> <pass> [roles]"
  441. } else {
  442. u := parts[1]
  443. p := parts[2]
  444. var roles []string
  445. if len(parts) > 3 {
  446. roles = strings.Split(parts[3], ",")
  447. }
  448. // Use RegisterUser (sync)
  449. err := s.AuthManager.RegisterUser(u, p, roles)
  450. if err != nil {
  451. resp = fmt.Sprintf("ERR %v", err)
  452. } else {
  453. resp = "OK"
  454. }
  455. }
  456. case "ROLE_CREATE":
  457. // Usage: ROLE_CREATE <name>
  458. if len(parts) < 2 {
  459. resp = "ERR usage: ROLE_CREATE <name>"
  460. } else {
  461. name := parts[1]
  462. err := s.AuthManager.CreateRole(name)
  463. if err != nil {
  464. resp = fmt.Sprintf("ERR %v", err)
  465. } else {
  466. resp = "OK"
  467. }
  468. }
  469. case "ROLE_PERMISSION_ADD":
  470. // Usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>
  471. // Actions: comma separated list of actions (read,write,admin,*)
  472. if len(parts) < 4 {
  473. resp = "ERR usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>"
  474. } else {
  475. roleName := parts[1]
  476. pattern := parts[2]
  477. actionsStr := parts[3]
  478. actions := strings.Split(actionsStr, ",")
  479. rolePtr, err := s.AuthManager.GetRole(roleName)
  480. if err != nil {
  481. resp = fmt.Sprintf("ERR %v", err)
  482. } else {
  483. // Create a copy to modify
  484. role := *rolePtr
  485. // Deep copy permissions to avoid potential side effects on cached object
  486. originalPerms := role.Permissions
  487. role.Permissions = make([]Permission, len(originalPerms))
  488. copy(role.Permissions, originalPerms)
  489. newPerm := Permission{
  490. KeyPattern: pattern,
  491. Actions: actions,
  492. }
  493. // Upsert logic: Update if exists, Append if new
  494. found := false
  495. for i, p := range role.Permissions {
  496. if p.KeyPattern == pattern {
  497. role.Permissions[i] = newPerm
  498. found = true
  499. break
  500. }
  501. }
  502. if !found {
  503. role.Permissions = append(role.Permissions, newPerm)
  504. }
  505. err := s.AuthManager.UpdateRole(role)
  506. if err != nil {
  507. resp = fmt.Sprintf("ERR %v", err)
  508. } else {
  509. resp = "OK"
  510. }
  511. }
  512. }
  513. case "ROLE_PERMISSION_REMOVE":
  514. // Usage: ROLE_PERMISSION_REMOVE <role> <pattern>
  515. if len(parts) < 3 {
  516. resp = "ERR usage: ROLE_PERMISSION_REMOVE <role> <pattern>"
  517. } else {
  518. roleName := parts[1]
  519. pattern := parts[2]
  520. rolePtr, err := s.AuthManager.GetRole(roleName)
  521. if err != nil {
  522. resp = fmt.Sprintf("ERR %v", err)
  523. } else {
  524. role := *rolePtr
  525. originalPerms := role.Permissions
  526. newPerms := make([]Permission, 0, len(originalPerms))
  527. found := false
  528. for _, p := range originalPerms {
  529. if p.KeyPattern == pattern {
  530. found = true
  531. continue
  532. }
  533. newPerms = append(newPerms, p)
  534. }
  535. if !found {
  536. resp = "ERR permission not found"
  537. } else {
  538. role.Permissions = newPerms
  539. err := s.AuthManager.UpdateRole(role)
  540. if err != nil {
  541. resp = fmt.Sprintf("ERR %v", err)
  542. } else {
  543. resp = "OK"
  544. }
  545. }
  546. }
  547. }
  548. case "USER_UNLOCK":
  549. // Usage: USER_UNLOCK <username>
  550. if len(parts) < 2 {
  551. resp = "ERR usage: USER_UNLOCK <username>"
  552. } else {
  553. // Manually clear the lock key
  554. // Note: accessing server.Set directly bypasses auth check which is fine here
  555. // as the TCP session itself should be authenticated as admin ideally.
  556. // For now we trust the connected client has rights or we check session.
  557. // In real impl, check if session.username is root or has admin perm.
  558. userToUnlock := parts[1]
  559. // We use Del to remove the lock key
  560. err := s.Del("system.lock." + userToUnlock)
  561. if err != nil {
  562. resp = fmt.Sprintf("ERR %v", err)
  563. } else {
  564. resp = "OK"
  565. }
  566. }
  567. case "EXIT", "QUIT":
  568. resp = "BYE"
  569. // Need signal to close connection after write
  570. // For simplicity, handle it in handleTCPConnection loop break,
  571. // but here we just return the string.
  572. // Actually, BYE handling is tricky in async writer.
  573. // Let's keep connection open or let client close it.
  574. // Or send special signal?
  575. // For now, simple return. Client will read BYE and close.
  576. default:
  577. s.Raft.config.Logger.Warn("Unknown command received: %s (parts: %v)", cmd, parts)
  578. resp = fmt.Sprintf("ERR unknown command: %s", cmd)
  579. }
  580. return resp
  581. }