15
15
#include " mlir/Dialect/Arith/IR/Arith.h"
16
16
#include " mlir/IR/BuiltinTypes.h"
17
17
18
- namespace mlir ::iree_compiler {
19
-
20
18
llvm::cl::opt<bool > clEnableI1Support (
21
19
" iree-experimental-packed-i1-storage" ,
22
20
llvm::cl::desc (
23
- " Experimental feature: enable i1 data type support in codegen" ),
21
+ " Experimental feature: force to use packed storage for i1 tensors."
22
+ " Turning on this option will see i1 tensors as if it has "
23
+ " #iree_encoding.packed_storage attribute."
24
+ " This is to allow an alternative way to test the packed storage "
25
+ " feature before frontend can emit packed i1 tensors."
26
+ " This option can be dropped once the frontend can emit packed i1 "
27
+ " tensors." ),
24
28
llvm::cl::init(false ));
25
29
30
+ namespace mlir ::iree_compiler {
31
+
26
32
bool needToPackSubByteElementBitWidth (unsigned bitWidth) {
33
+ return needToPackSubByteElementBitWidth (
34
+ bitWidth, /* isI1PackedStorage=*/ clEnableI1Support);
35
+ }
36
+
37
+ bool needToPackSubByteElementBitWidth (unsigned bitWidth,
38
+ bool isI1PackedStorage) {
27
39
// Enable i1 support if requested.
28
- if (clEnableI1Support && bitWidth == 1 ) {
40
+ if (isI1PackedStorage && bitWidth == 1 ) {
29
41
return true ;
30
42
}
31
43
// Require the original bit width to be some power of two for now to avoid
@@ -37,18 +49,28 @@ bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
37
49
38
50
bool needToPackSubByteElements (RankedTensorType shapedType) {
39
51
unsigned bitWidth = IREE::Util::getTypeBitWidth (shapedType.getElementType ());
40
- return needToPackSubByteElementBitWidth (bitWidth);
52
+ // Two paths to enable packed storage for i1 tensors: the attribute or cl
53
+ // option. The cl option will be dropped once frontend supports emitting
54
+ // tensors with attributes.
55
+ bool isI1PackedStorage =
56
+ IREE::Encoding::hasPackedStorageAttr (shapedType) || clEnableI1Support;
57
+ return needToPackSubByteElementBitWidth (bitWidth, isI1PackedStorage);
41
58
}
42
59
43
60
Type legalizeStorageElementType (Type elementType) {
61
+ return legalizeStorageElementType (elementType,
62
+ /* isI1PackedStorage=*/ clEnableI1Support);
63
+ }
64
+
65
+ Type legalizeStorageElementType (Type elementType, bool isI1PackedStorage) {
44
66
// Only handle integers; floats in MLIR all have aligned widths (today).
45
67
auto intType = dyn_cast<IntegerType>(elementType);
46
68
if (!intType)
47
69
return elementType;
48
70
49
71
// For sub-byte elements, default to pack them into bytes.
50
72
unsigned bitWidth = intType.getWidth ();
51
- if (needToPackSubByteElementBitWidth (bitWidth))
73
+ if (needToPackSubByteElementBitWidth (bitWidth, isI1PackedStorage ))
52
74
return elementType;
53
75
54
76
// Otherwise, extend them to the next power-of-two bit width.
@@ -72,13 +94,16 @@ Value calculateStorageElementCountInBytes(Location loc,
72
94
loc, builder, shapedType, dynamicDims);
73
95
}
74
96
97
+ // TODO: remove cl options once frontend can emit packed i1 tensors.
98
+ bool packedStorage =
99
+ IREE::Encoding::hasPackedStorageAttr (shapedType) || clEnableI1Support;
75
100
Type alignedElementType =
76
- legalizeStorageElementType (shapedType.getElementType ());
101
+ legalizeStorageElementType (shapedType.getElementType (), packedStorage );
77
102
unsigned elementBits = IREE::Util::getTypeBitWidth (alignedElementType);
78
103
79
104
// Calculate all static dims first, if any.
80
105
int64_t staticCount = 1 ;
81
- if (!needToPackSubByteElementBitWidth (elementBits)) {
106
+ if (!needToPackSubByteElementBitWidth (elementBits, packedStorage )) {
82
107
staticCount *= IREE::Util::getRoundedElementByteWidth (alignedElementType);
83
108
}
84
109
@@ -93,13 +118,13 @@ Value calculateStorageElementCountInBytes(Location loc,
93
118
value = builder.createOrFold <arith::MulIOp>(loc, value, dim);
94
119
}
95
120
// Sub-byte packing requires putting multiple elements in the same byte.
96
- if (needToPackSubByteElementBitWidth (elementBits)) {
121
+ if (needToPackSubByteElementBitWidth (elementBits, packedStorage )) {
97
122
assert (8 % elementBits == 0 );
98
123
unsigned byteElements = 8 / elementBits;
99
124
// TODO(antiagainst): We may want to emit runtime check to make sure this is
100
125
// divisible.
101
126
auto divisor = builder.create <arith::ConstantIndexOp>(loc, byteElements);
102
- if (!clEnableI1Support && dynamicDims.empty () &&
127
+ if (!packedStorage && dynamicDims.empty () &&
103
128
(staticCount * elementBits) % 8 != 0 ) {
104
129
return nullptr ;
105
130
}
@@ -113,12 +138,15 @@ Value calculateStorageElementOffsetInBytes(Location loc,
113
138
RankedTensorType originalType,
114
139
Value linearizedIndex,
115
140
OpBuilder &builder) {
116
- Type alignedElementType =
117
- legalizeStorageElementType (originalType.getElementType ());
141
+ // TODO: remove cl options once frontend can emit packed i1 tensors.
142
+ bool isI1PackedStorage =
143
+ IREE::Encoding::hasPackedStorageAttr (originalType) || clEnableI1Support;
144
+ Type alignedElementType = legalizeStorageElementType (
145
+ originalType.getElementType (), isI1PackedStorage);
118
146
unsigned elementBits = IREE::Util::getTypeBitWidth (alignedElementType);
119
147
120
148
// Sub-byte packing requires putting multiple elements in the same byte.
121
- if (needToPackSubByteElementBitWidth (elementBits)) {
149
+ if (needToPackSubByteElementBitWidth (elementBits, isI1PackedStorage )) {
122
150
Value byteElements =
123
151
builder.create <arith::ConstantIndexOp>(loc, 8 / elementBits);
124
152
// TODO(antiagainst): We may want to emit runtime check to make sure this is
0 commit comments