server.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. package raft
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "sort"
  7. "strings"
  8. "sync"
  9. "time"
  10. "igit.com/xbase/raft/db"
  11. )
  12. // KVServer wraps Raft to provide a distributed key-value store
  13. type KVServer struct {
  14. Raft *Raft
  15. DB *db.Engine
  16. stopCh chan struct{}
  17. wg sync.WaitGroup
  18. stopOnce sync.Once
  19. }
  20. // NewKVServer creates a new KV server
  21. func NewKVServer(config *Config) (*KVServer, error) {
  22. // Initialize DB Engine
  23. // Use a subdirectory for DB to avoid conflict with Raft logs if they share DataDir
  24. dbPath := config.DataDir + "/kv_engine"
  25. engine, err := db.NewEngine(dbPath)
  26. if err != nil {
  27. return nil, fmt.Errorf("failed to create db engine: %w", err)
  28. }
  29. // Configure snapshot provider
  30. config.SnapshotProvider = func() ([]byte, error) {
  31. return engine.Snapshot()
  32. }
  33. // Configure get handler for remote reads
  34. config.GetHandler = func(key string) (string, bool) {
  35. return engine.Get(key)
  36. }
  37. applyCh := make(chan ApplyMsg, 1000) // Increase buffer for async processing
  38. transport := NewTCPTransport(config.ListenAddr, 10, config.Logger)
  39. r, err := NewRaft(config, transport, applyCh)
  40. if err != nil {
  41. engine.Close()
  42. return nil, err
  43. }
  44. s := &KVServer{
  45. Raft: r,
  46. DB: engine,
  47. }
  48. // Start applying entries
  49. go s.runApplyLoop(applyCh)
  50. // Start background maintenance loop
  51. s.stopCh = make(chan struct{})
  52. s.wg.Add(1)
  53. go s.maintenanceLoop()
  54. return s, nil
  55. }
  56. func (s *KVServer) Start() error {
  57. return s.Raft.Start()
  58. }
  59. func (s *KVServer) Stop() error {
  60. var err error
  61. s.stopOnce.Do(func() {
  62. // Stop maintenance loop
  63. if s.stopCh != nil {
  64. close(s.stopCh)
  65. s.wg.Wait()
  66. }
  67. // Stop Raft first
  68. if errRaft := s.Raft.Stop(); errRaft != nil {
  69. err = errRaft
  70. }
  71. // Close DB
  72. if s.DB != nil {
  73. if errDB := s.DB.Close(); errDB != nil {
  74. // Combine errors if both fail
  75. if err != nil {
  76. err = fmt.Errorf("raft stop error: %v, db close error: %v", err, errDB)
  77. } else {
  78. err = errDB
  79. }
  80. }
  81. }
  82. })
  83. return err
  84. }
  85. func (s *KVServer) runApplyLoop(applyCh chan ApplyMsg) {
  86. for msg := range applyCh {
  87. if msg.CommandValid {
  88. // Optimization: Skip if already applied
  89. // We check this here to avoid unmarshalling and locking DB for known duplicates
  90. if msg.CommandIndex <= s.DB.GetLastAppliedIndex() {
  91. continue
  92. }
  93. var cmd KVCommand
  94. if err := json.Unmarshal(msg.Command, &cmd); err != nil {
  95. s.Raft.config.Logger.Error("Failed to unmarshal command: %v", err)
  96. continue
  97. }
  98. var err error
  99. switch cmd.Type {
  100. case KVSet:
  101. err = s.DB.Set(cmd.Key, cmd.Value, msg.CommandIndex)
  102. case KVDel:
  103. err = s.DB.Delete(cmd.Key, msg.CommandIndex)
  104. default:
  105. s.Raft.config.Logger.Error("Unknown command type: %d", cmd.Type)
  106. }
  107. if err != nil {
  108. s.Raft.config.Logger.Error("DB Apply failed: %v", err)
  109. }
  110. } else if msg.SnapshotValid {
  111. if err := s.DB.Restore(msg.Snapshot); err != nil {
  112. s.Raft.config.Logger.Error("DB Restore failed: %v", err)
  113. }
  114. }
  115. }
  116. }
  117. // Set sets a key-value pair
  118. func (s *KVServer) Set(key, value string) error {
  119. cmd := KVCommand{
  120. Type: KVSet,
  121. Key: key,
  122. Value: value,
  123. }
  124. data, err := json.Marshal(cmd)
  125. if err != nil {
  126. return err
  127. }
  128. _, _, err = s.Raft.ProposeWithForward(data)
  129. return err
  130. }
  131. // Del deletes a key
  132. func (s *KVServer) Del(key string) error {
  133. cmd := KVCommand{
  134. Type: KVDel,
  135. Key: key,
  136. }
  137. data, err := json.Marshal(cmd)
  138. if err != nil {
  139. return err
  140. }
  141. _, _, err = s.Raft.ProposeWithForward(data)
  142. return err
  143. }
  144. // Get gets a value (local read, can be stale)
  145. // For linearizable reads, use GetLinear instead
  146. func (s *KVServer) Get(key string) (string, bool) {
  147. return s.DB.Get(key)
  148. }
  149. // GetLinear gets a value with linearizable consistency
  150. // This ensures the read sees all writes committed before the read started
  151. func (s *KVServer) GetLinear(key string) (string, bool, error) {
  152. // First, ensure we have up-to-date data via ReadIndex
  153. _, err := s.Raft.ReadIndex()
  154. if err != nil {
  155. // If we're not leader, try forwarding
  156. if errors.Is(err, ErrNotLeader) {
  157. return s.forwardGet(key)
  158. }
  159. return "", false, err
  160. }
  161. val, ok := s.DB.Get(key)
  162. return val, ok, nil
  163. }
  164. // forwardGet forwards a get request to the leader
  165. func (s *KVServer) forwardGet(key string) (string, bool, error) {
  166. return s.Raft.ForwardGet(key)
  167. }
  168. // Join joins an existing cluster
  169. func (s *KVServer) Join(nodeID, addr string) error {
  170. return s.Raft.AddNodeWithForward(nodeID, addr)
  171. }
  172. // Leave leaves the cluster
  173. func (s *KVServer) Leave(nodeID string) error {
  174. return s.Raft.RemoveNodeWithForward(nodeID)
  175. }
  176. // WaitForLeader waits until a leader is elected
  177. func (s *KVServer) WaitForLeader(timeout time.Duration) error {
  178. deadline := time.Now().Add(timeout)
  179. for time.Now().Before(deadline) {
  180. leader := s.Raft.GetLeaderID()
  181. if leader != "" {
  182. return nil
  183. }
  184. time.Sleep(100 * time.Millisecond)
  185. }
  186. return fmt.Errorf("timeout waiting for leader")
  187. }
  188. // HealthCheck returns the health status of this server
  189. func (s *KVServer) HealthCheck() HealthStatus {
  190. return s.Raft.HealthCheck()
  191. }
  192. // GetStats returns runtime statistics
  193. func (s *KVServer) GetStats() Stats {
  194. return s.Raft.GetStats()
  195. }
  196. // GetMetrics returns runtime metrics
  197. func (s *KVServer) GetMetrics() Metrics {
  198. return s.Raft.GetMetrics()
  199. }
  200. // TransferLeadership transfers leadership to the specified node
  201. func (s *KVServer) TransferLeadership(targetID string) error {
  202. return s.Raft.TransferLeadership(targetID)
  203. }
  204. // GetClusterNodes returns current cluster membership
  205. func (s *KVServer) GetClusterNodes() map[string]string {
  206. return s.Raft.GetClusterNodes()
  207. }
  208. // IsLeader returns true if this node is the leader
  209. func (s *KVServer) IsLeader() bool {
  210. _, isLeader := s.Raft.GetState()
  211. return isLeader
  212. }
  213. // GetLeaderID returns the current leader ID
  214. func (s *KVServer) GetLeaderID() string {
  215. return s.Raft.GetLeaderID()
  216. }
  217. // WatchAll registers a watcher for all keys
  218. func (s *KVServer) WatchAll(handler WatchHandler) {
  219. // s.FSM.WatchAll(handler)
  220. // TODO: Implement Watcher for DB
  221. }
  222. // Watch registers a watcher for a key
  223. func (s *KVServer) Watch(key string, handler WatchHandler) {
  224. // s.FSM.Watch(key, handler)
  225. // TODO: Implement Watcher for DB
  226. }
  227. // Unwatch removes watchers for a key
  228. func (s *KVServer) Unwatch(key string) {
  229. // s.FSM.Unwatch(key)
  230. // TODO: Implement Watcher for DB
  231. }
  232. func (s *KVServer) maintenanceLoop() {
  233. defer s.wg.Done()
  234. // Check every 1 second for faster reaction
  235. ticker := time.NewTicker(1 * time.Second)
  236. defer ticker.Stop()
  237. for {
  238. select {
  239. case <-s.stopCh:
  240. return
  241. case <-ticker.C:
  242. s.updateNodeInfo()
  243. s.checkConnections()
  244. }
  245. }
  246. }
  247. func (s *KVServer) updateNodeInfo() {
  248. // 1. Ensure "CreateNode/<NodeID>" is set to self address
  249. // We do this via Propose (Set) so it's replicated
  250. myID := s.Raft.config.NodeID
  251. myAddr := s.Raft.config.ListenAddr
  252. key := fmt.Sprintf("CreateNode/%s", myID)
  253. // Check if we need to update (avoid spamming logs/proposals)
  254. val, exists := s.Get(key)
  255. if !exists || val != myAddr {
  256. // Run in goroutine to avoid blocking
  257. go func() {
  258. if err := s.Set(key, myAddr); err != nil {
  259. s.Raft.config.Logger.Debug("Failed to update node info: %v", err)
  260. }
  261. }()
  262. }
  263. // 2. Only leader updates RaftNode aggregation
  264. if s.IsLeader() {
  265. // Read current RaftNode to preserve history
  266. currentVal, _ := s.Get("RaftNode")
  267. knownNodes := make(map[string]string)
  268. if currentVal != "" {
  269. parts := strings.Split(currentVal, ";")
  270. for _, part := range parts {
  271. if part == "" { continue }
  272. kv := strings.SplitN(part, "=", 2)
  273. if len(kv) == 2 {
  274. knownNodes[kv[0]] = kv[1]
  275. }
  276. }
  277. }
  278. // Merge current cluster nodes
  279. changed := false
  280. currentCluster := s.GetClusterNodes()
  281. for id, addr := range currentCluster {
  282. if knownNodes[id] != addr {
  283. knownNodes[id] = addr
  284. changed = true
  285. }
  286. }
  287. // If changed, update RaftNode
  288. if changed {
  289. var peers []string
  290. for id, addr := range knownNodes {
  291. peers = append(peers, fmt.Sprintf("%s=%s", id, addr))
  292. }
  293. sort.Strings(peers)
  294. newVal := strings.Join(peers, ";")
  295. // Check again if we need to write to avoid loops if Get returned stale
  296. if newVal != currentVal {
  297. go func(k, v string) {
  298. if err := s.Set(k, v); err != nil {
  299. s.Raft.config.Logger.Warn("Failed to update RaftNode key: %v", err)
  300. }
  301. }("RaftNode", newVal)
  302. }
  303. }
  304. }
  305. }
  306. func (s *KVServer) checkConnections() {
  307. if !s.IsLeader() {
  308. return
  309. }
  310. // Read RaftNode key to find potential members that are missing
  311. val, ok := s.Get("RaftNode")
  312. if !ok || val == "" {
  313. return
  314. }
  315. // Parse saved nodes
  316. savedParts := strings.Split(val, ";")
  317. currentNodes := s.GetClusterNodes()
  318. // Invert currentNodes for address check
  319. currentAddrs := make(map[string]bool)
  320. for _, addr := range currentNodes {
  321. currentAddrs[addr] = true
  322. }
  323. for _, part := range savedParts {
  324. if part == "" {
  325. continue
  326. }
  327. // Expect id=addr
  328. kv := strings.SplitN(part, "=", 2)
  329. if len(kv) != 2 {
  330. continue
  331. }
  332. id, addr := kv[0], kv[1]
  333. if !currentAddrs[addr] {
  334. // Found a node that was previously in the cluster but is now missing
  335. // Try to add it back
  336. // We use AddNodeWithForward which handles non-blocking internally somewhat,
  337. // but we should run this in goroutine to not block the loop
  338. go func(nodeID, nodeAddr string) {
  339. // Try to add node
  340. s.Raft.config.Logger.Info("Auto-rejoining node found in RaftNode: %s (%s)", nodeID, nodeAddr)
  341. if err := s.Join(nodeID, nodeAddr); err != nil {
  342. s.Raft.config.Logger.Debug("Failed to auto-rejoin node %s: %v", nodeID, err)
  343. }
  344. }(id, addr)
  345. }
  346. }
  347. }