hostSync/core/dns.go

351 lines
8.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package core
import (
"context"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/evil7/hostsync/config"
"github.com/evil7/hostsync/utils"
"github.com/miekg/dns"
)
// DNSResolver DNS解析器
type DNSResolver struct {
timeout time.Duration
debugMode bool
dohClient *DoHClient
}
// NewDNSResolver 创建DNS解析器
func NewDNSResolver(debugMode ...bool) *DNSResolver {
timeout := time.Duration(config.AppConfig.DNSTimeout) * time.Millisecond
debug := false
if len(debugMode) > 0 {
debug = debugMode[0]
}
return &DNSResolver{
timeout: timeout,
debugMode: debug,
dohClient: NewDoHClient(timeout, debug),
}
}
// ResolveDomain 解析域名
func (r *DNSResolver) ResolveDomain(domain string, dnsServer, dohServer string) (string, error) {
if r.debugMode {
utils.LogDebug("开始解析域名: %s", domain)
if dohServer != "" {
utils.LogDebug("使用DoH服务器: %s", dohServer)
} else if dnsServer != "" {
utils.LogDebug("使用DNS服务器: %s", dnsServer)
} else {
utils.LogDebug("使用系统默认DNS")
}
}
// 优先使用DoH
if dohServer != "" {
if ip, err := r.dohClient.Resolve(domain, dohServer); err == nil {
if r.debugMode {
utils.LogDebug("DoH解析成功: %s -> %s", domain, ip)
}
return ip, nil
} else if r.debugMode {
utils.LogDebug("DoH解析失败: %v", err)
}
}
// 使用传统DNS
if dnsServer != "" {
if ip, err := r.resolveWithDNS(domain, dnsServer); err == nil {
if r.debugMode {
utils.LogDebug("DNS解析成功: %s -> %s\n", domain, ip)
}
return ip, nil
} else if r.debugMode {
utils.LogDebug("DNS解析失败: %v\n", err)
}
}
// 使用系统默认DNS
if ip, err := r.resolveWithSystem(domain); err == nil {
if r.debugMode {
utils.LogDebug("系统DNS解析成功: %s -> %s", domain, ip)
}
return ip, nil
} else {
if r.debugMode {
utils.LogDebug("系统DNS解析失败: %v", err)
}
return "", err
}
}
// resolveWithDNS 使用DNS服务器解析
func (r *DNSResolver) resolveWithDNS(domain, server string) (string, error) {
// 确保服务器地址包含端口
if !strings.Contains(server, ":") {
server += ":53"
}
c := new(dns.Client)
c.Timeout = r.timeout
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
resp, _, err := c.Exchange(m, server)
if err != nil {
return "", fmt.Errorf("DNS查询失败: %v", err)
}
if len(resp.Answer) == 0 {
return "", fmt.Errorf("没有找到A记录")
}
for _, ans := range resp.Answer {
if a, ok := ans.(*dns.A); ok {
return a.A.String(), nil
}
}
return "", fmt.Errorf("没有有效的A记录")
}
// resolveWithSystem 使用系统默认DNS解析
func (r *DNSResolver) resolveWithSystem(domain string) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
ips, err := net.DefaultResolver.LookupIPAddr(ctx, domain)
if err != nil {
return "", fmt.Errorf("系统DNS解析失败: %v", err)
}
for _, ip := range ips {
if ip.IP.To4() != nil { // IPv4
return ip.IP.String(), nil
}
}
return "", fmt.Errorf("没有找到IPv4地址")
}
// BatchResolve 批量解析域名
func (r *DNSResolver) BatchResolve(domains []string, dnsServer, dohServer string) map[string]string {
if r.debugMode {
utils.LogDebug("开始批量解析 %d 个域名", len(domains))
}
results := make(map[string]string)
var mu sync.Mutex
var wg sync.WaitGroup
// 控制并发数
semaphore := make(chan struct{}, config.AppConfig.MaxConcurrent)
for _, domain := range domains {
wg.Add(1)
go func(d string) {
defer wg.Done()
semaphore <- struct{}{} // 获取信号量
defer func() { <-semaphore }() // 释放信号量
if ip, err := r.ResolveDomain(d, dnsServer, dohServer); err == nil {
mu.Lock()
results[d] = ip
mu.Unlock()
} else if r.debugMode {
utils.LogDebug("解析失败 %s: %v", d, err)
}
}(domain)
}
wg.Wait()
if r.debugMode {
utils.LogDebug("批量解析完成,成功解析 %d/%d 个域名", len(results), len(domains))
}
return results
}
// TestDNSServer 测试DNS服务器
func (r *DNSResolver) TestDNSServer(server string) (time.Duration, error) {
start := time.Now()
_, err := r.resolveWithDNS("github.com", server)
duration := time.Since(start)
if err != nil {
return 0, err
}
return duration, nil
}
// TestDoHServer 测试DoH服务器
func (r *DNSResolver) TestDoHServer(server string) (time.Duration, error) {
start := time.Now()
_, err := r.dohClient.Resolve("github.com", server)
duration := time.Since(start)
if err != nil {
return 0, err
}
return duration, nil
}
// UpdateBlock 更新块中的域名解析
func (r *DNSResolver) UpdateBlock(hm *HostsManager, blockName string, forceDNS, forceDoH, forceServer string, saveConfig bool) error {
block := hm.GetBlock(blockName)
if block == nil {
return fmt.Errorf("块不存在: %s", blockName)
}
if r.debugMode {
utils.LogDebug("开始更新块: %s", blockName)
}
// 确定使用的DNS服务器
dnsServer := forceDNS
dohServer := forceDoH
if dnsServer == "" && dohServer == "" && forceServer == "" {
// 使用块配置的DNS设置
if block.DNS != "" {
dnsServer = block.DNS
} else if block.DoH != "" {
dohServer = block.DoH
} else if block.Server != "" {
// 使用预设服务器
if srv := config.GetDNSServer(block.Server); srv != nil {
if srv.DNS != "" {
dnsServer = srv.DNS
}
if srv.DoH != "" {
dohServer = srv.DoH
}
}
}
} else if forceServer != "" {
// 强制使用预设服务器
if srv := config.GetDNSServer(forceServer); srv != nil {
if srv.DNS != "" {
dnsServer = srv.DNS
}
if srv.DoH != "" {
dohServer = srv.DoH
}
}
}
if r.debugMode {
utils.LogDebug("DNS配置 - DNS: %s, DoH: %s", dnsServer, dohServer)
}
// 如果开启saveConfig保存DNS配置到块中
if saveConfig {
configUpdated := false
if forceDNS != "" && forceDNS != block.DNS {
block.DNS = forceDNS
block.DoH = "" // 清除DoH设置
block.Server = "" // 清除预设服务器
configUpdated = true
utils.LogInfo("已将DNS服务器 '%s' 保存到块 '%s'", forceDNS, blockName)
}
if forceDoH != "" && forceDoH != block.DoH {
block.DoH = forceDoH
block.DNS = "" // 清除DNS设置
block.Server = "" // 清除预设服务器
configUpdated = true
utils.LogInfo("已将DoH服务器 '%s' 保存到块 '%s'", forceDoH, blockName)
}
if forceServer != "" && forceServer != block.Server {
block.Server = forceServer
block.DNS = "" // 清除DNS设置
block.DoH = "" // 清除DoH设置
configUpdated = true
utils.LogInfo("已将预设服务器 '%s' 保存到块 '%s'", forceServer, blockName)
}
if configUpdated {
if err := hm.Save(); err != nil {
return fmt.Errorf("保存配置失败: %v", err)
}
}
}
// 收集需要解析的域名
domains := make([]string, 0, len(block.Entries))
for _, entry := range block.Entries {
if entry.Enabled {
domains = append(domains, entry.Domain)
}
}
if len(domains) == 0 {
return fmt.Errorf("没有需要更新的域名")
}
if r.debugMode {
utils.LogDebug("需要解析 %d 个域名: %v", len(domains), domains)
}
// 批量解析
results := r.BatchResolve(domains, dnsServer, dohServer)
if r.debugMode {
utils.LogDebug("解析结果: %d 个成功", len(results))
for domain, ip := range results {
utils.LogDebug(" %s -> %s", domain, ip)
}
}
// 更新IP地址
updated := 0
for i, entry := range block.Entries {
if entry.Enabled {
if newIP, ok := results[entry.Domain]; ok {
if r.debugMode {
utils.LogDebug("检查 %s: 当前IP=%s, 新IP=%s", entry.Domain, entry.IP, newIP)
}
if newIP != entry.IP {
if r.debugMode {
utils.LogDebug("更新 %s: %s -> %s", entry.Domain, entry.IP, newIP)
}
block.Entries[i].IP = newIP
updated++
}
} else {
if r.debugMode {
utils.LogDebug("解析失败: %s", entry.Domain)
}
}
}
}
if updated > 0 {
// 更新时间
block.UpdateAt = time.Now()
if err := hm.Save(); err != nil {
return fmt.Errorf("保存文件失败: %v", err)
}
utils.LogInfo("已更新 %d 个域名的IP地址", updated)
} else {
// 即使没有IP需要更新也要更新时间戳让用户知道定时任务确实执行了
block.UpdateAt = time.Now()
if err := hm.Save(); err != nil {
return fmt.Errorf("保存文件失败: %v", err)
}
utils.LogInfo("没有IP地址需要更新已更新检查时间")
}
return nil
}