Browse Source

auth密码输入优化

robert 1 tuần trước cách đây
mục cha
commit
26b864290b
4 tập tin đã thay đổi với 95 bổ sung13 xóa
  1. 26 0
      auth.go
  2. 62 10
      cli.go
  3. 3 3
      example/database/inspector.go
  4. 4 0
      server.go

+ 26 - 0
auth.go

@@ -13,6 +13,8 @@ import (
 	"strings"
 	"sync"
 	"time"
+
+	"igit.com/xbase/raft/db"
 )
 
 // ==================== Auth Constants ====================
@@ -626,3 +628,27 @@ func (am *AuthManager) DisableAuth() error {
 	data, _ := json.Marshal(config)
 	return am.server.Set(AuthConfigKey, string(data))
 }
+
+// LoadFromDB loads auth data from existing DB
+func (am *AuthManager) LoadFromDB() error {
+	prefix := SystemKeyPrefix
+	
+	// We use the DB engine's index to find all system keys
+	am.server.DB.Index.WalkPrefix(prefix, func(key string, entry db.IndexEntry) bool {
+		// Read value from storage
+		// We need to access Storage via Engine.
+		// Engine.Storage is exported.
+		val, err := am.server.DB.Storage.ReadValue(entry.ValueOffset)
+		if err != nil {
+			// Skip corrupted or missing values
+			return true
+		}
+		
+		// Update Cache
+		// Note: We are essentially replaying the state into the cache
+		am.UpdateCache(key, val, false)
+		return true
+	})
+	
+	return nil
+}

+ 62 - 10
cli.go

@@ -340,16 +340,25 @@ func (c *CLI) registerDefaultCommands() {
 		}
 	})
 
-	c.RegisterCommand("login", "Login to system (login <user> <pass> [code])", func(parts []string, server *KVServer) {
-		if len(parts) < 3 {
-			printBoxed("Usage: login <username> <password> [mfa_code]")
+	c.RegisterCommand("login", "Login to system (login <user> [code])", func(parts []string, server *KVServer) {
+		if len(parts) < 2 {
+			printBoxed("Usage: login <username> [mfa_code]")
 			return
 		}
 		username := parts[1]
-		password := parts[2]
 		code := ""
-		if len(parts) > 3 {
-			code = parts[3]
+		if len(parts) > 2 {
+			code = parts[2]
+		}
+
+		// Prompt for password
+		fmt.Print("Password: ")
+		password, err := readPassword()
+		fmt.Println() // Newline after input
+
+		if err != nil {
+			printBoxed(fmt.Sprintf("Error reading password: %v", err))
+			return
 		}
 
 		token, err := server.AuthManager.Login(username, password, code, "cli")
@@ -383,9 +392,9 @@ func (c *CLI) registerDefaultCommands() {
 		c.mu.Unlock()
 	})
 
-	c.RegisterCommand("auth-init", "Initialize Auth System (auth-init <root_pass>)", func(parts []string, server *KVServer) {
-		if len(parts) != 2 {
-			printBoxed("Usage: auth-init <root_password>")
+	c.RegisterCommand("auth-init", "Initialize Auth System (auth-init)", func(parts []string, server *KVServer) {
+		if len(parts) > 1 {
+			printBoxed("Usage: auth-init (prompts for password securely)")
 			return
 		}
 		
@@ -398,9 +407,24 @@ func (c *CLI) registerDefaultCommands() {
 			printBoxed(fmt.Sprintf("%sPermission Denied: Auth already enabled. Login as root to re-init.%s", ColorRed, ColorReset))
 			return
 		}
+
+		// Prompt for password
+		fmt.Print("Enter root password: ")
+		password, err := readPassword()
+		fmt.Println() // Newline after input
+		
+		if err != nil {
+			printBoxed(fmt.Sprintf("Error reading password: %v", err))
+			return
+		}
+		
+		if password == "" {
+			printBoxed("Error: Password cannot be empty")
+			return
+		}
 		
 		// 1. Create Root User
-		if err := server.AuthManager.CreateRootUser(parts[1]); err != nil {
+		if err := server.AuthManager.CreateRootUser(password); err != nil {
 			printBoxed(fmt.Sprintf("Failed to create root user: %v", err))
 			return
 		}
@@ -682,3 +706,31 @@ func (c *CLI) registerDefaultCommands() {
 		printBoxed(sb.String())
 	})
 }
+
+// readPassword securely reads a password from stdin without echoing.
+func readPassword() (string, error) {
+	// 1. Disable echo
+	cmd := exec.Command("stty", "-echo")
+	cmd.Stdin = os.Stdin
+	cmd.Stdout = os.Stdout
+	if err := cmd.Run(); err != nil {
+		return "", fmt.Errorf("failed to disable echo: %v", err)
+	}
+
+	defer func() {
+		// 3. Re-enable echo
+		cmd := exec.Command("stty", "echo")
+		cmd.Stdin = os.Stdin
+		cmd.Stdout = os.Stdout
+		_ = cmd.Run()
+	}()
+
+	// 2. Read input
+	reader := bufio.NewReader(os.Stdin)
+	pass, err := reader.ReadString('\n')
+	if err != nil {
+		return "", err
+	}
+
+	return strings.TrimSpace(pass), nil
+}

+ 3 - 3
example/database/inspector.go

@@ -80,7 +80,7 @@ func main() {
 
 		key := string(data[:keyLen])
 		val := string(data[keyLen:])
-		
+
 		// Truncate long values for display
 		displayVal := val
 		if len(displayVal) > 30 {
@@ -90,14 +90,14 @@ func main() {
 			displayVal = "<tombstone>"
 		}
 
-		fmt.Fprintf(w, "%d\t%s\t%08X\t%d\t%d\t%d\t%s\t%q\n", 
+		fmt.Fprintf(w, "%d\t%s\t%08X\t%d\t%d\t%d\t%s\t%q\n",
 			offset, typeStr, crc, keyLen, valLen, commitIndex, key, displayVal)
 
 		offset += int64(HeaderSize + totalDataLen)
 	}
 	w.Flush()
 	fmt.Println()
-	
+
 	// Also try to open Engine to verify Index Rebuild
 	e, err := db.NewEngine(dataDir)
 	if err == nil {

+ 4 - 0
server.go

@@ -98,6 +98,10 @@ func NewKVServer(config *Config) (*KVServer, error) {
 
 	// Initialize AuthManager
 	s.AuthManager = NewAuthManager(s)
+	// Load Auth Data from DB (if any)
+	if err := s.AuthManager.LoadFromDB(); err != nil {
+		s.Raft.config.Logger.Warn("Failed to load auth data from DB: %v", err)
+	}
 
 	// Initialize CLI
 	s.CLI = NewCLI(s)