Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KED-2865] Make sql datasets use a singleton pattern for connection #1163

Merged
merged 19 commits into from
Feb 3, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@

## Major features and improvements
* `pipeline` now accepts `tags` and a collection of `Node`s and/or `Pipeline`s rather than just a single `Pipeline` object. `pipeline` should be used in preference to `Pipeline` when creating a Kedro pipeline.
* `pandas.SQLTableDataSet` and `pandas.SQLQueryDataSet` now only open one connection per database, at instantiation time (therefore at catalog creation time), rather than one per load/save operation.

## Bug fixes and other changes
* Added tutorial documentation for experiment tracking (`03_tutorial/07_set_up_experiment_tracking.md`).
98 changes: 63 additions & 35 deletions kedro/extras/datasets/pandas/sql_dataset.py
Original file line number Diff line number Diff line change
@@ -147,8 +147,11 @@ class SQLTableDataSet(AbstractDataSet):

"""

DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
DEFAULT_SAVE_ARGS = {"index": False} # type: Dict[str, Any]
DEFAULT_LOAD_ARGS: Dict[str, Any] = {}
DEFAULT_SAVE_ARGS: Dict[str, Any] = {"index": False}
# using Any because of Sphinx but it should be
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
engines: Dict[str, Any] = {}

def __init__(
self,
@@ -207,42 +210,50 @@ def __init__(
self._load_args["table_name"] = table_name
self._save_args["name"] = table_name

self._load_args["con"] = self._save_args["con"] = credentials["con"]
self._connection_str = credentials["con"]
self.create_connection(self._connection_str)

@classmethod
def create_connection(cls, connection_str: str) -> None:
"""Given a connection string, create singleton connection
to be used across all instances of `SQLTableDataSet` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return

try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc

cls.engines[connection_str] = engine

def _describe(self) -> Dict[str, Any]:
load_args = self._load_args.copy()
save_args = self._save_args.copy()
load_args = copy.deepcopy(self._load_args)
save_args = copy.deepcopy(self._save_args)
del load_args["table_name"]
del load_args["con"]
del save_args["name"]
del save_args["con"]
return dict(
table_name=self._load_args["table_name"],
load_args=load_args,
save_args=save_args,
)

def _load(self) -> pd.DataFrame:
try:
return pd.read_sql_table(**self._load_args)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
engine = self.engines[self._connection_str] # type:ignore
return pd.read_sql_table(con=engine, **self._load_args)

def _save(self, data: pd.DataFrame) -> None:
try:
data.to_sql(**self._save_args)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
engine = self.engines[self._connection_str] # type: ignore
data.to_sql(con=engine, **self._save_args)

def _exists(self) -> bool:
eng = create_engine(self._load_args["con"])
eng = self.engines[self._connection_str] # type: ignore
schema = self._load_args.get("schema", None)
exists = self._load_args["table_name"] in eng.table_names(schema)
eng.dispose()
return exists


@@ -299,6 +310,10 @@ class SQLQueryDataSet(AbstractDataSet):

"""

# using Any because of Sphinx but it should be
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
engines: Dict[str, Any] = {}

def __init__( # pylint: disable=too-many-arguments
self,
sql: str = None,
@@ -374,32 +389,45 @@ def __init__( # pylint: disable=too-many-arguments
self._protocol = protocol
self._fs = fsspec.filesystem(self._protocol, **_fs_credentials, **_fs_args)
self._filepath = path
self._load_args["con"] = credentials["con"]
self._connection_str = credentials["con"]
self.create_connection(self._connection_str)

@classmethod
def create_connection(cls, connection_str: str) -> None:
"""Given a connection string, create singleton connection
to be used across all instances of `SQLQueryDataSet` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return

try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc

cls.engines[connection_str] = engine

def _describe(self) -> Dict[str, Any]:
load_args = copy.deepcopy(self._load_args)
desc = {}
desc["sql"] = str(load_args.pop("sql", None))
desc["filepath"] = str(self._filepath)
del load_args["con"]
desc["load_args"] = str(load_args)

return desc
return dict(
sql=str(load_args.pop("sql", None)),
filepath=str(self._filepath),
load_args=str(load_args),
)

def _load(self) -> pd.DataFrame:
load_args = copy.deepcopy(self._load_args)
engine = self.engines[self._connection_str] # type: ignore

if self._filepath:
load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol)
with self._fs.open(load_path, mode="r") as fs_file:
load_args["sql"] = fs_file.read()

