MlflowArtifactDataset
is a wrapper for any AbstractDataset
which logs the dataset automatically in mlflow as an artifact when its save
method is called. It can be used both with the YAML API:
my_dataset_to_version:
type: kedro_mlflow.io.artifacts.MlflowArtifactDataset
dataset:
type: pandas.CSVDataset # or any valid kedro DataSet
filepath: /path/to/a/local/destination/file.csv
or with additional parameters:
my_dataset_to_version:
type: kedro_mlflow.io.artifacts.MlflowArtifactDataset
dataset:
type: pandas.CSVDataset # or any valid kedro DataSet
filepath: /path/to/a/local/destination/file.csv
load_args:
sep: ;
save_args:
sep: ;
# ... any other valid arguments for dataset
run_id: 13245678910111213 # a valid mlflow run to log in. If None, default to active run
artifact_path: reporting # relative path where the artifact must be stored. if None, saved in root folder.
or with the python API:
from kedro_mlflow.io.artifacts import MlflowArtifactDataset
from kedro_datasets.pandas import CSVDataset
csv_dataset = MlflowArtifactDataset(
dataset={"type": CSVDataset, "filepath": r"/path/to/a/local/destination/file.csv"}
)
csv_dataset.save(data=pd.DataFrame({"a": [1, 2], "b": [3, 4]}))
The MlflowMetricDataset
is documented here.
The MlflowMetricHistoryDataset
is documented here.
The MlflowModelTrackingDataset
accepts the following arguments:
- flavor (str): Built-in or custom MLflow model flavor module. Must be Python-importable.
- run_id (Optional[str], optional): MLflow run ID to use to load the model from or save the model to. It plays the same role as "filepath" for standard mlflow datasets. Defaults to None.
- artifact_path (str, optional): the run relative path to the model.
- pyfunc_workflow (str, optional): Either
python_model
orloader_module
.See mlflow workflows. - load_args (Dict[str, Any], optional): Arguments to
load_model
function from specifiedflavor
. Defaults to None. - save_args (Dict[str, Any], optional): Arguments to
log_model
function from specifiedflavor
. Defaults to None.
You can either only specify the flavor:
from kedro_mlflow.io.models import MlflowModelTrackingDataset
from sklearn.linear_model import LinearRegression
mlflow_model_tracking = MlflowModelTrackingDataset(flavor="mlflow.sklearn")
mlflow_model_tracking.save(LinearRegression())
Let assume that this first model has been saved once, and you xant to retrieve it (for prediction for instance):
mlflow_model_tracking = MlflowModelTrackingDataset(
flavor="mlflow.sklearn", run_id="<the-model-run-id>"
)
my_linear_regression = mlflow_model_tracking.load()
my_linear_regression.predict(
data
) # will obviously fail if you have not fitted your model object first :)
You can also specify some logging parameters:
mlflow_model_tracking = MlflowModelTrackingDataset(
flavor="mlflow.sklearn",
run_id="<the-model-run-id>",
save_args={
"conda_env": {"python": "3.10.0", "dependencies": ["kedro==0.18.11"]},
"input_example": data.iloc[0:5, :],
},
)
mlflow_model_tracking.save(LinearRegression().fit(data))
As always with kedro, you can use it directly in the catalog.yml
file:
my_model:
type: kedro_mlflow.io.models.MlflowModelTrackingDataset
flavor: "mlflow.sklearn"
run_id: <the-model-run-id>,
save_args:
conda_env:
python: "3.10.0"
dependencies:
- "kedro==0.18.11"
The MlflowModelLocalFileSystemDataset
accepts the following arguments:
- flavor (str): Built-in or custom MLflow model flavor module. Must be Python-importable.
- filepath (str): Path to store the dataset locally.
- pyfunc_workflow (str, optional): Either
python_model
orloader_module
. See mlflow workflows. - load_args (Dict[str, Any], optional): Arguments to
load_model
function from specifiedflavor
. Defaults to None. - save_args (Dict[str, Any], optional): Arguments to
save_model
function from specifiedflavor
. Defaults to None. - version (Version, optional): Kedro version to use. Defaults to None.
The use is very similar to MlflowModelTrackingDataset
, but you have to specify a local filepath
instead of a run_id
:
from kedro_mlflow.io.models import MlflowModelTrackingDataset
from sklearn.linear_model import LinearRegression
mlflow_model_tracking = MlflowModelLocalFileSystemDataset(
flavor="mlflow.sklearn", filepath="path/to/where/you/want/model"
)
mlflow_model_tracking.save(LinearRegression().fit(data))
The same arguments are available, plus an additional version
common to usual AbstractVersionedDataset
mlflow_model_tracking = MlflowModelLocalFileSystemDataset(
flavor="mlflow.sklearn",
filepath="path/to/where/you/want/model",
version="<valid-kedro-version>",
)
my_model = mlflow_model_tracking.load()
and with the YAML API in the catalog.yml
:
my_model:
type: kedro_mlflow.io.models.MlflowModelLocalFileSystemDataset
flavor: mlflow.sklearn
filepath: path/to/where/you/want/model
version: <valid-kedro-version>
The MlflowModelRegistryDataset
accepts the following arguments:
model_name
(str): The name of the registered model is the mlflow registrystage_or_version
(str): A valid stage (either "staging" or "production") or version number for the registred model.Default to None,(internally converted to "latest" if no alias si provided) which fetch the last version and the higher "stage" available.alias
(str): A valid alias, which is used instead of stage to filter model since mlflow 2.9.0. Will raise an error if bothstage_or_version
andalias
are provided.flavor
(str): Built-in or custom MLflow model flavor module. Must be Python-importable.pyfunc_workflow
(str, optional): Eitherpython_model
orloader_module
. See mlflow workflows.load_args
(Dict[str, Any], optional): Arguments toload_model
function from specifiedflavor
. Defaults to None.
We assume you have registered a mlflow model first, either with the MlflowClient
or within the mlflow ui, e.g. :
from sklearn.tree import DecisionTreeClassifier
import mlflow
import mlflow.sklearn
with mlflow.start_run():
model = DecisionTreeClassifier()
# Log the sklearn model and register as version 1
mlflow.sklearn.log_model(
sk_model=model, artifact_path="model", registered_model_name="my_awesome_model"
)
You can fetch the model by its name:
from kedro_mlflow.io.models import MlflowModelRegistryDataset
mlflow_model_tracking = MlflowModelRegistryDataset(model_name="my_awesome_model")
my_model = mlflow_model_tracking.load()
and with the YAML API in the catalog.yml
(only for loading an existing model):
my_model:
type: kedro_mlflow.io.models.MlflowModelRegistryDataset
model_name: my_awesome_model