Skip to content

Commit 200c6c1

Browse files
committed
Restrict the IOError
1 parent 11f075b commit 200c6c1

File tree

2 files changed

+47
-27
lines changed

2 files changed

+47
-27
lines changed

starlette/websockets.py

+25-27
Original file line numberDiff line numberDiff line change
@@ -61,36 +61,34 @@ async def send(self, message: Message) -> None:
6161
"""
6262
Send ASGI websocket messages, ensuring valid state transitions.
6363
"""
64-
try:
65-
if self.application_state == WebSocketState.CONNECTING:
66-
message_type = message["type"]
67-
if message_type not in {"websocket.accept", "websocket.close"}:
68-
raise RuntimeError(
69-
'Expected ASGI message "websocket.accept" or '
70-
f'"websocket.close", but got {message_type!r}'
71-
)
72-
if message_type == "websocket.close":
73-
self.application_state = WebSocketState.DISCONNECTED
74-
else:
75-
self.application_state = WebSocketState.CONNECTED
76-
await self._send(message)
77-
elif self.application_state == WebSocketState.CONNECTED:
78-
message_type = message["type"]
79-
if message_type not in {"websocket.send", "websocket.close"}:
80-
raise RuntimeError(
81-
'Expected ASGI message "websocket.send" or "websocket.close", '
82-
f"but got {message_type!r}"
83-
)
84-
if message_type == "websocket.close":
85-
self.application_state = WebSocketState.DISCONNECTED
86-
await self._send(message)
64+
if self.application_state == WebSocketState.CONNECTING:
65+
message_type = message["type"]
66+
if message_type not in {"websocket.accept", "websocket.close"}:
67+
raise RuntimeError(
68+
'Expected ASGI message "websocket.accept" or '
69+
f'"websocket.close", but got {message_type!r}'
70+
)
71+
if message_type == "websocket.close":
72+
self.application_state = WebSocketState.DISCONNECTED
8773
else:
74+
self.application_state = WebSocketState.CONNECTED
75+
await self._send(message)
76+
elif self.application_state == WebSocketState.CONNECTED:
77+
message_type = message["type"]
78+
if message_type not in {"websocket.send", "websocket.close"}:
8879
raise RuntimeError(
89-
'Cannot call "send" once a close message has been sent.'
80+
'Expected ASGI message "websocket.send" or "websocket.close", '
81+
f"but got {message_type!r}"
9082
)
91-
except IOError: # pragma: no cover
92-
self.application_state = WebSocketState.DISCONNECTED
93-
raise WebSocketDisconnect(code=1006)
83+
if message_type == "websocket.close":
84+
self.application_state = WebSocketState.DISCONNECTED
85+
try:
86+
await self._send(message)
87+
except IOError:
88+
self.application_state = WebSocketState.DISCONNECTED
89+
raise WebSocketDisconnect(code=1006)
90+
else:
91+
raise RuntimeError('Cannot call "send" once a close message has been sent.')
9492

9593
async def accept(
9694
self,

tests/test_websockets.py

+22
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,28 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
255255
assert close_reason == "Going Away"
256256

257257

258+
@pytest.mark.anyio
259+
async def test_client_disconnect_on_send():
260+
async def app(scope: Scope, receive: Receive, send: Send) -> None:
261+
websocket = WebSocket(scope, receive=receive, send=send)
262+
await websocket.accept()
263+
await websocket.send_text("Hello, world!")
264+
265+
async def receive() -> Message:
266+
return {"type": "websocket.connect"}
267+
268+
async def send(message: Message) -> None:
269+
if message["type"] == "websocket.accept":
270+
return
271+
# Simulate the exception the server would send to the application when the
272+
# client disconnects.
273+
raise IOError
274+
275+
with pytest.raises(WebSocketDisconnect) as ctx:
276+
await app({"type": "websocket", "path": "/"}, receive, send)
277+
assert ctx.value.code == status.WS_1006_ABNORMAL_CLOSURE
278+
279+
258280
def test_application_close(test_client_factory: Callable[..., TestClient]):
259281
async def app(scope: Scope, receive: Receive, send: Send) -> None:
260282
websocket = WebSocket(scope, receive=receive, send=send)

0 commit comments

Comments
 (0)