Skip to content

Commit fbc5d37

Browse files
committed
naming: remove preprocess
1 parent 233bde5 commit fbc5d37

File tree

2 files changed

+15
-19
lines changed
  • sdmetrics/single_table/data_augmentation
  • tests/unit/single_table/data_augmentation

2 files changed

+15
-19
lines changed

sdmetrics/single_table/data_augmentation/base.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class BaseDataAugmentationMetric(SingleTableMetric):
100100
max_value = 1.0
101101

102102
@classmethod
103-
def _fit_preprocess(cls, data, metadata, prediction_column_name):
103+
def _fit(cls, data, metadata, prediction_column_name):
104104
"""Fit preprocessing parameters."""
105105
discrete_columns = []
106106
datetime_columns = []
@@ -115,7 +115,7 @@ def _fit_preprocess(cls, data, metadata, prediction_column_name):
115115
return discrete_columns, datetime_columns
116116

117117
@classmethod
118-
def _transform_preprocess(
118+
def _transform(
119119
cls,
120120
tables,
121121
discrete_columns,
@@ -152,7 +152,7 @@ def _fit_transform(
152152
minority_class_label,
153153
):
154154
"""Fit and transform the metric."""
155-
discrete_columns, datetime_columns = cls._fit_preprocess(
155+
discrete_columns, datetime_columns = cls._fit(
156156
real_training_data, metadata, prediction_column_name
157157
)
158158
tables = {
@@ -161,7 +161,7 @@ def _fit_transform(
161161
'real_validation_data': real_validation_data,
162162
}
163163

164-
return cls._transform_preprocess(
164+
return cls._transform(
165165
tables,
166166
discrete_columns,
167167
datetime_columns,

tests/unit/single_table/data_augmentation/test_base.py

+11-15
Original file line numberDiff line numberDiff line change
@@ -185,22 +185,20 @@ def test_get_scores(self, real_training_data, real_validation_data):
185185
class TestBaseDataAugmentationMetric:
186186
"""Test the BaseDataAugmentationMetric class."""
187187

188-
def test__fit_preprocess(self, real_training_data, metadata):
189-
"""Test the ``_fit_preprocess`` method."""
188+
def test__fit(self, real_training_data, metadata):
189+
"""Test the ``_fit`` method."""
190190
# Setup
191191
metric = BaseDataAugmentationMetric()
192192

193193
# Run
194-
discrete_columns, datetime_columns = metric._fit_preprocess(
195-
real_training_data, metadata, 'target'
196-
)
194+
discrete_columns, datetime_columns = metric._fit(real_training_data, metadata, 'target')
197195

198196
# Assert
199197
assert discrete_columns == ['categorical', 'boolean']
200198
assert datetime_columns == ['datetime']
201199

202-
def test__transform_preprocess(self, real_training_data, synthetic_data, real_validation_data):
203-
"""Test the ``_transform_preprocess`` method."""
200+
def test__transform(self, real_training_data, synthetic_data, real_validation_data):
201+
"""Test the ``_transform`` method."""
204202
# Setup
205203
metric = BaseDataAugmentationMetric()
206204
discrete_columns = ['categorical', 'boolean']
@@ -212,9 +210,7 @@ def test__transform_preprocess(self, real_training_data, synthetic_data, real_va
212210
}
213211

214212
# Run
215-
transformed = metric._transform_preprocess(
216-
tables, discrete_columns, datetime_columns, 'target', 1
217-
)
213+
transformed = metric._transform(tables, discrete_columns, datetime_columns, 'target', 1)
218214

219215
# Assert
220216
expected_transformed = {
@@ -257,10 +253,10 @@ def test__fit_transform(
257253
"""Test the ``_fit_transform`` method."""
258254
# Setup
259255
metric = BaseDataAugmentationMetric()
260-
BaseDataAugmentationMetric._fit_preprocess = Mock()
256+
BaseDataAugmentationMetric._fit = Mock()
261257
discrete_columns = ['categorical', 'boolean']
262258
datetime_columns = ['datetime']
263-
BaseDataAugmentationMetric._fit_preprocess.return_value = (
259+
BaseDataAugmentationMetric._fit.return_value = (
264260
discrete_columns,
265261
datetime_columns,
266262
)
@@ -269,18 +265,18 @@ def test__fit_transform(
269265
'synthetic_data': synthetic_data,
270266
'real_validation_data': real_validation_data,
271267
}
272-
BaseDataAugmentationMetric._transform_preprocess = Mock(return_value=tables)
268+
BaseDataAugmentationMetric._transform = Mock(return_value=tables)
273269

274270
# Run
275271
transformed = metric._fit_transform(
276272
real_training_data, synthetic_data, real_validation_data, metadata, 'target', 1
277273
)
278274

279275
# Assert
280-
BaseDataAugmentationMetric._fit_preprocess.assert_called_once_with(
276+
BaseDataAugmentationMetric._fit.assert_called_once_with(
281277
real_training_data, metadata, 'target'
282278
)
283-
BaseDataAugmentationMetric._transform_preprocess.assert_called_once_with(
279+
BaseDataAugmentationMetric._transform.assert_called_once_with(
284280
tables, discrete_columns, datetime_columns, 'target', 1
285281
)
286282
for table_name, table in transformed.items():

0 commit comments

Comments
 (0)