Skip to content

Commit f569786

Browse files
stanconiaStanley Opara
and
Stanley Opara
authored
feat: Add Async refresh to Sql Registry (#4251)
* Add sql registry async refresh Signed-off-by: Stanley Opara <a-sopara@expediagroup.com> * make refresh code a daemon thread Signed-off-by: Stanley Opara <a-sopara@expediagroup.com> * Change RegistryConfig to cacheMode Signed-off-by: Stanley Opara <a-sopara@expediagroup.com> * Only run async when ttl > 0 Signed-off-by: Stanley Opara <a-sopara@expediagroup.com> * make refresh async run in a loop Signed-off-by: Stanley Opara <a-sopara@expediagroup.com> * make refresh async run in a loop Signed-off-by: Stanley Opara <a-sopara@expediagroup.com> * Reorder async refresh call Signed-off-by: Stanley Opara <a-sopara@expediagroup.com> * Add documentation Signed-off-by: Stanley Opara <a-sopara@expediagroup.com> * Update test_universal_registry.py Signed-off-by: Stanley Opara <a-sopara@expediagroup.com> * Force rerun of tests Signed-off-by: Stanley Opara <a-sopara@expediagroup.com> * Force rerun of tests Signed-off-by: Stanley Opara <a-sopara@expediagroup.com> * Format repo config file Signed-off-by: Stanley Opara <a-sopara@expediagroup.com> --------- Signed-off-by: Stanley Opara <a-sopara@expediagroup.com> Co-authored-by: Stanley Opara <a-sopara@expediagroup.com>
1 parent cea52e9 commit f569786

File tree

4 files changed

+145
-32
lines changed

4 files changed

+145
-32
lines changed

sdk/python/feast/infra/registry/caching_registry.py

+36-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import atexit
12
import logging
3+
import threading
24
from abc import abstractmethod
35
from datetime import timedelta
46
from threading import Lock
@@ -21,18 +23,18 @@
2123

2224

2325
class CachingRegistry(BaseRegistry):
24-
def __init__(
25-
self,
26-
project: str,
27-
cache_ttl_seconds: int,
28-
):
26+
def __init__(self, project: str, cache_ttl_seconds: int, cache_mode: str):
2927
self.cached_registry_proto = self.proto()
3028
proto_registry_utils.init_project_metadata(self.cached_registry_proto, project)
3129
self.cached_registry_proto_created = _utc_now()
3230
self._refresh_lock = Lock()
3331
self.cached_registry_proto_ttl = timedelta(
3432
seconds=cache_ttl_seconds if cache_ttl_seconds is not None else 0
3533
)
34+
self.cache_mode = cache_mode
35+
if cache_mode == "thread":
36+
self._start_thread_async_refresh(cache_ttl_seconds)
37+
atexit.register(self._exit_handler)
3638

3739
@abstractmethod
3840
def _get_data_source(self, name: str, project: str) -> DataSource:
@@ -322,22 +324,35 @@ def refresh(self, project: Optional[str] = None):
322324
self.cached_registry_proto_created = _utc_now()
323325

324326
def _refresh_cached_registry_if_necessary(self):
325-
with self._refresh_lock:
326-
expired = (
327-
self.cached_registry_proto is None
328-
or self.cached_registry_proto_created is None
329-
) or (
330-
self.cached_registry_proto_ttl.total_seconds()
331-
> 0 # 0 ttl means infinity
332-
and (
333-
_utc_now()
334-
> (
335-
self.cached_registry_proto_created
336-
+ self.cached_registry_proto_ttl
327+
if self.cache_mode == "sync":
328+
with self._refresh_lock:
329+
expired = (
330+
self.cached_registry_proto is None
331+
or self.cached_registry_proto_created is None
332+
) or (
333+
self.cached_registry_proto_ttl.total_seconds()
334+
> 0 # 0 ttl means infinity
335+
and (
336+
_utc_now()
337+
> (
338+
self.cached_registry_proto_created
339+
+ self.cached_registry_proto_ttl
340+
)
337341
)
338342
)
339-
)
343+
if expired:
344+
logger.info("Registry cache expired, so refreshing")
345+
self.refresh()
346+
347+
def _start_thread_async_refresh(self, cache_ttl_seconds):
348+
self.refresh()
349+
if cache_ttl_seconds <= 0:
350+
return
351+
self.registry_refresh_thread = threading.Timer(
352+
cache_ttl_seconds, self._start_thread_async_refresh, [cache_ttl_seconds]
353+
)
354+
self.registry_refresh_thread.setDaemon(True)
355+
self.registry_refresh_thread.start()
340356

341-
if expired:
342-
logger.info("Registry cache expired, so refreshing")
343-
self.refresh()
357+
def _exit_handler(self):
358+
self.registry_refresh_thread.cancel()

sdk/python/feast/infra/registry/sql.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def __init__(
193193
)
194194
metadata.create_all(self.engine)
195195
super().__init__(
196-
project=project, cache_ttl_seconds=registry_config.cache_ttl_seconds
196+
project=project,
197+
cache_ttl_seconds=registry_config.cache_ttl_seconds,
198+
cache_mode=registry_config.cache_mode,
197199
)
198200

199201
def teardown(self):

sdk/python/feast/repo_config.py

+3
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ class RegistryConfig(FeastBaseModel):
124124
sqlalchemy_config_kwargs: Dict[str, Any] = {}
125125
""" Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """
126126

127+
cache_mode: StrictStr = "sync"
128+
""" str: Cache mode type, Possible options are sync and thread(asynchronous caching using threading library)"""
129+
127130
@field_validator("path")
128131
def validate_path(cls, path: str, values: ValidationInfo) -> str:
129132
if values.data.get("registry_type") == "sql":

sdk/python/tests/integration/registration/test_universal_registry.py

+103-10
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def minio_registry() -> Registry:
125125
logger = logging.getLogger(__name__)
126126

127127

128-
@pytest.fixture(scope="session")
128+
@pytest.fixture(scope="function")
129129
def pg_registry():
130130
container = (
131131
DockerContainer("postgres:latest")
@@ -137,6 +137,35 @@ def pg_registry():
137137

138138
container.start()
139139

140+
registry_config = _given_registry_config_for_pg_sql(container)
141+
142+
yield SqlRegistry(registry_config, "project", None)
143+
144+
container.stop()
145+
146+
147+
@pytest.fixture(scope="function")
148+
def pg_registry_async():
149+
container = (
150+
DockerContainer("postgres:latest")
151+
.with_exposed_ports(5432)
152+
.with_env("POSTGRES_USER", POSTGRES_USER)
153+
.with_env("POSTGRES_PASSWORD", POSTGRES_PASSWORD)
154+
.with_env("POSTGRES_DB", POSTGRES_DB)
155+
)
156+
157+
container.start()
158+
159+
registry_config = _given_registry_config_for_pg_sql(container, 2, "thread")
160+
161+
yield SqlRegistry(registry_config, "project", None)
162+
163+
container.stop()
164+
165+
166+
def _given_registry_config_for_pg_sql(
167+
container, cache_ttl_seconds=2, cache_mode="sync"
168+
):
140169
log_string_to_wait_for = "database system is ready to accept connections"
141170
waited = wait_for_logs(
142171
container=container,
@@ -148,42 +177,57 @@ def pg_registry():
148177
container_port = container.get_exposed_port(5432)
149178
container_host = container.get_container_host_ip()
150179

151-
registry_config = RegistryConfig(
180+
return RegistryConfig(
152181
registry_type="sql",
182+
cache_ttl_seconds=cache_ttl_seconds,
183+
cache_mode=cache_mode,
153184
# The `path` must include `+psycopg` in order for `sqlalchemy.create_engine()`
154185
# to understand that we are using psycopg3.
155186
path=f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{container_host}:{container_port}/{POSTGRES_DB}",
156187
sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True},
157188
)
158189

190+
191+
@pytest.fixture(scope="function")
192+
def mysql_registry():
193+
container = MySqlContainer("mysql:latest")
194+
container.start()
195+
196+
registry_config = _given_registry_config_for_mysql(container)
197+
159198
yield SqlRegistry(registry_config, "project", None)
160199

161200
container.stop()
162201

163202

164-
@pytest.fixture(scope="session")
165-
def mysql_registry():
203+
@pytest.fixture(scope="function")
204+
def mysql_registry_async():
166205
container = MySqlContainer("mysql:latest")
167206
container.start()
168207

169-
# testing for the database to exist and ready to connect and start testing.
208+
registry_config = _given_registry_config_for_mysql(container, 2, "thread")
209+
210+
yield SqlRegistry(registry_config, "project", None)
211+
212+
container.stop()
213+
214+
215+
def _given_registry_config_for_mysql(container, cache_ttl_seconds=2, cache_mode="sync"):
170216
import sqlalchemy
171217

172218
engine = sqlalchemy.create_engine(
173219
container.get_connection_url(), pool_pre_ping=True
174220
)
175221
engine.connect()
176222

177-
registry_config = RegistryConfig(
223+
return RegistryConfig(
178224
registry_type="sql",
179225
path=container.get_connection_url(),
226+
cache_ttl_seconds=cache_ttl_seconds,
227+
cache_mode=cache_mode,
180228
sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True},
181229
)
182230

183-
yield SqlRegistry(registry_config, "project", None)
184-
185-
container.stop()
186-
187231

188232
@pytest.fixture(scope="session")
189233
def sqlite_registry():
@@ -269,6 +313,17 @@ def mock_remote_registry():
269313
lazy_fixture("sqlite_registry"),
270314
]
271315

316+
async_sql_fixtures = [
317+
pytest.param(
318+
lazy_fixture("pg_registry_async"),
319+
marks=pytest.mark.xdist_group(name="pg_registry_async"),
320+
),
321+
pytest.param(
322+
lazy_fixture("mysql_registry_async"),
323+
marks=pytest.mark.xdist_group(name="mysql_registry_async"),
324+
),
325+
]
326+
272327

273328
@pytest.mark.integration
274329
@pytest.mark.parametrize("test_registry", all_fixtures)
@@ -999,6 +1054,44 @@ def test_registry_cache(test_registry):
9991054
test_registry.teardown()
10001055

10011056

1057+
@pytest.mark.integration
1058+
@pytest.mark.parametrize(
1059+
"test_registry",
1060+
async_sql_fixtures,
1061+
)
1062+
def test_registry_cache_thread_async(test_registry):
1063+
# Create Feature View
1064+
batch_source = FileSource(
1065+
name="test_source",
1066+
file_format=ParquetFormat(),
1067+
path="file://feast/*",
1068+
timestamp_field="ts_col",
1069+
created_timestamp_column="timestamp",
1070+
)
1071+
1072+
project = "project"
1073+
1074+
# Register data source
1075+
test_registry.apply_data_source(batch_source, project)
1076+
registry_data_sources_cached = test_registry.list_data_sources(
1077+
project, allow_cache=True
1078+
)
1079+
# async ttl yet to expire, so there will be a cache miss
1080+
assert len(registry_data_sources_cached) == 0
1081+
1082+
# Wait for cache to be refreshed
1083+
time.sleep(4)
1084+
# Now objects exist
1085+
registry_data_sources_cached = test_registry.list_data_sources(
1086+
project, allow_cache=True
1087+
)
1088+
assert len(registry_data_sources_cached) == 1
1089+
registry_data_source = registry_data_sources_cached[0]
1090+
assert registry_data_source == batch_source
1091+
1092+
test_registry.teardown()
1093+
1094+
10021095
@pytest.mark.integration
10031096
@pytest.mark.parametrize(
10041097
"test_registry",

0 commit comments

Comments
 (0)