Skip to content

Commit e27a766

Browse files
committed
Remove option and refactor
1 parent 419ee6b commit e27a766

File tree

14 files changed

+119
-65
lines changed

14 files changed

+119
-65
lines changed

compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,10 @@ RankedTensorType dropEncoding(RankedTensorType type) {
8383
return RankedTensorType::get(type.getShape(), type.getElementType());
8484
}
8585

86+
RankedTensorType dropPackedStorageEncodingIfAny(RankedTensorType type) {
87+
if (!IREE::Encoding::hasPackedStorageAttr(type))
88+
return type;
89+
return RankedTensorType::get(type.getShape(), type.getElementType());
90+
}
91+
8692
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
1111
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
1212
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
13+
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
1314
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
1415
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1516
#include "mlir/Transforms/DialectConversion.h"
@@ -79,6 +80,9 @@ class OpMaterializeEncodingPattern : public OpConversionPattern<OpTy> {
7980
/// Returns the RankedTensorType without encodings.
8081
RankedTensorType dropEncoding(RankedTensorType type);
8182

83+
/// Returns the RankedTensorType without packed storage encoding (if any).
84+
RankedTensorType dropPackedStorageEncodingIfAny(RankedTensorType type);
85+
8286
/// Utility method to convert from `set_encoding` op to `pack` operation.
8387
/// NOTE: `source` could be returned when packing is not needed.
8488
FailureOr<Value> lowerSetEncodingOpToPackOp(

compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp

+14-10
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
//===---------------------------------------------------------------------===//
2626

2727
#include "iree/compiler/Codegen/Common/Passes.h"
28+
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
2829
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
2930
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
3031
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
@@ -65,9 +66,8 @@ static Value convertElementType(OpBuilder &b, Location loc, Type targetType,
6566
/// std::nullopt.
6667
static std::optional<Type> getLegalizedType(Type t) {
6768
if (auto shapedType = llvm::dyn_cast<RankedTensorType>(t)) {
68-
Type elementType = shapedType.getElementType();
6969
std::optional<Type> legalizedElementType =
70-
legalizeStorageElementType(elementType);
70+
legalizeTensorStorageElementType(shapedType);
7171
if (!legalizedElementType)
7272
return std::nullopt;
7373
return RankedTensorType::get(shapedType.getShape(),
@@ -114,7 +114,7 @@ struct ConstantOpTypeConversion
114114
constantOp, "expected attribute type to be shaped type");
115115
}
116116
std::optional<Type> legalizedElementType =
117-
legalizeStorageElementType(attrType.getElementType());
117+
legalizeTensorStorageElementType(attrType);
118118
if (!legalizedElementType) {
119119
return rewriter.notifyMatchFailure(constantOp,
120120
"cannot legalize elementType");
@@ -220,8 +220,10 @@ struct GenericOpTypePropagation
220220
signatureConverter.addInputs(index, argType);
221221
continue;
222222
}
223+
auto inputOperandType =
224+
llvm::cast<RankedTensorType>(genericOp->getOperandTypes()[index]);
223225
std::optional<Type> legalizedArgType =
224-
legalizeStorageElementType(argType);
226+
legalizeTensorStorageElementType(inputOperandType);
225227
if (!legalizedArgType) {
226228
return genericOp.emitOpError("failed to get legalized type for arg ")
227229
<< index;
@@ -251,8 +253,8 @@ struct GenericOpTypePropagation
251253
modifyYield = true;
252254
OpOperand *yieldOperand =
253255
modifiedOp.getMatchingYieldValue(modifiedOpOperand);
254-
std::optional<Type> legalizedType =
255-
legalizeStorageElementType(yieldOperand->get().getType());
256+
std::optional<Type> legalizedType = legalizeTensorStorageElementType(
257+
modifiedOpOperand->get().getType());
256258
if (!legalizedType) {
257259
return genericOp.emitOpError(
258260
"failed to get legalized type for yield value");
@@ -282,7 +284,7 @@ struct LinalgFillTypePropagation
282284
ConversionPatternRewriter &rewriter) const final {
283285
Value value = adaptor.getInputs().front();
284286
std::optional<Type> legalizedElementType =
285-
legalizeStorageElementType(value.getType());
287+
legalizeTensorStorageElementType(adaptor.getOutputs()[0].getType());
286288
if (!legalizedElementType) {
287289
return fillOp.emitOpError("failed to get legalized type for value");
288290
}
@@ -348,8 +350,8 @@ struct IREELinalgExtScatterTypePropagation
348350
// type.
349351
TypeConverter::SignatureConversion signatureConverter(
350352
modifiedOpRegion.getNumArguments());
351-
Type argType = modifiedOpRegion.getArguments()[0].getType();
352-
std::optional<Type> legalizedArgType = legalizeStorageElementType(argType);
353+
std::optional<Type> legalizedArgType =
354+
legalizeTensorStorageElementType(inputType);
353355
if (!legalizedArgType) {
354356
return scatterOp.emitOpError("failed to get legalized type for argument");
355357
}
@@ -411,8 +413,10 @@ struct IREELinalgExtSortTypePropagation
411413
TypeConverter::SignatureConversion signatureConverter(
412414
modifiedOpRegion.getNumArguments());
413415
for (auto [index, arg] : llvm::enumerate(modifiedOpRegion.getArguments())) {
416+
// Refer to input types of the original operation to determine the
417+
// corresponding legal arg type.
414418
std::optional<Type> legalizedArgType =
415-
legalizeStorageElementType(arg.getType());
419+
legalizeTensorStorageElementType(sortOp->getOperandTypes()[index]);
416420
if (!legalizedArgType) {
417421
return sortOp.emitOpError("failed to get legalized type for argument");
418422
}

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,12 @@ EncodingAttr getEncodingAttr(RankedTensorType type) {
243243
return dyn_cast_or_null<EncodingAttr>(type.getEncoding());
244244
}
245245

246-
bool hasPackedStorageAttr(RankedTensorType type) {
247-
return dyn_cast_or_null<PackedStorageAttr>(type.getEncoding()) != nullptr;
246+
bool hasPackedStorageAttr(Type type) {
247+
if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
248+
return dyn_cast_or_null<PackedStorageAttr>(tensorType.getEncoding()) !=
249+
nullptr;
250+
}
251+
return false;
248252
}
249253

250254
FailureOr<linalg::ContractionDimensions>

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingTypes.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace mlir::iree_compiler::IREE::Encoding {
3939
EncodingAttr getEncodingAttr(RankedTensorType type);
4040

4141
/// Returns true if the type contains packed_storage attribute.
42-
bool hasPackedStorageAttr(RankedTensorType type);
42+
bool hasPackedStorageAttr(Type type);
4343

4444
/// Returns the ContractionDimensions for the encoding user_indexing_maps.
4545
FailureOr<linalg::ContractionDimensions>

compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.h"
88

9+
#include "iree/compiler/Codegen/Common/EncodingUtils.h"
10+
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
911
#include "iree/compiler/Dialect/HAL/Analysis/Captures.h"
1012
#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h"
1113
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
@@ -478,7 +480,8 @@ struct TensorExportBufferViewOpPattern
478480
}
479481

480482
auto loc = exportOp.getLoc();
481-
auto tensorType = llvm::cast<RankedTensorType>(adaptor.getSourceEncoding());
483+
auto tensorType = dropPackedStorageEncodingIfAny(
484+
llvm::cast<RankedTensorType>(adaptor.getSourceEncoding()));
482485
auto dynamicDims = adaptor.getSourceEncodingDims();
483486

484487
// NOTE: we should have verified supported encodings/types at entry into the

compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
88

9+
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
910
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
1011
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
1112
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
@@ -27,6 +28,10 @@
2728
#include "mlir/Support/LogicalResult.h"
2829
#include "mlir/Transforms/RegionUtils.h"
2930

31+
namespace mlir::iree_compiler {
32+
using IREE::Encoding::getEncodingAttr;
33+
}
34+
3035
namespace mlir::iree_compiler::IREE::Stream {
3136

3237
//===----------------------------------------------------------------------===//
@@ -1512,7 +1517,7 @@ LogicalResult TensorCloneOp::verify() {
15121517
// information.
15131518
auto sourceEncoding = llvm::cast<RankedTensorType>(op.getSourceEncoding());
15141519
auto resultEncoding = llvm::cast<RankedTensorType>(op.getResultEncoding());
1515-
if (sourceEncoding.getEncoding() != resultEncoding.getEncoding()) {
1520+
if (getEncodingAttr(sourceEncoding) != getEncodingAttr(resultEncoding)) {
15161521
return op.emitOpError() << "clones changing tensor encoding from "
15171522
<< sourceEncoding.getEncoding() << " to "
15181523
<< resultEncoding.getEncoding() << "; not allowed";

compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
78
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
89
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
910
#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
@@ -22,6 +23,7 @@
2223
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
2324
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
2425
#include "iree/compiler/Dialect/Util/Transforms/Patterns.h"
26+
#include "llvm/Support/Casting.h"
2527
#include "mlir/Dialect/Arith/IR/Arith.h"
2628
#include "mlir/Dialect/Func/IR/FuncOps.h"
2729
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -247,6 +249,12 @@ struct ConvertToStreamPass final
247249
if (llvm::isa<IREE::Flow::ChannelType>(type)) {
248250
return IREE::Stream::ChannelType::get(context);
249251
}
252+
if (auto rankedType = llvm::dyn_cast_or_null<RankedTensorType>(type)) {
253+
if (IREE::Encoding::hasPackedStorageAttr(rankedType)) {
254+
return RankedTensorType::get(rankedType.getShape(),
255+
rankedType.getElementType());
256+
}
257+
}
250258
return !llvm::isa<TensorType>(type) ? type : Type{};
251259
});
252260

compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ static LogicalResult checkEncoding(Operation *op, RankedTensorType encodingType,
5858
// Aligns the element type of a tensor<> to a byte-aligned power of 2 bit width.
5959
static RankedTensorType alignTensorType(RankedTensorType originalType) {
6060
Type elementType = originalType.getElementType();
61-
Type alignedType = legalizeStorageElementType(elementType);
61+
Type alignedType = legalizeTensorStorageElementType(originalType);
6262
if (alignedType == elementType)
6363
return originalType;
6464
return RankedTensorType::get(originalType.getShape(), alignedType,
@@ -620,7 +620,8 @@ struct EncodeHostTensorsPass
620620
static IREE::Flow::DispatchTensorType
621621
alignDispatchTensorType(IREE::Flow::DispatchTensorType originalType) {
622622
Type elementType = originalType.getBoundElementType();
623-
Type alignedType = legalizeStorageElementType(elementType);
623+
Type alignedType =
624+
legalizeTensorStorageElementType(originalType.asRankedTensorType());
624625
if (alignedType == elementType)
625626
return originalType;
626627
return IREE::Flow::DispatchTensorType::get(

compiler/src/iree/compiler/Dialect/Stream/Transforms/test/encode_host_tensors_packing_i1_attr.mlir

+7
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,10 @@ func.func @aligned_i1_size() -> index {
2020
// CHECK: func @aligned_i1_size() -> index {
2121
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
2222
// CHECK: return %[[C3]] : index
23+
24+
// -----
25+
26+
#packed = #iree_encoding.packed_storage
27+
func.func @packed_i1_input_output(%input : tensor<16xi1, #packed>) -> tensor<16xi1, #packed> {
28+
return %input : tensor<16xi1, #packed>
29+
}

compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp

+28-48
Original file line numberDiff line numberDiff line change
@@ -15,46 +15,23 @@
1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/IR/BuiltinTypes.h"
1717

18-
llvm::cl::opt<bool> clEnableI1Support(
19-
"iree-experimental-packed-i1-storage",
20-
llvm::cl::desc(
21-
"Experimental feature: force to use packed storage for i1 tensors."
22-
"Turning on this option will see i1 tensors as if it has "
23-
"#iree_encoding.packed_storage attribute."
24-
"This is to allow an alternative way to test the packed storage "
25-
"feature before frontend can emit packed i1 tensors."
26-
"This option can be dropped once the frontend can emit packed i1 "
27-
"tensors."),
28-
llvm::cl::init(false));
29-
3018
namespace mlir::iree_compiler {
3119

32-
static bool needToPackSubByteElementBitWidthImpl(unsigned bitWidth,
33-
bool isPackedStorage) {
34-
// Enable i1 support if requested.
35-
if (isPackedStorage && bitWidth == 1) {
36-
return true;
37-
}
20+
bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
3821
// Require the original bit width to be some power of two for now to avoid
3922
// trickiness and weirdness of packing and cross-byte access.
4023
// Also disallow boolean values for now--they may require separate interface
4124
// choices.
4225
return bitWidth < 8 && llvm::isPowerOf2_32(bitWidth) && bitWidth != 1;
4326
}
4427

45-
bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
46-
return needToPackSubByteElementBitWidthImpl(
47-
bitWidth, /*isPackedStorage=*/clEnableI1Support);
48-
}
49-
5028
bool needToPackSubByteElements(RankedTensorType shapedType) {
5129
unsigned bitWidth = IREE::Util::getTypeBitWidth(shapedType.getElementType());
52-
// Two paths to enable packed storage for i1 tensors: the attribute or cl
53-
// option. The cl option will be dropped once frontend supports emitting
54-
// tensors with attributes.
55-
bool isPackedStorage =
56-
IREE::Encoding::hasPackedStorageAttr(shapedType) || clEnableI1Support;
57-
return needToPackSubByteElementBitWidthImpl(bitWidth, isPackedStorage);
30+
// i1 with packed memory layout does not need to be extended.
31+
if (bitWidth == 1 && !IREE::Encoding::hasPackedStorageAttr(shapedType)) {
32+
return true;
33+
}
34+
return needToPackSubByteElementBitWidth(bitWidth);
5835
}
5936

6037
static Type legalizeStorageElementTypeImpl(Type elementType,
@@ -64,9 +41,13 @@ static Type legalizeStorageElementTypeImpl(Type elementType,
6441
if (!intType)
6542
return elementType;
6643

67-
// For sub-byte elements, default to pack them into bytes.
6844
unsigned bitWidth = intType.getWidth();
69-
if (needToPackSubByteElementBitWidthImpl(bitWidth, isPackedStorage))
45+
if (bitWidth == 1 && !isPackedStorage) {
46+
return elementType;
47+
}
48+
49+
// For sub-byte elements, default to pack them into bytes.
50+
if (needToPackSubByteElementBitWidth(bitWidth))
7051
return elementType;
7152

7253
// Otherwise, extend them to the next power-of-two bit width.
@@ -78,10 +59,10 @@ static Type legalizeStorageElementTypeImpl(Type elementType,
7859
intType.getSignedness());
7960
}
8061

81-
Type legalizeStorageElementType(Type elementType) {
82-
// Consider packed storage for i1 tensors if cl opt is set.
83-
return legalizeStorageElementTypeImpl(elementType,
84-
/*isPackedStorage=*/clEnableI1Support);
62+
Type legalizeTensorStorageElementType(Type type) {
63+
auto tensorType = llvm::dyn_cast<RankedTensorType>(type);
64+
return legalizeStorageElementTypeImpl(
65+
type, tensorType && IREE::Encoding::hasPackedStorageAttr(type));
8566
}
8667

8768
Value calculateStorageElementCountInBytes(Location loc,
@@ -96,16 +77,16 @@ Value calculateStorageElementCountInBytes(Location loc,
9677
loc, builder, shapedType, dynamicDims);
9778
}
9879

99-
// TODO(lialan): remove cl options once frontend can emit packed i1 tensors.
100-
bool isPackedStorage =
101-
IREE::Encoding::hasPackedStorageAttr(shapedType) || clEnableI1Support;
102-
Type alignedElementType = legalizeStorageElementTypeImpl(
103-
shapedType.getElementType(), isPackedStorage);
80+
Type alignedElementType = legalizeTensorStorageElementType(shapedType);
10481
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);
10582

83+
bool isPackedStorage = IREE::Encoding::hasPackedStorageAttr(shapedType);
84+
bool isI1WithPackedStorage = elementBits == 1 && isPackedStorage;
85+
10686
// Calculate all static dims first, if any.
10787
int64_t staticCount = 1;
108-
if (!needToPackSubByteElementBitWidthImpl(elementBits, isPackedStorage)) {
88+
if (!isI1WithPackedStorage &&
89+
!needToPackSubByteElementBitWidth(elementBits)) {
10990
staticCount *= IREE::Util::getRoundedElementByteWidth(alignedElementType);
11091
}
11192

@@ -120,7 +101,7 @@ Value calculateStorageElementCountInBytes(Location loc,
120101
value = builder.createOrFold<arith::MulIOp>(loc, value, dim);
121102
}
122103
// Sub-byte packing requires putting multiple elements in the same byte.
123-
if (needToPackSubByteElementBitWidthImpl(elementBits, isPackedStorage)) {
104+
if (isI1WithPackedStorage || needToPackSubByteElementBitWidth(elementBits)) {
124105
assert(8 % elementBits == 0);
125106
unsigned byteElements = 8 / elementBits;
126107
// TODO(antiagainst): We may want to emit runtime check to make sure this is
@@ -140,15 +121,14 @@ Value calculateStorageElementOffsetInBytes(Location loc,
140121
RankedTensorType originalType,
141122
Value linearizedIndex,
142123
OpBuilder &builder) {
143-
// TODO: remove cl options once frontend can emit packed i1 tensors.
144-
bool isPackedStorage =
145-
IREE::Encoding::hasPackedStorageAttr(originalType) || clEnableI1Support;
146-
Type alignedElementType = legalizeStorageElementTypeImpl(
147-
originalType.getElementType(), isPackedStorage);
124+
Type alignedElementType = legalizeTensorStorageElementType(originalType);
148125
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);
149126

127+
bool isPackedStorage = IREE::Encoding::hasPackedStorageAttr(originalType);
128+
bool isI1WithPackedStorage = elementBits == 1 && isPackedStorage;
129+
150130
// Sub-byte packing requires putting multiple elements in the same byte.
151-
if (needToPackSubByteElementBitWidthImpl(elementBits, isPackedStorage)) {
131+
if (isI1WithPackedStorage || needToPackSubByteElementBitWidth(elementBits)) {
152132
Value byteElements =
153133
builder.create<arith::ConstantIndexOp>(loc, 8 / elementBits);
154134
// TODO(antiagainst): We may want to emit runtime check to make sure this is

compiler/src/iree/compiler/Utils/ElementPackingUtils.h

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ bool needToPackSubByteElements(RankedTensorType shapedType);
3030
/// cases.
3131
Type legalizeStorageElementType(Type elementType);
3232

33+
Type legalizeTensorStorageElementType(Type tensorType);
34+
3335
/// Emits IR with the given |builder| to calculate the total number of bytes
3436
/// required for the given |shapedType| in storage. Returns the value for the
3537
/// final count on success; returns nullptr on failure. Dynamic dimensions in

0 commit comments

Comments
 (0)