-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathkedro_mlflow_config.py
202 lines (157 loc) · 6.91 KB
/
kedro_mlflow_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import os
from pathlib import Path, PurePath
from typing import List, Optional
import mlflow
from kedro.config import MissingConfigException
from kedro.framework.session import KedroSession, get_current_session
from kedro.framework.startup import _is_project
from mlflow.entities import Experiment
from mlflow.tracking.client import MlflowClient
from pydantic import BaseModel, PrivateAttr, StrictBool, validator
from typing_extensions import Literal
class DisableTrackingOptions(BaseModel):
# mutable default is ok for pydantic : https://stackoverflow.com/questions/63793662/how-to-give-a-pydantic-list-field-a-default-value
pipelines: List[str] = []
class Config:
extra = "forbid"
class ExperimentOptions(BaseModel):
name: str = "Default"
create: StrictBool = True
class Config:
extra = "forbid"
class RunOptions(BaseModel):
id: Optional[str]
name: Optional[str]
nested: StrictBool = True
class Config:
extra = "forbid"
class UiOptions(BaseModel):
port: str = "5000"
host: str = "127.0.0.1"
class Config:
extra = "forbid"
class NodeHookOptions(BaseModel):
flatten_dict_params: StrictBool = False
recursive: StrictBool = True
sep: str = "."
long_parameters_strategy: Literal["fail", "truncate", "tag"] = "fail"
class Config:
extra = "forbid"
class HookOptions(BaseModel):
node: NodeHookOptions = NodeHookOptions()
class Config:
extra = "forbid"
class KedroMlflowConfig(BaseModel):
project_path: Path # if str, will be converted
mlflow_tracking_uri: str = "mlruns"
credentials: Optional[str]
disable_tracking: DisableTrackingOptions = DisableTrackingOptions()
experiment: ExperimentOptions = ExperimentOptions()
run: RunOptions = RunOptions()
ui: UiOptions = UiOptions()
hooks: HookOptions = HookOptions()
_mlflow_client: MlflowClient = PrivateAttr()
_experiment: Experiment = PrivateAttr()
# do not create _experiment immediately to avoid creating
# a database connection when creating the object
# it will be instantiated on setup() call
class Config:
# force triggering type control when setting value instead of init
validate_assignment = True
# raise an error if an unknown key is passed to the constructor
extra = "forbid"
def __init__(self, **kwargs):
super().__init__(**kwargs)
# init after validating the uri, else mlflow creates a mlruns folder at the root
self._mlflow_client = MlflowClient(tracking_uri=self.mlflow_tracking_uri)
def setup(self, session: KedroSession = None):
"""Setup all the mlflow configuration"""
self._export_credentials(session)
# we set the configuration now: it takes priority
# if it has already be set in export_credentials
mlflow.set_tracking_uri(self.mlflow_tracking_uri)
self._get_or_create_experiment()
def _export_credentials(self, session: KedroSession = None):
session = session or get_current_session()
context = session.load_context()
conf_creds = context._get_config_credentials()
mlflow_creds = conf_creds.get(self.credentials, {})
for key, value in mlflow_creds.items():
os.environ[key] = value
def _get_or_create_experiment(self):
"""Best effort to get the experiment associated
to the configuration
Returns:
mlflow.entities.Experiment -- [description]
"""
# retrieve the experiment
self._experiment = self._mlflow_client.get_experiment_by_name(
name=self.experiment.name
)
# Deal with two side case when retrieving the experiment
if self.experiment.create:
if self._experiment is None:
# case 1 : the experiment does not exist, it must be created manually
experiment_id = self._mlflow_client.create_experiment(
name=self.experiment.name
)
self._experiment = self._mlflow_client.get_experiment(
experiment_id=experiment_id
)
elif self._experiment.lifecycle_stage == "deleted":
# case 2: the experiment was created, then deleted : we have to restore it manually
self._mlflow_client.restore_experiment(self._experiment.experiment_id)
@validator("project_path")
def _is_kedro_project(cls, folder_path):
if not _is_project(folder_path):
raise KedroMlflowConfigError(
f"'project_path' = '{folder_path}' is not the root of kedro project"
)
return folder_path
# pre=make a conversion before it is set
# always=even for default value
# values enable access to the other field, see https://pydantic-docs.helpmanual.io/usage/validators/
@validator("mlflow_tracking_uri", pre=True, always=True)
def _validate_uri(cls, uri, values):
"""Format the uri provided to match mlflow expectations.
Arguments:
uri {Union[None, str]} -- A valid filepath for mlflow uri
Returns:
str -- A valid mlflow_tracking_uri
"""
# this is a special reserved keyword for mlflow which should not be converted to a path
# se: https://mlflow.org/docs/latest/tracking.html#where-runs-are-recorded
if uri == "databricks":
return uri
# if no tracking uri is provided, we register the runs locally at the root of the project
pathlib_uri = PurePath(uri)
from urllib.parse import urlparse
if pathlib_uri.is_absolute():
valid_uri = pathlib_uri.as_uri()
else:
parsed = urlparse(uri)
if parsed.scheme == "":
# if it is a local relative path, make it absolute
# .resolve() does not work well on windows
# .absolute is undocumented and have known bugs
# Path.cwd() / uri is the recommend way by core developpers.
# See : https://discuss.python.org/t/pathlib-absolute-vs-resolve/2573/6
valid_uri = (values["project_path"] / uri).as_uri()
else:
# else assume it is an uri
valid_uri = uri
return valid_uri
class KedroMlflowConfigError(Exception):
"""Error occurred when loading the configuration"""
def get_mlflow_config(session: Optional[KedroSession] = None):
session = session or get_current_session()
context = session.load_context()
try:
conf_mlflow_yml = context.config_loader.get("mlflow*", "mlflow*/**")
except MissingConfigException:
raise KedroMlflowConfigError(
"No 'mlflow.yml' config file found in environment. Use ``kedro mlflow init`` command in CLI to create a default config file."
)
conf_mlflow_yml["project_path"] = context.project_path
mlflow_config = KedroMlflowConfig.parse_obj(conf_mlflow_yml)
return mlflow_config