Skip to content

Commit 0bf95d9

Browse files
authored
Add DisclosureProtectionEstimate metric (#686)
1 parent dafd198 commit 0bf95d9

File tree

7 files changed

+901
-32
lines changed

7 files changed

+901
-32
lines changed

sdmetrics/single_table/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@
6767
CategoricalRF,
6868
CategoricalSVM,
6969
)
70-
from sdmetrics.single_table.privacy.disclosure_protection import DisclosureProtection
70+
from sdmetrics.single_table.privacy.disclosure_protection import (
71+
DisclosureProtection,
72+
DisclosureProtectionEstimate,
73+
)
7174
from sdmetrics.single_table.privacy.ensemble import CategoricalEnsemble
7275
from sdmetrics.single_table.privacy.numerical_sklearn import NumericalLR, NumericalMLP, NumericalSVR
7376
from sdmetrics.single_table.privacy.radius_nearest_neighbor import NumericalRadiusNearestNeighbor
@@ -111,6 +114,7 @@
111114
'CategoricalZeroCAP',
112115
'CategoricalGeneralizedCAP',
113116
'DisclosureProtection',
117+
'DisclosureProtectionEstimate',
114118
'NumericalMLP',
115119
'NumericalLR',
116120
'NumericalSVR',

sdmetrics/single_table/privacy/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
CategoricalRF,
1313
CategoricalSVM,
1414
)
15-
from sdmetrics.single_table.privacy.disclosure_protection import DisclosureProtection
15+
from sdmetrics.single_table.privacy.disclosure_protection import (
16+
DisclosureProtection,
17+
DisclosureProtectionEstimate,
18+
)
1619
from sdmetrics.single_table.privacy.ensemble import CategoricalEnsemble
1720
from sdmetrics.single_table.privacy.numerical_sklearn import NumericalLR, NumericalMLP, NumericalSVR
1821
from sdmetrics.single_table.privacy.radius_nearest_neighbor import NumericalRadiusNearestNeighbor
@@ -28,6 +31,7 @@
2831
'CategoricalSVM',
2932
'CategoricalZeroCAP',
3033
'DisclosureProtection',
34+
'DisclosureProtectionEstimate',
3135
'NumericalLR',
3236
'NumericalMLP',
3337
'NumericalPrivacyMetric',

sdmetrics/single_table/privacy/cap.py

+224
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
"""CAP modules and their attackers."""
22

3+
import warnings
4+
35
from sdmetrics.single_table.privacy.base import CategoricalPrivacyMetric, PrivacyAttackerModel
46
from sdmetrics.single_table.privacy.util import closest_neighbors, count_frequency, majority
57

8+
DEPRECATION_MSG = (
9+
'Computing CAP metrics directly is deprecated. For improved privacy metrics, '
10+
"please use the 'DisclosureProtection' and 'DisclosureProtectionEstimate' "
11+
'metrics instead.'
12+
)
13+
614

715
class CAPAttacker(PrivacyAttackerModel):
816
"""The CAP (Correct Attribution Probability) privacy attacker.
@@ -78,6 +86,78 @@ class CategoricalCAP(CategoricalPrivacyMetric):
7886
MODEL = CAPAttacker
7987
ACCURACY_BASE = False
8088

89+
@classmethod
90+
def _compute(
91+
cls,
92+
real_data,
93+
synthetic_data,
94+
metadata=None,
95+
key_fields=None,
96+
sensitive_fields=None,
97+
model_kwargs=None,
98+
):
99+
return super().compute(
100+
real_data=real_data,
101+
synthetic_data=synthetic_data,
102+
metadata=metadata,
103+
key_fields=key_fields,
104+
sensitive_fields=sensitive_fields,
105+
model_kwargs=model_kwargs,
106+
)
107+
108+
@classmethod
109+
def compute(
110+
cls,
111+
real_data,
112+
synthetic_data,
113+
metadata=None,
114+
key_fields=None,
115+
sensitive_fields=None,
116+
model_kwargs=None,
117+
):
118+
"""Compute this metric.
119+
120+
This fits an adversial attacker model on the synthetic data and
121+
then evaluates it making predictions on the real data.
122+
123+
A ``key_fields`` column(s) name must be given, either directly or as a first level
124+
entry in the ``metadata`` dict, which will be used as the key column(s) for the
125+
attack.
126+
127+
A ``sensitive_fields`` column(s) name must be given, either directly or as a first level
128+
entry in the ``metadata`` dict, which will be used as the sensitive_fields column(s)
129+
for the attack.
130+
131+
Args:
132+
real_data (Union[numpy.ndarray, pandas.DataFrame]):
133+
The values from the real dataset.
134+
synthetic_data (Union[numpy.ndarray, pandas.DataFrame]):
135+
The values from the synthetic dataset.
136+
metadata (dict):
137+
Table metadata dict. If not passed, it is build based on the
138+
real_data fields and dtypes.
139+
key_fields (list(str)):
140+
Name of the column(s) to use as the key attributes.
141+
sensitive_fields (list(str)):
142+
Name of the column(s) to use as the sensitive attributes.
143+
model_kwargs (dict):
144+
Key word arguments of the attacker model. cls.MODEL_KWARGS will be used
145+
if none is provided.
146+
147+
Returns:
148+
union[float, tuple[float]]:
149+
Scores obtained by the attackers when evaluated on the real data.
150+
"""
151+
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
152+
return cls._compute(
153+
real_data=real_data,
154+
synthetic_data=synthetic_data,
155+
metadata=metadata,
156+
key_fields=key_fields,
157+
sensitive_fields=sensitive_fields,
158+
model_kwargs=model_kwargs,
159+
)
160+
81161

82162
class ZeroCAPAttacker(CAPAttacker):
83163
"""The 0CAP privacy attacker, which operates in the same way as CAP does.
@@ -113,6 +193,78 @@ class CategoricalZeroCAP(CategoricalPrivacyMetric):
113193
MODEL = ZeroCAPAttacker
114194
ACCURACY_BASE = False
115195

