@@ -185,22 +185,20 @@ def test_get_scores(self, real_training_data, real_validation_data):
185
185
class TestBaseDataAugmentationMetric :
186
186
"""Test the BaseDataAugmentationMetric class."""
187
187
188
- def test__fit_preprocess (self , real_training_data , metadata ):
189
- """Test the ``_fit_preprocess `` method."""
188
+ def test__fit (self , real_training_data , metadata ):
189
+ """Test the ``_fit `` method."""
190
190
# Setup
191
191
metric = BaseDataAugmentationMetric ()
192
192
193
193
# Run
194
- discrete_columns , datetime_columns = metric ._fit_preprocess (
195
- real_training_data , metadata , 'target'
196
- )
194
+ discrete_columns , datetime_columns = metric ._fit (real_training_data , metadata , 'target' )
197
195
198
196
# Assert
199
197
assert discrete_columns == ['categorical' , 'boolean' ]
200
198
assert datetime_columns == ['datetime' ]
201
199
202
- def test__transform_preprocess (self , real_training_data , synthetic_data , real_validation_data ):
203
- """Test the ``_transform_preprocess `` method."""
200
+ def test__transform (self , real_training_data , synthetic_data , real_validation_data ):
201
+ """Test the ``_transform `` method."""
204
202
# Setup
205
203
metric = BaseDataAugmentationMetric ()
206
204
discrete_columns = ['categorical' , 'boolean' ]
@@ -212,9 +210,7 @@ def test__transform_preprocess(self, real_training_data, synthetic_data, real_va
212
210
}
213
211
214
212
# Run
215
- transformed = metric ._transform_preprocess (
216
- tables , discrete_columns , datetime_columns , 'target' , 1
217
- )
213
+ transformed = metric ._transform (tables , discrete_columns , datetime_columns , 'target' , 1 )
218
214
219
215
# Assert
220
216
expected_transformed = {
@@ -257,10 +253,10 @@ def test__fit_transform(
257
253
"""Test the ``_fit_transform`` method."""
258
254
# Setup
259
255
metric = BaseDataAugmentationMetric ()
260
- BaseDataAugmentationMetric ._fit_preprocess = Mock ()
256
+ BaseDataAugmentationMetric ._fit = Mock ()
261
257
discrete_columns = ['categorical' , 'boolean' ]
262
258
datetime_columns = ['datetime' ]
263
- BaseDataAugmentationMetric ._fit_preprocess .return_value = (
259
+ BaseDataAugmentationMetric ._fit .return_value = (
264
260
discrete_columns ,
265
261
datetime_columns ,
266
262
)
@@ -269,18 +265,18 @@ def test__fit_transform(
269
265
'synthetic_data' : synthetic_data ,
270
266
'real_validation_data' : real_validation_data ,
271
267
}
272
- BaseDataAugmentationMetric ._transform_preprocess = Mock (return_value = tables )
268
+ BaseDataAugmentationMetric ._transform = Mock (return_value = tables )
273
269
274
270
# Run
275
271
transformed = metric ._fit_transform (
276
272
real_training_data , synthetic_data , real_validation_data , metadata , 'target' , 1
277
273
)
278
274
279
275
# Assert
280
- BaseDataAugmentationMetric ._fit_preprocess .assert_called_once_with (
276
+ BaseDataAugmentationMetric ._fit .assert_called_once_with (
281
277
real_training_data , metadata , 'target'
282
278
)
283
- BaseDataAugmentationMetric ._transform_preprocess .assert_called_once_with (
279
+ BaseDataAugmentationMetric ._transform .assert_called_once_with (
284
280
tables , discrete_columns , datetime_columns , 'target' , 1
285
281
)
286
282
for table_name , table in transformed .items ():
0 commit comments