Skip to content

Commit b8bd243

Browse files
authored
Fix buffer.UDP destination override (#2356)
1 parent e013dce commit b8bd243

File tree

6 files changed

+75
-94
lines changed

6 files changed

+75
-94
lines changed

app/dispatcher/default.go

+28-79
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ package dispatcher
44

55
import (
66
"context"
7-
"fmt"
87
"strings"
98
"sync"
109
"time"
@@ -135,77 +134,10 @@ func (*DefaultDispatcher) Start() error {
135134
// Close implements common.Closable.
136135
func (*DefaultDispatcher) Close() error { return nil }
137136

138-
func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sniffing session.SniffingRequest) (*transport.Link, *transport.Link) {
139-
downOpt := pipe.OptionsFromContext(ctx)
140-
upOpt := downOpt
141-
142-
if network == net.Network_UDP {
143-
var ip2domain *sync.Map // net.IP.String() => domain, this map is used by server side when client turn on fakedns
144-
// Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs.
145-
// When target replies, server will restore the domain and send back to client.
146-
// Note: this map is not global but per connection context
147-
upOpt = append(upOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
148-
for i, buffer := range mb {
149-
if buffer.UDP == nil {
150-
continue
151-
}
152-
addr := buffer.UDP.Address
153-
if addr.Family().IsIP() {
154-
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(addr) && sniffing.Enabled {
155-
domain := fkr0.GetDomainFromFakeDNS(addr)
156-
if len(domain) > 0 {
157-
buffer.UDP.Address = net.DomainAddress(domain)
158-
newError("[fakedns client] override with domain: ", domain, " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
159-
} else {
160-
newError("[fakedns client] failed to find domain! :", addr.String(), " for xUDP buffer at ", i).AtWarning().WriteToLog(session.ExportIDToError(ctx))
161-
}
162-
}
163-
} else {
164-
if ip2domain == nil {
165-
ip2domain = new(sync.Map)
166-
newError("[fakedns client] create a new map").WriteToLog(session.ExportIDToError(ctx))
167-
}
168-
domain := addr.Domain()
169-
ips, err := d.dns.LookupIP(domain, dns.IPOption{true, true, false})
170-
if err == nil {
171-
for _, ip := range ips {
172-
ip2domain.Store(ip.String(), domain)
173-
}
174-
newError("[fakedns client] candidate ip: "+fmt.Sprintf("%v", ips), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
175-
} else {
176-
newError("[fakedns client] failed to look up IP for ", domain, " for xUDP buffer at ", i).Base(err).WriteToLog(session.ExportIDToError(ctx))
177-
}
178-
}
179-
}
180-
return mb
181-
}))
182-
downOpt = append(downOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
183-
for i, buffer := range mb {
184-
if buffer.UDP == nil {
185-
continue
186-
}
187-
addr := buffer.UDP.Address
188-
if addr.Family().IsIP() {
189-
if ip2domain == nil {
190-
continue
191-
}
192-
if domain, found := ip2domain.Load(addr.IP().String()); found {
193-
buffer.UDP.Address = net.DomainAddress(domain.(string))
194-
newError("[fakedns client] restore domain: ", domain.(string), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
195-
}
196-
} else {
197-
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok {
198-
fakeIp := fkr0.GetFakeIPForDomain(addr.Domain())
199-
buffer.UDP.Address = fakeIp[0]
200-
newError("[fakedns client] restore FakeIP: ", buffer.UDP, fmt.Sprintf("%v", fakeIp), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
201-
}
202-
}
203-
}
204-
return mb
205-
}))
206-
}
207-
uplinkReader, uplinkWriter := pipe.New(upOpt...)
208-
downlinkReader, downlinkWriter := pipe.New(downOpt...)
137+
func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) {
138+
opt := pipe.OptionsFromContext(ctx)
139+
uplinkReader, uplinkWriter := pipe.New(opt...)
140+
downlinkReader, downlinkWriter := pipe.New(opt...)
209141

210142
inboundLink := &transport.Link{
211143
Reader: downlinkReader,
@@ -263,7 +195,7 @@ func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResu
263195
protocolString = resComp.ProtocolForDomainResult()
264196
}
265197
for _, p := range request.OverrideDestinationForProtocol {
266-
if strings.HasPrefix(protocolString, p) {
198+
if strings.HasPrefix(protocolString, p) || strings.HasPrefix(protocolString, p) {
267199
return true
268200
}
269201
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
@@ -287,17 +219,17 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
287219
panic("Dispatcher: Invalid destination.")
288220
}
289221
ob := &session.Outbound{
290-
Target: destination,
222+
OriginalTarget: destination,
223+
Target: destination,
291224
}
292225
ctx = session.ContextWithOutbound(ctx, ob)
293226
content := session.ContentFromContext(ctx)
294227
if content == nil {
295228
content = new(session.Content)
296229
ctx = session.ContextWithContent(ctx, content)
297230
}
298-
299231
sniffingRequest := content.SniffingRequest
300-
inbound, outbound := d.getLink(ctx, destination.Network, sniffingRequest)
232+
inbound, outbound := d.getLink(ctx)
301233
if !sniffingRequest.Enabled {
302234
go d.routedDispatch(ctx, outbound, destination)
303235
} else {
@@ -314,7 +246,15 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
314246
domain := result.Domain()
315247
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
316248
destination.Address = net.ParseAddress(domain)
317-
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
249+
protocol := result.Protocol()
250+
if resComp, ok := result.(SnifferResultComposite); ok {
251+
protocol = resComp.ProtocolForDomainResult()
252+
}
253+
isFakeIP := false
254+
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) {
255+
isFakeIP = true
256+
}
257+
if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
318258
ob.RouteTarget = destination
319259
} else {
320260
ob.Target = destination
@@ -332,7 +272,8 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
332272
return newError("Dispatcher: Invalid destination.")
333273
}
334274
ob := &session.Outbound{
335-
Target: destination,
275+
OriginalTarget: destination,
276+
Target: destination,
336277
}
337278
ctx = session.ContextWithOutbound(ctx, ob)
338279
content := session.ContentFromContext(ctx)
@@ -356,7 +297,15 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
356297
domain := result.Domain()
357298
newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
358299
destination.Address = net.ParseAddress(domain)
359-
if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
300+
protocol := result.Protocol()
301+
if resComp, ok := result.(SnifferResultComposite); ok {
302+
protocol = resComp.ProtocolForDomainResult()
303+
}
304+
isFakeIP := false
305+
if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) {
306+
isFakeIP = true
307+
}
308+
if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
360309
ob.RouteTarget = destination
361310
} else {
362311
ob.Target = destination

app/proxyman/outbound/handler.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88

99
"github.com/xtls/xray-core/app/proxyman"
1010
"github.com/xtls/xray-core/common"
11+
"github.com/xtls/xray-core/common/buf"
1112
"github.com/xtls/xray-core/common/mux"
1213
"github.com/xtls/xray-core/common/net"
1314
"github.com/xtls/xray-core/common/net/cnc"
@@ -166,6 +167,11 @@ func (h *Handler) Tag() string {
166167

167168
// Dispatch implements proxy.Outbound.Dispatch.
168169
func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
170+
outbound := session.OutboundFromContext(ctx)
171+
if outbound.Target.Network == net.Network_UDP && outbound.OriginalTarget.Address != nil && outbound.OriginalTarget.Address != outbound.Target.Address {
172+
link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
173+
link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
174+
}
169175
if h.mux != nil {
170176
test := func(err error) {
171177
if err != nil {
@@ -175,7 +181,6 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
175181
common.Interrupt(link.Writer)
176182
}
177183
}
178-
outbound := session.OutboundFromContext(ctx)
179184
if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 {
180185
switch h.udp443 {
181186
case "reject":

common/buf/override.go

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package buf
2+
3+
import (
4+
"github.com/xtls/xray-core/common/net"
5+
)
6+
7+
type EndpointOverrideReader struct {
8+
Reader
9+
Dest net.Address
10+
OriginalDest net.Address
11+
}
12+
13+
func (r *EndpointOverrideReader) ReadMultiBuffer() (MultiBuffer, error) {
14+
mb, err := r.Reader.ReadMultiBuffer()
15+
if err == nil {
16+
for _, b := range mb {
17+
if b.UDP != nil && b.UDP.Address == r.OriginalDest {
18+
b.UDP.Address = r.Dest
19+
}
20+
}
21+
}
22+
return mb, err
23+
}
24+
25+
type EndpointOverrideWriter struct {
26+
Writer
27+
Dest net.Address
28+
OriginalDest net.Address
29+
}
30+
31+
func (w *EndpointOverrideWriter) WriteMultiBuffer(mb MultiBuffer) error {
32+
for _, b := range mb {
33+
if b.UDP != nil && b.UDP.Address == w.Dest {
34+
b.UDP.Address = w.OriginalDest
35+
}
36+
}
37+
return w.Writer.WriteMultiBuffer(mb)
38+
}

common/session/session.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ type Inbound struct {
5555
// Outbound is the metadata of an outbound connection.
5656
type Outbound struct {
5757
// Target address of the outbound connection.
58-
Target net.Destination
59-
RouteTarget net.Destination
58+
OriginalTarget net.Destination
59+
Target net.Destination
60+
RouteTarget net.Destination
6061
// Gateway address
6162
Gateway net.Address
6263
}

transport/pipe/impl.go

-5
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ const (
2424
type pipeOption struct {
2525
limit int32 // maximum buffer size in bytes
2626
discardOverflow bool
27-
onTransmission func(buffer buf.MultiBuffer) buf.MultiBuffer
2827
}
2928

3029
func (o *pipeOption) isFull(curSize int32) bool {
@@ -141,10 +140,6 @@ func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error {
141140
return nil
142141
}
143142

144-
if p.option.onTransmission != nil {
145-
mb = p.option.onTransmission(mb)
146-
}
147-
148143
for {
149144
err := p.writeMultiBufferInternal(mb)
150145
if err == nil {

transport/pipe/pipe.go

-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package pipe
33
import (
44
"context"
55

6-
"github.com/xtls/xray-core/common/buf"
76
"github.com/xtls/xray-core/common/signal"
87
"github.com/xtls/xray-core/common/signal/done"
98
"github.com/xtls/xray-core/features/policy"
@@ -26,12 +25,6 @@ func WithSizeLimit(limit int32) Option {
2625
}
2726
}
2827

29-
func OnTransmission(hook func(mb buf.MultiBuffer) buf.MultiBuffer) Option {
30-
return func(option *pipeOption) {
31-
option.onTransmission = hook
32-
}
33-
}
34-
3528
// DiscardOverflow returns an Option for Pipe to discard writes if full.
3629
func DiscardOverflow() Option {
3730
return func(opt *pipeOption) {

0 commit comments

Comments
 (0)