Skip to content

Commit 2a1e934

Browse files
authored
server: after GracefulStop, ensure connections are closed when final RPC completes (#5968)
Fixes #5930
1 parent e2d69aa commit 2a1e934

File tree

4 files changed

+89
-9
lines changed

4 files changed

+89
-9
lines changed

internal/transport/controlbuf.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,9 @@ const minBatchSize = 1000
527527
// As an optimization, to increase the batch size for each flush, loopy yields the processor, once
528528
// if the batch size is too low to give stream goroutines a chance to fill it up.
529529
func (l *loopyWriter) run() (err error) {
530+
// Always flush the writer before exiting in case there are pending frames
531+
// to be sent.
532+
defer l.framer.writer.Flush()
530533
for {
531534
it, err := l.cbuf.get(true)
532535
if err != nil {
@@ -759,7 +762,7 @@ func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
759762
return err
760763
}
761764
}
762-
if l.side == clientSide && l.draining && len(l.estdStreams) == 0 {
765+
if l.draining && len(l.estdStreams) == 0 {
763766
return errors.New("finished processing active streams while in draining mode")
764767
}
765768
return nil
@@ -814,7 +817,6 @@ func (l *loopyWriter) goAwayHandler(g *goAway) error {
814817
}
815818

816819
func (l *loopyWriter) closeConnectionHandler() error {
817-
l.framer.writer.Flush()
818820
// Exit loopyWriter entirely by returning an error here. This will lead to
819821
// the transport closing the connection, and, ultimately, transport
820822
// closure.

test/gracefulstop_test.go

+51
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"testing"
2727
"time"
2828

29+
"golang.org/x/net/http2"
2930
"google.golang.org/grpc"
3031
"google.golang.org/grpc/codes"
3132
"google.golang.org/grpc/credentials/insecure"
@@ -164,3 +165,53 @@ func (s) TestGracefulStop(t *testing.T) {
164165
cancel()
165166
wg.Wait()
166167
}
168+
169+
func (s) TestGracefulStopClosesConnAfterLastStream(t *testing.T) {
170+
// This test ensures that a server closes the connections to its clients
171+
// when the final stream has completed after a GOAWAY.
172+
173+
handlerCalled := make(chan struct{})
174+
gracefulStopCalled := make(chan struct{})
175+
176+
ts := &funcServer{streamingInputCall: func(stream testpb.TestService_StreamingInputCallServer) error {
177+
close(handlerCalled) // Initiate call to GracefulStop.
178+
<-gracefulStopCalled // Wait for GOAWAYs to be received by the client.
179+
return nil
180+
}}
181+
182+
te := newTest(t, tcpClearEnv)
183+
te.startServer(ts)
184+
defer te.tearDown()
185+
186+
te.withServerTester(func(st *serverTester) {
187+
st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall", false)
188+
189+
<-handlerCalled // Wait for the server to invoke its handler.
190+
191+
// Gracefully stop the server.
192+
gracefulStopDone := make(chan struct{})
193+
go func() {
194+
te.srv.GracefulStop()
195+
close(gracefulStopDone)
196+
}()
197+
st.wantGoAway(http2.ErrCodeNo) // Server sends a GOAWAY due to GracefulStop.
198+
pf := st.wantPing() // Server sends a ping to verify client receipt.
199+
st.writePing(true, pf.Data) // Send ping ack to confirm.
200+
st.wantGoAway(http2.ErrCodeNo) // Wait for subsequent GOAWAY to indicate no new stream processing.
201+
202+
close(gracefulStopCalled) // Unblock server handler.
203+
204+
fr := st.wantAnyFrame() // Wait for trailer.
205+
hdr, ok := fr.(*http2.MetaHeadersFrame)
206+
if !ok {
207+
t.Fatalf("Received unexpected frame of type (%T) from server: %v; want HEADERS", fr, fr)
208+
}
209+
if !hdr.StreamEnded() {
210+
t.Fatalf("Received unexpected HEADERS frame from server: %v; want END_STREAM set", fr)
211+
}
212+
213+
st.wantRSTStream(http2.ErrCodeNo) // Server should send RST_STREAM because client did not half-close.
214+
215+
<-gracefulStopDone // Wait for GracefulStop to return.
216+
})
217+
}

test/servertester.go

+31-4
Original file line numberDiff line numberDiff line change
@@ -138,19 +138,46 @@ func (st *serverTester) writeSettingsAck() {
138138
}
139139
}
140140

141+
func (st *serverTester) wantGoAway(errCode http2.ErrCode) *http2.GoAwayFrame {
142+
f, err := st.readFrame()
143+
if err != nil {
144+
st.t.Fatalf("Error while expecting an RST frame: %v", err)
145+
}
146+
gaf, ok := f.(*http2.GoAwayFrame)
147+
if !ok {
148+
st.t.Fatalf("got a %T; want *http2.GoAwayFrame", f)
149+
}
150+
if gaf.ErrCode != errCode {
151+
st.t.Fatalf("expected GOAWAY error code '%v', got '%v'", errCode.String(), gaf.ErrCode.String())
152+
}
153+
return gaf
154+
}
155+
156+
func (st *serverTester) wantPing() *http2.PingFrame {
157+
f, err := st.readFrame()
158+
if err != nil {
159+
st.t.Fatalf("Error while expecting an RST frame: %v", err)
160+
}
161+
pf, ok := f.(*http2.PingFrame)
162+
if !ok {
163+
st.t.Fatalf("got a %T; want *http2.GoAwayFrame", f)
164+
}
165+
return pf
166+
}
167+
141168
func (st *serverTester) wantRSTStream(errCode http2.ErrCode) *http2.RSTStreamFrame {
142169
f, err := st.readFrame()
143170
if err != nil {
144171
st.t.Fatalf("Error while expecting an RST frame: %v", err)
145172
}
146-
sf, ok := f.(*http2.RSTStreamFrame)
173+
rf, ok := f.(*http2.RSTStreamFrame)
147174
if !ok {
148175
st.t.Fatalf("got a %T; want *http2.RSTStreamFrame", f)
149176
}
150-
if sf.ErrCode != errCode {
151-
st.t.Fatalf("expected RST error code '%v', got '%v'", errCode.String(), sf.ErrCode.String())
177+
if rf.ErrCode != errCode {
178+
st.t.Fatalf("expected RST error code '%v', got '%v'", errCode.String(), rf.ErrCode.String())
152179
}
153-
return sf
180+
return rf
154181
}
155182

156183
func (st *serverTester) wantSettings() *http2.SettingsFrame {

test/stream_cleanup_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func (s) TestStreamCleanup(t *testing.T) {
4646
return &testpb.Empty{}, nil
4747
},
4848
}
49-
if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(callRecvMsgSize))), grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
49+
if err := ss.Start(nil, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(callRecvMsgSize))), grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
5050
t.Fatalf("Error starting endpoint server: %v", err)
5151
}
5252
defer ss.Stop()
@@ -79,7 +79,7 @@ func (s) TestStreamCleanupAfterSendStatus(t *testing.T) {
7979
})
8080
},
8181
}
82-
if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}, grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
82+
if err := ss.Start(nil, grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
8383
t.Fatalf("Error starting endpoint server: %v", err)
8484
}
8585
defer ss.Stop()
@@ -132,6 +132,6 @@ func (s) TestStreamCleanupAfterSendStatus(t *testing.T) {
132132
case <-gracefulStopDone:
133133
timer.Stop()
134134
case <-timer.C:
135-
t.Fatalf("s.GracefulStop() didn't finish without 1 second after the last RPC")
135+
t.Fatalf("s.GracefulStop() didn't finish within 1 second after the last RPC")
136136
}
137137
}

0 commit comments

Comments
 (0)