@@ -7,10 +7,8 @@ import (
7
7
"bytes"
8
8
"context"
9
9
"crypto/tls"
10
- "encoding/base64"
11
10
"io"
12
11
"net/http"
13
- "regexp"
14
12
"sync"
15
13
"time"
16
14
@@ -25,27 +23,23 @@ func ApplyECH(c *Config, config *tls.Config) error {
25
23
var err error
26
24
27
25
if len (c .EchConfig ) > 0 {
28
- ECHConfig , err = base64 .StdEncoding .DecodeString (c .EchConfig )
29
- if err != nil {
30
- return errors .New ("invalid ECH config" )
31
- }
26
+ ECHConfig = c .EchConfig
32
27
} else { // ECH config > DOH lookup
33
- if c .ServerName == "" {
28
+ if config .ServerName == "" {
34
29
return errors .New ("Using DOH for ECH needs serverName" )
35
30
}
36
- ECHRecord , err : = QueryRecord (c .ServerName , c .Ech_DOHserver )
31
+ ECHConfig , err = QueryRecord (c .ServerName , c .Ech_DOHserver )
37
32
if err != nil {
38
33
return err
39
34
}
40
- ECHConfig , _ = base64 .StdEncoding .DecodeString (ECHRecord )
41
35
}
42
36
43
37
config .EncryptedClientHelloConfigList = ECHConfig
44
38
return nil
45
39
}
46
40
47
41
type record struct {
48
- record string
42
+ record [] byte
49
43
expire time.Time
50
44
}
51
45
@@ -54,34 +48,40 @@ var (
54
48
mutex sync.RWMutex
55
49
)
56
50
57
- func QueryRecord (domain string , server string ) (string , error ) {
58
- rec , found := dnsCache [domain ]
59
- if found && rec .expire .After (time .Now ()) {
60
- return rec .record , nil
61
- }
62
- mutex .Lock ()
63
- defer mutex .Unlock ()
64
- errors .LogDebug (context .Background (), "Tring to query ECH config for domain: " , domain , " with ECH server: " , server )
65
- record , ttl , err := dohQuery (server , domain )
66
- if err != nil {
67
- return "" , err
68
- }
69
- // Use TTL for good, but many HTTPS records have TTL 60, too short
70
- if ttl < 600 {
71
- ttl = 600
72
- }
73
- rec .record = record
74
- rec .expire = time .Now ().Add (time .Second * time .Duration (ttl ))
75
- dnsCache [domain ] = rec
76
- return record , nil
51
+ func QueryRecord (domain string , server string ) ([]byte , error ) {
52
+ mutex .Lock ()
53
+ rec , found := dnsCache [domain ]
54
+ if found && rec .expire .After (time .Now ()) {
55
+ mutex .Unlock ()
56
+ return rec .record , nil
57
+ }
58
+ mutex .Unlock ()
59
+
60
+ errors .LogDebug (context .Background (), "Trying to query ECH config for domain: " , domain , " with ECH server: " , server )
61
+ record , ttl , err := dohQuery (server , domain )
62
+ if err != nil {
63
+ return []byte {}, err
64
+ }
65
+
66
+ if ttl < 600 {
67
+ ttl = 600
68
+ }
69
+
70
+ mutex .Lock ()
71
+ defer mutex .Unlock ()
72
+ rec .record = record
73
+ rec .expire = time .Now ().Add (time .Second * time .Duration (ttl ))
74
+ dnsCache [domain ] = rec
75
+ return record , nil
77
76
}
78
77
79
- func dohQuery (server string , domain string ) (string , uint32 , error ) {
78
+ func dohQuery (server string , domain string ) ([] byte , uint32 , error ) {
80
79
m := new (dns.Msg )
81
80
m .SetQuestion (dns .Fqdn (domain ), dns .TypeHTTPS )
81
+ m .Id = 0
82
82
msg , err := m .Pack ()
83
83
if err != nil {
84
- return "" , 0 , err
84
+ return [] byte {} , 0 , err
85
85
}
86
86
tr := & http.Transport {
87
87
IdleConnTimeout : 90 * time .Second ,
@@ -104,33 +104,37 @@ func dohQuery(server string, domain string) (string, uint32, error) {
104
104
}
105
105
req , err := http .NewRequest ("POST" , server , bytes .NewReader (msg ))
106
106
if err != nil {
107
- return "" , 0 , err
107
+ return [] byte {} , 0 , err
108
108
}
109
109
req .Header .Set ("Content-Type" , "application/dns-message" )
110
110
resp , err := client .Do (req )
111
111
if err != nil {
112
- return "" , 0 , err
112
+ return [] byte {} , 0 , err
113
113
}
114
114
defer resp .Body .Close ()
115
115
respBody , err := io .ReadAll (resp .Body )
116
116
if err != nil {
117
- return "" , 0 , err
117
+ return [] byte {} , 0 , err
118
118
}
119
119
if resp .StatusCode != http .StatusOK {
120
- return "" , 0 , errors .New ("query failed with response code:" , resp .StatusCode )
120
+ return [] byte {} , 0 , errors .New ("query failed with response code:" , resp .StatusCode )
121
121
}
122
122
respMsg := new (dns.Msg )
123
123
err = respMsg .Unpack (respBody )
124
124
if err != nil {
125
- return "" , 0 , err
125
+ return [] byte {} , 0 , err
126
126
}
127
127
if len (respMsg .Answer ) > 0 {
128
- re := regexp .MustCompile (`ech="([^"]+)"` )
129
- match := re .FindStringSubmatch (respMsg .Answer [0 ].String ())
130
- if match [1 ] != "" {
131
- errors .LogDebug (context .Background (), "Get ECH config:" , match [1 ], " TTL:" , respMsg .Answer [0 ].Header ().Ttl )
132
- return match [1 ], respMsg .Answer [0 ].Header ().Ttl , nil
128
+ for _ , answer := range respMsg .Answer {
129
+ if https , ok := answer .(* dns.HTTPS ); ok && https .Hdr .Name == dns .Fqdn (domain ) {
130
+ for _ , v := range https .Value {
131
+ if echConfig , ok := v .(* dns.SVCBECHConfig ); ok {
132
+ errors .LogDebug (context .Background (), "Get ECH config:" , echConfig .String (), " TTL:" , respMsg .Answer [0 ].Header ().Ttl )
133
+ return echConfig .ECH , answer .Header ().Ttl , nil
134
+ }
135
+ }
136
+ }
133
137
}
134
138
}
135
- return "" , 0 , errors .New ("no ech record found" )
139
+ return [] byte {} , 0 , errors .New ("no ech record found" )
136
140
}
0 commit comments