From 37875f526ebfbb0fa20c0c69af59398aa7e2d437 Mon Sep 17 00:00:00 2001 From: Xiaojian Ma Date: Thu, 22 Jul 2021 23:04:45 -0700 Subject: [PATCH] Replace the looping with gather ops The for-loop can be replaced with torch.gather to speed up the operation. --- visualize_attention.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/visualize_attention.py b/visualize_attention.py index 4288265b9..e2ae8a3f6 100644 --- a/visualize_attention.py +++ b/visualize_attention.py @@ -190,8 +190,7 @@ def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, con cumval = torch.cumsum(val, dim=1) th_attn = cumval > (1 - args.threshold) idx2 = torch.argsort(idx) - for head in range(nh): - th_attn[head] = th_attn[head][idx2[head]] + th_attn = torch.gather(th_attn, 1, idx2) th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() # interpolate th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()