Skip to content

Commit 285cd08

Browse files
committed
Add warnings
1 parent 5815408 commit 285cd08

File tree

4 files changed

+316
-7
lines changed

4 files changed

+316
-7
lines changed

sdmetrics/single_table/privacy/cap.py

+224
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
"""CAP modules and their attackers."""
22

3+
import warnings
4+
35
from sdmetrics.single_table.privacy.base import CategoricalPrivacyMetric, PrivacyAttackerModel
46
from sdmetrics.single_table.privacy.util import closest_neighbors, count_frequency, majority
57

8+
DEPRECATION_MSG = (
9+
'Computing CAP metrics directly is deprecated. For improved privacy metrics, '
10+
"please use the 'DisclosureProtection' and 'DisclosureProtectionEstimate' "
11+
'metrics instead.'
12+
)
13+
614

715
class CAPAttacker(PrivacyAttackerModel):
816
"""The CAP (Correct Attribution Probability) privacy attacker.
@@ -78,6 +86,78 @@ class CategoricalCAP(CategoricalPrivacyMetric):
7886
MODEL = CAPAttacker
7987
ACCURACY_BASE = False
8088

89+
@classmethod
90+
def _compute(
91+
cls,
92+
real_data,
93+
synthetic_data,
94+
metadata=None,
95+
key_fields=None,
96+
sensitive_fields=None,
97+
model_kwargs=None,
98+
):
99+
return super().compute(
100+
real_data=real_data,
101+
synthetic_data=synthetic_data,
102+
metadata=metadata,
103+
key_fields=key_fields,
104+
sensitive_fields=sensitive_fields,
105+
model_kwargs=model_kwargs,
106+
)
107+
108+
@classmethod
109+
def compute(
110+
cls,
111+
real_data,
112+
synthetic_data,
113+
metadata=None,
114+
key_fields=None,
115+
sensitive_fields=None,
116+
model_kwargs=None,
117+
):
118+
"""Compute this metric.
119+
120+
This fits an adversial attacker model on the synthetic data and
121+
then evaluates it making predictions on the real data.
122+
123+
A ``key_fields`` column(s) name must be given, either directly or as a first level
124+
entry in the ``metadata`` dict, which will be used as the key column(s) for the
125+
attack.
126+
127+
A ``sensitive_fields`` column(s) name must be given, either directly or as a first level
128+
entry in the ``metadata`` dict, which will be used as the sensitive_fields column(s)
129+
for the attack.
130+
131+
Args:
132+
real_data (Union[numpy.ndarray, pandas.DataFrame]):
133+
The values from the real dataset.
134+
synthetic_data (Union[numpy.ndarray, pandas.DataFrame]):
135+
The values from the synthetic dataset.
136+
metadata (dict):
137+
Table metadata dict. If not passed, it is build based on the
138+
real_data fields and dtypes.
139+
key_fields (list(str)):
140+
Name of the column(s) to use as the key attributes.
141+
sensitive_fields (list(str)):
142+
Name of the column(s) to use as the sensitive attributes.
143+
model_kwargs (dict):
144+
Key word arguments of the attacker model. cls.MODEL_KWARGS will be used
145+
if none is provided.
146+
147+
Returns:
148+
union[float, tuple[float]]:
149+
Scores obtained by the attackers when evaluated on the real data.
150+
"""
151+
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
152+
return cls._compute(
153+
real_data=real_data,
154+
synthetic_data=synthetic_data,
155+
metadata=metadata,
156+
key_fields=key_fields,
157+
sensitive_fields=sensitive_fields,
158+
model_kwargs=model_kwargs,
159+
)
160+
81161

82162
class ZeroCAPAttacker(CAPAttacker):
83163
"""The 0CAP privacy attacker, which operates in the same way as CAP does.
@@ -113,6 +193,78 @@ class CategoricalZeroCAP(CategoricalPrivacyMetric):
113193
MODEL = ZeroCAPAttacker
114194
ACCURACY_BASE = False
115195

