Skip to content

Commit 189e2b8

Browse files
committed
Add discretization to ContingencySimilarity metric
1 parent b83bcf9 commit 189e2b8

File tree

6 files changed

+166
-22
lines changed

6 files changed

+166
-22
lines changed

sdmetrics/column_pairs/statistical/contingency_similarity.py

+44-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""Contingency Similarity Metric."""
22

3+
import pandas as pd
4+
35
from sdmetrics.column_pairs.base import ColumnPairsMetric
46
from sdmetrics.goal import Goal
7+
from sdmetrics.utils import discretize_column
58

69

710
class ContingencySimilarity(ColumnPairsMetric):
@@ -23,23 +26,57 @@ class ContingencySimilarity(ColumnPairsMetric):
2326
min_value = 0.0
2427
max_value = 1.0
2528

29+
@staticmethod
30+
def _validate_inputs(real_data, synthetic_data, continuous_column_names, num_discrete_bins):
31+
for data in [real_data, synthetic_data]:
32+
if not isinstance(data, pd.DataFrame) or len(data.columns) != 2:
33+
raise ValueError('The data must be a pandas DataFrame with two columns.')
34+
35+
if set(real_data.columns) != set(synthetic_data.columns):
36+
raise ValueError('The columns in the real and synthetic data must match.')
37+
38+
if continuous_column_names is not None:
39+
bad_continuous_columns = "' ,'".join([
40+
column for column in continuous_column_names if column not in real_data.columns
41+
])
42+
if bad_continuous_columns:
43+
raise ValueError(
44+
f"Continuous column(s) '{bad_continuous_columns}' not found in the data."
45+
)
46+
47+
if not isinstance(num_discrete_bins, int) or num_discrete_bins <= 0:
48+
raise ValueError('`num_discrete_bins` must be an integer greater than zero.')
49+
2650
@classmethod
27-
def compute(cls, real_data, synthetic_data):
51+
def compute(cls, real_data, synthetic_data, continuous_column_names=None, num_discrete_bins=10):
2852
"""Compare the contingency similarity of two discrete columns.
2953
3054
Args:
31-
real_data (Union[numpy.ndarray, pandas.Series]):
32-
The values from the real dataset.
33-
synthetic_data (Union[numpy.ndarray, pandas.Series]):
34-
The values from the synthetic dataset.
55+
real_data (pd.DataFrame):
56+
The target columns the real dataset.
57+
synthetic_data (pd.DataFrame):
58+
The target columns the synthetic dataset.
59+
continuous_column_names (list[str], optional):
60+
The list of columns to discretize before running the metric. The column names in
61+
this list should match the column names in the real and synthetic data. Defaults
62+
to ``None``.
63+
num_discrete_bins (int, optional):
64+
The number of bins to create for the continuous columns. Defaults to 10.
3565
3666
Returns:
3767
float:
3868
The contingency similarity of the two columns.
3969
"""
70+
cls._validate_inputs(real_data, synthetic_data, continuous_column_names, num_discrete_bins)
4071
columns = real_data.columns[:2]
41-
real = real_data[columns]
42-
synthetic = synthetic_data[columns]
72+
real = real_data[columns].copy()
73+
synthetic = synthetic_data[columns].copy()
74+
if continuous_column_names is not None:
75+
for column in continuous_column_names:
76+
real[column], synthetic[column] = discretize_column(
77+
real[column], synthetic[column], num_discrete_bins=num_discrete_bins
78+
)
79+
4380
contingency_real = real.groupby(list(columns), dropna=False).size() / len(real)
4481
contingency_synthetic = synthetic.groupby(list(columns), dropna=False).size() / len(
4582
synthetic

sdmetrics/reports/utils.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pandas.core.tools.datetimes import _guess_datetime_format_for_array
1010

1111
from sdmetrics.utils import (
12+
discretize_column,
1213
get_alternate_keys,
1314
get_columns_from_metadata,
1415
get_type_from_column_meta,
@@ -116,9 +117,7 @@ def discretize_table_data(real_data, synthetic_data, metadata):
116117
real_col = pd.to_numeric(real_col)
117118
synthetic_col = pd.to_numeric(synthetic_col)
118119

119-
bin_edges = np.histogram_bin_edges(real_col.dropna())
120-
binned_real_col = np.digitize(real_col, bins=bin_edges)
121-
binned_synthetic_col = np.digitize(synthetic_col, bins=bin_edges)
120+
binned_real_col, binned_synthetic_col = discretize_column(real_col, synthetic_col)
122121

123122
binned_real[column_name] = binned_real_col
124123
binned_synthetic[column_name] = binned_synthetic_col

sdmetrics/utils.py

+22
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,28 @@ def is_datetime(data):
123123
)
124124

125125

126+
def discretize_column(real_column, synthetic_column, num_discrete_bins=10):
127+
"""Discretize a real and synthetic column.
128+
129+
Args:
130+
real_column (pd.Series):
131+
The real column.
132+
synthetic_column (pd.Series):
133+
The synthetic column.
134+
num_discrete_bins (int, optional):
135+
The number of bins to create. Defaults to 10.
136+
137+
Returns:
138+
tuple(pd.Series, pd.Series):
139+
The discretized real and synthetic columns.
140+
"""
141+
bin_edges = np.histogram_bin_edges(real_column.dropna(), bins=num_discrete_bins)
142+
bin_edges[0], bin_edges[-1] = -np.inf, np.inf
143+
binned_real_column = np.digitize(real_column, bins=bin_edges)
144+
binned_synthetic_column = np.digitize(synthetic_column, bins=bin_edges)
145+
return binned_real_column, binned_synthetic_column
146+
147+
126148
class HyperTransformer:
127149
"""HyperTransformer class.
128150

