Skip to content

Commit ccc922b

Browse files
job-almekindersnick-amaya-sp
authored andcommitted
feat: Bump psycopg2 to psycopg3 for all Postgres components (feast-dev#4303)
* Makefile: Formatting Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Makefile: Exclude Snowflake tests for postgres offline store tests Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Bootstrap: Use conninfo Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Tests: Make connection string compatible with psycopg3 Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Tests: Test connection type pool and singleton Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Global: Replace conn.set_session() calls to be psycopg3 compatible Set connection read only Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Offline: Use psycopg3 Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Online: Use psycopg3 Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Online: Restructure online_write_batch Addition Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Online: Use correct placeholder Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Online: Handle bytes properly in online_read() Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Online: Whitespace Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Online: Open ConnectionPool Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Online: Add typehint Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Utils: Use psycopg3 Use new ConnectionPool Pass kwargs as named argument Use executemany over execute_values Remove not-required open argument in psycopg.connect Improve Use SpooledTemporaryFile Use max_size and add docstring Properly write with StringIO Utils: Use SpooledTemporaryFile over StringIO object Add replace Fix df_to_postgres_table Remove import Utils Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Lint: Raise exceptions if cursor returned no columns or rows Add log statement Lint: Fix _to_arrow_internal Lint: Fix _get_entity_df_event_timestamp_range Update exception Use ZeroColumnQueryResult Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Add comment on +psycopg string Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Docs: Remove mention of psycopg2 Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Lint: Fix Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Default to postgresql+psycopg and log warning Update warning Fix Format warning Add typehints Use better variable name Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> * Solve merge conflicts Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com> --------- Signed-off-by: Job Almekinders <job.almekinders@teampicnic.com>
1 parent 5a8edc5 commit ccc922b

File tree

18 files changed

+925
-408
lines changed

18 files changed

+925
-408
lines changed

Makefile

+4-3
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ install-python:
6565
python setup.py develop
6666

6767
lock-python-dependencies:
68-
uv pip compile --system --no-strip-extras setup.py --output-file sdk/python/requirements/py$(PYTHON)-requirements.txt
68+
uv pip compile --system --no-strip-extras setup.py --output-file sdk/python/requirements/py$(PYTHON)-requirements.txt
6969

7070
lock-python-dependencies-all:
7171
pixi run --environment py39 --manifest-path infra/scripts/pixi/pixi.toml "uv pip compile --system --no-strip-extras setup.py --output-file sdk/python/requirements/py3.9-requirements.txt"
@@ -164,7 +164,7 @@ test-python-universal-mssql:
164164
sdk/python/tests
165165

166166

167-
# To use Athena as an offline store, you need to create an Athena database and an S3 bucket on AWS.
167+
# To use Athena as an offline store, you need to create an Athena database and an S3 bucket on AWS.
168168
# https://docs.aws.amazon.com/athena/latest/ug/getting-started.html
169169
# Modify environment variables ATHENA_REGION, ATHENA_DATA_SOURCE, ATHENA_DATABASE, ATHENA_WORKGROUP or
170170
# ATHENA_S3_BUCKET_NAME according to your needs. If tests fail with the pytest -n 8 option, change the number to 1.
@@ -191,7 +191,7 @@ test-python-universal-athena:
191191
not s3_registry and \
192192
not test_snowflake" \
193193
sdk/python/tests
194-
194+
195195
test-python-universal-postgres-offline:
196196
PYTHONPATH='.' \
197197
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.offline_stores.contrib.postgres_repo_configuration \
@@ -209,6 +209,7 @@ test-python-universal-postgres-offline:
209209
not test_push_features_to_offline_store and \
210210
not gcs_registry and \
211211
not s3_registry and \
212+
not test_snowflake and \
212213
not test_universal_types" \
213214
sdk/python/tests
214215

docs/tutorials/using-scalable-registry.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ When this happens, your database is likely using what is referred to as an
4949
in `SQLAlchemy` terminology. See your database's documentation for examples on
5050
how to set its scheme in the Database URL.
5151

52-
`Psycopg2`, which is the database library leveraged by the online and offline
52+
`Psycopg`, which is the database library leveraged by the online and offline
5353
stores, is not impacted by the need to speak a particular dialect, and so the
5454
following only applies to the registry.
5555

sdk/python/feast/errors.py

+10
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,13 @@ def __init__(self, input_dict: dict):
389389
super().__init__(
390390
f"Failed to serialize the provided dictionary into a pandas DataFrame: {input_dict.keys()}"
391391
)
392+
393+
394+
class ZeroRowsQueryResult(Exception):
395+
def __init__(self, query: str):
396+
super().__init__(f"This query returned zero rows:\n{query}")
397+
398+
399+
class ZeroColumnQueryResult(Exception):
400+
def __init__(self, query: str):
401+
super().__init__(f"This query returned zero columns:\n{query}")

sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
import pandas as pd
2020
import pyarrow as pa
2121
from jinja2 import BaseLoader, Environment
22-
from psycopg2 import sql
22+
from psycopg import sql
2323
from pytz import utc
2424

2525
from feast.data_source import DataSource
26-
from feast.errors import InvalidEntityType
26+
from feast.errors import InvalidEntityType, ZeroColumnQueryResult, ZeroRowsQueryResult
2727
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
2828
from feast.infra.offline_stores import offline_utils
2929
from feast.infra.offline_stores.contrib.postgres_offline_store.postgres_source import (
@@ -274,8 +274,10 @@ def to_sql(self) -> str:
274274
def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table:
275275
with self._query_generator() as query:
276276
with _get_conn(self.config.offline_store) as conn, conn.cursor() as cur:
277-
conn.set_session(readonly=True)
277+
conn.read_only = True
278278
cur.execute(query)
279+
if not cur.description:
280+
raise ZeroColumnQueryResult(query)
279281
fields = [
280282
(c.name, pg_type_code_to_arrow(c.type_code))
281283
for c in cur.description
@@ -331,16 +333,19 @@ def _get_entity_df_event_timestamp_range(
331333
entity_df_event_timestamp.max().to_pydatetime(),
332334
)
333335
elif isinstance(entity_df, str):
334-
# If the entity_df is a string (SQL query), determine range
335-
# from table
336+
# If the entity_df is a string (SQL query), determine range from table
336337
with _get_conn(config.offline_store) as conn, conn.cursor() as cur:
337-
(
338-
cur.execute(
339-
f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max FROM ({entity_df}) as tmp_alias"
340-
),
341-
)
338+
query = f"""
339+
SELECT
340+
MIN({entity_df_event_timestamp_col}) AS min,
341+
MAX({entity_df_event_timestamp_col}) AS max
342+
FROM ({entity_df}) AS tmp_alias
343+
"""
344+
cur.execute(query)
342345
res = cur.fetchone()
343-
entity_df_event_timestamp_range = (res[0], res[1])
346+
if not res:
347+
raise ZeroRowsQueryResult(query)
348+
entity_df_event_timestamp_range = (res[0], res[1])
344349
else:
345350
raise InvalidEntityType(type(entity_df))
346351

sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres_source.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typeguard import typechecked
55

66
from feast.data_source import DataSource
7-
from feast.errors import DataSourceNoNameException
7+
from feast.errors import DataSourceNoNameException, ZeroColumnQueryResult
88
from feast.infra.utils.postgres.connection_utils import _get_conn
99
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
1010
from feast.protos.feast.core.SavedDataset_pb2 import (
@@ -111,7 +111,11 @@ def get_table_column_names_and_types(
111111
self, config: RepoConfig
112112
) -> Iterable[Tuple[str, str]]:
113113
with _get_conn(config.offline_store) as conn, conn.cursor() as cur:
114-
cur.execute(f"SELECT * FROM {self.get_table_query_string()} AS sub LIMIT 0")
114+
query = f"SELECT * FROM {self.get_table_query_string()} AS sub LIMIT 0"
115+
cur.execute(query)
116+
if not cur.description:
117+
raise ZeroColumnQueryResult(query)
118+
115119
return (
116120
(c.name, pg_type_code_to_pg_type(c.type_code)) for c in cur.description
117121
)

sdk/python/feast/infra/online_stores/contrib/postgres.py

+66-54
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,22 @@
22
import logging
33
from collections import defaultdict
44
from datetime import datetime
5-
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
5+
from typing import (
6+
Any,
7+
Callable,
8+
Dict,
9+
Generator,
10+
List,
11+
Literal,
12+
Optional,
13+
Sequence,
14+
Tuple,
15+
)
616

7-
import psycopg2
817
import pytz
9-
from psycopg2 import sql
10-
from psycopg2.extras import execute_values
11-
from psycopg2.pool import SimpleConnectionPool
18+
from psycopg import sql
19+
from psycopg.connection import Connection
20+
from psycopg_pool import ConnectionPool
1221

1322
from feast import Entity
1423
from feast.feature_view import FeatureView
@@ -39,15 +48,17 @@ class PostgreSQLOnlineStoreConfig(PostgreSQLConfig):
3948

4049

4150
class PostgreSQLOnlineStore(OnlineStore):
42-
_conn: Optional[psycopg2._psycopg.connection] = None
43-
_conn_pool: Optional[SimpleConnectionPool] = None
51+
_conn: Optional[Connection] = None
52+
_conn_pool: Optional[ConnectionPool] = None
4453

4554
@contextlib.contextmanager
46-
def _get_conn(self, config: RepoConfig):
55+
def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]:
4756
assert config.online_store.type == "postgres"
57+
4858
if config.online_store.conn_type == ConnectionType.pool:
4959
if not self._conn_pool:
5060
self._conn_pool = _get_connection_pool(config.online_store)
61+
self._conn_pool.open()
5162
connection = self._conn_pool.getconn()
5263
yield connection
5364
self._conn_pool.putconn(connection)
@@ -64,57 +75,56 @@ def online_write_batch(
6475
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
6576
],
6677
progress: Optional[Callable[[int], Any]],
78+
batch_size: int = 5000,
6779
) -> None:
68-
project = config.project
80+
# Format insert values
81+
insert_values = []
82+
for entity_key, values, timestamp, created_ts in data:
83+
entity_key_bin = serialize_entity_key(
84+
entity_key,
85+
entity_key_serialization_version=config.entity_key_serialization_version,
86+
)
87+
timestamp = _to_naive_utc(timestamp)
88+
if created_ts is not None:
89+
created_ts = _to_naive_utc(created_ts)
6990

70-
with self._get_conn(config) as conn, conn.cursor() as cur:
71-
insert_values = []
72-
for entity_key, values, timestamp, created_ts in data:
73-
entity_key_bin = serialize_entity_key(
74-
entity_key,
75-
entity_key_serialization_version=config.entity_key_serialization_version,
76-
)
77-
timestamp = _to_naive_utc(timestamp)
78-
if created_ts is not None:
79-
created_ts = _to_naive_utc(created_ts)
80-
81-
for feature_name, val in values.items():
82-
vector_val = None
83-
if config.online_store.pgvector_enabled:
84-
vector_val = get_list_val_str(val)
85-
insert_values.append(
86-
(
87-
entity_key_bin,
88-
feature_name,
89-
val.SerializeToString(),
90-
vector_val,
91-
timestamp,
92-
created_ts,
93-
)
91+
for feature_name, val in values.items():
92+
vector_val = None
93+
if config.online_store.pgvector_enabled:
94+
vector_val = get_list_val_str(val)
95+
insert_values.append(
96+
(
97+
entity_key_bin,
98+
feature_name,
99+
val.SerializeToString(),
100+
vector_val,
101+
timestamp,
102+
created_ts,
94103
)
95-
# Control the batch so that we can update the progress
96-
batch_size = 5000
104+
)
105+
106+
# Create insert query
107+
sql_query = sql.SQL(
108+
"""
109+
INSERT INTO {}
110+
(entity_key, feature_name, value, vector_value, event_ts, created_ts)
111+
VALUES (%s, %s, %s, %s, %s, %s)
112+
ON CONFLICT (entity_key, feature_name) DO
113+
UPDATE SET
114+
value = EXCLUDED.value,
115+
vector_value = EXCLUDED.vector_value,
116+
event_ts = EXCLUDED.event_ts,
117+
created_ts = EXCLUDED.created_ts;
118+
"""
119+
).format(sql.Identifier(_table_id(config.project, table)))
120+
121+
# Push data in batches to online store
122+
with self._get_conn(config) as conn, conn.cursor() as cur:
97123
for i in range(0, len(insert_values), batch_size):
98124
cur_batch = insert_values[i : i + batch_size]
99-
execute_values(
100-
cur,
101-
sql.SQL(
102-
"""
103-
INSERT INTO {}
104-
(entity_key, feature_name, value, vector_value, event_ts, created_ts)
105-
VALUES %s
106-
ON CONFLICT (entity_key, feature_name) DO
107-
UPDATE SET
108-
value = EXCLUDED.value,
109-
vector_value = EXCLUDED.vector_value,
110-
event_ts = EXCLUDED.event_ts,
111-
created_ts = EXCLUDED.created_ts;
112-
""",
113-
).format(sql.Identifier(_table_id(project, table))),
114-
cur_batch,
115-
page_size=batch_size,
116-
)
125+
cur.executemany(sql_query, cur_batch)
117126
conn.commit()
127+
118128
if progress:
119129
progress(len(cur_batch))
120130

@@ -172,7 +182,9 @@ def online_read(
172182
# when we iterate through the keys since they are in the correct order
173183
values_dict = defaultdict(list)
174184
for row in rows if rows is not None else []:
175-
values_dict[row[0].tobytes()].append(row[1:])
185+
values_dict[
186+
row[0] if isinstance(row[0], bytes) else row[0].tobytes()
187+
].append(row[1:])
176188

177189
for key in keys:
178190
if key in values_dict:

0 commit comments

Comments
 (0)