tcp_server.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. package raft
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net"
  8. "runtime"
  9. "strconv"
  10. "strings"
  11. "time"
  12. )
  13. // TCPClientSession holds state for a TCP connection
  14. type TCPClientSession struct {
  15. conn net.Conn
  16. server *KVServer
  17. token string
  18. username string
  19. reader *bufio.Reader
  20. writer *bufio.Writer
  21. remoteAddr string
  22. }
  23. func (s *KVServer) StartTCPServer(addr string) error {
  24. listener, err := net.Listen("tcp", addr)
  25. if err != nil {
  26. return err
  27. }
  28. s.Raft.config.Logger.Info("TCP API server listening on %s", addr)
  29. go func() {
  30. defer listener.Close()
  31. for {
  32. conn, err := listener.Accept()
  33. if err != nil {
  34. if s.stopCh != nil {
  35. select {
  36. case <-s.stopCh:
  37. return
  38. default:
  39. }
  40. }
  41. s.Raft.config.Logger.Error("TCP Accept error: %v", err)
  42. continue
  43. }
  44. go s.handleTCPConnection(conn)
  45. }
  46. }()
  47. return nil
  48. }
  49. func (s *KVServer) handleTCPConnection(conn net.Conn) {
  50. defer conn.Close()
  51. // Use larger buffers for high throughput
  52. reader := bufio.NewReaderSize(conn, 64*1024)
  53. writer := bufio.NewWriterSize(conn, 64*1024)
  54. session := &TCPClientSession{
  55. conn: conn,
  56. server: s,
  57. reader: reader,
  58. writer: writer,
  59. remoteAddr: conn.RemoteAddr().String(),
  60. }
  61. for {
  62. // Set Keep-Alive Deadline
  63. // Using a longer deadline avoids frequent syscalls if traffic is continuous
  64. conn.SetReadDeadline(time.Now().Add(60 * time.Second))
  65. // Read Request Line
  66. line, err := reader.ReadString('\n')
  67. if err != nil {
  68. return
  69. }
  70. line = strings.TrimSpace(line)
  71. if line == "" {
  72. continue
  73. }
  74. // Parse Headers
  75. var contentLength int
  76. for {
  77. hLine, err := reader.ReadString('\n')
  78. if err != nil {
  79. return
  80. }
  81. hLine = strings.TrimSpace(hLine)
  82. if hLine == "" {
  83. break
  84. }
  85. // Optimized Content-Length check
  86. // "Content-Length: 123"
  87. lowerHLine := strings.ToLower(hLine)
  88. if strings.HasPrefix(lowerHLine, "content-length:") {
  89. valStr := strings.TrimSpace(hLine[15:])
  90. if l, err := strconv.Atoi(valStr); err == nil {
  91. contentLength = l
  92. }
  93. }
  94. }
  95. // Read Body
  96. var body string
  97. if contentLength > 0 {
  98. // Optimization: Avoid allocation for small bodies if possible,
  99. // but Raft commands need to be passed as bytes/string anyway.
  100. buf := make([]byte, contentLength)
  101. if _, err := io.ReadFull(reader, buf); err != nil {
  102. return
  103. }
  104. body = string(buf)
  105. }
  106. // Implement Batching: Accumulate requests if they are ASET/SET
  107. // Ideally we need an internal buffer or channel here to queue up commands.
  108. // For zero-allocation batching in this loop structure, we can try to
  109. // eagerly read the next request if available in the buffer.
  110. // Note: Raft's Propose is thread-safe. The current serial loop is efficient for
  111. // minimizing context switches. Batching helps if we can merge multiple Proposes into one.
  112. // Since we haven't modified Raft to support ProposeBatch yet, we can't do true backend batching easily.
  113. // However, we can do "frontend batching" by checking buffer availability?
  114. // No, frontend batching requires Raft to accept a batch.
  115. // Without modifying Raft.Propose to accept []Command, "batching" here is limited.
  116. // BUT, we can at least pipeline the execution if we had concurrent workers,
  117. // but we replaced that with this serial loop for perf.
  118. // So for now, we just execute.
  119. // Execute Command Inline
  120. resp := s.executeCommandWithBody(session, line, body)
  121. // Write Response
  122. if _, err := writer.WriteString(resp + "\n"); err != nil {
  123. return
  124. }
  125. // Flush Optimization:
  126. // Only flush if there is no more data in the read buffer.
  127. // This automatically batches responses when requests are pipelined.
  128. if reader.Buffered() == 0 {
  129. if err := writer.Flush(); err != nil {
  130. return
  131. }
  132. }
  133. }
  134. }
  135. // Helper to bridge old logic
  136. func (s *KVServer) executeCommandWithBody(session *TCPClientSession, line string, body string) string {
  137. parts := strings.Fields(line)
  138. if len(parts) == 0 {
  139. return ""
  140. }
  141. cmd := strings.ToUpper(parts[0])
  142. // If body is present, it overrides the value part of the command
  143. // We handle specific commands that use body
  144. if body != "" {
  145. switch cmd {
  146. case "SET", "ASET":
  147. if len(parts) < 2 {
  148. return "ERR usage: SET <key> (value in body)"
  149. }
  150. key := parts[1]
  151. // Value is body
  152. if cmd == "SET" {
  153. if err := s.SetAuthenticated(key, body, session.token); err != nil {
  154. return fmt.Sprintf("ERR %v", err)
  155. }
  156. return "OK"
  157. } else {
  158. if err := s.SetAuthenticatedAsync(key, body, session.token); err != nil {
  159. return fmt.Sprintf("ERR %v", err)
  160. }
  161. return "OK"
  162. }
  163. }
  164. }
  165. // Fallback to legacy parsing if no body or command doesn't use body
  166. return s.executeCommand(session, line)
  167. }
  168. func (s *KVServer) executeCommand(session *TCPClientSession, line string) string {
  169. parts := strings.Fields(line)
  170. if len(parts) == 0 {
  171. return ""
  172. }
  173. cmd := strings.ToUpper(parts[0])
  174. var resp string
  175. switch cmd {
  176. case "LOGIN":
  177. if len(parts) < 3 {
  178. resp = "ERR usage: LOGIN <username> <password> [otp]"
  179. } else {
  180. user := parts[1]
  181. pass := parts[2]
  182. otp := ""
  183. if len(parts) > 3 {
  184. otp = parts[3]
  185. }
  186. // Extract IP
  187. ip := session.remoteAddr
  188. if host, _, err := net.SplitHostPort(ip); err == nil {
  189. ip = host
  190. }
  191. token, err := s.AuthManager.Login(user, pass, otp, ip)
  192. if err != nil {
  193. resp = fmt.Sprintf("ERR %v", err)
  194. } else {
  195. session.token = token
  196. session.username = user
  197. resp = fmt.Sprintf("OK %s", token)
  198. }
  199. }
  200. case "AUTH":
  201. if len(parts) < 2 {
  202. resp = "ERR usage: AUTH <token>"
  203. } else {
  204. token := parts[1]
  205. // Verify token
  206. sess, err := s.AuthManager.GetSession(token)
  207. if err != nil {
  208. resp = fmt.Sprintf("ERR %v", err)
  209. } else {
  210. session.token = token
  211. session.username = sess.Username
  212. resp = "OK"
  213. }
  214. }
  215. case "LOGOUT":
  216. if session.token != "" {
  217. s.AuthManager.Logout(session.token)
  218. session.token = ""
  219. session.username = ""
  220. }
  221. resp = "OK"
  222. case "GET":
  223. if len(parts) < 2 {
  224. resp = "ERR usage: GET <key>"
  225. } else {
  226. key := parts[1]
  227. val, found, err := s.GetLinearAuthenticated(key, session.token)
  228. if err != nil {
  229. resp = fmt.Sprintf("ERR %v", err)
  230. } else if !found {
  231. resp = "ERR not found"
  232. } else {
  233. resp = fmt.Sprintf("OK %s", val)
  234. }
  235. }
  236. case "SET":
  237. if len(parts) < 3 {
  238. resp = "ERR usage: SET <key> <value>"
  239. } else {
  240. key := parts[1]
  241. // Value might contain spaces, join the rest
  242. val := strings.Join(parts[2:], " ")
  243. // Use SetAuthenticated (Sync) by default for safety
  244. err := s.SetAuthenticated(key, val, session.token)
  245. if err != nil {
  246. resp = fmt.Sprintf("ERR %v", err)
  247. } else {
  248. resp = "OK"
  249. }
  250. }
  251. case "ASET":
  252. // Async SET for high performance
  253. if len(parts) < 3 {
  254. resp = "ERR usage: ASET <key> <value>"
  255. } else {
  256. key := parts[1]
  257. val := strings.Join(parts[2:], " ")
  258. err := s.SetAuthenticatedAsync(key, val, session.token)
  259. if err != nil {
  260. resp = fmt.Sprintf("ERR %v", err)
  261. } else {
  262. resp = "OK"
  263. }
  264. }
  265. case "DEL":
  266. if len(parts) < 2 {
  267. resp = "ERR usage: DEL <key>"
  268. } else {
  269. key := parts[1]
  270. err := s.DelAuthenticated(key, session.token)
  271. if err != nil {
  272. resp = fmt.Sprintf("ERR %v", err)
  273. } else {
  274. resp = "OK"
  275. }
  276. }
  277. case "WHOAMI":
  278. if session.username == "" {
  279. resp = "Guest"
  280. } else {
  281. resp = session.username
  282. }
  283. case "HELP":
  284. helpText := `Available Commands:
  285. GET <key> - Get value
  286. SET <key> <value> - Set value
  287. DEL <key> - Delete value
  288. SEARCH <pattern> [limit] - Search keys (e.g. user.*)
  289. COUNT <pattern> - Count keys
  290. INFO - Show system stats
  291. WHOAMI - Show current user
  292. JOIN <id> <addr> - Add node (Root only)
  293. LEAVE <id> - Remove node (Root only)
  294. USER_LIST - List users (Admin)
  295. ROLE_LIST - List roles (Admin)
  296. LOGIN/LOGOUT/EXIT`
  297. resp = "OK " + helpText
  298. case "INFO":
  299. // Check permission (Root only if auth enabled)
  300. if s.AuthManager.IsEnabled() {
  301. sess, err := s.AuthManager.GetSession(session.token)
  302. if err != nil {
  303. resp = fmt.Sprintf("ERR %v", err)
  304. break
  305. }
  306. if sess.Username != "root" {
  307. resp = "ERR Permission Denied: Root access required"
  308. break
  309. }
  310. }
  311. // Gather stats
  312. stats := s.GetStats()
  313. health := s.HealthCheck()
  314. dbSize := s.GetDBSize()
  315. logSize := s.GetLogSize()
  316. var m runtime.MemStats
  317. runtime.ReadMemStats(&m)
  318. // Construct JSON response
  319. info := map[string]interface{}{
  320. "node": map[string]interface{}{
  321. "id": health.NodeID,
  322. "state": health.State,
  323. "term": health.Term,
  324. "leader": health.LeaderID,
  325. "healthy": health.IsHealthy,
  326. },
  327. "storage": map[string]interface{}{
  328. "db_size": dbSize,
  329. "log_size": logSize,
  330. "mem_alloc": m.Alloc,
  331. "mem_sys": m.Sys,
  332. "num_gc": m.NumGC,
  333. },
  334. "indices": map[string]interface{}{
  335. "commit_index": stats.CommitIndex,
  336. "applied_index": stats.LastApplied,
  337. "last_log_index": stats.LastLogIndex,
  338. "db_applied": s.DB.GetLastAppliedIndex(),
  339. },
  340. "cluster": stats.ClusterNodes,
  341. "cluster_size": stats.ClusterSize,
  342. }
  343. data, err := json.Marshal(info)
  344. if err != nil {
  345. resp = fmt.Sprintf("ERR %v", err)
  346. } else {
  347. resp = "OK " + string(data)
  348. }
  349. case "SEARCH":
  350. // Usage: SEARCH <pattern> [limit] [offset]
  351. if len(parts) < 2 {
  352. resp = "ERR usage: SEARCH <pattern> [limit] [offset]"
  353. } else {
  354. pattern := parts[1]
  355. limit := 20
  356. offset := 0
  357. if len(parts) >= 3 {
  358. if l, err := strconv.Atoi(parts[2]); err == nil {
  359. limit = l
  360. }
  361. }
  362. if len(parts) >= 4 {
  363. if o, err := strconv.Atoi(parts[3]); err == nil {
  364. offset = o
  365. }
  366. }
  367. results, err := s.SearchAuthenticated(pattern, limit, offset, session.token)
  368. if err != nil {
  369. resp = fmt.Sprintf("ERR %v", err)
  370. } else {
  371. data, _ := json.Marshal(results)
  372. resp = "OK " + string(data)
  373. }
  374. }
  375. case "COUNT":
  376. // Usage: COUNT <pattern>
  377. if len(parts) < 2 {
  378. resp = "ERR usage: COUNT <pattern>"
  379. } else {
  380. pattern := parts[1]
  381. count, err := s.CountAuthenticated(pattern, session.token)
  382. if err != nil {
  383. resp = fmt.Sprintf("ERR %v", err)
  384. } else {
  385. resp = fmt.Sprintf("OK %d", count)
  386. }
  387. }
  388. case "JOIN":
  389. // Usage: JOIN <id> <addr>
  390. // Admin only
  391. if s.AuthManager.IsEnabled() {
  392. sess, err := s.AuthManager.GetSession(session.token)
  393. if err != nil || sess.Username != "root" {
  394. resp = "ERR Permission Denied: Root access required"
  395. break
  396. }
  397. }
  398. if len(parts) < 3 {
  399. resp = "ERR usage: JOIN <id> <addr>"
  400. } else {
  401. err := s.Join(parts[1], parts[2])
  402. if err != nil {
  403. resp = fmt.Sprintf("ERR %v", err)
  404. } else {
  405. resp = "OK Join request sent"
  406. }
  407. }
  408. case "LEAVE":
  409. // Usage: LEAVE <id>
  410. // Admin only
  411. if s.AuthManager.IsEnabled() {
  412. sess, err := s.AuthManager.GetSession(session.token)
  413. if err != nil || sess.Username != "root" {
  414. resp = "ERR Permission Denied: Root access required"
  415. break
  416. }
  417. }
  418. if len(parts) < 2 {
  419. resp = "ERR usage: LEAVE <id>"
  420. } else {
  421. err := s.Leave(parts[1])
  422. if err != nil {
  423. resp = fmt.Sprintf("ERR %v", err)
  424. } else {
  425. resp = "OK Leave request sent"
  426. }
  427. }
  428. // --- Admin Commands ---
  429. case "USER_LIST":
  430. users := s.AuthManager.ListUsers()
  431. data, err := json.Marshal(users)
  432. if err != nil {
  433. resp = fmt.Sprintf("ERR %v", err)
  434. } else {
  435. // Ensure it's a single line for TCP protocol simplicity
  436. jsonStr := string(data)
  437. resp = fmt.Sprintf("OK %s", jsonStr)
  438. }
  439. case "ROLE_LIST":
  440. roles := s.AuthManager.ListRoles()
  441. data, err := json.Marshal(roles)
  442. if err != nil {
  443. resp = fmt.Sprintf("ERR %v", err)
  444. } else {
  445. resp = fmt.Sprintf("OK %s", string(data))
  446. }
  447. case "USER_CREATE":
  448. // Usage: USER_CREATE <username> <password> <role1,role2>
  449. if len(parts) < 3 {
  450. resp = "ERR usage: USER_CREATE <user> <pass> [roles]"
  451. } else {
  452. u := parts[1]
  453. p := parts[2]
  454. var roles []string
  455. if len(parts) > 3 {
  456. roles = strings.Split(parts[3], ",")
  457. }
  458. // Use RegisterUser (sync)
  459. err := s.AuthManager.RegisterUser(u, p, roles)
  460. if err != nil {
  461. resp = fmt.Sprintf("ERR %v", err)
  462. } else {
  463. resp = "OK"
  464. }
  465. }
  466. case "ROLE_CREATE":
  467. // Usage: ROLE_CREATE <name>
  468. if len(parts) < 2 {
  469. resp = "ERR usage: ROLE_CREATE <name>"
  470. } else {
  471. name := parts[1]
  472. err := s.AuthManager.CreateRole(name)
  473. if err != nil {
  474. resp = fmt.Sprintf("ERR %v", err)
  475. } else {
  476. resp = "OK"
  477. }
  478. }
  479. case "ROLE_PERMISSION_ADD":
  480. // Usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>
  481. // Actions: comma separated list of actions (read,write,admin,*)
  482. if len(parts) < 4 {
  483. resp = "ERR usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>"
  484. } else {
  485. roleName := parts[1]
  486. pattern := parts[2]
  487. actionsStr := parts[3]
  488. actions := strings.Split(actionsStr, ",")
  489. rolePtr, err := s.AuthManager.GetRole(roleName)
  490. if err != nil {
  491. resp = fmt.Sprintf("ERR %v", err)
  492. } else {
  493. // Create a copy to modify
  494. role := *rolePtr
  495. // Deep copy permissions to avoid potential side effects on cached object
  496. originalPerms := role.Permissions
  497. role.Permissions = make([]Permission, len(originalPerms))
  498. copy(role.Permissions, originalPerms)
  499. newPerm := Permission{
  500. KeyPattern: pattern,
  501. Actions: actions,
  502. }
  503. // Upsert logic: Update if exists, Append if new
  504. found := false
  505. for i, p := range role.Permissions {
  506. if p.KeyPattern == pattern {
  507. role.Permissions[i] = newPerm
  508. found = true
  509. break
  510. }
  511. }
  512. if !found {
  513. role.Permissions = append(role.Permissions, newPerm)
  514. }
  515. err := s.AuthManager.UpdateRole(role)
  516. if err != nil {
  517. resp = fmt.Sprintf("ERR %v", err)
  518. } else {
  519. resp = "OK"
  520. }
  521. }
  522. }
  523. case "ROLE_PERMISSION_REMOVE":
  524. // Usage: ROLE_PERMISSION_REMOVE <role> <pattern>
  525. if len(parts) < 3 {
  526. resp = "ERR usage: ROLE_PERMISSION_REMOVE <role> <pattern>"
  527. } else {
  528. roleName := parts[1]
  529. pattern := parts[2]
  530. rolePtr, err := s.AuthManager.GetRole(roleName)
  531. if err != nil {
  532. resp = fmt.Sprintf("ERR %v", err)
  533. } else {
  534. role := *rolePtr
  535. originalPerms := role.Permissions
  536. newPerms := make([]Permission, 0, len(originalPerms))
  537. found := false
  538. for _, p := range originalPerms {
  539. if p.KeyPattern == pattern {
  540. found = true
  541. continue
  542. }
  543. newPerms = append(newPerms, p)
  544. }
  545. if !found {
  546. resp = "ERR permission not found"
  547. } else {
  548. role.Permissions = newPerms
  549. err := s.AuthManager.UpdateRole(role)
  550. if err != nil {
  551. resp = fmt.Sprintf("ERR %v", err)
  552. } else {
  553. resp = "OK"
  554. }
  555. }
  556. }
  557. }
  558. case "USER_UNLOCK":
  559. // Usage: USER_UNLOCK <username>
  560. if len(parts) < 2 {
  561. resp = "ERR usage: USER_UNLOCK <username>"
  562. } else {
  563. // Manually clear the lock key
  564. // Note: accessing server.Set directly bypasses auth check which is fine here
  565. // as the TCP session itself should be authenticated as admin ideally.
  566. // For now we trust the connected client has rights or we check session.
  567. // In real impl, check if session.username is root or has admin perm.
  568. userToUnlock := parts[1]
  569. // We use Del to remove the lock key
  570. err := s.Del("system.lock." + userToUnlock)
  571. if err != nil {
  572. resp = fmt.Sprintf("ERR %v", err)
  573. } else {
  574. resp = "OK"
  575. }
  576. }
  577. case "EXIT", "QUIT":
  578. resp = "BYE"
  579. // Need signal to close connection after write
  580. // For simplicity, handle it in handleTCPConnection loop break,
  581. // but here we just return the string.
  582. // Actually, BYE handling is tricky in async writer.
  583. // Let's keep connection open or let client close it.
  584. // Or send special signal?
  585. // For now, simple return. Client will read BYE and close.
  586. default:
  587. s.Raft.config.Logger.Warn("Unknown command received: %s (parts: %v)", cmd, parts)
  588. resp = fmt.Sprintf("ERR unknown command: %s", cmd)
  589. }
  590. return resp
  591. }