-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathkedro_mlflow_config.py
247 lines (191 loc) · 8.54 KB
/
kedro_mlflow_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import os
from pathlib import Path, PurePath
from typing import Dict, List, Optional
from urllib.parse import urlparse
import kedro.framework.session.session as kfss # necessary to access the global variable _active_session of the namespace
import mlflow
from kedro.config import MissingConfigException
from kedro.framework.session import KedroSession # , get_current_session
from kedro.framework.startup import _is_project
from mlflow.entities import Experiment
from mlflow.tracking.client import MlflowClient
from pydantic import BaseModel, PrivateAttr, StrictBool, validator
from typing_extensions import Literal
class MlflowServerOptions(BaseModel):
# mutable default is ok for pydantic : https://stackoverflow.com/questions/63793662/how-to-give-a-pydantic-list-field-a-default-value
mlflow_tracking_uri: str = "mlruns"
stores_environment_variables: Dict[str, str] = {}
credentials: Optional[str] = None
_mlflow_client: MlflowClient = PrivateAttr()
class Config:
extra = "forbid"
class DisableTrackingOptions(BaseModel):
# mutable default is ok for pydantic : https://stackoverflow.com/questions/63793662/how-to-give-a-pydantic-list-field-a-default-value
pipelines: List[str] = []
class Config:
extra = "forbid"
class ExperimentOptions(BaseModel):
name: str = "Default"
restore_if_deleted: StrictBool = True
_experiment: Experiment = PrivateAttr()
# do not create _experiment immediately to avoid creating
# a database connection when creating the object
# it will be instantiated on setup() call
class Config:
extra = "forbid"
class RunOptions(BaseModel):
id: Optional[str] = None
name: Optional[str] = None
nested: StrictBool = True
class Config:
extra = "forbid"
class DictParamsOptions(BaseModel):
flatten: StrictBool = False
recursive: StrictBool = True
sep: str = "."
class Config:
extra = "forbid"
class MlflowParamsOptions(BaseModel):
dict_params: DictParamsOptions = DictParamsOptions()
long_params_strategy: Literal["fail", "truncate", "tag"] = "fail"
class Config:
extra = "forbid"
class MlflowTrackingOptions(BaseModel):
# mutable default is ok for pydantic : https://stackoverflow.com/questions/63793662/how-to-give-a-pydantic-list-field-a-default-value
disable_tracking: DisableTrackingOptions = DisableTrackingOptions()
experiment: ExperimentOptions = ExperimentOptions()
run: RunOptions = RunOptions()
params: MlflowParamsOptions = MlflowParamsOptions()
class Config:
extra = "forbid"
class UiOptions(BaseModel):
port: str = "5000"
host: str = "127.0.0.1"
class Config:
extra = "forbid"
class KedroMlflowConfig(BaseModel):
project_path: Path # if str, will be converted
server: MlflowServerOptions = MlflowServerOptions()
tracking: MlflowTrackingOptions = MlflowTrackingOptions()
ui: UiOptions = UiOptions()
class Config:
# force triggering type control when setting value instead of init
validate_assignment = True
# raise an error if an unknown key is passed to the constructor
extra = "forbid"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.server.mlflow_tracking_uri = self._validate_uri(
self.server.mlflow_tracking_uri
)
# init after validating the uri, else mlflow creates a mlruns folder at the root
self.server._mlflow_client = MlflowClient(
tracking_uri=self.server.mlflow_tracking_uri
)
def setup(self, session: KedroSession = None):
"""Setup all the mlflow configuration"""
self._export_credentials(session)
# we set the configuration now: it takes priority
# if it has already be set in export_credentials
mlflow.set_tracking_uri(self.server.mlflow_tracking_uri)
self._set_experiment()
def _export_credentials(self, session: KedroSession = None):
session = session or _get_current_session()
context = session.load_context()
conf_creds = context._get_config_credentials()
mlflow_creds = conf_creds.get(self.server.credentials, {})
for key, value in mlflow_creds.items():
os.environ[key] = value
def _set_experiment(self):
"""Best effort to get the experiment associated
to the configuration
Returns:
mlflow.entities.Experiment -- [description]
"""
# we retrieve the experiment manually to check if it exsits
mlflow_experiment = self.server._mlflow_client.get_experiment_by_name(
name=self.tracking.experiment.name
)
# Deal with two side case when retrieving the experiment
if mlflow_experiment is not None:
if (
self.tracking.experiment.restore_if_deleted
and mlflow_experiment.lifecycle_stage == "deleted"
):
# the experiment was created, then deleted : we have to restore it manually before setting it as the active one
self.server._mlflow_client.restore_experiment(
mlflow_experiment.experiment_id
)
# this creates the experiment if it does not exists
# and creates a global variable with the experiment
# but returns nothing
mlflow.set_experiment(experiment_name=self.tracking.experiment.name)
# we do not use "experiment" variable directly but we fetch again from the database
# because if it did not exists at all, it was created by previous command
self.tracking.experiment._experiment = (
self.server._mlflow_client.get_experiment_by_name(
name=self.tracking.experiment.name
)
)
def _validate_uri(self, uri):
"""Format the uri provided to match mlflow expectations.
Arguments:
uri {Union[None, str]} -- A valid filepath for mlflow uri
Returns:
str -- A valid mlflow_tracking_uri
"""
# this is a special reserved keyword for mlflow which should not be converted to a path
# se: https://mlflow.org/docs/latest/tracking.html#where-runs-are-recorded
if uri == "databricks":
return uri
# if no tracking uri is provided, we register the runs locally at the root of the project
pathlib_uri = PurePath(uri)
if pathlib_uri.is_absolute():
valid_uri = pathlib_uri.as_uri()
else:
parsed = urlparse(uri)
if parsed.scheme == "":
# if it is a local relative path, make it absolute
# .resolve() does not work well on windows
# .absolute is undocumented and have known bugs
# Path.cwd() / uri is the recommend way by core developpers.
# See : https://discuss.python.org/t/pathlib-absolute-vs-resolve/2573/6
valid_uri = (self.project_path / uri).as_uri()
else:
# else assume it is an uri
valid_uri = uri
return valid_uri
@validator("project_path")
def _is_kedro_project(cls, folder_path):
if not _is_project(folder_path):
raise KedroMlflowConfigError(
f"'project_path' = '{folder_path}' is not the root of kedro project"
)
return folder_path
class KedroMlflowConfigError(Exception):
"""Error occurred when loading the configuration"""
def get_mlflow_config(session: Optional[KedroSession] = None):
session = session or _get_current_session()
context = session.load_context()
try:
conf_mlflow_yml = context._config_loader.get("mlflow*", "mlflow*/**")
except MissingConfigException:
raise KedroMlflowConfigError(
"No 'mlflow.yml' config file found in environment. Use ``kedro mlflow init`` command in CLI to create a default config file."
)
conf_mlflow_yml["project_path"] = context.project_path
mlflow_config = KedroMlflowConfig.parse_obj(conf_mlflow_yml)
return mlflow_config
def _get_current_session(silent: bool = False) -> Optional["KedroSession"]:
"""Fetch the active ``KedroSession`` instance.
Args:
silent: Indicates to suppress the error if no active session was found.
Raises:
RuntimeError: If no active session was found and `silent` is False.
Returns:
KedroSession instance.
"""
# _active_session is a global variable from kedro itself
if not kfss._active_session and not silent:
raise RuntimeError("There is no active Kedro session.")
return kfss._active_session