package core import ( "bytes" "context" "crypto/tls" "encoding/base64" "fmt" "io" "net/http" "strings" "time" "github.com/evil7/hostsync/utils" "github.com/miekg/dns" ) // DoHClient DoH客户端 type DoHClient struct { client *http.Client debugMode bool timeout time.Duration } // NewDoHClient 创建DoH客户端 func NewDoHClient(timeout time.Duration, debugMode bool) *DoHClient { // 为DoH请求设置合理的超时时间 dohTimeout := timeout if dohTimeout < 5*time.Second { dohTimeout = 5 * time.Second // DoH请求至少5秒超时 } if dohTimeout > 15*time.Second { dohTimeout = 15 * time.Second // DoH请求最多15秒超时 } return &DoHClient{ debugMode: debugMode, timeout: timeout, client: &http.Client{ Timeout: dohTimeout, Transport: &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, MaxIdleConns: 100, MaxIdleConnsPerHost: 10, IdleConnTimeout: 30 * time.Second, TLSHandshakeTimeout: 5 * time.Second, // TLS握手超时 ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second, // 响应头超时 DisableKeepAlives: false, // 启用Keep-Alive以提高性能 }, }, } } // Resolve 使用DoH解析域名 func (c *DoHClient) Resolve(domain, dohURL string) (string, error) { // 创建DNS查询消息 m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dns.TypeA) // 根据RFC 8484,为了HTTP缓存友好性,DNS ID应该设置为0 m.Id = 0 if c.debugMode { utils.LogDebug("DoH 创建DNS查询,域名: %s, ID: %d", domain, m.Id) } // 将DNS消息打包 data, err := m.Pack() if err != nil { return "", fmt.Errorf("打包DNS消息失败: %v", err) } if c.debugMode { utils.LogDebug("DoH DNS消息大小: %d 字节", len(data)) } // 优先尝试POST方法(推荐),fallback到GET方法 ip, err := c.doRequest(dohURL, data, "POST") if err != nil { if c.debugMode { utils.LogDebug("DoH POST请求失败,尝试GET: %v", err) } return c.doRequest(dohURL, data, "GET") } return ip, nil } // doRequest 执行DoH请求 func (c *DoHClient) doRequest(dohURL string, dnsData []byte, method string) (string, error) { var req *http.Request var err error if method == "POST" { // POST方法:直接发送DNS消息作为请求体 (符合RFC 8484 Section 4.1) // 确保URL不包含查询参数,因为POST使用原始路径 baseURL := strings.Split(dohURL, "?")[0] req, err = http.NewRequest("POST", baseURL, bytes.NewReader(dnsData)) if err != nil { return "", fmt.Errorf("创建DoH POST请求失败: %v", err) } req.Header.Set("Content-Type", "application/dns-message") req.Header.Set("Content-Length", fmt.Sprintf("%d", len(dnsData))) if c.debugMode { utils.LogDebug("DoH POST 请求URL: %s", baseURL) utils.LogDebug("DoH POST 数据大小: %d 字节", len(dnsData)) } } else { // GET方法:Base64url编码并作为URL参数 (符合RFC 8484 Section 6) // 使用Base64url编码(无填充),确保URL安全字符转换 b64data := base64.RawURLEncoding.EncodeToString(dnsData) // 根据RFC 8484 Section 6,确保使用URL安全的Base64编码 // 虽然RawURLEncoding已经处理了大部分,但我们再次确认字符替换 b64data = strings.ReplaceAll(b64data, "+", "-") b64data = strings.ReplaceAll(b64data, "/", "_") // 移除任何可能的填充字符(RawURLEncoding应该已经不包含,但确保安全) b64data = strings.TrimRight(b64data, "=") if c.debugMode { utils.LogDebug("DoH GET Base64url编码后: %s (长度: %d)", b64data, len(b64data)) // 验证编码的正确性 if decoded, err := base64.RawURLEncoding.DecodeString(b64data); err == nil { utils.LogDebug("DoH GET Base64url解码验证成功,原始数据长度: %d", len(decoded)) } else { utils.LogDebug("DoH GET Base64url解码验证失败: %v", err) } } // 处理URL模板:如果URL包含{?dns}模板,替换它;否则直接添加查询参数 var finalURL string if strings.Contains(dohURL, "{?dns}") { // URI模板格式,替换{?dns} finalURL = strings.ReplaceAll(dohURL, "{?dns}", "?dns="+b64data) } else if strings.Contains(dohURL, "?") { // URL已包含查询参数,添加dns参数 finalURL = dohURL + "&dns=" + b64data } else { // URL不包含查询参数,添加dns参数 finalURL = dohURL + "?dns=" + b64data } req, err = http.NewRequest("GET", finalURL, nil) if err != nil { return "", fmt.Errorf("创建DoH GET请求失败: %v", err) } if c.debugMode { utils.LogDebug("DoH GET 请求URL: %s", finalURL) } } // 设置标准DoH头部(符合RFC 8484 Section 4.1) req.Header.Set("Accept", "application/dns-message") req.Header.Set("User-Agent", "HostSync/1.0") // 根据RFC 8484建议,GET方法更缓存友好,但POST方法通常更小 if method == "GET" { // GET请求更适合HTTP缓存 req.Header.Set("Cache-Control", "max-age=300") } else { // POST请求避免缓存问题 req.Header.Set("Cache-Control", "no-cache") } if c.debugMode { utils.LogDebug("DoH %s请求 URL: %s", method, req.URL.String()) utils.LogDebug("DoH 请求头: %v", req.Header) } // 设置超时上下文 timeout := c.client.Timeout ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() req = req.WithContext(ctx) // 执行请求 resp, err := c.client.Do(req) if err != nil { // 检查是否为超时错误 if ctx.Err() == context.DeadlineExceeded { return "", fmt.Errorf("DoH请求超时 (超过%v): %v", timeout, err) } return "", fmt.Errorf("DoH请求失败: %v", err) } defer resp.Body.Close() if c.debugMode { utils.LogDebug("DoH 响应状态: %d %s", resp.StatusCode, resp.Status) utils.LogDebug("DoH 响应头: %v", resp.Header) } // 根据RFC 8484 Section 4.2.1,成功的2xx状态码用于任何有效的DNS响应 if resp.StatusCode < 200 || resp.StatusCode >= 300 { bodyBytes, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("DoH请求返回错误状态: %d %s, 响应体: %s", resp.StatusCode, resp.Status, string(bodyBytes)) } // 验证响应内容类型(符合RFC 8484 Section 6) contentType := resp.Header.Get("Content-Type") if contentType != "" && !strings.Contains(contentType, "application/dns-message") { if c.debugMode { utils.LogDebug("DoH 意外的Content-Type: %s (期望: application/dns-message)", contentType) } // 继续处理,某些服务器可能不设置正确的Content-Type } // 读取响应数据 respData, err := io.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("读取DoH响应失败: %v", err) } if c.debugMode { utils.LogDebug("DoH 响应数据大小: %d 字节", len(respData)) } // 检查响应数据是否为空 if len(respData) == 0 { return "", fmt.Errorf("DoH响应数据为空") } // 解析DNS响应 var respMsg dns.Msg if err := respMsg.Unpack(respData); err != nil { return "", fmt.Errorf("解析DNS响应失败 (数据可能损坏): %v", err) } if c.debugMode { utils.LogDebug("DoH DNS响应 - ID: %d, 答案数量: %d, 响应代码: %d (%s)", respMsg.Id, len(respMsg.Answer), respMsg.Rcode, dns.RcodeToString[respMsg.Rcode]) } // 检查DNS响应代码(符合RFC标准) if respMsg.Rcode != dns.RcodeSuccess { return "", fmt.Errorf("DNS响应错误,响应代码: %d (%s)", respMsg.Rcode, dns.RcodeToString[respMsg.Rcode]) } if len(respMsg.Answer) == 0 { return "", fmt.Errorf("DoH响应中没有找到答案记录") } // 查找A记录 for _, ans := range respMsg.Answer { if a, ok := ans.(*dns.A); ok { if c.debugMode { utils.LogDebug("DoH 找到A记录: %s (TTL: %d)", a.A.String(), a.Hdr.Ttl) } return a.A.String(), nil } } return "", fmt.Errorf("DoH响应中没有找到A记录") } // Test 测试DoH服务器 func (c *DoHClient) Test(server string) (time.Duration, error) { start := time.Now() _, err := c.Resolve("github.com", server) duration := time.Since(start) if err != nil { return 0, err } return duration, nil }