Skip to content

Commit 8cb63db

Browse files
authored
XHTTP server: Set remoteAddr & localAddr correctly
Completes 22c50a7
1 parent eef74b2 commit 8cb63db

File tree

5 files changed

+77
-71
lines changed

5 files changed

+77
-71
lines changed

common/net/system.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,9 @@ type (
7676
)
7777

7878
var (
79-
ResolveUnixAddr = net.ResolveUnixAddr
79+
ResolveTCPAddr = net.ResolveTCPAddr
8080
ResolveUDPAddr = net.ResolveUDPAddr
81+
ResolveUnixAddr = net.ResolveUnixAddr
8182
)
8283

8384
type Resolver = net.Resolver

proxy/proxy.go

+19-16
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,12 @@ type TrafficState struct {
113113

114114
type InboundState struct {
115115
// reader link state
116-
WithinPaddingBuffers bool
117-
UplinkReaderDirectCopy bool
118-
RemainingCommand int32
119-
RemainingContent int32
120-
RemainingPadding int32
121-
CurrentCommand int
116+
WithinPaddingBuffers bool
117+
UplinkReaderDirectCopy bool
118+
RemainingCommand int32
119+
RemainingContent int32
120+
RemainingPadding int32
121+
CurrentCommand int
122122
// write link state
123123
IsPadding bool
124124
DownlinkWriterDirectCopy bool
@@ -133,19 +133,19 @@ type OutboundState struct {
133133
RemainingPadding int32
134134
CurrentCommand int
135135
// write link state
136-
IsPadding bool
137-
UplinkWriterDirectCopy bool
136+
IsPadding bool
137+
UplinkWriterDirectCopy bool
138138
}
139139

140140
func NewTrafficState(userUUID []byte) *TrafficState {
141141
return &TrafficState{
142-
UserUUID: userUUID,
143-
NumberOfPacketToFilter: 8,
144-
EnableXtls: false,
145-
IsTLS12orAbove: false,
146-
IsTLS: false,
147-
Cipher: 0,
148-
RemainingServerHello: -1,
142+
UserUUID: userUUID,
143+
NumberOfPacketToFilter: 8,
144+
EnableXtls: false,
145+
IsTLS12orAbove: false,
146+
IsTLS: false,
147+
Cipher: 0,
148+
RemainingServerHello: -1,
149149
Inbound: InboundState{
150150
WithinPaddingBuffers: true,
151151
UplinkReaderDirectCopy: false,
@@ -524,7 +524,7 @@ func XtlsFilterTls(buffer buf.MultiBuffer, trafficState *TrafficState, ctx conte
524524
}
525525
}
526526

527-
// UnwrapRawConn support unwrap stats, tls, utls, reality and proxyproto conn and get raw tcp conn from it
527+
// UnwrapRawConn support unwrap stats, tls, utls, reality, proxyproto, uds-wrapper conn and get raw tcp/uds conn from it
528528
func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) {
529529
var readCounter, writerCounter stats.Counter
530530
if conn != nil {
@@ -547,6 +547,9 @@ func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) {
547547
conn = pc.Raw()
548548
// 8192 > 4096, there is no need to process pc's bufReader
549549
}
550+
if uc, ok := conn.(*internet.UDSWrapperConn); ok {
551+
conn = uc.Conn
552+
}
550553
}
551554
return conn, readCounter, writerCounter
552555
}

transport/internet/splithttp/hub.go

+40-38
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@ package splithttp
33
import (
44
"bytes"
55
"context"
6-
"crypto/tls"
6+
gotls "crypto/tls"
77
"io"
8-
gonet "net"
98
"net/http"
109
"net/url"
1110
"strconv"
@@ -24,7 +23,7 @@ import (
2423
"github.com/xtls/xray-core/transport/internet"
2524
"github.com/xtls/xray-core/transport/internet/reality"
2625
"github.com/xtls/xray-core/transport/internet/stat"
27-
v2tls "github.com/xtls/xray-core/transport/internet/tls"
26+
"github.com/xtls/xray-core/transport/internet/tls"
2827
"golang.org/x/net/http2"
2928
"golang.org/x/net/http2/h2c"
3029
)
@@ -36,7 +35,7 @@ type requestHandler struct {
3635
ln *Listener
3736
sessionMu *sync.Mutex
3837
sessions sync.Map
39-
localAddr gonet.TCPAddr
38+
localAddr net.Addr
4039
}
4140

4241
type httpSession struct {
@@ -144,14 +143,25 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
144143
}
145144

146145
forwardedAddrs := http_proto.ParseXForwardedFor(request.Header)
147-
remoteAddr, err := gonet.ResolveTCPAddr("tcp", request.RemoteAddr)
146+
var remoteAddr net.Addr
147+
var err error
148+
remoteAddr, err = net.ResolveTCPAddr("tcp", request.RemoteAddr)
148149
if err != nil {
149-
remoteAddr = &gonet.TCPAddr{}
150+
remoteAddr = &net.TCPAddr{
151+
IP: []byte{0, 0, 0, 0},
152+
Port: 0,
153+
}
154+
}
155+
if request.ProtoMajor == 3 {
156+
remoteAddr = &net.UDPAddr{
157+
IP: remoteAddr.(*net.TCPAddr).IP,
158+
Port: remoteAddr.(*net.TCPAddr).Port,
159+
}
150160
}
151161
if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() {
152162
remoteAddr = &net.TCPAddr{
153163
IP: forwardedAddrs[0].IP(),
154-
Port: int(0),
164+
Port: 0,
155165
}
156166
}
157167

@@ -289,6 +299,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
289299
responseFlusher: responseFlusher,
290300
},
291301
reader: request.Body,
302+
localAddr: h.localAddr,
292303
remoteAddr: remoteAddr,
293304
}
294305
if sessionId != "" { // if not stream-one
@@ -362,34 +373,30 @@ type Listener struct {
362373
isH3 bool
363374
}
364375

365-
func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
376+
func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
366377
l := &Listener{
367378
addConn: addConn,
368379
}
369-
shSettings := streamSettings.ProtocolSettings.(*Config)
370-
l.config = shSettings
380+
l.config = streamSettings.ProtocolSettings.(*Config)
371381
if l.config != nil {
372382
if streamSettings.SocketSettings == nil {
373383
streamSettings.SocketSettings = &internet.SocketConfig{}
374384
}
375385
}
376-
var listener net.Listener
377-
var err error
378-
var localAddr = gonet.TCPAddr{}
379386
handler := &requestHandler{
380-
config: shSettings,
381-
host: shSettings.Host,
382-
path: shSettings.GetNormalizedPath(),
387+
config: l.config,
388+
host: l.config.Host,
389+
path: l.config.GetNormalizedPath(),
383390
ln: l,
384391
sessionMu: &sync.Mutex{},
385392
sessions: sync.Map{},
386-
localAddr: localAddr,
387393
}
388394
tlsConfig := getTLSConfig(streamSettings)
389395
l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"
390396

397+
var err error
391398
if port == net.Port(0) { // unix
392-
listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
399+
l.listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
393400
Name: address.Domain(),
394401
Net: "unix",
395402
}, streamSettings.SocketSettings)
@@ -405,13 +412,14 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
405412
if err != nil {
406413
return nil, errors.New("failed to listen UDP for XHTTP/3 on ", address, ":", port).Base(err)
407414
}
408-
h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil)
415+
l.h3listener, err = quic.ListenEarly(Conn, tlsConfig, nil)
409416
if err != nil {
410417
return nil, errors.New("failed to listen QUIC for XHTTP/3 on ", address, ":", port).Base(err)
411418
}
412-
l.h3listener = h3listener
413419
errors.LogInfo(ctx, "listening QUIC for XHTTP/3 on ", address, ":", port)
414420

