Skip to content

Commit a210f77

Browse files
committed
Cancel SSE publisher in case of async timeouts
Prior to this commit, the MVC `GraphQlSseHandler` would not react to async request timeouts thrown by the Servlet container. This means that when such timeouts happened, the SSE handler would still try to write to the underlying response, whereas it was already recycled. This would lead to NullPointerException thrown by the container. This commit ensures that the SSE handler registers an async listener to be notified of async timeouts and cancels the publisher as a result. The SSE completion is not performed so as to let the client know that the exchange did not complete and that it should re-subscribe. Fixes gh-1067
1 parent 2650eb5 commit a210f77

File tree

2 files changed

+39
-6
lines changed

2 files changed

+39
-6
lines changed

spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlSseHandler.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.graphql.server.WebGraphQlResponse;
3434
import org.springframework.util.AlternativeJdkIdGenerator;
3535
import org.springframework.util.IdGenerator;
36+
import org.springframework.web.context.request.async.AsyncRequestTimeoutException;
3637
import org.springframework.web.servlet.function.ServerRequest;
3738
import org.springframework.web.servlet.function.ServerResponse;
3839

@@ -93,6 +94,12 @@ private static final class SseSubscriber extends BaseSubscriber<Map<String, Obje
9394

9495
private SseSubscriber(ServerResponse.SseBuilder sseBuilder) {
9596
this.sseBuilder = sseBuilder;
97+
this.sseBuilder.onTimeout(this::onTimeout);
98+
}
99+
100+
private void onTimeout() {
101+
this.cancel();
102+
this.sseBuilder.error(new AsyncRequestTimeoutException());
96103
}
97104

98105
@Override
@@ -116,11 +123,11 @@ protected void hookOnError(Throwable ex) {
116123
if (ex instanceof SubscriptionPublisherException spe) {
117124
ExecutionResult result = ExecutionResult.newExecutionResult().errors(spe.getErrors()).build();
118125
writeResult(result.toSpecification());
126+
hookOnComplete();
119127
}
120128
else {
121129
this.sseBuilder.error(ex);
122130
}
123-
hookOnComplete();
124131
}
125132

126133
@Override

spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlSseHandlerTests.java

+31-5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import java.util.concurrent.atomic.AtomicBoolean;
2525

2626
import graphql.schema.DataFetcher;
27+
import jakarta.servlet.AsyncEvent;
28+
import jakarta.servlet.AsyncListener;
2729
import jakarta.servlet.ServletException;
2830
import jakarta.servlet.ServletOutputStream;
2931
import jakarta.servlet.http.HttpServletResponse;
@@ -35,6 +37,7 @@
3537
import org.springframework.http.MediaType;
3638
import org.springframework.http.converter.HttpMessageConverter;
3739
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
40+
import org.springframework.mock.web.MockAsyncContext;
3841
import org.springframework.mock.web.MockHttpServletRequest;
3942
import org.springframework.mock.web.MockHttpServletResponse;
4043
import org.springframework.web.servlet.function.AsyncServerResponse;
@@ -72,7 +75,7 @@ class GraphQlSseHandlerTests {
7275
void shouldRejectQueryOperations() throws Exception {
7376
GraphQlSseHandler handler = createSseHandler(SEARCH_DATA_FETCHER);
7477
MockHttpServletRequest request = createServletRequest("{ \"query\": \"{ bookById(id: 42) {name} }\"}");
75-
MockHttpServletResponse response = handleRequest(request, handler);
78+
MockHttpServletResponse response = handleAndAwait(request, handler);
7679

7780
assertThat(response.getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM_VALUE);
7881
assertThat(response.getContentAsString()).isEqualTo("""
@@ -91,7 +94,7 @@ void shouldWriteMultipleEventsForSubscription() throws Exception {
9194
MockHttpServletRequest request = createServletRequest("""
9295
{ "query": "subscription TestSubscription { bookSearch(author:\\\"Orwell\\\") { id name } }" }
9396
""");
94-
MockHttpServletResponse response = handleRequest(request, handler);
97+
MockHttpServletResponse response = handleAndAwait(request, handler);
9598

9699
assertThat(response.getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM_VALUE);
97100
assertThat(response.getContentAsString()).isEqualTo("""
@@ -117,7 +120,7 @@ void shouldWriteEventsAndTerminalError() throws Exception {
117120
MockHttpServletRequest request = createServletRequest("""
118121
{ "query": "subscription TestSubscription { bookSearch(author:\\\"Orwell\\\") { id name } }" }
119122
""");
120-
MockHttpServletResponse response = handleRequest(request, handler);
123+
MockHttpServletResponse response = handleAndAwait(request, handler);
121124

122125
assertThat(response.getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM_VALUE);
123126
assertThat(response.getContentAsString()).isEqualTo("""
@@ -153,7 +156,26 @@ void shouldCancelDataFetcherPublisherWhenWritingFails() throws Exception {
153156

154157
response.writeTo(servletRequest, servletResponse, new DefaultContext());
155158
await().atMost(Duration.ofMillis(500)).until(DATA_FETCHER_CANCELLED::get);
159+
}
160+
161+
@Test
162+
void shouldCancelDataFetcherWhenAsyncTimeout() throws Exception {
163+
DataFetcher<?> errorDataFetcher = env -> Flux.just(BookSource.getBook(1L))
164+
.delayElements(Duration.ofMillis(500)).doOnCancel(() -> DATA_FETCHER_CANCELLED.set(true));
165+
166+
GraphQlSseHandler handler = createSseHandler(errorDataFetcher);
167+
MockHttpServletRequest servletRequest = createServletRequest("""
168+
{ "query": "subscription TestSubscription { bookSearch(author:\\\"Orwell\\\") { id name } }" }
169+
""");
156170

171+
MockHttpServletResponse servletResponse = handleRequest(servletRequest, handler);
172+
for (AsyncListener listener : ((MockAsyncContext) servletRequest.getAsyncContext()).getListeners()) {
173+
listener.onTimeout(new AsyncEvent(servletRequest.getAsyncContext()));
174+
}
175+
176+
assertThat(DATA_FETCHER_CANCELLED.get()).isTrue();
177+
assertThat(servletResponse.getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM_VALUE);
178+
assertThat(servletResponse.getContentAsString()).isEmpty();
157179
}
158180

159181
private GraphQlSseHandler createSseHandler(DataFetcher<?> dataFetcher) {
@@ -174,15 +196,19 @@ private MockHttpServletRequest createServletRequest(String query) {
174196

175197
private MockHttpServletResponse handleRequest(
176198
MockHttpServletRequest servletRequest, GraphQlSseHandler handler) throws ServletException, IOException {
177-
178199
ServerRequest request = ServerRequest.create(servletRequest, MESSAGE_READERS);
179200
ServerResponse response = handler.handleRequest(request);
180201
if (response instanceof AsyncServerResponse asyncResponse) {
181202
asyncResponse.block();
182203
}
183-
184204
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
185205
response.writeTo(servletRequest, servletResponse, new DefaultContext());
206+
return servletResponse;
207+
}
208+
209+
private MockHttpServletResponse handleAndAwait(
210+
MockHttpServletRequest servletRequest, GraphQlSseHandler handler) throws ServletException, IOException {
211+
MockHttpServletResponse servletResponse = handleRequest(servletRequest, handler);
186212
await().atMost(Duration.ofMillis(500)).until(() -> servletResponse.getContentAsString().contains("complete"));
187213
return servletResponse;
188214
}

0 commit comments

Comments
 (0)