15
15
#include " mlir/Dialect/Arith/IR/Arith.h"
16
16
#include " mlir/IR/BuiltinTypes.h"
17
17
18
- namespace mlir ::iree_compiler {
19
-
20
18
llvm::cl::opt<bool > clEnableI1Support (
21
19
" iree-experimental-packed-i1-storage" ,
22
20
llvm::cl::desc (
23
- " Experimental feature: enable i1 data type support in codegen" ),
21
+ " Experimental feature: force to use packed storage for i1 tensors."
22
+ " Turning on this option will see i1 tensors as if it has "
23
+ " #iree_encoding.packed_storage attribute."
24
+ " This is to allow an alternative way to test the packed storage "
25
+ " feature before frontend can emit packed i1 tensors."
26
+ " This option can be dropped once the frontend can emit packed i1 "
27
+ " tensors." ),
24
28
llvm::cl::init(false ));
25
29
30
+ namespace mlir ::iree_compiler {
31
+
26
32
bool needToPackSubByteElementBitWidth (unsigned bitWidth) {
33
+ return needToPackSubByteElementBitWidth (
34
+ bitWidth, /* isPackedStorage=*/ clEnableI1Support);
35
+ }
36
+
37
+ bool needToPackSubByteElementBitWidth (unsigned bitWidth, bool isPackedStorage) {
27
38
// Enable i1 support if requested.
28
- if (clEnableI1Support && bitWidth == 1 ) {
39
+ if (isPackedStorage && bitWidth == 1 ) {
29
40
return true ;
30
41
}
31
42
// Require the original bit width to be some power of two for now to avoid
@@ -37,18 +48,28 @@ bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
37
48
38
49
bool needToPackSubByteElements (RankedTensorType shapedType) {
39
50
unsigned bitWidth = IREE::Util::getTypeBitWidth (shapedType.getElementType ());
40
- return needToPackSubByteElementBitWidth (bitWidth);
51
+ // Two paths to enable packed storage for i1 tensors: the attribute or cl
52
+ // option. The cl option will be dropped once frontend supports emitting
53
+ // tensors with attributes.
54
+ bool isPackedStorage =
55
+ IREE::Encoding::hasPackedStorageAttr (shapedType) || clEnableI1Support;
56
+ return needToPackSubByteElementBitWidth (bitWidth, isPackedStorage);
41
57
}
42
58
43
59
Type legalizeStorageElementType (Type elementType) {
60
+ return legalizeStorageElementType (elementType,
61
+ /* isPackedStorage=*/ clEnableI1Support);
62
+ }
63
+
64
+ Type legalizeStorageElementType (Type elementType, bool isPackedStorage) {
44
65
// Only handle integers; floats in MLIR all have aligned widths (today).
45
66
auto intType = dyn_cast<IntegerType>(elementType);
46
67
if (!intType)
47
68
return elementType;
48
69
49
70
// For sub-byte elements, default to pack them into bytes.
50
71
unsigned bitWidth = intType.getWidth ();
51
- if (needToPackSubByteElementBitWidth (bitWidth))
72
+ if (needToPackSubByteElementBitWidth (bitWidth, isPackedStorage ))
52
73
return elementType;
53
74
54
75
// Otherwise, extend them to the next power-of-two bit width.
@@ -72,13 +93,16 @@ Value calculateStorageElementCountInBytes(Location loc,
72
93
loc, builder, shapedType, dynamicDims);
73
94
}
74
95
96
+ // TODO: remove cl options once frontend can emit packed i1 tensors.
97
+ bool isPackedStorage =
98
+ IREE::Encoding::hasPackedStorageAttr (shapedType) || clEnableI1Support;
75
99
Type alignedElementType =
76
- legalizeStorageElementType (shapedType.getElementType ());
100
+ legalizeStorageElementType (shapedType.getElementType (), isPackedStorage );
77
101
unsigned elementBits = IREE::Util::getTypeBitWidth (alignedElementType);
78
102
79
103
// Calculate all static dims first, if any.
80
104
int64_t staticCount = 1 ;
81
- if (!needToPackSubByteElementBitWidth (elementBits)) {
105
+ if (!needToPackSubByteElementBitWidth (elementBits, isPackedStorage )) {
82
106
staticCount *= IREE::Util::getRoundedElementByteWidth (alignedElementType);
83
107
}
84
108
@@ -93,13 +117,13 @@ Value calculateStorageElementCountInBytes(Location loc,
93
117
value = builder.createOrFold <arith::MulIOp>(loc, value, dim);
94
118
}
95
119
// Sub-byte packing requires putting multiple elements in the same byte.
96
- if (needToPackSubByteElementBitWidth (elementBits)) {
120
+ if (needToPackSubByteElementBitWidth (elementBits, isPackedStorage )) {
97
121
assert (8 % elementBits == 0 );
98
122
unsigned byteElements = 8 / elementBits;
99
123
// TODO(antiagainst): We may want to emit runtime check to make sure this is
100
124
// divisible.
101
125
auto divisor = builder.create <arith::ConstantIndexOp>(loc, byteElements);
102
- if (!clEnableI1Support && dynamicDims.empty () &&
126
+ if (!isPackedStorage && dynamicDims.empty () &&
103
127
(staticCount * elementBits) % 8 != 0 ) {
104
128
return nullptr ;
105
129
}
@@ -113,12 +137,15 @@ Value calculateStorageElementOffsetInBytes(Location loc,
113
137
RankedTensorType originalType,
114
138
Value linearizedIndex,
115
139
OpBuilder &builder) {
116
- Type alignedElementType =
117
- legalizeStorageElementType (originalType.getElementType ());
140
+ // TODO: remove cl options once frontend can emit packed i1 tensors.
141
+ bool isPackedStorage =
142
+ IREE::Encoding::hasPackedStorageAttr (originalType) || clEnableI1Support;
143
+ Type alignedElementType = legalizeStorageElementType (
144
+ originalType.getElementType (), isPackedStorage);
118
145
unsigned elementBits = IREE::Util::getTypeBitWidth (alignedElementType);
119
146
120
147
// Sub-byte packing requires putting multiple elements in the same byte.
121
- if (needToPackSubByteElementBitWidth (elementBits)) {
148
+ if (needToPackSubByteElementBitWidth (elementBits, isPackedStorage )) {
122
149
Value byteElements =
123
150
builder.create <arith::ConstantIndexOp>(loc, 8 / elementBits);
124
151
// TODO(antiagainst): We may want to emit runtime check to make sure this is
0 commit comments