Skip to content

Commit ab546b6

Browse files
committed
Add unit tests
1 parent 5f4d071 commit ab546b6

File tree

2 files changed

+52
-4
lines changed

2 files changed

+52
-4
lines changed

sdmetrics/visualization.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,7 @@ def _get_max_between_datasets(real_data, synthetic_data):
372372
return max(synthetic_data)
373373
elif synthetic_data is None:
374374
return max(real_data)
375-
else:
376-
return max(max(real_data), max(synthetic_data))
375+
return max(max(real_data), max(synthetic_data))
377376

378377

379378
def _get_min_between_datasets(real_data, synthetic_data):
@@ -383,8 +382,7 @@ def _get_min_between_datasets(real_data, synthetic_data):
383382
return min(synthetic_data)
384383
elif synthetic_data is None:
385384
return min(real_data)
386-
else:
387-
return min(min(real_data), min(synthetic_data))
385+
return min(min(real_data), min(synthetic_data))
388386

389387

390388
def _generate_cardinality_plot(

tests/unit/test_visualization.py

+50
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
_generate_line_plot,
1616
_generate_scatter_plot,
1717
_get_cardinality,
18+
_get_max_between_datasets,
19+
_get_min_between_datasets,
1820
get_cardinality_plot,
1921
get_column_line_plot,
2022
get_column_pair_plot,
@@ -48,6 +50,54 @@ def test_get_cardinality():
4850
pd.testing.assert_series_equal(result, expected_result)
4951

5052

53+
def test__get_max_between_datasets():
54+
"""Test the ``_get_max_between_datasets`` method"""
55+
# Setup
56+
mock_real_data = pd.Series([1, 1, 2, 2, 2])
57+
mock_synthetic_data = pd.Series([3, 3, 4])
58+
59+
# Run
60+
real_only_val = _get_max_between_datasets(mock_real_data, None)
61+
synth_only_val = _get_max_between_datasets(None, mock_synthetic_data)
62+
all_val = _get_max_between_datasets(mock_real_data, mock_synthetic_data)
63+
64+
# Assert
65+
expected_real_only_val = 2
66+
expected_synth_only_val = 4
67+
expected_all_val = 4
68+
assert expected_real_only_val == real_only_val
69+
assert expected_synth_only_val == synth_only_val
70+
assert expected_all_val == all_val
71+
72+
error_msg = re.escape('Cannot get max between two None values.')
73+
with pytest.raises(ValueError, match=error_msg):
74+
_get_max_between_datasets(None, None)
75+
76+
77+
def test__get_min_between_datasets():
78+
"""Test the ``_get_min_between_datasets`` method"""
79+
# Setup
80+
mock_real_data = pd.Series([1, 1, 2, 2, 2])
81+
mock_synthetic_data = pd.Series([3, 3, 4])
82+
83+
# Run
84+
real_only_val = _get_min_between_datasets(mock_real_data, None)
85+
synth_only_val = _get_min_between_datasets(None, mock_synthetic_data)
86+
all_val = _get_min_between_datasets(mock_real_data, mock_synthetic_data)
87+
88+
# Assert
89+
expected_real_only_val = 1
90+
expected_synth_only_val = 3
91+
expected_all_val = 1
92+
assert expected_real_only_val == real_only_val
93+
assert expected_synth_only_val == synth_only_val
94+
assert expected_all_val == all_val
95+
96+
error_msg = re.escape('Cannot get min between two None values.')
97+
with pytest.raises(ValueError, match=error_msg):
98+
_get_min_between_datasets(None, None)
99+
100+
51101
@patch('sdmetrics.visualization.px')
52102
def test_generate_cardinality_bar_plot(mock_px):
53103
"""Test the ``_generate_cardinality_plot`` method."""

0 commit comments

Comments
 (0)