351 lines
8.4 KiB
Go
351 lines
8.4 KiB
Go
package core
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"net"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/evil7/hostsync/config"
|
||
"github.com/evil7/hostsync/utils"
|
||
"github.com/miekg/dns"
|
||
)
|
||
|
||
// DNSResolver DNS解析器
|
||
type DNSResolver struct {
|
||
timeout time.Duration
|
||
debugMode bool
|
||
dohClient *DoHClient
|
||
}
|
||
|
||
// NewDNSResolver 创建DNS解析器
|
||
func NewDNSResolver(debugMode ...bool) *DNSResolver {
|
||
timeout := time.Duration(config.AppConfig.DNSTimeout) * time.Millisecond
|
||
|
||
debug := false
|
||
if len(debugMode) > 0 {
|
||
debug = debugMode[0]
|
||
}
|
||
|
||
return &DNSResolver{
|
||
timeout: timeout,
|
||
debugMode: debug,
|
||
dohClient: NewDoHClient(timeout, debug),
|
||
}
|
||
}
|
||
|
||
// ResolveDomain 解析域名
|
||
func (r *DNSResolver) ResolveDomain(domain string, dnsServer, dohServer string) (string, error) {
|
||
if r.debugMode {
|
||
utils.LogDebug("开始解析域名: %s", domain)
|
||
if dohServer != "" {
|
||
utils.LogDebug("使用DoH服务器: %s", dohServer)
|
||
} else if dnsServer != "" {
|
||
utils.LogDebug("使用DNS服务器: %s", dnsServer)
|
||
} else {
|
||
utils.LogDebug("使用系统默认DNS")
|
||
}
|
||
}
|
||
|
||
// 优先使用DoH
|
||
if dohServer != "" {
|
||
if ip, err := r.dohClient.Resolve(domain, dohServer); err == nil {
|
||
if r.debugMode {
|
||
utils.LogDebug("DoH解析成功: %s -> %s", domain, ip)
|
||
}
|
||
return ip, nil
|
||
} else if r.debugMode {
|
||
utils.LogDebug("DoH解析失败: %v", err)
|
||
}
|
||
}
|
||
|
||
// 使用传统DNS
|
||
if dnsServer != "" {
|
||
if ip, err := r.resolveWithDNS(domain, dnsServer); err == nil {
|
||
if r.debugMode {
|
||
utils.LogDebug("DNS解析成功: %s -> %s\n", domain, ip)
|
||
}
|
||
return ip, nil
|
||
} else if r.debugMode {
|
||
utils.LogDebug("DNS解析失败: %v\n", err)
|
||
}
|
||
}
|
||
|
||
// 使用系统默认DNS
|
||
if ip, err := r.resolveWithSystem(domain); err == nil {
|
||
if r.debugMode {
|
||
utils.LogDebug("系统DNS解析成功: %s -> %s", domain, ip)
|
||
}
|
||
return ip, nil
|
||
} else {
|
||
if r.debugMode {
|
||
utils.LogDebug("系统DNS解析失败: %v", err)
|
||
}
|
||
return "", err
|
||
}
|
||
}
|
||
|
||
// resolveWithDNS 使用DNS服务器解析
|
||
func (r *DNSResolver) resolveWithDNS(domain, server string) (string, error) {
|
||
// 确保服务器地址包含端口
|
||
if !strings.Contains(server, ":") {
|
||
server += ":53"
|
||
}
|
||
|
||
c := new(dns.Client)
|
||
c.Timeout = r.timeout
|
||
|
||
m := new(dns.Msg)
|
||
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
|
||
|
||
resp, _, err := c.Exchange(m, server)
|
||
if err != nil {
|
||
return "", fmt.Errorf("DNS查询失败: %v", err)
|
||
}
|
||
|
||
if len(resp.Answer) == 0 {
|
||
return "", fmt.Errorf("没有找到A记录")
|
||
}
|
||
|
||
for _, ans := range resp.Answer {
|
||
if a, ok := ans.(*dns.A); ok {
|
||
return a.A.String(), nil
|
||
}
|
||
}
|
||
|
||
return "", fmt.Errorf("没有有效的A记录")
|
||
}
|
||
|
||
// resolveWithSystem 使用系统默认DNS解析
|
||
func (r *DNSResolver) resolveWithSystem(domain string) (string, error) {
|
||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||
defer cancel()
|
||
|
||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, domain)
|
||
if err != nil {
|
||
return "", fmt.Errorf("系统DNS解析失败: %v", err)
|
||
}
|
||
|
||
for _, ip := range ips {
|
||
if ip.IP.To4() != nil { // IPv4
|
||
return ip.IP.String(), nil
|
||
}
|
||
}
|
||
|
||
return "", fmt.Errorf("没有找到IPv4地址")
|
||
}
|
||
|
||
// BatchResolve 批量解析域名
|
||
func (r *DNSResolver) BatchResolve(domains []string, dnsServer, dohServer string) map[string]string {
|
||
if r.debugMode {
|
||
utils.LogDebug("开始批量解析 %d 个域名", len(domains))
|
||
}
|
||
|
||
results := make(map[string]string)
|
||
var mu sync.Mutex
|
||
var wg sync.WaitGroup
|
||
|
||
// 控制并发数
|
||
semaphore := make(chan struct{}, config.AppConfig.MaxConcurrent)
|
||
|
||
for _, domain := range domains {
|
||
wg.Add(1)
|
||
go func(d string) {
|
||
defer wg.Done()
|
||
semaphore <- struct{}{} // 获取信号量
|
||
defer func() { <-semaphore }() // 释放信号量
|
||
|
||
if ip, err := r.ResolveDomain(d, dnsServer, dohServer); err == nil {
|
||
mu.Lock()
|
||
results[d] = ip
|
||
mu.Unlock()
|
||
} else if r.debugMode {
|
||
utils.LogDebug("解析失败 %s: %v", d, err)
|
||
}
|
||
}(domain)
|
||
}
|
||
|
||
wg.Wait()
|
||
|
||
if r.debugMode {
|
||
utils.LogDebug("批量解析完成,成功解析 %d/%d 个域名", len(results), len(domains))
|
||
}
|
||
|
||
return results
|
||
}
|
||
|
||
// TestDNSServer 测试DNS服务器
|
||
func (r *DNSResolver) TestDNSServer(server string) (time.Duration, error) {
|
||
start := time.Now()
|
||
_, err := r.resolveWithDNS("github.com", server)
|
||
duration := time.Since(start)
|
||
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
return duration, nil
|
||
}
|
||
|
||
// TestDoHServer 测试DoH服务器
|
||
func (r *DNSResolver) TestDoHServer(server string) (time.Duration, error) {
|
||
start := time.Now()
|
||
_, err := r.dohClient.Resolve("github.com", server)
|
||
duration := time.Since(start)
|
||
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
return duration, nil
|
||
}
|
||
|
||
// UpdateBlock 更新块中的域名解析
|
||
func (r *DNSResolver) UpdateBlock(hm *HostsManager, blockName string, forceDNS, forceDoH, forceServer string, saveConfig bool) error {
|
||
block := hm.GetBlock(blockName)
|
||
if block == nil {
|
||
return fmt.Errorf("块不存在: %s", blockName)
|
||
}
|
||
|
||
if r.debugMode {
|
||
utils.LogDebug("开始更新块: %s", blockName)
|
||
}
|
||
|
||
// 确定使用的DNS服务器
|
||
dnsServer := forceDNS
|
||
dohServer := forceDoH
|
||
|
||
if dnsServer == "" && dohServer == "" && forceServer == "" {
|
||
// 使用块配置的DNS设置
|
||
if block.DNS != "" {
|
||
dnsServer = block.DNS
|
||
} else if block.DoH != "" {
|
||
dohServer = block.DoH
|
||
} else if block.Server != "" {
|
||
// 使用预设服务器
|
||
if srv := config.GetDNSServer(block.Server); srv != nil {
|
||
if srv.DNS != "" {
|
||
dnsServer = srv.DNS
|
||
}
|
||
if srv.DoH != "" {
|
||
dohServer = srv.DoH
|
||
}
|
||
}
|
||
}
|
||
} else if forceServer != "" {
|
||
// 强制使用预设服务器
|
||
if srv := config.GetDNSServer(forceServer); srv != nil {
|
||
if srv.DNS != "" {
|
||
dnsServer = srv.DNS
|
||
}
|
||
if srv.DoH != "" {
|
||
dohServer = srv.DoH
|
||
}
|
||
}
|
||
}
|
||
|
||
if r.debugMode {
|
||
utils.LogDebug("DNS配置 - DNS: %s, DoH: %s", dnsServer, dohServer)
|
||
}
|
||
|
||
// 如果开启saveConfig,保存DNS配置到块中
|
||
if saveConfig {
|
||
configUpdated := false
|
||
if forceDNS != "" && forceDNS != block.DNS {
|
||
block.DNS = forceDNS
|
||
block.DoH = "" // 清除DoH设置
|
||
block.Server = "" // 清除预设服务器
|
||
configUpdated = true
|
||
utils.LogInfo("已将DNS服务器 '%s' 保存到块 '%s'", forceDNS, blockName)
|
||
}
|
||
if forceDoH != "" && forceDoH != block.DoH {
|
||
block.DoH = forceDoH
|
||
block.DNS = "" // 清除DNS设置
|
||
block.Server = "" // 清除预设服务器
|
||
configUpdated = true
|
||
utils.LogInfo("已将DoH服务器 '%s' 保存到块 '%s'", forceDoH, blockName)
|
||
}
|
||
if forceServer != "" && forceServer != block.Server {
|
||
block.Server = forceServer
|
||
block.DNS = "" // 清除DNS设置
|
||
block.DoH = "" // 清除DoH设置
|
||
configUpdated = true
|
||
utils.LogInfo("已将预设服务器 '%s' 保存到块 '%s'", forceServer, blockName)
|
||
}
|
||
|
||
if configUpdated {
|
||
if err := hm.Save(); err != nil {
|
||
return fmt.Errorf("保存配置失败: %v", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 收集需要解析的域名
|
||
domains := make([]string, 0, len(block.Entries))
|
||
for _, entry := range block.Entries {
|
||
if entry.Enabled {
|
||
domains = append(domains, entry.Domain)
|
||
}
|
||
}
|
||
|
||
if len(domains) == 0 {
|
||
return fmt.Errorf("没有需要更新的域名")
|
||
}
|
||
|
||
if r.debugMode {
|
||
utils.LogDebug("需要解析 %d 个域名: %v", len(domains), domains)
|
||
}
|
||
|
||
// 批量解析
|
||
results := r.BatchResolve(domains, dnsServer, dohServer)
|
||
|
||
if r.debugMode {
|
||
utils.LogDebug("解析结果: %d 个成功", len(results))
|
||
for domain, ip := range results {
|
||
utils.LogDebug(" %s -> %s", domain, ip)
|
||
}
|
||
}
|
||
|
||
// 更新IP地址
|
||
updated := 0
|
||
for i, entry := range block.Entries {
|
||
if entry.Enabled {
|
||
if newIP, ok := results[entry.Domain]; ok {
|
||
if r.debugMode {
|
||
utils.LogDebug("检查 %s: 当前IP=%s, 新IP=%s", entry.Domain, entry.IP, newIP)
|
||
}
|
||
if newIP != entry.IP {
|
||
if r.debugMode {
|
||
utils.LogDebug("更新 %s: %s -> %s", entry.Domain, entry.IP, newIP)
|
||
}
|
||
block.Entries[i].IP = newIP
|
||
updated++
|
||
}
|
||
} else {
|
||
if r.debugMode {
|
||
utils.LogDebug("解析失败: %s", entry.Domain)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if updated > 0 {
|
||
// 更新时间
|
||
block.UpdateAt = time.Now()
|
||
if err := hm.Save(); err != nil {
|
||
return fmt.Errorf("保存文件失败: %v", err)
|
||
}
|
||
utils.LogInfo("已更新 %d 个域名的IP地址", updated)
|
||
} else {
|
||
// 即使没有IP需要更新,也要更新时间戳,让用户知道定时任务确实执行了
|
||
block.UpdateAt = time.Now()
|
||
if err := hm.Save(); err != nil {
|
||
return fmt.Errorf("保存文件失败: %v", err)
|
||
}
|
||
utils.LogInfo("没有IP地址需要更新,已更新检查时间")
|
||
}
|
||
return nil
|
||
}
|