server.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. package raft
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "sort"
  7. "strings"
  8. "sync"
  9. "time"
  10. "igit.com/xbase/raft/db"
  11. )
  12. // KVServer wraps Raft to provide a distributed key-value store
  13. type KVServer struct {
  14. Raft *Raft
  15. DB *db.Engine
  16. CLI *CLI
  17. stopCh chan struct{}
  18. wg sync.WaitGroup
  19. stopOnce sync.Once
  20. // leavingNodes tracks nodes that are currently being removed
  21. // to prevent auto-rejoin/discovery logic from interfering
  22. leavingNodes sync.Map
  23. }
  24. // NewKVServer creates a new KV server
  25. func NewKVServer(config *Config) (*KVServer, error) {
  26. // Initialize DB Engine
  27. // Use a subdirectory for DB to avoid conflict with Raft logs if they share DataDir
  28. dbPath := config.DataDir + "/kv_engine"
  29. engine, err := db.NewEngine(dbPath)
  30. if err != nil {
  31. return nil, fmt.Errorf("failed to create db engine: %w", err)
  32. }
  33. // Initialize LastAppliedIndex from DB to prevent re-applying entries
  34. config.LastAppliedIndex = engine.GetLastAppliedIndex()
  35. // Create stop channel early for use in callbacks
  36. stopCh := make(chan struct{})
  37. // Configure snapshot provider
  38. config.SnapshotProvider = func(minIncludeIndex uint64) ([]byte, error) {
  39. // Wait for DB to catch up to the requested index
  40. // This is critical for data integrity during compaction
  41. for engine.GetLastAppliedIndex() < minIncludeIndex {
  42. select {
  43. case <-stopCh:
  44. return nil, fmt.Errorf("server stopping")
  45. default:
  46. time.Sleep(10 * time.Millisecond)
  47. }
  48. }
  49. // Force sync to disk to ensure data durability before compaction
  50. // This prevents data loss if Raft logs are compacted but DB data is only in OS cache
  51. if err := engine.Sync(); err != nil {
  52. return nil, fmt.Errorf("failed to sync engine before snapshot: %w", err)
  53. }
  54. return engine.Snapshot()
  55. }
  56. // Configure get handler for remote reads
  57. config.GetHandler = func(key string) (string, bool) {
  58. return engine.Get(key)
  59. }
  60. applyCh := make(chan ApplyMsg, 1000) // Increase buffer for async processing
  61. transport := NewTCPTransport(config.ListenAddr, 10, config.Logger)
  62. r, err := NewRaft(config, transport, applyCh)
  63. if err != nil {
  64. engine.Close()
  65. return nil, err
  66. }
  67. s := &KVServer{
  68. Raft: r,
  69. DB: engine,
  70. CLI: nil,
  71. stopCh: stopCh,
  72. }
  73. // Initialize CLI
  74. s.CLI = NewCLI(s)
  75. // Start applying entries
  76. go s.runApplyLoop(applyCh)
  77. // Start background maintenance loop
  78. s.wg.Add(1)
  79. go s.maintenanceLoop()
  80. return s, nil
  81. }
  82. func (s *KVServer) Start() error {
  83. // Start CLI if enabled
  84. if s.Raft.config.EnableCLI {
  85. go s.CLI.Start()
  86. }
  87. return s.Raft.Start()
  88. }
  89. func (s *KVServer) Stop() error {
  90. var err error
  91. s.stopOnce.Do(func() {
  92. // Stop maintenance loop
  93. if s.stopCh != nil {
  94. close(s.stopCh)
  95. s.wg.Wait()
  96. }
  97. // Stop Raft first
  98. if errRaft := s.Raft.Stop(); errRaft != nil {
  99. err = errRaft
  100. }
  101. // Close DB
  102. if s.DB != nil {
  103. if errDB := s.DB.Close(); errDB != nil {
  104. // Combine errors if both fail
  105. if err != nil {
  106. err = fmt.Errorf("raft stop error: %v, db close error: %v", err, errDB)
  107. } else {
  108. err = errDB
  109. }
  110. }
  111. }
  112. })
  113. return err
  114. }
  115. func (s *KVServer) runApplyLoop(applyCh chan ApplyMsg) {
  116. for msg := range applyCh {
  117. if msg.CommandValid {
  118. // Optimization: Skip if already applied
  119. // We check this here to avoid unmarshalling and locking DB for known duplicates
  120. if msg.CommandIndex <= s.DB.GetLastAppliedIndex() {
  121. continue
  122. }
  123. var cmd KVCommand
  124. if err := json.Unmarshal(msg.Command, &cmd); err != nil {
  125. s.Raft.config.Logger.Error("Failed to unmarshal command: %v", err)
  126. continue
  127. }
  128. var err error
  129. switch cmd.Type {
  130. case KVSet:
  131. err = s.DB.Set(cmd.Key, cmd.Value, msg.CommandIndex)
  132. case KVDel:
  133. err = s.DB.Delete(cmd.Key, msg.CommandIndex)
  134. default:
  135. s.Raft.config.Logger.Error("Unknown command type: %d", cmd.Type)
  136. }
  137. if err != nil {
  138. s.Raft.config.Logger.Error("DB Apply failed: %v", err)
  139. }
  140. } else if msg.SnapshotValid {
  141. if err := s.DB.Restore(msg.Snapshot); err != nil {
  142. s.Raft.config.Logger.Error("DB Restore failed: %v", err)
  143. }
  144. }
  145. }
  146. }
  147. // Set sets a key-value pair
  148. func (s *KVServer) Set(key, value string) error {
  149. cmd := KVCommand{
  150. Type: KVSet,
  151. Key: key,
  152. Value: value,
  153. }
  154. data, err := json.Marshal(cmd)
  155. if err != nil {
  156. return err
  157. }
  158. _, _, err = s.Raft.ProposeWithForward(data)
  159. return err
  160. }
  161. // Del deletes a key
  162. func (s *KVServer) Del(key string) error {
  163. cmd := KVCommand{
  164. Type: KVDel,
  165. Key: key,
  166. }
  167. data, err := json.Marshal(cmd)
  168. if err != nil {
  169. return err
  170. }
  171. _, _, err = s.Raft.ProposeWithForward(data)
  172. return err
  173. }
  174. // Get gets a value (local read, can be stale)
  175. // For linearizable reads, use GetLinear instead
  176. func (s *KVServer) Get(key string) (string, bool) {
  177. return s.DB.Get(key)
  178. }
  179. // GetLinear gets a value with linearizable consistency
  180. // This ensures the read sees all writes committed before the read started
  181. func (s *KVServer) GetLinear(key string) (string, bool, error) {
  182. // First, ensure we have up-to-date data via ReadIndex
  183. _, err := s.Raft.ReadIndex()
  184. if err != nil {
  185. // If we're not leader, try forwarding
  186. if errors.Is(err, ErrNotLeader) {
  187. return s.forwardGet(key)
  188. }
  189. return "", false, err
  190. }
  191. val, ok := s.DB.Get(key)
  192. return val, ok, nil
  193. }
  194. // forwardGet forwards a get request to the leader
  195. func (s *KVServer) forwardGet(key string) (string, bool, error) {
  196. return s.Raft.ForwardGet(key)
  197. }
  198. // Join joins an existing cluster
  199. func (s *KVServer) Join(nodeID, addr string) error {
  200. return s.Raft.AddNodeWithForward(nodeID, addr)
  201. }
  202. // Leave leaves the cluster
  203. func (s *KVServer) Leave(nodeID string) error {
  204. // Mark node as leaving to prevent auto-rejoin
  205. s.leavingNodes.Store(nodeID, time.Now())
  206. // Auto-expire the leaving flag after a while
  207. go func() {
  208. time.Sleep(30 * time.Second)
  209. s.leavingNodes.Delete(nodeID)
  210. }()
  211. // Remove from RaftNode discovery key first to prevent auto-rejoin
  212. if err := s.removeNodeFromDiscovery(nodeID); err != nil {
  213. s.Raft.config.Logger.Warn("Failed to remove node from discovery key: %v", err)
  214. // Continue anyway, as the main goal is to leave the cluster
  215. }
  216. return s.Raft.RemoveNodeWithForward(nodeID)
  217. }
  218. // removeNodeFromDiscovery removes a node from the RaftNode key to prevent auto-rejoin
  219. func (s *KVServer) removeNodeFromDiscovery(targetID string) error {
  220. val, ok := s.Get("RaftNode")
  221. if !ok || val == "" {
  222. return nil
  223. }
  224. parts := strings.Split(val, ";")
  225. var newParts []string
  226. changed := false
  227. for _, part := range parts {
  228. if part == "" {
  229. continue
  230. }
  231. kv := strings.SplitN(part, "=", 2)
  232. if len(kv) == 2 {
  233. if kv[0] == targetID {
  234. changed = true
  235. continue // Skip this node
  236. }
  237. newParts = append(newParts, part)
  238. }
  239. }
  240. if changed {
  241. newVal := strings.Join(newParts, ";")
  242. return s.Set("RaftNode", newVal)
  243. }
  244. return nil
  245. }
  246. // WaitForLeader waits until a leader is elected
  247. func (s *KVServer) WaitForLeader(timeout time.Duration) error {
  248. deadline := time.Now().Add(timeout)
  249. for time.Now().Before(deadline) {
  250. leader := s.Raft.GetLeaderID()
  251. if leader != "" {
  252. return nil
  253. }
  254. time.Sleep(100 * time.Millisecond)
  255. }
  256. return fmt.Errorf("timeout waiting for leader")
  257. }
  258. // HealthCheck returns the health status of this server
  259. func (s *KVServer) HealthCheck() HealthStatus {
  260. return s.Raft.HealthCheck()
  261. }
  262. // GetStats returns runtime statistics
  263. func (s *KVServer) GetStats() Stats {
  264. return s.Raft.GetStats()
  265. }
  266. // GetMetrics returns runtime metrics
  267. func (s *KVServer) GetMetrics() Metrics {
  268. return s.Raft.GetMetrics()
  269. }
  270. // TransferLeadership transfers leadership to the specified node
  271. func (s *KVServer) TransferLeadership(targetID string) error {
  272. return s.Raft.TransferLeadership(targetID)
  273. }
  274. // GetClusterNodes returns current cluster membership
  275. func (s *KVServer) GetClusterNodes() map[string]string {
  276. return s.Raft.GetClusterNodes()
  277. }
  278. // IsLeader returns true if this node is the leader
  279. func (s *KVServer) IsLeader() bool {
  280. _, isLeader := s.Raft.GetState()
  281. return isLeader
  282. }
  283. // GetLeaderID returns the current leader ID
  284. func (s *KVServer) GetLeaderID() string {
  285. return s.Raft.GetLeaderID()
  286. }
  287. // GetLogSize returns the raft log size
  288. func (s *KVServer) GetLogSize() int64 {
  289. return s.Raft.log.GetLogSize()
  290. }
  291. // GetDBSize returns the db size
  292. func (s *KVServer) GetDBSize() int64 {
  293. return s.DB.GetDBSize()
  294. }
  295. // WatchAll registers a watcher for all keys
  296. func (s *KVServer) WatchAll(handler WatchHandler) {
  297. // s.FSM.WatchAll(handler)
  298. // TODO: Implement Watcher for DB
  299. }
  300. // Watch registers a watcher for a key
  301. func (s *KVServer) Watch(key string, handler WatchHandler) {
  302. // s.FSM.Watch(key, handler)
  303. // TODO: Implement Watcher for DB
  304. }
  305. // Unwatch removes watchers for a key
  306. func (s *KVServer) Unwatch(key string) {
  307. // s.FSM.Unwatch(key)
  308. // TODO: Implement Watcher for DB
  309. }
  310. func (s *KVServer) maintenanceLoop() {
  311. defer s.wg.Done()
  312. // Check every 1 second for faster reaction
  313. ticker := time.NewTicker(1 * time.Second)
  314. defer ticker.Stop()
  315. for {
  316. select {
  317. case <-s.stopCh:
  318. return
  319. case <-ticker.C:
  320. s.updateNodeInfo()
  321. s.checkConnections()
  322. }
  323. }
  324. }
  325. func (s *KVServer) updateNodeInfo() {
  326. // 1. Ensure "CreateNode/<NodeID>" is set to self address
  327. // We do this via Propose (Set) so it's replicated
  328. myID := s.Raft.config.NodeID
  329. myAddr := s.Raft.config.ListenAddr
  330. key := fmt.Sprintf("CreateNode/%s", myID)
  331. // Check if we need to update (avoid spamming logs/proposals)
  332. val, exists := s.Get(key)
  333. if !exists || val != myAddr {
  334. // Run in goroutine to avoid blocking
  335. go func() {
  336. if err := s.Set(key, myAddr); err != nil {
  337. s.Raft.config.Logger.Debug("Failed to update node info: %v", err)
  338. }
  339. }()
  340. }
  341. // 2. Only leader updates RaftNode aggregation
  342. if s.IsLeader() {
  343. // Read current RaftNode to preserve history
  344. currentVal, _ := s.Get("RaftNode")
  345. knownNodes := make(map[string]string)
  346. if currentVal != "" {
  347. parts := strings.Split(currentVal, ";")
  348. for _, part := range parts {
  349. if part == "" { continue }
  350. kv := strings.SplitN(part, "=", 2)
  351. if len(kv) == 2 {
  352. knownNodes[kv[0]] = kv[1]
  353. }
  354. }
  355. }
  356. // Merge current cluster nodes
  357. changed := false
  358. currentCluster := s.GetClusterNodes()
  359. for id, addr := range currentCluster {
  360. // Skip nodes that are marked as leaving
  361. if _, leaving := s.leavingNodes.Load(id); leaving {
  362. continue
  363. }
  364. if knownNodes[id] != addr {
  365. knownNodes[id] = addr
  366. changed = true
  367. }
  368. }
  369. // If changed, update RaftNode
  370. if changed {
  371. var peers []string
  372. for id, addr := range knownNodes {
  373. peers = append(peers, fmt.Sprintf("%s=%s", id, addr))
  374. }
  375. sort.Strings(peers)
  376. newVal := strings.Join(peers, ";")
  377. // Check again if we need to write to avoid loops if Get returned stale
  378. if newVal != currentVal {
  379. go func(k, v string) {
  380. if err := s.Set(k, v); err != nil {
  381. s.Raft.config.Logger.Warn("Failed to update RaftNode key: %v", err)
  382. }
  383. }("RaftNode", newVal)
  384. }
  385. }
  386. }
  387. }
  388. func (s *KVServer) checkConnections() {
  389. if !s.IsLeader() {
  390. return
  391. }
  392. // Read RaftNode key to find potential members that are missing
  393. val, ok := s.Get("RaftNode")
  394. if !ok || val == "" {
  395. return
  396. }
  397. // Parse saved nodes
  398. savedParts := strings.Split(val, ";")
  399. currentNodes := s.GetClusterNodes()
  400. // Invert currentNodes for address check
  401. currentAddrs := make(map[string]bool)
  402. for _, addr := range currentNodes {
  403. currentAddrs[addr] = true
  404. }
  405. for _, part := range savedParts {
  406. if part == "" {
  407. continue
  408. }
  409. // Expect id=addr
  410. kv := strings.SplitN(part, "=", 2)
  411. if len(kv) != 2 {
  412. continue
  413. }
  414. id, addr := kv[0], kv[1]
  415. // Skip invalid addresses
  416. if strings.HasPrefix(addr, ".") || !strings.Contains(addr, ":") {
  417. continue
  418. }
  419. if !currentAddrs[addr] {
  420. // Skip nodes that are marked as leaving
  421. if _, leaving := s.leavingNodes.Load(id); leaving {
  422. continue
  423. }
  424. // Found a node that was previously in the cluster but is now missing
  425. // Try to add it back
  426. // We use AddNodeWithForward which handles non-blocking internally somewhat,
  427. // but we should run this in goroutine to not block the loop
  428. go func(nodeID, nodeAddr string) {
  429. // Try to add node
  430. s.Raft.config.Logger.Info("Auto-rejoining node found in RaftNode: %s (%s)", nodeID, nodeAddr)
  431. if err := s.Join(nodeID, nodeAddr); err != nil {
  432. s.Raft.config.Logger.Debug("Failed to auto-rejoin node %s: %v", nodeID, err)
  433. }
  434. }(id, addr)
  435. }
  436. }
  437. }