-
Notifications
You must be signed in to change notification settings - Fork 78
/
Copy pathquery.py
477 lines (398 loc) · 17.6 KB
/
query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
# Copyright 2017 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes for representing queries for the Google Cloud Firestore API.
A :class:`~google.cloud.firestore_v1.query.Query` can be created directly from
a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be
a more common way to create a query than direct usage of the constructor.
"""
from __future__ import annotations
from google.cloud import firestore_v1
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.api_core import exceptions
from google.api_core import gapic_v1
from google.api_core import retry as retries
from google.cloud.firestore_v1.base_query import (
BaseCollectionGroup,
BaseQuery,
QueryPartition,
_query_response_to_snapshot,
_collection_group_query_response_to_snapshot,
_enum_from_direction,
)
from google.cloud.firestore_v1 import aggregation
from google.cloud.firestore_v1 import document
from google.cloud.firestore_v1.watch import Watch
from typing import Any, Callable, Generator, List, Optional, Type, TYPE_CHECKING
if TYPE_CHECKING: # pragma: NO COVER
from google.cloud.firestore_v1.field_path import FieldPath
class Query(BaseQuery):
"""Represents a query to the Firestore API.
Instances of this class are considered immutable: all methods that
would modify an instance instead return a new instance.
Args:
parent (:class:`~google.cloud.firestore_v1.collection.CollectionReference`):
The collection that this query applies to.
projection (Optional[:class:`google.cloud.proto.firestore.v1.\
query.StructuredQuery.Projection`]):
A projection of document fields to limit the query results to.
field_filters (Optional[Tuple[:class:`google.cloud.proto.firestore.v1.\
query.StructuredQuery.FieldFilter`, ...]]):
The filters to be applied in the query.
orders (Optional[Tuple[:class:`google.cloud.proto.firestore.v1.\
query.StructuredQuery.Order`, ...]]):
The "order by" entries to use in the query.
limit (Optional[int]):
The maximum number of documents the query is allowed to return.
offset (Optional[int]):
The number of results to skip.
start_at (Optional[Tuple[dict, bool]]):
Two-tuple of :
* a mapping of fields. Any field that is present in this mapping
must also be present in ``orders``
* an ``after`` flag
The fields and the flag combine to form a cursor used as
a starting point in a query result set. If the ``after``
flag is :data:`True`, the results will start just after any
documents which have fields matching the cursor, otherwise
any matching documents will be included in the result set.
When the query is formed, the document values
will be used in the order given by ``orders``.
end_at (Optional[Tuple[dict, bool]]):
Two-tuple of:
* a mapping of fields. Any field that is present in this mapping
must also be present in ``orders``
* a ``before`` flag
The fields and the flag combine to form a cursor used as
an ending point in a query result set. If the ``before``
flag is :data:`True`, the results will end just before any
documents which have fields matching the cursor, otherwise
any matching documents will be included in the result set.
When the query is formed, the document values
will be used in the order given by ``orders``.
all_descendants (Optional[bool]):
When false, selects only collections that are immediate children
of the `parent` specified in the containing `RunQueryRequest`.
When true, selects all descendant collections.
"""
def __init__(
self,
parent,
projection=None,
field_filters=(),
orders=(),
limit=None,
limit_to_last=False,
offset=None,
start_at=None,
end_at=None,
all_descendants=False,
recursive=False,
) -> None:
super(Query, self).__init__(
parent=parent,
projection=projection,
field_filters=field_filters,
orders=orders,
limit=limit,
limit_to_last=limit_to_last,
offset=offset,
start_at=start_at,
end_at=end_at,
all_descendants=all_descendants,
recursive=recursive,
)
def get(
self,
transaction=None,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> List[DocumentSnapshot]:
"""Read the documents in the collection that match this query.
This sends a ``RunQuery`` RPC and returns a list of documents
returned in the stream of ``RunQueryResponse`` messages.
Args:
transaction
(Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
An existing transaction that this query will run in.
If a ``transaction`` is used and it already has write operations
added, this method cannot be used (i.e. read-after-write is not
allowed).
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
Returns:
list: The documents in the collection that match this query.
"""
is_limited_to_last = self._limit_to_last
if self._limit_to_last:
# In order to fetch up to `self._limit` results from the end of the
# query flip the defined ordering on the query to start from the
# end, retrieving up to `self._limit` results from the backend.
for order in self._orders:
order.direction = _enum_from_direction(
self.DESCENDING
if order.direction.name == self.ASCENDING
else self.ASCENDING
)
self._limit_to_last = False
result = self.stream(transaction=transaction, retry=retry, timeout=timeout)
if is_limited_to_last:
result = reversed(list(result))
return list(result)
def _chunkify(
self, chunk_size: int
) -> Generator[List[DocumentSnapshot], None, None]:
max_to_return: Optional[int] = self._limit
num_returned: int = 0
original: Query = self._copy()
last_document: Optional[DocumentSnapshot] = None
while True:
# Optionally trim the `chunk_size` down to honor a previously
# applied limits as set by `self.limit()`
_chunk_size: int = original._resolve_chunk_size(num_returned, chunk_size)
# Apply the optionally pruned limit and the cursor, if we are past
# the first page.
_q = original.limit(_chunk_size)
if last_document:
_q = _q.start_after(last_document)
snapshots = _q.get()
if snapshots:
last_document = snapshots[-1]
num_returned += len(snapshots)
yield snapshots
# Terminate the iterator if we have reached either of two end
# conditions:
# 1. There are no more documents, or
# 2. We have reached the desired overall limit
if len(snapshots) < _chunk_size or (
max_to_return and num_returned >= max_to_return
):
return
def _get_stream_iterator(self, transaction, retry, timeout):
"""Helper method for :meth:`stream`."""
request, expected_prefix, kwargs = self._prep_stream(
transaction,
retry,
timeout,
)
response_iterator = self._client._firestore_api.run_query(
request=request,
metadata=self._client._rpc_metadata,
**kwargs,
)
return response_iterator, expected_prefix
def _retry_query_after_exception(self, exc, retry, transaction):
"""Helper method for :meth:`stream`."""
if transaction is None: # no snapshot-based retry inside transaction
if retry is gapic_v1.method.DEFAULT:
transport = self._client._firestore_api._transport
gapic_callable = transport.run_query
retry = gapic_callable._retry
return retry._predicate(exc)
return False
def count(
self, alias: str | None = None
) -> Type["firestore_v1.aggregation.AggregationQuery"]:
"""
Adds a count over the query.
:type alias: Optional[str]
:param alias: Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
"""
return aggregation.AggregationQuery(self).count(alias=alias)
def sum(
self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.aggregation.AggregationQuery"]:
"""
Adds a sum over the query.
:type field_ref: Union[str, google.cloud.firestore_v1.field_path.FieldPath]
:param field_ref: The field to aggregate across.
:type alias: Optional[str]
:param alias: Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
"""
return aggregation.AggregationQuery(self).sum(field_ref, alias=alias)
def avg(
self, field_ref: str | FieldPath, alias: str | None = None
) -> Type["firestore_v1.aggregation.AggregationQuery"]:
"""
Adds an avg over the query.
:type field_ref: [Union[str, google.cloud.firestore_v1.field_path.FieldPath]
:param field_ref: The field to aggregate across.
:type alias: Optional[str]
:param alias: Optional name of the field to store the result of the aggregation into.
If not provided, Firestore will pick a default name following the format field_<incremental_id++>.
"""
return aggregation.AggregationQuery(self).avg(field_ref, alias=alias)
def stream(
self,
transaction=None,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> Generator[document.DocumentSnapshot, Any, None]:
"""Read the documents in the collection that match this query.
This sends a ``RunQuery`` RPC and then returns an iterator which
consumes each document returned in the stream of ``RunQueryResponse``
messages.
.. note::
The underlying stream of responses will time out after
the ``max_rpc_timeout_millis`` value set in the GAPIC
client configuration for the ``RunQuery`` API. Snapshots
not consumed from the iterator before that point will be lost.
If a ``transaction`` is used and it already has write operations
added, this method cannot be used (i.e. read-after-write is not
allowed).
Args:
transaction
(Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
An existing transaction that this query will run in.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
Yields:
:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`:
The next document that fulfills the query.
"""
response_iterator, expected_prefix = self._get_stream_iterator(
transaction,
retry,
timeout,
)
last_snapshot = None
while True:
try:
response = next(response_iterator, None)
except exceptions.GoogleAPICallError as exc:
if self._retry_query_after_exception(exc, retry, transaction):
new_query = self.start_after(last_snapshot)
response_iterator, _ = new_query._get_stream_iterator(
transaction,
retry,
timeout,
)
continue
else:
raise
if response is None: # EOI
break
if self._all_descendants:
snapshot = _collection_group_query_response_to_snapshot(
response, self._parent
)
else:
snapshot = _query_response_to_snapshot(
response, self._parent, expected_prefix
)
if snapshot is not None:
last_snapshot = snapshot
yield snapshot
def on_snapshot(self, callback: Callable) -> Watch:
"""Monitor the documents in this collection that match this query.
This starts a watch on this query using a background thread. The
provided callback is run on the snapshot of the documents.
Args:
callback(Callable[[:class:`~google.cloud.firestore.query.QuerySnapshot`], NoneType]):
a callback to run when a change occurs.
Example:
.. code-block:: python
from google.cloud import firestore_v1
db = firestore_v1.Client()
query_ref = db.collection(u'users').where("user", "==", u'Ada')
def on_snapshot(docs, changes, read_time):
for doc in docs:
print(u'{} => {}'.format(doc.id, doc.to_dict()))
# Watch this query
query_watch = query_ref.on_snapshot(on_snapshot)
# Terminate this watch
query_watch.unsubscribe()
"""
return Watch.for_query(self, callback, document.DocumentSnapshot)
@staticmethod
def _get_collection_reference_class() -> (
Type["firestore_v1.collection.CollectionReference"]
):
from google.cloud.firestore_v1.collection import CollectionReference
return CollectionReference
class CollectionGroup(Query, BaseCollectionGroup):
"""Represents a Collection Group in the Firestore API.
This is a specialization of :class:`.Query` that includes all documents in the
database that are contained in a collection or subcollection of the given
parent.
Args:
parent (:class:`~google.cloud.firestore_v1.collection.CollectionReference`):
The collection that this query applies to.
"""
def __init__(
self,
parent,
projection=None,
field_filters=(),
orders=(),
limit=None,
limit_to_last=False,
offset=None,
start_at=None,
end_at=None,
all_descendants=True,
recursive=False,
) -> None:
super(CollectionGroup, self).__init__(
parent=parent,
projection=projection,
field_filters=field_filters,
orders=orders,
limit=limit,
limit_to_last=limit_to_last,
offset=offset,
start_at=start_at,
end_at=end_at,
all_descendants=all_descendants,
recursive=recursive,
)
@staticmethod
def _get_query_class():
return Query
def get_partitions(
self,
partition_count,
retry: retries.Retry = gapic_v1.method.DEFAULT,
timeout: float = None,
) -> Generator[QueryPartition, None, None]:
"""Partition a query for parallelization.
Partitions a query by returning partition cursors that can be used to run the
query in parallel. The returned partition cursors are split points that can be
used as starting/end points for the query results.
Args:
partition_count (int): The desired maximum number of partition points. The
number must be strictly positive. The actual number of partitions
returned may be fewer.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried. Defaults to a system-specified policy.
timeout (float): The timeout for this request. Defaults to a
system-specified value.
"""
request, kwargs = self._prep_get_partitions(partition_count, retry, timeout)
pager = self._client._firestore_api.partition_query(
request=request,
metadata=self._client._rpc_metadata,
**kwargs,
)
start_at = None
for cursor_pb in pager:
cursor = self._client.document(cursor_pb.values[0].reference_value)
yield QueryPartition(self, start_at, cursor)
start_at = cursor
yield QueryPartition(self, start_at, None)