196+
@classmethod
197+
def _compute(
198+
cls,
199+
real_data,
200+
synthetic_data,
201+
metadata=None,
202+
key_fields=None,
203+
sensitive_fields=None,
204+
model_kwargs=None,
205+
):
206+
return super().compute(
207+
real_data=real_data,
208+
synthetic_data=synthetic_data,
209+
metadata=metadata,
210+
key_fields=key_fields,
211+
sensitive_fields=sensitive_fields,
212+
model_kwargs=model_kwargs,
213+
)
214+
215+
@classmethod
216+
def compute(
217+
cls,
218+
real_data,
219+
synthetic_data,
220+
metadata=None,
221+
key_fields=None,
222+
sensitive_fields=None,
223+
model_kwargs=None,
224+
):
225+
"""Compute this metric.
226+
227+
This fits an adversial attacker model on the synthetic data and
228+
then evaluates it making predictions on the real data.
229+
230+
A ``key_fields`` column(s) name must be given, either directly or as a first level
231+
entry in the ``metadata`` dict, which will be used as the key column(s) for the
232+
attack.
233+
234+
A ``sensitive_fields`` column(s) name must be given, either directly or as a first level
235+
entry in the ``metadata`` dict, which will be used as the sensitive_fields column(s)
236+
for the attack.
237+
238+
Args:
239+
real_data (Union[numpy.ndarray, pandas.DataFrame]):
240+
The values from the real dataset.
241+
synthetic_data (Union[numpy.ndarray, pandas.DataFrame]):
242+
The values from the synthetic dataset.
243+
metadata (dict):
244+
Table metadata dict. If not passed, it is build based on the
245+
real_data fields and dtypes.
246+
key_fields (list(str)):
247+
Name of the column(s) to use as the key attributes.
248+
sensitive_fields (list(str)):
249+
Name of the column(s) to use as the sensitive attributes.
250+
model_kwargs (dict):
251+
Key word arguments of the attacker model. cls.MODEL_KWARGS will be used
252+
if none is provided.
253+
254+
Returns:
255+
union[float, tuple[float]]:
256+
Scores obtained by the attackers when evaluated on the real data.
257+
"""
258+
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
259+
return cls._compute(
260+
real_data=real_data,
261+
synthetic_data=synthetic_data,
262+
metadata=metadata,
263+
key_fields=key_fields,
264+
sensitive_fields=sensitive_fields,
265+
model_kwargs=model_kwargs,
266+
)
267+
116268

117269
class GeneralizedCAPAttacker(CAPAttacker):
118270
"""The GeneralizedCAP privacy attacker.
@@ -169,3 +321,75 @@ class CategoricalGeneralizedCAP(CategoricalPrivacyMetric):
169321
name = 'Categorical GeneralizedCAP'
170322
MODEL = GeneralizedCAPAttacker
171323
ACCURACY_BASE = False
324+
325+
@classmethod
326+
def _compute(
327+
cls,
328+
real_data,
329+
synthetic_data,
330+
metadata=None,
331+
key_fields=None,
332+
sensitive_fields=None,
333+
model_kwargs=None,
334+
):
335+
return super().compute(
336+
real_data=real_data,
337+
synthetic_data=synthetic_data,
338+
metadata=metadata,
339+
key_fields=key_fields,
340+
sensitive_fields=sensitive_fields,
341+
model_kwargs=model_kwargs,
342+
)
343+
344+
@classmethod
345+
def compute(
346+
cls,
347+
real_data,
348+
synthetic_data,
349+
metadata=None,
350+
key_fields=None,
351+
sensitive_fields=None,
352+
model_kwargs=None,
353+
):
354+
"""Compute this metric.
355+
356+
This fits an adversial attacker model on the synthetic data and
357+
then evaluates it making predictions on the real data.
358+
359+
A ``key_fields`` column(s) name must be given, either directly or as a first level
360+
entry in the ``metadata`` dict, which will be used as the key column(s) for the
361+
attack.
362+
363+
A ``sensitive_fields`` column(s) name must be given, either directly or as a first level
364+
entry in the ``metadata`` dict, which will be used as the sensitive_fields column(s)
365+
for the attack.
366+
367+
Args:
368+
real_data (Union[numpy.ndarray, pandas.DataFrame]):
369+
The values from the real dataset.
370+
synthetic_data (Union[numpy.ndarray, pandas.DataFrame]):
371+
The values from the synthetic dataset.
372+
metadata (dict):
373+
Table metadata dict. If not passed, it is build based on the
374+
real_data fields and dtypes.
375+
key_fields (list(str)):
376+
Name of the column(s) to use as the key attributes.
377+
sensitive_fields (list(str)):
378+
Name of the column(s) to use as the sensitive attributes.
379+
model_kwargs (dict):
380+
Key word arguments of the attacker model. cls.MODEL_KWARGS will be used
381+
if none is provided.
382+
383+
Returns:
384+
union[float, tuple[float]]:
385+
Scores obtained by the attackers when evaluated on the real data.
386+
"""
387+
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
388+
return cls._compute(
389+
real_data=real_data,
390+
synthetic_data=synthetic_data,
391+
metadata=metadata,
392+
key_fields=key_fields,
393+
sensitive_fields=sensitive_fields,
394+
model_kwargs=model_kwargs,
395+
)

sdmetrics/single_table/privacy/disclosure_protection.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Disclosure protection metrics."""
22

