15
15
#include " mlir/Dialect/Arith/IR/Arith.h"
16
16
#include " mlir/IR/BuiltinTypes.h"
17
17
18
- llvm::cl::opt<bool > clEnableI1Support (
19
- " iree-experimental-packed-i1-storage" ,
20
- llvm::cl::desc (
21
- " Experimental feature: force to use packed storage for i1 tensors."
22
- " Turning on this option will see i1 tensors as if it has "
23
- " #iree_encoding.packed_storage attribute."
24
- " This is to allow an alternative way to test the packed storage "
25
- " feature before frontend can emit packed i1 tensors."
26
- " This option can be dropped once the frontend can emit packed i1 "
27
- " tensors." ),
28
- llvm::cl::init(false ));
29
-
30
18
namespace mlir ::iree_compiler {
31
19
32
- static bool needToPackSubByteElementBitWidthImpl (unsigned bitWidth,
33
- bool isPackedStorage) {
34
- // Enable i1 support if requested.
35
- if (isPackedStorage && bitWidth == 1 ) {
36
- return true ;
37
- }
20
+ bool needToPackSubByteElementBitWidth (unsigned bitWidth) {
38
21
// Require the original bit width to be some power of two for now to avoid
39
22
// trickiness and weirdness of packing and cross-byte access.
40
23
// Also disallow boolean values for now--they may require separate interface
41
24
// choices.
42
25
return bitWidth < 8 && llvm::isPowerOf2_32 (bitWidth) && bitWidth != 1 ;
43
26
}
44
27
45
- bool needToPackSubByteElementBitWidth (unsigned bitWidth) {
46
- return needToPackSubByteElementBitWidthImpl (
47
- bitWidth, /* isPackedStorage=*/ clEnableI1Support);
48
- }
49
-
50
28
bool needToPackSubByteElements (RankedTensorType shapedType) {
51
29
unsigned bitWidth = IREE::Util::getTypeBitWidth (shapedType.getElementType ());
52
- // Two paths to enable packed storage for i1 tensors: the attribute or cl
53
- // option. The cl option will be dropped once frontend supports emitting
54
- // tensors with attributes.
55
- bool isPackedStorage =
56
- IREE::Encoding::hasPackedStorageAttr (shapedType) || clEnableI1Support;
57
- return needToPackSubByteElementBitWidthImpl (bitWidth, isPackedStorage);
30
+ // i1 with packed memory layout does not need to be extended.
31
+ if (bitWidth == 1 && !IREE::Encoding::hasPackedStorageAttr (shapedType)) {
32
+ return true ;
33
+ }
34
+ return needToPackSubByteElementBitWidth (bitWidth);
58
35
}
59
36
60
37
static Type legalizeStorageElementTypeImpl (Type elementType,
@@ -64,9 +41,13 @@ static Type legalizeStorageElementTypeImpl(Type elementType,
64
41
if (!intType)
65
42
return elementType;
66
43
67
- // For sub-byte elements, default to pack them into bytes.
68
44
unsigned bitWidth = intType.getWidth ();
69
- if (needToPackSubByteElementBitWidthImpl (bitWidth, isPackedStorage))
45
+ if (bitWidth == 1 && !isPackedStorage) {
46
+ return elementType;
47
+ }
48
+
49
+ // For sub-byte elements, default to pack them into bytes.
50
+ if (needToPackSubByteElementBitWidth (bitWidth))
70
51
return elementType;
71
52
72
53
// Otherwise, extend them to the next power-of-two bit width.
@@ -78,10 +59,10 @@ static Type legalizeStorageElementTypeImpl(Type elementType,
78
59
intType.getSignedness ());
79
60
}
80
61
81
- Type legalizeStorageElementType (Type elementType ) {
82
- // Consider packed storage for i1 tensors if cl opt is set.
83
- return legalizeStorageElementTypeImpl (elementType,
84
- /* isPackedStorage= */ clEnableI1Support );
62
+ Type legalizeTensorStorageElementType (Type type ) {
63
+ auto tensorType = llvm::dyn_cast<RankedTensorType>(type);
64
+ return legalizeStorageElementTypeImpl (
65
+ type, tensorType && IREE::Encoding::hasPackedStorageAttr (type) );
85
66
}
86
67
87
68
Value calculateStorageElementCountInBytes (Location loc,
@@ -96,16 +77,16 @@ Value calculateStorageElementCountInBytes(Location loc,
96
77
loc, builder, shapedType, dynamicDims);
97
78
}
98
79
99
- // TODO(lialan): remove cl options once frontend can emit packed i1 tensors.
100
- bool isPackedStorage =
101
- IREE::Encoding::hasPackedStorageAttr (shapedType) || clEnableI1Support;
102
- Type alignedElementType = legalizeStorageElementTypeImpl (
103
- shapedType.getElementType (), isPackedStorage);
80
+ Type alignedElementType = legalizeTensorStorageElementType (shapedType);
104
81
unsigned elementBits = IREE::Util::getTypeBitWidth (alignedElementType);
105
82
83
+ bool isPackedStorage = IREE::Encoding::hasPackedStorageAttr (shapedType);
84
+ bool isI1WithPackedStorage = elementBits == 1 && isPackedStorage;
85
+
106
86
// Calculate all static dims first, if any.
107
87
int64_t staticCount = 1 ;
108
- if (!needToPackSubByteElementBitWidthImpl (elementBits, isPackedStorage)) {
88
+ if (!isI1WithPackedStorage &&
89
+ !needToPackSubByteElementBitWidth (elementBits)) {
109
90
staticCount *= IREE::Util::getRoundedElementByteWidth (alignedElementType);
110
91
}
111
92
@@ -120,7 +101,7 @@ Value calculateStorageElementCountInBytes(Location loc,
120
101
value = builder.createOrFold <arith::MulIOp>(loc, value, dim);
121
102
}
122
103
// Sub-byte packing requires putting multiple elements in the same byte.
123
- if (needToPackSubByteElementBitWidthImpl (elementBits, isPackedStorage )) {
104
+ if (isI1WithPackedStorage || needToPackSubByteElementBitWidth (elementBits)) {
124
105
assert (8 % elementBits == 0 );
125
106
unsigned byteElements = 8 / elementBits;
126
107
// TODO(antiagainst): We may want to emit runtime check to make sure this is
@@ -140,15 +121,14 @@ Value calculateStorageElementOffsetInBytes(Location loc,
140
121
RankedTensorType originalType,
141
122
Value linearizedIndex,
142
123
OpBuilder &builder) {
143
- // TODO: remove cl options once frontend can emit packed i1 tensors.
144
- bool isPackedStorage =
145
- IREE::Encoding::hasPackedStorageAttr (originalType) || clEnableI1Support;
146
- Type alignedElementType = legalizeStorageElementTypeImpl (
147
- originalType.getElementType (), isPackedStorage);
124
+ Type alignedElementType = legalizeTensorStorageElementType (originalType);
148
125
unsigned elementBits = IREE::Util::getTypeBitWidth (alignedElementType);
149
126
127
+ bool isPackedStorage = IREE::Encoding::hasPackedStorageAttr (originalType);
128
+ bool isI1WithPackedStorage = elementBits == 1 && isPackedStorage;
129
+
150
130
// Sub-byte packing requires putting multiple elements in the same byte.
151
- if (needToPackSubByteElementBitWidthImpl (elementBits, isPackedStorage )) {
131
+ if (isI1WithPackedStorage || needToPackSubByteElementBitWidth (elementBits)) {
152
132
Value byteElements =
153
133
builder.create <arith::ConstantIndexOp>(loc, 8 / elementBits);
154
134
// TODO(antiagainst): We may want to emit runtime check to make sure this is
0 commit comments