-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathkedro_pipeline_model.py
56 lines (46 loc) · 2.18 KB
/
kedro_pipeline_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from copy import deepcopy
from pathlib import Path
from kedro.io import DataCatalog, MemoryDataSet
from kedro.runner import SequentialRunner
from mlflow.pyfunc import PythonModel
from kedro_mlflow.pipeline.pipeline_ml import PipelineML
class KedroPipelineModel(PythonModel):
def __init__(self, pipeline_ml: PipelineML, catalog: DataCatalog):
self.pipeline_ml = pipeline_ml
self.initial_catalog = pipeline_ml.extract_pipeline_catalog(catalog)
self.loaded_catalog = DataCatalog()
def load_context(self, context):
# a consistency check is made when loading the model
# it would be better to check when saving the model
# but we rely on a mlflow function for saving, and it is unaware of kedro
# pipeline structure
mlflow_artifacts_keys = set(context.artifacts.keys())
kedro_artifacts_keys = set(
self.pipeline_ml.inference.inputs() - {self.pipeline_ml.input_name}
)
if mlflow_artifacts_keys != kedro_artifacts_keys:
in_artifacts_but_not_inference = (
mlflow_artifacts_keys - kedro_artifacts_keys
)
in_inference_but_not_artifacts = (
kedro_artifacts_keys - mlflow_artifacts_keys
)
raise ValueError(
f"Provided artifacts do not match catalog entries:\n- 'artifacts - inference.inputs()' = : {in_artifacts_but_not_inference}'\n- 'inference.inputs() - artifacts' = : {in_inference_but_not_artifacts}'"
)
self.loaded_catalog = deepcopy(self.initial_catalog)
for name, uri in context.artifacts.items():
self.loaded_catalog._data_sets[name]._filepath = Path(uri)
def predict(self, context, model_input):
# TODO : checkout out how to pass extra args in predict
# for instance, to enable parallelization
self.loaded_catalog.add(
data_set_name=self.pipeline_ml.input_name,
data_set=MemoryDataSet(model_input),
replace=True,
)
runner = SequentialRunner()
run_outputs = runner.run(
pipeline=self.pipeline_ml.inference, catalog=self.loaded_catalog
)
return run_outputs