Skip to content

Commit ee9cfe1

Browse files
committed
Iterate refactor
1 parent 228e791 commit ee9cfe1

16 files changed

+532
-821
lines changed

httpcore/_async/connection.py

+11-14
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from .._backends.auto import AutoBackend
1010
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
1111
from .._exceptions import ConnectError, ConnectTimeout
12-
from .._models import Origin, Request, Response
12+
from .._models import Origin, Request
1313
from .._ssl import default_ssl_context
14-
from .._synchronization import AsyncLock
14+
from .._synchronization import AsyncSemaphore
1515
from .._trace import Trace
1616
from .http11 import AsyncHTTP11Connection
17-
from .interfaces import AsyncConnectionInterface
17+
from .interfaces import AsyncConnectionInterface, StartResponse
1818

1919
RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
2020

@@ -63,10 +63,10 @@ def __init__(
6363
)
6464
self._connection: AsyncConnectionInterface | None = None
6565
self._connect_failed: bool = False
66-
self._request_lock = AsyncLock()
66+
self._request_lock = AsyncSemaphore(bound=1)
6767
self._socket_options = socket_options
6868

69-
async def handle_async_request(self, request: Request) -> Response:
69+
async def iterate_response(self, request: Request) -> typing.AsyncIterator[StartResponse | bytes]:
7070
if not self.can_handle_request(request.url.origin):
7171
raise RuntimeError(
7272
f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
@@ -100,7 +100,11 @@ async def handle_async_request(self, request: Request) -> Response:
100100
self._connect_failed = True
101101
raise exc
102102

103-
return await self._connection.handle_async_request(request)
103+
iterator = self._connection.iterate_response(request)
104+
start_response = await anext(iterator)
105+
yield start_response
106+
async for body in iterator:
107+
yield body
104108

105109
async def _connect(self, request: Request) -> AsyncNetworkStream:
106110
timeouts = request.extensions.get("timeout", {})
@@ -174,14 +178,7 @@ async def aclose(self) -> None:
174178

175179
def is_available(self) -> bool:
176180
if self._connection is None:
177-
# If HTTP/2 support is enabled, and the resulting connection could
178-
# end up as HTTP/2 then we should indicate the connection as being
179-
# available to service multiple requests.
180-
return (
181-
self._http2
182-
and (self._origin.scheme == b"https" or not self._http1)
183-
and not self._connect_failed
184-
)
181+
return False
185182
return self._connection.is_available()
186183

187184
def has_expired(self) -> bool:

httpcore/_async/connection_pool.py

+54-218
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,15 @@
22

33
import ssl
44
import sys
5-
import types
65
import typing
76

87
from .._backends.auto import AutoBackend
98
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
10-
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
11-
from .._models import Origin, Proxy, Request, Response
12-
from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock
9+
from .._exceptions import UnsupportedProtocol
10+
from .._models import Origin, Proxy, Request
11+
from .._synchronization import AsyncSemaphore
1312
from .connection import AsyncHTTPConnection
14-
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface
15-
16-
17-
class AsyncPoolRequest:
18-
def __init__(self, request: Request) -> None:
19-
self.request = request
20-
self.connection: AsyncConnectionInterface | None = None
21-
self._connection_acquired = AsyncEvent()
22-
23-
def assign_to_connection(self, connection: AsyncConnectionInterface | None) -> None:
24-
self.connection = connection
25-
self._connection_acquired.set()
26-
27-
def clear_connection(self) -> None:
28-
self.connection = None
29-
self._connection_acquired = AsyncEvent()
30-
31-
async def wait_for_connection(
32-
self, timeout: float | None = None
33-
) -> AsyncConnectionInterface:
34-
if self.connection is None:
35-
await self._connection_acquired.wait(timeout=timeout)
36-
assert self.connection is not None
37-
return self.connection
38-
39-
def is_queued(self) -> bool:
40-
return self.connection is None
13+
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface, StartResponse
4114

4215

4316
class AsyncConnectionPool(AsyncRequestInterface):
@@ -49,6 +22,7 @@ def __init__(
4922
self,
5023
ssl_context: ssl.SSLContext | None = None,
5124
proxy: Proxy | None = None,
25+
concurrency_limit: int = 100,
5226
max_connections: int | None = 10,
5327
max_keepalive_connections: int | None = None,
5428
keepalive_expiry: float | None = None,
@@ -102,6 +76,7 @@ def __init__(
10276
self._max_keepalive_connections = min(
10377
self._max_connections, self._max_keepalive_connections
10478
)
79+
self._limits = AsyncSemaphore(bound=concurrency_limit)
10580

10681
self._keepalive_expiry = keepalive_expiry
10782
self._http1 = http1
@@ -123,7 +98,7 @@ def __init__(
12398
# We only mutate the state of the connection pool within an 'optional_thread_lock'
12499
# context. This holds a threading lock unless we're running in async mode,
125100
# in which case it is a no-op.
126-
self._optional_thread_lock = AsyncThreadLock()
101+
# self._optional_thread_lock = AsyncThreadLock()
127102

128103
def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
129104
if self._proxy is not None:
@@ -196,7 +171,7 @@ def connections(self) -> list[AsyncConnectionInterface]:
196171
"""
197172
return list(self._connections)
198173

199-
async def handle_async_request(self, request: Request) -> Response:
174+
async def iterate_response(self, request: Request) -> typing.AsyncIterator[StartResponse | bytes]:
200175
"""
201176
Send an HTTP request, and return an HTTP response.
202177
@@ -212,145 +187,50 @@ async def handle_async_request(self, request: Request) -> Response:
212187
f"Request URL has an unsupported protocol '{scheme}://'."
213188
)
214189

215-
timeouts = request.extensions.get("timeout", {})
216-
timeout = timeouts.get("pool", None)
217-
218-
with self._optional_thread_lock:
219-
# Add the incoming request to our request queue.
220-
pool_request = AsyncPoolRequest(request)
221-
self._requests.append(pool_request)
222-
223-
try:
224-
while True:
225-
with self._optional_thread_lock:
226-
# Assign incoming requests to available connections,
227-
# closing or creating new connections as required.
228-
closing = self._assign_requests_to_connections()
229-
await self._close_connections(closing)
230-
231-
# Wait until this request has an assigned connection.
232-
connection = await pool_request.wait_for_connection(timeout=timeout)
233-
234-
try:
235-
# Send the request on the assigned connection.
236-
response = await connection.handle_async_request(
237-
pool_request.request
238-
)
239-
except ConnectionNotAvailable:
240-
# In some cases a connection may initially be available to
241-
# handle a request, but then become unavailable.
242-
#
243-
# In this case we clear the connection and try again.
244-
pool_request.clear_connection()
245-
else:
246-
break # pragma: nocover
247-
248-
except BaseException as exc:
249-
with self._optional_thread_lock:
250-
# For any exception or cancellation we remove the request from
251-
# the queue, and then re-assign requests to connections.
252-
self._requests.remove(pool_request)
253-
closing = self._assign_requests_to_connections()
254-
255-
await self._close_connections(closing)
256-
raise exc from None
257-
258-
# Return the response. Note that in this case we still have to manage
259-
# the point at which the response is closed.
260-
assert isinstance(response.stream, typing.AsyncIterable)
261-
return Response(
262-
status=response.status,
263-
headers=response.headers,
264-
content=PoolByteStream(
265-
stream=response.stream, pool_request=pool_request, pool=self
266-
),
267-
extensions=response.extensions,
268-
)
269-
270-
def _assign_requests_to_connections(self) -> list[AsyncConnectionInterface]:
271-
"""
272-
Manage the state of the connection pool, assigning incoming
273-
requests to connections as available.
274-
275-
Called whenever a new request is added or removed from the pool.
276-
277-
Any closing connections are returned, allowing the I/O for closing
278-
those connections to be handled seperately.
279-
"""
280-
closing_connections = []
281-
282-
# First we handle cleaning up any connections that are closed,
283-
# have expired their keep-alive, or surplus idle connections.
284-
for connection in list(self._connections):
285-
if connection.is_closed():
286-
# log: "removing closed connection"
287-
self._connections.remove(connection)
288-
elif connection.has_expired():
289-
# log: "closing expired connection"
290-
self._connections.remove(connection)
291-
closing_connections.append(connection)
292-
elif (
293-
connection.is_idle()
294-
and len([connection.is_idle() for connection in self._connections])
295-
> self._max_keepalive_connections
296-
):
297-
# log: "closing idle connection"
298-
self._connections.remove(connection)
299-
closing_connections.append(connection)
300-
301-
# Assign queued requests to connections.
302-
queued_requests = [request for request in self._requests if request.is_queued()]
303-
for pool_request in queued_requests:
304-
origin = pool_request.request.url.origin
305-
available_connections = [
306-
connection
307-
for connection in self._connections
308-
if connection.can_handle_request(origin) and connection.is_available()
309-
]
310-
idle_connections = [
311-
connection for connection in self._connections if connection.is_idle()
312-
]
313-
314-
# There are three cases for how we may be able to handle the request:
315-
#
316-
# 1. There is an existing connection that can handle the request.
317-
# 2. We can create a new connection to handle the request.
318-
# 3. We can close an idle connection and then create a new connection
319-
# to handle the request.
320-
if available_connections:
321-
# log: "reusing existing connection"
322-
connection = available_connections[0]
323-
pool_request.assign_to_connection(connection)
324-
elif len(self._connections) < self._max_connections:
325-
# log: "creating new connection"
326-
connection = self.create_connection(origin)
327-
self._connections.append(connection)
328-
pool_request.assign_to_connection(connection)
329-
elif idle_connections:
330-
# log: "closing idle connection"
331-
connection = idle_connections[0]
332-
self._connections.remove(connection)
333-
closing_connections.append(connection)
334-
# log: "creating new connection"
335-
connection = self.create_connection(origin)
336-
self._connections.append(connection)
337-
pool_request.assign_to_connection(connection)
338-
339-
return closing_connections
340-
341-
async def _close_connections(self, closing: list[AsyncConnectionInterface]) -> None:
342-
# Close connections which have been removed from the pool.
343-
with AsyncShieldCancellation():
344-
for connection in closing:
345-
await connection.aclose()
190+
# timeouts = request.extensions.get("timeout", {})
191+
# timeout = timeouts.get("pool", None)
192+
193+
async with self._limits:
194+
connection = self._get_connection(request)
195+
iterator = connection.iterate_response(request)
196+
try:
197+
response_start = await anext(iterator)
198+
# Return the response status and headers.
199+
yield response_start
200+
# Return the response.
201+
async for event in iterator:
202+
yield event
203+
finally:
204+
await iterator.aclose()
205+
closing = self._close_connections()
206+
for conn in closing:
207+
await conn.aclose()
208+
209+
def _get_connection(self, request):
210+
origin = request.url.origin
211+
for connection in self._connections:
212+
if connection.can_handle_request(origin) and connection.is_available():
213+
return connection
214+
215+
connection = self.create_connection(origin)
216+
self._connections.append(connection)
217+
return connection
218+
219+
def _close_connections(self):
220+
closing = [conn for conn in self._connections if conn.has_expired()]
221+
self._connections = [
222+
conn for conn in self._connections
223+
if not (conn.has_expired() or conn.is_closed())
224+
]
225+
return closing
346226

347227
async def aclose(self) -> None:
348228
# Explicitly close the connection pool.
349229
# Clears all existing requests and connections.
350-
with self._optional_thread_lock:
351-
closing_connections = list(self._connections)
352-
self._connections = []
353-
await self._close_connections(closing_connections)
230+
closing = list(self._connections)
231+
self._connections = []
232+
for conn in closing:
233+
await conn.aclose()
354234

355235
async def __aenter__(self) -> AsyncConnectionPool:
356236
return self
@@ -365,56 +245,12 @@ async def __aexit__(
365245

366246
def __repr__(self) -> str:
367247
class_name = self.__class__.__name__
368-
with self._optional_thread_lock:
369-
request_is_queued = [request.is_queued() for request in self._requests]
370-
connection_is_idle = [
371-
connection.is_idle() for connection in self._connections
372-
]
373-
374-
num_active_requests = request_is_queued.count(False)
375-
num_queued_requests = request_is_queued.count(True)
376-
num_active_connections = connection_is_idle.count(False)
377-
num_idle_connections = connection_is_idle.count(True)
378-
379-
requests_info = (
380-
f"Requests: {num_active_requests} active, {num_queued_requests} queued"
381-
)
248+
connection_is_idle = [
249+
connection.is_idle() for connection in self._connections
250+
]
251+
num_active_connections = connection_is_idle.count(False)
252+
num_idle_connections = connection_is_idle.count(True)
382253
connection_info = (
383254
f"Connections: {num_active_connections} active, {num_idle_connections} idle"
384255
)
385-
386-
return f"<{class_name} [{requests_info} | {connection_info}]>"
387-
388-
389-
class PoolByteStream:
390-
def __init__(
391-
self,
392-
stream: typing.AsyncIterable[bytes],
393-
pool_request: AsyncPoolRequest,
394-
pool: AsyncConnectionPool,
395-
) -> None:
396-
self._stream = stream
397-
self._pool_request = pool_request
398-
self._pool = pool
399-
self._closed = False
400-
401-
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
402-
try:
403-
async for part in self._stream:
404-
yield part
405-
except BaseException as exc:
406-
await self.aclose()
407-
raise exc from None
408-
409-
async def aclose(self) -> None:
410-
if not self._closed:
411-
self._closed = True
412-
with AsyncShieldCancellation():
413-
if hasattr(self._stream, "aclose"):
414-
await self._stream.aclose()
415-
416-
with self._pool._optional_thread_lock:
417-
self._pool._requests.remove(self._pool_request)
418-
closing = self._pool._assign_requests_to_connections()
419-
420-
await self._pool._close_connections(closing)
256+
return f"<{class_name} [{connection_info}]>"

0 commit comments

Comments
 (0)