package raft import ( "encoding/json" "errors" "fmt" "sort" "strings" "sync" "time" "igit.com/xbase/raft/db" ) // KVServer wraps Raft to provide a distributed key-value store type KVServer struct { Raft *Raft DB *db.Engine stopCh chan struct{} wg sync.WaitGroup stopOnce sync.Once } // NewKVServer creates a new KV server func NewKVServer(config *Config) (*KVServer, error) { // Initialize DB Engine // Use a subdirectory for DB to avoid conflict with Raft logs if they share DataDir dbPath := config.DataDir + "/kv_engine" engine, err := db.NewEngine(dbPath) if err != nil { return nil, fmt.Errorf("failed to create db engine: %w", err) } // Initialize LastAppliedIndex from DB to prevent re-applying entries config.LastAppliedIndex = engine.GetLastAppliedIndex() // Create stop channel early for use in callbacks stopCh := make(chan struct{}) // Configure snapshot provider config.SnapshotProvider = func(minIncludeIndex uint64) ([]byte, error) { // Wait for DB to catch up to the requested index // This is critical for data integrity during compaction for engine.GetLastAppliedIndex() < minIncludeIndex { select { case <-stopCh: return nil, fmt.Errorf("server stopping") default: time.Sleep(10 * time.Millisecond) } } // Force sync to disk to ensure data durability before compaction // This prevents data loss if Raft logs are compacted but DB data is only in OS cache if err := engine.Sync(); err != nil { return nil, fmt.Errorf("failed to sync engine before snapshot: %w", err) } return engine.Snapshot() } // Configure get handler for remote reads config.GetHandler = func(key string) (string, bool) { return engine.Get(key) } applyCh := make(chan ApplyMsg, 1000) // Increase buffer for async processing transport := NewTCPTransport(config.ListenAddr, 10, config.Logger) r, err := NewRaft(config, transport, applyCh) if err != nil { engine.Close() return nil, err } s := &KVServer{ Raft: r, DB: engine, stopCh: stopCh, } // Start applying entries go s.runApplyLoop(applyCh) // Start background maintenance loop s.wg.Add(1) go s.maintenanceLoop() return s, nil } func (s *KVServer) Start() error { return s.Raft.Start() } func (s *KVServer) Stop() error { var err error s.stopOnce.Do(func() { // Stop maintenance loop if s.stopCh != nil { close(s.stopCh) s.wg.Wait() } // Stop Raft first if errRaft := s.Raft.Stop(); errRaft != nil { err = errRaft } // Close DB if s.DB != nil { if errDB := s.DB.Close(); errDB != nil { // Combine errors if both fail if err != nil { err = fmt.Errorf("raft stop error: %v, db close error: %v", err, errDB) } else { err = errDB } } } }) return err } func (s *KVServer) runApplyLoop(applyCh chan ApplyMsg) { for msg := range applyCh { if msg.CommandValid { // Optimization: Skip if already applied // We check this here to avoid unmarshalling and locking DB for known duplicates if msg.CommandIndex <= s.DB.GetLastAppliedIndex() { continue } var cmd KVCommand if err := json.Unmarshal(msg.Command, &cmd); err != nil { s.Raft.config.Logger.Error("Failed to unmarshal command: %v", err) continue } var err error switch cmd.Type { case KVSet: err = s.DB.Set(cmd.Key, cmd.Value, msg.CommandIndex) case KVDel: err = s.DB.Delete(cmd.Key, msg.CommandIndex) default: s.Raft.config.Logger.Error("Unknown command type: %d", cmd.Type) } if err != nil { s.Raft.config.Logger.Error("DB Apply failed: %v", err) } } else if msg.SnapshotValid { if err := s.DB.Restore(msg.Snapshot); err != nil { s.Raft.config.Logger.Error("DB Restore failed: %v", err) } } } } // Set sets a key-value pair func (s *KVServer) Set(key, value string) error { cmd := KVCommand{ Type: KVSet, Key: key, Value: value, } data, err := json.Marshal(cmd) if err != nil { return err } _, _, err = s.Raft.ProposeWithForward(data) return err } // Del deletes a key func (s *KVServer) Del(key string) error { cmd := KVCommand{ Type: KVDel, Key: key, } data, err := json.Marshal(cmd) if err != nil { return err } _, _, err = s.Raft.ProposeWithForward(data) return err } // Get gets a value (local read, can be stale) // For linearizable reads, use GetLinear instead func (s *KVServer) Get(key string) (string, bool) { return s.DB.Get(key) } // GetLinear gets a value with linearizable consistency // This ensures the read sees all writes committed before the read started func (s *KVServer) GetLinear(key string) (string, bool, error) { // First, ensure we have up-to-date data via ReadIndex _, err := s.Raft.ReadIndex() if err != nil { // If we're not leader, try forwarding if errors.Is(err, ErrNotLeader) { return s.forwardGet(key) } return "", false, err } val, ok := s.DB.Get(key) return val, ok, nil } // forwardGet forwards a get request to the leader func (s *KVServer) forwardGet(key string) (string, bool, error) { return s.Raft.ForwardGet(key) } // Join joins an existing cluster func (s *KVServer) Join(nodeID, addr string) error { return s.Raft.AddNodeWithForward(nodeID, addr) } // Leave leaves the cluster func (s *KVServer) Leave(nodeID string) error { return s.Raft.RemoveNodeWithForward(nodeID) } // WaitForLeader waits until a leader is elected func (s *KVServer) WaitForLeader(timeout time.Duration) error { deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { leader := s.Raft.GetLeaderID() if leader != "" { return nil } time.Sleep(100 * time.Millisecond) } return fmt.Errorf("timeout waiting for leader") } // HealthCheck returns the health status of this server func (s *KVServer) HealthCheck() HealthStatus { return s.Raft.HealthCheck() } // GetStats returns runtime statistics func (s *KVServer) GetStats() Stats { return s.Raft.GetStats() } // GetMetrics returns runtime metrics func (s *KVServer) GetMetrics() Metrics { return s.Raft.GetMetrics() } // TransferLeadership transfers leadership to the specified node func (s *KVServer) TransferLeadership(targetID string) error { return s.Raft.TransferLeadership(targetID) } // GetClusterNodes returns current cluster membership func (s *KVServer) GetClusterNodes() map[string]string { return s.Raft.GetClusterNodes() } // IsLeader returns true if this node is the leader func (s *KVServer) IsLeader() bool { _, isLeader := s.Raft.GetState() return isLeader } // GetLeaderID returns the current leader ID func (s *KVServer) GetLeaderID() string { return s.Raft.GetLeaderID() } // WatchAll registers a watcher for all keys func (s *KVServer) WatchAll(handler WatchHandler) { // s.FSM.WatchAll(handler) // TODO: Implement Watcher for DB } // Watch registers a watcher for a key func (s *KVServer) Watch(key string, handler WatchHandler) { // s.FSM.Watch(key, handler) // TODO: Implement Watcher for DB } // Unwatch removes watchers for a key func (s *KVServer) Unwatch(key string) { // s.FSM.Unwatch(key) // TODO: Implement Watcher for DB } func (s *KVServer) maintenanceLoop() { defer s.wg.Done() // Check every 1 second for faster reaction ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { select { case <-s.stopCh: return case <-ticker.C: s.updateNodeInfo() s.checkConnections() } } } func (s *KVServer) updateNodeInfo() { // 1. Ensure "CreateNode/" is set to self address // We do this via Propose (Set) so it's replicated myID := s.Raft.config.NodeID myAddr := s.Raft.config.ListenAddr key := fmt.Sprintf("CreateNode/%s", myID) // Check if we need to update (avoid spamming logs/proposals) val, exists := s.Get(key) if !exists || val != myAddr { // Run in goroutine to avoid blocking go func() { if err := s.Set(key, myAddr); err != nil { s.Raft.config.Logger.Debug("Failed to update node info: %v", err) } }() } // 2. Only leader updates RaftNode aggregation if s.IsLeader() { // Read current RaftNode to preserve history currentVal, _ := s.Get("RaftNode") knownNodes := make(map[string]string) if currentVal != "" { parts := strings.Split(currentVal, ";") for _, part := range parts { if part == "" { continue } kv := strings.SplitN(part, "=", 2) if len(kv) == 2 { knownNodes[kv[0]] = kv[1] } } } // Merge current cluster nodes changed := false currentCluster := s.GetClusterNodes() for id, addr := range currentCluster { if knownNodes[id] != addr { knownNodes[id] = addr changed = true } } // If changed, update RaftNode if changed { var peers []string for id, addr := range knownNodes { peers = append(peers, fmt.Sprintf("%s=%s", id, addr)) } sort.Strings(peers) newVal := strings.Join(peers, ";") // Check again if we need to write to avoid loops if Get returned stale if newVal != currentVal { go func(k, v string) { if err := s.Set(k, v); err != nil { s.Raft.config.Logger.Warn("Failed to update RaftNode key: %v", err) } }("RaftNode", newVal) } } } } func (s *KVServer) checkConnections() { if !s.IsLeader() { return } // Read RaftNode key to find potential members that are missing val, ok := s.Get("RaftNode") if !ok || val == "" { return } // Parse saved nodes savedParts := strings.Split(val, ";") currentNodes := s.GetClusterNodes() // Invert currentNodes for address check currentAddrs := make(map[string]bool) for _, addr := range currentNodes { currentAddrs[addr] = true } for _, part := range savedParts { if part == "" { continue } // Expect id=addr kv := strings.SplitN(part, "=", 2) if len(kv) != 2 { continue } id, addr := kv[0], kv[1] if !currentAddrs[addr] { // Found a node that was previously in the cluster but is now missing // Try to add it back // We use AddNodeWithForward which handles non-blocking internally somewhat, // but we should run this in goroutine to not block the loop go func(nodeID, nodeAddr string) { // Try to add node s.Raft.config.Logger.Info("Auto-rejoining node found in RaftNode: %s (%s)", nodeID, nodeAddr) if err := s.Join(nodeID, nodeAddr); err != nil { s.Raft.config.Logger.Debug("Failed to auto-rejoin node %s: %v", nodeID, err) } }(id, addr) } } }