Skip to content

Commit 85fc3fd

Browse files
authored
✨ Enable all datasets to take a metadata parameter (#633)
Signed-off-by: Guillaume Tauzin <4648633+gtauzin@users.noreply.github.com>
1 parent 2920a26 commit 85fc3fd

17 files changed

+176
-4
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## [Unreleased]
44

5+
### Added
6+
7+
- :sparkles: All datasets can now take a metadata parameter ([#625](https://github.com/Galileo-Galilei/kedro-mlflow/issues/625), [#633](https://github.com/Galileo-Galilei/kedro-mlflow/pull/633))
8+
59
## [0.14.2] - 2025-02-16
610

711
### Fixed

kedro_mlflow/io/artifacts/mlflow_artifact_dataset.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __new__(
1919
run_id: str = None,
2020
artifact_path: str = None,
2121
credentials: Dict[str, Any] = None,
22+
metadata: Dict[str, Any] | None = None,
2223
):
2324
dataset_obj, dataset_args = parse_dataset_definition(config=dataset)
2425

@@ -27,11 +28,12 @@ def __new__(
2728
# instead and since we can't modify the core package,
2829
# we create a subclass which inherits dynamically from the dataset class
2930
class MlflowArtifactDatasetChildren(dataset_obj):
30-
def __init__(self, run_id, artifact_path):
31+
def __init__(self, run_id, artifact_path, metadata):
3132
super().__init__(**dataset_args)
3233
self.run_id = run_id
3334
self.artifact_path = artifact_path
3435
self._logging_activated = True
36+
self.metadata = metadata
3537

3638
@property
3739
def _logging_activated(self):
@@ -147,7 +149,9 @@ def _load(self) -> Any: # pragma: no cover
147149
)
148150

149151
mlflow_dataset_instance = MlflowArtifactDatasetChildren(
150-
run_id=run_id, artifact_path=artifact_path
152+
run_id=run_id,
153+
artifact_path=artifact_path,
154+
metadata=metadata,
151155
)
152156
return mlflow_dataset_instance
153157

kedro_mlflow/io/metrics/mlflow_abstract_metric_dataset.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def __init__(
1212
run_id: str = None,
1313
load_args: Dict[str, Any] = None,
1414
save_args: Dict[str, Any] = None,
15+
metadata: Dict[str, Any] | None = None,
1516
):
1617
"""Initialise MlflowMetricsHistoryDataset.
1718
@@ -24,6 +25,7 @@ def __init__(
2425
self._load_args = load_args or {}
2526
self._save_args = save_args or {}
2627
self._logging_activated = True # by default, logging is activated!
28+
self.metadata = metadata
2729

2830
@property
2931
def run_id(self) -> Union[str, None]:

kedro_mlflow/io/metrics/mlflow_metric_dataset.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@ def __init__(
1818
run_id: str = None,
1919
load_args: Dict[str, Any] = None,
2020
save_args: Dict[str, Any] = None,
21+
metadata: Dict[str, Any] | None = None,
2122
):
2223
"""Initialise MlflowMetricDataset.
2324
Args:
2425
run_id (str): The ID of the mlflow run where the metric should be logged
2526
"""
2627

27-
super().__init__(key, run_id, load_args, save_args)
28+
super().__init__(key, run_id, load_args, save_args, metadata)
2829

2930
# We add an extra argument mode="overwrite" / "append" to enable logging update an existing metric
3031
# this is not an offical mlflow argument for log_metric, so we separate it from the others

kedro_mlflow/io/metrics/mlflow_metric_history_dataset.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ def __init__(
1414
run_id: str = None,
1515
load_args: Dict[str, Any] = None,
1616
save_args: Dict[str, Any] = None,
17+
metadata: Dict[str, Any] | None = None,
1718
):
1819
"""Initialise MlflowMetricDataset.
1920
Args:
2021
run_id (str): The ID of the mlflow run where the metric should be logged
2122
"""
2223

23-
super().__init__(key, run_id, load_args, save_args)
24+
super().__init__(key, run_id, load_args, save_args, metadata)
2425

2526
def _load(self):
2627
self._validate_run_id()

kedro_mlflow/io/metrics/mlflow_metrics_history_dataset.py

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(
1818
self,
1919
run_id: str = None,
2020
prefix: Optional[str] = None,
21+
metadata: Dict[str, Any] | None = None,
2122
):
2223
"""Initialise MlflowMetricsHistoryDataset.
2324
@@ -28,6 +29,7 @@ def __init__(
2829
self._prefix = prefix
2930
self.run_id = run_id
3031
self._logging_activated = True # by default, logging is activated!
32+
self.metadata = metadata
3133

3234
@property
3335
def run_id(self):

kedro_mlflow/io/models/mlflow_abstract_model_dataset.py

+4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(
2020
load_args: Dict[str, Any] = None,
2121
save_args: Dict[str, Any] = None,
2222
version: Version = None,
23+
metadata: Dict[str, Any] | None = None,
2324
) -> None:
2425
"""Initialize the Kedro MlflowAbstractModelDataSet.
2526
@@ -39,6 +40,8 @@ def __init__(
3940
save_args (Dict[str, Any], optional): Arguments to `log_model`
4041
function from specified `flavor`. Defaults to {}.
4142
version (Version, optional): Specific version to load.
43+
metadata: Any arbitrary metadata.
44+
This is ignored by Kedro, but may be consumed by users or external plugins.
4245
4346
Raises:
4447
DatasetError: When passed `flavor` does not exist.
@@ -61,6 +64,7 @@ def __init__(
6164

6265
self._load_args = load_args or {}
6366
self._save_args = save_args or {}
67+
self.metadata = metadata
6468

6569
try:
6670
self._mlflow_model_module

kedro_mlflow/io/models/mlflow_model_local_filesystem_dataset.py

+5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(
2121
save_args: Dict[str, Any] = None,
2222
log_args: Dict[str, Any] = None,
2323
version: Version = None,
24+
metadata: Dict[str, Any] | None = None,
2425
) -> None:
2526
"""Initialize the Kedro MlflowModelDataSet.
2627
@@ -40,6 +41,9 @@ def __init__(
4041
save_args (Dict[str, Any], optional): Arguments to `save_model`
4142
function from specified `flavor`. Defaults to None.
4243
version (Version, optional): Kedro version to use. Defaults to None.
44+
metadata: Any arbitrary metadata.
45+
This is ignored by Kedro, but may be consumed by users or external plugins.
46+
4347
Raises:
4448
DatasetError: When passed `flavor` does not exist.
4549
"""
@@ -50,6 +54,7 @@ def __init__(
5054
load_args=load_args,
5155
save_args=save_args,
5256
version=version,
57+
metadata=metadata,
5358
)
5459

5560
def _load(self) -> Any:

kedro_mlflow/io/models/mlflow_model_registry_dataset.py

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
flavor: Optional[str] = "mlflow.pyfunc",
2020
pyfunc_workflow: Optional[str] = "python_model",
2121
load_args: Optional[Dict[str, Any]] = None,
22+
metadata: Dict[str, Any] | None = None,
2223
) -> None:
2324
"""Initialize the Kedro MlflowModelRegistryDataset.
2425
@@ -37,6 +38,8 @@ def __init__(
3738
See https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows.
3839
load_args (Dict[str, Any], optional): Arguments to `load_model`
3940
function from specified `flavor`. Defaults to None.
41+
metadata: Any arbitrary metadata.
42+
This is ignored by Kedro, but may be consumed by users or external plugins.
4043
4144
Raises:
4245
DatasetError: When passed `flavor` does not exist.
@@ -48,6 +51,7 @@ def __init__(
4851
load_args=load_args,
4952
save_args={},
5053
version=None,
54+
metadata=metadata,
5155
)
5256

5357
if alias is None and stage_or_version is None:

kedro_mlflow/io/models/mlflow_model_tracking_dataset.py

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
pyfunc_workflow: Optional[str] = None,
2020
load_args: Optional[Dict[str, Any]] = None,
2121
save_args: Optional[Dict[str, Any]] = None,
22+
metadata: Dict[str, Any] | None = None,
2223
) -> None:
2324
"""Initialize the Kedro MlflowModelDataSet.
2425
@@ -40,6 +41,8 @@ def __init__(
4041
function from specified `flavor`. Defaults to None.
4142
save_args (Dict[str, Any], optional): Arguments to `log_model`
4243
function from specified `flavor`. Defaults to None.
44+
metadata: Any arbitrary metadata.
45+
This is ignored by Kedro, but may be consumed by users or external plugins.
4346
4447
Raises:
4548
DatasetError: When passed `flavor` does not exist.
@@ -51,6 +54,7 @@ def __init__(
5154
load_args=load_args,
5255
save_args=save_args,
5356
version=None,
57+
metadata=metadata,
5458
)
5559

5660
self._run_id = run_id

tests/io/artifacts/test_mlflow_artifact_dataset.py

+20
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,23 @@ def _describe(self):
376376
remote_path = (Path("artifact_dir") / filepath.name).as_posix()
377377
assert remote_path in run_artifacts
378378
assert df1.equals(mlflow_dataset.load())
379+
380+
381+
@pytest.mark.parametrize(
382+
"metadata",
383+
(
384+
None,
385+
{"description": "My awsome dataset"},
386+
{"string": "bbb", "int": 0},
387+
),
388+
)
389+
def test_artifact_dataset_with_metadata(metadata):
390+
mlflow_csv_dataset = MlflowArtifactDataset(
391+
dataset=dict(type=CSVDataset, filepath="/my/file/path"),
392+
metadata=metadata,
393+
)
394+
395+
assert mlflow_csv_dataset.metadata == metadata
396+
397+
# Metadata should not show in _describe
398+
assert "metadata" not in mlflow_csv_dataset._describe()

tests/io/metrics/test_mlflow_metric_dataset.py

+20
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,23 @@ def test_mlflow_metric_logging_deactivation_is_bool():
193193

194194
with pytest.raises(ValueError, match="_logging_activated must be a boolean"):
195195
mlflow_metric_dataset._logging_activated = "hello"
196+
197+
198+
@pytest.mark.parametrize(
199+
"metadata",
200+
(
201+
None,
202+
{"description": "My awsome dataset"},
203+
{"string": "bbb", "int": 0},
204+
),
205+
)
206+
def test_metric_dataset_with_metadata(tmp_path, metadata):
207+
mlflow_metric_dataset = MlflowMetricDataset(
208+
key="hello",
209+
metadata=metadata,
210+
)
211+
212+
assert mlflow_metric_dataset.metadata == metadata
213+
214+
# Metadata should not show in _describe
215+
assert "metadata" not in mlflow_metric_dataset._describe()

tests/io/metrics/test_mlflow_metric_history_dataset.py

+20
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,23 @@ def test_mlflow_metric_history_dataset_logging_deactivation(mlflow_tracking_uri)
6868
with mlflow.start_run():
6969
metric_ds.save([0.1])
7070
assert metric_ds._exists() is False
71+
72+
73+
@pytest.mark.parametrize(
74+
"metadata",
75+
(
76+
None,
77+
{"description": "My awsome dataset"},
78+
{"string": "bbb", "int": 0},
79+
),
80+
)
81+
def test_metric_history_dataset_with_metadata(tmp_path, metadata):
82+
metric_ds = MlflowMetricHistoryDataset(
83+
key="hello",
84+
metadata=metadata,
85+
)
86+
87+
assert metric_ds.metadata == metadata
88+
89+
# Metadata should not show in _describe
90+
assert "metadata" not in metric_ds._describe()

tests/io/metrics/test_mlflow_metrics_dataset.py

+20
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,23 @@ def test_mlflow_metrics_logging_deactivation_is_bool():
189189

190190
with pytest.raises(ValueError, match="_logging_activated must be a boolean"):
191191
mlflow_metrics_dataset._logging_activated = "hello"
192+
193+
194+
@pytest.mark.parametrize(
195+
"metadata",
196+
(
197+
None,
198+
{"description": "My awsome dataset"},
199+
{"string": "bbb", "int": 0},
200+
),
201+
)
202+
def test_metrics_history_dataset_with_metadata(tmp_path, metadata):
203+
mlflow_metrics_dataset = MlflowMetricsHistoryDataset(
204+
prefix="hello",
205+
metadata=metadata,
206+
)
207+
208+
assert mlflow_metrics_dataset.metadata == metadata
209+
210+
# Metadata should not show in _describe
211+
assert "metadata" not in mlflow_metrics_dataset._describe()

tests/io/models/test_mlflow_model_local_filesystem_dataset.py

+21
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,24 @@ def test_pyfunc_flavor_python_model_save_and_load(
193193
loaded_model.predict(pd.DataFrame(data=[1], columns=["a"])) == pd.DataFrame(
194194
data=[2], columns=["a"]
195195
)
196+
197+
198+
@pytest.mark.parametrize(
199+
"metadata",
200+
(
201+
None,
202+
{"description": "My awsome dataset"},
203+
{"string": "bbb", "int": 0},
204+
),
205+
)
206+
def test_metrics_history_dataset_with_metadata(metadata):
207+
mlflow_model_ds = MlflowModelLocalFileSystemDataset(
208+
flavor="mlflow.sklearn",
209+
filepath="/my/file/path",
210+
metadata=metadata,
211+
)
212+
213+
assert mlflow_model_ds.metadata == metadata
214+
215+
# Metadata should not show in _describe
216+
assert "metadata" not in mlflow_model_ds._describe()

tests/io/models/test_mlflow_model_registry_dataset.py

+20
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,23 @@ def test_mlflow_model_registry_load_given_alias(tmp_path, monkeypatch):
156156
ml_ds = MlflowModelRegistryDataset(model_name="demo_model", alias="champion")
157157
loaded_model = ml_ds.load()
158158
assert loaded_model.metadata.run_id == runs[1]
159+
160+
161+
@pytest.mark.parametrize(
162+
"metadata",
163+
(
164+
None,
165+
{"description": "My awsome dataset"},
166+
{"string": "bbb", "int": 0},
167+
),
168+
)
169+
def test_metrics_history_dataset_with_metadata(tmp_path, metadata):
170+
mlflow_model_ds = MlflowModelRegistryDataset(
171+
model_name="demo_model",
172+
metadata=metadata,
173+
)
174+
175+
assert mlflow_model_ds.metadata == metadata
176+
177+
# Metadata should not show in _describe
178+
assert "metadata" not in mlflow_model_ds._describe()

tests/io/models/test_mlflow_model_tracking_dataset.py

+20
Original file line numberDiff line numberDiff line change
@@ -364,3 +364,23 @@ def test_mlflow_model_tracking_logging_deactivation_is_bool():
364364

365365
with pytest.raises(ValueError, match="_logging_activated must be a boolean"):
366366
mlflow_model_tracking_dataset._logging_activated = "hello"
367+
368+
369+
@pytest.mark.parametrize(
370+
"metadata",
371+
(
372+
None,
373+
{"description": "My awsome dataset"},
374+
{"string": "bbb", "int": 0},
375+
),
376+
)
377+
def test_metrics_history_dataset_with_metadata(metadata):
378+
mlflow_model_ds = MlflowModelTrackingDataset(
379+
flavor="mlflow.sklearn",
380+
metadata=metadata,
381+
)
382+
383+
assert mlflow_model_ds.metadata == metadata
384+
385+
# Metadata should not show in _describe
386+
assert "metadata" not in mlflow_model_ds._describe()

0 commit comments

Comments
 (0)