-
-
Notifications
You must be signed in to change notification settings - Fork 982
/
Copy path_exception_handler.py
87 lines (70 loc) · 2.79 KB
/
_exception_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from __future__ import annotations
import typing
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.types import (
ASGIApp,
ExceptionHandler,
HTTPExceptionHandler,
Message,
Receive,
Scope,
Send,
WebSocketExceptionHandler,
)
from starlette.websockets import WebSocket
ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
StatusHandlers = typing.Dict[int, ExceptionHandler]
def _lookup_exception_handler(
exc_handlers: ExceptionHandlers, exc: Exception
) -> ExceptionHandler | None:
for cls in type(exc).__mro__:
if cls in exc_handlers:
return exc_handlers[cls]
return None
def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp:
exception_handlers: ExceptionHandlers
status_handlers: StatusHandlers
try:
exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"]
except KeyError:
exception_handlers, status_handlers = {}, {}
async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
response_started = False
async def sender(message: Message) -> None:
nonlocal response_started
if message["type"] == "http.response.start":
response_started = True
await send(message)
try:
await app(scope, receive, sender)
except Exception as exc:
handler = None
if isinstance(exc, HTTPException):
handler = status_handlers.get(exc.status_code)
if handler is None:
handler = _lookup_exception_handler(exception_handlers, exc)
if handler is None:
raise exc
if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc
if scope["type"] == "http":
nonlocal conn
handler = typing.cast(HTTPExceptionHandler, handler)
conn = typing.cast(Request, conn)
if is_async_callable(handler):
response = await handler(conn, exc)
else:
response = await run_in_threadpool(handler, conn, exc)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
handler = typing.cast(WebSocketExceptionHandler, handler)
conn = typing.cast(WebSocket, conn)
if is_async_callable(handler):
await handler(conn, exc)
else:
await run_in_threadpool(handler, conn, exc)
return wrapped_app