|
15 | 15 | _generate_line_plot,
|
16 | 16 | _generate_scatter_plot,
|
17 | 17 | _get_cardinality,
|
| 18 | + _get_max_between_datasets, |
| 19 | + _get_min_between_datasets, |
18 | 20 | get_cardinality_plot,
|
19 | 21 | get_column_line_plot,
|
20 | 22 | get_column_pair_plot,
|
@@ -48,6 +50,54 @@ def test_get_cardinality():
|
48 | 50 | pd.testing.assert_series_equal(result, expected_result)
|
49 | 51 |
|
50 | 52 |
|
| 53 | +def test__get_max_between_datasets(): |
| 54 | + """Test the ``_get_max_between_datasets`` method""" |
| 55 | + # Setup |
| 56 | + mock_real_data = pd.Series([1, 1, 2, 2, 2]) |
| 57 | + mock_synthetic_data = pd.Series([3, 3, 4]) |
| 58 | + |
| 59 | + # Run |
| 60 | + real_only_val = _get_max_between_datasets(mock_real_data, None) |
| 61 | + synth_only_val = _get_max_between_datasets(None, mock_synthetic_data) |
| 62 | + all_val = _get_max_between_datasets(mock_real_data, mock_synthetic_data) |
| 63 | + |
| 64 | + # Assert |
| 65 | + expected_real_only_val = 2 |
| 66 | + expected_synth_only_val = 4 |
| 67 | + expected_all_val = 4 |
| 68 | + assert expected_real_only_val == real_only_val |
| 69 | + assert expected_synth_only_val == synth_only_val |
| 70 | + assert expected_all_val == all_val |
| 71 | + |
| 72 | + error_msg = re.escape('Cannot get max between two None values.') |
| 73 | + with pytest.raises(ValueError, match=error_msg): |
| 74 | + _get_max_between_datasets(None, None) |
| 75 | + |
| 76 | + |
| 77 | +def test__get_min_between_datasets(): |
| 78 | + """Test the ``_get_min_between_datasets`` method""" |
| 79 | + # Setup |
| 80 | + mock_real_data = pd.Series([1, 1, 2, 2, 2]) |
| 81 | + mock_synthetic_data = pd.Series([3, 3, 4]) |
| 82 | + |
| 83 | + # Run |
| 84 | + real_only_val = _get_min_between_datasets(mock_real_data, None) |
| 85 | + synth_only_val = _get_min_between_datasets(None, mock_synthetic_data) |
| 86 | + all_val = _get_min_between_datasets(mock_real_data, mock_synthetic_data) |
| 87 | + |
| 88 | + # Assert |
| 89 | + expected_real_only_val = 1 |
| 90 | + expected_synth_only_val = 3 |
| 91 | + expected_all_val = 1 |
| 92 | + assert expected_real_only_val == real_only_val |
| 93 | + assert expected_synth_only_val == synth_only_val |
| 94 | + assert expected_all_val == all_val |
| 95 | + |
| 96 | + error_msg = re.escape('Cannot get min between two None values.') |
| 97 | + with pytest.raises(ValueError, match=error_msg): |
| 98 | + _get_min_between_datasets(None, None) |
| 99 | + |
| 100 | + |
51 | 101 | @patch('sdmetrics.visualization.px')
|
52 | 102 | def test_generate_cardinality_bar_plot(mock_px):
|
53 | 103 | """Test the ``_generate_cardinality_plot`` method."""
|
|
0 commit comments