try:
return pd.read_sql_query(**load_args)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
return pd.read_sql_query(con=engine, **load_args)

def _save(self, data: pd.DataFrame) -> None:
raise DataSetError("`save` is not supported on SQLQueryDataSet")
290 changes: 166 additions & 124 deletions tests/extras/datasets/pandas/test_sql_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# pylint: disable=no-member

from pathlib import PosixPath
from typing import Any

import pandas as pd
import pytest
@@ -19,6 +17,13 @@
)


@pytest.fixture(autouse=True)
def cleanup_engines():
yield
SQLTableDataSet.engines = {}
SQLQueryDataSet.engines = {}


@pytest.fixture
def dummy_dataframe():
return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
@@ -52,10 +57,16 @@ def query_file_data_set(request, sql_file):
return SQLQueryDataSet(**kwargs)


class TestSQLTableDataSetLoad:
class TestSQLTableDataSet:
_unknown_conn = "mysql+unknown_module://scott:tiger@localhost/foo"

@staticmethod
def _assert_pd_called_once():
pd.read_sql_table.assert_called_once_with(table_name=TABLE_NAME, con=CONNECTION)
def _assert_sqlalchemy_called_once(*args):
_callable = sqlalchemy.engine.Engine.table_names
if args:
_callable.assert_called_once_with(*args)
else:
assert _callable.call_count == 1

def test_empty_table_name(self):
"""Check the error when instantiating with an empty table"""
@@ -73,65 +84,80 @@ def test_empty_connection(self):
with pytest.raises(DataSetError, match=pattern):
SQLTableDataSet(table_name=TABLE_NAME, credentials=dict(con=""))

def test_load_sql_params(self, mocker, table_data_set):
"""Test `load` method invocation"""
mocker.patch("pandas.read_sql_table")
table_data_set.load()
self._assert_pd_called_once()

def test_load_driver_missing(self, mocker, table_data_set):
def test_driver_missing(self, mocker):
"""Check the error when the sql driver is missing"""
mocker.patch(
"pandas.read_sql_table",
"kedro.extras.datasets.pandas.sql_dataset.create_engine",
side_effect=ImportError("No module named 'mysqldb'"),
)
with pytest.raises(DataSetError, match=ERROR_PREFIX + "mysqlclient"):
table_data_set.load()
self._assert_pd_called_once()
SQLTableDataSet(table_name=TABLE_NAME, credentials=dict(con=CONNECTION))

def test_invalid_module(self, mocker, table_data_set):
"""Test that if an invalid module/driver is encountered by SQLAlchemy
then the error should contain the original error message"""
_err = ImportError("Invalid module some_module")
mocker.patch("pandas.read_sql_table", side_effect=_err)
pattern = ERROR_PREFIX + r"Invalid module some\_module"
def test_unknown_sql(self):
"""Check the error when unknown sql dialect is provided;
this means the error is raised on catalog creation, rather
than on load or save operation.
"""
pattern = r"The SQL dialect in your connection is not supported by SQLAlchemy"
with pytest.raises(DataSetError, match=pattern):
table_data_set.load()
self._assert_pd_called_once()
SQLTableDataSet(table_name=TABLE_NAME, credentials=dict(con=FAKE_CONN_STR))

def test_load_unknown_module(self, mocker, table_data_set):
def test_unknown_module(self, mocker):
"""Test that if an unknown module/driver is encountered by SQLAlchemy
then the error should contain the original error message"""
mocker.patch(
"pandas.read_sql_table",
"kedro.extras.datasets.pandas.sql_dataset.create_engine",
side_effect=ImportError("No module named 'unknown_module'"),
)
pattern = ERROR_PREFIX + r"No module named \'unknown\_module\'"
with pytest.raises(DataSetError, match=pattern):
table_data_set.load()
SQLTableDataSet(table_name=TABLE_NAME, credentials=dict(con=CONNECTION))

def test_str_representation_table(self, table_data_set):
"""Test the data set instance string representation"""
str_repr = str(table_data_set)
assert (
"SQLTableDataSet(load_args={}, save_args={'index': False}, "
f"table_name={TABLE_NAME})" in str_repr
)
assert CONNECTION not in str(str_repr)

def test_table_exists(self, mocker, table_data_set):
"""Test `exists` method invocation"""
mocker.patch("sqlalchemy.engine.Engine.table_names")
assert not table_data_set.exists()
self._assert_sqlalchemy_called_once()

