-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
62 lines (56 loc) · 1.56 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import numpy as np
from scipy.spatial import KDTree
def to_numpy(a):
if isinstance(a, torch.Tensor):
return a.cpu().detach().numpy()
return a
class Avg:
def __init__(self):
self.sum = 0
self.cnt = 0
def update(self, x, n):
self.sum+=x
self.cnt+=n
def value(self):
if self.cnt == 0:
return 0
return self.sum/self.cnt
def __str__(self):
return str(self.value())
def __add__(self, other):
res = Avg()
res.sum = self.sum + other.sum
res.cnt = self.cnt + other.cnt
return res
def hausdorff(x, y):
"""
x,y: (N, D)
"""
x, y = to_numpy(x), to_numpy(y)
xTree = KDTree(x)
y_to_x_dist, _ = xTree.query(y, 1)
yTree = KDTree(y)
x_to_y_dist, _ = yTree.query(x, 1)
y_to_x_max_dist = np.max(y_to_x_dist)
x_to_y_max_dist = np.max(x_to_y_dist)
return max(y_to_x_max_dist, x_to_y_max_dist)
def SNR(ori, wm):
mean_v = ori.mean(dim=1, keepdim=True)
son = ((ori - mean_v)**2).sum(dim=-1).sum(dim=-1)
mother = ((wm - ori)**2).sum(dim=-1).sum(dim=-1)
snr = son / mother
return 10 * torch.log10(snr).mean()
if __name__ == "__main__":
import torch
B, N, D = 4, 10, 3
a, b = torch.randn((B, N, D)), torch.randn((B, N, D))
d = hausdorff(a, b)
print(d)
from chamferdist import ChamferDistance
chamferDist = ChamferDistance()
dist_forward = chamferDist(a, b)
print(dist_forward)
a,b=a*10,b*10
dist_forward = chamferDist(a, b)
print(dist_forward)