15
15
16
16
namespace mlir ::iree_compiler {
17
17
18
- llvm::cl::opt<bool > clEnableI1Support (
19
- " iree-experimental-packed-i1-storage" ,
20
- llvm::cl::desc (
21
- " Experimental feature: enable i1 data type support in codegen" ),
22
- llvm::cl::init(false ));
23
-
24
- bool needToPackSubByteElementBitWidth (unsigned bitWidth) {
18
+ bool needToPackSubByteElementBitWidth (unsigned bitWidth,
19
+ bool isI1PackedStorage) {
25
20
// Enable i1 support if requested.
26
- if (clEnableI1Support && bitWidth == 1 ) {
21
+ if (isI1PackedStorage && bitWidth == 1 ) {
27
22
return true ;
28
23
}
29
24
// Require the original bit width to be some power of two for now to avoid
@@ -35,18 +30,20 @@ bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
35
30
36
31
bool needToPackSubByteElements (RankedTensorType shapedType) {
37
32
unsigned bitWidth = IREE::Util::getTypeBitWidth (shapedType.getElementType ());
38
- return needToPackSubByteElementBitWidth (bitWidth);
33
+ auto encoding = IREE::Encoding::getEncodingAttr (shapedType);
34
+ bool isI1PackedStorage = encoding && encoding.i1PackedStorage ();
35
+ return needToPackSubByteElementBitWidth (bitWidth, isI1PackedStorage);
39
36
}
40
37
41
- Type legalizeStorageElementType (Type elementType) {
38
+ Type legalizeStorageElementType (Type elementType, bool isI1PackedStorage ) {
42
39
// Only handle integers; floats in MLIR all have aligned widths (today).
43
40
auto intType = dyn_cast<IntegerType>(elementType);
44
41
if (!intType)
45
42
return elementType;
46
43
47
44
// For sub-byte elements, default to pack them into bytes.
48
45
unsigned bitWidth = intType.getWidth ();
49
- if (needToPackSubByteElementBitWidth (bitWidth))
46
+ if (needToPackSubByteElementBitWidth (bitWidth, isI1PackedStorage ))
50
47
return elementType;
51
48
52
49
// Otherwise, extend them to the next power-of-two bit width.
@@ -62,21 +59,22 @@ Value calculateStorageElementCountInBytes(Location loc,
62
59
RankedTensorType shapedType,
63
60
ValueRange dynamicDims,
64
61
OpBuilder &builder) {
62
+ auto encoding = IREE::Encoding::getEncodingAttr (shapedType);
63
+ bool packedStorage = IREE::Encoding::hasPackedStorageAttr (shapedType);
65
64
Type alignedElementType =
66
- legalizeStorageElementType (shapedType.getElementType ());
65
+ legalizeStorageElementType (shapedType.getElementType (), packedStorage );
67
66
unsigned elementBits = IREE::Util::getTypeBitWidth (alignedElementType);
68
67
69
68
// Calculate all static dims first, if any.
70
69
int64_t staticCount = 1 ;
71
- if (!needToPackSubByteElementBitWidth (elementBits)) {
70
+ if (!needToPackSubByteElementBitWidth (elementBits, packedStorage )) {
72
71
staticCount *= IREE::Util::getRoundedElementByteWidth (alignedElementType);
73
72
}
74
73
75
74
// TODO: Do we use makeComposedFoldedAffineApply here, so the index
76
75
// computation an be much simpler.
77
76
SmallVector<int64_t > paddedShape (shapedType.getShape ());
78
77
SmallVector<Value> paddedDynamicDims (dynamicDims.begin (), dynamicDims.end ());
79
- auto encoding = IREE::Encoding::getEncodingAttr (shapedType);
80
78
if (encoding && !encoding.getRoundDimsToArray ().empty ()) {
81
79
auto roundDimsTo = encoding.getRoundDimsToArray ();
82
80
FailureOr<linalg::ContractionDimensions> cDims =
@@ -121,13 +119,13 @@ Value calculateStorageElementCountInBytes(Location loc,
121
119
value = builder.createOrFold <arith::MulIOp>(loc, value, dim);
122
120
}
123
121
// Sub-byte packing requires putting multiple elements in the same byte.
124
- if (needToPackSubByteElementBitWidth (elementBits)) {
122
+ if (needToPackSubByteElementBitWidth (elementBits, packedStorage )) {
125
123
assert (8 % elementBits == 0 );
126
124
unsigned byteElements = 8 / elementBits;
127
125
// TODO(antiagainst): We may want to emit runtime check to make sure this is
128
126
// divisible.
129
127
auto divisor = builder.create <arith::ConstantIndexOp>(loc, byteElements);
130
- if (!clEnableI1Support && paddedDynamicDims.empty () &&
128
+ if (!packedStorage && paddedDynamicDims.empty () &&
131
129
(staticCount * elementBits) % 8 != 0 ) {
132
130
return nullptr ;
133
131
}
@@ -145,8 +143,11 @@ Value calculateStorageElementOffsetInBytes(Location loc,
145
143
legalizeStorageElementType (originalType.getElementType ());
146
144
unsigned elementBits = IREE::Util::getTypeBitWidth (alignedElementType);
147
145
146
+ auto encoding = IREE::Encoding::getEncodingAttr (originalType);
147
+ bool isI1PackedStorage = encoding && encoding.i1PackedStorage ();
148
+
148
149
// Sub-byte packing requires putting multiple elements in the same byte.
149
- if (needToPackSubByteElementBitWidth (elementBits)) {
150
+ if (needToPackSubByteElementBitWidth (elementBits, isI1PackedStorage )) {
150
151
Value byteElements =
151
152
builder.create <arith::ConstantIndexOp>(loc, 8 / elementBits);
152
153
// TODO(antiagainst): We may want to emit runtime check to make sure this is
0 commit comments