Skip to content

Commit 77df479

Browse files
authored
CategoricalCAP metric returns 0 if no overlap in known fields (#695)
1 parent 0bf95d9 commit 77df479

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

sdmetrics/single_table/privacy/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def compute(
166166
score += row_score
167167

168168
if count == 0:
169-
return 0
169+
return np.nan
170170

171171
return 1.0 - score / count
172172

tests/integration/single_table/privacy/test_privacy.py

+22
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import pytest
44

55
from sdmetrics.single_table.privacy import (
6+
CategoricalCAP,
67
CategoricalEnsemble,
8+
CategoricalGeneralizedCAP,
79
CategoricalPrivacyMetric,
810
NumericalPrivacyMetric,
911
)
@@ -53,6 +55,15 @@ def cat_bad_synthetic_data():
5355
})
5456

5557

58+
def cat_disjoint_synthetic_data():
59+
return pd.DataFrame({
60+
'key1': ['v', 'w', 'x', 'y', 'z'] * 20,
61+
'key2': [5, 6, 7, 8, 9] * 20,
62+
'sensitive1': ['a', 'b', 'c', 'e', 'd'] * 20,
63+
'sensitive2': [0, 1, 2, 3, 4] * 20,
64+
})
65+
66+
5667
@pytest.mark.parametrize('metric', categorical_metrics.values()) # noqa: PD011
5768
def test_categoricals_non_ens(metric):
5869
if metric != CategoricalEnsemble: # Ensemble needs additional args to work
@@ -84,7 +95,18 @@ def test_categoricals_non_ens(metric):
8495
sensitive_fields=['sensitive1', 'sensitive2'],
8596
)
8697

98+
disjoint = metric.compute(
99+
cat_real_data(),
100+
cat_disjoint_synthetic_data(),
101+
key_fields=['key1', 'key2'],
102+
sensitive_fields=['sensitive1', 'sensitive2'],
103+
)
104+
87105
assert metric.min_value <= horrible <= bad <= good <= perfect <= metric.max_value
106+
if metric == CategoricalCAP or metric == CategoricalEnsemble:
107+
assert np.isnan(disjoint)
108+
elif metric != CategoricalGeneralizedCAP:
109+
assert disjoint == metric.max_value
88110

89111

90112
def test_categorical_ens():

0 commit comments

Comments
 (0)