Skip to content

Commit a343d68

Browse files
cty123yuhan6665
authored andcommitted
fix(proxy): removed the udp payload length check when encryption is disabled
1 parent f67167b commit a343d68

File tree

3 files changed

+106
-52
lines changed

3 files changed

+106
-52
lines changed

proxy/shadowsocks/protocol.go

+34-28
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"crypto/hmac"
55
"crypto/rand"
66
"crypto/sha256"
7+
"errors"
78
"hash/crc32"
89
"io"
910

@@ -236,37 +237,37 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff
236237
}
237238

238239
func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) {
239-
bs := payload.Bytes()
240-
if len(bs) <= 32 {
241-
return nil, nil, newError("len(bs) <= 32")
242-
}
240+
rawPayload := payload.Bytes()
241+
user, _, d, _, err := validator.Get(rawPayload, protocol.RequestCommandUDP)
243242

244-
user, _, d, _, err := validator.Get(bs, protocol.RequestCommandUDP)
245-
switch err {
246-
case ErrIVNotUnique:
243+
if errors.Is(err, ErrIVNotUnique) {
247244
return nil, nil, newError("failed iv check").Base(err)
248-
case ErrNotFound:
245+
}
246+
247+
if errors.Is(err, ErrNotFound) {
249248
return nil, nil, newError("failed to match an user").Base(err)
250-
default:
251-
account := user.Account.(*MemoryAccount)
252-
if account.Cipher.IsAEAD() {
253-
payload.Clear()
254-
payload.Write(d)
255-
} else {
256-
if account.Cipher.IVSize() > 0 {
257-
iv := make([]byte, account.Cipher.IVSize())
258-
copy(iv, payload.BytesTo(account.Cipher.IVSize()))
259-
}
260-
if err = account.Cipher.DecodePacket(account.Key, payload); err != nil {
261-
return nil, nil, newError("failed to decrypt UDP payload").Base(err)
262-
}
263-
}
264249
}
265250

266-
request := &protocol.RequestHeader{
267-
Version: Version,
268-
User: user,
269-
Command: protocol.RequestCommandUDP,
251+
if err != nil {
252+
return nil, nil, newError("unexpected error").Base(err)
253+
}
254+
255+
account, ok := user.Account.(*MemoryAccount)
256+
if !ok {
257+
return nil, nil, newError("expected MemoryAccount returned from validator")
258+
}
259+
260+
if account.Cipher.IsAEAD() {
261+
payload.Clear()
262+
payload.Write(d)
263+
} else {
264+
if account.Cipher.IVSize() > 0 {
265+
iv := make([]byte, account.Cipher.IVSize())
266+
copy(iv, payload.BytesTo(account.Cipher.IVSize()))
267+
}
268+
if err = account.Cipher.DecodePacket(account.Key, payload); err != nil {
269+
return nil, nil, newError("failed to decrypt UDP payload").Base(err)
270+
}
270271
}
271272

272273
payload.SetByte(0, payload.Byte(0)&0x0F)
@@ -276,8 +277,13 @@ func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.Reque
276277
return nil, nil, newError("failed to parse address").Base(err)
277278
}
278279

279-
request.Address = addr
280-
request.Port = port
280+
request := &protocol.RequestHeader{
281+
Version: Version,
282+
User: user,
283+
Command: protocol.RequestCommandUDP,
284+
Address: addr,
285+
Port: port,
286+
}
281287

282288
return request, payload, nil
283289
}

proxy/shadowsocks/protocol_test.go

+67-24
Original file line numberDiff line numberDiff line change
@@ -23,37 +23,80 @@ func equalRequestHeader(x, y *protocol.RequestHeader) bool {
2323
}))
2424
}
2525

