forked from Galileo-Galilei/kedro-mlflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmlflow_metrics_history_dataset.py
216 lines (177 loc) · 7.07 KB
/
mlflow_metrics_history_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
from functools import partial
from itertools import chain
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
import mlflow
from kedro.io import AbstractDataset, DatasetError
from mlflow.tracking import MlflowClient
MetricItem = Union[Dict[str, float], List[Dict[str, float]]]
MetricTuple = Tuple[str, float, int]
MetricsDict = Dict[str, MetricItem]
class MlflowMetricsHistoryDataset(AbstractDataset):
"""This class represent MLflow metrics dataset."""
def __init__(
self,
run_id: str = None,
prefix: Optional[str] = None,
metadata: Dict[str, Any] | None = None,
):
"""Initialise MlflowMetricsHistoryDataset.
Args:
prefix (Optional[str]): Prefix for metrics logged in MLflow.
run_id (str): ID of MLflow run.
"""
self._prefix = prefix
self.run_id = run_id
self._logging_activated = True # by default, logging is activated!
self.metadata = metadata
@property
def run_id(self):
"""Get run id.
If active run is not found, tries to find last experiment.
Raise `DatasetError` exception if run id can't be found.
Returns:
str: String contains run_id.
"""
if self._run_id is not None:
return self._run_id
run = mlflow.active_run()
if run:
return run.info.run_id
raise DatasetError("Cannot find run id.")
@run_id.setter
def run_id(self, run_id):
self._run_id = run_id
# we want to be able to turn logging off for an entire pipeline run
# To avoid that a single call to a dataset in the catalog creates a new run automatically
# we want to be able to turn everything off
@property
def _logging_activated(self):
return self.__logging_activated
@_logging_activated.setter
def _logging_activated(self, flag):
if not isinstance(flag, bool):
raise ValueError(f"_logging_activated must be a boolean, got {type(flag)}")
self.__logging_activated = flag
def _load(self) -> MetricsDict:
"""Load MlflowMetricDataSet.
Returns:
Dict[str, Union[int, float]]: Dictionary with MLflow metrics dataset.
"""
client = MlflowClient()
all_metrics_keys = list(client.get_run(self.run_id).data.metrics.keys())
dataset_metrics_keys = [
key for key in all_metrics_keys if self._is_dataset_metric(key)
]
dataset_metrics = {
key: self._convert_metric_history_to_list_or_dict(
client.get_metric_history(self.run_id, key)
)
for key in dataset_metrics_keys
}
return dataset_metrics
def _save(self, data: MetricsDict) -> None:
"""Save given MLflow metrics dataset and log it in MLflow as metrics.
Args:
data (MetricsDict): MLflow metrics dataset.
"""
client = MlflowClient()
try:
run_id = self.run_id
except DatasetError:
# If run_id can't be found log_metric would create new run.
run_id = None
log_metric = (
partial(client.log_metric, run_id)
if run_id is not None
else mlflow.log_metric
)
metrics = (
self._build_args_list_from_metric_item(k, v) for k, v in data.items()
)
if self._logging_activated:
for k, v, i in chain.from_iterable(metrics):
log_metric(k, v, step=i)
def _exists(self) -> bool:
"""Check if MLflow metrics dataset exists.
Returns:
bool: Is MLflow metrics dataset exists?
"""
client = MlflowClient()
all_metrics_keys = client.get_run(self.run_id).data.metrics.keys()
# all_metrics = client._tracking_client.store.get_all_metrics(
# run_uuid=self.run_id
# )
return any(self._is_dataset_metric(x) for x in all_metrics_keys)
def _describe(self) -> Dict[str, Any]:
"""Describe MLflow metrics dataset.
Returns:
Dict[str, Any]: Dictionary with MLflow metrics dataset description.
"""
return {
"run_id": self._run_id,
"prefix": self._prefix,
}
def _is_dataset_metric(self, key: str) -> bool:
"""Check if given metric belongs to dataset.
Args:
key str: The name of a mlflow metric registered in the run
"""
return self._prefix is None or (self._prefix and key.startswith(self._prefix))
# @staticmethod
# def _update_metric(
# metrics: List[mlflow.entities.Metric], dataset: MetricsDict = {}
# ) -> MetricsDict:
# """Update metric in given dataset.
# Args:
# metrics (List[mlflow.entities.Metric]): List with MLflow metric objects.
# dataset (MetricsDict): Dictionary contains MLflow metrics dataset.
# Returns:
# MetricsDict: Dictionary with MLflow metrics dataset.
# """
# for metric in metrics:
# metric_dict = {"step": metric.step, "value": metric.value}
# if metric.key in dataset:
# if isinstance(dataset[metric.key], list):
# dataset[metric.key].append(metric_dict)
# else:
# dataset[metric.key] = [dataset[metric.key], metric_dict]
# else:
# dataset[metric.key] = metric_dict
# return dataset
@staticmethod
def _convert_metric_history_to_list_or_dict(
metrics: List[mlflow.entities.Metric],
) -> Dict[str, Dict[str, Union[float, List[float]]]]:
"""Convert Mlflow metrics objects from MlflowClient().get_metric_history(run_id, key)
to a list [{'step': x, 'value': y}, {'step': ..., 'value': ...}]
Args:
metrics (List[mlflow.entities.Metric]): A list of MLflow Metrics retrieved from the run history
"""
metrics_as_list = [
{"step": metric.step, "value": metric.value} for metric in metrics
]
metrics_result = (
metrics_as_list[0] if len(metrics_as_list) == 1 else metrics_as_list
)
return metrics_result
def _build_args_list_from_metric_item(
self, key: str, value: MetricItem
) -> Generator[MetricTuple, None, None]:
"""Build list of tuples with metrics.
First element of a tuple is key, second metric value, third step.
If MLflow metrics dataset has prefix, it will be attached to key.
Args:
key (str): Metric key.
value (MetricItem): Metric value
Returns:
List[MetricTuple]: List with metrics as tuples.
"""
if self._prefix:
key = f"{self._prefix}.{key}"
if isinstance(value, dict):
return (i for i in [(key, value["value"], value["step"])])
if isinstance(value, list) and len(value) > 0:
return ((key, x["value"], x["step"]) for x in value)
raise DatasetError(
f"Unexpected metric value. Should be of type `{MetricItem}`, got {type(value)}"
)