Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XTLS: More separate uplink/downlink flags for splice copy #4407

Merged
merged 1 commit into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions proxy/http/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
}
responseFunc := func() error {
ob.CanSpliceCopy = 1
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
}
Expand Down
1 change: 1 addition & 0 deletions proxy/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ func (s *Server) handleConnect(ctx context.Context, _ *http.Request, reader *buf
}

responseDone := func() error {
inbound.CanSpliceCopy = 1
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)

v2writer := buf.NewWriter(conn)
Expand Down
175 changes: 117 additions & 58 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,33 @@ type TrafficState struct {
IsTLS bool
Cipher uint16
RemainingServerHello int32
Inbound InboundState
Outbound OutboundState
}

type InboundState struct {
// reader link state
WithinPaddingBuffers bool
DownlinkReaderDirectCopy bool
UplinkReaderDirectCopy bool
RemainingCommand int32
RemainingContent int32
RemainingPadding int32
CurrentCommand int

// write link state
IsPadding bool
DownlinkWriterDirectCopy bool
}

type OutboundState struct {
// reader link state
WithinPaddingBuffers bool
DownlinkReaderDirectCopy bool
RemainingCommand int32
RemainingContent int32
RemainingPadding int32
CurrentCommand int
// write link state
IsPadding bool
UplinkWriterDirectCopy bool
}

Expand All @@ -132,16 +146,26 @@ func NewTrafficState(userUUID []byte) *TrafficState {
IsTLS: false,
Cipher: 0,
RemainingServerHello: -1,
WithinPaddingBuffers: true,
DownlinkReaderDirectCopy: false,
UplinkReaderDirectCopy: false,
RemainingCommand: -1,
RemainingContent: -1,
RemainingPadding: -1,
CurrentCommand: 0,
IsPadding: true,
DownlinkWriterDirectCopy: false,
UplinkWriterDirectCopy: false,
Inbound: InboundState{
WithinPaddingBuffers: true,
UplinkReaderDirectCopy: false,
RemainingCommand: -1,
RemainingContent: -1,
RemainingPadding: -1,
CurrentCommand: 0,
IsPadding: true,
DownlinkWriterDirectCopy: false,
},
Outbound: OutboundState{
WithinPaddingBuffers: true,
DownlinkReaderDirectCopy: false,
RemainingCommand: -1,
RemainingContent: -1,
RemainingPadding: -1,
CurrentCommand: 0,
IsPadding: true,
UplinkWriterDirectCopy: false,
},
}
}