26-
func TestUDPEncoding(t *testing.T) {
27-
request := &protocol.RequestHeader{
28-
Version: Version,
29-
Command: protocol.RequestCommandUDP,
30-
Address: net.LocalHostIP,
31-
Port: 1234,
32-
User: &protocol.MemoryUser{
33-
Email: "love@example.com",
34-
Account: toAccount(&Account{
35-
Password: "password",
36-
CipherType: CipherType_AES_128_GCM,
37-
}),
26+
func TestUDPEncodingDecoding(t *testing.T) {
27+
testRequests := []protocol.RequestHeader{
28+
{
29+
Version: Version,
30+
Command: protocol.RequestCommandUDP,
31+
Address: net.LocalHostIP,
32+
Port: 1234,
33+
User: &protocol.MemoryUser{
34+
Email: "love@example.com",
35+
Account: toAccount(&Account{
36+
Password: "password",
37+
CipherType: CipherType_AES_128_GCM,
38+
}),
39+
},
40+
},
41+
{
42+
Version: Version,
43+
Command: protocol.RequestCommandUDP,
44+
Address: net.LocalHostIP,
45+
Port: 1234,
46+
User: &protocol.MemoryUser{
47+
Email: "love@example.com",
48+
Account: toAccount(&Account{
49+
Password: "123",
50+
CipherType: CipherType_NONE,
51+
}),
52+
},
3853
},
3954
}
4055

41-
data := buf.New()
42-
common.Must2(data.WriteString("test string"))
43-
encodedData, err := EncodeUDPPacket(request, data.Bytes())
44-
common.Must(err)
56+
for _, request := range testRequests {
57+
data := buf.New()
58+
common.Must2(data.WriteString("test string"))
59+
encodedData, err := EncodeUDPPacket(&request, data.Bytes())
60+
common.Must(err)
4561

46-
validator := new(Validator)
47-
validator.Add(request.User)
48-
decodedRequest, decodedData, err := DecodeUDPPacket(validator, encodedData)
49-
common.Must(err)
62+
validator := new(Validator)
63+
validator.Add(request.User)
64+
decodedRequest, decodedData, err := DecodeUDPPacket(validator, encodedData)
65+
common.Must(err)
5066

51-
if r := cmp.Diff(decodedData.Bytes(), data.Bytes()); r != "" {
52-
t.Error("data: ", r)
67+
if r := cmp.Diff(decodedData.Bytes(), data.Bytes()); r != "" {
68+
t.Error("data: ", r)
69+
}
70+
71+
if equalRequestHeader(decodedRequest, &request) == false {
72+
t.Error("different request")
73+
}
5374
}
75+
}
5476

55-
if equalRequestHeader(decodedRequest, request) == false {
56-
t.Error("different request")
77+
func TestUDPDecodingWithPayloadTooShort(t *testing.T) {
78+
testAccounts := []protocol.Account{
79+
toAccount(&Account{
80+
Password: "password",
81+
CipherType: CipherType_AES_128_GCM,
82+
}),
83+
toAccount(&Account{
84+
Password: "password",
85+
CipherType: CipherType_NONE,
86+
}),
87+
}
88+
89+
for _, account := range testAccounts {
90+
data := buf.New()
91+
data.WriteString("short payload")
92+
validator := new(Validator)
93+
validator.Add(&protocol.MemoryUser{
94+
Account: account,
95+
})
96+
_, _, err := DecodeUDPPacket(validator, data)
97+
if err == nil {
98+
t.Fatal("expected error")
99+
}
57100
}
58101
}
59102

proxy/shadowsocks/validator.go

+5
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ func (v *Validator) Get(bs []byte, command protocol.RequestCommand) (u *protocol
8080

8181
for _, user := range v.users {
8282
if account := user.Account.(*MemoryAccount); account.Cipher.IsAEAD() {
83+
// AEAD payload decoding requires the payload to be over 32 bytes
84+
if len(bs) < 32 {
85+
continue
86+
}
87+
8388
aeadCipher := account.Cipher.(*AEADCipher)
8489
ivLen = aeadCipher.IVSize()
8590
iv := bs[:ivLen]

0 commit comments

Comments
 (0)