Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce code indent in ResponseHandler.data_received #8699

Merged
merged 9 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 52 additions & 53 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def data_received(self, data: bytes) -> None:
if not data:
return

# custom payload parser
# custom payload parser - currently always WebSocketReader
if self._payload_parser is not None:
eof, tail = self._payload_parser.feed_data(data)
if eof:
Expand All @@ -268,57 +268,56 @@ def data_received(self, data: bytes) -> None:
if tail:
self.data_received(tail)
return
else:
if self._upgraded or self._parser is None:
# i.e. websocket connection, websocket parser is not set yet
self._tail += data

if self._upgraded or self._parser is None:
# i.e. websocket connection, websocket parser is not set yet
self._tail += data
return

# parse http messages
try:
messages, upgraded, tail = self._parser.feed_data(data)
except BaseException as underlying_exc:
if self.transport is not None:
# connection.release() could be called BEFORE
# data_received(), the transport is already
# closed in this case
self.transport.close()
# should_close is True after the call
if isinstance(underlying_exc, HttpProcessingError):
exc = HttpProcessingError(
code=underlying_exc.code,
message=underlying_exc.message,
headers=underlying_exc.headers,
)
else:
# parse http messages
try:
messages, upgraded, tail = self._parser.feed_data(data)
except BaseException as underlying_exc:
if self.transport is not None:
# connection.release() could be called BEFORE
# data_received(), the transport is already
# closed in this case
self.transport.close()
# should_close is True after the call
if isinstance(underlying_exc, HttpProcessingError):
exc = HttpProcessingError(
code=underlying_exc.code,
message=underlying_exc.message,
headers=underlying_exc.headers,
)
else:
exc = HttpProcessingError()
self.set_exception(exc, underlying_exc)
return

self._upgraded = upgraded

payload: Optional[StreamReader] = None
for message, payload in messages:
if message.should_close:
self._should_close = True

self._payload = payload

if self._skip_payload or message.code in EMPTY_BODY_STATUS_CODES:
self.feed_data((message, EMPTY_PAYLOAD))
else:
self.feed_data((message, payload))
if payload is not None:
# new message(s) was processed
# register timeout handler unsubscribing
# either on end-of-stream or immediately for
# EMPTY_PAYLOAD
if payload is not EMPTY_PAYLOAD:
payload.on_eof(self._drop_timeout)
else:
self._drop_timeout()
exc = HttpProcessingError()
self.set_exception(exc, underlying_exc)
return

if tail:
if upgraded:
self.data_received(tail)
else:
self._tail = tail
self._upgraded = upgraded

payload: Optional[StreamReader] = None
for message, payload in messages:
if message.should_close:
self._should_close = True

self._payload = payload

if self._skip_payload or message.code in EMPTY_BODY_STATUS_CODES:
self.feed_data((message, EMPTY_PAYLOAD))
else:
self.feed_data((message, payload))

if payload is not None:
# new message(s) was processed
# register timeout handler unsubscribing
# either on end-of-stream or immediately for
# EMPTY_PAYLOAD
if payload is not EMPTY_PAYLOAD:
payload.on_eof(self._drop_timeout)
else:
self._drop_timeout()

if upgraded and tail:
self.data_received(tail)
82 changes: 82 additions & 0 deletions tests/test_client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,88 @@ async def test_uncompleted_message(loop: asyncio.AbstractEventLoop) -> None:
assert dict(exc.message.headers) == {"Location": "http://python.org/"}


async def test_data_received_after_close(loop: asyncio.AbstractEventLoop) -> None:
proto = ResponseHandler(loop=loop)
transport = mock.Mock()
proto.connection_made(transport)
proto.set_response_params(read_until_eof=True)
proto.close()
assert transport.close.called
transport.close.reset_mock()
proto.data_received(b"HTTP\r\n\r\n")
assert proto.should_close
assert not transport.close.called
assert isinstance(proto.exception(), http.HttpProcessingError)


async def test_multiple_responses_one_byte_at_a_time(
loop: asyncio.AbstractEventLoop,
) -> None:
proto = ResponseHandler(loop=loop)
proto.connection_made(mock.Mock())
conn = mock.Mock(protocol=proto)
proto.set_response_params(read_until_eof=True)

for _ in range(2):
messages = (
b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nab"
b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\ncd"
b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nef"
)
for i in range(len(messages)):
proto.data_received(messages[i : i + 1])

expected = [b"ab", b"cd", b"ef"]
for payload in expected:
response = ClientResponse(
"get",
URL("http://def-cl-resp.org"),
writer=mock.Mock(),
continue100=None,
timer=TimerNoop(),
request_info=mock.Mock(),
traces=[],
loop=loop,
session=mock.Mock(),
)
await response.start(conn)
await response.read() == payload


async def test_unexpected_exception_during_data_received(
loop: asyncio.AbstractEventLoop,
) -> None:
proto = ResponseHandler(loop=loop)

class PatchableHttpResponseParser(http.HttpResponseParser):
"""Subclass of HttpResponseParser to make it patchable."""

with mock.patch(
"aiohttp.client_proto.HttpResponseParser", PatchableHttpResponseParser
):
proto.connection_made(mock.Mock())
conn = mock.Mock(protocol=proto)
proto.set_response_params(read_until_eof=True)
proto.data_received(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nab")
response = ClientResponse(
"get",
URL("http://def-cl-resp.org"),
writer=mock.Mock(),
continue100=None,
timer=TimerNoop(),
request_info=mock.Mock(),
traces=[],
loop=loop,
session=mock.Mock(),
)
await response.start(conn)
await response.read() == b"ab"
with mock.patch.object(proto._parser, "feed_data", side_effect=ValueError):
proto.data_received(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\ncd")

assert isinstance(proto.exception(), http.HttpProcessingError)


async def test_client_protocol_readuntil_eof(loop: asyncio.AbstractEventLoop) -> None:
proto = ResponseHandler(loop=loop)
transport = mock.Mock()
Expand Down
Loading