Commit 5818ac3 1 parent 8097069 commit 5818ac3 Copy full SHA for 5818ac3
File tree 3 files changed +40
-15
lines changed
sdmetrics/single_table/privacy
tests/unit/single_table/privacy
3 files changed +40
-15
lines changed Original file line number Diff line number Diff line change 8
8
from sdmetrics .goal import Goal
9
9
from sdmetrics .single_table .base import SingleTableMetric
10
10
from sdmetrics .single_table .privacy .dcr_utils import calculate_dcr
11
+ from sdmetrics .single_table .privacy .util import validate_num_samples_num_iteration
11
12
12
13
13
14
class DCROverfittingProtection (SingleTableMetric ):
@@ -32,20 +33,7 @@ def _validate_inputs(
32
33
num_rows_subsample ,
33
34
num_iterations ,
34
35
):
35
- if num_rows_subsample is not None :
36
- if not isinstance (num_rows_subsample , int ) or num_rows_subsample < 1 :
37
- raise ValueError (
38
- f'num_rows_subsample ({ num_rows_subsample } ) must be an integer greater than 1.'
39
- )
40
- elif num_rows_subsample is None and num_iterations > 1 :
41
- raise ValueError (
42
- 'num_iterations should not be greater than 1 if there is no subsampling.'
43
- )
44
-
45
- if not isinstance (num_iterations , int ) or num_iterations < 1 :
46
- raise ValueError (
47
- f'num_iterations ({ num_iterations } ) must be an integer greater than 1.'
48
- )
36
+ validate_num_samples_num_iteration (num_rows_subsample , num_iterations )
49
37
50
38
if len (real_training_data ) * 0.5 > len (real_validation_data ):
51
39
warnings .warn (
Original file line number Diff line number Diff line change @@ -148,3 +148,16 @@ def allow_nan_array(attributes):
148
148
ret .append (entry )
149
149
150
150
return ret
151
+
152
+
153
+ def validate_num_samples_num_iteration (num_rows_subsample , num_iterations ):
154
+ if num_rows_subsample is not None :
155
+ if not isinstance (num_rows_subsample , int ) or num_rows_subsample < 1 :
156
+ raise ValueError (
157
+ f'num_rows_subsample ({ num_rows_subsample } ) must be an integer greater than 1.'
158
+ )
159
+ elif num_rows_subsample is None and num_iterations > 1 :
160
+ raise ValueError ('num_iterations should not be greater than 1 if there is no subsampling.' )
161
+
162
+ if not isinstance (num_iterations , int ) or num_iterations < 1 :
163
+ raise ValueError (f'num_iterations ({ num_iterations } ) must be an integer greater than 1.' )
Original file line number Diff line number Diff line change 1
- from sdmetrics .single_table .privacy .util import closest_neighbors
1
+ import re
2
+
3
+ import pytest
4
+
5
+ from sdmetrics .single_table .privacy .util import (
6
+ closest_neighbors ,
7
+ validate_num_samples_num_iteration ,
8
+ )
2
9
3
10
4
11
def test_closest_neighbors_exact ():
@@ -30,3 +37,20 @@ def test_closest_neighbors_non_exact():
30
37
assert ('a' , '1' ) in results
31
38
assert ('a' , '3' ) in results
32
39
assert ('b' , '2' ) in results
40
+
41
+
42
+ def test_validate_num_samples_num_iteration ():
43
+ # Run and Assert
44
+ zero_subsample_msg = re .escape ('num_rows_subsample (0) must be an integer greater than 1.' )
45
+ with pytest .raises (ValueError , match = zero_subsample_msg ):
46
+ validate_num_samples_num_iteration (0 , 1 )
47
+
48
+ subsample_none_msg = re .escape (
49
+ 'num_iterations should not be greater than 1 if there is no subsampling.'
50
+ )
51
+ with pytest .raises (ValueError , match = subsample_none_msg ):
52
+ validate_num_samples_num_iteration (None , 2 )
53
+
54
+ zero_iteration_msg = re .escape ('num_iterations (0) must be an integer greater than 1.' )
55
+ with pytest .raises (ValueError , match = zero_iteration_msg ):
56
+ validate_num_samples_num_iteration (1 , 0 )
You can’t perform that action at this time.
0 commit comments