Skip to content

Commit ceae20c

Browse files
author
Lorena Bălan
authored
[KED-2865] Make sql datasets use a singleton pattern for connection (#1163)
Signed-off-by: lorenabalan <lorena.balan@quantumblack.com>
1 parent 4e75b7d commit ceae20c

File tree

3 files changed

+230
-159
lines changed

3 files changed

+230
-159
lines changed

RELEASE.md

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Major features and improvements
44
* `pipeline` now accepts `tags` and a collection of `Node`s and/or `Pipeline`s rather than just a single `Pipeline` object. `pipeline` should be used in preference to `Pipeline` when creating a Kedro pipeline.
5+
* `pandas.SQLTableDataSet` and `pandas.SQLQueryDataSet` now only open one connection per database, at instantiation time (therefore at catalog creation time), rather than one per load/save operation.
56

67
## Bug fixes and other changes
78
* Added tutorial documentation for experiment tracking (`03_tutorial/07_set_up_experiment_tracking.md`).

kedro/extras/datasets/pandas/sql_dataset.py

+63-35
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,11 @@ class SQLTableDataSet(AbstractDataSet):
147147
148148
"""
149149

150-
DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
151-
DEFAULT_SAVE_ARGS = {"index": False} # type: Dict[str, Any]
150+
DEFAULT_LOAD_ARGS: Dict[str, Any] = {}
151+
DEFAULT_SAVE_ARGS: Dict[str, Any] = {"index": False}
152+
# using Any because of Sphinx but it should be
153+
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
154+
engines: Dict[str, Any] = {}
152155

153156
def __init__(
154157
self,
@@ -207,42 +210,50 @@ def __init__(
207210
self._load_args["table_name"] = table_name
208211
self._save_args["name"] = table_name
209212

210-
self._load_args["con"] = self._save_args["con"] = credentials["con"]
213+
self._connection_str = credentials["con"]
214+
self.create_connection(self._connection_str)
215+
216+
@classmethod
217+
def create_connection(cls, connection_str: str) -> None:
218+
"""Given a connection string, create singleton connection
219+
to be used across all instances of `SQLTableDataSet` that
220+
need to connect to the same source.
221+
"""
222+
if connection_str in cls.engines:
223+
return
224+
225+
try:
226+
engine = create_engine(connection_str)
227+
except ImportError as import_error:
228+
raise _get_missing_module_error(import_error) from import_error
229+
except NoSuchModuleError as exc:
230+
raise _get_sql_alchemy_missing_error() from exc
231+
232+
cls.engines[connection_str] = engine
211233

212234
def _describe(self) -> Dict[str, Any]:
213-
load_args = self._load_args.copy()
214-
save_args = self._save_args.copy()
235+
load_args = copy.deepcopy(self._load_args)
236+
save_args = copy.deepcopy(self._save_args)
215237
del load_args["table_name"]
216-
del load_args["con"]
217238
del save_args["name"]
218-
del save_args["con"]
219239
return dict(
220240
table_name=self._load_args["table_name"],
221241
load_args=load_args,
222242
save_args=save_args,
223243
)
224244

225245
def _load(self) -> pd.DataFrame:
226-
try:
227-
return pd.read_sql_table(**self._load_args)
228-
except ImportError as import_error:
229-
raise _get_missing_module_error(import_error) from import_error
230-
except NoSuchModuleError as exc:
231-
raise _get_sql_alchemy_missing_error() from exc
246+
engine = self.engines[self._connection_str] # type:ignore
247+
return pd.read_sql_table(con=engine, **self._load_args)
232248

233249
def _save(self, data: pd.DataFrame) -> None:
234-
try:
235-
data.to_sql(**self._save_args)
236-
except ImportError as import_error:
237-
raise _get_missing_module_error(import_error) from import_error
238-
except NoSuchModuleError as exc:
239-
raise _get_sql_alchemy_missing_error() from exc
250+
engine = self.engines[self._connection_str] # type: ignore
251+
data.to_sql(con=engine, **self._save_args)
240252

241253
def _exists(self) -> bool:
242-
eng = create_engine(self._load_args["con"])
254+
eng = self.engines[self._connection_str] # type: ignore
243255
schema = self._load_args.get("schema", None)
244256
exists = self._load_args["table_name"] in eng.table_names(schema)
245-
eng.dispose()
246257
return exists
247258

248259

@@ -299,6 +310,10 @@ class SQLQueryDataSet(AbstractDataSet):
299310
300311
"""
301312

313+
# using Any because of Sphinx but it should be
314+
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
315+
engines: Dict[str, Any] = {}
316+
302317
def __init__( # pylint: disable=too-many-arguments
303318
self,
304319
sql: str = None,
@@ -374,32 +389,45 @@ def __init__( # pylint: disable=too-many-arguments
374389
self._protocol = protocol
375390
self._fs = fsspec.filesystem(self._protocol, **_fs_credentials, **_fs_args)
376391
self._filepath = path
377-
self._load_args["con"] = credentials["con"]
392+
self._connection_str = credentials["con"]
393+
self.create_connection(self._connection_str)
394+
395+
@classmethod
396+
def create_connection(cls, connection_str: str) -> None:
397+
"""Given a connection string, create singleton connection
398+
to be used across all instances of `SQLQueryDataSet` that
399+
need to connect to the same source.
400+
"""
401+
if connection_str in cls.engines:
402+
return
403+
404+
try:
405+
engine = create_engine(connection_str)
406+
except ImportError as import_error:
407+
raise _get_missing_module_error(import_error) from import_error
408+
except NoSuchModuleError as exc:
409+
raise _get_sql_alchemy_missing_error() from exc
410+
411+
cls.engines[connection_str] = engine
378412

379413
def _describe(self) -> Dict[str, Any]:
380414
load_args = copy.deepcopy(self._load_args)
381-
desc = {}
382-
desc["sql"] = str(load_args.pop("sql", None))
383-
desc["filepath"] = str(self._filepath)
384-
del load_args["con"]
385-
desc["load_args"] = str(load_args)
386-
387-
return desc
415+
return dict(
416+
sql=str(load_args.pop("sql", None)),
417+
filepath=str(self._filepath),
418+
load_args=str(load_args),
419+
)
388420

389421
def _load(self) -> pd.DataFrame:
390422
load_args = copy.deepcopy(self._load_args)
423+
engine = self.engines[self._connection_str] # type: ignore
391424

392425
if self._filepath:
393426
load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol)
394427
with self._fs.open(load_path, mode="r") as fs_file:
395428
load_args["sql"] = fs_file.read()
396429

397-
try:
398-
return pd.read_sql_query(**load_args)
399-
except ImportError as import_error:
400-
raise _get_missing_module_error(import_error) from import_error
401-
except NoSuchModuleError as exc:
402-
raise _get_sql_alchemy_missing_error() from exc
430+
return pd.read_sql_query(con=engine, **load_args)
403431

404432
def _save(self, data: pd.DataFrame) -> None:
405433
raise DataSetError("`save` is not supported on SQLQueryDataSet")

0 commit comments

Comments
 (0)