196+
@classmethod
197+
def _compute(
198+
cls,
199+
real_data,
200+
synthetic_data,
201+
metadata=None,
202+
key_fields=None,
203+
sensitive_fields=None,
204+
model_kwargs=None,
205+
):
206+
return super().compute(
207+
real_data=real_data,
208+
synthetic_data=synthetic_data,
209+
metadata=metadata,
210+
key_fields=key_fields,
211+
sensitive_fields=sensitive_fields,
212+
model_kwargs=model_kwargs,
213+
)
214+
215+
@classmethod
216+
def compute(
217+
cls,
218+
real_data,
219+
synthetic_data,
220+
metadata=None,
221+
key_fields=None,
222+
sensitive_fields=None,
223+
model_kwargs=None,
224+
):
225+
"""Compute this metric.
226+
227+
This fits an adversial attacker model on the synthetic data and
228+
then evaluates it making predictions on the real data.
229+
230+
A ``key_fields`` column(s) name must be given, either directly or as a first level
231+
entry in the ``metadata`` dict, which will be used as the key column(s) for the
232+
attack.
233+
234+
A ``sensitive_fields`` column(s) name must be given, either directly or as a first level
235+
entry in the ``metadata`` dict, which will be used as the sensitive_fields column(s)
236+
for the attack.
237+
238+
Args:
239+
real_data (Union[numpy.ndarray, pandas.DataFrame]):
240+
The values from the real dataset.
241+
synthetic_data (Union[numpy.ndarray, pandas.DataFrame]):
242+
The values from the synthetic dataset.
243+
metadata (dict):
244+
Table metadata dict. If not passed, it is build based on the
245+
real_data fields and dtypes.
246+
key_fields (list(str)):
247+
Name of the column(s) to use as the key attributes.
248+
sensitive_fields (list(str)):
249+
Name of the column(s) to use as the sensitive attributes.
250+
model_kwargs (dict):
251+
Key word arguments of the attacker model. cls.MODEL_KWARGS will be used
252+
if none is provided.
253+
254+
Returns:
255+
union[float, tuple[float]]:
256+
Scores obtained by the attackers when evaluated on the real data.
257+
"""
258+
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
259+
return cls._compute(
260+
real_data=real_data,
261+
synthetic_data=synthetic_data,
262+
metadata=metadata,
263+
key_fields=key_fields,
264+
sensitive_fields=sensitive_fields,
265+
model_kwargs=model_kwargs,
266+
)
267+
116268

117269
class GeneralizedCAPAttacker(CAPAttacker):
118270
"""The GeneralizedCAP privacy attacker.
@@ -169,3 +321,75 @@ class CategoricalGeneralizedCAP(CategoricalPrivacyMetric):
169321
name = 'Categorical GeneralizedCAP'
170322
MODEL = GeneralizedCAPAttacker
171323
ACCURACY_BASE = False
324+
325+
@classmethod
326+
def _compute(
327+
cls,
328+
real_data,
329+
synthetic_data,
330+
metadata=None,
331+
key_fields=None,
332+
sensitive_fields=None,
333+
model_kwargs=None,
334+
):
335+
return super().compute(
336+
real_data=real_data,
337+
synthetic_data=synthetic_data,
338+
metadata=metadata,
339+
key_fields=key_fields,
340+
sensitive_fields=sensitive_fields,
341+
model_kwargs=model_kwargs,
342+
)
343+
344+
@classmethod
345+
def compute(
346+
cls,
347+
real_data,
348+
synthetic_data,
349+
metadata=None,
350+
key_fields=None,
351+
sensitive_fields=None,
352+
model_kwargs=None,
353+
):
354+
"""Compute this metric.
355+
356+
This fits an adversial attacker model on the synthetic data and
357+
then evaluates it making predictions on the real data.
358+
359+
A ``key_fields`` column(s) name must be given, either directly or as a first level
360+
entry in the ``metadata`` dict, which will be used as the key column(s) for the
361+
attack.
362+
363+
A ``sensitive_fields`` column(s) name must be given, either directly or as a first level
364+
entry in the ``metadata`` dict, which will be used as the sensitive_fields column(s)
365+
for the attack.
366+
367+
Args:
368+
real_data (Union[numpy.ndarray, pandas.DataFrame]):
369+
The values from the real dataset.
370+
synthetic_data (Union[numpy.ndarray, pandas.DataFrame]):
371+
The values from the synthetic dataset.
372+
metadata (dict):
373+
Table metadata dict. If not passed, it is build based on the
374+
real_data fields and dtypes.
375+
key_fields (list(str)):
376+
Name of the column(s) to use as the key attributes.
377+
sensitive_fields (list(str)):
378+
Name of the column(s) to use as the sensitive attributes.
379+
model_kwargs (dict):
380+
Key word arguments of the attacker model. cls.MODEL_KWARGS will be used
381+
if none is provided.
382+
383+
Returns:
384+
union[float, tuple[float]]:
385+
Scores obtained by the attackers when evaluated on the real data.
386+
"""
387+
warnings.warn(DEPRECATION_MSG, DeprecationWarning)
388+
return cls._compute(
389+
real_data=real_data,
390+
synthetic_data=synthetic_data,
391+
metadata=metadata,
392+
key_fields=key_fields,
393+
sensitive_fields=sensitive_fields,
394+
model_kwargs=model_kwargs,
395+
)

0 commit comments

Comments
 (0)