Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make mlflow not thread safe by reopening the same run before each node #638

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

- :sparkles: All datasets can now take a metadata parameter ([#625](https://github.com/Galileo-Galilei/kedro-mlflow/issues/625), [#633](https://github.com/Galileo-Galilei/kedro-mlflow/pull/633))

### Fixed

- :bug: Reopen the mlflow run before each node to bypass mlflow thread safety and ensure all tracking is done within the same run_id ([#623](https://github.com/Galileo-Galilei/kedro-mlflow/issues/623), [#624](https://github.com/Galileo-Galilei/kedro-mlflow/issues/624))

## [0.14.2] - 2025-02-16

### Fixed
Expand Down Expand Up @@ -462,7 +466,7 @@

### Removed

- :recycle: :boom: `kedro mlflow init` command is no longer declaring hooks in `run.py`. You must now [register your hooks manually](https://kedro-mlflow.readthedocs.io/en/stable/source/02_installation/02_setup.html#declaring-kedro-mlflow-hooks) in the `run.py` if you use `kedro>=0.16.0, <0.16.3` ([#62](https://github.com/Galileo-Galilei/kedro-mlflow/issues/62)).
- :recycle: :boom: `kedro mlflow init` command is no longer declaring hooks in `run.py`. You must now [register your hooks manually](https://kedro-mlflow.readthedocs.io/en/stable/source/02_getting_started/01_installation/02_setup.html#declaring-kedro-mlflow-hooks) in the `run.py` if you use `kedro>=0.16.0, <0.16.3` ([#62](https://github.com/Galileo-Galilei/kedro-mlflow/issues/62)).
- :fire: Remove `pipeline_ml` function which was deprecated in 0.3.0. It is now replaced by `pipeline_ml_factory` ([#105](https://github.com/Galileo-Galilei/kedro-mlflow/issues/105))
- :fire: Remove `MlflowDataSet` dataset which was deprecated in 0.3.0. It is now replaced by `MlflowArtifactDataSet` ([#105](https://github.com/Galileo-Galilei/kedro-mlflow/issues/105))

Expand Down
56 changes: 45 additions & 11 deletions kedro_mlflow/framework/hooks/mlflow_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
_flatten_dict,
_generate_kedro_command,
)
from kedro_mlflow.io.catalog.add_run_id_to_artifact_datasets import (
add_run_id_to_artifact_datasets,
)
from kedro_mlflow.io.catalog.switch_catalog_logging import switch_catalog_logging
from kedro_mlflow.io.metrics import (
MlflowMetricDataset,
Expand All @@ -47,6 +44,7 @@ def __init__(self):
self.recursive = True
self.sep = "."
self.long_parameters_strategy = "fail"
self.run_id = None # we store the run_id because the hook is stateful and we need to keep track of the active run between the different threads

@property
def _logger(self) -> Logger:
Expand Down Expand Up @@ -242,8 +240,9 @@ def before_pipeline_run(
)

if self._already_active_mlflow:
self.run_id = mlflow.active_run().info.run_id
self._logger.warning(
f"A mlflow run was already active (run_id='{mlflow.active_run().info.run_id}') before the KedroSession was started. This run will be used for logging."
f"A mlflow run was already active (run_id='{self.run_id}') before the KedroSession was started. This run will be used for logging."
)
else:
mlflow.start_run(
Expand All @@ -252,8 +251,9 @@ def before_pipeline_run(
run_name=run_name,
nested=self.mlflow_config.tracking.run.nested,
)
self.run_id = mlflow.active_run().info.run_id
self._logger.info(
f"Mlflow run '{mlflow.active_run().info.run_id}' has started"
f"Mlflow run '{mlflow.active_run().info.run_name}' - '{self.run_id}' has started"
)
# Set tags only for run parameters that have values.
mlflow.set_tags({k: v for k, v in run_params.items() if v})
Expand All @@ -274,12 +274,6 @@ def before_pipeline_run(
),
)

# This function ensures the run_id started at the beginning of the pipeline
# is associated to all the datasets. This is necessary because to make mlflow thread safe
# each call to the "active run" now creates a new run when started in a new thread. See
# https://github.com/Galileo-Galilei/kedro-mlflow/issues/613 and https://github.com/Galileo-Galilei/kedro-mlflow/pull/615
add_run_id_to_artifact_datasets(catalog, mlflow.active_run().info.run_id)

else:
self._logger.info(
"kedro-mlflow logging is deactivated for this pipeline in the configuration. This includes DataSets and parameters."
Expand All @@ -298,6 +292,30 @@ def before_node_run(
inputs: The dictionary of inputs dataset.
is_async: Whether the node was run in ``async`` mode.
"""
if self.run_id is not None:
# Reopening the run ensures the run_id started at the beginning of the pipeline
# is used for all tracking. This is necessary because to bypass mlflow thread safety
# each call to the "active run" now creates a new run when started in a new thread. See
# https://github.com/Galileo-Galilei/kedro-mlflow/issues/613
# https://github.com/Galileo-Galilei/kedro-mlflow/pull/615
# https://github.com/Galileo-Galilei/kedro-mlflow/issues/623
# https://github.com/Galileo-Galilei/kedro-mlflow/issues/624

# If self.run_id is None, this means that the no run was ever started, i.e. that we have deactivated mlflow for this pipeline
try:
mlflow.start_run(
run_id=self.run_id,
nested=self.mlflow_config.tracking.run.nested,
)
self._logger.info(
f"Restarting mlflow run '{mlflow.active_run().info.run_name}' - '{self.run_id}' at node level for multi-threading"
)
except Exception as err: # pragma: no cover
if f"Run with UUID {self.run_id} is already active" in str(err):
# This means that the run was started before in the same thread, likely at the beginning of another node
pass
else:
raise err

# only parameters will be logged. Artifacts must be declared manually in the catalog
if self._is_mlflow_enabled:
Expand Down Expand Up @@ -467,8 +485,24 @@ def on_pipeline_error(
f"The run '{mlflow.active_run().info.run_id}' was already opened before launching 'kedro run' so it is not closed. You should close it manually."
)
else:
# first, close all runs within the thread
while mlflow.active_run():
current_run_id = mlflow.active_run().info.run_id
self._logger.info(
f"The run '{current_run_id}' was closed because of an error in the pipeline."
)
mlflow.end_run(RunStatus.to_string(RunStatus.FAILED))
pipeline_run_id_is_closed = current_run_id == self.run_id

# second, ensure that parent run in another thread is closed
if not pipeline_run_id_is_closed:
self.mlflow_config.server._mlflow_client.set_terminated(
self.run_id, RunStatus.to_string(RunStatus.FAILED)
)
self._logger.info(
f"The parent run '{self.run_id}' was closed because of an error in the pipeline."
)

else: # pragma: no cover
# the catalog is supposed to be reloaded each time with _get_catalog,
# hence it should not be modified. this is only a safeguard
Expand Down
9 changes: 0 additions & 9 deletions kedro_mlflow/io/catalog/add_run_id_to_artifact_datasets.py

This file was deleted.

83 changes: 7 additions & 76 deletions tests/framework/hooks/test_hook_log_artifact.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import mlflow
import pandas as pd
import pytest
from kedro.framework.hooks import _create_hook_manager
from kedro.framework.hooks.manager import _register_hooks
from kedro.framework.session import KedroSession
from kedro.framework.startup import bootstrap_project
from kedro.io import DataCatalog, MemoryDataset
Expand Down Expand Up @@ -71,84 +73,10 @@ def dummy_run_params(tmp_path):
return dummy_run_params


def test_mlflow_hook_automatically_update_artifact_run_id(
kedro_project, dummy_run_params, dummy_pipeline, dummy_catalog
):
# since mlflow>=2.18, the fluent API create a new run for each thread
# hence for thread runner we need to prefix the catalog with the run id

bootstrap_project(kedro_project)
with KedroSession.create(project_path=kedro_project) as session:
context = session.load_context() # triggers conf setup

mlflow_hook = MlflowHook()
mlflow_hook.after_context_created(context) # setup mlflow config

mlflow_hook.after_catalog_created(
catalog=dummy_catalog,
# `after_catalog_created` is not using any of below arguments,
# so we are setting them to empty values.
conf_catalog={},
conf_creds={},
feed_dict={},
save_version="",
load_versions="",
)

mlflow_hook.before_pipeline_run(
run_params=dummy_run_params, pipeline=dummy_pipeline, catalog=dummy_catalog
)

run_id = mlflow.active_run().info.run_id
# Check if artifact datasets have the run_id
assert dummy_catalog._datasets["model"].run_id == run_id


def test_mlflow_hook_automatically_update_artifact_run_id_except_if_it_already_has_run_id(
kedro_project, dummy_run_params, dummy_pipeline, dummy_catalog
):
# since mlflow>=2.18, the fluent API create a new run for each thread
# hence for thread runner we need to prefix the catalog with the run id

bootstrap_project(kedro_project)
with KedroSession.create(project_path=kedro_project) as session:
context = session.load_context() # triggers conf setup

# we modify the run id to simulate an existing run id.
# We need to do it after load_context() to ensure the tracking uri is properly set up.
with mlflow.start_run():
existing_run_id = mlflow.active_run().info.run_id
dummy_catalog_with_run_id = dummy_catalog.shallow_copy()
dummy_catalog_with_run_id._datasets["model"].run_id = existing_run_id

mlflow_hook = MlflowHook()
mlflow_hook.after_context_created(context) # setup mlflow config

mlflow_hook.after_catalog_created(
catalog=dummy_catalog,
# `after_catalog_created` is not using any of below arguments,
# so we are setting them to empty values.
conf_catalog={},
conf_creds={},
feed_dict={},
save_version="",
load_versions="",
)

mlflow_hook.before_pipeline_run(
run_params=dummy_run_params, pipeline=dummy_pipeline, catalog=dummy_catalog
)

run_id = mlflow.active_run().info.run_id
# Check if artifact datasets have the run_id
assert run_id != existing_run_id
assert dummy_catalog._datasets["model"].run_id == existing_run_id


def test_mlflow_hook_log_artifacts_within_same_run_with_thread_runner(
kedro_project, dummy_run_params, dummy_pipeline, dummy_catalog
):
# this test is very specific to a new design introduced in mlflow 2.18 to make it htread safe
# this test is very specific to a new design introduced in mlflow 2.18 to make it thread safe
# see https://github.com/Galileo-Galilei/kedro-mlflow/issues/613
bootstrap_project(kedro_project)

Expand Down Expand Up @@ -178,7 +106,10 @@ def test_mlflow_hook_log_artifacts_within_same_run_with_thread_runner(
# we get the run id BEFORE running the pipeline because it was modified in different thread
run_id_before_run = mlflow.active_run().info.run_id

runner.run(dummy_pipeline, dummy_catalog, session._hook_manager)
hook_manager = _create_hook_manager()
_register_hooks(hook_manager, (mlflow_hook,))

runner.run(dummy_pipeline, dummy_catalog, hook_manager)

run_id_after_run = mlflow.active_run().info.run_id

Expand Down
7 changes: 6 additions & 1 deletion tests/framework/hooks/test_hook_log_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import mlflow
import pandas as pd
import pytest
from kedro.framework.hooks import _create_hook_manager
from kedro.framework.hooks.manager import _register_hooks
from kedro.framework.session import KedroSession
from kedro.framework.startup import bootstrap_project
from kedro.io import DataCatalog, MemoryDataset
Expand Down Expand Up @@ -219,7 +221,10 @@ def test_mlflow_hook_metrics_dataset_with_run_id(
pipeline=dummy_pipeline,
catalog=dummy_catalog_with_run_id,
)
runner.run(dummy_pipeline, dummy_catalog_with_run_id, session._hook_manager)

hook_manager = _create_hook_manager()
_register_hooks(hook_manager, (mlflow_hook,))
runner.run(dummy_pipeline, dummy_catalog_with_run_id, hook_manager)

current_run_id = mlflow.active_run().info.run_id

Expand Down
30 changes: 17 additions & 13 deletions tests/framework/hooks/test_hook_on_pipeline_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,26 +86,30 @@ def mocked_register_pipelines():
)


@pytest.mark.usefixtures("mock_settings_with_mlflow_hooks")
# @pytest.mark.usefixtures("mock_settings_with_mlflow_hooks")
@pytest.mark.usefixtures("mock_failing_pipeline")
def test_on_pipeline_error(kedro_project_with_mlflow_conf):
tracking_uri = (kedro_project_with_mlflow_conf / "mlruns").as_uri()

bootstrap_project(kedro_project_with_mlflow_conf)
with KedroSession.create(project_path=kedro_project_with_mlflow_conf) as session:
context = session.load_context()
from logging import getLogger

LOGGER = getLogger(__name__)
LOGGER.info(f"{mlflow.active_run()=}")
with pytest.raises(ValueError):
LOGGER.info(f"{mlflow.active_run()=}")
session.run()

# the run we want is the last one in the configuration experiment
mlflow_client = MlflowClient(tracking_uri)
experiment = mlflow_client.get_experiment_by_name(
context.mlflow.tracking.experiment.name
)
failing_run_info = (
MlflowClient(tracking_uri).search_runs(experiment.experiment_id)[0].info
)
assert mlflow.active_run() is None # the run must have been closed
assert failing_run_info.status == RunStatus.to_string(
RunStatus.FAILED
) # it must be marked as failed
# the run we want is the last one in the configuration experiment
mlflow_client = MlflowClient(tracking_uri)
experiment = mlflow_client.get_experiment_by_name(
context.mlflow.tracking.experiment.name
)
failing_run_info = mlflow_client.search_runs(experiment.experiment_id)[-1].info

assert mlflow.active_run() is None # the run must have been closed
assert failing_run_info.status == RunStatus.to_string(
RunStatus.FAILED
) # it must be marked as failed
Loading
Loading