@pytest.mark.parametrize(
"table_data_set", [{"credentials": dict(con=FAKE_CONN_STR)}], indirect=True
"table_data_set", [{"load_args": dict(schema="ingested")}], indirect=True
)
def test_load_unknown_sql(self, table_data_set):
"""Check the error when unknown sql dialect is provided"""
pattern = r"The SQL dialect in your connection is not supported by SQLAlchemy"
with pytest.raises(DataSetError, match=pattern):
table_data_set.load()

def test_table_exists_schema(self, mocker, table_data_set):
"""Test `exists` method invocation with DB schema provided"""
mocker.patch("sqlalchemy.engine.Engine.table_names")
assert not table_data_set.exists()
self._assert_sqlalchemy_called_once("ingested")

class TestSQLTableDataSetSave:
_unknown_conn = "mysql+unknown_module://scott:tiger@localhost/foo"
def test_table_exists_mocked(self, mocker, table_data_set):
"""Test `exists` method invocation with mocked list of tables"""
mocker.patch("sqlalchemy.engine.Engine.table_names", return_value=[TABLE_NAME])
assert table_data_set.exists()
self._assert_sqlalchemy_called_once()

@staticmethod
def _assert_to_sql_called_once(df: Any, index: bool = False):
df.to_sql.assert_called_once_with(name=TABLE_NAME, con=CONNECTION, index=index)
def test_load_sql_params(self, mocker, table_data_set):
"""Test `load` method invocation"""
mocker.patch("pandas.read_sql_table")
table_data_set.load()
pd.read_sql_table.assert_called_once_with(
table_name=TABLE_NAME, con=table_data_set.engines[CONNECTION]
)

def test_save_default_index(self, mocker, table_data_set, dummy_dataframe):
"""Test `save` method invocation"""
mocker.patch.object(dummy_dataframe, "to_sql")
table_data_set.save(dummy_dataframe)
self._assert_to_sql_called_once(dummy_dataframe)
dummy_dataframe.to_sql.assert_called_once_with(
name=TABLE_NAME, con=table_data_set.engines[CONNECTION], index=False
)

