Skip to content

Commit d076e14

Browse files
hueyparkdfawley
authored andcommitted
rpc_util: Fix RecvBufferPool deactivation issues (#6766)
1 parent 9d981b0 commit d076e14

File tree

3 files changed

+188
-67
lines changed

3 files changed

+188
-67
lines changed

experimental/shared_buffer_pool_test.go

+149-47
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ import (
2626
"time"
2727

2828
"google.golang.org/grpc"
29+
"google.golang.org/grpc/encoding/gzip"
2930
"google.golang.org/grpc/experimental"
3031
"google.golang.org/grpc/internal/grpctest"
3132
"google.golang.org/grpc/internal/stubserver"
3233

3334
testgrpc "google.golang.org/grpc/interop/grpc_testing"
34-
testpb "google.golang.org/grpc/interop/grpc_testing"
3535
)
3636

3737
type s struct {
@@ -44,59 +44,161 @@ func Test(t *testing.T) {
4444

4545
const defaultTestTimeout = 10 * time.Second
4646

47-
func (s) TestRecvBufferPool(t *testing.T) {
48-
ss := &stubserver.StubServer{
49-
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
50-
for i := 0; i < 10; i++ {
51-
preparedMsg := &grpc.PreparedMsg{}
52-
err := preparedMsg.Encode(stream, &testpb.StreamingOutputCallResponse{
53-
Payload: &testpb.Payload{
54-
Body: []byte{'0' + uint8(i)},
55-
},
56-
})
47+
func (s) TestRecvBufferPoolStream(t *testing.T) {
48+
tcs := []struct {
49+
name string
50+
callOpts []grpc.CallOption
51+
}{
52+
{
53+
name: "default",
54+
},
55+
{
56+
name: "useCompressor",
57+
callOpts: []grpc.CallOption{
58+
grpc.UseCompressor(gzip.Name),
59+
},
60+
},
61+
}
62+
63+
for _, tc := range tcs {
64+
t.Run(tc.name, func(t *testing.T) {
65+
const reqCount = 10
66+
67+
ss := &stubserver.StubServer{
68+
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
69+
for i := 0; i < reqCount; i++ {
70+
preparedMsg := &grpc.PreparedMsg{}
71+
if err := preparedMsg.Encode(stream, &testgrpc.StreamingOutputCallResponse{
72+
Payload: &testgrpc.Payload{
73+
Body: []byte{'0' + uint8(i)},
74+
},
75+
}); err != nil {
76+
return err
77+
}
78+
stream.SendMsg(preparedMsg)
79+
}
80+
return nil
81+
},
82+
}
83+
84+
pool := &checkBufferPool{}
85+
sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
86+
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
87+
if err := ss.Start(sopts, dopts...); err != nil {
88+
t.Fatalf("Error starting endpoint server: %v", err)
89+
}
90+
defer ss.Stop()
91+
92+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
93+
defer cancel()
94+
95+
stream, err := ss.Client.FullDuplexCall(ctx, tc.callOpts...)
96+
if err != nil {
97+
t.Fatalf("ss.Client.FullDuplexCall failed: %v", err)
98+
}
99+
100+
var ngot int
101+
var buf bytes.Buffer
102+
for {
103+
reply, err := stream.Recv()
104+
if err == io.EOF {
105+
break
106+
}
57107
if err != nil {
58-
return err
108+
t.Fatal(err)
59109
}
60-
stream.SendMsg(preparedMsg)
110+
ngot++
111+
if buf.Len() > 0 {
112+
buf.WriteByte(',')
113+
}
114+
buf.Write(reply.GetPayload().GetBody())
61115
}
62-
return nil
63-
},
116+
if want := 10; ngot != want {
117+
t.Fatalf("Got %d replies, want %d", ngot, want)
118+
}
119+
if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
120+
t.Fatalf("Got replies %q; want %q", got, want)
121+
}
122+
123+
if len(pool.puts) != reqCount {
124+
t.Fatalf("Expected 10 buffers to be returned to the pool, got %d", len(pool.puts))
125+
}
126+
})
64127
}
65-
sopts := []grpc.ServerOption{experimental.RecvBufferPool(grpc.NewSharedBufferPool())}
66-
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(grpc.NewSharedBufferPool())}
67-
if err := ss.Start(sopts, dopts...); err != nil {
68-
t.Fatalf("Error starting endpoint server: %v", err)
128+
}
129+
130+
func (s) TestRecvBufferPoolUnary(t *testing.T) {
131+
tcs := []struct {
132+
name string
133+
callOpts []grpc.CallOption
134+
}{
135+
{
136+
name: "default",
137+
},
138+
{
139+
name: "useCompressor",
140+
callOpts: []grpc.CallOption{
141+
grpc.UseCompressor(gzip.Name),
142+
},
143+
},
69144
}
70-
defer ss.Stop()
71145

72-
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
73-
defer cancel()
146+
for _, tc := range tcs {
147+
t.Run(tc.name, func(t *testing.T) {
148+
const largeSize = 1024
74149

75-
stream, err := ss.Client.FullDuplexCall(ctx)
76-
if err != nil {
77-
t.Fatalf("ss.Client.FullDuplexCall failed: %f", err)
78-
}
150+
ss := &stubserver.StubServer{
151+
UnaryCallF: func(ctx context.Context, in *testgrpc.SimpleRequest) (*testgrpc.SimpleResponse, error) {
152+
return &testgrpc.SimpleResponse{
153+
Payload: &testgrpc.Payload{
154+
Body: make([]byte, largeSize),
155+
},
156+
}, nil
157+
},
158+
}
79159

80-
var ngot int
81-
var buf bytes.Buffer
82-
for {
83-
reply, err := stream.Recv()
84-
if err == io.EOF {
85-
break
86-
}
87-
if err != nil {
88-
t.Fatal(err)
89-
}
90-
ngot++
91-
if buf.Len() > 0 {
92-
buf.WriteByte(',')
93-
}
94-
buf.Write(reply.GetPayload().GetBody())
95-
}
96-
if want := 10; ngot != want {
97-
t.Errorf("Got %d replies, want %d", ngot, want)
98-
}
99-
if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
100-
t.Errorf("Got replies %q; want %q", got, want)
160+
pool := &checkBufferPool{}
161+
sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
162+
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
163+
if err := ss.Start(sopts, dopts...); err != nil {
164+
t.Fatalf("Error starting endpoint server: %v", err)
165+
}
166+
defer ss.Stop()
167+
168+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
169+
defer cancel()
170+
171+
const reqCount = 10
172+
for i := 0; i < reqCount; i++ {
173+
if _, err := ss.Client.UnaryCall(
174+
ctx,
175+
&testgrpc.SimpleRequest{
176+
Payload: &testgrpc.Payload{
177+
Body: make([]byte, largeSize),
178+
},
179+
},
180+
tc.callOpts...,
181+
); err != nil {
182+
t.Fatalf("ss.Client.UnaryCall failed: %v", err)
183+
}
184+
}
185+
186+
const bufferCount = reqCount * 2 // req + resp
187+
if len(pool.puts) != bufferCount {
188+
t.Fatalf("Expected %d buffers to be returned to the pool, got %d", bufferCount, len(pool.puts))
189+
}
190+
})
101191
}
102192
}
193+
194+
type checkBufferPool struct {
195+
puts [][]byte
196+
}
197+
198+
func (p *checkBufferPool) Get(size int) []byte {
199+
return make([]byte, size)
200+
}
201+
202+
func (p *checkBufferPool) Put(bs *[]byte) {
203+
p.puts = append(p.puts, *bs)
204+
}

rpc_util.go

+35-19
Original file line numberDiff line numberDiff line change
@@ -744,39 +744,55 @@ type payloadInfo struct {
744744
uncompressedBytes []byte
745745
}
746746

747-
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) ([]byte, error) {
748-
pf, buf, err := p.recvMsg(maxReceiveMessageSize)
747+
// recvAndDecompress reads a message from the stream, decompressing it if necessary.
748+
//
749+
// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
750+
// the buffer is no longer needed.
751+
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor,
752+
) (uncompressedBuf []byte, cancel func(), err error) {
753+
pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize)
749754
if err != nil {
750-
return nil, err
751-
}
752-
if payInfo != nil {
753-
payInfo.compressedLength = len(buf)
755+
return nil, nil, err
754756
}
755757

756758
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
757-
return nil, st.Err()
759+
return nil, nil, st.Err()
758760
}
759761

760762
var size int
761763
if pf == compressionMade {
762764
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
763765
// use this decompressor as the default.
764766
if dc != nil {
765-
buf, err = dc.Do(bytes.NewReader(buf))
766-
size = len(buf)
767+
uncompressedBuf, err = dc.Do(bytes.NewReader(compressedBuf))
768+
size = len(uncompressedBuf)
767769
} else {
768-
buf, size, err = decompress(compressor, buf, maxReceiveMessageSize)
770+
uncompressedBuf, size, err = decompress(compressor, compressedBuf, maxReceiveMessageSize)
769771
}
770772
if err != nil {
771-
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
773+
return nil, nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
772774
}
773775
if size > maxReceiveMessageSize {
774776
// TODO: Revisit the error code. Currently keep it consistent with java
775777
// implementation.
776-
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
778+
return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
777779
}
780+
} else {
781+
uncompressedBuf = compressedBuf
778782
}
779-
return buf, nil
783+
784+
if payInfo != nil {
785+
payInfo.compressedLength = len(compressedBuf)
786+
payInfo.uncompressedBytes = uncompressedBuf
787+
788+
cancel = func() {}
789+
} else {
790+
cancel = func() {
791+
p.recvBufferPool.Put(&compressedBuf)
792+
}
793+
}
794+
795+
return uncompressedBuf, cancel, nil
780796
}
781797

