server.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661
  1. package raft
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "sort"
  9. "strings"
  10. "sync"
  11. "time"
  12. "igit.com/xbase/raft/db"
  13. )
  14. // KVServer wraps Raft to provide a distributed key-value store
  15. type KVServer struct {
  16. Raft *Raft
  17. DB *db.Engine
  18. CLI *CLI
  19. Watcher *WebHookWatcher
  20. httpServer *http.Server
  21. stopCh chan struct{}
  22. wg sync.WaitGroup
  23. stopOnce sync.Once
  24. // leavingNodes tracks nodes that are currently being removed
  25. // to prevent auto-rejoin/discovery logic from interfering
  26. leavingNodes sync.Map
  27. }
  28. // NewKVServer creates a new KV server
  29. func NewKVServer(config *Config) (*KVServer, error) {
  30. // Initialize DB Engine
  31. // Use a subdirectory for DB to avoid conflict with Raft logs if they share DataDir
  32. dbPath := config.DataDir + "/kv_engine"
  33. engine, err := db.NewEngine(dbPath)
  34. if err != nil {
  35. return nil, fmt.Errorf("failed to create db engine: %w", err)
  36. }
  37. // Initialize LastAppliedIndex from DB to prevent re-applying entries
  38. config.LastAppliedIndex = engine.GetLastAppliedIndex()
  39. // Create stop channel early for use in callbacks
  40. stopCh := make(chan struct{})
  41. // Configure snapshot provider
  42. config.SnapshotProvider = func(minIncludeIndex uint64) ([]byte, error) {
  43. // Wait for DB to catch up to the requested index
  44. // This is critical for data integrity during compaction
  45. for engine.GetLastAppliedIndex() < minIncludeIndex {
  46. select {
  47. case <-stopCh:
  48. return nil, fmt.Errorf("server stopping")
  49. default:
  50. time.Sleep(10 * time.Millisecond)
  51. }
  52. }
  53. // Force sync to disk to ensure data durability before compaction
  54. // This prevents data loss if Raft logs are compacted but DB data is only in OS cache
  55. if err := engine.Sync(); err != nil {
  56. return nil, fmt.Errorf("failed to sync engine before snapshot: %w", err)
  57. }
  58. return engine.Snapshot()
  59. }
  60. // Configure get handler for remote reads
  61. config.GetHandler = func(key string) (string, bool) {
  62. return engine.Get(key)
  63. }
  64. applyCh := make(chan ApplyMsg, 1000) // Increase buffer for async processing
  65. transport := NewTCPTransport(config.ListenAddr, 10, config.Logger)
  66. // Initialize WebHookWatcher
  67. // 5 workers, 3 retries
  68. watcher := NewWebHookWatcher(5, 3, config.Logger)
  69. r, err := NewRaft(config, transport, applyCh)
  70. if err != nil {
  71. engine.Close()
  72. return nil, err
  73. }
  74. s := &KVServer{
  75. Raft: r,
  76. DB: engine,
  77. CLI: nil,
  78. Watcher: watcher,
  79. stopCh: stopCh,
  80. }
  81. // Initialize CLI
  82. s.CLI = NewCLI(s)
  83. // Start applying entries
  84. go s.runApplyLoop(applyCh)
  85. // Start background maintenance loop
  86. s.wg.Add(1)
  87. go s.maintenanceLoop()
  88. return s, nil
  89. }
  90. func (s *KVServer) Start() error {
  91. // Start CLI if enabled
  92. if s.Raft.config.EnableCLI {
  93. go s.CLI.Start()
  94. }
  95. // Start HTTP Server if configured
  96. if s.Raft.config.HTTPAddr != "" {
  97. if err := s.startHTTPServer(s.Raft.config.HTTPAddr); err != nil {
  98. s.Raft.config.Logger.Warn("Failed to start HTTP server: %v", err)
  99. }
  100. }
  101. return s.Raft.Start()
  102. }
  103. func (s *KVServer) Stop() error {
  104. var err error
  105. s.stopOnce.Do(func() {
  106. // Stop maintenance loop
  107. if s.stopCh != nil {
  108. close(s.stopCh)
  109. s.wg.Wait()
  110. }
  111. // Stop Watcher
  112. if s.Watcher != nil {
  113. s.Watcher.Stop()
  114. }
  115. // Stop HTTP Server
  116. if s.httpServer != nil {
  117. s.httpServer.Close()
  118. }
  119. // Stop Raft first
  120. if errRaft := s.Raft.Stop(); errRaft != nil {
  121. err = errRaft
  122. }
  123. // Close DB
  124. if s.DB != nil {
  125. if errDB := s.DB.Close(); errDB != nil {
  126. // Combine errors if both fail
  127. if err != nil {
  128. err = fmt.Errorf("raft stop error: %v, db close error: %v", err, errDB)
  129. } else {
  130. err = errDB
  131. }
  132. }
  133. }
  134. })
  135. return err
  136. }
  137. func (s *KVServer) runApplyLoop(applyCh chan ApplyMsg) {
  138. for msg := range applyCh {
  139. if msg.CommandValid {
  140. // Optimization: Skip if already applied
  141. // We check this here to avoid unmarshalling and locking DB for known duplicates
  142. if msg.CommandIndex <= s.DB.GetLastAppliedIndex() {
  143. continue
  144. }
  145. var cmd KVCommand
  146. if err := json.Unmarshal(msg.Command, &cmd); err != nil {
  147. s.Raft.config.Logger.Error("Failed to unmarshal command: %v", err)
  148. continue
  149. }
  150. var err error
  151. switch cmd.Type {
  152. case KVSet:
  153. err = s.DB.Set(cmd.Key, cmd.Value, msg.CommandIndex)
  154. if err == nil {
  155. s.Watcher.Notify(cmd.Key, cmd.Value, KVSet)
  156. }
  157. case KVDel:
  158. err = s.DB.Delete(cmd.Key, msg.CommandIndex)
  159. if err == nil {
  160. s.Watcher.Notify(cmd.Key, "", KVDel)
  161. }
  162. default:
  163. s.Raft.config.Logger.Error("Unknown command type: %d", cmd.Type)
  164. }
  165. if err != nil {
  166. s.Raft.config.Logger.Error("DB Apply failed: %v", err)
  167. }
  168. } else if msg.SnapshotValid {
  169. if err := s.DB.Restore(msg.Snapshot); err != nil {
  170. s.Raft.config.Logger.Error("DB Restore failed: %v", err)
  171. }
  172. }
  173. }
  174. }
  175. // Set sets a key-value pair
  176. func (s *KVServer) Set(key, value string) error {
  177. cmd := KVCommand{
  178. Type: KVSet,
  179. Key: key,
  180. Value: value,
  181. }
  182. data, err := json.Marshal(cmd)
  183. if err != nil {
  184. return err
  185. }
  186. _, _, err = s.Raft.ProposeWithForward(data)
  187. return err
  188. }
  189. // Del deletes a key
  190. func (s *KVServer) Del(key string) error {
  191. cmd := KVCommand{
  192. Type: KVDel,
  193. Key: key,
  194. }
  195. data, err := json.Marshal(cmd)
  196. if err != nil {
  197. return err
  198. }
  199. _, _, err = s.Raft.ProposeWithForward(data)
  200. return err
  201. }
  202. // Get gets a value (local read, can be stale)
  203. // For linearizable reads, use GetLinear instead
  204. func (s *KVServer) Get(key string) (string, bool) {
  205. return s.DB.Get(key)
  206. }
  207. // GetLinear gets a value with linearizable consistency
  208. // This ensures the read sees all writes committed before the read started
  209. func (s *KVServer) GetLinear(key string) (string, bool, error) {
  210. // First, ensure we have up-to-date data via ReadIndex
  211. _, err := s.Raft.ReadIndex()
  212. if err != nil {
  213. // If we're not leader, try forwarding
  214. if errors.Is(err, ErrNotLeader) {
  215. return s.forwardGet(key)
  216. }
  217. return "", false, err
  218. }
  219. val, ok := s.DB.Get(key)
  220. return val, ok, nil
  221. }
  222. // forwardGet forwards a get request to the leader
  223. func (s *KVServer) forwardGet(key string) (string, bool, error) {
  224. return s.Raft.ForwardGet(key)
  225. }
  226. // Join joins an existing cluster
  227. func (s *KVServer) Join(nodeID, addr string) error {
  228. return s.Raft.AddNodeWithForward(nodeID, addr)
  229. }
  230. // Leave leaves the cluster
  231. func (s *KVServer) Leave(nodeID string) error {
  232. // Mark node as leaving to prevent auto-rejoin
  233. s.leavingNodes.Store(nodeID, time.Now())
  234. // Auto-expire the leaving flag after a while
  235. go func() {
  236. time.Sleep(30 * time.Second)
  237. s.leavingNodes.Delete(nodeID)
  238. }()
  239. // Remove from RaftNode discovery key first to prevent auto-rejoin
  240. if err := s.removeNodeFromDiscovery(nodeID); err != nil {
  241. s.Raft.config.Logger.Warn("Failed to remove node from discovery key: %v", err)
  242. // Continue anyway, as the main goal is to leave the cluster
  243. }
  244. return s.Raft.RemoveNodeWithForward(nodeID)
  245. }
  246. // removeNodeFromDiscovery removes a node from the RaftNode key to prevent auto-rejoin
  247. func (s *KVServer) removeNodeFromDiscovery(targetID string) error {
  248. val, ok := s.Get("RaftNode")
  249. if !ok || val == "" {
  250. return nil
  251. }
  252. parts := strings.Split(val, ";")
  253. var newParts []string
  254. changed := false
  255. for _, part := range parts {
  256. if part == "" {
  257. continue
  258. }
  259. kv := strings.SplitN(part, "=", 2)
  260. if len(kv) == 2 {
  261. if kv[0] == targetID {
  262. changed = true
  263. continue // Skip this node
  264. }
  265. newParts = append(newParts, part)
  266. }
  267. }
  268. if changed {
  269. newVal := strings.Join(newParts, ";")
  270. return s.Set("RaftNode", newVal)
  271. }
  272. return nil
  273. }
  274. // WaitForLeader waits until a leader is elected
  275. func (s *KVServer) WaitForLeader(timeout time.Duration) error {
  276. deadline := time.Now().Add(timeout)
  277. for time.Now().Before(deadline) {
  278. leader := s.Raft.GetLeaderID()
  279. if leader != "" {
  280. return nil
  281. }
  282. time.Sleep(100 * time.Millisecond)
  283. }
  284. return fmt.Errorf("timeout waiting for leader")
  285. }
  286. // HealthCheck returns the health status of this server
  287. func (s *KVServer) HealthCheck() HealthStatus {
  288. return s.Raft.HealthCheck()
  289. }
  290. // GetStats returns runtime statistics
  291. func (s *KVServer) GetStats() Stats {
  292. return s.Raft.GetStats()
  293. }
  294. // GetMetrics returns runtime metrics
  295. func (s *KVServer) GetMetrics() Metrics {
  296. return s.Raft.GetMetrics()
  297. }
  298. // TransferLeadership transfers leadership to the specified node
  299. func (s *KVServer) TransferLeadership(targetID string) error {
  300. return s.Raft.TransferLeadership(targetID)
  301. }
  302. // GetClusterNodes returns current cluster membership
  303. func (s *KVServer) GetClusterNodes() map[string]string {
  304. return s.Raft.GetClusterNodes()
  305. }
  306. // IsLeader returns true if this node is the leader
  307. func (s *KVServer) IsLeader() bool {
  308. _, isLeader := s.Raft.GetState()
  309. return isLeader
  310. }
  311. // GetLeaderID returns the current leader ID
  312. func (s *KVServer) GetLeaderID() string {
  313. return s.Raft.GetLeaderID()
  314. }
  315. // GetLogSize returns the raft log size
  316. func (s *KVServer) GetLogSize() int64 {
  317. return s.Raft.log.GetLogSize()
  318. }
  319. // GetDBSize returns the db size
  320. func (s *KVServer) GetDBSize() int64 {
  321. return s.DB.GetDBSize()
  322. }
  323. // WatchURL registers a webhook url for a key
  324. func (s *KVServer) WatchURL(key, url string) {
  325. s.Watcher.Subscribe(key, url)
  326. }
  327. // UnwatchURL removes a webhook url for a key
  328. func (s *KVServer) UnwatchURL(key, url string) {
  329. s.Watcher.Unsubscribe(key, url)
  330. }
  331. // WatchAll registers a watcher for all keys
  332. func (s *KVServer) WatchAll(handler WatchHandler) {
  333. // s.FSM.WatchAll(handler)
  334. // TODO: Implement Watcher for DB
  335. }
  336. // Watch registers a watcher for a key
  337. func (s *KVServer) Watch(key string, handler WatchHandler) {
  338. // s.FSM.Watch(key, handler)
  339. // TODO: Implement Watcher for DB
  340. }
  341. // Unwatch removes watchers for a key
  342. func (s *KVServer) Unwatch(key string) {
  343. // s.FSM.Unwatch(key)
  344. // TODO: Implement Watcher for DB
  345. }
  346. func (s *KVServer) maintenanceLoop() {
  347. defer s.wg.Done()
  348. // Check every 1 second for faster reaction
  349. ticker := time.NewTicker(1 * time.Second)
  350. defer ticker.Stop()
  351. for {
  352. select {
  353. case <-s.stopCh:
  354. return
  355. case <-ticker.C:
  356. s.updateNodeInfo()
  357. s.checkConnections()
  358. }
  359. }
  360. }
  361. func (s *KVServer) updateNodeInfo() {
  362. // 1. Ensure "CreateNode/<NodeID>" is set to self address
  363. // We do this via Propose (Set) so it's replicated
  364. myID := s.Raft.config.NodeID
  365. myAddr := s.Raft.config.ListenAddr
  366. key := fmt.Sprintf("CreateNode/%s", myID)
  367. // Check if we need to update (avoid spamming logs/proposals)
  368. val, exists := s.Get(key)
  369. if !exists || val != myAddr {
  370. // Run in goroutine to avoid blocking
  371. go func() {
  372. if err := s.Set(key, myAddr); err != nil {
  373. s.Raft.config.Logger.Debug("Failed to update node info: %v", err)
  374. }
  375. }()
  376. }
  377. // 2. Only leader updates RaftNode aggregation
  378. if s.IsLeader() {
  379. // Read current RaftNode to preserve history
  380. currentVal, _ := s.Get("RaftNode")
  381. knownNodes := make(map[string]string)
  382. if currentVal != "" {
  383. parts := strings.Split(currentVal, ";")
  384. for _, part := range parts {
  385. if part == "" { continue }
  386. kv := strings.SplitN(part, "=", 2)
  387. if len(kv) == 2 {
  388. knownNodes[kv[0]] = kv[1]
  389. }
  390. }
  391. }
  392. // Merge current cluster nodes
  393. changed := false
  394. currentCluster := s.GetClusterNodes()
  395. for id, addr := range currentCluster {
  396. // Skip nodes that are marked as leaving
  397. if _, leaving := s.leavingNodes.Load(id); leaving {
  398. continue
  399. }
  400. if knownNodes[id] != addr {
  401. knownNodes[id] = addr
  402. changed = true
  403. }
  404. }
  405. // If changed, update RaftNode
  406. if changed {
  407. var peers []string
  408. for id, addr := range knownNodes {
  409. peers = append(peers, fmt.Sprintf("%s=%s", id, addr))
  410. }
  411. sort.Strings(peers)
  412. newVal := strings.Join(peers, ";")
  413. // Check again if we need to write to avoid loops if Get returned stale
  414. if newVal != currentVal {
  415. go func(k, v string) {
  416. if err := s.Set(k, v); err != nil {
  417. s.Raft.config.Logger.Warn("Failed to update RaftNode key: %v", err)
  418. }
  419. }("RaftNode", newVal)
  420. }
  421. }
  422. }
  423. }
  424. func (s *KVServer) checkConnections() {
  425. if !s.IsLeader() {
  426. return
  427. }
  428. // Read RaftNode key to find potential members that are missing
  429. val, ok := s.Get("RaftNode")
  430. if !ok || val == "" {
  431. return
  432. }
  433. // Parse saved nodes
  434. savedParts := strings.Split(val, ";")
  435. currentNodes := s.GetClusterNodes()
  436. // Invert currentNodes for address check
  437. currentAddrs := make(map[string]bool)
  438. for _, addr := range currentNodes {
  439. currentAddrs[addr] = true
  440. }
  441. for _, part := range savedParts {
  442. if part == "" {
  443. continue
  444. }
  445. // Expect id=addr
  446. kv := strings.SplitN(part, "=", 2)
  447. if len(kv) != 2 {
  448. continue
  449. }
  450. id, addr := kv[0], kv[1]
  451. // Skip invalid addresses
  452. if strings.HasPrefix(addr, ".") || !strings.Contains(addr, ":") {
  453. continue
  454. }
  455. if !currentAddrs[addr] {
  456. // Skip nodes that are marked as leaving
  457. if _, leaving := s.leavingNodes.Load(id); leaving {
  458. continue
  459. }
  460. // Found a node that was previously in the cluster but is now missing
  461. // Try to add it back
  462. // We use AddNodeWithForward which handles non-blocking internally somewhat,
  463. // but we should run this in goroutine to not block the loop
  464. go func(nodeID, nodeAddr string) {
  465. // Try to add node
  466. s.Raft.config.Logger.Info("Auto-rejoining node found in RaftNode: %s (%s)", nodeID, nodeAddr)
  467. if err := s.Join(nodeID, nodeAddr); err != nil {
  468. s.Raft.config.Logger.Debug("Failed to auto-rejoin node %s: %v", nodeID, err)
  469. }
  470. }(id, addr)
  471. }
  472. }
  473. }
  474. // startHTTPServer starts the HTTP API server
  475. func (s *KVServer) startHTTPServer(addr string) error {
  476. mux := http.NewServeMux()
  477. // KV API
  478. mux.HandleFunc("/kv", func(w http.ResponseWriter, r *http.Request) {
  479. switch r.Method {
  480. case http.MethodGet:
  481. key := r.URL.Query().Get("key")
  482. if key == "" {
  483. http.Error(w, "missing key", http.StatusBadRequest)
  484. return
  485. }
  486. val, found, err := s.GetLinear(key)
  487. if err != nil {
  488. http.Error(w, err.Error(), http.StatusInternalServerError)
  489. return
  490. }
  491. if !found {
  492. http.Error(w, "not found", http.StatusNotFound)
  493. return
  494. }
  495. w.Write([]byte(val))
  496. case http.MethodPost:
  497. body, _ := io.ReadAll(r.Body)
  498. var req struct {
  499. Key string `json:"key"`
  500. Value string `json:"value"`
  501. }
  502. if err := json.Unmarshal(body, &req); err != nil {
  503. http.Error(w, "invalid json", http.StatusBadRequest)
  504. return
  505. }
  506. if err := s.Set(req.Key, req.Value); err != nil {
  507. http.Error(w, err.Error(), http.StatusInternalServerError)
  508. return
  509. }
  510. w.WriteHeader(http.StatusOK)
  511. case http.MethodDelete:
  512. key := r.URL.Query().Get("key")
  513. if key == "" {
  514. http.Error(w, "missing key", http.StatusBadRequest)
  515. return
  516. }
  517. if err := s.Del(key); err != nil {
  518. http.Error(w, err.Error(), http.StatusInternalServerError)
  519. return
  520. }
  521. w.WriteHeader(http.StatusOK)
  522. default:
  523. http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
  524. }
  525. })
  526. // Watcher API
  527. mux.HandleFunc("/watch", func(w http.ResponseWriter, r *http.Request) {
  528. if r.Method != http.MethodPost {
  529. http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
  530. return
  531. }
  532. body, _ := io.ReadAll(r.Body)
  533. var req struct {
  534. Key string `json:"key"`
  535. URL string `json:"url"`
  536. }
  537. if err := json.Unmarshal(body, &req); err != nil {
  538. http.Error(w, "invalid json", http.StatusBadRequest)
  539. return
  540. }
  541. if req.Key == "" || req.URL == "" {
  542. http.Error(w, "missing key or url", http.StatusBadRequest)
  543. return
  544. }
  545. s.WatchURL(req.Key, req.URL)
  546. w.WriteHeader(http.StatusOK)
  547. })
  548. mux.HandleFunc("/unwatch", func(w http.ResponseWriter, r *http.Request) {
  549. if r.Method != http.MethodPost {
  550. http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
  551. return
  552. }
  553. body, _ := io.ReadAll(r.Body)
  554. var req struct {
  555. Key string `json:"key"`
  556. URL string `json:"url"`
  557. }
  558. if err := json.Unmarshal(body, &req); err != nil {
  559. http.Error(w, "invalid json", http.StatusBadRequest)
  560. return
  561. }
  562. s.UnwatchURL(req.Key, req.URL)
  563. w.WriteHeader(http.StatusOK)
  564. })
  565. s.httpServer = &http.Server{
  566. Addr: addr,
  567. Handler: mux,
  568. }
  569. go func() {
  570. s.Raft.config.Logger.Info("HTTP API server listening on %s", addr)
  571. if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
  572. s.Raft.config.Logger.Error("HTTP server failed: %v", err)
  573. }
  574. }()
  575. return nil
  576. }