Skip to content

Commit 791c7d5

Browse files
committed
address comments
1 parent 1b64e12 commit 791c7d5

File tree

4 files changed

+50
-16
lines changed

4 files changed

+50
-16
lines changed

sdmetrics/single_table/data_augmentation/base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _transform_preprocess(self, tables):
6262
6363
Args:
6464
tables (dict[str, pandas.DataFrame]):
65-
The tables to transform.
65+
Dict containing `real_training_data`, `synthetic_data` and `real_validation_data`.
6666
"""
6767
tables_result = {}
6868
for table_name, table in tables.items():
@@ -82,8 +82,8 @@ def _get_best_threshold(self, train_data, train_target):
8282
"""Find the best threshold for the classifier model."""
8383
target_probabilities = self._classifier.predict_proba(train_data)[:, 1]
8484
precision, recall, thresholds = precision_recall_curve(train_target, target_probabilities)
85-
# To assess the preicision efficacy, we have to fix the recall and reciprocally
86-
metric = precision if self.metric_name == 'recall' else recall
85+
metric_map = {'precision': precision, 'recall': recall}
86+
metric = metric_map[self._metric_to_fix]
8787
best_threshold = 0.0
8888
valid_idx = np.where(metric >= self.fixed_value)[0]
8989
if valid_idx.size:

sdmetrics/single_table/data_augmentation/utils.py

+34-12
Original file line numberDiff line numberDiff line change
@@ -3,43 +3,64 @@
33
import pandas as pd
44

55

6-
def _validate_parameters(
7-
real_training_data,
8-
synthetic_data,
9-
real_validation_data,
10-
metadata,
11-
prediction_column_name,
12-
classifier,
13-
fixed_recall_value,
14-
):
15-
"""Validate the parameters of the Data Augmentation metrics."""
6+
def _validate_tables(real_training_data, synthetic_data, real_validation_data):
7+
"""Validate the tables of the Data Augmentation metrics."""
168
tables = [real_training_data, synthetic_data, real_validation_data]
179
if any(not isinstance(table, pd.DataFrame) for table in tables):
1810
raise ValueError(
1911
'`real_training_data`, `synthetic_data` and `real_validation_data` must be '
2012
'pandas DataFrames.'
2113
)
2214

15+
16+
def _validate_metadata(metadata):
17+
"""Validate the metadata of the Data Augmentation metrics."""
2318
if not isinstance(metadata, dict):
2419
raise TypeError(
2520
f"Expected a dictionary but received a '{type(metadata).__name__}' instead."
2621
" For SDV metadata objects, please use the 'to_dict' function to convert it"
2722
' to a dictionary.'
2823
)
2924

25+
26+
def _validate_prediction_column_name(prediction_column_name):
27+
"""Validate the prediction column name of the Data Augmentation metrics."""
3028
if not isinstance(prediction_column_name, str):
3129
raise TypeError('`prediction_column_name` must be a string.')
3230

31+
32+
def _validate_classifier(classifier):
33+
"""Validate the classifier of the Data Augmentation metrics."""
3334
if classifier is not None and not isinstance(classifier, str):
3435
raise TypeError('`classifier` must be a string or None.')
3536

3637
if classifier != 'XGBoost':
3738
raise ValueError('Currently only `XGBoost` is supported as classifier.')
3839

40+
41+
def _validate_fixed_recall_value(fixed_recall_value):
42+
"""Validate the fixed recall value of the Data Augmentation metrics."""
3943
if not isinstance(fixed_recall_value, (int, float)) or not (0 < fixed_recall_value < 1):
4044
raise TypeError('`fixed_recall_value` must be a float in the range (0, 1).')
4145

4246

47+
def _validate_parameters(
48+
real_training_data,
49+
synthetic_data,
50+
real_validation_data,
51+
metadata,
52+
prediction_column_name,
53+
classifier,
54+
fixed_recall_value,
55+
):
56+
"""Validate the parameters of the Data Augmentation metrics."""
57+
_validate_tables(real_training_data, synthetic_data, real_validation_data)
58+
_validate_metadata(metadata)
59+
_validate_prediction_column_name(prediction_column_name)
60+
_validate_classifier(classifier)
61+
_validate_fixed_recall_value(fixed_recall_value)
62+
63+
4364
def _validate_data_and_metadata(
4465
real_training_data,
4566
synthetic_data,
@@ -89,10 +110,11 @@ def _validate_data_and_metadata(
89110
synthetic_labels = set(synthetic_data[prediction_column_name].unique())
90111
real_labels = set(real_training_data[prediction_column_name].unique())
91112
if not synthetic_labels.issubset(real_labels):
113+
to_print = "', '".join(sorted(synthetic_labels - real_labels))
92114
raise ValueError(
93115
f'The ``{prediction_column_name}`` column must have the same values in the real '
94-
'and synthetic data. The synthetic data has the following unseen values: '
95-
f'{sorted(synthetic_labels - real_labels)}'
116+
'and synthetic data. The following values are present in the synthetic data and'
117+
f" not the real data: '{to_print}'"
96118
)
97119

98120

tests/unit/single_table/data_augmentation/test_base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
import pytest
88
from sklearn.metrics import precision_score, recall_score
9+
from xgboost import XGBClassifier
910

1011
from sdmetrics.single_table.data_augmentation.base import BaseDataAugmentationMetric
1112

@@ -103,7 +104,7 @@ def test__fit(self, real_training_data, metadata):
103104
assert metric.fixed_value == fixed_recall_value
104105
assert metric._metric_method == recall_score
105106
assert metric._classifier_name == classifier
106-
# assert metric._classifier == 'XGBClassifier()'
107+
assert isinstance(metric._classifier, XGBClassifier)
107108

108109
@patch('sdmetrics.single_table.data_augmentation.base.precision_recall_curve')
109110
def test__get_best_threshold(self, mock_precision_recall_curve, real_training_data):
@@ -120,6 +121,7 @@ def test__get_best_threshold(self, mock_precision_recall_curve, real_training_da
120121
np.array([0.02, 0.15, 0.25, 0.35, 0.42, 0.51, 0.63, 0.77, 0.82, 0.93, 0.97]),
121122
]
122123
metric.metric_name = 'recall'
124+
metric._metric_to_fix = 'precision'
123125
metric.fixed_value = 0.69
124126
train_data = real_training_data[['numerical']]
125127
train_target = real_training_data['target']

tests/unit/single_table/data_augmentation/test_utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ def test__validate_data_and_metadata():
111111
'the column `target` for the real validation data. The `precision`and `recall`'
112112
' are undefined for this case.'
113113
)
114+
expected_error_synthetic_wrong_label = re.escape(
115+
'The ``target`` column must have the same values in the real and synthetic data. '
116+
'The following values are present in the synthetic data and not the real'
117+
" data: 'wrong_1', 'wrong_2'"
118+
)
114119

115120
# Run and Assert
116121
_validate_data_and_metadata(**inputs)
@@ -146,6 +151,11 @@ def test__validate_data_and_metadata():
146151
with pytest.raises(ValueError, match=expected_error_missing_minority):
147152
_validate_data_and_metadata(**missing_minority_class_label_validation)
148153

154+
wrong_synthetic_label = deepcopy(inputs)
155+
wrong_synthetic_label['synthetic_data'] = pd.DataFrame({'target': [0, 1, 'wrong_1', 'wrong_2']})
156+
with pytest.raises(ValueError, match=expected_error_synthetic_wrong_label):
157+
_validate_data_and_metadata(**wrong_synthetic_label)
158+
149159

150160
@patch('sdmetrics.single_table.data_augmentation.utils._validate_parameters')
151161
@patch('sdmetrics.single_table.data_augmentation.utils._validate_data_and_metadata')

0 commit comments

Comments
 (0)