Skip to content

Commit c793f90

Browse files
authored
[i1] Implement packed_storage layout encoding attribute (#19354)
* make `packed_storage` as a type of `iree_encoding` attribute, and make type converters accept it. * `i1` tensors with `#iree_encoding.packed_storage` will be interpreted as packed i1 type, same as specifying `--iree-experimental-packed-i1-storage`. Other i1 tensors are treated as non-packed datatype, and will be extended. * `--iree-experimental-packed-i1-storage` are kept for testing purposes. * We can drop this option after frontend enables emitting `i1` tensors with attributes. Signed-off-by: Alan Li <me@alanli.org>
1 parent 801e2c1 commit c793f90

15 files changed

+117
-35
lines changed

compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp

+9-6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "iree/compiler/Codegen/Common/EncodingUtils.h"
88
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
9+
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
910
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
1011
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1112
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -62,16 +63,18 @@ MaterializeEncodingConversionTarget::MaterializeEncodingConversionTarget(
6263
// Mark any operation that has operands/results with encoding as
6364
// illegal.
6465
markUnknownOpDynamicallyLegal([](Operation *op) {
65-
auto typeHasEncoding = [](Type t) -> bool {
66+
auto typeHasDataTilingEncoding = [](Type t) -> bool {
6667
auto tensorType = dyn_cast<RankedTensorType>(t);
67-
return tensorType && tensorType.getEncoding();
68+
if (!tensorType)
69+
return false;
70+
return getEncodingAttr(tensorType) != nullptr;
6871
};
69-
auto valueHasEncoding = [=](Value v) -> bool {
70-
return typeHasEncoding(v.getType());
72+
auto valueHasDataTilingEncoding = [=](Value v) -> bool {
73+
return typeHasDataTilingEncoding(v.getType());
7174
};
7275
bool hasOperandOrResultsWithEncoding =
73-
llvm::any_of(op->getOperands(), valueHasEncoding) ||
74-
llvm::any_of(op->getResultTypes(), typeHasEncoding);
76+
llvm::any_of(op->getOperands(), valueHasDataTilingEncoding) ||
77+
llvm::any_of(op->getResultTypes(), typeHasDataTilingEncoding);
7578
return !hasOperandOrResultsWithEncoding;
7679
});
7780
}

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,10 @@ EncodingAttr getEncodingAttr(RankedTensorType type) {
253253
return dyn_cast_or_null<EncodingAttr>(type.getEncoding());
254254
}
255255

256+
bool hasPackedStorageAttr(RankedTensorType type) {
257+
return dyn_cast_or_null<PackedStorageAttr>(type.getEncoding()) != nullptr;
258+
}
259+
256260
FailureOr<linalg::ContractionDimensions>
257261
getEncodingContractionDims(EncodingAttr encoding) {
258262
auto indexingMapsAttr = encoding.getUserIndexingMaps();

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td

+11
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ def EncodingOpType : IREEEncoding_I32EnumAttr<"EncodingOpType",
4141
def EncodingOpTypeAttr:
4242
IREEEncoding_EnumAttr<EncodingOpType, "optype">;
4343

44+
45+
def PackedStorageAttr : IREEEncoding_Attr<"PackedStorage"> {
46+
let mnemonic = "packed_storage";
47+
let summary = [{Indicates packed storage data type.}];
48+
let description = [{
49+
This attribute indicates this is a back-to-back packed storage in memory.
50+
This attribute takes no arguments.
51+
}];
52+
let genVerifyDecl = 0;
53+
}
54+
4455
def EncodingAttr :
4556
IREEEncoding_Attr<"Encoding", [
4657
DeclareAttrInterfaceMethods<IREEEncoding_EncodingLayoutAttrInterface, [

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h

+3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ namespace mlir::iree_compiler::IREE::Encoding {
3838
/// Otherwise, returns null.
3939
EncodingAttr getEncodingAttr(RankedTensorType type);
4040

41+
/// Returns true if the type contains packed_storage attribute.
42+
bool hasPackedStorageAttr(RankedTensorType type);
43+
4144
/// Returns the ContractionDimensions for the encoding user_indexing_maps.
4245
FailureOr<linalg::ContractionDimensions>
4346
getEncodingContractionDims(EncodingAttr encoding);

compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/BUILD.bazel

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ iree_compiler_cc_library(
2121
"Patterns.h",
2222
],
2323
deps = [
24+
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
2425
"//compiler/src/iree/compiler/Dialect/HAL/IR",
2526
"//compiler/src/iree/compiler/Dialect/Stream/Conversion",
2627
"//compiler/src/iree/compiler/Dialect/Stream/IR",

compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ iree_cc_library(
2424
MLIRIR
2525
MLIRTransformUtils
2626
MLIRTransforms
27+
iree::compiler::Dialect::Encoding::IR
2728
iree::compiler::Dialect::HAL::IR
2829
iree::compiler::Dialect::Stream::Conversion
2930
iree::compiler::Dialect::Stream::IR

compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.h"
88

9+
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
910
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
1011
#include "iree/compiler/Dialect/Stream/Conversion/PatternUtils.h"
1112
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
@@ -100,6 +101,13 @@ struct ConvertTensorImportOp
100101
RankedTensorType tensorType,
101102
ValueRange dynamicDims,
102103
OpBuilder &builder) {
104+
// If the encoding attr is about packed storage then we don't need
105+
// assertion, because packed storage attribute is about memory layout and it
106+
// doesn't affect the tensor shape.
107+
if (IREE::Encoding::hasPackedStorageAttr(tensorType)) {
108+
return success();
109+
}
110+
103111
auto expectedElementType = builder.create<IREE::HAL::ElementTypeOp>(
104112
loc, tensorType.getElementType());
105113
auto expectedEncodingType = builder.create<IREE::HAL::EncodingTypeOp>(

compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ static LogicalResult checkEncoding(Operation *op, RankedTensorType encodingType,
4646
ValueRange encodingDims,
4747
PatternRewriter &rewriter) {
4848
auto encoding = encodingType.getEncoding();
49-
if (encoding && !llvm::isa<IREE::Encoding::EncodingAttr>(encoding)) {
49+
if (encoding && !llvm::isa<IREE::Encoding::EncodingAttr,
50+
IREE::Encoding::PackedStorageAttr>(encoding)) {
5051
return rewriter.notifyMatchFailure(op, [=](Diagnostic &d) {
5152
d << "unsupported tensor encoding: " << encodingType;
5253
});

compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ iree_lit_test_suite(
2929
"encode_device_tensors_packing.mlir",
3030
"encode_host_tensors.mlir",
3131
"encode_host_tensors_packing.mlir",
32-
"encode_host_tensors_packing_i1.mlir",
32+
"encode_host_tensors_packing_i1_attr.mlir",
33+
"encode_host_tensors_packing_i1_experimental_clopt.mlir",
3334
"fold_globals.mlir",
3435
"fold_uniform_operands.mlir",
3536
"fuse_dispatch_bindings.mlir",

compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ iree_lit_test_suite(
2727
"encode_device_tensors_packing.mlir"
2828
"encode_host_tensors.mlir"
2929
"encode_host_tensors_packing.mlir"
30-
"encode_host_tensors_packing_i1.mlir"
30+
"encode_host_tensors_packing_i1_attr.mlir"
31+
"encode_host_tensors_packing_i1_experimental_clopt.mlir"
3132
"fold_globals.mlir"
3233
"fold_uniform_operands.mlir"
3334
"fuse_dispatch_bindings.mlir"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: iree-opt --split-input-file --iree-stream-encode-host-tensors %s | FileCheck %s
2+
3+
#packed = #iree_encoding.packed_storage
4+
func.func @unaligned_i1_size() -> index {
5+
%0 = stream.tensor.sizeof tensor<12xi1, #packed> : index
6+
return %0 : index
7+
}
8+
// CHECK: func @unaligned_i1_size() -> index {
9+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
10+
// CHECK: return %[[C2]] : index
11+
12+
// -----
13+
14+
#packed = #iree_encoding.packed_storage
15+
func.func @aligned_i1_size() -> index {
16+
%0 = stream.tensor.sizeof tensor<24xi1, #packed> : index
17+
return %0 : index
18+
}
19+
20+
// CHECK: func @aligned_i1_size() -> index {
21+
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
22+
// CHECK: return %[[C3]] : index

compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp

+45-16
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,25 @@
1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/IR/BuiltinTypes.h"
1717

18-
namespace mlir::iree_compiler {
19-
18+
// TODO(lialan): remove cl options once frontend can emit packed i1 tensors.
2019
llvm::cl::opt<bool> clEnableI1Support(
2120
"iree-experimental-packed-i1-storage",
2221
llvm::cl::desc(
23-
"Experimental feature: enable i1 data type support in codegen"),
22+
"Experimental feature: force to use packed storage for i1 tensors."
23+
"Turning on this option will see i1 tensors as if it has "
24+
"#iree_encoding.packed_storage attribute."
25+
"This is to allow an alternative way to test the packed storage "
26+
"feature before frontend can emit packed i1 tensors."
27+
"This option can be dropped once the frontend can emit packed i1 "
28+
"tensors."),
2429
llvm::cl::init(false));
2530

26-
bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
31+
namespace mlir::iree_compiler {
32+
33+
static bool needToPackSubByteElementBitWidthImpl(unsigned bitWidth,
34+
bool isPackedStorage) {
2735
// Enable i1 support if requested.
28-
if (clEnableI1Support && bitWidth == 1) {
36+
if (isPackedStorage && bitWidth == 1) {
2937
return true;
3038
}
3139
// Require the original bit width to be some power of two for now to avoid
@@ -35,20 +43,31 @@ bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
3543
return bitWidth < 8 && llvm::isPowerOf2_32(bitWidth) && bitWidth != 1;
3644
}
3745

46+
bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
47+
return needToPackSubByteElementBitWidthImpl(
48+
bitWidth, /*isPackedStorage=*/clEnableI1Support);
49+
}
50+
3851
bool needToPackSubByteElements(RankedTensorType shapedType) {
3952
unsigned bitWidth = IREE::Util::getTypeBitWidth(shapedType.getElementType());
40-
return needToPackSubByteElementBitWidth(bitWidth);
53+
// Two paths to enable packed storage for i1 tensors: the attribute or cl
54+
// option. The cl option will be dropped once frontend supports emitting
55+
// tensors with attributes.
56+
bool isPackedStorage =
57+
IREE::Encoding::hasPackedStorageAttr(shapedType) || clEnableI1Support;
58+
return needToPackSubByteElementBitWidthImpl(bitWidth, isPackedStorage);
4159
}
4260

43-
Type legalizeStorageElementType(Type elementType) {
61+
static Type legalizeStorageElementTypeImpl(Type elementType,
62+
bool isPackedStorage) {
4463
// Only handle integers; floats in MLIR all have aligned widths (today).
4564
auto intType = dyn_cast<IntegerType>(elementType);
4665
if (!intType)
4766
return elementType;
4867

4968
// For sub-byte elements, default to pack them into bytes.
5069
unsigned bitWidth = intType.getWidth();
51-
if (needToPackSubByteElementBitWidth(bitWidth))
70+
if (needToPackSubByteElementBitWidthImpl(bitWidth, isPackedStorage))
5271
return elementType;
5372

5473
// Otherwise, extend them to the next power-of-two bit width.
@@ -60,6 +79,12 @@ Type legalizeStorageElementType(Type elementType) {
6079
intType.getSignedness());
6180
}
6281

82+
Type legalizeStorageElementType(Type elementType) {
83+
// Consider packed storage for i1 tensors if cl opt is set.
84+
return legalizeStorageElementTypeImpl(elementType,
85+
/*isPackedStorage=*/clEnableI1Support);
86+
}
87+
6388
Value calculateStorageElementCountInBytes(Location loc,
6489
RankedTensorType shapedType,
6590
ValueRange dynamicDims,
@@ -72,13 +97,15 @@ Value calculateStorageElementCountInBytes(Location loc,
7297
loc, builder, shapedType, dynamicDims);
7398
}
7499

75-
Type alignedElementType =
76-
legalizeStorageElementType(shapedType.getElementType());
100+
bool isPackedStorage =
101+
IREE::Encoding::hasPackedStorageAttr(shapedType) || clEnableI1Support;
102+
Type alignedElementType = legalizeStorageElementTypeImpl(
103+
shapedType.getElementType(), isPackedStorage);
77104
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);
78105

79106
// Calculate all static dims first, if any.
80107
int64_t staticCount = 1;
81-
if (!needToPackSubByteElementBitWidth(elementBits)) {
108+
if (!needToPackSubByteElementBitWidthImpl(elementBits, isPackedStorage)) {
82109
staticCount *= IREE::Util::getRoundedElementByteWidth(alignedElementType);
83110
}
84111

@@ -93,13 +120,13 @@ Value calculateStorageElementCountInBytes(Location loc,
93120
value = builder.createOrFold<arith::MulIOp>(loc, value, dim);
94121
}
95122
// Sub-byte packing requires putting multiple elements in the same byte.
96-
if (needToPackSubByteElementBitWidth(elementBits)) {
123+
if (needToPackSubByteElementBitWidthImpl(elementBits, isPackedStorage)) {
97124
assert(8 % elementBits == 0);
98125
unsigned byteElements = 8 / elementBits;
99126
// TODO(antiagainst): We may want to emit runtime check to make sure this is
100127
// divisible.
101128
auto divisor = builder.create<arith::ConstantIndexOp>(loc, byteElements);
102-
if (!clEnableI1Support && dynamicDims.empty() &&
129+
if (!isPackedStorage && dynamicDims.empty() &&
103130
(staticCount * elementBits) % 8 != 0) {
104131
return nullptr;
105132
}
@@ -113,12 +140,14 @@ Value calculateStorageElementOffsetInBytes(Location loc,
113140
RankedTensorType originalType,
114141
Value linearizedIndex,
115142
OpBuilder &builder) {
116-
Type alignedElementType =
117-
legalizeStorageElementType(originalType.getElementType());
143+
bool isPackedStorage =
144+
IREE::Encoding::hasPackedStorageAttr(originalType) || clEnableI1Support;
145+
Type alignedElementType = legalizeStorageElementTypeImpl(
146+
originalType.getElementType(), isPackedStorage);
118147
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);
119148

120149
// Sub-byte packing requires putting multiple elements in the same byte.
121-
if (needToPackSubByteElementBitWidth(elementBits)) {
150+
if (needToPackSubByteElementBitWidthImpl(elementBits, isPackedStorage)) {
122151
Value byteElements =
123152
builder.create<arith::ConstantIndexOp>(loc, 8 / elementBits);
124153
// TODO(antiagainst): We may want to emit runtime check to make sure this is

compiler/src/iree/compiler/Utils/ElementPackingUtils.h

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ namespace mlir::iree_compiler {
1616
/// Returns true if the given |bitWidth|, if appearing at runtime-kernel
1717
/// interface, is less than a byte that should be tightly packed together.
1818
bool needToPackSubByteElementBitWidth(unsigned bitWidth);
19+
1920
/// Returns true if the given |shapedType|, if appearing at runtime-kernel
2021
/// interface, has sub-byte element types that should be tightly packed
2122
/// together.

tests/e2e/subbyte_types/BUILD.bazel

+6-10
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,14 @@ package(
1818
licenses = ["notice"], # Apache 2.0
1919
)
2020

21-
LLVM_SRCS = enforce_glob(
22-
# keep sorted
23-
[
24-
"subbyte_types.mlir",
25-
],
26-
include = ["*.mlir"],
27-
exclude = [],
28-
)
29-
3021
iree_check_single_backend_test_suite(
3122
name = "check_llvm-cpu_subbyte_emulation",
32-
srcs = LLVM_SRCS,
23+
srcs = enforce_glob(
24+
[
25+
"subbyte_types.mlir",
26+
],
27+
include = ["*.mlir"],
28+
),
3329
compiler_flags = [
3430
"--iree-llvmcpu-target-cpu=generic",
3531
"--iree-experimental-packed-i1-storage",

0 commit comments

Comments
 (0)