| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401 |
- package raft
- import (
- "encoding/json"
- "errors"
- "fmt"
- "sort"
- "strings"
- "sync"
- "time"
- "igit.com/xbase/raft/db"
- )
- // KVServer wraps Raft to provide a distributed key-value store
- type KVServer struct {
- Raft *Raft
- DB *db.Engine
- stopCh chan struct{}
- wg sync.WaitGroup
- stopOnce sync.Once
- }
- // NewKVServer creates a new KV server
- func NewKVServer(config *Config) (*KVServer, error) {
- // Initialize DB Engine
- // Use a subdirectory for DB to avoid conflict with Raft logs if they share DataDir
- dbPath := config.DataDir + "/kv_engine"
- engine, err := db.NewEngine(dbPath)
- if err != nil {
- return nil, fmt.Errorf("failed to create db engine: %w", err)
- }
- // Configure snapshot provider
- config.SnapshotProvider = func() ([]byte, error) {
- return engine.Snapshot()
- }
- // Configure get handler for remote reads
- config.GetHandler = func(key string) (string, bool) {
- return engine.Get(key)
- }
- applyCh := make(chan ApplyMsg, 1000) // Increase buffer for async processing
- transport := NewTCPTransport(config.ListenAddr, 10, config.Logger)
- r, err := NewRaft(config, transport, applyCh)
- if err != nil {
- engine.Close()
- return nil, err
- }
- s := &KVServer{
- Raft: r,
- DB: engine,
- }
- // Start applying entries
- go s.runApplyLoop(applyCh)
- // Start background maintenance loop
- s.stopCh = make(chan struct{})
- s.wg.Add(1)
- go s.maintenanceLoop()
- return s, nil
- }
- func (s *KVServer) Start() error {
- return s.Raft.Start()
- }
- func (s *KVServer) Stop() error {
- var err error
- s.stopOnce.Do(func() {
- // Stop maintenance loop
- if s.stopCh != nil {
- close(s.stopCh)
- s.wg.Wait()
- }
- // Stop Raft first
- if errRaft := s.Raft.Stop(); errRaft != nil {
- err = errRaft
- }
- // Close DB
- if s.DB != nil {
- if errDB := s.DB.Close(); errDB != nil {
- // Combine errors if both fail
- if err != nil {
- err = fmt.Errorf("raft stop error: %v, db close error: %v", err, errDB)
- } else {
- err = errDB
- }
- }
- }
- })
- return err
- }
- func (s *KVServer) runApplyLoop(applyCh chan ApplyMsg) {
- for msg := range applyCh {
- if msg.CommandValid {
- // Optimization: Skip if already applied
- // We check this here to avoid unmarshalling and locking DB for known duplicates
- if msg.CommandIndex <= s.DB.GetLastAppliedIndex() {
- continue
- }
- var cmd KVCommand
- if err := json.Unmarshal(msg.Command, &cmd); err != nil {
- s.Raft.config.Logger.Error("Failed to unmarshal command: %v", err)
- continue
- }
- var err error
- switch cmd.Type {
- case KVSet:
- err = s.DB.Set(cmd.Key, cmd.Value, msg.CommandIndex)
- case KVDel:
- err = s.DB.Delete(cmd.Key, msg.CommandIndex)
- default:
- s.Raft.config.Logger.Error("Unknown command type: %d", cmd.Type)
- }
- if err != nil {
- s.Raft.config.Logger.Error("DB Apply failed: %v", err)
- }
- } else if msg.SnapshotValid {
- if err := s.DB.Restore(msg.Snapshot); err != nil {
- s.Raft.config.Logger.Error("DB Restore failed: %v", err)
- }
- }
- }
- }
- // Set sets a key-value pair
- func (s *KVServer) Set(key, value string) error {
- cmd := KVCommand{
- Type: KVSet,
- Key: key,
- Value: value,
- }
- data, err := json.Marshal(cmd)
- if err != nil {
- return err
- }
- _, _, err = s.Raft.ProposeWithForward(data)
- return err
- }
- // Del deletes a key
- func (s *KVServer) Del(key string) error {
- cmd := KVCommand{
- Type: KVDel,
- Key: key,
- }
- data, err := json.Marshal(cmd)
- if err != nil {
- return err
- }
- _, _, err = s.Raft.ProposeWithForward(data)
- return err
- }
- // Get gets a value (local read, can be stale)
- // For linearizable reads, use GetLinear instead
- func (s *KVServer) Get(key string) (string, bool) {
- return s.DB.Get(key)
- }
- // GetLinear gets a value with linearizable consistency
- // This ensures the read sees all writes committed before the read started
- func (s *KVServer) GetLinear(key string) (string, bool, error) {
- // First, ensure we have up-to-date data via ReadIndex
- _, err := s.Raft.ReadIndex()
- if err != nil {
- // If we're not leader, try forwarding
- if errors.Is(err, ErrNotLeader) {
- return s.forwardGet(key)
- }
- return "", false, err
- }
- val, ok := s.DB.Get(key)
- return val, ok, nil
- }
- // forwardGet forwards a get request to the leader
- func (s *KVServer) forwardGet(key string) (string, bool, error) {
- return s.Raft.ForwardGet(key)
- }
- // Join joins an existing cluster
- func (s *KVServer) Join(nodeID, addr string) error {
- return s.Raft.AddNodeWithForward(nodeID, addr)
- }
- // Leave leaves the cluster
- func (s *KVServer) Leave(nodeID string) error {
- return s.Raft.RemoveNodeWithForward(nodeID)
- }
- // WaitForLeader waits until a leader is elected
- func (s *KVServer) WaitForLeader(timeout time.Duration) error {
- deadline := time.Now().Add(timeout)
- for time.Now().Before(deadline) {
- leader := s.Raft.GetLeaderID()
- if leader != "" {
- return nil
- }
- time.Sleep(100 * time.Millisecond)
- }
- return fmt.Errorf("timeout waiting for leader")
- }
- // HealthCheck returns the health status of this server
- func (s *KVServer) HealthCheck() HealthStatus {
- return s.Raft.HealthCheck()
- }
- // GetStats returns runtime statistics
- func (s *KVServer) GetStats() Stats {
- return s.Raft.GetStats()
- }
- // GetMetrics returns runtime metrics
- func (s *KVServer) GetMetrics() Metrics {
- return s.Raft.GetMetrics()
- }
- // TransferLeadership transfers leadership to the specified node
- func (s *KVServer) TransferLeadership(targetID string) error {
- return s.Raft.TransferLeadership(targetID)
- }
- // GetClusterNodes returns current cluster membership
- func (s *KVServer) GetClusterNodes() map[string]string {
- return s.Raft.GetClusterNodes()
- }
- // IsLeader returns true if this node is the leader
- func (s *KVServer) IsLeader() bool {
- _, isLeader := s.Raft.GetState()
- return isLeader
- }
- // GetLeaderID returns the current leader ID
- func (s *KVServer) GetLeaderID() string {
- return s.Raft.GetLeaderID()
- }
- // WatchAll registers a watcher for all keys
- func (s *KVServer) WatchAll(handler WatchHandler) {
- // s.FSM.WatchAll(handler)
- // TODO: Implement Watcher for DB
- }
- // Watch registers a watcher for a key
- func (s *KVServer) Watch(key string, handler WatchHandler) {
- // s.FSM.Watch(key, handler)
- // TODO: Implement Watcher for DB
- }
- // Unwatch removes watchers for a key
- func (s *KVServer) Unwatch(key string) {
- // s.FSM.Unwatch(key)
- // TODO: Implement Watcher for DB
- }
- func (s *KVServer) maintenanceLoop() {
- defer s.wg.Done()
- // Check every 1 second for faster reaction
- ticker := time.NewTicker(1 * time.Second)
- defer ticker.Stop()
- for {
- select {
- case <-s.stopCh:
- return
- case <-ticker.C:
- s.updateNodeInfo()
- s.checkConnections()
- }
- }
- }
- func (s *KVServer) updateNodeInfo() {
- // 1. Ensure "CreateNode/<NodeID>" is set to self address
- // We do this via Propose (Set) so it's replicated
- myID := s.Raft.config.NodeID
- myAddr := s.Raft.config.ListenAddr
- key := fmt.Sprintf("CreateNode/%s", myID)
- // Check if we need to update (avoid spamming logs/proposals)
- val, exists := s.Get(key)
- if !exists || val != myAddr {
- // Run in goroutine to avoid blocking
- go func() {
- if err := s.Set(key, myAddr); err != nil {
- s.Raft.config.Logger.Debug("Failed to update node info: %v", err)
- }
- }()
- }
- // 2. Only leader updates RaftNode aggregation
- if s.IsLeader() {
- // Read current RaftNode to preserve history
- currentVal, _ := s.Get("RaftNode")
-
- knownNodes := make(map[string]string)
- if currentVal != "" {
- parts := strings.Split(currentVal, ";")
- for _, part := range parts {
- if part == "" { continue }
- kv := strings.SplitN(part, "=", 2)
- if len(kv) == 2 {
- knownNodes[kv[0]] = kv[1]
- }
- }
- }
- // Merge current cluster nodes
- changed := false
- currentCluster := s.GetClusterNodes()
- for id, addr := range currentCluster {
- if knownNodes[id] != addr {
- knownNodes[id] = addr
- changed = true
- }
- }
- // If changed, update RaftNode
- if changed {
- var peers []string
- for id, addr := range knownNodes {
- peers = append(peers, fmt.Sprintf("%s=%s", id, addr))
- }
- sort.Strings(peers)
- newVal := strings.Join(peers, ";")
- // Check again if we need to write to avoid loops if Get returned stale
- if newVal != currentVal {
- go func(k, v string) {
- if err := s.Set(k, v); err != nil {
- s.Raft.config.Logger.Warn("Failed to update RaftNode key: %v", err)
- }
- }("RaftNode", newVal)
- }
- }
- }
- }
- func (s *KVServer) checkConnections() {
- if !s.IsLeader() {
- return
- }
- // Read RaftNode key to find potential members that are missing
- val, ok := s.Get("RaftNode")
- if !ok || val == "" {
- return
- }
- // Parse saved nodes
- savedParts := strings.Split(val, ";")
- currentNodes := s.GetClusterNodes()
- // Invert currentNodes for address check
- currentAddrs := make(map[string]bool)
- for _, addr := range currentNodes {
- currentAddrs[addr] = true
- }
- for _, part := range savedParts {
- if part == "" {
- continue
- }
- // Expect id=addr
- kv := strings.SplitN(part, "=", 2)
- if len(kv) != 2 {
- continue
- }
- id, addr := kv[0], kv[1]
- if !currentAddrs[addr] {
- // Found a node that was previously in the cluster but is now missing
- // Try to add it back
- // We use AddNodeWithForward which handles non-blocking internally somewhat,
- // but we should run this in goroutine to not block the loop
- go func(nodeID, nodeAddr string) {
- // Try to add node
- s.Raft.config.Logger.Info("Auto-rejoining node found in RaftNode: %s (%s)", nodeID, nodeAddr)
- if err := s.Join(nodeID, nodeAddr); err != nil {
- s.Raft.config.Logger.Debug("Failed to auto-rejoin node %s: %v", nodeID, err)
- }
- }(id, addr)
- }
- }
- }
|