Skip to content

Commit f829f7f

Browse files
committed
use torch.tensor
1 parent 4c00cf7 commit f829f7f

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

sdmetrics/single_table/bayesian_network.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None):
4949
probabilities = []
5050
for _, row in synthetic_data[fields].iterrows():
5151
try:
52-
probabilities.append(bn.probability([row.to_numpy()]))
52+
probabilities.append(torch.tensor(bn.probability([row.to_numpy()])))
5353
except ValueError:
54-
probabilities.append(0)
54+
probabilities.append(torch.tensor(0))
5555

5656
return np.asarray(probabilities)
5757

@@ -125,7 +125,7 @@ def compute(cls, real_data, synthetic_data, metadata=None, structure=None):
125125
float:
126126
Mean of the probabilities returned by the Bayesian Network.
127127
"""
128-
return np.mean(cls._likelihoods(real_data, synthetic_data, metadata, structure))
128+
return np.mean(cls._likelihoods(real_data, synthetic_data, metadata, structure)).item()
129129

130130

131131
class BNLogLikelihood(BNLikelihoodBase):
@@ -199,7 +199,7 @@ def compute(cls, real_data, synthetic_data, metadata=None, structure=None):
199199
"""
200200
likelihoods = cls._likelihoods(real_data, synthetic_data, metadata, structure)
201201
likelihoods[np.where(likelihoods == 0)] = 1e-8
202-
return np.mean(np.log(likelihoods))
202+
return np.mean(np.log(likelihoods)).item()
203203

204204
@classmethod
205205
def normalize(cls, raw_score):

tests/unit/single_table/test_bayesian_network.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_compute(self, real_data, synthetic_data, metadata):
5151
result = metric.compute(real_data, synthetic_data, metadata)
5252

5353
# Assert
54-
assert result == 0.111111104
54+
assert result == 0.1111111044883728
5555

5656

5757
class TestBNLogLikelihood:
@@ -65,4 +65,4 @@ def test_compute(self, real_data, synthetic_data, metadata):
6565
result = metric.compute(real_data, synthetic_data, metadata)
6666

6767
# Assert
68-
assert result == -7.3347335
68+
assert result == -7.334733486175537

0 commit comments

Comments
 (0)