2
2
import logging
3
3
from collections import defaultdict
4
4
from datetime import datetime
5
- from typing import Any , Callable , Dict , List , Literal , Optional , Sequence , Tuple
5
+ from typing import (
6
+ Any ,
7
+ Callable ,
8
+ Dict ,
9
+ Generator ,
10
+ List ,
11
+ Literal ,
12
+ Optional ,
13
+ Sequence ,
14
+ Tuple ,
15
+ )
6
16
7
- import psycopg2
8
17
import pytz
9
- from psycopg2 import sql
10
- from psycopg2 . extras import execute_values
11
- from psycopg2 . pool import SimpleConnectionPool
18
+ from psycopg import sql
19
+ from psycopg . connection import Connection
20
+ from psycopg_pool import ConnectionPool
12
21
13
22
from feast import Entity
14
23
from feast .feature_view import FeatureView
@@ -39,15 +48,17 @@ class PostgreSQLOnlineStoreConfig(PostgreSQLConfig):
39
48
40
49
41
50
class PostgreSQLOnlineStore (OnlineStore ):
42
- _conn : Optional [psycopg2 . _psycopg . connection ] = None
43
- _conn_pool : Optional [SimpleConnectionPool ] = None
51
+ _conn : Optional [Connection ] = None
52
+ _conn_pool : Optional [ConnectionPool ] = None
44
53
45
54
@contextlib .contextmanager
46
- def _get_conn (self , config : RepoConfig ):
55
+ def _get_conn (self , config : RepoConfig ) -> Generator [ Connection , Any , Any ] :
47
56
assert config .online_store .type == "postgres"
57
+
48
58
if config .online_store .conn_type == ConnectionType .pool :
49
59
if not self ._conn_pool :
50
60
self ._conn_pool = _get_connection_pool (config .online_store )
61
+ self ._conn_pool .open ()
51
62
connection = self ._conn_pool .getconn ()
52
63
yield connection
53
64
self ._conn_pool .putconn (connection )
@@ -64,57 +75,56 @@ def online_write_batch(
64
75
Tuple [EntityKeyProto , Dict [str , ValueProto ], datetime , Optional [datetime ]]
65
76
],
66
77
progress : Optional [Callable [[int ], Any ]],
78
+ batch_size : int = 5000 ,
67
79
) -> None :
68
- project = config .project
80
+ # Format insert values
81
+ insert_values = []
82
+ for entity_key , values , timestamp , created_ts in data :
83
+ entity_key_bin = serialize_entity_key (
84
+ entity_key ,
85
+ entity_key_serialization_version = config .entity_key_serialization_version ,
86
+ )
87
+ timestamp = _to_naive_utc (timestamp )
88
+ if created_ts is not None :
89
+ created_ts = _to_naive_utc (created_ts )
69
90
70
- with self ._get_conn (config ) as conn , conn .cursor () as cur :
71
- insert_values = []
72
- for entity_key , values , timestamp , created_ts in data :
73
- entity_key_bin = serialize_entity_key (
74
- entity_key ,
75
- entity_key_serialization_version = config .entity_key_serialization_version ,
76
- )
77
- timestamp = _to_naive_utc (timestamp )
78
- if created_ts is not None :
79
- created_ts = _to_naive_utc (created_ts )
80
-
81
- for feature_name , val in values .items ():
82
- vector_val = None
83
- if config .online_store .pgvector_enabled :
84
- vector_val = get_list_val_str (val )
85
- insert_values .append (
86
- (
87
- entity_key_bin ,
88
- feature_name ,
89
- val .SerializeToString (),
90
- vector_val ,
91
- timestamp ,
92
- created_ts ,
93
- )
91
+ for feature_name , val in values .items ():
92
+ vector_val = None
93
+ if config .online_store .pgvector_enabled :
94
+ vector_val = get_list_val_str (val )
95
+ insert_values .append (
96
+ (
97
+ entity_key_bin ,
98
+ feature_name ,
99
+ val .SerializeToString (),
100
+ vector_val ,
101
+ timestamp ,
102
+ created_ts ,
94
103
)
95
- # Control the batch so that we can update the progress
96
- batch_size = 5000
104
+ )
105
+
106
+ # Create insert query
107
+ sql_query = sql .SQL (
108
+ """
109
+ INSERT INTO {}
110
+ (entity_key, feature_name, value, vector_value, event_ts, created_ts)
111
+ VALUES (%s, %s, %s, %s, %s, %s)
112
+ ON CONFLICT (entity_key, feature_name) DO
113
+ UPDATE SET
114
+ value = EXCLUDED.value,
115
+ vector_value = EXCLUDED.vector_value,
116
+ event_ts = EXCLUDED.event_ts,
117
+ created_ts = EXCLUDED.created_ts;
118
+ """
119
+ ).format (sql .Identifier (_table_id (config .project , table )))
120
+
121
+ # Push data in batches to online store
122
+ with self ._get_conn (config ) as conn , conn .cursor () as cur :
97
123
for i in range (0 , len (insert_values ), batch_size ):
98
124
cur_batch = insert_values [i : i + batch_size ]
99
- execute_values (
100
- cur ,
101
- sql .SQL (
102
- """
103
- INSERT INTO {}
104
- (entity_key, feature_name, value, vector_value, event_ts, created_ts)
105
- VALUES %s
106
- ON CONFLICT (entity_key, feature_name) DO
107
- UPDATE SET
108
- value = EXCLUDED.value,
109
- vector_value = EXCLUDED.vector_value,
110
- event_ts = EXCLUDED.event_ts,
111
- created_ts = EXCLUDED.created_ts;
112
- """ ,
113
- ).format (sql .Identifier (_table_id (project , table ))),
114
- cur_batch ,
115
- page_size = batch_size ,
116
- )
125
+ cur .executemany (sql_query , cur_batch )
117
126
conn .commit ()
127
+
118
128
if progress :
119
129
progress (len (cur_batch ))
120
130
@@ -172,7 +182,9 @@ def online_read(
172
182
# when we iterate through the keys since they are in the correct order
173
183
values_dict = defaultdict (list )
174
184
for row in rows if rows is not None else []:
175
- values_dict [row [0 ].tobytes ()].append (row [1 :])
185
+ values_dict [
186
+ row [0 ] if isinstance (row [0 ], bytes ) else row [0 ].tobytes ()
187
+ ].append (row [1 :])
176
188
177
189
for key in keys :
178
190
if key in values_dict :
0 commit comments