Skip to content

Commit 97a0ed4

Browse files
committed
interface
1 parent 900ef1d commit 97a0ed4

File tree

10 files changed

+100
-28
lines changed

10 files changed

+100
-28
lines changed

compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "iree/compiler/Codegen/Common/EncodingUtils.h"
88
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
9+
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
910
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
1011
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1112
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -83,7 +84,8 @@ static RankedTensorType transposeIfNarrowNResult(RankedTensorType tensorType) {
8384
}
8485
auto newEncoding = IREE::Encoding::EncodingAttr::get(
8586
context, operandIndex, encoding.getOpType().getValue(),
86-
encoding.getElementTypesArray(), maps, newBcastMap, newRoundDimsTo);
87+
encoding.getElementTypesArray(), maps, newBcastMap, newRoundDimsTo,
88+
encoding.getLayouts());
8789
return RankedTensorType::get(newShape, elemType, newEncoding);
8890
}
8991

@@ -133,7 +135,9 @@ MaterializeEncodingConversionTarget::MaterializeEncodingConversionTarget(
133135
markUnknownOpDynamicallyLegal([](Operation *op) {
134136
auto typeHasEncoding = [](Type t) -> bool {
135137
auto tensorType = dyn_cast<RankedTensorType>(t);
136-
return tensorType && tensorType.getEncoding();
138+
if (!(tensorType && tensorType.getEncoding()))
139+
return false;
140+
return !IREE::Encoding::hasPackedStorageAttr(tensorType);
137141
};
138142
auto valueHasEncoding = [=](Value v) -> bool {
139143
return typeHasEncoding(v.getType());

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ EncodingAttr EncodingAttr::clone(AffineMap bcastMap) {
111111
AffineMapAttr::get(bcastMap), getRoundDimsTo(), getLayouts());
112112
}
113113

114+
bool EncodingAttr::i1PackedStorage() const {
115+
return llvm::any_of(getLayouts(), [&](const Attribute &a) {
116+
return llvm::isa<PackedStorageAttr>(a);
117+
});
118+
}
119+
114120
MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding) {
115121
if (encoding.getOpType().getValue() != EncodingOpType::matmul) {
116122
return {};
@@ -142,6 +148,10 @@ EncodingAttr getEncodingAttr(RankedTensorType type) {
142148
return dyn_cast_or_null<EncodingAttr>(type.getEncoding());
143149
}
144150

151+
bool hasPackedStorageAttr(RankedTensorType type) {
152+
return dyn_cast_or_null<PackedStorageAttr>(type.getEncoding()) != nullptr;
153+
}
154+
145155
FailureOr<linalg::ContractionDimensions>
146156
getEncodingContractionDims(EncodingAttr encoding) {
147157
auto indexingMapsAttr = encoding.getUserIndexingMaps();

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td

+14
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ def EncodingOpType : IREEEncoding_I32EnumAttr<"EncodingOpType",
4040
def EncodingOpTypeAttr:
4141
IREEEncoding_EnumAttr<EncodingOpType, "optype">;
4242

43+
44+
def PackedStorageAttr : IREEEncoding_Attr<"PackedStorage"> {
45+
let mnemonic = "packed_storage";
46+
let summary = [{Indicates packed storage datatype.}];
47+
let description = [{
48+
This attribute indicates this is a back-to-back packed storage in memory.
49+
This attribute takes no arguments.
50+
}];
51+
let genVerifyDecl = 0;
52+
}
53+
4354
def EncodingAttr :
4455
IREEEncoding_Attr<"Encoding"> {
4556
let mnemonic = "encoding";
@@ -108,6 +119,9 @@ def EncodingAttr :
108119

109120
/// Clones an encoding with a new bcast_map
110121
EncodingAttr clone(AffineMap bcastMap);
122+
123+
/// Returns whether the encoding indicates packed storage for i1 datatype.
124+
bool i1PackedStorage() const;
111125
}];
112126

113127
let genVerifyDecl = 0;

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h

+3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ namespace mlir::iree_compiler::IREE::Encoding {
3838
/// Otherwise, returns null.
3939
EncodingAttr getEncodingAttr(RankedTensorType type);
4040

41+
/// Returns true if the type contains packed_storage attribute.
42+
bool hasPackedStorageAttr(RankedTensorType type);
43+
4144
/// Returns the ContractionDimensions for the encoding user_indexing_maps.
4245
FailureOr<linalg::ContractionDimensions>
4346
getEncodingContractionDims(EncodingAttr encoding);

compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
1313
#include "mlir/Dialect/Arith/IR/Arith.h"
1414

15+
namespace mlir::iree_compiler::IREE::Encoding {
16+
bool hasPackedStorageAttr(mlir::RankedTensorType);
17+
} // namespace mlir::iree_compiler::IREE::Encoding
18+
1519
namespace mlir::iree_compiler {
1620

1721
namespace {
@@ -90,6 +94,11 @@ struct ConvertTensorImportOp
9094
RankedTensorType tensorType,
9195
ValueRange dynamicDims,
9296
OpBuilder &builder) {
97+
// If the encoding attr is about packed storage then we don't need all this
98+
if (IREE::Encoding::hasPackedStorageAttr(tensorType)) {
99+
return success();
100+
}
101+
93102
auto expectedElementType = builder.create<IREE::HAL::ElementTypeOp>(
94103
loc, tensorType.getElementType());
95104
auto expectedEncodingType = builder.create<IREE::HAL::EncodingTypeOp>(

compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
88
// TODO(benvanik): have a stream/upstream equivalent of the flow.dispatch.* ops.
9+
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
910
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
1011
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
1112
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
@@ -46,7 +47,8 @@ static LogicalResult checkEncoding(Operation *op, RankedTensorType encodingType,
4647
ValueRange encodingDims,
4748
PatternRewriter &rewriter) {
4849
auto encoding = encodingType.getEncoding();
49-
if (encoding && !llvm::isa<IREE::Encoding::EncodingAttr>(encoding)) {
50+
if (encoding && !llvm::isa<IREE::Encoding::EncodingAttr,
51+
IREE::Encoding::PackedStorageAttr>(encoding)) {
5052
return rewriter.notifyMatchFailure(op, [=](Diagnostic &d) {
5153
d << "unsupported tensor encoding: " << encodingType;
5254
});

compiler/src/iree/compiler/Dialect/Stream/Transforms/test/encode_host_tensors_packing_i1.mlir

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
// RUN: iree-opt --split-input-file --iree-stream-encode-host-tensors --iree-experimental-packed-i1-storage %s | FileCheck %s
1+
// RUN: iree-opt --split-input-file --iree-stream-encode-host-tensors %s | FileCheck %s
22

3+
#packed = #iree_encoding.packed_storage
34
func.func @unaligned_i1_size() -> index {
4-
%0 = stream.tensor.sizeof tensor<12xi1> : index
5+
%0 = stream.tensor.sizeof tensor<12xi1, #packed> : index
56
return %0 : index
67
}
78
// CHECK: func @unaligned_i1_size() -> index {
@@ -10,8 +11,9 @@ func.func @unaligned_i1_size() -> index {
1011

1112
// -----
1213

14+
#packed = #iree_encoding.packed_storage
1315
func.func @aligned_i1_size() -> index {
14-
%0 = stream.tensor.sizeof tensor<24xi1> : index
16+
%0 = stream.tensor.sizeof tensor<24xi1, #packed> : index
1517
return %0 : index
1618
}
1719

compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp

+27-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1313
#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
1414
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
15+
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
1516
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
1617
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
1718
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
@@ -25,18 +26,33 @@
2526
#include "mlir/Support/LLVM.h"
2627
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2728

29+
llvm::cl::opt<bool> clEnableI1Support(
30+
"iree-experimental-packed-i1-storage",
31+
llvm::cl::desc(
32+
"Experimental feature: use packed storage for i1 tensors. This feature "
33+
"can be dropped once the frontend can emit packed i1 tensors."),
34+
llvm::cl::init(false));
35+
2836
#define DEBUG_TYPE "iree-dispatch-creation-set-encoding"
2937

3038
namespace mlir::iree_compiler::DispatchCreation {
3139
#define GEN_PASS_DEF_SETENCODINGPASS
3240
#include "iree/compiler/DispatchCreation/Passes.h.inc"
3341

3442
using IREE::Encoding::EncodingAttr;
43+
using IREE::Encoding::PackedStorageAttr;
3544

3645
//===---------------------------------------------------------------------===//
3746
// Utility functions
3847
//===---------------------------------------------------------------------===//
3948

49+
static std::optional<Attribute> getI1PackedStorageAttr(MLIRContext *context) {
50+
if (clEnableI1Support) {
51+
return PackedStorageAttr::get(context);
52+
}
53+
return {};
54+
}
55+
4056
Value setEncoding(OpBuilder &builder, Location loc, Value source,
4157
EncodingAttr encodingAttr) {
4258
auto sourceType = cast<RankedTensorType>(source.getType());
@@ -216,9 +232,17 @@ class setContractionOpEncoding
216232
if (narrowDim.isN()) {
217233
roundDimsTo[1] = llvm::PowerOf2Ceil(narrowDim.size);
218234
}
219-
auto encoding = EncodingAttr::get(linalgOp.getContext(), operandIndex,
220-
opType, elemTypes, maps,
221-
/*bcastMap=*/std::nullopt, roundDimsTo);
235+
// This is a temporary solution that we can use
236+
// `--iree-experimental-packed-i1-storage` to enable packed storage for i1
237+
// tensors without specifying attributes.
238+
auto optI1PackedStorageAttr =
239+
getI1PackedStorageAttr(rewriter.getContext());
240+
auto layouts = optI1PackedStorageAttr
241+
? ArrayRef<Attribute>(*optI1PackedStorageAttr)
242+
: ArrayRef<Attribute>{};
243+
auto encoding = EncodingAttr::get(
244+
linalgOp.getContext(), operandIndex, opType, elemTypes, maps,
245+
/*bcastMap=*/std::nullopt, roundDimsTo, layouts);
222246
return setEncoding(rewriter, loc, src, encoding);
223247
};
224248
Value encodedLhs = setEncodingWrapper(lhs, IREE::Encoding::MATMUL_LHS);

compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp

+18-17
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,10 @@
1515

1616
namespace mlir::iree_compiler {
1717

18-
llvm::cl::opt<bool> clEnableI1Support(
19-
"iree-experimental-packed-i1-storage",
20-
llvm::cl::desc(
21-
"Experimental feature: enable i1 data type support in codegen"),
22-
llvm::cl::init(false));
23-
24-
bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
18+
bool needToPackSubByteElementBitWidth(unsigned bitWidth,
19+
bool isI1PackedStorage) {
2520
// Enable i1 support if requested.
26-
if (clEnableI1Support && bitWidth == 1) {
21+
if (isI1PackedStorage && bitWidth == 1) {
2722
return true;
2823
}
2924
// Require the original bit width to be some power of two for now to avoid
@@ -35,18 +30,20 @@ bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
3530

3631
bool needToPackSubByteElements(RankedTensorType shapedType) {
3732
unsigned bitWidth = IREE::Util::getTypeBitWidth(shapedType.getElementType());
38-
return needToPackSubByteElementBitWidth(bitWidth);
33+
auto encoding = IREE::Encoding::getEncodingAttr(shapedType);
34+
bool isI1PackedStorage = encoding && encoding.i1PackedStorage();
35+
return needToPackSubByteElementBitWidth(bitWidth, isI1PackedStorage);
3936
}
4037

41-
Type legalizeStorageElementType(Type elementType) {
38+
Type legalizeStorageElementType(Type elementType, bool isI1PackedStorage) {
4239
// Only handle integers; floats in MLIR all have aligned widths (today).
4340
auto intType = dyn_cast<IntegerType>(elementType);
4441
if (!intType)
4542
return elementType;
4643

4744
// For sub-byte elements, default to pack them into bytes.
4845
unsigned bitWidth = intType.getWidth();
49-
if (needToPackSubByteElementBitWidth(bitWidth))
46+
if (needToPackSubByteElementBitWidth(bitWidth, isI1PackedStorage))
5047
return elementType;
5148

5249
// Otherwise, extend them to the next power-of-two bit width.
@@ -62,21 +59,22 @@ Value calculateStorageElementCountInBytes(Location loc,
6259
RankedTensorType shapedType,
6360
ValueRange dynamicDims,
6461
OpBuilder &builder) {
62+
auto encoding = IREE::Encoding::getEncodingAttr(shapedType);
63+
bool packedStorage = IREE::Encoding::hasPackedStorageAttr(shapedType);
6564
Type alignedElementType =
66-
legalizeStorageElementType(shapedType.getElementType());
65+
legalizeStorageElementType(shapedType.getElementType(), packedStorage);
6766
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);
6867

6968
// Calculate all static dims first, if any.
7069
int64_t staticCount = 1;
71-
if (!needToPackSubByteElementBitWidth(elementBits)) {
70+
if (!needToPackSubByteElementBitWidth(elementBits, packedStorage)) {
7271
staticCount *= IREE::Util::getRoundedElementByteWidth(alignedElementType);
7372
}
7473

7574
// TODO: Do we use makeComposedFoldedAffineApply here, so the index
7675
// computation an be much simpler.
7776
SmallVector<int64_t> paddedShape(shapedType.getShape());
7877
SmallVector<Value> paddedDynamicDims(dynamicDims.begin(), dynamicDims.end());
79-
auto encoding = IREE::Encoding::getEncodingAttr(shapedType);
8078
if (encoding && !encoding.getRoundDimsToArray().empty()) {
8179
auto roundDimsTo = encoding.getRoundDimsToArray();
8280
FailureOr<linalg::ContractionDimensions> cDims =
@@ -121,13 +119,13 @@ Value calculateStorageElementCountInBytes(Location loc,
121119
value = builder.createOrFold<arith::MulIOp>(loc, value, dim);
122120
}
123121
// Sub-byte packing requires putting multiple elements in the same byte.
124-
if (needToPackSubByteElementBitWidth(elementBits)) {
122+
if (needToPackSubByteElementBitWidth(elementBits, packedStorage)) {
125123
assert(8 % elementBits == 0);
126124
unsigned byteElements = 8 / elementBits;
127125
// TODO(antiagainst): We may want to emit runtime check to make sure this is
128126
// divisible.
129127
auto divisor = builder.create<arith::ConstantIndexOp>(loc, byteElements);
130-
if (!clEnableI1Support && paddedDynamicDims.empty() &&
128+
if (!packedStorage && paddedDynamicDims.empty() &&
131129
(staticCount * elementBits) % 8 != 0) {
132130
return nullptr;
133131
}
@@ -145,8 +143,11 @@ Value calculateStorageElementOffsetInBytes(Location loc,
145143
legalizeStorageElementType(originalType.getElementType());
146144
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);
147145

146+
auto encoding = IREE::Encoding::getEncodingAttr(originalType);
147+
bool isI1PackedStorage = encoding && encoding.i1PackedStorage();
148+
148149
// Sub-byte packing requires putting multiple elements in the same byte.
149-
if (needToPackSubByteElementBitWidth(elementBits)) {
150+
if (needToPackSubByteElementBitWidth(elementBits, isI1PackedStorage)) {
150151
Value byteElements =
151152
builder.create<arith::ConstantIndexOp>(loc, 8 / elementBits);
152153
// TODO(antiagainst): We may want to emit runtime check to make sure this is

compiler/src/iree/compiler/Utils/ElementPackingUtils.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ namespace mlir::iree_compiler {
1515

1616
/// Returns true if the given |bitWidth|, if appearing at runtime-kernel
1717
/// interface, is less than a byte that should be tightly packed together.
18-
bool needToPackSubByteElementBitWidth(unsigned bitWidth);
18+
bool needToPackSubByteElementBitWidth(unsigned bitWidth,
19+
bool isI1PackedStorage = false);
20+
1921
/// Returns true if the given |shapedType|, if appearing at runtime-kernel
2022
/// interface, has sub-byte element types that should be tightly packed
2123
/// together.
@@ -27,7 +29,8 @@ bool needToPackSubByteElements(RankedTensorType shapedType);
2729
/// runtime and kernel. For such cases, we perform tight packing for supported
2830
/// sub-byte elements, and expand to the next power-of-two bit width for other
2931
/// cases.
30-
Type legalizeStorageElementType(Type elementType);
32+
Type legalizeStorageElementType(Type elementType,
33+
bool isI1PackedStorage = false);
3134

3235
/// Emits IR with the given |builder| to calculate the total number of bytes
3336
/// required for the given |shapedType| in storage. Returns the value for the

0 commit comments

Comments
 (0)