|
3 | 3 | import pytest
|
4 | 4 |
|
5 | 5 | from sdmetrics.single_table.privacy import (
|
| 6 | + CategoricalCAP, |
6 | 7 | CategoricalEnsemble,
|
| 8 | + CategoricalGeneralizedCAP, |
7 | 9 | CategoricalPrivacyMetric,
|
8 | 10 | NumericalPrivacyMetric,
|
9 | 11 | )
|
@@ -53,6 +55,15 @@ def cat_bad_synthetic_data():
|
53 | 55 | })
|
54 | 56 |
|
55 | 57 |
|
| 58 | +def cat_disjoint_synthetic_data(): |
| 59 | + return pd.DataFrame({ |
| 60 | + 'key1': ['v', 'w', 'x', 'y', 'z'] * 20, |
| 61 | + 'key2': [5, 6, 7, 8, 9] * 20, |
| 62 | + 'sensitive1': ['a', 'b', 'c', 'e', 'd'] * 20, |
| 63 | + 'sensitive2': [0, 1, 2, 3, 4] * 20, |
| 64 | + }) |
| 65 | + |
| 66 | + |
56 | 67 | @pytest.mark.parametrize('metric', categorical_metrics.values()) # noqa: PD011
|
57 | 68 | def test_categoricals_non_ens(metric):
|
58 | 69 | if metric != CategoricalEnsemble: # Ensemble needs additional args to work
|
@@ -84,7 +95,18 @@ def test_categoricals_non_ens(metric):
|
84 | 95 | sensitive_fields=['sensitive1', 'sensitive2'],
|
85 | 96 | )
|
86 | 97 |
|
| 98 | + disjoint = metric.compute( |
| 99 | + cat_real_data(), |
| 100 | + cat_disjoint_synthetic_data(), |
| 101 | + key_fields=['key1', 'key2'], |
| 102 | + sensitive_fields=['sensitive1', 'sensitive2'], |
| 103 | + ) |
| 104 | + |
87 | 105 | assert metric.min_value <= horrible <= bad <= good <= perfect <= metric.max_value
|
| 106 | + if metric == CategoricalCAP or metric == CategoricalEnsemble: |
| 107 | + assert np.isnan(disjoint) |
| 108 | + elif metric != CategoricalGeneralizedCAP: |
| 109 | + assert disjoint == metric.max_value |
88 | 110 |
|
89 | 111 |
|
90 | 112 | def test_categorical_ens():
|
|
0 commit comments