@pytest.mark.parametrize(
"table_data_set", [{"save_args": dict(index=True)}], indirect=True
@@ -140,36 +166,9 @@ def test_save_overwrite_index(self, mocker, table_data_set, dummy_dataframe):
"""Test writing DataFrame index as a column"""
mocker.patch.object(dummy_dataframe, "to_sql")
table_data_set.save(dummy_dataframe)
self._assert_to_sql_called_once(dummy_dataframe, True)

def test_save_driver_missing(self, mocker, table_data_set, dummy_dataframe):
"""Test that if an unknown module/driver is encountered by SQLAlchemy
then the error should contain the original error message"""
_err = ImportError("No module named 'mysqldb'")
mocker.patch.object(dummy_dataframe, "to_sql", side_effect=_err)
with pytest.raises(DataSetError, match=ERROR_PREFIX + "mysqlclient"):
table_data_set.save(dummy_dataframe)

@pytest.mark.parametrize(
"table_data_set", [{"credentials": dict(con=FAKE_CONN_STR)}], indirect=True
)
def test_save_unknown_sql(self, table_data_set, dummy_dataframe):
"""Check the error when unknown sql dialect is provided"""
pattern = r"The SQL dialect in your connection is not supported by SQLAlchemy"
with pytest.raises(DataSetError, match=pattern):
table_data_set.save(dummy_dataframe)

@pytest.mark.parametrize(
"table_data_set", [{"credentials": dict(con=_unknown_conn)}], indirect=True
)
def test_save_unknown_module(self, mocker, table_data_set, dummy_dataframe):
"""Test that if an unknown module/driver is encountered by SQLAlchemy
then the error should contain the original error message"""
_err = ImportError("No module named 'unknown_module'")
mocker.patch.object(dummy_dataframe, "to_sql", side_effect=_err)
pattern = r"No module named \'unknown_module\'"
with pytest.raises(DataSetError, match=pattern):
table_data_set.save(dummy_dataframe)
dummy_dataframe.to_sql.assert_called_once_with(
name=TABLE_NAME, con=table_data_set.engines[CONNECTION], index=True
)

@pytest.mark.parametrize(
"table_data_set", [{"save_args": dict(name="TABLE_B")}], indirect=True
@@ -181,55 +180,75 @@ def test_save_ignore_table_name_override(
effect"""
mocker.patch.object(dummy_dataframe, "to_sql")
table_data_set.save(dummy_dataframe)
self._assert_to_sql_called_once(dummy_dataframe)
dummy_dataframe.to_sql.assert_called_once_with(
name=TABLE_NAME, con=table_data_set.engines[CONNECTION], index=False
)


class TestSQLTableDataSet:
@staticmethod
def _assert_sqlalchemy_called_once(*args):
_callable = sqlalchemy.engine.Engine.table_names
if args:
_callable.assert_called_once_with(*args)
else:
assert _callable.call_count == 1
class TestSQLTableDataSetSingleConnection:
def test_single_connection(self, dummy_dataframe, mocker):
"""Test to make sure multiple instances use the same connection object."""
mocker.patch("pandas.read_sql_table")
dummy_to_sql = mocker.patch.object(dummy_dataframe, "to_sql")
kwargs = dict(table_name=TABLE_NAME, credentials=dict(con=CONNECTION))

first = SQLTableDataSet(**kwargs)
unique_connection = first.engines[CONNECTION]
datasets = [SQLTableDataSet(**kwargs) for _ in range(10)]

for ds in datasets:
ds.save(dummy_dataframe)
engine = ds.engines[CONNECTION]
assert engine is unique_connection

expected_call = mocker.call(name=TABLE_NAME, con=unique_connection, index=False)
dummy_to_sql.assert_has_calls([expected_call] * 10)

for ds in datasets:
ds.load()
engine = ds.engines[CONNECTION]
assert engine is unique_connection

def test_create_connection_only_once(self, mocker):
"""Test that two datasets that need to connect to the same db
(but different tables, for example) only create a connection once.
"""
mock_engine = mocker.patch(
"kedro.extras.datasets.pandas.sql_dataset.create_engine"
)
first = SQLTableDataSet(table_name=TABLE_NAME, credentials=dict(con=CONNECTION))
assert len(first.engines) == 1

def test_str_representation_table(self, table_data_set):
"""Test the data set instance string representation"""
str_repr = str(table_data_set)
assert (
"SQLTableDataSet(load_args={}, save_args={'index': False}, "
f"table_name={TABLE_NAME})" in str_repr
second = SQLTableDataSet(
table_name="other_table", credentials=dict(con=CONNECTION)
)
assert CONNECTION not in str(str_repr)
assert len(second.engines) == 1
assert len(first.engines) == 1

def test_table_exists(self, mocker, table_data_set):
"""Test `exists` method invocation"""
mocker.patch("sqlalchemy.engine.Engine.table_names")
assert not table_data_set.exists()
self._assert_sqlalchemy_called_once()
mock_engine.assert_called_once_with(CONNECTION)

@pytest.mark.parametrize(
"table_data_set", [{"load_args": dict(schema="ingested")}], indirect=True
)
def test_able_exists_schema(self, mocker, table_data_set):
"""Test `exists` method invocation with DB schema provided"""
mocker.patch("sqlalchemy.engine.Engine.table_names")
assert not table_data_set.exists()
self._assert_sqlalchemy_called_once("ingested")
def test_multiple_connections(self, mocker):
"""Test that two datasets that need to connect to different dbs
only create one connection per db.
"""
mock_engine = mocker.patch(
"kedro.extras.datasets.pandas.sql_dataset.create_engine"
)
first = SQLTableDataSet(table_name=TABLE_NAME, credentials=dict(con=CONNECTION))
assert len(first.engines) == 1

def test_table_exists_mocked(self, mocker, table_data_set):
"""Test `exists` method invocation with mocked list of tables"""
mocker.patch("sqlalchemy.engine.Engine.table_names", return_value=[TABLE_NAME])
assert table_data_set.exists()
self._assert_sqlalchemy_called_once()
second_con = f"other_{CONNECTION}"
second = SQLTableDataSet(
table_name=TABLE_NAME, credentials=dict(con=second_con)
)
assert len(second.engines) == 2
assert len(first.engines) == 2

expected_calls = [mocker.call(CONNECTION), mocker.call(second_con)]
assert mock_engine.call_args_list == expected_calls

class TestSQLQueryDataSet:
@staticmethod
def _assert_pd_called_once():
_callable = pd.read_sql_query
_callable.assert_called_once_with(sql=SQL_QUERY, con=CONNECTION)

class TestSQLQueryDataSet:
def test_empty_query_error(self):
"""Check the error when instantiating with empty query or file"""
pattern = (
@@ -252,49 +271,56 @@ def test_load(self, mocker, query_data_set):
"""Test `load` method invocation"""
mocker.patch("pandas.read_sql_query")
query_data_set.load()
self._assert_pd_called_once()
pd.read_sql_query.assert_called_once_with(
sql=SQL_QUERY, con=query_data_set.engines[CONNECTION]
)

def test_load_query_file(self, mocker, query_file_data_set):
"""Test `load` method with a query file"""
mocker.patch("pandas.read_sql_query")
query_file_data_set.load()
self._assert_pd_called_once()
pd.read_sql_query.assert_called_once_with(
sql=SQL_QUERY, con=query_file_data_set.engines[CONNECTION]
)

def test_load_driver_missing(self, mocker, query_data_set):
def test_load_driver_missing(self, mocker):
"""Test that if an unknown module/driver is encountered by SQLAlchemy
then the error should contain the original error message"""
_err = ImportError("No module named 'mysqldb'")
mocker.patch("pandas.read_sql_query", side_effect=_err)
mocker.patch(
"kedro.extras.datasets.pandas.sql_dataset.create_engine", side_effect=_err
)
with pytest.raises(DataSetError, match=ERROR_PREFIX + "mysqlclient"):
query_data_set.load()
SQLQueryDataSet(sql=SQL_QUERY, credentials=dict(con=CONNECTION))

def test_invalid_module(self, mocker, query_data_set):
def test_invalid_module(self, mocker):
"""Test that if an unknown module/driver is encountered by SQLAlchemy
then the error should contain the original error message"""
_err = ImportError("Invalid module some_module")
mocker.patch("pandas.read_sql_query", side_effect=_err)
mocker.patch(
"kedro.extras.datasets.pandas.sql_dataset.create_engine", side_effect=_err
)
pattern = ERROR_PREFIX + r"Invalid module some\_module"
with pytest.raises(DataSetError, match=pattern):
query_data_set.load()
SQLQueryDataSet(sql=SQL_QUERY, credentials=dict(con=CONNECTION))

def test_load_unknown_module(self, mocker, query_data_set):
def test_load_unknown_module(self, mocker):
"""Test that if an unknown module/driver is encountered by SQLAlchemy
then the error should contain the original error message"""
_err = ImportError("No module named 'unknown_module'")
mocker.patch("pandas.read_sql_query", side_effect=_err)
mocker.patch(
"kedro.extras.datasets.pandas.sql_dataset.create_engine", side_effect=_err
)
pattern = ERROR_PREFIX + r"No module named \'unknown\_module\'"
with pytest.raises(DataSetError, match=pattern):
query_data_set.load()
SQLQueryDataSet(sql=SQL_QUERY, credentials=dict(con=CONNECTION))

@pytest.mark.parametrize(
"query_data_set", [{"credentials": dict(con=FAKE_CONN_STR)}], indirect=True
)
def test_load_unknown_sql(self, query_data_set):
def test_load_unknown_sql(self):
"""Check the error when unknown SQL dialect is provided
in the connection string"""
pattern = r"The SQL dialect in your connection is not supported by SQLAlchemy"
with pytest.raises(DataSetError, match=pattern):
query_data_set.load()
SQLQueryDataSet(sql=SQL_QUERY, credentials=dict(con=FAKE_CONN_STR))

def test_save_error(self, query_data_set, dummy_dataframe):
"""Check the error when trying to save to the data set"""
@@ -330,3 +356,19 @@ def test_sql_and_filepath_args(self, sql_file):
)
with pytest.raises(DataSetError, match=pattern):
SQLQueryDataSet(sql=SQL_QUERY, filepath=sql_file)

def test_create_connection_only_once(self, mocker):
"""Test that two datasets that need to connect to the same db
(but different tables, for example) only create a connection once.
"""
mock_engine = mocker.patch(
"kedro.extras.datasets.pandas.sql_dataset.create_engine"
)
first = SQLQueryDataSet(sql=SQL_QUERY, credentials=dict(con=CONNECTION))
assert len(first.engines) == 1

second = SQLQueryDataSet(sql=SQL_QUERY, credentials=dict(con=CONNECTION))
assert len(second.engines) == 1
assert len(first.engines) == 1

mock_engine.assert_called_once_with(CONNECTION)