Skip to content

Commit cba6fba

Browse files
Merge pull request #14353 from Nuullll/ipex-sdpa
[IPEX] Slice SDPA into smaller chunks
2 parents ac0ecf3 + f586f49 commit cba6fba

File tree

1 file changed

+67
-2
lines changed

1 file changed

+67
-2
lines changed

modules/xpu_specific.py

+67-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,71 @@ def torch_xpu_gc():
2727

2828
has_xpu = check_for_xpu()
2929

30+
31+
# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627
32+
# Here we implement a slicing algorithm to split large batch size into smaller chunks,
33+
# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
34+
# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
35+
# which is the best trade-off between VRAM usage and performance.
36+
ARC_SINGLE_ALLOCATION_LIMIT = {}
37+
orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
38+
def torch_xpu_scaled_dot_product_attention(
39+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
40+
):
41+
# cast to same dtype first
42+
key = key.to(query.dtype)
43+
value = value.to(query.dtype)
44+
45+
N = query.shape[:-2] # Batch size
46+
L = query.size(-2) # Target sequence length
47+
E = query.size(-1) # Embedding dimension of the query and key
48+
S = key.size(-2) # Source sequence length
49+
Ev = value.size(-1) # Embedding dimension of the value
50+
51+
total_batch_size = torch.numel(torch.empty(N))
52+
device_id = query.device.index
53+
if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
54+
ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
55+
batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
56+
57+
if total_batch_size <= batch_size_limit:
58+
return orig_sdp_attn_func(
59+
query,
60+
key,
61+
value,
62+
attn_mask,
63+
dropout_p,
64+
is_causal,
65+
*args, **kwargs
66+
)
67+
68+
query = torch.reshape(query, (-1, L, E))
69+
key = torch.reshape(key, (-1, S, E))
70+
value = torch.reshape(value, (-1, S, Ev))
71+
if attn_mask is not None:
72+
attn_mask = attn_mask.view(-1, L, S)
73+
chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit
74+
outputs = []
75+
for i in range(chunk_count):
76+
attn_mask_chunk = (
77+
None
78+
if attn_mask is None
79+
else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :]
80+
)
81+
chunk_output = orig_sdp_attn_func(
82+
query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
83+
key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
84+
value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
85+
attn_mask_chunk,
86+
dropout_p,
87+
is_causal,
88+
*args, **kwargs
89+
)
90+
outputs.append(chunk_output)
91+
result = torch.cat(outputs, dim=0)
92+
return torch.reshape(result, (*N, L, Ev))
93+
94+
3095
if has_xpu:
3196
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
3297
CondFunc('torch.Generator',
@@ -55,5 +120,5 @@ def torch_xpu_gc():
55120
lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
56121
lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
57122
CondFunc('torch.nn.functional.scaled_dot_product_attention',
58-
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal),
59-
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype)
123+
lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs),
124+
lambda orig_func, query, *args, **kwargs: query.is_xpu)

0 commit comments

Comments
 (0)