Skip to content

Commit 6946835

Browse files
authored
Update final scoring method for BinaryClassifierEfficacy precision & recall metrics (make 0.5 a threshold) (#725)
1 parent 7261608 commit 6946835

File tree

5 files changed

+13
-17
lines changed

5 files changed

+13
-17
lines changed

sdmetrics/single_table/data_augmentation/base.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,9 @@ def compute_breakdown(
227227
f'fixed_{metric_to_fix}_value': trainer.fixed_value,
228228
},
229229
}
230-
result['score'] = max(
231-
0,
232-
(
233-
result['augmented_data'][f'{cls.metric_name}_score_validation']
234-
- result['real_data_baseline'][f'{cls.metric_name}_score_validation']
235-
),
236-
)
230+
augmented_score = result['augmented_data'][f'{cls.metric_name}_score_validation']
231+
baseline_score = result['real_data_baseline'][f'{cls.metric_name}_score_validation']
232+
result['score'] = (augmented_score - baseline_score) / 2 + 0.5
237233
return result
238234

239235
@classmethod

sdmetrics/single_table/data_augmentation/binary_classifier_precision_efficacy.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def compute_breakdown(
1818
metadata,
1919
prediction_column_name,
2020
minority_class_label,
21-
classifier,
22-
fixed_recall_value,
21+
classifier='xgboost',
22+
fixed_recall_value=0.9,
2323
):
2424
"""Compute the score breakdown of the metric."""
2525
return super().compute_breakdown(
@@ -42,8 +42,8 @@ def compute(
4242
metadata,
4343
prediction_column_name,
4444
minority_class_label,
45-
classifier,
46-
fixed_recall_value,
45+
classifier='xgboost',
46+
fixed_recall_value=0.9,
4747
):
4848
"""Compute the score of the metric.
4949

tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def test_with_nan_target_column(self):
162162
'classifier': 'XGBoost',
163163
'fixed_recall_value': 0.8,
164164
},
165-
'score': 0,
165+
'score': 0.48571428571428577,
166166
}
167167
assert result_breakdown == expected_result
168168

@@ -244,6 +244,6 @@ def test_with_multi_class(self):
244244
'classifier': 'XGBoost',
245245
'fixed_recall_value': 0.8,
246246
},
247-
'score': 0,
247+
'score': 0.4944444444444444,
248248
}
249249
assert score_breakdown == expected_score_breakdown

tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_with_nan_target_column(self):
124124
)
125125

126126
# Assert
127-
assert result_breakdown['score'] in (0, 0.07692307692307698)
127+
assert result_breakdown['score'] in (0.5, 0.5384615384615385)
128128

129129
def test_with_minority_being_majority(self):
130130
"""Test the metric when the minority class is the majority class."""
@@ -148,7 +148,7 @@ def test_with_minority_being_majority(self):
148148
)
149149

150150
# Assert
151-
assert score == 0
151+
assert score == 0.46153846153846156
152152

153153
def test_with_multi_class(self):
154154
"""Test the metric with multi-class classification.
@@ -175,4 +175,4 @@ def test_with_multi_class(self):
175175
)
176176

177177
# Assert
178-
assert score_breakdown['score'] in (0, 0.07692307692307687)
178+
assert score_breakdown['score'] in (0.46153846153846156, 0.5384615384615384)

tests/unit/single_table/data_augmentation/test_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def test_compute_breakdown(
357357

358358
# Assert
359359
expected_result = {
360-
'score': 0.19999999999999996,
360+
'score': 0.6,
361361
'real_data_baseline': real_data_baseline,
362362
'augmented_data': augmented_table_result,
363363
'parameters': {

0 commit comments

Comments
 (0)