272 lines
7.9 KiB
Go
272 lines
7.9 KiB
Go
package core
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"crypto/tls"
|
||
"encoding/base64"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/evil7/hostsync/utils"
|
||
"github.com/miekg/dns"
|
||
)
|
||
|
||
// DoHClient DoH客户端
|
||
type DoHClient struct {
|
||
client *http.Client
|
||
debugMode bool
|
||
timeout time.Duration
|
||
}
|
||
|
||
// NewDoHClient 创建DoH客户端
|
||
func NewDoHClient(timeout time.Duration, debugMode bool) *DoHClient {
|
||
// 为DoH请求设置合理的超时时间
|
||
dohTimeout := timeout
|
||
if dohTimeout < 5*time.Second {
|
||
dohTimeout = 5 * time.Second // DoH请求至少5秒超时
|
||
}
|
||
if dohTimeout > 15*time.Second {
|
||
dohTimeout = 15 * time.Second // DoH请求最多15秒超时
|
||
}
|
||
|
||
return &DoHClient{
|
||
debugMode: debugMode,
|
||
timeout: timeout,
|
||
client: &http.Client{
|
||
Timeout: dohTimeout,
|
||
Transport: &http.Transport{
|
||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||
MaxIdleConns: 100,
|
||
MaxIdleConnsPerHost: 10,
|
||
IdleConnTimeout: 30 * time.Second,
|
||
TLSHandshakeTimeout: 5 * time.Second, // TLS握手超时
|
||
ExpectContinueTimeout: 1 * time.Second,
|
||
ResponseHeaderTimeout: 10 * time.Second, // 响应头超时
|
||
DisableKeepAlives: false, // 启用Keep-Alive以提高性能
|
||
},
|
||
},
|
||
}
|
||
}
|
||
|
||
// Resolve 使用DoH解析域名
|
||
func (c *DoHClient) Resolve(domain, dohURL string) (string, error) {
|
||
// 创建DNS查询消息
|
||
m := new(dns.Msg)
|
||
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
|
||
|
||
// 根据RFC 8484,为了HTTP缓存友好性,DNS ID应该设置为0
|
||
m.Id = 0
|
||
|
||
if c.debugMode {
|
||
utils.LogDebug("DoH 创建DNS查询,域名: %s, ID: %d", domain, m.Id)
|
||
}
|
||
|
||
// 将DNS消息打包
|
||
data, err := m.Pack()
|
||
if err != nil {
|
||
return "", fmt.Errorf("打包DNS消息失败: %v", err)
|
||
}
|
||
|
||
if c.debugMode {
|
||
utils.LogDebug("DoH DNS消息大小: %d 字节", len(data))
|
||
}
|
||
|
||
// 优先尝试POST方法(推荐),fallback到GET方法
|
||
ip, err := c.doRequest(dohURL, data, "POST")
|
||
if err != nil {
|
||
if c.debugMode {
|
||
utils.LogDebug("DoH POST请求失败,尝试GET: %v", err)
|
||
}
|
||
return c.doRequest(dohURL, data, "GET")
|
||
}
|
||
return ip, nil
|
||
}
|
||
|
||
// doRequest 执行DoH请求
|
||
func (c *DoHClient) doRequest(dohURL string, dnsData []byte, method string) (string, error) {
|
||
var req *http.Request
|
||
var err error
|
||
|
||
if method == "POST" {
|
||
// POST方法:直接发送DNS消息作为请求体 (符合RFC 8484 Section 4.1)
|
||
// 确保URL不包含查询参数,因为POST使用原始路径
|
||
baseURL := strings.Split(dohURL, "?")[0]
|
||
|
||
req, err = http.NewRequest("POST", baseURL, bytes.NewReader(dnsData))
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建DoH POST请求失败: %v", err)
|
||
}
|
||
req.Header.Set("Content-Type", "application/dns-message")
|
||
req.Header.Set("Content-Length", fmt.Sprintf("%d", len(dnsData)))
|
||
|
||
if c.debugMode {
|
||
utils.LogDebug("DoH POST 请求URL: %s", baseURL)
|
||
utils.LogDebug("DoH POST 数据大小: %d 字节", len(dnsData))
|
||
}
|
||
} else {
|
||
// GET方法:Base64url编码并作为URL参数 (符合RFC 8484 Section 6)
|
||
// 使用Base64url编码(无填充),确保URL安全字符转换
|
||
b64data := base64.RawURLEncoding.EncodeToString(dnsData)
|
||
|
||
// 根据RFC 8484 Section 6,确保使用URL安全的Base64编码
|
||
// 虽然RawURLEncoding已经处理了大部分,但我们再次确认字符替换
|
||
b64data = strings.ReplaceAll(b64data, "+", "-")
|
||
b64data = strings.ReplaceAll(b64data, "/", "_")
|
||
// 移除任何可能的填充字符(RawURLEncoding应该已经不包含,但确保安全)
|
||
b64data = strings.TrimRight(b64data, "=")
|
||
|
||
if c.debugMode {
|
||
utils.LogDebug("DoH GET Base64url编码后: %s (长度: %d)", b64data, len(b64data))
|
||
// 验证编码的正确性
|
||
if decoded, err := base64.RawURLEncoding.DecodeString(b64data); err == nil {
|
||
utils.LogDebug("DoH GET Base64url解码验证成功,原始数据长度: %d", len(decoded))
|
||
} else {
|
||
utils.LogDebug("DoH GET Base64url解码验证失败: %v", err)
|
||
}
|
||
}
|
||
|
||
// 处理URL模板:如果URL包含{?dns}模板,替换它;否则直接添加查询参数
|
||
var finalURL string
|
||
if strings.Contains(dohURL, "{?dns}") {
|
||
// URI模板格式,替换{?dns}
|
||
finalURL = strings.ReplaceAll(dohURL, "{?dns}", "?dns="+b64data)
|
||
} else if strings.Contains(dohURL, "?") {
|
||
// URL已包含查询参数,添加dns参数
|
||
finalURL = dohURL + "&dns=" + b64data
|
||
} else {
|
||
// URL不包含查询参数,添加dns参数
|
||
finalURL = dohURL + "?dns=" + b64data
|
||
}
|
||
|
||
req, err = http.NewRequest("GET", finalURL, nil)
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建DoH GET请求失败: %v", err)
|
||
}
|
||
|
||
if c.debugMode {
|
||
utils.LogDebug("DoH GET 请求URL: %s", finalURL)
|
||
}
|
||
}
|
||
|
||
// 设置标准DoH头部(符合RFC 8484 Section 4.1)
|
||
req.Header.Set("Accept", "application/dns-message")
|
||
req.Header.Set("User-Agent", "HostSync/1.0")
|
||
|
||
// 根据RFC 8484建议,GET方法更缓存友好,但POST方法通常更小
|
||
if method == "GET" {
|
||
// GET请求更适合HTTP缓存
|
||
req.Header.Set("Cache-Control", "max-age=300")
|
||
} else {
|
||
// POST请求避免缓存问题
|
||
req.Header.Set("Cache-Control", "no-cache")
|
||
}
|
||
|
||
if c.debugMode {
|
||
utils.LogDebug("DoH %s请求 URL: %s", method, req.URL.String())
|
||
utils.LogDebug("DoH 请求头: %v", req.Header)
|
||
}
|
||
|
||
// 设置超时上下文
|
||
timeout := c.client.Timeout
|
||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||
defer cancel()
|
||
req = req.WithContext(ctx)
|
||
|
||
// 执行请求
|
||
resp, err := c.client.Do(req)
|
||
if err != nil {
|
||
// 检查是否为超时错误
|
||
if ctx.Err() == context.DeadlineExceeded {
|
||
return "", fmt.Errorf("DoH请求超时 (超过%v): %v", timeout, err)
|
||
}
|
||
return "", fmt.Errorf("DoH请求失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if c.debugMode {
|
||
utils.LogDebug("DoH 响应状态: %d %s", resp.StatusCode, resp.Status)
|
||
utils.LogDebug("DoH 响应头: %v", resp.Header)
|
||
}
|
||
|
||
// 根据RFC 8484 Section 4.2.1,成功的2xx状态码用于任何有效的DNS响应
|
||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||
return "", fmt.Errorf("DoH请求返回错误状态: %d %s, 响应体: %s",
|
||
resp.StatusCode, resp.Status, string(bodyBytes))
|
||
}
|
||
|
||
// 验证响应内容类型(符合RFC 8484 Section 6)
|
||
contentType := resp.Header.Get("Content-Type")
|
||
if contentType != "" && !strings.Contains(contentType, "application/dns-message") {
|
||
if c.debugMode {
|
||
utils.LogDebug("DoH 意外的Content-Type: %s (期望: application/dns-message)", contentType)
|
||
}
|
||
// 继续处理,某些服务器可能不设置正确的Content-Type
|
||
}
|
||
|
||
// 读取响应数据
|
||
respData, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取DoH响应失败: %v", err)
|
||
}
|
||
|
||
if c.debugMode {
|
||
utils.LogDebug("DoH 响应数据大小: %d 字节", len(respData))
|
||
}
|
||
|
||
// 检查响应数据是否为空
|
||
if len(respData) == 0 {
|
||
return "", fmt.Errorf("DoH响应数据为空")
|
||
}
|
||
|
||
// 解析DNS响应
|
||
var respMsg dns.Msg
|
||
if err := respMsg.Unpack(respData); err != nil {
|
||
return "", fmt.Errorf("解析DNS响应失败 (数据可能损坏): %v", err)
|
||
}
|
||
|
||
if c.debugMode {
|
||
utils.LogDebug("DoH DNS响应 - ID: %d, 答案数量: %d, 响应代码: %d (%s)",
|
||
respMsg.Id, len(respMsg.Answer), respMsg.Rcode, dns.RcodeToString[respMsg.Rcode])
|
||
}
|
||
|
||
// 检查DNS响应代码(符合RFC标准)
|
||
if respMsg.Rcode != dns.RcodeSuccess {
|
||
return "", fmt.Errorf("DNS响应错误,响应代码: %d (%s)",
|
||
respMsg.Rcode, dns.RcodeToString[respMsg.Rcode])
|
||
}
|
||
|
||
if len(respMsg.Answer) == 0 {
|
||
return "", fmt.Errorf("DoH响应中没有找到答案记录")
|
||
}
|
||
|
||
// 查找A记录
|
||
for _, ans := range respMsg.Answer {
|
||
if a, ok := ans.(*dns.A); ok {
|
||
if c.debugMode {
|
||
utils.LogDebug("DoH 找到A记录: %s (TTL: %d)", a.A.String(), a.Hdr.Ttl)
|
||
}
|
||
return a.A.String(), nil
|
||
}
|
||
}
|
||
|
||
return "", fmt.Errorf("DoH响应中没有找到A记录")
|
||
}
|
||
|
||
// Test 测试DoH服务器
|
||
func (c *DoHClient) Test(server string) (time.Duration, error) {
|
||
start := time.Now()
|
||
_, err := c.Resolve("github.com", server)
|
||
duration := time.Since(start)
|
||
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
return duration, nil
|
||
}
|