server.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. package raft
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "sort"
  6. "strings"
  7. "sync"
  8. "time"
  9. )
  10. // KVServer wraps Raft to provide a distributed key-value store
  11. type KVServer struct {
  12. Raft *Raft
  13. FSM *KVStateMachine
  14. stopCh chan struct{}
  15. wg sync.WaitGroup
  16. stopOnce sync.Once
  17. }
  18. // NewKVServer creates a new KV server
  19. func NewKVServer(config *Config) (*KVServer, error) {
  20. fsm := NewKVStateMachine()
  21. // Configure snapshot provider
  22. config.SnapshotProvider = func() ([]byte, error) {
  23. return fsm.Snapshot()
  24. }
  25. // Configure get handler for remote reads
  26. config.GetHandler = func(key string) (string, bool) {
  27. return fsm.Get(key)
  28. }
  29. applyCh := make(chan ApplyMsg, 100)
  30. transport := NewTCPTransport(config.ListenAddr, 10, config.Logger)
  31. r, err := NewRaft(config, transport, applyCh)
  32. if err != nil {
  33. return nil, err
  34. }
  35. s := &KVServer{
  36. Raft: r,
  37. FSM: fsm,
  38. }
  39. // Start applying entries
  40. go s.runApplyLoop(applyCh)
  41. // Start background maintenance loop
  42. s.stopCh = make(chan struct{})
  43. s.wg.Add(1)
  44. go s.maintenanceLoop()
  45. return s, nil
  46. }
  47. func (s *KVServer) Start() error {
  48. return s.Raft.Start()
  49. }
  50. func (s *KVServer) Stop() error {
  51. var err error
  52. s.stopOnce.Do(func() {
  53. // Stop maintenance loop
  54. if s.stopCh != nil {
  55. close(s.stopCh)
  56. s.wg.Wait()
  57. }
  58. err = s.Raft.Stop()
  59. })
  60. return err
  61. }
  62. func (s *KVServer) runApplyLoop(applyCh chan ApplyMsg) {
  63. for msg := range applyCh {
  64. if msg.CommandValid {
  65. if _, err := s.FSM.Apply(msg.Command); err != nil {
  66. s.Raft.config.Logger.Error("FSM Apply failed: %v", err)
  67. }
  68. } else if msg.SnapshotValid {
  69. if err := s.FSM.Restore(msg.Snapshot); err != nil {
  70. s.Raft.config.Logger.Error("FSM Restore failed: %v", err)
  71. }
  72. }
  73. }
  74. }
  75. // Set sets a key-value pair
  76. func (s *KVServer) Set(key, value string) error {
  77. cmd := KVCommand{
  78. Type: KVSet,
  79. Key: key,
  80. Value: value,
  81. }
  82. data, err := json.Marshal(cmd)
  83. if err != nil {
  84. return err
  85. }
  86. _, _, err = s.Raft.ProposeWithForward(data)
  87. return err
  88. }
  89. // Del deletes a key
  90. func (s *KVServer) Del(key string) error {
  91. cmd := KVCommand{
  92. Type: KVDel,
  93. Key: key,
  94. }
  95. data, err := json.Marshal(cmd)
  96. if err != nil {
  97. return err
  98. }
  99. _, _, err = s.Raft.ProposeWithForward(data)
  100. return err
  101. }
  102. // Get gets a value (local read, can be stale)
  103. // For linearizable reads, use GetLinear instead
  104. func (s *KVServer) Get(key string) (string, bool) {
  105. return s.FSM.Get(key)
  106. }
  107. // GetLinear gets a value with linearizable consistency
  108. // This ensures the read sees all writes committed before the read started
  109. func (s *KVServer) GetLinear(key string) (string, bool, error) {
  110. // First, ensure we have up-to-date data via ReadIndex
  111. _, err := s.Raft.ReadIndex()
  112. if err != nil {
  113. // If we're not leader, try forwarding
  114. if err == ErrNotLeader {
  115. return s.forwardGet(key)
  116. }
  117. return "", false, err
  118. }
  119. val, ok := s.FSM.Get(key)
  120. return val, ok, nil
  121. }
  122. // forwardGet forwards a get request to the leader
  123. func (s *KVServer) forwardGet(key string) (string, bool, error) {
  124. leaderID := s.Raft.GetLeaderID()
  125. if leaderID == "" {
  126. return "", false, ErrNoLeader
  127. }
  128. // For now, return an error asking client to retry on leader
  129. // A full implementation would forward the request
  130. return "", false, NewRaftError(ErrNotLeader, leaderID, 100*time.Millisecond)
  131. }
  132. // Join joins an existing cluster
  133. func (s *KVServer) Join(nodeID, addr string) error {
  134. return s.Raft.AddNodeWithForward(nodeID, addr)
  135. }
  136. // Leave leaves the cluster
  137. func (s *KVServer) Leave(nodeID string) error {
  138. return s.Raft.RemoveNodeWithForward(nodeID)
  139. }
  140. // WaitForLeader waits until a leader is elected
  141. func (s *KVServer) WaitForLeader(timeout time.Duration) error {
  142. deadline := time.Now().Add(timeout)
  143. for time.Now().Before(deadline) {
  144. leader := s.Raft.GetLeaderID()
  145. if leader != "" {
  146. return nil
  147. }
  148. time.Sleep(100 * time.Millisecond)
  149. }
  150. return fmt.Errorf("timeout waiting for leader")
  151. }
  152. // HealthCheck returns the health status of this server
  153. func (s *KVServer) HealthCheck() HealthStatus {
  154. return s.Raft.HealthCheck()
  155. }
  156. // GetMetrics returns runtime metrics
  157. func (s *KVServer) GetMetrics() Metrics {
  158. return s.Raft.GetMetrics()
  159. }
  160. // TransferLeadership transfers leadership to the specified node
  161. func (s *KVServer) TransferLeadership(targetID string) error {
  162. return s.Raft.TransferLeadership(targetID)
  163. }
  164. // GetClusterNodes returns current cluster membership
  165. func (s *KVServer) GetClusterNodes() map[string]string {
  166. return s.Raft.GetClusterNodes()
  167. }
  168. // IsLeader returns true if this node is the leader
  169. func (s *KVServer) IsLeader() bool {
  170. _, isLeader := s.Raft.GetState()
  171. return isLeader
  172. }
  173. // GetLeaderID returns the current leader ID
  174. func (s *KVServer) GetLeaderID() string {
  175. return s.Raft.GetLeaderID()
  176. }
  177. // WatchAll registers a watcher for all keys
  178. func (s *KVServer) WatchAll(handler WatchHandler) {
  179. s.FSM.WatchAll(handler)
  180. }
  181. // Watch registers a watcher for a key
  182. func (s *KVServer) Watch(key string, handler WatchHandler) {
  183. s.FSM.Watch(key, handler)
  184. }
  185. // Unwatch removes watchers for a key
  186. func (s *KVServer) Unwatch(key string) {
  187. s.FSM.Unwatch(key)
  188. }
  189. func (s *KVServer) maintenanceLoop() {
  190. defer s.wg.Done()
  191. // Check every 1 second for faster reaction
  192. ticker := time.NewTicker(1 * time.Second)
  193. defer ticker.Stop()
  194. for {
  195. select {
  196. case <-s.stopCh:
  197. return
  198. case <-ticker.C:
  199. s.updateNodeInfo()
  200. s.checkConnections()
  201. }
  202. }
  203. }
  204. func (s *KVServer) updateNodeInfo() {
  205. // 1. Ensure "CreateNode/<NodeID>" is set to self address
  206. // We do this via Propose (Set) so it's replicated
  207. myID := s.Raft.config.NodeID
  208. myAddr := s.Raft.config.ListenAddr
  209. key := fmt.Sprintf("CreateNode/%s", myID)
  210. // Check if we need to update (avoid spamming logs/proposals)
  211. val, exists := s.Get(key)
  212. if !exists || val != myAddr {
  213. // Run in goroutine to avoid blocking
  214. go func() {
  215. if err := s.Set(key, myAddr); err != nil {
  216. s.Raft.config.Logger.Debug("Failed to update node info: %v", err)
  217. }
  218. }()
  219. }
  220. // 2. Only leader updates RaftNode aggregation
  221. if s.IsLeader() {
  222. // Read current RaftNode to preserve history
  223. currentVal, _ := s.Get("RaftNode")
  224. knownNodes := make(map[string]string)
  225. if currentVal != "" {
  226. parts := strings.Split(currentVal, ";")
  227. for _, part := range parts {
  228. if part == "" { continue }
  229. kv := strings.SplitN(part, "=", 2)
  230. if len(kv) == 2 {
  231. knownNodes[kv[0]] = kv[1]
  232. }
  233. }
  234. }
  235. // Merge current cluster nodes
  236. changed := false
  237. currentCluster := s.GetClusterNodes()
  238. for id, addr := range currentCluster {
  239. if knownNodes[id] != addr {
  240. knownNodes[id] = addr
  241. changed = true
  242. }
  243. }
  244. // If changed, update RaftNode
  245. if changed {
  246. var peers []string
  247. for id, addr := range knownNodes {
  248. peers = append(peers, fmt.Sprintf("%s=%s", id, addr))
  249. }
  250. sort.Strings(peers)
  251. newVal := strings.Join(peers, ";")
  252. // Check again if we need to write to avoid loops if Get returned stale
  253. if newVal != currentVal {
  254. go func(k, v string) {
  255. if err := s.Set(k, v); err != nil {
  256. s.Raft.config.Logger.Warn("Failed to update RaftNode key: %v", err)
  257. }
  258. }("RaftNode", newVal)
  259. }
  260. }
  261. }
  262. }
  263. func (s *KVServer) checkConnections() {
  264. if !s.IsLeader() {
  265. return
  266. }
  267. // Read RaftNode key to find potential members that are missing
  268. val, ok := s.Get("RaftNode")
  269. if !ok || val == "" {
  270. return
  271. }
  272. // Parse saved nodes
  273. savedParts := strings.Split(val, ";")
  274. currentNodes := s.GetClusterNodes()
  275. // Invert currentNodes for address check
  276. currentAddrs := make(map[string]bool)
  277. for _, addr := range currentNodes {
  278. currentAddrs[addr] = true
  279. }
  280. for _, part := range savedParts {
  281. if part == "" {
  282. continue
  283. }
  284. // Expect id=addr
  285. kv := strings.SplitN(part, "=", 2)
  286. if len(kv) != 2 {
  287. continue
  288. }
  289. id, addr := kv[0], kv[1]
  290. if !currentAddrs[addr] {
  291. // Found a node that was previously in the cluster but is now missing
  292. // Try to add it back
  293. // We use AddNodeWithForward which handles non-blocking internally somewhat,
  294. // but we should run this in goroutine to not block the loop
  295. go func(nodeID, nodeAddr string) {
  296. // Try to add node
  297. s.Raft.config.Logger.Info("Auto-rejoining node found in RaftNode: %s (%s)", nodeID, nodeAddr)
  298. if err := s.Join(nodeID, nodeAddr); err != nil {
  299. s.Raft.config.Logger.Debug("Failed to auto-rejoin node %s: %v", nodeID, err)
  300. }
  301. }(id, addr)
  302. }
  303. }
  304. }