Skip to content

Commit ced56d9

Browse files
authored
Making sklearn estimator compliant based on original commit proposed by pizzooid (#50)
1 parent 7cec4f1 commit ced56d9

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

ffx/api.py

+6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
'''
44

55
from sklearn.base import BaseEstimator, RegressorMixin
6+
from sklearn.utils import check_array, check_X_y
7+
from sklearn.utils.validation import check_is_fitted
68

79

810
def run(train_X, train_y, test_X, test_y, varnames=None, verbose=False):
@@ -15,6 +17,7 @@ class FFXRegressor(BaseEstimator, RegressorMixin):
1517
'''This class provides a Scikit-learn style estimator.'''
1618

1719
def fit(self, X, y):
20+
X, y = check_X_y(X, y, y_numeric=True, multi_output=False)
1821
# if X is a Pandas DataFrame, we don't have to pass in varnames.
1922
# otherwise we make up placeholders.
2023
if hasattr(X, 'columns'):
@@ -25,8 +28,11 @@ def fit(self, X, y):
2528
X, y, X, y, varnames=varnames
2629
)
2730
self.model_ = self.models_[-1] # pylint: disable=attribute-defined-outside-init
31+
return self
2832

2933
def predict(self, X):
34+
check_is_fitted(self, "model_")
35+
X = check_array(X, accept_sparse=False)
3036
return self.model_.predict(X)
3137

3238
def complexity(self):

ffx_tests/test_sklearn_api.py

+7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import ffx
22
import numpy as np
3+
from sklearn.utils.estimator_checks import check_estimator
34

45
EXPECTED_MODELS = [
56
(0, 1, '0.298'),
@@ -40,3 +41,9 @@ def test_sklearn_api():
4041
assert [
4142
(model.numBases(), model.complexity(), str(model)) for model in FFX.models_
4243
] == EXPECTED_MODELS
44+
45+
46+
def test_check_estimator():
47+
# Pass instance of estimator to run sklearn's built in estimator check
48+
check_estimator(ffx.FFXRegressor())
49+

0 commit comments

Comments
 (0)