Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to blacklist params. also sanitise keys #595

Merged
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

## [Unreleased]

### Added

- :sparkles: Implement missing PipelineML filtering functionalities to let kedro display resume hints and avoid breaking kedro-viz ([#377, Calychas](https://github.com/Galileo-Galilei/kedro-mlflow/pull/377), [#601, Calychas](https://github.com/Galileo-Galilei/kedro-mlflow/pull/601))
- :sparkles: Sanitize parameters name with unsupported characters to avoid mlflow errors when logging ([#595, pascalwhoop](https://github.com/Galileo-Galilei/kedro-mlflow/pull/595))

## [0.13.2] - 2024-10-15

### Fixed
Expand Down
28 changes: 28 additions & 0 deletions kedro_mlflow/framework/hooks/mlflow_hook.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import re
from logging import Logger, getLogger
from pathlib import Path
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -303,6 +305,11 @@ def before_node_run(
d=params_inputs, recursive=self.recursive, sep=self.sep
)

# sanitize params inputs to avoid mlflow errors
params_inputs = {
self.sanitize_param_name(k): v for k, v in params_inputs.items()
}

# logging parameters based on defined strategy
for k, v in params_inputs.items():
self._log_param(k, v)
Expand Down Expand Up @@ -446,5 +453,26 @@ def on_pipeline_error(
# hence it should not be modified. this is only a safeguard
switch_catalog_logging(catalog, True)

def sanitize_param_name(self, name: str) -> str:
# regex taken from MLFlow codebase: https://github.com/mlflow/mlflow/blob/e40e782b6fcab473159e6d4fee85bc0fc10f78fd/mlflow/utils/validation.py#L140C1-L148C44

# for windows colon ':' are not accepted
matching_pattern = r"^[/\w.\- ]*$" if is_windows() else r"^[/\w.\- :]*$"

if re.match(matching_pattern, name):
return name
else:
replacement_pattern = r"[^/\w.\- ]" if is_windows() else r"[^/\w.\- :]"
# Replace invalid characters with underscore
sanitized_name = re.sub(replacement_pattern, "_", name)
self._logger.warning(
f"'{name}' is not a valid name for a mlflow paramter. It is renamed as '{sanitized_name}'"
)
return sanitized_name


def is_windows():
return os.name == "nt"


mlflow_hook = MlflowHook()
131 changes: 131 additions & 0 deletions tests/framework/hooks/test_hook_log_parameters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path
from typing import Dict

Expand Down Expand Up @@ -74,6 +75,136 @@ def dummy_catalog():
return catalog


@pytest.mark.parametrize(
"param_name,expected_name",
[
("valid_param", "valid_param"),
("valid-param", "valid-param"),
("invalid/param", "invalid/param"),
("invalid.param", "invalid.param"),
("[invalid]$param", "_invalid__param"),
],
)
def test_parameter_name_sanitization(
kedro_project, dummy_run_params, param_name, expected_name
):
mlflow_tracking_uri = (kedro_project / "mlruns").as_uri()
mlflow.set_tracking_uri(mlflow_tracking_uri)

node_inputs = {f"params:{param_name}": "test_value"}

bootstrap_project(kedro_project)
with KedroSession.create(
project_path=kedro_project,
) as session:
context = session.load_context()
mlflow_node_hook = MlflowHook()
mlflow_node_hook.after_context_created(context)

with mlflow.start_run():
mlflow_node_hook.before_pipeline_run(
run_params=dummy_run_params,
pipeline=Pipeline([]),
catalog=DataCatalog(),
)
mlflow_node_hook.before_node_run(
node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None),
catalog=DataCatalog(),
inputs=node_inputs,
is_async=False,
)
run_id = mlflow.active_run().info.run_id

mlflow_client = MlflowClient(mlflow_tracking_uri)
current_run = mlflow_client.get_run(run_id)
assert expected_name in current_run.data.params
assert current_run.data.params[expected_name] == "test_value"


@pytest.mark.skipif(
os.name != "nt", reason="Windows does not log params with colon symbol"
)
def test_parameter_name_with_colon_sanitization_on_windows(
kedro_project, dummy_run_params
):
param_name = "valid:param"
expected_name = "valid_param"

mlflow_tracking_uri = (kedro_project / "mlruns").as_uri()
mlflow.set_tracking_uri(mlflow_tracking_uri)

node_inputs = {f"params:{param_name}": "test_value"}

bootstrap_project(kedro_project)
with KedroSession.create(
project_path=kedro_project,
) as session:
context = session.load_context()
mlflow_node_hook = MlflowHook()
mlflow_node_hook.after_context_created(context)

with mlflow.start_run():
mlflow_node_hook.before_pipeline_run(
run_params=dummy_run_params,
pipeline=Pipeline([]),
catalog=DataCatalog(),
)
mlflow_node_hook.before_node_run(
node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None),
catalog=DataCatalog(),
inputs=node_inputs,
is_async=False,
)
run_id = mlflow.active_run().info.run_id

mlflow_client = MlflowClient(mlflow_tracking_uri)
current_run = mlflow_client.get_run(run_id)
assert expected_name in current_run.data.params
assert current_run.data.params[expected_name] == "test_value"


@pytest.mark.skipif(
os.name == "nt", reason="Linux and Mac do log params with colon symbol"
)
def test_parameter_name_with_colon_sanitization_on_mac_linux(
kedro_project, dummy_run_params
):
param_name = "valid:param"
expected_name = "valid:param"

mlflow_tracking_uri = (kedro_project / "mlruns").as_uri()
mlflow.set_tracking_uri(mlflow_tracking_uri)

node_inputs = {f"params:{param_name}": "test_value"}

bootstrap_project(kedro_project)
with KedroSession.create(
project_path=kedro_project,
) as session:
context = session.load_context()
Comment on lines +180 to +184
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a nitpick, but can we just ignore all this? Since you invoke the hook manually, I don't think it's necessary creating a whole session and contexte, unless I miss something?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, this is needed to call the after_context_created hook, which format and stores the entire configuration.

mlflow_node_hook = MlflowHook()
mlflow_node_hook.after_context_created(context)
Copy link
Owner

@Galileo-Galilei Galileo-Galilei Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure you need after_context_created since you already set mlflow_tracking_uri up just above.

Overall really like the test👍


with mlflow.start_run():
mlflow_node_hook.before_pipeline_run(
run_params=dummy_run_params,
pipeline=Pipeline([]),
catalog=DataCatalog(),
)
mlflow_node_hook.before_node_run(
node=node(func=lambda x: x, inputs=dict(x="a"), outputs=None),
catalog=DataCatalog(),
inputs=node_inputs,
is_async=False,
)
run_id = mlflow.active_run().info.run_id

mlflow_client = MlflowClient(mlflow_tracking_uri)
current_run = mlflow_client.get_run(run_id)
assert expected_name in current_run.data.params
assert current_run.data.params[expected_name] == "test_value"


def test_pipeline_run_hook_getting_configs(
kedro_project,
dummy_run_params,
Expand Down
Loading