@@ -27,6 +27,71 @@ def torch_xpu_gc():
27
27
28
28
has_xpu = check_for_xpu ()
29
29
30
+
31
+ # Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627
32
+ # Here we implement a slicing algorithm to split large batch size into smaller chunks,
33
+ # so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
34
+ # The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
35
+ # which is the best trade-off between VRAM usage and performance.
36
+ ARC_SINGLE_ALLOCATION_LIMIT = {}
37
+ orig_sdp_attn_func = torch .nn .functional .scaled_dot_product_attention
38
+ def torch_xpu_scaled_dot_product_attention (
39
+ query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False , * args , ** kwargs
40
+ ):
41
+ # cast to same dtype first
42
+ key = key .to (query .dtype )
43
+ value = value .to (query .dtype )
44
+
45
+ N = query .shape [:- 2 ] # Batch size
46
+ L = query .size (- 2 ) # Target sequence length
47
+ E = query .size (- 1 ) # Embedding dimension of the query and key
48
+ S = key .size (- 2 ) # Source sequence length
49
+ Ev = value .size (- 1 ) # Embedding dimension of the value
50
+
51
+ total_batch_size = torch .numel (torch .empty (N ))
52
+ device_id = query .device .index
53
+ if device_id not in ARC_SINGLE_ALLOCATION_LIMIT :
54
+ ARC_SINGLE_ALLOCATION_LIMIT [device_id ] = min (torch .xpu .get_device_properties (device_id ).total_memory // 8 , 4 * 1024 * 1024 * 1024 )
55
+ batch_size_limit = max (1 , ARC_SINGLE_ALLOCATION_LIMIT [device_id ] // (L * S * query .element_size ()))
56
+
57
+ if total_batch_size <= batch_size_limit :
58
+ return orig_sdp_attn_func (
59
+ query ,
60
+ key ,
61
+ value ,
62
+ attn_mask ,
63
+ dropout_p ,
64
+ is_causal ,
65
+ * args , ** kwargs
66
+ )
67
+
68
+ query = torch .reshape (query , (- 1 , L , E ))
69
+ key = torch .reshape (key , (- 1 , S , E ))
70
+ value = torch .reshape (value , (- 1 , S , Ev ))
71
+ if attn_mask is not None :
72
+ attn_mask = attn_mask .view (- 1 , L , S )
73
+ chunk_count = (total_batch_size + batch_size_limit - 1 ) // batch_size_limit
74
+ outputs = []
75
+ for i in range (chunk_count ):
76
+ attn_mask_chunk = (
77
+ None
78
+ if attn_mask is None
79
+ else attn_mask [i * batch_size_limit : (i + 1 ) * batch_size_limit , :, :]
80
+ )
81
+ chunk_output = orig_sdp_attn_func (
82
+ query [i * batch_size_limit : (i + 1 ) * batch_size_limit , :, :],
83
+ key [i * batch_size_limit : (i + 1 ) * batch_size_limit , :, :],
84
+ value [i * batch_size_limit : (i + 1 ) * batch_size_limit , :, :],
85
+ attn_mask_chunk ,
86
+ dropout_p ,
87
+ is_causal ,
88
+ * args , ** kwargs
89
+ )
90
+ outputs .append (chunk_output )
91
+ result = torch .cat (outputs , dim = 0 )
92
+ return torch .reshape (result , (* N , L , Ev ))
93
+
94
+
30
95
if has_xpu :
31
96
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
32
97
CondFunc ('torch.Generator' ,
@@ -55,5 +120,5 @@ def torch_xpu_gc():
55
120
lambda orig_func , tensors , dim = 0 , out = None : orig_func ([t .to (tensors [0 ].dtype ) for t in tensors ], dim = dim , out = out ),
56
121
lambda orig_func , tensors , dim = 0 , out = None : not all (t .dtype == tensors [0 ].dtype for t in tensors ))
57
122
CondFunc ('torch.nn.functional.scaled_dot_product_attention' ,
58
- lambda orig_func , query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False : orig_func ( query , key . to ( query . dtype ), value . to ( query . dtype ), attn_mask , dropout_p , is_causal ),
59
- lambda orig_func , query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False : query .dtype != key . dtype or query . dtype != value . dtype )
123
+ lambda orig_func , * args , ** kwargs : torch_xpu_scaled_dot_product_attention ( * args , ** kwargs ),
124
+ lambda orig_func , query , * args , ** kwargs : query .is_xpu )
0 commit comments