Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable all datasets to take a metadata parameter #633

Merged
merged 3 commits into from
Feb 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -2,6 +2,10 @@

## [Unreleased]

### Added

- :sparkles: All datasets can now take a metadata parameter ([#625](https://github.com/Galileo-Galilei/kedro-mlflow/issues/625), [#633](https://github.com/Galileo-Galilei/kedro-mlflow/pull/633))

## [0.14.2] - 2025-02-16

### Fixed
8 changes: 6 additions & 2 deletions kedro_mlflow/io/artifacts/mlflow_artifact_dataset.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ def __new__(
run_id: str = None,
artifact_path: str = None,
credentials: Dict[str, Any] = None,
metadata: Dict[str, Any] | None = None,
):
dataset_obj, dataset_args = parse_dataset_definition(config=dataset)

@@ -27,11 +28,12 @@ def __new__(
# instead and since we can't modify the core package,
# we create a subclass which inherits dynamically from the dataset class
class MlflowArtifactDatasetChildren(dataset_obj):
def __init__(self, run_id, artifact_path):
def __init__(self, run_id, artifact_path, metadata):
super().__init__(**dataset_args)
self.run_id = run_id
self.artifact_path = artifact_path
self._logging_activated = True
self.metadata = metadata

@property
def _logging_activated(self):
@@ -147,7 +149,9 @@ def _load(self) -> Any: # pragma: no cover
)

mlflow_dataset_instance = MlflowArtifactDatasetChildren(
run_id=run_id, artifact_path=artifact_path
run_id=run_id,
artifact_path=artifact_path,
metadata=metadata,
)
return mlflow_dataset_instance

2 changes: 2 additions & 0 deletions kedro_mlflow/io/metrics/mlflow_abstract_metric_dataset.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@ def __init__(
run_id: str = None,
load_args: Dict[str, Any] = None,
save_args: Dict[str, Any] = None,
metadata: Dict[str, Any] | None = None,
):
"""Initialise MlflowMetricsHistoryDataset.

@@ -24,6 +25,7 @@ def __init__(
self._load_args = load_args or {}
self._save_args = save_args or {}
self._logging_activated = True # by default, logging is activated!
self.metadata = metadata

@property
def run_id(self) -> Union[str, None]:
3 changes: 2 additions & 1 deletion kedro_mlflow/io/metrics/mlflow_metric_dataset.py
Original file line number Diff line number Diff line change
@@ -18,13 +18,14 @@ def __init__(
run_id: str = None,
load_args: Dict[str, Any] = None,
save_args: Dict[str, Any] = None,
metadata: Dict[str, Any] | None = None,
):
"""Initialise MlflowMetricDataset.
Args:
run_id (str): The ID of the mlflow run where the metric should be logged
"""

super().__init__(key, run_id, load_args, save_args)
super().__init__(key, run_id, load_args, save_args, metadata)

# We add an extra argument mode="overwrite" / "append" to enable logging update an existing metric
# this is not an offical mlflow argument for log_metric, so we separate it from the others
3 changes: 2 additions & 1 deletion kedro_mlflow/io/metrics/mlflow_metric_history_dataset.py
Original file line number Diff line number Diff line change
@@ -14,13 +14,14 @@ def __init__(
run_id: str = None,
load_args: Dict[str, Any] = None,
save_args: Dict[str, Any] = None,
metadata: Dict[str, Any] | None = None,
):
"""Initialise MlflowMetricDataset.
Args:
run_id (str): The ID of the mlflow run where the metric should be logged
"""

super().__init__(key, run_id, load_args, save_args)
super().__init__(key, run_id, load_args, save_args, metadata)

def _load(self):
self._validate_run_id()
2 changes: 2 additions & 0 deletions kedro_mlflow/io/metrics/mlflow_metrics_history_dataset.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ def __init__(
self,
run_id: str = None,
prefix: Optional[str] = None,
metadata: Dict[str, Any] | None = None,
):
"""Initialise MlflowMetricsHistoryDataset.

@@ -28,6 +29,7 @@ def __init__(
self._prefix = prefix
self.run_id = run_id
self._logging_activated = True # by default, logging is activated!
self.metadata = metadata

@property
def run_id(self):
4 changes: 4 additions & 0 deletions kedro_mlflow/io/models/mlflow_abstract_model_dataset.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ def __init__(
load_args: Dict[str, Any] = None,
save_args: Dict[str, Any] = None,
version: Version = None,
metadata: Dict[str, Any] | None = None,
) -> None:
"""Initialize the Kedro MlflowAbstractModelDataSet.

@@ -39,6 +40,8 @@ def __init__(
save_args (Dict[str, Any], optional): Arguments to `log_model`
function from specified `flavor`. Defaults to {}.
version (Version, optional): Specific version to load.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.

Raises:
DatasetError: When passed `flavor` does not exist.
@@ -61,6 +64,7 @@ def __init__(

self._load_args = load_args or {}
self._save_args = save_args or {}
self.metadata = metadata

try:
self._mlflow_model_module
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@ def __init__(
save_args: Dict[str, Any] = None,
log_args: Dict[str, Any] = None,
version: Version = None,
metadata: Dict[str, Any] | None = None,
) -> None:
"""Initialize the Kedro MlflowModelDataSet.

@@ -40,6 +41,9 @@ def __init__(
save_args (Dict[str, Any], optional): Arguments to `save_model`
function from specified `flavor`. Defaults to None.
version (Version, optional): Kedro version to use. Defaults to None.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.

Raises:
DatasetError: When passed `flavor` does not exist.
"""
@@ -50,6 +54,7 @@ def __init__(
load_args=load_args,
save_args=save_args,
version=version,
metadata=metadata,
)

def _load(self) -> Any:
4 changes: 4 additions & 0 deletions kedro_mlflow/io/models/mlflow_model_registry_dataset.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ def __init__(
flavor: Optional[str] = "mlflow.pyfunc",
pyfunc_workflow: Optional[str] = "python_model",
load_args: Optional[Dict[str, Any]] = None,
metadata: Dict[str, Any] | None = None,
) -> None:
"""Initialize the Kedro MlflowModelRegistryDataset.

@@ -37,6 +38,8 @@ def __init__(
See https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows.
load_args (Dict[str, Any], optional): Arguments to `load_model`
function from specified `flavor`. Defaults to None.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.

Raises:
DatasetError: When passed `flavor` does not exist.
@@ -48,6 +51,7 @@ def __init__(
load_args=load_args,
save_args={},
version=None,
metadata=metadata,
)

if alias is None and stage_or_version is None:
4 changes: 4 additions & 0 deletions kedro_mlflow/io/models/mlflow_model_tracking_dataset.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ def __init__(
pyfunc_workflow: Optional[str] = None,
load_args: Optional[Dict[str, Any]] = None,
save_args: Optional[Dict[str, Any]] = None,
metadata: Dict[str, Any] | None = None,
) -> None:
"""Initialize the Kedro MlflowModelDataSet.

@@ -40,6 +41,8 @@ def __init__(
function from specified `flavor`. Defaults to None.
save_args (Dict[str, Any], optional): Arguments to `log_model`
function from specified `flavor`. Defaults to None.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.

Raises:
DatasetError: When passed `flavor` does not exist.
@@ -51,6 +54,7 @@ def __init__(
load_args=load_args,
save_args=save_args,
version=None,
metadata=metadata,
)

self._run_id = run_id
20 changes: 20 additions & 0 deletions tests/io/artifacts/test_mlflow_artifact_dataset.py
Original file line number Diff line number Diff line change
@@ -376,3 +376,23 @@ def _describe(self):
remote_path = (Path("artifact_dir") / filepath.name).as_posix()
assert remote_path in run_artifacts
assert df1.equals(mlflow_dataset.load())


@pytest.mark.parametrize(
"metadata",
(
None,
{"description": "My awsome dataset"},
{"string": "bbb", "int": 0},
),
)
def test_artifact_dataset_with_metadata(metadata):
mlflow_csv_dataset = MlflowArtifactDataset(
dataset=dict(type=CSVDataset, filepath="/my/file/path"),
metadata=metadata,
)

assert mlflow_csv_dataset.metadata == metadata

# Metadata should not show in _describe
assert "metadata" not in mlflow_csv_dataset._describe()
20 changes: 20 additions & 0 deletions tests/io/metrics/test_mlflow_metric_dataset.py
Original file line number Diff line number Diff line change
@@ -193,3 +193,23 @@ def test_mlflow_metric_logging_deactivation_is_bool():

with pytest.raises(ValueError, match="_logging_activated must be a boolean"):
mlflow_metric_dataset._logging_activated = "hello"


@pytest.mark.parametrize(
"metadata",
(
None,
{"description": "My awsome dataset"},
{"string": "bbb", "int": 0},
),
)
def test_metric_dataset_with_metadata(tmp_path, metadata):
mlflow_metric_dataset = MlflowMetricDataset(
key="hello",
metadata=metadata,
)

assert mlflow_metric_dataset.metadata == metadata

# Metadata should not show in _describe
assert "metadata" not in mlflow_metric_dataset._describe()
20 changes: 20 additions & 0 deletions tests/io/metrics/test_mlflow_metric_history_dataset.py
Original file line number Diff line number Diff line change
@@ -68,3 +68,23 @@ def test_mlflow_metric_history_dataset_logging_deactivation(mlflow_tracking_uri)
with mlflow.start_run():
metric_ds.save([0.1])
assert metric_ds._exists() is False


@pytest.mark.parametrize(
"metadata",
(
None,
{"description": "My awsome dataset"},
{"string": "bbb", "int": 0},
),
)
def test_metric_history_dataset_with_metadata(tmp_path, metadata):
metric_ds = MlflowMetricHistoryDataset(
key="hello",
metadata=metadata,
)

assert metric_ds.metadata == metadata

# Metadata should not show in _describe
assert "metadata" not in metric_ds._describe()
20 changes: 20 additions & 0 deletions tests/io/metrics/test_mlflow_metrics_dataset.py
Original file line number Diff line number Diff line change
@@ -189,3 +189,23 @@ def test_mlflow_metrics_logging_deactivation_is_bool():

with pytest.raises(ValueError, match="_logging_activated must be a boolean"):
mlflow_metrics_dataset._logging_activated = "hello"


@pytest.mark.parametrize(
"metadata",
(
None,
{"description": "My awsome dataset"},
{"string": "bbb", "int": 0},
),
)
def test_metrics_history_dataset_with_metadata(tmp_path, metadata):
mlflow_metrics_dataset = MlflowMetricsHistoryDataset(
prefix="hello",
metadata=metadata,
)

assert mlflow_metrics_dataset.metadata == metadata

# Metadata should not show in _describe
assert "metadata" not in mlflow_metrics_dataset._describe()
21 changes: 21 additions & 0 deletions tests/io/models/test_mlflow_model_local_filesystem_dataset.py
Original file line number Diff line number Diff line change
@@ -193,3 +193,24 @@ def test_pyfunc_flavor_python_model_save_and_load(
loaded_model.predict(pd.DataFrame(data=[1], columns=["a"])) == pd.DataFrame(
data=[2], columns=["a"]
)


@pytest.mark.parametrize(
"metadata",
(
None,
{"description": "My awsome dataset"},
{"string": "bbb", "int": 0},
),
)
def test_metrics_history_dataset_with_metadata(metadata):
mlflow_model_ds = MlflowModelLocalFileSystemDataset(
flavor="mlflow.sklearn",
filepath="/my/file/path",
metadata=metadata,
)

assert mlflow_model_ds.metadata == metadata

# Metadata should not show in _describe
assert "metadata" not in mlflow_model_ds._describe()
20 changes: 20 additions & 0 deletions tests/io/models/test_mlflow_model_registry_dataset.py
Original file line number Diff line number Diff line change
@@ -156,3 +156,23 @@ def test_mlflow_model_registry_load_given_alias(tmp_path, monkeypatch):
ml_ds = MlflowModelRegistryDataset(model_name="demo_model", alias="champion")
loaded_model = ml_ds.load()
assert loaded_model.metadata.run_id == runs[1]


@pytest.mark.parametrize(
"metadata",
(
None,
{"description": "My awsome dataset"},
{"string": "bbb", "int": 0},
),
)
def test_metrics_history_dataset_with_metadata(tmp_path, metadata):
mlflow_model_ds = MlflowModelRegistryDataset(
model_name="demo_model",
metadata=metadata,
)

assert mlflow_model_ds.metadata == metadata

# Metadata should not show in _describe
assert "metadata" not in mlflow_model_ds._describe()
20 changes: 20 additions & 0 deletions tests/io/models/test_mlflow_model_tracking_dataset.py
Original file line number Diff line number Diff line change
@@ -364,3 +364,23 @@ def test_mlflow_model_tracking_logging_deactivation_is_bool():

with pytest.raises(ValueError, match="_logging_activated must be a boolean"):
mlflow_model_tracking_dataset._logging_activated = "hello"


@pytest.mark.parametrize(
"metadata",
(
None,
{"description": "My awsome dataset"},
{"string": "bbb", "int": 0},
),
)
def test_metrics_history_dataset_with_metadata(metadata):
mlflow_model_ds = MlflowModelTrackingDataset(
flavor="mlflow.sklearn",
metadata=metadata,
)

assert mlflow_model_ds.metadata == metadata

# Metadata should not show in _describe
assert "metadata" not in mlflow_model_ds._describe()