3+
import warnings
4+
35
import numpy as np
46
import pandas as pd
57
import tqdm
@@ -12,6 +14,8 @@
1214
CategoricalZeroCAP,
1315
)
1416

17+
MAX_NUM_ROWS = 50000
18+
1519
CAP_METHODS = {
1620
'CAP': CategoricalCAP,
1721
'ZERO_CAP': CategoricalZeroCAP,
@@ -204,7 +208,14 @@ def compute_breakdown(
204208
continuous_column_names,
205209
num_discrete_bins,
206210
)
211+
207212
computation_method = computation_method.upper()
213+
if len(real_data) > MAX_NUM_ROWS or len(synthetic_data) > MAX_NUM_ROWS:
214+
warnings.warn(
215+
f'Data exceeds {MAX_NUM_ROWS} rows, perfomance may be slow.'
216+
'Consider using the `DisclosureProtectionEstimate` for faster computation.'
217+
)
218+
208219
real_data, synthetic_data = cls._discretize_and_fillna(
209220
real_data,
210221
synthetic_data,
@@ -219,7 +230,7 @@ def compute_breakdown(
219230

220231
# Compute CAP metric
221232
cap_metric = CAP_METHODS.get(computation_method)
222-
cap_protection = cap_metric.compute(
233+
cap_protection = cap_metric._compute(
223234
real_data,
224235
synthetic_data,
225236
key_fields=known_column_names,
@@ -343,7 +354,7 @@ def _compute_estimated_cap_metric(
343354
real_data_samp = real_data.sample(min(num_rows_subsample, len(real_data)))
344355
synth_data_samp = synthetic_data.sample(min(num_rows_subsample, len(synthetic_data)))
345356

346-
estimated_cap_protection = cap_metric.compute(
357+
estimated_cap_protection = cap_metric._compute(
347358
real_data_samp,
348359
synth_data_samp,
349360
key_fields=known_column_names,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import re
2+
3+
import pandas as pd
4+
import pytest
5+
6+
from sdmetrics.single_table.privacy.cap import (
7+
CategoricalCAP,
8+
CategoricalGeneralizedCAP,
9+
CategoricalZeroCAP,
10+
)
11+
12+
13+
@pytest.mark.parametrize('metric', [CategoricalCAP, CategoricalZeroCAP, CategoricalGeneralizedCAP])
14+
def test_CAP_deprecation_message(metric):
15+
"""Test deprecation warning is raised when running the metric directly."""
16+
# Setup
17+
real_data = pd.DataFrame({'col1': range(5), 'col2': ['A', 'B', 'C', 'A', 'B']})
18+
synthetic_data = pd.DataFrame({'col1': range(5), 'col2': ['C', 'A', 'A', 'B', 'C']})
19+
20+
# Run and Assert
21+
expected_warning = re.escape(
22+
'Computing CAP metrics directly is deprecated. For improved privacy metrics, '
23+
"please use the 'DisclosureProtection' and 'DisclosureProtectionEstimate' "
24+
'metrics instead.'
25+
)
26+
with pytest.warns(DeprecationWarning, match=expected_warning):
27+
metric.compute(real_data, synthetic_data, key_fields=['col1'], sensitive_fields=['col2'])

0 commit comments

Comments
 (0)