tests/unit/column_pairs/statistical/test_contingency_similarity.py

+70
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from unittest.mock import patch
23

34
import pandas as pd
@@ -7,6 +8,59 @@
78

89

910
class TestContingencySimilarity:
11+
def test__validate_inputs(self):
12+
"""Test the ``_validate_inputs`` method."""
13+
# Setup
14+
bad_data = pd.Series(range(5))
15+
real_data = pd.DataFrame({'col1': range(10), 'col2': range(10, 20)})
16+
bad_synthetic_data = pd.DataFrame({'bad_column': range(10), 'col2': range(10)})
17+
synthetic_data = pd.DataFrame({'col1': range(5), 'col2': range(5)})
18+
bad_continous_columns = ['col1', 'missing_col']
19+
bad_num_discrete_bins = -1
20+
21+
# Run and Assert
22+
expected_bad_data = re.escape('The data must be a pandas DataFrame with two columns.')
23+
with pytest.raises(ValueError, match=expected_bad_data):
24+
ContingencySimilarity._validate_inputs(
25+
real_data=bad_data,
26+
synthetic_data=bad_data,
27+
continuous_column_names=None,
28+
num_discrete_bins=10,
29+
)
30+
31+
expected_mismatch_columns_error = re.escape(
32+
'The columns in the real and synthetic data must match.'
33+
)
34+
with pytest.raises(ValueError, match=expected_mismatch_columns_error):
35+
ContingencySimilarity._validate_inputs(
36+
real_data=real_data,
37+
synthetic_data=bad_synthetic_data,
38+
continuous_column_names=None,
39+
num_discrete_bins=10,
40+
)
41+
42+
expected_bad_continous_column_error = re.escape(
43+
"Continuous column(s) 'missing_col' not found in the data."
44+
)
45+
with pytest.raises(ValueError, match=expected_bad_continous_column_error):
46+
ContingencySimilarity._validate_inputs(
47+
real_data=real_data,
48+
synthetic_data=synthetic_data,
49+
continuous_column_names=bad_continous_columns,
50+
num_discrete_bins=10,
51+
)
52+
53+
expected_bad_num_discrete_bins_error = re.escape(
54+
'`num_discrete_bins` must be an integer greater than zero.'
55+
)
56+
with pytest.raises(ValueError, match=expected_bad_num_discrete_bins_error):
57+
ContingencySimilarity._validate_inputs(
58+
real_data=real_data,
59+
synthetic_data=synthetic_data,
60+
continuous_column_names=['col1'],
61+
num_discrete_bins=bad_num_discrete_bins,
62+
)
63+
1064
def test_compute(self):
1165
"""Test the ``compute`` method.
1266
@@ -32,6 +86,22 @@ def test_compute(self):
3286
# Assert
3387
assert result == expected_score
3488

89+
def test_compute_with_discretization(self):
90+
"""Test the ``compute`` method with continuous columns."""
91+
# Setup
92+
real_data = pd.DataFrame({'col1': [1.0, 2.4, 2.6, 0.8], 'col2': [1, 2, 3, 4]})
93+
synthetic_data = pd.DataFrame({'col1': [1.0, 1.8, 2.6, 1.0], 'col2': [2, 3, 7, -10]})
94+
expected_score = 0.25
95+
96+
# Run
97+
metric = ContingencySimilarity()
98+
result = metric.compute(
99+
real_data, synthetic_data, continuous_column_names=['col2'], num_discrete_bins=4
100+
)
101+
102+
# Assert
103+
assert result == expected_score
104+
35105
@patch('sdmetrics.column_pairs.statistical.contingency_similarity.ColumnPairsMetric.normalize')
36106
def test_normalize(self, normalize_mock):
37107
"""Test the ``normalize`` method.

