4
4
from datetime import datetime
5
5
from typing import (
6
6
Any ,
7
+ AsyncGenerator ,
7
8
Callable ,
8
9
Dict ,
9
10
Generator ,
12
13
Optional ,
13
14
Sequence ,
14
15
Tuple ,
16
+ Union ,
15
17
)
16
18
17
19
import pytz
18
- from psycopg import sql
20
+ from psycopg import AsyncConnection , sql
19
21
from psycopg .connection import Connection
20
- from psycopg_pool import ConnectionPool
22
+ from psycopg_pool import AsyncConnectionPool , ConnectionPool
21
23
22
24
from feast import Entity
23
25
from feast .feature_view import FeatureView
24
26
from feast .infra .key_encoding_utils import get_list_val_str , serialize_entity_key
25
27
from feast .infra .online_stores .online_store import OnlineStore
26
- from feast .infra .utils .postgres .connection_utils import _get_conn , _get_connection_pool
28
+ from feast .infra .utils .postgres .connection_utils import (
29
+ _get_conn ,
30
+ _get_conn_async ,
31
+ _get_connection_pool ,
32
+ _get_connection_pool_async ,
33
+ )
27
34
from feast .infra .utils .postgres .postgres_config import ConnectionType , PostgreSQLConfig
28
35
from feast .protos .feast .types .EntityKey_pb2 import EntityKey as EntityKeyProto
29
36
from feast .protos .feast .types .Value_pb2 import Value as ValueProto
@@ -51,6 +58,9 @@ class PostgreSQLOnlineStore(OnlineStore):
51
58
_conn : Optional [Connection ] = None
52
59
_conn_pool : Optional [ConnectionPool ] = None
53
60
61
+ _conn_async : Optional [AsyncConnection ] = None
62
+ _conn_pool_async : Optional [AsyncConnectionPool ] = None
63
+
54
64
@contextlib .contextmanager
55
65
def _get_conn (self , config : RepoConfig ) -> Generator [Connection , Any , Any ]:
56
66
assert config .online_store .type == "postgres"
@@ -67,6 +77,24 @@ def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]:
67
77
self ._conn = _get_conn (config .online_store )
68
78
yield self ._conn
69
79
80
+ @contextlib .asynccontextmanager
81
+ async def _get_conn_async (
82
+ self , config : RepoConfig
83
+ ) -> AsyncGenerator [AsyncConnection , Any ]:
84
+ if config .online_store .conn_type == ConnectionType .pool :
85
+ if not self ._conn_pool_async :
86
+ self ._conn_pool_async = await _get_connection_pool_async (
87
+ config .online_store
88
+ )
89
+ await self ._conn_pool_async .open ()
90
+ connection = await self ._conn_pool_async .getconn ()
91
+ yield connection
92
+ await self ._conn_pool_async .putconn (connection )
93
+ else :
94
+ if not self ._conn_async :
95
+ self ._conn_async = await _get_conn_async (config .online_store )
96
+ yield self ._conn_async
97
+
70
98
def online_write_batch (
71
99
self ,
72
100
config : RepoConfig ,
@@ -132,69 +160,107 @@ def online_read(
132
160
entity_keys : List [EntityKeyProto ],
133
161
requested_features : Optional [List [str ]] = None ,
134
162
) -> List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]]:
135
- result : List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]] = []
163
+ keys = self ._prepare_keys (entity_keys , config .entity_key_serialization_version )
164
+ query , params = self ._construct_query_and_params (
165
+ config , table , keys , requested_features
166
+ )
136
167
137
- project = config .project
138
168
with self ._get_conn (config ) as conn , conn .cursor () as cur :
139
- # Collecting all the keys to a list allows us to make fewer round trips
140
- # to PostgreSQL
141
- keys = []
142
- for entity_key in entity_keys :
143
- keys .append (
144
- serialize_entity_key (
145
- entity_key ,
146
- entity_key_serialization_version = config .entity_key_serialization_version ,
147
- )
148
- )
169
+ cur .execute (query , params )
170
+ rows = cur .fetchall ()
149
171
150
- if not requested_features :
151
- cur .execute (
152
- sql .SQL (
153
- """
154
- SELECT entity_key, feature_name, value, event_ts
155
- FROM {} WHERE entity_key = ANY(%s);
156
- """
157
- ).format (
158
- sql .Identifier (_table_id (project , table )),
159
- ),
160
- (keys ,),
161
- )
162
- else :
163
- cur .execute (
164
- sql .SQL (
165
- """
166
- SELECT entity_key, feature_name, value, event_ts
167
- FROM {} WHERE entity_key = ANY(%s) and feature_name = ANY(%s);
168
- """
169
- ).format (
170
- sql .Identifier (_table_id (project , table )),
171
- ),
172
- (keys , requested_features ),
173
- )
172
+ return self ._process_rows (keys , rows )
174
173
175
- rows = cur .fetchall ()
174
+ async def online_read_async (
175
+ self ,
176
+ config : RepoConfig ,
177
+ table : FeatureView ,
178
+ entity_keys : List [EntityKeyProto ],
179
+ requested_features : Optional [List [str ]] = None ,
180
+ ) -> List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]]:
181
+ keys = self ._prepare_keys (entity_keys , config .entity_key_serialization_version )
182
+ query , params = self ._construct_query_and_params (
183
+ config , table , keys , requested_features
184
+ )
176
185
177
- # Since we don't know the order returned from PostgreSQL we'll need
178
- # to construct a dict to be able to quickly look up the correct row
179
- # when we iterate through the keys since they are in the correct order
180
- values_dict = defaultdict (list )
181
- for row in rows if rows is not None else []:
182
- values_dict [
183
- row [0 ] if isinstance (row [0 ], bytes ) else row [0 ].tobytes ()
184
- ].append (row [1 :])
185
-
186
- for key in keys :
187
- if key in values_dict :
188
- value = values_dict [key ]
189
- res = {}
190
- for feature_name , value_bin , event_ts in value :
191
- val = ValueProto ()
192
- val .ParseFromString (bytes (value_bin ))
193
- res [feature_name ] = val
194
- result .append ((event_ts , res ))
195
- else :
196
- result .append ((None , None ))
186
+ async with self ._get_conn_async (config ) as conn :
187
+ async with conn .cursor () as cur :
188
+ await cur .execute (query , params )
189
+ rows = await cur .fetchall ()
190
+
191
+ return self ._process_rows (keys , rows )
192
+
193
+ @staticmethod
194
+ def _construct_query_and_params (
195
+ config : RepoConfig ,
196
+ table : FeatureView ,
197
+ keys : List [bytes ],
198
+ requested_features : Optional [List [str ]] = None ,
199
+ ) -> Tuple [sql .Composed , Union [Tuple [List [bytes ], List [str ]], Tuple [List [bytes ]]]]:
200
+ """Construct the SQL query based on the given parameters."""
201
+ if requested_features :
202
+ query = sql .SQL (
203
+ """
204
+ SELECT entity_key, feature_name, value, event_ts
205
+ FROM {} WHERE entity_key = ANY(%s) AND feature_name = ANY(%s);
206
+ """
207
+ ).format (
208
+ sql .Identifier (_table_id (config .project , table )),
209
+ )
210
+ params = (keys , requested_features )
211
+ else :
212
+ query = sql .SQL (
213
+ """
214
+ SELECT entity_key, feature_name, value, event_ts
215
+ FROM {} WHERE entity_key = ANY(%s);
216
+ """
217
+ ).format (
218
+ sql .Identifier (_table_id (config .project , table )),
219
+ )
220
+ params = (keys , [])
221
+ return query , params
222
+
223
+ @staticmethod
224
+ def _prepare_keys (
225
+ entity_keys : List [EntityKeyProto ], entity_key_serialization_version : int
226
+ ) -> List [bytes ]:
227
+ """Prepare all keys in a list to make fewer round trips to the database."""
228
+ return [
229
+ serialize_entity_key (
230
+ entity_key ,
231
+ entity_key_serialization_version = entity_key_serialization_version ,
232
+ )
233
+ for entity_key in entity_keys
234
+ ]
235
+
236
+ @staticmethod
237
+ def _process_rows (
238
+ keys : List [bytes ], rows : List [Tuple ]
239
+ ) -> List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]]:
240
+ """Transform the retrieved rows in the desired output.
197
241
242
+ PostgreSQL may return rows in an unpredictable order. Therefore, `values_dict`
243
+ is created to quickly look up the correct row using the keys, since these are
244
+ actually in the correct order.
245
+ """
246
+ values_dict = defaultdict (list )
247
+ for row in rows if rows is not None else []:
248
+ values_dict [
249
+ row [0 ] if isinstance (row [0 ], bytes ) else row [0 ].tobytes ()
250
+ ].append (row [1 :])
251
+
252
+ result : List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]] = []
253
+ for key in keys :
254
+ if key in values_dict :
255
+ value = values_dict [key ]
256
+ res = {}
257
+ for feature_name , value_bin , event_ts in value :
258
+ val = ValueProto ()
259
+ val .ParseFromString (bytes (value_bin ))
260
+ res [feature_name ] = val
261
+ result .append ((event_ts , res ))
262
+ else :
263
+ result .append ((None , None ))
198
264
return result
199
265
200
266
def update (
0 commit comments