Skip to content

Commit b5126b2

Browse files
authored
Raise WebSocketDisconnect when WebSocket.send() excepts IOError (#2425)
* Raise `WebSocketDisconnect` when `WebSocket.send()` excepts `IOError` * Restrict the IOError
1 parent 3ae161e commit b5126b2

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

starlette/websockets.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,11 @@ async def send(self, message: Message) -> None:
8282
)
8383
if message_type == "websocket.close":
8484
self.application_state = WebSocketState.DISCONNECTED
85-
await self._send(message)
85+
try:
86+
await self._send(message)
87+
except IOError:
88+
self.application_state = WebSocketState.DISCONNECTED
89+
raise WebSocketDisconnect(code=1006)
8690
else:
8791
raise RuntimeError('Cannot call "send" once a close message has been sent.')
8892

tests/test_websockets.py

+22
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,28 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
255255
assert close_reason == "Going Away"
256256

257257

258+
@pytest.mark.anyio
259+
async def test_client_disconnect_on_send():
260+
async def app(scope: Scope, receive: Receive, send: Send) -> None:
261+
websocket = WebSocket(scope, receive=receive, send=send)
262+
await websocket.accept()
263+
await websocket.send_text("Hello, world!")
264+
265+
async def receive() -> Message:
266+
return {"type": "websocket.connect"}
267+
268+
async def send(message: Message) -> None:
269+
if message["type"] == "websocket.accept":
270+
return
271+
# Simulate the exception the server would send to the application when the
272+
# client disconnects.
273+
raise IOError
274+
275+
with pytest.raises(WebSocketDisconnect) as ctx:
276+
await app({"type": "websocket", "path": "/"}, receive, send)
277+
assert ctx.value.code == status.WS_1006_ABNORMAL_CLOSURE
278+
279+
258280
def test_application_close(test_client_factory: Callable[..., TestClient]):
259281
async def app(scope: Scope, receive: Receive, send: Send) -> None:
260282
websocket = WebSocket(scope, receive=receive, send=send)

0 commit comments

Comments
 (0)