-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathconcurrency.py
90 lines (75 loc) · 2.74 KB
/
concurrency.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import contextlib
import threading
from types import TracebackType
from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Type
from cassandra.cluster import ResponseFuture, Session
from cassandra.query import PreparedStatement
class ConcurrentQueries(contextlib.AbstractContextManager):
"""Context manager for concurrent queries."""
def __init__(self, session: Session, *, concurrency: int = 20) -> None:
self._session = session
self._semaphore = threading.Semaphore(concurrency)
self._completion = threading.Condition()
self._pending = 0
self._error = None
def _handle_result(
self,
result: Sequence[NamedTuple],
future: ResponseFuture,
callback: Optional[Callable[[Sequence[NamedTuple]], Any]],
):
if callback is not None:
callback(result)
if future.has_more_pages:
future.start_fetching_next_page()
else:
self._semaphore.release()
with self._completion:
self._pending -= 1
if self._pending == 0:
self._completion.notify()
def _handle_error(self, error, future: ResponseFuture):
with self._completion:
print(f"Failed to execute {future.query}: {error}")
self._error = error
self._completion.notify()
def execute(
self,
query: PreparedStatement,
parameters: Optional[Tuple] = None,
callback: Optional[Callable[[Sequence[NamedTuple]], Any]] = None,
):
with self._completion:
self._pending += 1
if self._error is not None:
return
self._semaphore.acquire()
future: ResponseFuture = self._session.execute_async(query, parameters)
future.add_callbacks(
self._handle_result,
self._handle_error,
callback_kwargs={
"future": future,
"callback": callback,
},
errback_kwargs={
"future": future,
}
)
def __enter__(self) -> "ConcurrentQueries":
return super().__enter__()
def __exit__(
self,
_exc_type: Optional[Type[BaseException]],
_exc_inst: Optional[BaseException],
_exc_traceback: Optional[TracebackType],
) -> bool:
with self._completion:
while self._error is None and self._pending > 0:
self._completion.wait()
if self._error is not None:
raise self._error
# Don't swallow the exception.
# We don't need to do anything with the exception (`_exc_*` parameters)
# since returning false here will automatically re-raise it.
return False