tcp_server.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649
  1. package raft
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net"
  8. "runtime"
  9. "strconv"
  10. "strings"
  11. "time"
  12. )
  13. // TCPClientSession holds state for a TCP connection
  14. type TCPClientSession struct {
  15. conn net.Conn
  16. server *KVServer
  17. token string
  18. username string
  19. reader *bufio.Reader
  20. writer *bufio.Writer
  21. remoteAddr string
  22. }
  23. func (s *KVServer) StartTCPServer(addr string) error {
  24. listener, err := net.Listen("tcp", addr)
  25. if err != nil {
  26. return err
  27. }
  28. s.Raft.config.Logger.Info("TCP API server listening on %s", addr)
  29. go func() {
  30. defer listener.Close()
  31. for {
  32. conn, err := listener.Accept()
  33. if err != nil {
  34. if s.stopCh != nil {
  35. select {
  36. case <-s.stopCh:
  37. return
  38. default:
  39. }
  40. }
  41. s.Raft.config.Logger.Error("TCP Accept error: %v", err)
  42. continue
  43. }
  44. go s.handleTCPConnection(conn)
  45. }
  46. }()
  47. return nil
  48. }
  49. func (s *KVServer) handleTCPConnection(conn net.Conn) {
  50. defer conn.Close()
  51. // Use larger buffers for high throughput
  52. reader := bufio.NewReaderSize(conn, 64*1024)
  53. writer := bufio.NewWriterSize(conn, 64*1024)
  54. session := &TCPClientSession{
  55. conn: conn,
  56. server: s,
  57. reader: reader,
  58. writer: writer,
  59. remoteAddr: conn.RemoteAddr().String(),
  60. }
  61. for {
  62. // Set Keep-Alive Deadline
  63. // Using a longer deadline avoids frequent syscalls if traffic is continuous
  64. conn.SetReadDeadline(time.Now().Add(60 * time.Second))
  65. // Read Request Line
  66. line, err := reader.ReadString('\n')
  67. if err != nil {
  68. return
  69. }
  70. line = strings.TrimSpace(line)
  71. if line == "" {
  72. continue
  73. }
  74. // Parse Headers
  75. var contentLength int
  76. for {
  77. hLine, err := reader.ReadString('\n')
  78. if err != nil {
  79. return
  80. }
  81. hLine = strings.TrimSpace(hLine)
  82. if hLine == "" {
  83. break
  84. }
  85. // Optimized Content-Length check
  86. // "Content-Length: 123"
  87. lowerHLine := strings.ToLower(hLine)
  88. if strings.HasPrefix(lowerHLine, "content-length:") {
  89. valStr := strings.TrimSpace(hLine[15:])
  90. if l, err := strconv.Atoi(valStr); err == nil {
  91. contentLength = l
  92. }
  93. }
  94. }
  95. // Read Body
  96. var body string
  97. if contentLength > 0 {
  98. // Optimization: Avoid allocation for small bodies if possible,
  99. // but Raft commands need to be passed as bytes/string anyway.
  100. buf := make([]byte, contentLength)
  101. if _, err := io.ReadFull(reader, buf); err != nil {
  102. return
  103. }
  104. body = string(buf)
  105. }
  106. // Implement Batching: Accumulate requests if they are ASET/SET
  107. // Ideally we need an internal buffer or channel here to queue up commands.
  108. // For zero-allocation batching in this loop structure, we can try to
  109. // eagerly read the next request if available in the buffer.
  110. // Note: Raft's Propose is thread-safe. The current serial loop is efficient for
  111. // minimizing context switches. Batching helps if we can merge multiple Proposes into one.
  112. // Since we haven't modified Raft to support ProposeBatch yet, we can't do true backend batching easily.
  113. // However, we can do "frontend batching" by checking buffer availability?
  114. // No, frontend batching requires Raft to accept a batch.
  115. // Without modifying Raft.Propose to accept []Command, "batching" here is limited.
  116. // BUT, we can at least pipeline the execution if we had concurrent workers,
  117. // but we replaced that with this serial loop for perf.
  118. // So for now, we just execute.
  119. // Execute Command Inline
  120. resp := s.executeCommandWithBody(session, line, body)
  121. // Write Response
  122. if _, err := writer.WriteString(resp + "\n"); err != nil {
  123. return
  124. }
  125. // Flush Optimization:
  126. // Only flush if there is no more data in the read buffer.
  127. // This automatically batches responses when requests are pipelined.
  128. if reader.Buffered() == 0 {
  129. if err := writer.Flush(); err != nil {
  130. return
  131. }
  132. }
  133. }
  134. }
  135. // Helper to bridge old logic
  136. func (s *KVServer) executeCommandWithBody(session *TCPClientSession, line string, body string) string {
  137. parts := strings.Fields(line)
  138. if len(parts) == 0 {
  139. return ""
  140. }
  141. cmd := strings.ToUpper(parts[0])
  142. // If body is present, it overrides the value part of the command
  143. // We handle specific commands that use body
  144. if body != "" {
  145. switch cmd {
  146. case "SET", "ASET":
  147. if len(parts) < 2 {
  148. return "ERR usage: SET <key> (value in body)"
  149. }
  150. key := parts[1]
  151. // Value is body
  152. if cmd == "SET" {
  153. if err := s.SetAuthenticated(key, body, session.token); err != nil {
  154. return fmt.Sprintf("ERR %v", err)
  155. }
  156. return "OK"
  157. } else {
  158. if err := s.SetAuthenticatedAsync(key, body, session.token); err != nil {
  159. return fmt.Sprintf("ERR %v", err)
  160. }
  161. return "OK"
  162. }
  163. }
  164. }
  165. // Fallback to legacy parsing if no body or command doesn't use body
  166. return s.executeCommand(session, line)
  167. }
  168. func (s *KVServer) executeCommand(session *TCPClientSession, line string) string {
  169. parts := strings.Fields(line)
  170. if len(parts) == 0 {
  171. return ""
  172. }
  173. cmd := strings.ToUpper(parts[0])
  174. var resp string
  175. switch cmd {
  176. case "LOGIN":
  177. if len(parts) < 3 {
  178. resp = "ERR usage: LOGIN <username> <password> [otp]"
  179. } else {
  180. user := parts[1]
  181. pass := parts[2]
  182. otp := ""
  183. if len(parts) > 3 {
  184. otp = parts[3]
  185. }
  186. // Extract IP
  187. ip := session.remoteAddr
  188. if host, _, err := net.SplitHostPort(ip); err == nil {
  189. ip = host
  190. }
  191. token, err := s.AuthManager.Login(user, pass, otp, ip)
  192. if err != nil {
  193. resp = fmt.Sprintf("ERR %v", err)
  194. } else {
  195. session.token = token
  196. session.username = user
  197. resp = fmt.Sprintf("OK %s", token)
  198. }
  199. }
  200. case "AUTH":
  201. if len(parts) < 2 {
  202. resp = "ERR usage: AUTH <token>"
  203. } else {
  204. token := parts[1]
  205. // Verify token
  206. sess, err := s.AuthManager.GetSession(token)
  207. if err != nil {
  208. resp = fmt.Sprintf("ERR %v", err)
  209. } else {
  210. session.token = token
  211. session.username = sess.Username
  212. resp = "OK"
  213. }
  214. }
  215. case "LOGOUT":
  216. if session.token != "" {
  217. s.AuthManager.Logout(session.token)
  218. session.token = ""
  219. session.username = ""
  220. }
  221. resp = "OK"
  222. case "GET":
  223. if len(parts) < 2 {
  224. resp = "ERR usage: GET <key>"
  225. } else {
  226. key := parts[1]
  227. val, found, err := s.GetLinearAuthenticated(key, session.token)
  228. if err != nil {
  229. resp = fmt.Sprintf("ERR %v", err)
  230. } else if !found {
  231. resp = "ERR not found"
  232. } else {
  233. resp = fmt.Sprintf("OK %s", val)
  234. }
  235. }
  236. case "SET":
  237. if len(parts) < 3 {
  238. resp = "ERR usage: SET <key> <value>"
  239. } else {
  240. key := parts[1]
  241. // Value might contain spaces, join the rest
  242. val := strings.Join(parts[2:], " ")
  243. // Use SetAuthenticated (Sync) by default for safety
  244. err := s.SetAuthenticated(key, val, session.token)
  245. if err != nil {
  246. resp = fmt.Sprintf("ERR %v", err)
  247. } else {
  248. resp = "OK"
  249. }
  250. }
  251. case "ASET":
  252. // Async SET for high performance
  253. if len(parts) < 3 {
  254. resp = "ERR usage: ASET <key> <value>"
  255. } else {
  256. key := parts[1]
  257. val := strings.Join(parts[2:], " ")
  258. err := s.SetAuthenticatedAsync(key, val, session.token)
  259. if err != nil {
  260. resp = fmt.Sprintf("ERR %v", err)
  261. } else {
  262. resp = "OK"
  263. }
  264. }
  265. case "DEL":
  266. if len(parts) < 2 {
  267. resp = "ERR usage: DEL <key>"
  268. } else {
  269. key := parts[1]
  270. err := s.DelAuthenticated(key, session.token)
  271. if err != nil {
  272. resp = fmt.Sprintf("ERR %v", err)
  273. } else {
  274. resp = "OK"
  275. }
  276. }
  277. case "WHOAMI":
  278. if session.username == "" {
  279. resp = "Guest"
  280. } else {
  281. resp = session.username
  282. }
  283. case "HELP":
  284. helpText := `Available Commands:
  285. GET <key> - Get value
  286. SET <key> <value> - Set value
  287. DEL <key> - Delete value
  288. SEARCH <pattern> [limit] - Search keys (e.g. user.*)
  289. COUNT <pattern> - Count keys
  290. INFO - Show system stats
  291. WHOAMI - Show current user
  292. JOIN <id> <addr> - Add node (Root only)
  293. LEAVE <id> - Remove node (Root only)
  294. USER_LIST - List users (Admin)
  295. ROLE_LIST - List roles (Admin)
  296. LOGIN/LOGOUT/EXIT`
  297. resp = "OK " + helpText
  298. case "INFO":
  299. // Check permission (Admin only if auth enabled)
  300. if s.AuthManager.IsEnabled() {
  301. if !s.IsAdmin(session.token) {
  302. resp = "ERR Permission Denied: Admin access required"
  303. break
  304. }
  305. }
  306. // Gather stats
  307. stats := s.GetStats()
  308. health := s.HealthCheck()
  309. dbSize := s.GetDBSize()
  310. logSize := s.GetLogSize()
  311. var m runtime.MemStats
  312. runtime.ReadMemStats(&m)
  313. // Construct JSON response
  314. info := map[string]interface{}{
  315. "node": map[string]interface{}{
  316. "id": health.NodeID,
  317. "state": health.State,
  318. "term": health.Term,
  319. "leader": health.LeaderID,
  320. "healthy": health.IsHealthy,
  321. },
  322. "storage": map[string]interface{}{
  323. "db_size": dbSize,
  324. "log_size": logSize,
  325. "mem_alloc": m.Alloc,
  326. "mem_sys": m.Sys,
  327. "num_gc": m.NumGC,
  328. },
  329. "indices": map[string]interface{}{
  330. "commit_index": stats.CommitIndex,
  331. "applied_index": stats.LastApplied,
  332. "last_log_index": stats.LastLogIndex,
  333. "db_applied": s.DB.GetLastAppliedIndex(),
  334. },
  335. "cluster": stats.ClusterNodes,
  336. "cluster_size": stats.ClusterSize,
  337. }
  338. data, err := json.Marshal(info)
  339. if err != nil {
  340. resp = fmt.Sprintf("ERR %v", err)
  341. } else {
  342. resp = "OK " + string(data)
  343. }
  344. case "SEARCH":
  345. // Usage: SEARCH <pattern> [limit] [offset]
  346. if len(parts) < 2 {
  347. resp = "ERR usage: SEARCH <pattern> [limit] [offset]"
  348. } else {
  349. pattern := parts[1]
  350. limit := 20
  351. offset := 0
  352. if len(parts) >= 3 {
  353. if l, err := strconv.Atoi(parts[2]); err == nil {
  354. limit = l
  355. }
  356. }
  357. if len(parts) >= 4 {
  358. if o, err := strconv.Atoi(parts[3]); err == nil {
  359. offset = o
  360. }
  361. }
  362. results, err := s.SearchAuthenticated(pattern, limit, offset, session.token)
  363. if err != nil {
  364. resp = fmt.Sprintf("ERR %v", err)
  365. } else {
  366. data, _ := json.Marshal(results)
  367. resp = "OK " + string(data)
  368. }
  369. }
  370. case "COUNT":
  371. // Usage: COUNT <pattern>
  372. if len(parts) < 2 {
  373. resp = "ERR usage: COUNT <pattern>"
  374. } else {
  375. pattern := parts[1]
  376. count, err := s.CountAuthenticated(pattern, session.token)
  377. if err != nil {
  378. resp = fmt.Sprintf("ERR %v", err)
  379. } else {
  380. resp = fmt.Sprintf("OK %d", count)
  381. }
  382. }
  383. case "JOIN":
  384. // Usage: JOIN <id> <addr>
  385. // Admin only
  386. if s.AuthManager.IsEnabled() {
  387. if !s.IsAdmin(session.token) {
  388. resp = "ERR Permission Denied: Admin access required"
  389. break
  390. }
  391. }
  392. if len(parts) < 3 {
  393. resp = "ERR usage: JOIN <id> <addr>"
  394. } else {
  395. err := s.Join(parts[1], parts[2])
  396. if err != nil {
  397. resp = fmt.Sprintf("ERR %v", err)
  398. } else {
  399. resp = "OK Join request sent"
  400. }
  401. }
  402. case "LEAVE":
  403. // Usage: LEAVE <id>
  404. // Admin only
  405. if s.AuthManager.IsEnabled() {
  406. if !s.IsAdmin(session.token) {
  407. resp = "ERR Permission Denied: Admin access required"
  408. break
  409. }
  410. }
  411. if len(parts) < 2 {
  412. resp = "ERR usage: LEAVE <id>"
  413. } else {
  414. err := s.Leave(parts[1])
  415. if err != nil {
  416. resp = fmt.Sprintf("ERR %v", err)
  417. } else {
  418. resp = "OK Leave request sent"
  419. }
  420. }
  421. // --- Admin Commands ---
  422. case "USER_LIST":
  423. users := s.AuthManager.ListUsers()
  424. data, err := json.Marshal(users)
  425. if err != nil {
  426. resp = fmt.Sprintf("ERR %v", err)
  427. } else {
  428. // Ensure it's a single line for TCP protocol simplicity
  429. jsonStr := string(data)
  430. resp = fmt.Sprintf("OK %s", jsonStr)
  431. }
  432. case "ROLE_LIST":
  433. roles := s.AuthManager.ListRoles()
  434. data, err := json.Marshal(roles)
  435. if err != nil {
  436. resp = fmt.Sprintf("ERR %v", err)
  437. } else {
  438. resp = fmt.Sprintf("OK %s", string(data))
  439. }
  440. case "USER_CREATE":
  441. // Usage: USER_CREATE <username> <password> <role1,role2>
  442. if len(parts) < 3 {
  443. resp = "ERR usage: USER_CREATE <user> <pass> [roles]"
  444. } else {
  445. u := parts[1]
  446. p := parts[2]
  447. var roles []string
  448. if len(parts) > 3 {
  449. roles = strings.Split(parts[3], ",")
  450. }
  451. // Use RegisterUser (sync)
  452. err := s.AuthManager.RegisterUser(u, p, roles)
  453. if err != nil {
  454. resp = fmt.Sprintf("ERR %v", err)
  455. } else {
  456. resp = "OK"
  457. }
  458. }
  459. case "ROLE_CREATE":
  460. // Usage: ROLE_CREATE <name>
  461. if len(parts) < 2 {
  462. resp = "ERR usage: ROLE_CREATE <name>"
  463. } else {
  464. name := parts[1]
  465. err := s.AuthManager.CreateRole(name)
  466. if err != nil {
  467. resp = fmt.Sprintf("ERR %v", err)
  468. } else {
  469. resp = "OK"
  470. }
  471. }
  472. case "ROLE_PERMISSION_ADD":
  473. // Usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>
  474. // Actions: comma separated list of actions (read,write,admin,*)
  475. if len(parts) < 4 {
  476. resp = "ERR usage: ROLE_PERMISSION_ADD <role> <pattern> <actions>"
  477. } else {
  478. roleName := parts[1]
  479. pattern := parts[2]
  480. actionsStr := parts[3]
  481. actions := strings.Split(actionsStr, ",")
  482. rolePtr, err := s.AuthManager.GetRole(roleName)
  483. if err != nil {
  484. resp = fmt.Sprintf("ERR %v", err)
  485. } else {
  486. // Create a copy to modify
  487. role := *rolePtr
  488. // Deep copy permissions to avoid potential side effects on cached object
  489. originalPerms := role.Permissions
  490. role.Permissions = make([]Permission, len(originalPerms))
  491. copy(role.Permissions, originalPerms)
  492. newPerm := Permission{
  493. KeyPattern: pattern,
  494. Actions: actions,
  495. }
  496. // Upsert logic: Update if exists, Append if new
  497. found := false
  498. for i, p := range role.Permissions {
  499. if p.KeyPattern == pattern {
  500. role.Permissions[i] = newPerm
  501. found = true
  502. break
  503. }
  504. }
  505. if !found {
  506. role.Permissions = append(role.Permissions, newPerm)
  507. }
  508. err := s.AuthManager.UpdateRole(role)
  509. if err != nil {
  510. resp = fmt.Sprintf("ERR %v", err)
  511. } else {
  512. resp = "OK"
  513. }
  514. }
  515. }
  516. case "ROLE_PERMISSION_REMOVE":
  517. // Usage: ROLE_PERMISSION_REMOVE <role> <pattern>
  518. if len(parts) < 3 {
  519. resp = "ERR usage: ROLE_PERMISSION_REMOVE <role> <pattern>"
  520. } else {
  521. roleName := parts[1]
  522. pattern := parts[2]
  523. rolePtr, err := s.AuthManager.GetRole(roleName)
  524. if err != nil {
  525. resp = fmt.Sprintf("ERR %v", err)
  526. } else {
  527. role := *rolePtr
  528. originalPerms := role.Permissions
  529. newPerms := make([]Permission, 0, len(originalPerms))
  530. found := false
  531. for _, p := range originalPerms {
  532. if p.KeyPattern == pattern {
  533. found = true
  534. continue
  535. }
  536. newPerms = append(newPerms, p)
  537. }
  538. if !found {
  539. resp = "ERR permission not found"
  540. } else {
  541. role.Permissions = newPerms
  542. err := s.AuthManager.UpdateRole(role)
  543. if err != nil {
  544. resp = fmt.Sprintf("ERR %v", err)
  545. } else {
  546. resp = "OK"
  547. }
  548. }
  549. }
  550. }
  551. case "USER_UNLOCK":
  552. // Usage: USER_UNLOCK <username>
  553. if s.AuthManager.IsEnabled() {
  554. if !s.IsAdmin(session.token) {
  555. resp = "ERR Permission Denied: Admin access required"
  556. break
  557. }
  558. }
  559. if len(parts) < 2 {
  560. resp = "ERR usage: USER_UNLOCK <username>"
  561. } else {
  562. // Manually clear the lock key
  563. userToUnlock := parts[1]
  564. // We use Del to remove the lock key
  565. err := s.Del("system.lock." + userToUnlock)
  566. if err != nil {
  567. resp = fmt.Sprintf("ERR %v", err)
  568. } else {
  569. resp = "OK"
  570. }
  571. }
  572. case "EXIT", "QUIT":
  573. resp = "BYE"
  574. // Need signal to close connection after write
  575. // For simplicity, handle it in handleTCPConnection loop break,
  576. // but here we just return the string.
  577. // Actually, BYE handling is tricky in async writer.
  578. // Let's keep connection open or let client close it.
  579. // Or send special signal?
  580. // For now, simple return. Client will read BYE and close.
  581. default:
  582. s.Raft.config.Logger.Warn("Unknown command received: %s (parts: %v)", cmd, parts)
  583. resp = fmt.Sprintf("ERR unknown command: %s", cmd)
  584. }
  585. return resp
  586. }