Skip to content

Commit 03e84e0

Browse files
committed
Removes occurrences of DataFrame.values (ndarray)
Uses the DataFrame everywhere it's possible.
1 parent e8a16e3 commit 03e84e0

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/fklearn/training/classification.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -571,21 +571,21 @@ def lgbm_classification_learner(df: pd.DataFrame,
571571
params = assoc(params, "eta", learning_rate)
572572
params = params if "objective" in params else assoc(params, "objective", 'binary')
573573

574-
weights = df[weight_column].values if weight_column else None
574+
weights = df[weight_column] if weight_column else None
575575

576576
features = features if not encode_extra_cols else expand_features_encoded(df, features)
577577

578578
dtrain = lgbm.Dataset(df[features], label=df[target], feature_name=list(map(str, features)), weight=weights,
579579
silent=True, categorical_feature=categorical_features)
580580

581-
bst = lgbm.train(params, dtrain, num_estimators)
581+
bst = lgbm.train(params, dtrain, num_estimators, categorical_feature=categorical_features)
582582

583583
def p(new_df: pd.DataFrame, apply_shap: bool = False) -> pd.DataFrame:
584584
if params["objective"] == "multiclass":
585585
col_dict = {prediction_column + "_" + str(key): value
586-
for (key, value) in enumerate(bst.predict(new_df[features].values).T)}
586+
for (key, value) in enumerate(bst.predict(new_df[features]).T)}
587587
else:
588-
col_dict = {prediction_column: bst.predict(new_df[features].values)}
588+
col_dict = {prediction_column: bst.predict(new_df[features])}
589589

590590
if apply_shap:
591591
import shap

0 commit comments

Comments
 (0)