package config import ( "encoding/json" "fmt" "os" "os/user" "path/filepath" "runtime" "github.com/spf13/viper" ) // Config 程序配置结构 type Config struct { HostsPath string `json:"hostsPath"` BackupCount int `json:"backupCount"` DNSTimeout int `json:"dnsTimeout"` MaxConcurrent int `json:"maxConcurrent"` LogLevel string `json:"logLevel"` // debug, info, warning, error, silent LogPath string `json:"logPath"` // 空字符串表示只输出到控制台,有路径表示同时输出到文件和控制台 } // DNSServer DNS服务器配置 type DNSServer struct { Name string `json:"Name"` DNS string `json:"Dns"` DoH string `json:"Doh"` } var ( AppConfig *Config DNSServers []DNSServer ConfigPath string ServersPath string ) // Init 初始化配置(使用默认用户目录) func Init() { userConfigDir, err := getUserConfigDir() if err != nil { fmt.Fprintf(os.Stderr, "获取用户配置目录失败: %v\n", err) os.Exit(1) } InitWithConfigDir(userConfigDir) } // InitWithConfigDir 使用指定配置目录初始化配置 func InitWithConfigDir(userConfigDir string) { ConfigPath = filepath.Join(userConfigDir, "config", "config.json") ServersPath = filepath.Join(userConfigDir, "config", "servers.json") // 初始化默认配置 AppConfig = &Config{ HostsPath: getDefaultHostsPath(), BackupCount: 5, DNSTimeout: 5000, MaxConcurrent: 10, LogLevel: "info", LogPath: filepath.Join(userConfigDir, "logs"), } // 加载配置文件 loadConfig() loadServers() // 创建日志目录 if err := os.MkdirAll(AppConfig.LogPath, 0755); err != nil { fmt.Fprintf(os.Stderr, "创建日志目录失败: %v\n", err) } } // getUserConfigDir 获取用户配置目录 func getUserConfigDir() (string, error) { currentUser, err := user.Current() if err != nil { return "", fmt.Errorf("获取当前用户信息失败: %v", err) } return filepath.Join(currentUser.HomeDir, ".hostsync"), nil } // getDefaultHostsPath 获取默认hosts文件路径 func getDefaultHostsPath() string { switch runtime.GOOS { case "windows": return `C:\Windows\System32\drivers\etc\hosts` default: return "/etc/hosts" } } // loadConfig 加载配置文件 func loadConfig() { if _, err := os.Stat(ConfigPath); os.IsNotExist(err) { // 配置文件不存在,创建默认配置 saveConfig() return } data, err := os.ReadFile(ConfigPath) if err != nil { fmt.Fprintf(os.Stderr, "读取配置文件失败: %v\n", err) return } // 创建临时配置对象进行解析测试 tempConfig := &Config{} if err := json.Unmarshal(data, tempConfig); err != nil { fmt.Fprintf(os.Stderr, "解析配置文件失败: %v\n", err) fmt.Fprintln(os.Stderr, "💡 将使用默认配置并重新创建配置文件") saveConfig() return } // 解析成功,更新当前配置 AppConfig = tempConfig // 检查并补充缺失的字段 needSave := false if AppConfig.LogPath == "" { userConfigDir, _ := getUserConfigDir() AppConfig.LogPath = filepath.Join(userConfigDir, "logs") needSave = true } if AppConfig.LogLevel == "" { AppConfig.LogLevel = "info" needSave = true } // 如果有缺失字段,保存更新后的配置 if needSave { saveConfig() } } // saveConfig 保存配置文件 func saveConfig() { data, err := json.MarshalIndent(AppConfig, "", " ") if err != nil { fmt.Fprintf(os.Stderr, "序列化配置失败: %v\n", err) return } // 确保配置文件目录存在 configDir := filepath.Dir(ConfigPath) if err := os.MkdirAll(configDir, 0755); err != nil { fmt.Fprintf(os.Stderr, "创建配置目录失败: %v\n", err) return } if err := os.WriteFile(ConfigPath, data, 0644); err != nil { fmt.Fprintf(os.Stderr, "保存配置文件失败: %v\n", err) } } // loadServers 加载DNS服务器配置 func loadServers() { if _, err := os.Stat(ServersPath); os.IsNotExist(err) { // DNS服务器配置文件不存在时,不输出错误信息 // init 命令会创建这个文件 return } data, err := os.ReadFile(ServersPath) if err != nil { fmt.Fprintf(os.Stderr, "读取DNS服务器配置失败: %v\n", err) return } // 创建临时变量进行解析测试 var tempServers []DNSServer if err := json.Unmarshal(data, &tempServers); err != nil { fmt.Fprintf(os.Stderr, "解析DNS服务器配置失败: %v\n", err) fmt.Fprintln(os.Stderr, "💡 请运行 'hostsync init --force' 重新创建配置文件") return } // 解析成功,更新服务器列表 DNSServers = tempServers } // GetDNSServer 根据名称获取DNS服务器 func GetDNSServer(name string) *DNSServer { for _, server := range DNSServers { if server.Name == name { return &server } } return nil } // SetupViper 设置viper配置 func SetupViper() { viper.SetConfigName("config") viper.SetConfigType("json") // 添加用户配置目录到搜索路径 if userConfigDir, err := getUserConfigDir(); err == nil { viper.AddConfigPath(filepath.Join(userConfigDir, "config")) } viper.AddConfigPath(".") // 设置默认值 viper.SetDefault("hostsPath", getDefaultHostsPath()) viper.SetDefault("backupCount", 5) viper.SetDefault("dnsTimeout", 5000) viper.SetDefault("maxConcurrent", 10) viper.SetDefault("logLevel", "info") } // GetBackupDir 获取备份目录 func GetBackupDir() string { if userConfigDir, err := getUserConfigDir(); err == nil { return filepath.Join(userConfigDir, "backup") } return "backup" // 回退到相对路径 }