diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index 006112bc6f4..6e628a7c2fe 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -258,7 +258,7 @@ def data_received(self, data: bytes) -> None: if not data: return - # custom payload parser + # custom payload parser - currently always WebSocketReader if self._payload_parser is not None: eof, tail = self._payload_parser.feed_data(data) if eof: @@ -268,57 +268,56 @@ def data_received(self, data: bytes) -> None: if tail: self.data_received(tail) return - else: - if self._upgraded or self._parser is None: - # i.e. websocket connection, websocket parser is not set yet - self._tail += data + + if self._upgraded or self._parser is None: + # i.e. websocket connection, websocket parser is not set yet + self._tail += data + return + + # parse http messages + try: + messages, upgraded, tail = self._parser.feed_data(data) + except BaseException as underlying_exc: + if self.transport is not None: + # connection.release() could be called BEFORE + # data_received(), the transport is already + # closed in this case + self.transport.close() + # should_close is True after the call + if isinstance(underlying_exc, HttpProcessingError): + exc = HttpProcessingError( + code=underlying_exc.code, + message=underlying_exc.message, + headers=underlying_exc.headers, + ) else: - # parse http messages - try: - messages, upgraded, tail = self._parser.feed_data(data) - except BaseException as underlying_exc: - if self.transport is not None: - # connection.release() could be called BEFORE - # data_received(), the transport is already - # closed in this case - self.transport.close() - # should_close is True after the call - if isinstance(underlying_exc, HttpProcessingError): - exc = HttpProcessingError( - code=underlying_exc.code, - message=underlying_exc.message, - headers=underlying_exc.headers, - ) - else: - exc = HttpProcessingError() - self.set_exception(exc, underlying_exc) - return - - self._upgraded = upgraded - - payload: Optional[StreamReader] = None - for message, payload in messages: - if message.should_close: - self._should_close = True - - self._payload = payload - - if self._skip_payload or message.code in EMPTY_BODY_STATUS_CODES: - self.feed_data((message, EMPTY_PAYLOAD)) - else: - self.feed_data((message, payload)) - if payload is not None: - # new message(s) was processed - # register timeout handler unsubscribing - # either on end-of-stream or immediately for - # EMPTY_PAYLOAD - if payload is not EMPTY_PAYLOAD: - payload.on_eof(self._drop_timeout) - else: - self._drop_timeout() + exc = HttpProcessingError() + self.set_exception(exc, underlying_exc) + return - if tail: - if upgraded: - self.data_received(tail) - else: - self._tail = tail + self._upgraded = upgraded + + payload: Optional[StreamReader] = None + for message, payload in messages: + if message.should_close: + self._should_close = True + + self._payload = payload + + if self._skip_payload or message.code in EMPTY_BODY_STATUS_CODES: + self.feed_data((message, EMPTY_PAYLOAD)) + else: + self.feed_data((message, payload)) + + if payload is not None: + # new message(s) was processed + # register timeout handler unsubscribing + # either on end-of-stream or immediately for + # EMPTY_PAYLOAD + if payload is not EMPTY_PAYLOAD: + payload.on_eof(self._drop_timeout) + else: + self._drop_timeout() + + if upgraded and tail: + self.data_received(tail) diff --git a/tests/test_client_proto.py b/tests/test_client_proto.py index 52065eca318..e5d62d1e467 100644 --- a/tests/test_client_proto.py +++ b/tests/test_client_proto.py @@ -75,6 +75,88 @@ async def test_uncompleted_message(loop: asyncio.AbstractEventLoop) -> None: assert dict(exc.message.headers) == {"Location": "http://python.org/"} +async def test_data_received_after_close(loop: asyncio.AbstractEventLoop) -> None: + proto = ResponseHandler(loop=loop) + transport = mock.Mock() + proto.connection_made(transport) + proto.set_response_params(read_until_eof=True) + proto.close() + assert transport.close.called + transport.close.reset_mock() + proto.data_received(b"HTTP\r\n\r\n") + assert proto.should_close + assert not transport.close.called + assert isinstance(proto.exception(), http.HttpProcessingError) + + +async def test_multiple_responses_one_byte_at_a_time( + loop: asyncio.AbstractEventLoop, +) -> None: + proto = ResponseHandler(loop=loop) + proto.connection_made(mock.Mock()) + conn = mock.Mock(protocol=proto) + proto.set_response_params(read_until_eof=True) + + for _ in range(2): + messages = ( + b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nab" + b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\ncd" + b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nef" + ) + for i in range(len(messages)): + proto.data_received(messages[i : i + 1]) + + expected = [b"ab", b"cd", b"ef"] + for payload in expected: + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + request_info=mock.Mock(), + traces=[], + loop=loop, + session=mock.Mock(), + ) + await response.start(conn) + await response.read() == payload + + +async def test_unexpected_exception_during_data_received( + loop: asyncio.AbstractEventLoop, +) -> None: + proto = ResponseHandler(loop=loop) + + class PatchableHttpResponseParser(http.HttpResponseParser): + """Subclass of HttpResponseParser to make it patchable.""" + + with mock.patch( + "aiohttp.client_proto.HttpResponseParser", PatchableHttpResponseParser + ): + proto.connection_made(mock.Mock()) + conn = mock.Mock(protocol=proto) + proto.set_response_params(read_until_eof=True) + proto.data_received(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nab") + response = ClientResponse( + "get", + URL("http://def-cl-resp.org"), + writer=mock.Mock(), + continue100=None, + timer=TimerNoop(), + request_info=mock.Mock(), + traces=[], + loop=loop, + session=mock.Mock(), + ) + await response.start(conn) + await response.read() == b"ab" + with mock.patch.object(proto._parser, "feed_data", side_effect=ValueError): + proto.data_received(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\ncd") + + assert isinstance(proto.exception(), http.HttpProcessingError) + + async def test_client_protocol_readuntil_eof(loop: asyncio.AbstractEventLoop) -> None: proto = ResponseHandler(loop=loop) transport = mock.Mock()