421+
handler.localAddr = l.h3listener.Addr()
422+
415423
l.h3server = &http3.Server{
416424
Handler: handler,
417425
}
@@ -421,11 +429,7 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
421429
}
422430
}()
423431
} else { // tcp
424-
localAddr = gonet.TCPAddr{
425-
IP: address.IP(),
426-
Port: int(port),
427-
}
428-
listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
432+
l.listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
429433
IP: address.IP(),
430434
Port: int(port),
431435
}, streamSettings.SocketSettings)
@@ -436,26 +440,24 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet
436440
}
437441

438442
// tcp/unix (h1/h2)
439-
if listener != nil {
440-
if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
443+
if l.listener != nil {
444+
if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
441445
if tlsConfig := config.GetTLSConfig(); tlsConfig != nil {
442-
listener = tls.NewListener(listener, tlsConfig)
446+
l.listener = gotls.NewListener(l.listener, tlsConfig)
443447
}
444448
}
445-
446449
if config := reality.ConfigFromStreamSettings(streamSettings); config != nil {
447-
listener = goreality.NewListener(listener, config.GetREALITYConfig())
450+
l.listener = goreality.NewListener(l.listener, config.GetREALITYConfig())
448451
}
449452

453+
handler.localAddr = l.listener.Addr()
454+
450455
// h2cHandler can handle both plaintext HTTP/1.1 and h2c
451-
h2cHandler := h2c.NewHandler(handler, &http2.Server{})
452-
l.listener = listener
453456
l.server = http.Server{
454-
Handler: h2cHandler,
457+
Handler: h2c.NewHandler(handler, &http2.Server{}),
455458
ReadHeaderTimeout: time.Second * 4,
456459
MaxHeaderBytes: 8192,
457460
}
458-
459461
go func() {
460462
if err := l.server.Serve(l.listener); err != nil {
461463
errors.LogWarningInner(ctx, err, "failed to serve HTTP for XHTTP")
@@ -488,13 +490,13 @@ func (ln *Listener) Close() error {
488490
}
489491
return errors.New("listener does not have an HTTP/3 server or a net.listener")
490492
}
491-
func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *tls.Config {
492-
config := v2tls.ConfigFromStreamSettings(streamSettings)
493+
func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *gotls.Config {
494+
config := tls.ConfigFromStreamSettings(streamSettings)
493495
if config == nil {
494-
return &tls.Config{}
496+
return &gotls.Config{}
495497
}
496498
return config.GetTLSConfig()
497499
}
498500
func init() {
499-
common.Must(internet.RegisterTransportListener(protocolName, ListenSH))
501+
common.Must(internet.RegisterTransportListener(protocolName, ListenXH))
500502
}

transport/internet/splithttp/splithttp_test.go

+13-13
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ import (
2626
"golang.org/x/net/http2"
2727
)
2828

29-
func Test_listenSHAndDial(t *testing.T) {
29+
func Test_ListenXHAndDial(t *testing.T) {
3030
listenPort := tcp.PickPort()
31-
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
31+
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
3232
ProtocolName: "splithttp",
3333
ProtocolSettings: &Config{
3434
Path: "/sh",
@@ -85,7 +85,7 @@ func Test_listenSHAndDial(t *testing.T) {
8585

8686
func TestDialWithRemoteAddr(t *testing.T) {
8787
listenPort := tcp.PickPort()
88-
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
88+
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
8989
ProtocolName: "splithttp",
9090
ProtocolSettings: &Config{
9191
Path: "sh",
@@ -125,7 +125,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
125125
common.Must(listen.Close())
126126
}
127127

128-
func Test_listenSHAndDial_TLS(t *testing.T) {
128+
func Test_ListenXHAndDial_TLS(t *testing.T) {
129129
if runtime.GOARCH == "arm64" {
130130
return
131131
}
@@ -145,7 +145,7 @@ func Test_listenSHAndDial_TLS(t *testing.T) {
145145
Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))},
146146
},
147147
}
148-
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
148+
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
149149
go func() {
150150
defer conn.Close()
151151

@@ -180,7 +180,7 @@ func Test_listenSHAndDial_TLS(t *testing.T) {
180180
}
181181
}
182182

183-
func Test_listenSHAndDial_H2C(t *testing.T) {
183+
func Test_ListenXHAndDial_H2C(t *testing.T) {
184184
if runtime.GOARCH == "arm64" {
185185
return
186186
}
@@ -193,7 +193,7 @@ func Test_listenSHAndDial_H2C(t *testing.T) {
193193
Path: "shs",
194194
},
195195
}
196-
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
196+
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
197197
go func() {
198198
_ = conn.Close()
199199
}()
@@ -227,7 +227,7 @@ func Test_listenSHAndDial_H2C(t *testing.T) {
227227
}
228228
}
229229

230-
func Test_listenSHAndDial_QUIC(t *testing.T) {
230+
func Test_ListenXHAndDial_QUIC(t *testing.T) {
231231
if runtime.GOARCH == "arm64" {
232232
return
233233
}
@@ -250,7 +250,7 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
250250
}
251251

252252
serverClosed := false
253-
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
253+
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
254254
go func() {
255255
defer conn.Close()
256256

@@ -309,11 +309,11 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
309309
}
310310
}
311311

312-
func Test_listenSHAndDial_Unix(t *testing.T) {
312+
func Test_ListenXHAndDial_Unix(t *testing.T) {
313313
tempDir := t.TempDir()
314314
tempSocket := tempDir + "/server.sock"
315315

316-
listen, err := ListenSH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{
316+
listen, err := ListenXH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{
317317
ProtocolName: "splithttp",
318318
ProtocolSettings: &Config{
319319
Path: "/sh",
@@ -373,7 +373,7 @@ func Test_listenSHAndDial_Unix(t *testing.T) {
373373

374374
func Test_queryString(t *testing.T) {
375375
listenPort := tcp.PickPort()
376-
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
376+
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
377377
ProtocolName: "splithttp",
378378
ProtocolSettings: &Config{
379379
// this querystring does not have any effect, but sometimes people blindly copy it from websocket config. make sure the outbound doesn't break
@@ -431,7 +431,7 @@ func Test_maxUpload(t *testing.T) {
431431
}
432432

433433
var uploadSize int
434-
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
434+
listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
435435
go func(c stat.Connection) {
436436
defer c.Close()
437437
var b [10240]byte

transport/internet/system_listener.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func (l *listenUDSWrapper) Accept() (net.Conn, error) {
5454
if err != nil {
5555
return nil, err
5656
}
57-
return &listenUDSWrapperConn{Conn: conn}, nil
57+
return &UDSWrapperConn{Conn: conn}, nil
5858
}
5959

6060
func (l *listenUDSWrapper) Close() error {
@@ -65,11 +65,11 @@ func (l *listenUDSWrapper) Close() error {
6565
return l.Listener.Close()
6666
}
6767

68-
type listenUDSWrapperConn struct {
68+
type UDSWrapperConn struct {
6969
net.Conn
7070
}
7171

72-
func (conn *listenUDSWrapperConn) RemoteAddr() net.Addr {
72+
func (conn *UDSWrapperConn) RemoteAddr() net.Addr {
7373
return &net.TCPAddr{
7474
IP: []byte{0, 0, 0, 0},
7575
}

0 commit comments

Comments
 (0)