216 lines
5.4 KiB
Go
216 lines
5.4 KiB
Go
package config
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"os"
|
||
"os/user"
|
||
"path/filepath"
|
||
"runtime"
|
||
|
||
"github.com/spf13/viper"
|
||
)
|
||
|
||
// Config 程序配置结构
|
||
type Config struct {
|
||
HostsPath string `json:"hostsPath"`
|
||
BackupCount int `json:"backupCount"`
|
||
DNSTimeout int `json:"dnsTimeout"`
|
||
MaxConcurrent int `json:"maxConcurrent"`
|
||
LogLevel string `json:"logLevel"` // debug, info, warning, error, silent
|
||
LogPath string `json:"logPath"` // 空字符串表示只输出到控制台,有路径表示同时输出到文件和控制台
|
||
}
|
||
|
||
// DNSServer DNS服务器配置
|
||
type DNSServer struct {
|
||
Name string `json:"Name"`
|
||
DNS string `json:"Dns"`
|
||
DoH string `json:"Doh"`
|
||
}
|
||
|
||
var (
|
||
AppConfig *Config
|
||
DNSServers []DNSServer
|
||
ConfigPath string
|
||
ServersPath string
|
||
)
|
||
|
||
// Init 初始化配置(使用默认用户目录)
|
||
func Init() {
|
||
userConfigDir, err := getUserConfigDir()
|
||
if err != nil {
|
||
fmt.Fprintf(os.Stderr, "获取用户配置目录失败: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
InitWithConfigDir(userConfigDir)
|
||
}
|
||
|
||
// InitWithConfigDir 使用指定配置目录初始化配置
|
||
func InitWithConfigDir(userConfigDir string) {
|
||
ConfigPath = filepath.Join(userConfigDir, "config", "config.json")
|
||
ServersPath = filepath.Join(userConfigDir, "config", "servers.json")
|
||
// 初始化默认配置
|
||
AppConfig = &Config{
|
||
HostsPath: getDefaultHostsPath(),
|
||
BackupCount: 5,
|
||
DNSTimeout: 5000,
|
||
MaxConcurrent: 10,
|
||
LogLevel: "info",
|
||
LogPath: filepath.Join(userConfigDir, "logs"),
|
||
}
|
||
|
||
// 加载配置文件
|
||
loadConfig()
|
||
loadServers()
|
||
// 创建日志目录
|
||
if err := os.MkdirAll(AppConfig.LogPath, 0755); err != nil {
|
||
fmt.Fprintf(os.Stderr, "创建日志目录失败: %v\n", err)
|
||
}
|
||
}
|
||
|
||
// getUserConfigDir 获取用户配置目录
|
||
func getUserConfigDir() (string, error) {
|
||
currentUser, err := user.Current()
|
||
if err != nil {
|
||
return "", fmt.Errorf("获取当前用户信息失败: %v", err)
|
||
}
|
||
return filepath.Join(currentUser.HomeDir, ".hostsync"), nil
|
||
}
|
||
|
||
// getDefaultHostsPath 获取默认hosts文件路径
|
||
func getDefaultHostsPath() string {
|
||
switch runtime.GOOS {
|
||
case "windows":
|
||
return `C:\Windows\System32\drivers\etc\hosts`
|
||
default:
|
||
return "/etc/hosts"
|
||
}
|
||
}
|
||
|
||
// loadConfig 加载配置文件
|
||
func loadConfig() {
|
||
if _, err := os.Stat(ConfigPath); os.IsNotExist(err) {
|
||
// 配置文件不存在,创建默认配置
|
||
saveConfig()
|
||
return
|
||
}
|
||
|
||
data, err := os.ReadFile(ConfigPath)
|
||
if err != nil {
|
||
fmt.Fprintf(os.Stderr, "读取配置文件失败: %v\n", err)
|
||
return
|
||
}
|
||
|
||
// 创建临时配置对象进行解析测试
|
||
tempConfig := &Config{}
|
||
if err := json.Unmarshal(data, tempConfig); err != nil {
|
||
fmt.Fprintf(os.Stderr, "解析配置文件失败: %v\n", err)
|
||
fmt.Fprintln(os.Stderr, "💡 将使用默认配置并重新创建配置文件")
|
||
saveConfig()
|
||
return
|
||
}
|
||
|
||
// 解析成功,更新当前配置
|
||
AppConfig = tempConfig
|
||
|
||
// 检查并补充缺失的字段
|
||
needSave := false
|
||
if AppConfig.LogPath == "" {
|
||
userConfigDir, _ := getUserConfigDir()
|
||
AppConfig.LogPath = filepath.Join(userConfigDir, "logs")
|
||
needSave = true
|
||
}
|
||
if AppConfig.LogLevel == "" {
|
||
AppConfig.LogLevel = "info"
|
||
needSave = true
|
||
}
|
||
|
||
// 如果有缺失字段,保存更新后的配置
|
||
if needSave {
|
||
saveConfig()
|
||
}
|
||
}
|
||
|
||
// saveConfig 保存配置文件
|
||
func saveConfig() {
|
||
data, err := json.MarshalIndent(AppConfig, "", " ")
|
||
if err != nil {
|
||
fmt.Fprintf(os.Stderr, "序列化配置失败: %v\n", err)
|
||
return
|
||
}
|
||
|
||
// 确保配置文件目录存在
|
||
configDir := filepath.Dir(ConfigPath)
|
||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||
fmt.Fprintf(os.Stderr, "创建配置目录失败: %v\n", err)
|
||
return
|
||
}
|
||
|
||
if err := os.WriteFile(ConfigPath, data, 0644); err != nil {
|
||
fmt.Fprintf(os.Stderr, "保存配置文件失败: %v\n", err)
|
||
}
|
||
}
|
||
|
||
// loadServers 加载DNS服务器配置
|
||
func loadServers() {
|
||
if _, err := os.Stat(ServersPath); os.IsNotExist(err) {
|
||
// DNS服务器配置文件不存在时,不输出错误信息
|
||
// init 命令会创建这个文件
|
||
return
|
||
}
|
||
|
||
data, err := os.ReadFile(ServersPath)
|
||
if err != nil {
|
||
fmt.Fprintf(os.Stderr, "读取DNS服务器配置失败: %v\n", err)
|
||
return
|
||
}
|
||
|
||
// 创建临时变量进行解析测试
|
||
var tempServers []DNSServer
|
||
if err := json.Unmarshal(data, &tempServers); err != nil {
|
||
fmt.Fprintf(os.Stderr, "解析DNS服务器配置失败: %v\n", err)
|
||
fmt.Fprintln(os.Stderr, "💡 请运行 'hostsync init --force' 重新创建配置文件")
|
||
return
|
||
}
|
||
|
||
// 解析成功,更新服务器列表
|
||
DNSServers = tempServers
|
||
}
|
||
|
||
// GetDNSServer 根据名称获取DNS服务器
|
||
func GetDNSServer(name string) *DNSServer {
|
||
for _, server := range DNSServers {
|
||
if server.Name == name {
|
||
return &server
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// SetupViper 设置viper配置
|
||
func SetupViper() {
|
||
viper.SetConfigName("config")
|
||
viper.SetConfigType("json")
|
||
|
||
// 添加用户配置目录到搜索路径
|
||
if userConfigDir, err := getUserConfigDir(); err == nil {
|
||
viper.AddConfigPath(filepath.Join(userConfigDir, "config"))
|
||
}
|
||
|
||
viper.AddConfigPath(".")
|
||
// 设置默认值
|
||
viper.SetDefault("hostsPath", getDefaultHostsPath())
|
||
viper.SetDefault("backupCount", 5)
|
||
viper.SetDefault("dnsTimeout", 5000)
|
||
viper.SetDefault("maxConcurrent", 10)
|
||
viper.SetDefault("logLevel", "info")
|
||
}
|
||
|
||
// GetBackupDir 获取备份目录
|
||
func GetBackupDir() string {
|
||
if userConfigDir, err := getUserConfigDir(); err == nil {
|
||
return filepath.Join(userConfigDir, "backup")
|
||
}
|
||
return "backup" // 回退到相对路径
|
||
}
|