Skip to content

Commit de3e7ae

Browse files
author
lorenabalan
committed
Keep mapping of connections. Fix unit tests
Signed-off-by: lorenabalan <lorena.balan@quantumblack.com>
1 parent 03c8c2d commit de3e7ae

File tree

2 files changed

+126
-114
lines changed

2 files changed

+126
-114
lines changed

kedro/extras/datasets/pandas/sql_dataset.py

+52-52
Original file line numberDiff line numberDiff line change
@@ -207,60 +207,57 @@ def __init__(
207207
self._load_args["table_name"] = table_name
208208
self._save_args["name"] = table_name
209209

210-
self._load_args["con"] = self._save_args["con"] = credentials["con"]
211-
self.create_connection(self._load_args["con"])
210+
self._connection_str = credentials["con"]
211+
self.create_connection(self._connection_str)
212212

213213
@classmethod
214-
def create_connection(cls, con):
215-
"""Create singleton connection to be used
216-
across all instances of `SQLTableDataSet`.
214+
def create_connection(cls, connection_str: str) -> None:
215+
"""Given a connection string, create singleton connection
216+
to be used across all instances of `SQLTableDataSet` that
217+
need to connect to the same source.
217218
"""
218-
if hasattr(cls, "engine"):
219+
if connection_str in getattr(cls, "engines", {}):
219220
return
220221

221-
engine = create_engine(con)
222-
cls.engine = engine
222+
engines = cls.engines if hasattr(cls, "engines") else {} # type:ignore
223+
224+
try:
225+
engine = create_engine(connection_str)
226+
except ImportError as import_error:
227+
raise _get_missing_module_error(import_error) from import_error
228+
except NoSuchModuleError as exc:
229+
raise _get_sql_alchemy_missing_error() from exc
230+
231+
engines[connection_str] = engine
232+
cls.engines = engines # type: ignore
223233

224234
def _describe(self) -> Dict[str, Any]:
225-
load_args = self._load_args.copy()
226-
save_args = self._save_args.copy()
235+
load_args = copy.deepcopy(self._load_args)
236+
save_args = copy.deepcopy(self._save_args)
227237
del load_args["table_name"]
228-
del load_args["con"]
229238
del save_args["name"]
230-
del save_args["con"]
231239
return dict(
232240
table_name=self._load_args["table_name"],
233241
load_args=load_args,
234242
save_args=save_args,
235243
)
236244

237245
def _load(self) -> pd.DataFrame:
238-
load_args = copy.deepcopy(self._load_args)
239-
load_args["con"] = self.engine # type: ignore
246+
engine = self.engines.get(self._connection_str) # type:ignore
240247

241-
try:
242-
return pd.read_sql_table(**load_args)
243-
except ImportError as import_error:
244-
raise _get_missing_module_error(import_error) from import_error
245-
except NoSuchModuleError as exc:
246-
raise _get_sql_alchemy_missing_error() from exc
248+
# TODO: handle engine = None
249+
return pd.read_sql_table(con=engine, **self._load_args)
247250

248251
def _save(self, data: pd.DataFrame) -> None:
249-
save_args = copy.deepcopy(self._save_args)
250-
save_args["con"] = self.engine # type: ignore
252+
engine = self.engines.get(self._connection_str) # type: ignore
251253

252-
try:
253-
data.to_sql(**save_args)
254-
except ImportError as import_error:
255-
raise _get_missing_module_error(import_error) from import_error
256-
except NoSuchModuleError as exc:
257-
raise _get_sql_alchemy_missing_error() from exc
254+
# TODO: handle engine = None
255+
data.to_sql(con=engine, **self._save_args)
258256

259257
def _exists(self) -> bool:
260-
eng = self.engine # type: ignore
258+
eng = self.engines[self._connection_str] # type: ignore
261259
schema = self._load_args.get("schema", None)
262260
exists = self._load_args["table_name"] in eng.table_names(schema)
263-
# eng.dispose()
264261
return exists
265262

266263

@@ -392,45 +389,48 @@ def __init__( # pylint: disable=too-many-arguments
392389
self._protocol = protocol
393390
self._fs = fsspec.filesystem(self._protocol, **_fs_credentials, **_fs_args)
394391
self._filepath = path
395-
self._load_args["con"] = credentials["con"]
396-
self.create_connection(self._load_args["con"])
392+
self._connection_str = credentials["con"]
393+
self.create_connection(self._connection_str)
397394

398395
@classmethod
399-
def create_connection(cls, con):
400-
"""Create singleton connection to be used
401-
across all instances of `SQLQueryDataSet`.
396+
def create_connection(cls, connection_str: str) -> None:
397+
"""Given a connection string, create singleton connection
398+
to be used across all instances of `SQLQueryDataSet` that
399+
need to connect to the same source.
402400
"""
403-
if hasattr(cls, "engine"):
401+
if connection_str in getattr(cls, "engines", {}):
404402
return
405403

406-
engine = create_engine(con)
407-
cls.engine = engine
404+
engines = cls.engines if hasattr(cls, "engines") else {} # type:ignore
405+
406+
try:
407+
engine = create_engine(connection_str)
408+
except ImportError as import_error:
409+
raise _get_missing_module_error(import_error) from import_error
410+
except NoSuchModuleError as exc:
411+
raise _get_sql_alchemy_missing_error() from exc
412+
413+
engines[connection_str] = engine
414+
cls.engines = engines # type: ignore
408415

409416
def _describe(self) -> Dict[str, Any]:
410417
load_args = copy.deepcopy(self._load_args)
411-
desc = {}
412-
desc["sql"] = str(load_args.pop("sql", None))
413-
desc["filepath"] = str(self._filepath)
414-
del load_args["con"]
415-
desc["load_args"] = str(load_args)
416-
417-
return desc
418+
return dict(
419+
sql=str(load_args.pop("sql", None)),
420+
filepath=str(self._filepath),
421+
load_args=str(load_args),
422+
)
418423

419424
def _load(self) -> pd.DataFrame:
420425
load_args = copy.deepcopy(self._load_args)
421-
load_args["con"] = self.engine # type: ignore
426+
engine = self.engines[self._connection_str] # type: ignore
422427

423428
if self._filepath:
424429
load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol)
425430
with self._fs.open(load_path, mode="r") as fs_file:
426431
load_args["sql"] = fs_file.read()
427432

428-
try:
429-
return pd.read_sql_query(**load_args)
430-
except ImportError as import_error:
431-
raise _get_missing_module_error(import_error) from import_error
432-
except NoSuchModuleError as exc:
433-
raise _get_sql_alchemy_missing_error() from exc
433+
return pd.read_sql_query(con=engine, **load_args)
434434

435435
def _save(self, data: pd.DataFrame) -> None:
436436
raise DataSetError("`save` is not supported on SQLQueryDataSet")

0 commit comments

Comments
 (0)