Expand All @@ -166,28 +190,43 @@ func NewVisionReader(reader buf.Reader, state *TrafficState, isUplink bool, cont
func (w *VisionReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
buffer, err := w.Reader.ReadMultiBuffer()
if !buffer.IsEmpty() {
if w.trafficState.WithinPaddingBuffers || w.trafficState.NumberOfPacketToFilter > 0 {
var withinPaddingBuffers *bool
var remainingContent *int32
var remainingPadding *int32
var currentCommand *int
var switchToDirectCopy *bool
if w.isUplink {
withinPaddingBuffers = &w.trafficState.Inbound.WithinPaddingBuffers
remainingContent = &w.trafficState.Inbound.RemainingContent
remainingPadding = &w.trafficState.Inbound.RemainingPadding
currentCommand = &w.trafficState.Inbound.CurrentCommand
switchToDirectCopy = &w.trafficState.Inbound.UplinkReaderDirectCopy
} else {
withinPaddingBuffers = &w.trafficState.Outbound.WithinPaddingBuffers
remainingContent = &w.trafficState.Outbound.RemainingContent
remainingPadding = &w.trafficState.Outbound.RemainingPadding
currentCommand = &w.trafficState.Outbound.CurrentCommand
switchToDirectCopy = &w.trafficState.Outbound.DownlinkReaderDirectCopy
}

if *withinPaddingBuffers || w.trafficState.NumberOfPacketToFilter > 0 {
mb2 := make(buf.MultiBuffer, 0, len(buffer))
for _, b := range buffer {
newbuffer := XtlsUnpadding(b, w.trafficState, w.ctx)
newbuffer := XtlsUnpadding(b, w.trafficState, w.isUplink, w.ctx)
if newbuffer.Len() > 0 {
mb2 = append(mb2, newbuffer)
}
}
buffer = mb2
if w.trafficState.RemainingContent > 0 || w.trafficState.RemainingPadding > 0 || w.trafficState.CurrentCommand == 0 {
w.trafficState.WithinPaddingBuffers = true
} else if w.trafficState.CurrentCommand == 1 {
w.trafficState.WithinPaddingBuffers = false
} else if w.trafficState.CurrentCommand == 2 {
w.trafficState.WithinPaddingBuffers = false
if w.isUplink {
w.trafficState.UplinkReaderDirectCopy = true
} else {
w.trafficState.DownlinkReaderDirectCopy = true
}
if *remainingContent > 0 || *remainingPadding > 0 || *currentCommand == 0 {
*withinPaddingBuffers = true
} else if *currentCommand == 1 {
*withinPaddingBuffers = false
} else if *currentCommand == 2 {
*withinPaddingBuffers = false
*switchToDirectCopy = true
} else {
errors.LogInfo(w.ctx, "XtlsRead unknown command ", w.trafficState.CurrentCommand, buffer.Len())
errors.LogInfo(w.ctx, "XtlsRead unknown command ", *currentCommand, buffer.Len())
}
}
if w.trafficState.NumberOfPacketToFilter > 0 {
Expand Down Expand Up @@ -223,7 +262,16 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
if w.trafficState.NumberOfPacketToFilter > 0 {
XtlsFilterTls(mb, w.trafficState, w.ctx)
}
if w.trafficState.IsPadding {
var isPadding *bool
var switchToDirectCopy *bool
if w.isUplink {
isPadding = &w.trafficState.Outbound.IsPadding
switchToDirectCopy = &w.trafficState.Outbound.UplinkWriterDirectCopy
} else {
isPadding = &w.trafficState.Inbound.IsPadding
switchToDirectCopy = &w.trafficState.Inbound.DownlinkWriterDirectCopy
}
if *isPadding {
if len(mb) == 1 && mb[0] == nil {
mb[0] = XtlsPadding(nil, CommandPaddingContinue, &w.writeOnceUserUUID, true, w.ctx) // we do a long padding to hide vless header
return w.Writer.WriteMultiBuffer(mb)
Expand All @@ -233,11 +281,7 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
for i, b := range mb {
if w.trafficState.IsTLS && b.Len() >= 6 && bytes.Equal(TlsApplicationDataStart, b.BytesTo(3)) {
if w.trafficState.EnableXtls {
if w.isUplink {
w.trafficState.UplinkWriterDirectCopy = true
} else {
w.trafficState.DownlinkWriterDirectCopy = true
}
*switchToDirectCopy = true
}
var command byte = CommandPaddingContinue
if i == len(mb)-1 {
Expand All @@ -247,16 +291,16 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
}
}
mb[i] = XtlsPadding(b, command, &w.writeOnceUserUUID, true, w.ctx)
w.trafficState.IsPadding = false // padding going to end
*isPadding = false // padding going to end
longPadding = false
continue
} else if !w.trafficState.IsTLS12orAbove && w.trafficState.NumberOfPacketToFilter <= 1 { // For compatibility with earlier vision receiver, we finish padding 1 packet early
w.trafficState.IsPadding = false
*isPadding = false
mb[i] = XtlsPadding(b, CommandPaddingEnd, &w.writeOnceUserUUID, longPadding, w.ctx)
break
}
var command byte = CommandPaddingContinue
if i == len(mb)-1 && !w.trafficState.IsPadding {
if i == len(mb)-1 && !*isPadding {
command = CommandPaddingEnd
if w.trafficState.EnableXtls {
command = CommandPaddingDirect
Expand Down Expand Up @@ -343,38 +387,53 @@ func XtlsPadding(b *buf.Buffer, command byte, userUUID *[]byte, longPadding bool
}

// XtlsUnpadding remove padding and parse command
func XtlsUnpadding(b *buf.Buffer, s *TrafficState, ctx context.Context) *buf.Buffer {
if s.RemainingCommand == -1 && s.RemainingContent == -1 && s.RemainingPadding == -1 { // initial state
func XtlsUnpadding(b *buf.Buffer, s *TrafficState, isUplink bool, ctx context.Context) *buf.Buffer {
var remainingCommand *int32
var remainingContent *int32
var remainingPadding *int32
var currentCommand *int
if isUplink {
remainingCommand = &s.Inbound.RemainingCommand
remainingContent = &s.Inbound.RemainingContent
remainingPadding = &s.Inbound.RemainingPadding
currentCommand = &s.Inbound.CurrentCommand
} else {
remainingCommand = &s.Outbound.RemainingCommand
remainingContent = &s.Outbound.RemainingContent
remainingPadding = &s.Outbound.RemainingPadding
currentCommand = &s.Outbound.CurrentCommand
}
if *remainingCommand == -1 && *remainingContent == -1 && *remainingPadding == -1 { // initial state
if b.Len() >= 21 && bytes.Equal(s.UserUUID, b.BytesTo(16)) {
b.Advance(16)
s.RemainingCommand = 5
*remainingCommand = 5
} else {
return b
}
}
newbuffer := buf.New()
for b.Len() > 0 {
if s.RemainingCommand > 0 {
if *remainingCommand > 0 {
data, err := b.ReadByte()
if err != nil {
return newbuffer
}
switch s.RemainingCommand {
switch *remainingCommand {
case 5:
s.CurrentCommand = int(data)
*currentCommand = int(data)
case 4:
s.RemainingContent = int32(data) << 8
*remainingContent = int32(data) << 8
case 3:
s.RemainingContent = s.RemainingContent | int32(data)
*remainingContent = *remainingContent | int32(data)
case 2:
s.RemainingPadding = int32(data) << 8
*remainingPadding = int32(data) << 8
case 1:
s.RemainingPadding = s.RemainingPadding | int32(data)
errors.LogInfo(ctx, "Xtls Unpadding new block, content ", s.RemainingContent, " padding ", s.RemainingPadding, " command ", s.CurrentCommand)
*remainingPadding = *remainingPadding | int32(data)
errors.LogInfo(ctx, "Xtls Unpadding new block, content ", *remainingContent, " padding ", *remainingPadding, " command ", *currentCommand)
}
s.RemainingCommand--
} else if s.RemainingContent > 0 {
len := s.RemainingContent
*remainingCommand--
} else if *remainingContent > 0 {
len := *remainingContent
if b.Len() < len {
len = b.Len()
}
Expand All @@ -383,22 +442,22 @@ func XtlsUnpadding(b *buf.Buffer, s *TrafficState, ctx context.Context) *buf.Buf
return newbuffer
}
newbuffer.Write(data)
s.RemainingContent -= len
*remainingContent -= len
} else { // remainingPadding > 0
len := s.RemainingPadding
len := *remainingPadding
if b.Len() < len {
len = b.Len()
}
b.Advance(len)
s.RemainingPadding -= len
*remainingPadding -= len
}
if s.RemainingCommand <= 0 && s.RemainingContent <= 0 && s.RemainingPadding <= 0 { // this block done
if s.CurrentCommand == 0 {
s.RemainingCommand = 5
if *remainingCommand <= 0 && *remainingContent <= 0 && *remainingPadding <= 0 { // this block done
if *currentCommand == 0 {
*remainingCommand = 5
} else {
s.RemainingCommand = -1 // set to initial state
s.RemainingContent = -1
s.RemainingPadding = -1
*remainingCommand = -1 // set to initial state
*remainingContent = -1
*remainingPadding = -1
if b.Len() > 0 { // shouldn't happen
newbuffer.Write(b.Bytes())
}
Expand Down
2 changes: 2 additions & 0 deletions proxy/socks/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
}
responseFunc = func() error {
ob.CanSpliceCopy = 1
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
}
Expand All @@ -161,6 +162,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer))
}
responseFunc = func() error {
ob.CanSpliceCopy = 1
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
reader := &UDPReader{Reader: udpConn}
return buf.Copy(reader, link.Writer, buf.UpdateActivity(timer))
Expand Down
2 changes: 2 additions & 0 deletions proxy/socks/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
}

responseDone := func() error {
inbound.CanSpliceCopy = 1
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)

v2writer := buf.NewWriter(writer)
Expand Down Expand Up @@ -256,6 +257,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
if inbound != nil && inbound.Source.IsValid() {
errors.LogInfo(ctx, "client UDP connection from ", inbound.Source)
}
inbound.CanSpliceCopy = 1

var dest *net.Destination

Expand Down
18 changes: 9 additions & 9 deletions proxy/vless/encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,16 +175,16 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A
func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, isUplink bool, ctx context.Context) error {
err := func() error {
for {
if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy {
if isUplink && trafficState.Inbound.UplinkReaderDirectCopy || !isUplink && trafficState.Outbound.DownlinkReaderDirectCopy {
var writerConn net.Conn
var inTimer *signal.ActivityTimer
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil {
writerConn = inbound.Conn
inTimer = inbound.Timer
if inbound.CanSpliceCopy == 2 {
if isUplink && inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1
}
if ob != nil && ob.CanSpliceCopy == 2 { // ob need to be passed in due to context can change
if !isUplink && ob != nil && ob.CanSpliceCopy == 2 { // ob need to be passed in due to context can change
ob.CanSpliceCopy = 1
}
}
Expand All @@ -193,7 +193,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer,
buffer, err := reader.ReadMultiBuffer()
if !buffer.IsEmpty() {
timer.Update()
if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy {
if isUplink && trafficState.Inbound.UplinkReaderDirectCopy || !isUplink && trafficState.Outbound.DownlinkReaderDirectCopy {
// XTLS Vision processes struct TLS Conn's input and rawInput
if inputBuffer, err := buf.ReadFrom(input); err == nil {
if !inputBuffer.IsEmpty() {
Expand Down Expand Up @@ -227,22 +227,22 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
var ct stats.Counter
for {
buffer, err := reader.ReadMultiBuffer()
if isUplink && trafficState.UplinkWriterDirectCopy || !isUplink && trafficState.DownlinkWriterDirectCopy {
if isUplink && trafficState.Outbound.UplinkWriterDirectCopy || !isUplink && trafficState.Inbound.DownlinkWriterDirectCopy {
if inbound := session.InboundFromContext(ctx); inbound != nil {
if inbound.CanSpliceCopy == 2 {
if !isUplink && inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1
}
if ob != nil && ob.CanSpliceCopy == 2 {
if isUplink && ob != nil && ob.CanSpliceCopy == 2 {
ob.CanSpliceCopy = 1
}
}
rawConn, _, writerCounter := proxy.UnwrapRawConn(conn)
writer = buf.NewWriter(rawConn)
ct = writerCounter
if isUplink {
trafficState.UplinkWriterDirectCopy = false
trafficState.Outbound.UplinkWriterDirectCopy = false
} else {
trafficState.DownlinkWriterDirectCopy = false
trafficState.Inbound.DownlinkWriterDirectCopy = false
}
}
if !buffer.IsEmpty() {
Expand Down
Loading