3
3
'''
4
4
5
5
from sklearn .base import BaseEstimator , RegressorMixin
6
+ from sklearn .utils import check_array , check_X_y
7
+ from sklearn .utils .validation import check_is_fitted
6
8
7
9
8
10
def run (train_X , train_y , test_X , test_y , varnames = None , verbose = False ):
@@ -15,6 +17,7 @@ class FFXRegressor(BaseEstimator, RegressorMixin):
15
17
'''This class provides a Scikit-learn style estimator.'''
16
18
17
19
def fit (self , X , y ):
20
+ X , y = check_X_y (X , y , y_numeric = True , multi_output = False )
18
21
# if X is a Pandas DataFrame, we don't have to pass in varnames.
19
22
# otherwise we make up placeholders.
20
23
if hasattr (X , 'columns' ):
@@ -25,8 +28,11 @@ def fit(self, X, y):
25
28
X , y , X , y , varnames = varnames
26
29
)
27
30
self .model_ = self .models_ [- 1 ] # pylint: disable=attribute-defined-outside-init
31
+ return self
28
32
29
33
def predict (self , X ):
34
+ check_is_fitted (self , "model_" )
35
+ X = check_array (X , accept_sparse = False )
30
36
return self .model_ .predict (X )
31
37
32
38
def complexity (self ):
0 commit comments