tests/unit/reports/test_utils.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -119,18 +119,18 @@ def test_discretize_table_data():
119119

120120
# Assert
121121
expected_real = pd.DataFrame({
122-
'col1': [1, 6, 11],
122+
'col1': [1, 6, 10],
123123
'col2': ['a', 'b', 'c'],
124-
'col3': [2, 1, 11],
124+
'col3': [2, 1, 10],
125125
'col4': [True, False, True],
126-
'col5': [10, 1, 11],
126+
'col5': [10, 1, 10],
127127
})
128128
expected_synth = pd.DataFrame({
129-
'col1': [11, 1, 11],
129+
'col1': [10, 1, 10],
130130
'col2': ['c', 'a', 'c'],
131-
'col3': [11, 0, 5],
131+
'col3': [10, 1, 5],
132132
'col4': [False, False, True],
133-
'col5': [10, 5, 11],
133+
'col5': [10, 5, 10],
134134
})
135135

136136
pd.testing.assert_frame_equal(discretized_real, expected_real)
@@ -193,18 +193,18 @@ def test_discretize_table_data_new_metadata():
193193

194194
# Assert
195195
expected_real = pd.DataFrame({
196-
'col1': [1, 6, 11],
196+
'col1': [1, 6, 10],
197197
'col2': ['a', 'b', 'c'],
198-
'col3': [2, 1, 11],
198+
'col3': [2, 1, 10],
199199
'col4': [True, False, True],
200-
'col5': [10, 1, 11],
200+
'col5': [10, 1, 10],
201201
})
202202
expected_synth = pd.DataFrame({
203-
'col1': [11, 1, 11],
203+
'col1': [10, 1, 10],
204204
'col2': ['c', 'a', 'c'],
205-
'col3': [11, 0, 5],
205+
'col3': [10, 1, 5],
206206
'col4': [False, False, True],
207-
'col5': [10, 5, 11],
207+
'col5': [10, 5, 10],
208208
})
209209

210210
pd.testing.assert_frame_equal(discretized_real, expected_real)

tests/unit/test_utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from sdmetrics.utils import (
88
HyperTransformer,
9+
discretize_column,
910
get_alternate_keys,
1011
get_cardinality_distribution,
1112
get_columns_from_metadata,
@@ -54,6 +55,21 @@ def test_get_missing_percentage():
5455
assert percentage_nan == 28.57
5556

5657

58+
def test_discretize_column():
59+
"""Test the ``discretize_column`` method."""
60+
# Setup
61+
real = pd.Series(range(10))
62+
synthetic = pd.Series([-10] + list(range(1, 9)) + [20])
63+
num_bins = 5
64+
65+
# Run
66+
binned_real, binned_synthetic = discretize_column(real, synthetic, num_discrete_bins=num_bins)
67+
68+
# Assert
69+
np.testing.assert_array_equal([1, 1, 2, 2, 3, 3, 4, 4, 5, 5], binned_real)
70+
np.testing.assert_array_equal([1, 1, 2, 2, 3, 3, 4, 4, 5, 5], binned_synthetic)
71+
72+
5773
def test_get_columns_from_metadata():
5874
"""Test the ``get_columns_from_metadata`` method with current metadata format.
5975

0 commit comments

Comments
 (0)