15
15
#include " mlir/Dialect/Arith/IR/Arith.h"
16
16
#include " mlir/IR/BuiltinTypes.h"
17
17
18
- namespace mlir ::iree_compiler {
19
-
18
+ // TODO(lialan): remove cl options once frontend can emit packed i1 tensors.
20
19
llvm::cl::opt<bool > clEnableI1Support (
21
20
" iree-experimental-packed-i1-storage" ,
22
21
llvm::cl::desc (
23
- " Experimental feature: enable i1 data type support in codegen" ),
22
+ " Experimental feature: force to use packed storage for i1 tensors."
23
+ " Turning on this option will see i1 tensors as if it has "
24
+ " #iree_encoding.packed_storage attribute."
25
+ " This is to allow an alternative way to test the packed storage "
26
+ " feature before frontend can emit packed i1 tensors."
27
+ " This option can be dropped once the frontend can emit packed i1 "
28
+ " tensors." ),
24
29
llvm::cl::init(false ));
25
30
26
- bool needToPackSubByteElementBitWidth (unsigned bitWidth) {
31
+ namespace mlir ::iree_compiler {
32
+
33
+ static bool needToPackSubByteElementBitWidthImpl (unsigned bitWidth,
34
+ bool isPackedStorage) {
27
35
// Enable i1 support if requested.
28
- if (clEnableI1Support && bitWidth == 1 ) {
36
+ if (isPackedStorage && bitWidth == 1 ) {
29
37
return true ;
30
38
}
31
39
// Require the original bit width to be some power of two for now to avoid
@@ -35,20 +43,31 @@ bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
35
43
return bitWidth < 8 && llvm::isPowerOf2_32 (bitWidth) && bitWidth != 1 ;
36
44
}
37
45
46
+ bool needToPackSubByteElementBitWidth (unsigned bitWidth) {
47
+ return needToPackSubByteElementBitWidthImpl (
48
+ bitWidth, /* isPackedStorage=*/ clEnableI1Support);
49
+ }
50
+
38
51
bool needToPackSubByteElements (RankedTensorType shapedType) {
39
52
unsigned bitWidth = IREE::Util::getTypeBitWidth (shapedType.getElementType ());
40
- return needToPackSubByteElementBitWidth (bitWidth);
53
+ // Two paths to enable packed storage for i1 tensors: the attribute or cl
54
+ // option. The cl option will be dropped once frontend supports emitting
55
+ // tensors with attributes.
56
+ bool isPackedStorage =
57
+ IREE::Encoding::hasPackedStorageAttr (shapedType) || clEnableI1Support;
58
+ return needToPackSubByteElementBitWidthImpl (bitWidth, isPackedStorage);
41
59
}
42
60
43
- Type legalizeStorageElementType (Type elementType) {
61
+ static Type legalizeStorageElementTypeImpl (Type elementType,
62
+ bool isPackedStorage) {
44
63
// Only handle integers; floats in MLIR all have aligned widths (today).
45
64
auto intType = dyn_cast<IntegerType>(elementType);
46
65
if (!intType)
47
66
return elementType;
48
67
49
68
// For sub-byte elements, default to pack them into bytes.
50
69
unsigned bitWidth = intType.getWidth ();
51
- if (needToPackSubByteElementBitWidth (bitWidth))
70
+ if (needToPackSubByteElementBitWidthImpl (bitWidth, isPackedStorage ))
52
71
return elementType;
53
72
54
73
// Otherwise, extend them to the next power-of-two bit width.
@@ -60,6 +79,12 @@ Type legalizeStorageElementType(Type elementType) {
60
79
intType.getSignedness ());
61
80
}
62
81
82
+ Type legalizeStorageElementType (Type elementType) {
83
+ // Consider packed storage for i1 tensors if cl opt is set.
84
+ return legalizeStorageElementTypeImpl (elementType,
85
+ /* isPackedStorage=*/ clEnableI1Support);
86
+ }
87
+
63
88
Value calculateStorageElementCountInBytes (Location loc,
64
89
RankedTensorType shapedType,
65
90
ValueRange dynamicDims,
@@ -72,13 +97,15 @@ Value calculateStorageElementCountInBytes(Location loc,
72
97
loc, builder, shapedType, dynamicDims);
73
98
}
74
99
75
- Type alignedElementType =
76
- legalizeStorageElementType (shapedType.getElementType ());
100
+ bool isPackedStorage =
101
+ IREE::Encoding::hasPackedStorageAttr (shapedType) || clEnableI1Support;
102
+ Type alignedElementType = legalizeStorageElementTypeImpl (
103
+ shapedType.getElementType (), isPackedStorage);
77
104
unsigned elementBits = IREE::Util::getTypeBitWidth (alignedElementType);
78
105
79
106
// Calculate all static dims first, if any.
80
107
int64_t staticCount = 1 ;
81
- if (!needToPackSubByteElementBitWidth (elementBits)) {
108
+ if (!needToPackSubByteElementBitWidthImpl (elementBits, isPackedStorage )) {
82
109
staticCount *= IREE::Util::getRoundedElementByteWidth (alignedElementType);
83
110
}
84
111
@@ -93,13 +120,13 @@ Value calculateStorageElementCountInBytes(Location loc,
93
120
value = builder.createOrFold <arith::MulIOp>(loc, value, dim);
94
121
}
95
122
// Sub-byte packing requires putting multiple elements in the same byte.
96
- if (needToPackSubByteElementBitWidth (elementBits)) {
123
+ if (needToPackSubByteElementBitWidthImpl (elementBits, isPackedStorage )) {
97
124
assert (8 % elementBits == 0 );
98
125
unsigned byteElements = 8 / elementBits;
99
126
// TODO(antiagainst): We may want to emit runtime check to make sure this is
100
127
// divisible.
101
128
auto divisor = builder.create <arith::ConstantIndexOp>(loc, byteElements);
102
- if (!clEnableI1Support && dynamicDims.empty () &&
129
+ if (!isPackedStorage && dynamicDims.empty () &&
103
130
(staticCount * elementBits) % 8 != 0 ) {
104
131
return nullptr ;
105
132
}
@@ -113,12 +140,14 @@ Value calculateStorageElementOffsetInBytes(Location loc,
113
140
RankedTensorType originalType,
114
141
Value linearizedIndex,
115
142
OpBuilder &builder) {
116
- Type alignedElementType =
117
- legalizeStorageElementType (originalType.getElementType ());
143
+ bool isPackedStorage =
144
+ IREE::Encoding::hasPackedStorageAttr (originalType) || clEnableI1Support;
145
+ Type alignedElementType = legalizeStorageElementTypeImpl (
146
+ originalType.getElementType (), isPackedStorage);
118
147
unsigned elementBits = IREE::Util::getTypeBitWidth (alignedElementType);
119
148
120
149
// Sub-byte packing requires putting multiple elements in the same byte.
121
- if (needToPackSubByteElementBitWidth (elementBits)) {
150
+ if (needToPackSubByteElementBitWidthImpl (elementBits, isPackedStorage )) {
122
151
Value byteElements =
123
152
builder.create <arith::ConstantIndexOp>(loc, 8 / elementBits);
124
153
// TODO(antiagainst): We may want to emit runtime check to make sure this is
0 commit comments