Skip to content

Commit 5818ac3

Browse files
committed
Move validation check to a separate function for reuse
1 parent 8097069 commit 5818ac3

File tree

3 files changed

+40
-15
lines changed

3 files changed

+40
-15
lines changed

sdmetrics/single_table/privacy/dcr_overfitting_protection.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sdmetrics.goal import Goal
99
from sdmetrics.single_table.base import SingleTableMetric
1010
from sdmetrics.single_table.privacy.dcr_utils import calculate_dcr
11+
from sdmetrics.single_table.privacy.util import validate_num_samples_num_iteration
1112

1213

1314
class DCROverfittingProtection(SingleTableMetric):
@@ -32,20 +33,7 @@ def _validate_inputs(
3233
num_rows_subsample,
3334
num_iterations,
3435
):
35-
if num_rows_subsample is not None:
36-
if not isinstance(num_rows_subsample, int) or num_rows_subsample < 1:
37-
raise ValueError(
38-
f'num_rows_subsample ({num_rows_subsample}) must be an integer greater than 1.'
39-
)
40-
elif num_rows_subsample is None and num_iterations > 1:
41-
raise ValueError(
42-
'num_iterations should not be greater than 1 if there is no subsampling.'
43-
)
44-
45-
if not isinstance(num_iterations, int) or num_iterations < 1:
46-
raise ValueError(
47-
f'num_iterations ({num_iterations}) must be an integer greater than 1.'
48-
)
36+
validate_num_samples_num_iteration(num_rows_subsample, num_iterations)
4937

5038
if len(real_training_data) * 0.5 > len(real_validation_data):
5139
warnings.warn(

sdmetrics/single_table/privacy/util.py

+13
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,16 @@ def allow_nan_array(attributes):
148148
ret.append(entry)
149149

150150
return ret
151+
152+
153+
def validate_num_samples_num_iteration(num_rows_subsample, num_iterations):
154+
if num_rows_subsample is not None:
155+
if not isinstance(num_rows_subsample, int) or num_rows_subsample < 1:
156+
raise ValueError(
157+
f'num_rows_subsample ({num_rows_subsample}) must be an integer greater than 1.'
158+
)
159+
elif num_rows_subsample is None and num_iterations > 1:
160+
raise ValueError('num_iterations should not be greater than 1 if there is no subsampling.')
161+
162+
if not isinstance(num_iterations, int) or num_iterations < 1:
163+
raise ValueError(f'num_iterations ({num_iterations}) must be an integer greater than 1.')

tests/unit/single_table/privacy/test_util.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
from sdmetrics.single_table.privacy.util import closest_neighbors
1+
import re
2+
3+
import pytest
4+
5+
from sdmetrics.single_table.privacy.util import (
6+
closest_neighbors,
7+
validate_num_samples_num_iteration,
8+
)
29

310

411
def test_closest_neighbors_exact():
@@ -30,3 +37,20 @@ def test_closest_neighbors_non_exact():
3037
assert ('a', '1') in results
3138
assert ('a', '3') in results
3239
assert ('b', '2') in results
40+
41+
42+
def test_validate_num_samples_num_iteration():
43+
# Run and Assert
44+
zero_subsample_msg = re.escape('num_rows_subsample (0) must be an integer greater than 1.')
45+
with pytest.raises(ValueError, match=zero_subsample_msg):
46+
validate_num_samples_num_iteration(0, 1)
47+
48+
subsample_none_msg = re.escape(
49+
'num_iterations should not be greater than 1 if there is no subsampling.'
50+
)
51+
with pytest.raises(ValueError, match=subsample_none_msg):
52+
validate_num_samples_num_iteration(None, 2)
53+
54+
zero_iteration_msg = re.escape('num_iterations (0) must be an integer greater than 1.')
55+
with pytest.raises(ValueError, match=zero_iteration_msg):
56+
validate_num_samples_num_iteration(1, 0)

0 commit comments

Comments
 (0)