782798
// Using compressor, decompress d, returning data and size.
@@ -796,6 +812,9 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize
796812
// size is used as an estimate to size the buffer, but we
797813
// will read more data if available.
798814
// +MinRead so ReadFrom will not reallocate if size is correct.
815+
//
816+
// TODO: If we ensure that the buffer size is the same as the DecompressedSize,
817+
// we can also utilize the recv buffer pool here.
799818
buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
800819
bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
801820
return buf.Bytes(), int(bytesRead), err
@@ -811,18 +830,15 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize
811830
// dc takes precedence over compressor.
812831
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
813832
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
814-
buf, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
833+
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
815834
if err != nil {
816835
return err
817836
}
837+
defer cancel()
838+
818839
if err := c.Unmarshal(buf, m); err != nil {
819840
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
820841
}
821-
if payInfo != nil {
822-
payInfo.uncompressedBytes = buf
823-
} else {
824-
p.recvBufferPool.Put(&buf)
825-
}
826842
return nil
827843
}
828844

server.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -1342,7 +1342,8 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
13421342
if len(shs) != 0 || len(binlogs) != 0 {
13431343
payInfo = &payloadInfo{}
13441344
}
1345-
d, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
1345+
1346+
d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
13461347
if err != nil {
13471348
if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
13481349
channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e)
@@ -1353,6 +1354,8 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
13531354
t.IncrMsgRecv()
13541355
}
13551356
df := func(v any) error {
1357+
defer cancel()
1358+
13561359
if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
13571360
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
13581361
}

0 commit comments

Comments
 (0)