hostSync/core/doh.go

272 lines
7.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package core
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/evil7/hostsync/utils"
"github.com/miekg/dns"
)
// DoHClient DoH客户端
type DoHClient struct {
client *http.Client
debugMode bool
timeout time.Duration
}
// NewDoHClient 创建DoH客户端
func NewDoHClient(timeout time.Duration, debugMode bool) *DoHClient {
// 为DoH请求设置合理的超时时间
dohTimeout := timeout
if dohTimeout < 5*time.Second {
dohTimeout = 5 * time.Second // DoH请求至少5秒超时
}
if dohTimeout > 15*time.Second {
dohTimeout = 15 * time.Second // DoH请求最多15秒超时
}
return &DoHClient{
debugMode: debugMode,
timeout: timeout,
client: &http.Client{
Timeout: dohTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 30 * time.Second,
TLSHandshakeTimeout: 5 * time.Second, // TLS握手超时
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 10 * time.Second, // 响应头超时
DisableKeepAlives: false, // 启用Keep-Alive以提高性能
},
},
}
}
// Resolve 使用DoH解析域名
func (c *DoHClient) Resolve(domain, dohURL string) (string, error) {
// 创建DNS查询消息
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
// 根据RFC 8484为了HTTP缓存友好性DNS ID应该设置为0
m.Id = 0
if c.debugMode {
utils.LogDebug("DoH 创建DNS查询域名: %s, ID: %d", domain, m.Id)
}
// 将DNS消息打包
data, err := m.Pack()
if err != nil {
return "", fmt.Errorf("打包DNS消息失败: %v", err)
}
if c.debugMode {
utils.LogDebug("DoH DNS消息大小: %d 字节", len(data))
}
// 优先尝试POST方法推荐fallback到GET方法
ip, err := c.doRequest(dohURL, data, "POST")
if err != nil {
if c.debugMode {
utils.LogDebug("DoH POST请求失败尝试GET: %v", err)
}
return c.doRequest(dohURL, data, "GET")
}
return ip, nil
}
// doRequest 执行DoH请求
func (c *DoHClient) doRequest(dohURL string, dnsData []byte, method string) (string, error) {
var req *http.Request
var err error
if method == "POST" {
// POST方法直接发送DNS消息作为请求体 (符合RFC 8484 Section 4.1)
// 确保URL不包含查询参数因为POST使用原始路径
baseURL := strings.Split(dohURL, "?")[0]
req, err = http.NewRequest("POST", baseURL, bytes.NewReader(dnsData))
if err != nil {
return "", fmt.Errorf("创建DoH POST请求失败: %v", err)
}
req.Header.Set("Content-Type", "application/dns-message")
req.Header.Set("Content-Length", fmt.Sprintf("%d", len(dnsData)))
if c.debugMode {
utils.LogDebug("DoH POST 请求URL: %s", baseURL)
utils.LogDebug("DoH POST 数据大小: %d 字节", len(dnsData))
}
} else {
// GET方法Base64url编码并作为URL参数 (符合RFC 8484 Section 6)
// 使用Base64url编码无填充确保URL安全字符转换
b64data := base64.RawURLEncoding.EncodeToString(dnsData)
// 根据RFC 8484 Section 6确保使用URL安全的Base64编码
// 虽然RawURLEncoding已经处理了大部分但我们再次确认字符替换
b64data = strings.ReplaceAll(b64data, "+", "-")
b64data = strings.ReplaceAll(b64data, "/", "_")
// 移除任何可能的填充字符RawURLEncoding应该已经不包含但确保安全
b64data = strings.TrimRight(b64data, "=")
if c.debugMode {
utils.LogDebug("DoH GET Base64url编码后: %s (长度: %d)", b64data, len(b64data))
// 验证编码的正确性
if decoded, err := base64.RawURLEncoding.DecodeString(b64data); err == nil {
utils.LogDebug("DoH GET Base64url解码验证成功原始数据长度: %d", len(decoded))
} else {
utils.LogDebug("DoH GET Base64url解码验证失败: %v", err)
}
}
// 处理URL模板如果URL包含{?dns}模板,替换它;否则直接添加查询参数
var finalURL string
if strings.Contains(dohURL, "{?dns}") {
// URI模板格式替换{?dns}
finalURL = strings.ReplaceAll(dohURL, "{?dns}", "?dns="+b64data)
} else if strings.Contains(dohURL, "?") {
// URL已包含查询参数添加dns参数
finalURL = dohURL + "&dns=" + b64data
} else {
// URL不包含查询参数添加dns参数
finalURL = dohURL + "?dns=" + b64data
}
req, err = http.NewRequest("GET", finalURL, nil)
if err != nil {
return "", fmt.Errorf("创建DoH GET请求失败: %v", err)
}
if c.debugMode {
utils.LogDebug("DoH GET 请求URL: %s", finalURL)
}
}
// 设置标准DoH头部符合RFC 8484 Section 4.1
req.Header.Set("Accept", "application/dns-message")
req.Header.Set("User-Agent", "HostSync/1.0")
// 根据RFC 8484建议GET方法更缓存友好但POST方法通常更小
if method == "GET" {
// GET请求更适合HTTP缓存
req.Header.Set("Cache-Control", "max-age=300")
} else {
// POST请求避免缓存问题
req.Header.Set("Cache-Control", "no-cache")
}
if c.debugMode {
utils.LogDebug("DoH %s请求 URL: %s", method, req.URL.String())
utils.LogDebug("DoH 请求头: %v", req.Header)
}
// 设置超时上下文
timeout := c.client.Timeout
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
req = req.WithContext(ctx)
// 执行请求
resp, err := c.client.Do(req)
if err != nil {
// 检查是否为超时错误
if ctx.Err() == context.DeadlineExceeded {
return "", fmt.Errorf("DoH请求超时 (超过%v): %v", timeout, err)
}
return "", fmt.Errorf("DoH请求失败: %v", err)
}
defer resp.Body.Close()
if c.debugMode {
utils.LogDebug("DoH 响应状态: %d %s", resp.StatusCode, resp.Status)
utils.LogDebug("DoH 响应头: %v", resp.Header)
}
// 根据RFC 8484 Section 4.2.1成功的2xx状态码用于任何有效的DNS响应
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
bodyBytes, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("DoH请求返回错误状态: %d %s, 响应体: %s",
resp.StatusCode, resp.Status, string(bodyBytes))
}
// 验证响应内容类型符合RFC 8484 Section 6
contentType := resp.Header.Get("Content-Type")
if contentType != "" && !strings.Contains(contentType, "application/dns-message") {
if c.debugMode {
utils.LogDebug("DoH 意外的Content-Type: %s (期望: application/dns-message)", contentType)
}
// 继续处理某些服务器可能不设置正确的Content-Type
}
// 读取响应数据
respData, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取DoH响应失败: %v", err)
}
if c.debugMode {
utils.LogDebug("DoH 响应数据大小: %d 字节", len(respData))
}
// 检查响应数据是否为空
if len(respData) == 0 {
return "", fmt.Errorf("DoH响应数据为空")
}
// 解析DNS响应
var respMsg dns.Msg
if err := respMsg.Unpack(respData); err != nil {
return "", fmt.Errorf("解析DNS响应失败 (数据可能损坏): %v", err)
}
if c.debugMode {
utils.LogDebug("DoH DNS响应 - ID: %d, 答案数量: %d, 响应代码: %d (%s)",
respMsg.Id, len(respMsg.Answer), respMsg.Rcode, dns.RcodeToString[respMsg.Rcode])
}
// 检查DNS响应代码符合RFC标准
if respMsg.Rcode != dns.RcodeSuccess {
return "", fmt.Errorf("DNS响应错误响应代码: %d (%s)",
respMsg.Rcode, dns.RcodeToString[respMsg.Rcode])
}
if len(respMsg.Answer) == 0 {
return "", fmt.Errorf("DoH响应中没有找到答案记录")
}
// 查找A记录
for _, ans := range respMsg.Answer {
if a, ok := ans.(*dns.A); ok {
if c.debugMode {
utils.LogDebug("DoH 找到A记录: %s (TTL: %d)", a.A.String(), a.Hdr.Ttl)
}
return a.A.String(), nil
}
}
return "", fmt.Errorf("DoH响应中没有找到A记录")
}
// Test 测试DoH服务器
func (c *DoHClient) Test(server string) (time.Duration, error) {
start := time.Now()
_, err := c.Resolve("github.com", server)
duration := time.Since(start)
if err != nil {
return 0, err
}
return duration, nil
}