package core import ( "context" "fmt" "net" "strings" "sync" "time" "github.com/evil7/hostsync/config" "github.com/evil7/hostsync/utils" "github.com/miekg/dns" ) // DNSResolver DNS解析器 type DNSResolver struct { timeout time.Duration debugMode bool dohClient *DoHClient } // NewDNSResolver 创建DNS解析器 func NewDNSResolver(debugMode ...bool) *DNSResolver { timeout := time.Duration(config.AppConfig.DNSTimeout) * time.Millisecond debug := false if len(debugMode) > 0 { debug = debugMode[0] } return &DNSResolver{ timeout: timeout, debugMode: debug, dohClient: NewDoHClient(timeout, debug), } } // ResolveDomain 解析域名 func (r *DNSResolver) ResolveDomain(domain string, dnsServer, dohServer string) (string, error) { if r.debugMode { utils.LogDebug("开始解析域名: %s", domain) if dohServer != "" { utils.LogDebug("使用DoH服务器: %s", dohServer) } else if dnsServer != "" { utils.LogDebug("使用DNS服务器: %s", dnsServer) } else { utils.LogDebug("使用系统默认DNS") } } // 优先使用DoH if dohServer != "" { if ip, err := r.dohClient.Resolve(domain, dohServer); err == nil { if r.debugMode { utils.LogDebug("DoH解析成功: %s -> %s", domain, ip) } return ip, nil } else if r.debugMode { utils.LogDebug("DoH解析失败: %v", err) } } // 使用传统DNS if dnsServer != "" { if ip, err := r.resolveWithDNS(domain, dnsServer); err == nil { if r.debugMode { utils.LogDebug("DNS解析成功: %s -> %s\n", domain, ip) } return ip, nil } else if r.debugMode { utils.LogDebug("DNS解析失败: %v\n", err) } } // 使用系统默认DNS if ip, err := r.resolveWithSystem(domain); err == nil { if r.debugMode { utils.LogDebug("系统DNS解析成功: %s -> %s", domain, ip) } return ip, nil } else { if r.debugMode { utils.LogDebug("系统DNS解析失败: %v", err) } return "", err } } // resolveWithDNS 使用DNS服务器解析 func (r *DNSResolver) resolveWithDNS(domain, server string) (string, error) { // 确保服务器地址包含端口 if !strings.Contains(server, ":") { server += ":53" } c := new(dns.Client) c.Timeout = r.timeout m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dns.TypeA) resp, _, err := c.Exchange(m, server) if err != nil { return "", fmt.Errorf("DNS查询失败: %v", err) } if len(resp.Answer) == 0 { return "", fmt.Errorf("没有找到A记录") } for _, ans := range resp.Answer { if a, ok := ans.(*dns.A); ok { return a.A.String(), nil } } return "", fmt.Errorf("没有有效的A记录") } // resolveWithSystem 使用系统默认DNS解析 func (r *DNSResolver) resolveWithSystem(domain string) (string, error) { ctx, cancel := context.WithTimeout(context.Background(), r.timeout) defer cancel() ips, err := net.DefaultResolver.LookupIPAddr(ctx, domain) if err != nil { return "", fmt.Errorf("系统DNS解析失败: %v", err) } for _, ip := range ips { if ip.IP.To4() != nil { // IPv4 return ip.IP.String(), nil } } return "", fmt.Errorf("没有找到IPv4地址") } // BatchResolve 批量解析域名 func (r *DNSResolver) BatchResolve(domains []string, dnsServer, dohServer string) map[string]string { if r.debugMode { utils.LogDebug("开始批量解析 %d 个域名", len(domains)) } results := make(map[string]string) var mu sync.Mutex var wg sync.WaitGroup // 控制并发数 semaphore := make(chan struct{}, config.AppConfig.MaxConcurrent) for _, domain := range domains { wg.Add(1) go func(d string) { defer wg.Done() semaphore <- struct{}{} // 获取信号量 defer func() { <-semaphore }() // 释放信号量 if ip, err := r.ResolveDomain(d, dnsServer, dohServer); err == nil { mu.Lock() results[d] = ip mu.Unlock() } else if r.debugMode { utils.LogDebug("解析失败 %s: %v", d, err) } }(domain) } wg.Wait() if r.debugMode { utils.LogDebug("批量解析完成,成功解析 %d/%d 个域名", len(results), len(domains)) } return results } // TestDNSServer 测试DNS服务器 func (r *DNSResolver) TestDNSServer(server string) (time.Duration, error) { start := time.Now() _, err := r.resolveWithDNS("github.com", server) duration := time.Since(start) if err != nil { return 0, err } return duration, nil } // TestDoHServer 测试DoH服务器 func (r *DNSResolver) TestDoHServer(server string) (time.Duration, error) { start := time.Now() _, err := r.dohClient.Resolve("github.com", server) duration := time.Since(start) if err != nil { return 0, err } return duration, nil } // UpdateBlock 更新块中的域名解析 func (r *DNSResolver) UpdateBlock(hm *HostsManager, blockName string, forceDNS, forceDoH, forceServer string, saveConfig bool) error { block := hm.GetBlock(blockName) if block == nil { return fmt.Errorf("块不存在: %s", blockName) } if r.debugMode { utils.LogDebug("开始更新块: %s", blockName) } // 确定使用的DNS服务器 dnsServer := forceDNS dohServer := forceDoH if dnsServer == "" && dohServer == "" && forceServer == "" { // 使用块配置的DNS设置 if block.DNS != "" { dnsServer = block.DNS } else if block.DoH != "" { dohServer = block.DoH } else if block.Server != "" { // 使用预设服务器 if srv := config.GetDNSServer(block.Server); srv != nil { if srv.DNS != "" { dnsServer = srv.DNS } if srv.DoH != "" { dohServer = srv.DoH } } } } else if forceServer != "" { // 强制使用预设服务器 if srv := config.GetDNSServer(forceServer); srv != nil { if srv.DNS != "" { dnsServer = srv.DNS } if srv.DoH != "" { dohServer = srv.DoH } } } if r.debugMode { utils.LogDebug("DNS配置 - DNS: %s, DoH: %s", dnsServer, dohServer) } // 如果开启saveConfig,保存DNS配置到块中 if saveConfig { configUpdated := false if forceDNS != "" && forceDNS != block.DNS { block.DNS = forceDNS block.DoH = "" // 清除DoH设置 block.Server = "" // 清除预设服务器 configUpdated = true utils.LogInfo("已将DNS服务器 '%s' 保存到块 '%s'", forceDNS, blockName) } if forceDoH != "" && forceDoH != block.DoH { block.DoH = forceDoH block.DNS = "" // 清除DNS设置 block.Server = "" // 清除预设服务器 configUpdated = true utils.LogInfo("已将DoH服务器 '%s' 保存到块 '%s'", forceDoH, blockName) } if forceServer != "" && forceServer != block.Server { block.Server = forceServer block.DNS = "" // 清除DNS设置 block.DoH = "" // 清除DoH设置 configUpdated = true utils.LogInfo("已将预设服务器 '%s' 保存到块 '%s'", forceServer, blockName) } if configUpdated { if err := hm.Save(); err != nil { return fmt.Errorf("保存配置失败: %v", err) } } } // 收集需要解析的域名 domains := make([]string, 0, len(block.Entries)) for _, entry := range block.Entries { if entry.Enabled { domains = append(domains, entry.Domain) } } if len(domains) == 0 { return fmt.Errorf("没有需要更新的域名") } if r.debugMode { utils.LogDebug("需要解析 %d 个域名: %v", len(domains), domains) } // 批量解析 results := r.BatchResolve(domains, dnsServer, dohServer) if r.debugMode { utils.LogDebug("解析结果: %d 个成功", len(results)) for domain, ip := range results { utils.LogDebug(" %s -> %s", domain, ip) } } // 更新IP地址 updated := 0 for i, entry := range block.Entries { if entry.Enabled { if newIP, ok := results[entry.Domain]; ok { if r.debugMode { utils.LogDebug("检查 %s: 当前IP=%s, 新IP=%s", entry.Domain, entry.IP, newIP) } if newIP != entry.IP { if r.debugMode { utils.LogDebug("更新 %s: %s -> %s", entry.Domain, entry.IP, newIP) } block.Entries[i].IP = newIP updated++ } } else { if r.debugMode { utils.LogDebug("解析失败: %s", entry.Domain) } } } } if updated > 0 { // 更新时间 block.UpdateAt = time.Now() if err := hm.Save(); err != nil { return fmt.Errorf("保存文件失败: %v", err) } utils.LogInfo("已更新 %d 个域名的IP地址", updated) } else { // 即使没有IP需要更新,也要更新时间戳,让用户知道定时任务确实执行了 block.UpdateAt = time.Now() if err := hm.Save(); err != nil { return fmt.Errorf("保存文件失败: %v", err) } utils.LogInfo("没有IP地址需要更新,已更新检查时间") } return nil }