server.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. package raft
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "sort"
  7. "strings"
  8. "sync"
  9. "time"
  10. )
  11. // KVServer wraps Raft to provide a distributed key-value store
  12. type KVServer struct {
  13. Raft *Raft
  14. FSM *KVStateMachine
  15. stopCh chan struct{}
  16. wg sync.WaitGroup
  17. stopOnce sync.Once
  18. }
  19. // NewKVServer creates a new KV server
  20. func NewKVServer(config *Config) (*KVServer, error) {
  21. fsm := NewKVStateMachine()
  22. // Configure snapshot provider
  23. config.SnapshotProvider = func() ([]byte, error) {
  24. return fsm.Snapshot()
  25. }
  26. // Configure get handler for remote reads
  27. config.GetHandler = func(key string) (string, bool) {
  28. return fsm.Get(key)
  29. }
  30. applyCh := make(chan ApplyMsg, 100)
  31. transport := NewTCPTransport(config.ListenAddr, 10, config.Logger)
  32. r, err := NewRaft(config, transport, applyCh)
  33. if err != nil {
  34. return nil, err
  35. }
  36. s := &KVServer{
  37. Raft: r,
  38. FSM: fsm,
  39. }
  40. // Start applying entries
  41. go s.runApplyLoop(applyCh)
  42. // Start background maintenance loop
  43. s.stopCh = make(chan struct{})
  44. s.wg.Add(1)
  45. go s.maintenanceLoop()
  46. return s, nil
  47. }
  48. func (s *KVServer) Start() error {
  49. return s.Raft.Start()
  50. }
  51. func (s *KVServer) Stop() error {
  52. var err error
  53. s.stopOnce.Do(func() {
  54. // Stop maintenance loop
  55. if s.stopCh != nil {
  56. close(s.stopCh)
  57. s.wg.Wait()
  58. }
  59. err = s.Raft.Stop()
  60. })
  61. return err
  62. }
  63. func (s *KVServer) runApplyLoop(applyCh chan ApplyMsg) {
  64. for msg := range applyCh {
  65. if msg.CommandValid {
  66. if _, err := s.FSM.Apply(msg.Command); err != nil {
  67. s.Raft.config.Logger.Error("FSM Apply failed: %v", err)
  68. }
  69. } else if msg.SnapshotValid {
  70. if err := s.FSM.Restore(msg.Snapshot); err != nil {
  71. s.Raft.config.Logger.Error("FSM Restore failed: %v", err)
  72. }
  73. }
  74. }
  75. }
  76. // Set sets a key-value pair
  77. func (s *KVServer) Set(key, value string) error {
  78. cmd := KVCommand{
  79. Type: KVSet,
  80. Key: key,
  81. Value: value,
  82. }
  83. data, err := json.Marshal(cmd)
  84. if err != nil {
  85. return err
  86. }
  87. _, _, err = s.Raft.ProposeWithForward(data)
  88. return err
  89. }
  90. // Del deletes a key
  91. func (s *KVServer) Del(key string) error {
  92. cmd := KVCommand{
  93. Type: KVDel,
  94. Key: key,
  95. }
  96. data, err := json.Marshal(cmd)
  97. if err != nil {
  98. return err
  99. }
  100. _, _, err = s.Raft.ProposeWithForward(data)
  101. return err
  102. }
  103. // Get gets a value (local read, can be stale)
  104. // For linearizable reads, use GetLinear instead
  105. func (s *KVServer) Get(key string) (string, bool) {
  106. return s.FSM.Get(key)
  107. }
  108. // GetLinear gets a value with linearizable consistency
  109. // This ensures the read sees all writes committed before the read started
  110. func (s *KVServer) GetLinear(key string) (string, bool, error) {
  111. // First, ensure we have up-to-date data via ReadIndex
  112. _, err := s.Raft.ReadIndex()
  113. if err != nil {
  114. // If we're not leader, try forwarding
  115. if errors.Is(err, ErrNotLeader) {
  116. return s.forwardGet(key)
  117. }
  118. return "", false, err
  119. }
  120. val, ok := s.FSM.Get(key)
  121. return val, ok, nil
  122. }
  123. // forwardGet forwards a get request to the leader
  124. func (s *KVServer) forwardGet(key string) (string, bool, error) {
  125. return s.Raft.ForwardGet(key)
  126. }
  127. // Join joins an existing cluster
  128. func (s *KVServer) Join(nodeID, addr string) error {
  129. return s.Raft.AddNodeWithForward(nodeID, addr)
  130. }
  131. // Leave leaves the cluster
  132. func (s *KVServer) Leave(nodeID string) error {
  133. return s.Raft.RemoveNodeWithForward(nodeID)
  134. }
  135. // WaitForLeader waits until a leader is elected
  136. func (s *KVServer) WaitForLeader(timeout time.Duration) error {
  137. deadline := time.Now().Add(timeout)
  138. for time.Now().Before(deadline) {
  139. leader := s.Raft.GetLeaderID()
  140. if leader != "" {
  141. return nil
  142. }
  143. time.Sleep(100 * time.Millisecond)
  144. }
  145. return fmt.Errorf("timeout waiting for leader")
  146. }
  147. // HealthCheck returns the health status of this server
  148. func (s *KVServer) HealthCheck() HealthStatus {
  149. return s.Raft.HealthCheck()
  150. }
  151. // GetStats returns runtime statistics
  152. func (s *KVServer) GetStats() Stats {
  153. return s.Raft.GetStats()
  154. }
  155. // GetMetrics returns runtime metrics
  156. func (s *KVServer) GetMetrics() Metrics {
  157. return s.Raft.GetMetrics()
  158. }
  159. // TransferLeadership transfers leadership to the specified node
  160. func (s *KVServer) TransferLeadership(targetID string) error {
  161. return s.Raft.TransferLeadership(targetID)
  162. }
  163. // GetClusterNodes returns current cluster membership
  164. func (s *KVServer) GetClusterNodes() map[string]string {
  165. return s.Raft.GetClusterNodes()
  166. }
  167. // IsLeader returns true if this node is the leader
  168. func (s *KVServer) IsLeader() bool {
  169. _, isLeader := s.Raft.GetState()
  170. return isLeader
  171. }
  172. // GetLeaderID returns the current leader ID
  173. func (s *KVServer) GetLeaderID() string {
  174. return s.Raft.GetLeaderID()
  175. }
  176. // WatchAll registers a watcher for all keys
  177. func (s *KVServer) WatchAll(handler WatchHandler) {
  178. s.FSM.WatchAll(handler)
  179. }
  180. // Watch registers a watcher for a key
  181. func (s *KVServer) Watch(key string, handler WatchHandler) {
  182. s.FSM.Watch(key, handler)
  183. }
  184. // Unwatch removes watchers for a key
  185. func (s *KVServer) Unwatch(key string) {
  186. s.FSM.Unwatch(key)
  187. }
  188. func (s *KVServer) maintenanceLoop() {
  189. defer s.wg.Done()
  190. // Check every 1 second for faster reaction
  191. ticker := time.NewTicker(1 * time.Second)
  192. defer ticker.Stop()
  193. for {
  194. select {
  195. case <-s.stopCh:
  196. return
  197. case <-ticker.C:
  198. s.updateNodeInfo()
  199. s.checkConnections()
  200. }
  201. }
  202. }
  203. func (s *KVServer) updateNodeInfo() {
  204. // 1. Ensure "CreateNode/<NodeID>" is set to self address
  205. // We do this via Propose (Set) so it's replicated
  206. myID := s.Raft.config.NodeID
  207. myAddr := s.Raft.config.ListenAddr
  208. key := fmt.Sprintf("CreateNode/%s", myID)
  209. // Check if we need to update (avoid spamming logs/proposals)
  210. val, exists := s.Get(key)
  211. if !exists || val != myAddr {
  212. // Run in goroutine to avoid blocking
  213. go func() {
  214. if err := s.Set(key, myAddr); err != nil {
  215. s.Raft.config.Logger.Debug("Failed to update node info: %v", err)
  216. }
  217. }()
  218. }
  219. // 2. Only leader updates RaftNode aggregation
  220. if s.IsLeader() {
  221. // Read current RaftNode to preserve history
  222. currentVal, _ := s.Get("RaftNode")
  223. knownNodes := make(map[string]string)
  224. if currentVal != "" {
  225. parts := strings.Split(currentVal, ";")
  226. for _, part := range parts {
  227. if part == "" { continue }
  228. kv := strings.SplitN(part, "=", 2)
  229. if len(kv) == 2 {
  230. knownNodes[kv[0]] = kv[1]
  231. }
  232. }
  233. }
  234. // Merge current cluster nodes
  235. changed := false
  236. currentCluster := s.GetClusterNodes()
  237. for id, addr := range currentCluster {
  238. if knownNodes[id] != addr {
  239. knownNodes[id] = addr
  240. changed = true
  241. }
  242. }
  243. // If changed, update RaftNode
  244. if changed {
  245. var peers []string
  246. for id, addr := range knownNodes {
  247. peers = append(peers, fmt.Sprintf("%s=%s", id, addr))
  248. }
  249. sort.Strings(peers)
  250. newVal := strings.Join(peers, ";")
  251. // Check again if we need to write to avoid loops if Get returned stale
  252. if newVal != currentVal {
  253. go func(k, v string) {
  254. if err := s.Set(k, v); err != nil {
  255. s.Raft.config.Logger.Warn("Failed to update RaftNode key: %v", err)
  256. }
  257. }("RaftNode", newVal)
  258. }
  259. }
  260. }
  261. }
  262. func (s *KVServer) checkConnections() {
  263. if !s.IsLeader() {
  264. return
  265. }
  266. // Read RaftNode key to find potential members that are missing
  267. val, ok := s.Get("RaftNode")
  268. if !ok || val == "" {
  269. return
  270. }
  271. // Parse saved nodes
  272. savedParts := strings.Split(val, ";")
  273. currentNodes := s.GetClusterNodes()
  274. // Invert currentNodes for address check
  275. currentAddrs := make(map[string]bool)
  276. for _, addr := range currentNodes {
  277. currentAddrs[addr] = true
  278. }
  279. for _, part := range savedParts {
  280. if part == "" {
  281. continue
  282. }
  283. // Expect id=addr
  284. kv := strings.SplitN(part, "=", 2)
  285. if len(kv) != 2 {
  286. continue
  287. }
  288. id, addr := kv[0], kv[1]
  289. if !currentAddrs[addr] {
  290. // Found a node that was previously in the cluster but is now missing
  291. // Try to add it back
  292. // We use AddNodeWithForward which handles non-blocking internally somewhat,
  293. // but we should run this in goroutine to not block the loop
  294. go func(nodeID, nodeAddr string) {
  295. // Try to add node
  296. s.Raft.config.Logger.Info("Auto-rejoining node found in RaftNode: %s (%s)", nodeID, nodeAddr)
  297. if err := s.Join(nodeID, nodeAddr); err != nil {
  298. s.Raft.config.Logger.Debug("Failed to auto-rejoin node %s: %v", nodeID, err)
  299. }
  300. }(id, addr)
  301. }
  302. }
  303. }