Skip to content

Commit 30ab65f

Browse files
committed
Updates
1 parent b7a7d07 commit 30ab65f

File tree

1 file changed

+44
-11
lines changed

1 file changed

+44
-11
lines changed

compiler/src/iree/compiler/GlobalOptimization/PackedStorage.cpp

+44-11
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1212
#include "mlir/Pass/Pass.h"
1313
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14+
#include "mlir/Dialect/Func/IR/FuncOps.h"
1415

1516
#define DEBUG_TYPE "iree-global-opt-pack-storage"
1617

@@ -19,34 +20,66 @@ namespace mlir::iree_compiler::GlobalOptimization {
1920
#define GEN_PASS_DEF_PACKSTORAGEPASS
2021
#include "iree/compiler/GlobalOptimization/Passes.h.inc"
2122

23+
2224
static RankedTensorType appendAttributeToTensor(RankedTensorType type) {
23-
IREE::Encoding::PackedStorageAttr packedAttr;
24-
return RankedTensorType::get(type.getShape(), type.getElementType(),
25+
IntegerAttr bitwidthAttr =
26+
IntegerAttr::get(IntegerType::get(type.getContext(), 32),
27+
type.getElementType().getIntOrFloatBitWidth());
28+
IREE::Encoding::PackedStorageAttr packedAttr =
29+
IREE::Encoding::PackedStorageAttr::get(type.getContext(), bitwidthAttr);
30+
auto newType = RankedTensorType::get(type.getShape(), type.getElementType(),
2531
packedAttr);
32+
assert(mlir::iree_compiler::IREE::Encoding::hasPackedStorageAttr(newType));
33+
LLVM_DEBUG(llvm::dbgs() << " appending packed tensor type: " << newType << "\n");
34+
return newType;
2635
}
2736

37+
struct PackAttributeSignaturePattern : public OpConversionPattern<func::FuncOp> {
38+
using OpConversionPattern<func::FuncOp>::OpConversionPattern;
39+
40+
LogicalResult
41+
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
42+
ConversionPatternRewriter &rewriter) const override {
43+
44+
TypeConverter::SignatureConversion convertedResult(
45+
funcOp.getNumArguments());
46+
if (failed(getTypeConverter()->convertSignatureArgs(
47+
funcOp.getArgumentTypes(), convertedResult)))
48+
return failure();
49+
rewriter.modifyOpInPlace(funcOp, [&] {
50+
rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(),
51+
convertedResult);
52+
});
53+
return success();
54+
}
55+
};
56+
57+
2858
struct PackStoragePass : impl::PackStoragePassBase<PackStoragePass> {
2959

3060
void getDependentDialects(DialectRegistry &registry) const override {
3161
registry.insert<tensor::TensorDialect>();
3262
}
3363
void runOnOperation() override;
64+
65+
static bool isPackStorageCandidate(RankedTensorType type) {
66+
auto elementType = type.getElementType();
67+
return elementType.isIntOrFloat() &&
68+
elementType.getIntOrFloatBitWidth() == 1;
69+
}
3470
};
3571

3672
void PackStoragePass::runOnOperation() {
3773
auto funcOp = getOperation();
3874
LLVM_DEBUG(llvm::dbgs() << "== Running PackStoragePass on "
3975
<< funcOp.getName() << "\n");
40-
/*
41-
for (auto & arg : funcOp.getArguments()) {
42-
if(auto tensorType = dyn_cast<RankedTensorType>(arg.getType())) {
43-
auto elementType = tensorType.getElementType();
44-
if (elementType.isIntOrFloat() && elementType.getIntOrFloatBitWidth() ==
45-
1) { arg.setType(appendAttributeToTensor(tensorType));
46-
}
47-
}
76+
RewritePatternSet conversionPatterns(&getContext());
77+
conversionPatterns.add<PackAttributeSignaturePattern>(&getContext());
78+
79+
if (failed(applyPatternsAndFoldGreedily(funcOp,
80+
std::move(conversionPatterns)))) {
81+
signalPassFailure();
4882
}
49-
*/
5083
}
5184

5285
} // namespace mlir::iree_compiler::GlobalOptimization

0 commit comments

Comments
 (0)