1
1
"""Base class for metrics that compare pairs of columns."""
2
2
3
3
from sdmetrics .base import BaseMetric
4
-
4
+ from time import process_time
5
+ import numpy as np
6
+ DEFAULT_NUM_ROWS = None
7
+ DEFAULT_NUM_TRY = None
5
8
6
9
class ColumnPairsMetric (BaseMetric ):
7
10
"""Base class for metrics that compare pairs of columns.
@@ -42,18 +45,19 @@ def compute(real_data, synthetic_data):
42
45
43
46
@classmethod
44
47
def compute_breakdown (cls , real_data , synthetic_data ):
45
- """Compute the breakdown of this metric.
46
-
47
- Args:
48
- real_data (pandas.DataFrame):
49
- The values from the real dataset, passed as pandas.DataFrame
50
- with 2 columns.
51
- synthetic_data (pandas.DataFrame):
52
- The values from the synthetic dataset, passed as a
53
- pandas.DataFrame with 2 columns.
54
-
55
- Returns:
56
- dict
57
- A mapping of the metric output. Must contain the key 'score'.
58
- """
59
- return {'score' : cls .compute (real_data , synthetic_data )}
48
+ """Compute the breakdown of this metric."""
49
+ start = process_time ()
50
+ num_try = 1 if DEFAULT_NUM_TRY is None else DEFAULT_NUM_TRY
51
+ result = np .zeros (num_try )
52
+ for i in range (num_try ):
53
+ if DEFAULT_NUM_ROWS is not None :
54
+ real_to_subsample = min (DEFAULT_NUM_ROWS , len (real_data ))
55
+ real_data_to_compute = real_data .sample (n = real_to_subsample )
56
+ synthetic_data_to_compute = synthetic_data .sample (n = real_to_subsample )
57
+
58
+ result [i ] = cls .compute (real_data_to_compute , synthetic_data_to_compute )
59
+
60
+ score = np .mean (result )
61
+ end = process_time ()
62
+
63
+ return {'score' : score , 'time' : end - start , 'num_rows' : DEFAULT_NUM_ROWS }
0 commit comments