diff --git a/visualize_attention.py b/visualize_attention.py index 4288265b9..e2ae8a3f6 100644 --- a/visualize_attention.py +++ b/visualize_attention.py @@ -190,8 +190,7 @@ def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, con cumval = torch.cumsum(val, dim=1) th_attn = cumval > (1 - args.threshold) idx2 = torch.argsort(idx) - for head in range(nh): - th_attn[head] = th_attn[head][idx2[head]] + th_attn = torch.gather(th_attn, 1, idx2) th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() # interpolate th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()