Skip to content

Commit b1d3612

Browse files
Add DCROverfittingProtection metric (#733)
Co-authored-by: R-Palazzo <romainpalazzo@gmail.com>
1 parent a6f655c commit b1d3612

File tree

7 files changed

+508
-1
lines changed

7 files changed

+508
-1
lines changed

sdmetrics/single_table/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
DisclosureProtection,
7272
DisclosureProtectionEstimate,
7373
)
74+
from sdmetrics.single_table.privacy.dcr_overfitting_protection import DCROverfittingProtection
7475
from sdmetrics.single_table.privacy.ensemble import CategoricalEnsemble
7576
from sdmetrics.single_table.privacy.numerical_sklearn import NumericalLR, NumericalMLP, NumericalSVR
7677
from sdmetrics.single_table.privacy.radius_nearest_neighbor import NumericalRadiusNearestNeighbor
@@ -136,4 +137,5 @@
136137
'RangeCoverage',
137138
'NewRowSynthesis',
138139
'TableStructure',
140+
'DCROverfittingProtection',
139141
]

sdmetrics/single_table/privacy/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
DisclosureProtection,
1717
DisclosureProtectionEstimate,
1818
)
19+
from sdmetrics.single_table.privacy.dcr_overfitting_protection import DCROverfittingProtection
1920
from sdmetrics.single_table.privacy.ensemble import CategoricalEnsemble
2021
from sdmetrics.single_table.privacy.numerical_sklearn import NumericalLR, NumericalMLP, NumericalSVR
2122
from sdmetrics.single_table.privacy.radius_nearest_neighbor import NumericalRadiusNearestNeighbor
@@ -37,4 +38,5 @@
3738
'NumericalPrivacyMetric',
3839
'NumericalRadiusNearestNeighbor',
3940
'NumericalSVR',
41+
'DCROverfittingProtection',
4042
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""DCR Overfitting Protection metrics."""
2+
3+
import warnings
4+
5+
import numpy as np
6+
7+
from sdmetrics._utils_metadata import _process_data_with_metadata
8+
from sdmetrics.goal import Goal
9+
from sdmetrics.single_table.base import SingleTableMetric
10+
from sdmetrics.single_table.privacy.dcr_utils import calculate_dcr
11+
from sdmetrics.single_table.privacy.util import validate_num_samples_num_iteration
12+
13+
14+
class DCROverfittingProtection(SingleTableMetric):
15+
"""DCR Overfitting Protection metric.
16+
17+
This metric uses a DCR (distance to closest record) computation to measure whether the
18+
synthetic data has been overfit to the real data, as compared to a holdout set.
19+
"""
20+
21+
name = 'DCROverfittingProtection'
22+
goal = Goal.MAXIMIZE
23+
min_value = 0.0
24+
max_value = 1.0
25+
26+
@classmethod
27+
def _validate_inputs(
28+
cls,
29+
real_training_data,
30+
synthetic_data,
31+
real_validation_data,
32+
metadata,
33+
num_rows_subsample,
34+
num_iterations,
35+
):
36+
validate_num_samples_num_iteration(num_rows_subsample, num_iterations)
37+
38+
if num_rows_subsample and num_rows_subsample > len(synthetic_data):
39+
warnings.warn(
40+
f'num_rows_subsample ({num_rows_subsample}) is greater than the length of the '
41+
f'synthetic data ({len(synthetic_data)}). Ignoring the num_rows_subsample and '
42+
'num_iterations args.',
43+
)
44+
num_rows_subsample = None
45+
num_iterations = 1
46+
47+
if len(real_training_data) * 0.5 > len(real_validation_data):
48+
warnings.warn(
49+
f'Your real_validation_data contains {len(real_validation_data)} rows while your '
50+
f'real_training_data contains {len(real_training_data)} rows. For most accurate '
51+
'results, we recommend that the validation data at least half the size of the training data.'
52+
)
53+
54+
real_data_copy = real_training_data.copy()
55+
synthetic_data_copy = synthetic_data.copy()
56+
real_validation_copy = real_validation_data.copy()
57+
real_data_copy = _process_data_with_metadata(real_data_copy, metadata, True)
58+
synthetic_data_copy = _process_data_with_metadata(synthetic_data_copy, metadata, True)
59+
real_validation_copy = _process_data_with_metadata(real_validation_copy, metadata, True)
60+
61+
return (
62+
real_data_copy,
63+
synthetic_data_copy,
64+
real_validation_copy,
65+
num_rows_subsample,
66+
num_iterations,
67+
)
68+
69+
@classmethod
70+
def compute_breakdown(
71+
cls,
72+
real_training_data,
73+
synthetic_data,
74+
real_validation_data,
75+
metadata,
76+
num_rows_subsample=None,
77+
num_iterations=1,
78+
):
79+
"""Compute the DCROverfittingProtection metric.
80+
81+
Args:
82+
real_training_data (pd.DataFrame):
83+
A pd.DataFrame object containing the real data used for training the synthesizer.
84+
synthetic_data (pd.DataFrame):
85+
A pandas.DataFrame object containing the synthetic data sampled
86+
from the synthesizer.
87+
real_validation_data (pd.DataFrame):
88+
A pandas.DataFrame object containing a validation set of real data.
89+
This data should not have been used to train the synthesizer.
90+
metadata (dict):
91+
A metadata dictionary that describes the table of data.
92+
num_rows_subsample (int or None):
93+
The number of synthetic data rows to subsample from the synthetic data.
94+
This is used to increase the speed of the computation, if the dataset is large.
95+
Defaults to None which means no subsampling will be done.
96+
num_iterations (int):
97+
The number of iterations to perform when subsampling.
98+
The final score will be the average of all iterations. Default is 1 iteration.
99+
100+
Returns:
101+
dict:
102+
Returns a dictionary that contains the overall score, the % of synthetic data rows
103+
that were closer to the validation set, and the % of synthetic data rows that were
104+
closer to the real dataset. Averages of the medians are returned in the case of
105+
multiple iterations.
106+
"""
107+
sanitized_data = cls._validate_inputs(
108+
real_training_data,
109+
synthetic_data,
110+
real_validation_data,
111+
metadata,
112+
num_rows_subsample,
113+
num_iterations,
114+
)
115+
116+
training_data = sanitized_data[0]
117+
sanitized_synthetic_data = sanitized_data[1]
118+
validation_data = sanitized_data[2]
119+
num_rows_subsample = sanitized_data[3]
120+
num_iterations = sanitized_data[4]
121+
122+
sum_of_scores = 0
123+
sum_percent_close_to_real = 0
124+
sum_percent_close_to_random = 0
125+
for _ in range(num_iterations):
126+
synthetic_sample = sanitized_synthetic_data
127+
if num_rows_subsample is not None:
128+
synthetic_sample = sanitized_synthetic_data.sample(n=num_rows_subsample)
129+
130+
dcr_real = calculate_dcr(training_data, synthetic_sample, metadata)
131+
dcr_holdout = calculate_dcr(validation_data, synthetic_sample, metadata)
132+
133+
num_rows_closer_to_real = np.where(dcr_real < dcr_holdout, 1.0, 0.0).sum()
134+
total_rows = dcr_real.size
135+
percentage_close_to_real = num_rows_closer_to_real / total_rows
136+
percentage_close_to_random = 1 - percentage_close_to_real
137+
score = min((1.0 - percentage_close_to_real) * 2, 1.0)
138+
sum_of_scores += score
139+
sum_percent_close_to_real += percentage_close_to_real
140+
sum_percent_close_to_random += percentage_close_to_random
141+
142+
result = {
143+
'score': sum_of_scores / num_iterations,
144+
'synthetic_data_percentages': {
145+
'closer_to_training': sum_percent_close_to_real / num_iterations,
146+
'closer_to_holdout': sum_percent_close_to_random / num_iterations,
147+
},
148+
}
149+
150+
return result
151+
152+
@classmethod
153+
def compute(
154+
cls,
155+
real_training_data,
156+
synthetic_data,
157+
real_validation_data,
158+
metadata,
159+
num_rows_subsample=None,
160+
num_iterations=1,
161+
):
162+
"""Compute the DCROverfittingProtection metric.
163+
164+
Args:
165+
real_training_data (pd.DataFrame):
166+
A pd.DataFrame object containing the real data used for training the synthesizer.
167+
synthetic_data (pd.DataFrame):
168+
A pandas.DataFrame object containing the synthetic data sampled
169+
from the synthesizer.
170+
real_validation_data (pd.DataFrame):
171+
A pandas.DataFrame object containing a validation set of real data.
172+
This data should not have been used to train the synthesizer.
173+
metadata (dict):
174+
A metadata dictionary that describes the table of data.
175+
num_rows_subsample (int or None):
176+
The number of synthetic data rows to subsample from the synthetic data.
177+
This is used to increase the speed of the computation, if the dataset is large.
178+
Defaults to None which means no subsampling will be done.
179+
num_iterations (int):
180+
The number of iterations to perform when subsampling.
181+
The final score will be the average of all iterations. Default is 1 iteration.
182+
183+
Returns:
184+
float:
185+
The score for the DCROverfittingProtection metric.
186+
"""
187+
result = cls.compute_breakdown(
188+
real_training_data,
189+
synthetic_data,
190+
real_validation_data,
191+
metadata,
192+
num_rows_subsample,
193+
num_iterations,
194+
)
195+
196+
return result.get('score')

sdmetrics/single_table/privacy/util.py

+14
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,17 @@ def allow_nan_array(attributes):
148148
ret.append(entry)
149149

150150
return ret
151+
152+
153+
def validate_num_samples_num_iteration(num_rows_subsample, num_iterations):
154+
if num_rows_subsample is not None:
155+
if not isinstance(num_rows_subsample, int) or num_rows_subsample < 1:
156+
raise ValueError(
157+
f'num_rows_subsample ({num_rows_subsample}) must be an integer greater than 1.'
158+
)
159+
160+
elif num_rows_subsample is None and num_iterations > 1:
161+
raise ValueError('num_iterations should not be greater than 1 if there is no subsampling.')
162+
163+
if not isinstance(num_iterations, int) or num_iterations < 1:
164+
raise ValueError(f'num_iterations ({num_iterations}) must be an integer greater than 1.')
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import random
2+
import re
3+
4+
import pandas as pd
5+
import pytest
6+
from sklearn.model_selection import train_test_split
7+
8+
from sdmetrics.demos import load_single_table_demo
9+
from sdmetrics.single_table.privacy import DCROverfittingProtection
10+
11+
12+
class TestDCROverfittingProtection:
13+
def test_end_to_end_with_demo(self):
14+
"""Test end to end for DCROverfittingProtection metric against the demo dataset.
15+
16+
In this end to end test, test against demo dataset. Use subsampling to speed
17+
up the test. Make sure that if hold two datasets to be the same we get expected
18+
values even with subsampling. Note that if synthetic data is equally distant from
19+
the training data and the holdout data, it is labeled as closer to holdout data.
20+
"""
21+
# Setup
22+
real_data, synthetic_data, metadata = load_single_table_demo()
23+
train_df, holdout_df = train_test_split(real_data, test_size=0.2)
24+
25+
# Run
26+
num_rows_subsample = 50
27+
compute_breakdown_result = DCROverfittingProtection.compute_breakdown(
28+
train_df, synthetic_data, holdout_df, metadata
29+
)
30+
compute_result = DCROverfittingProtection.compute(
31+
train_df, synthetic_data, holdout_df, metadata
32+
)
33+
compute_holdout_same = DCROverfittingProtection.compute_breakdown(
34+
train_df, synthetic_data, synthetic_data, metadata, num_rows_subsample
35+
)
36+
compute_train_same = DCROverfittingProtection.compute_breakdown(
37+
synthetic_data, synthetic_data, holdout_df, metadata, num_rows_subsample
38+
)
39+
compute_all_same = DCROverfittingProtection.compute_breakdown(
40+
synthetic_data,
41+
synthetic_data,
42+
synthetic_data,
43+
metadata,
44+
num_rows_subsample,
45+
)
46+
47+
synth_percentages_key = 'synthetic_data_percentages'
48+
synth_train_key = 'closer_to_training'
49+
synth_holdout_key = 'closer_to_holdout'
50+
score_key = 'score'
51+
52+
# Assert
53+
assert compute_result == compute_breakdown_result[score_key]
54+
assert compute_holdout_same[score_key] == 1.0
55+
assert compute_holdout_same[synth_percentages_key][synth_train_key] == 0.0
56+
assert compute_holdout_same[synth_percentages_key][synth_holdout_key] == 1.0
57+
assert compute_train_same[score_key] == 0.0
58+
assert compute_train_same[synth_percentages_key][synth_train_key] == 1.0
59+
assert compute_train_same[synth_percentages_key][synth_holdout_key] == 0.0
60+
assert compute_all_same[score_key] == 1.0
61+
assert compute_all_same[synth_percentages_key][synth_train_key] == 0.0
62+
assert compute_all_same[synth_percentages_key][synth_holdout_key] == 1.0
63+
64+
def test_compute_breakdown_drop_all_columns(self):
65+
"""Testing invalid sdtypes and ensure only appropriate columns are measured."""
66+
# Setup
67+
train_data = pd.DataFrame({'bad_col': [10.0, 15.0], 'num_col': [1.0, 2.0]})
68+
synth_data = pd.DataFrame({'bad_col': [2.0, 1.0], 'num_col': [1.0, 2.0]})
69+
holdout_data = pd.DataFrame({'bad_col': [2.0, 1.0], 'num_col': [3.0, 4.0]})
70+
metadata = {
71+
'columns': {
72+
'bad_col': {'sdtype': 'unknown'},
73+
'num_col': {'sdtype': 'numerical'},
74+
}
75+
}
76+
77+
# Run
78+
result = DCROverfittingProtection.compute_breakdown(
79+
train_data, synth_data, holdout_data, metadata
80+
)
81+
82+
# Assert
83+
assert result['score'] == 0.0
84+
assert result['synthetic_data_percentages']['closer_to_training'] == 1.0
85+
assert result['synthetic_data_percentages']['closer_to_holdout'] == 0.0
86+
87+
def test_compute_breakdown_subsampling(self):
88+
"""Test subsampling produces different values."""
89+
# Setup
90+
train_data = pd.DataFrame({'num_col': [random.randint(1, 1000) for _ in range(50)]})
91+
holdout_data = pd.DataFrame({'num_col': [random.randint(1, 1000) for _ in range(50)]})
92+
synthetic_data = pd.DataFrame({'num_col': [random.randint(1, 1000) for _ in range(50)]})
93+
metadata = {'columns': {'num_col': {'sdtype': 'numerical'}}}
94+
num_rows_subsample = 4
95+
large_num_subsample = len(synthetic_data) * 2
96+
97+
# Run
98+
compute_subsample = DCROverfittingProtection.compute_breakdown(
99+
train_data, synthetic_data, holdout_data, metadata, num_rows_subsample
100+
)
101+
102+
large_subsample_msg = re.escape('Ignoring the num_rows_subsample and num_iterations args.')
103+
with pytest.warns(UserWarning, match=large_subsample_msg):
104+
compute_large_subsample = DCROverfittingProtection.compute_breakdown(
105+
train_data, synthetic_data, holdout_data, metadata, large_num_subsample
106+
)
107+
108+
compute_full_1 = DCROverfittingProtection.compute_breakdown(
109+
train_data, synthetic_data, holdout_data, metadata
110+
)
111+
compute_full_2 = DCROverfittingProtection.compute_breakdown(
112+
train_data, synthetic_data, holdout_data, metadata
113+
)
114+
115+
# Assert that subsampling provides different values if smaller than data length.
116+
assert compute_subsample != compute_full_1
117+
assert compute_full_1 == compute_full_2
118+
assert compute_large_subsample == compute_full_1
119+
120+
def test_compute_breakdown_iterations(self):
121+
"""Test that number iterations for subsampling works as expected."""
122+
# Setup
123+
train_data = pd.DataFrame({'num_col': [random.randint(1, 1000) for _ in range(10)]})
124+
holdout_data = pd.DataFrame({'num_col': [random.randint(1, 1000) for _ in range(10)]})
125+
synthetic_data = pd.DataFrame({'num_col': [random.randint(1, 1000) for _ in range(10)]})
126+
metadata = {'columns': {'num_col': {'sdtype': 'numerical'}}}
127+
num_rows_subsample = 3
128+
num_iterations = 1000
129+
130+
# Run
131+
compute_num_iteration_1 = DCROverfittingProtection.compute_breakdown(
132+
train_data, synthetic_data, holdout_data, metadata, num_rows_subsample, 1
133+
)
134+
compute_num_iteration_1000 = DCROverfittingProtection.compute_breakdown(
135+
train_data, synthetic_data, holdout_data, metadata, num_rows_subsample, num_iterations
136+
)
137+
compute_train_same = DCROverfittingProtection.compute_breakdown(
138+
synthetic_data,
139+
synthetic_data,
140+
holdout_data,
141+
metadata,
142+
num_rows_subsample,
143+
num_iterations,
144+
)
145+
146+
# Assert
147+
assert compute_num_iteration_1 != compute_num_iteration_1000
148+
assert compute_train_same['score'] == 0.0

0 commit comments

Comments
 (0)