Skip to content

Commit 3c2e199

Browse files
committed
Fix tests
1 parent ea090e1 commit 3c2e199

File tree

2 files changed

+110
-49
lines changed

2 files changed

+110
-49
lines changed

tests/integration/single_table/privacy/test_dcr_overfitting_protection.py

-49
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import random
2-
import re
32

43
import pandas as pd
5-
import pytest
64
from sklearn.model_selection import train_test_split
75

86
from sdmetrics.demos import load_single_table_demo
@@ -127,50 +125,3 @@ def test_compute_breakdown_iterations(self):
127125

128126
assert compute_num_iteration_1 != compute_num_iteration_1000
129127
assert compute_train_same['score'] == 0.0
130-
131-
def test_validation(self):
132-
# Setup
133-
train_data = pd.DataFrame({'num_col': [random.randint(1, 1000) for _ in range(50)]})
134-
holdout_data = pd.DataFrame({'num_col': [random.randint(1, 1000) for _ in range(50)]})
135-
synthetic_data = pd.DataFrame({'num_col': [random.randint(1, 1000) for _ in range(50)]})
136-
metadata = {'columns': {'num_col': {'sdtype': 'numerical'}}}
137-
138-
zero_subsample_msg = re.escape('num_rows_subsample (0) must be an integer greater than 1.')
139-
with pytest.raises(ValueError, match=zero_subsample_msg):
140-
DCROverfittingProtection.compute_breakdown(
141-
train_data, synthetic_data, holdout_data, metadata, 0
142-
)
143-
144-
subsample_none_msg = re.escape(
145-
'num_iterations should not be greater than 1 if there is no subsampling.'
146-
)
147-
with pytest.raises(ValueError, match=subsample_none_msg):
148-
DCROverfittingProtection.compute_breakdown(
149-
train_data, synthetic_data, holdout_data, metadata, None, 10
150-
)
151-
152-
zero_iteration_msg = re.escape('num_iterations (0) must be an integer greater than 1.')
153-
with pytest.raises(ValueError, match=zero_iteration_msg):
154-
DCROverfittingProtection.compute_breakdown(
155-
train_data, synthetic_data, holdout_data, metadata, 1, 0
156-
)
157-
158-
no_dcr_metadata = {'columns': {'bad_col': {'sdtype': 'unknown'}}}
159-
no_dcr_data = pd.DataFrame({'bad_col': [1.0]})
160-
161-
missing_metric = 'There are no overlapping statistical columns to measure.'
162-
with pytest.raises(ValueError, match=missing_metric):
163-
DCROverfittingProtection.compute_breakdown(
164-
no_dcr_data, no_dcr_data, no_dcr_data, no_dcr_metadata
165-
)
166-
167-
small_holdout_data = holdout_data.sample(frac=0.2)
168-
small_validation_msg = (
169-
f'Your real_validation_data contains {len(small_holdout_data)} rows while your '
170-
f'real_training_data contains {len(holdout_data)} rows. For most accurate '
171-
'results, we recommend that the validation data at least half the size of the training data.'
172-
)
173-
with pytest.warns(UserWarning, match=small_validation_msg):
174-
DCROverfittingProtection.compute_breakdown(
175-
train_data, synthetic_data, small_holdout_data, metadata
176-
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import random
2+
import re
3+
from unittest.mock import patch
4+
5+
import numpy as np
6+
import pandas as pd
7+
import pytest
8+
9+
from sdmetrics.single_table.privacy.dcr_overfitting_protection import DCROverfittingProtection
10+
11+
12+
@pytest.fixture()
13+
def test_data():
14+
train_data = pd.DataFrame({'num_col': [random.randint(1, 1000) for _ in range(50)]})
15+
holdout_data = pd.DataFrame({'num_col': [random.randint(1, 1000) for _ in range(50)]})
16+
synthetic_data = pd.DataFrame({'num_col': [random.randint(1, 1000) for _ in range(50)]})
17+
metadata = {'columns': {'num_col': {'sdtype': 'numerical'}}}
18+
return (train_data, holdout_data, synthetic_data, metadata)
19+
20+
21+
class TestDCROverfittingProtection:
22+
def test__validate_inputs(self, test_data):
23+
"""Test that we properly validate inputs to our DCROverfittingProtection."""
24+
# Setup
25+
train_data, holdout_data, synthetic_data, metadata = test_data
26+
27+
# Run and Assert
28+
zero_subsample_msg = re.escape('num_rows_subsample (0) must be an integer greater than 1.')
29+
with pytest.raises(ValueError, match=zero_subsample_msg):
30+
DCROverfittingProtection.compute_breakdown(
31+
train_data, synthetic_data, holdout_data, metadata, 0
32+
)
33+
34+
subsample_none_msg = re.escape(
35+
'num_iterations should not be greater than 1 if there is no subsampling.'
36+
)
37+
with pytest.raises(ValueError, match=subsample_none_msg):
38+
DCROverfittingProtection.compute_breakdown(
39+
train_data, synthetic_data, holdout_data, metadata, None, 10
40+
)
41+
42+
zero_iteration_msg = re.escape('num_iterations (0) must be an integer greater than 1.')
43+
with pytest.raises(ValueError, match=zero_iteration_msg):
44+
DCROverfittingProtection.compute_breakdown(
45+
train_data, synthetic_data, holdout_data, metadata, 1, 0
46+
)
47+
48+
no_dcr_metadata = {'columns': {'bad_col': {'sdtype': 'unknown'}}}
49+
no_dcr_data = pd.DataFrame({'bad_col': [1.0]})
50+
51+
missing_metric = 'There are no overlapping statistical columns to measure.'
52+
with pytest.raises(ValueError, match=missing_metric):
53+
DCROverfittingProtection.compute_breakdown(
54+
no_dcr_data, no_dcr_data, no_dcr_data, no_dcr_metadata
55+
)
56+
57+
small_holdout_data = holdout_data.sample(frac=0.2)
58+
small_validation_msg = (
59+
f'Your real_validation_data contains {len(small_holdout_data)} rows while your '
60+
f'real_training_data contains {len(holdout_data)} rows. For most accurate '
61+
'results, we recommend that the validation data at least half the size of the training data.'
62+
)
63+
with pytest.warns(UserWarning, match=small_validation_msg):
64+
DCROverfittingProtection.compute_breakdown(
65+
train_data, synthetic_data, small_holdout_data, metadata
66+
)
67+
68+
@patch('numpy.where')
69+
@patch('sdmetrics.single_table.privacy.dcr_overfitting_protection.calculate_dcr')
70+
def test_compute_breakdown(self, mock_calculate_dcr, mock_numpy_where, test_data):
71+
"""Test that compute breakdown correctly measures the fraction of data overfitted."""
72+
# Setup
73+
train_data, holdout_data, synthetic_data, metadata = test_data
74+
num_iterations = 2
75+
num_rows_subsample = 2
76+
mock_calculate_dcr_array = np.array([0.0] * 50)
77+
mock_calculate_dcr.return_value = pd.DataFrame(mock_calculate_dcr_array, columns=['dcr'])
78+
data = np.array([1] * 25 + [0] * 25)
79+
mock_numpy_where.return_value = pd.Series(data)
80+
81+
# Run
82+
result = DCROverfittingProtection.compute_breakdown(
83+
train_data, synthetic_data, holdout_data, metadata, num_rows_subsample, num_iterations
84+
)
85+
86+
# Assert
87+
assert mock_calculate_dcr.call_count == 2 * num_iterations
88+
assert result['score'] == 1.0
89+
assert result['synthetic_data_percentages']['closer_to_training'] == 0.5
90+
assert result['synthetic_data_percentages']['closer_to_holdout'] == 0.5
91+
92+
@patch(
93+
'sdmetrics.single_table.privacy.dcr_overfitting_protection.DCROverfittingProtection.compute_breakdown'
94+
)
95+
def test_compute(self, mock_compute_breakdown, test_data):
96+
"""Test that compute makes a call to compute_breakdown."""
97+
# Setup
98+
train_data, holdout_data, synthetic_data, metadata = test_data
99+
num_iterations = 2
100+
num_rows_subsample = 2
101+
102+
# Run
103+
DCROverfittingProtection.compute(
104+
train_data, synthetic_data, holdout_data, metadata, num_rows_subsample, num_iterations
105+
)
106+
107+
# Assert
108+
mock_compute_breakdown.assert_called_once_with(
109+
train_data, synthetic_data, holdout_data, metadata, num_rows_subsample, num_iterations
110+
)

0 commit comments

Comments
 (0)