11
11
#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
12
12
#include " mlir/Pass/Pass.h"
13
13
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
14
+ #include " mlir/Dialect/Func/IR/FuncOps.h"
14
15
15
16
#define DEBUG_TYPE " iree-global-opt-pack-storage"
16
17
@@ -19,34 +20,66 @@ namespace mlir::iree_compiler::GlobalOptimization {
19
20
#define GEN_PASS_DEF_PACKSTORAGEPASS
20
21
#include " iree/compiler/GlobalOptimization/Passes.h.inc"
21
22
23
+
22
24
static RankedTensorType appendAttributeToTensor (RankedTensorType type) {
23
- IREE::Encoding::PackedStorageAttr packedAttr;
24
- return RankedTensorType::get (type.getShape (), type.getElementType (),
25
+ IntegerAttr bitwidthAttr =
26
+ IntegerAttr::get (IntegerType::get (type.getContext (), 32 ),
27
+ type.getElementType ().getIntOrFloatBitWidth ());
28
+ IREE::Encoding::PackedStorageAttr packedAttr =
29
+ IREE::Encoding::PackedStorageAttr::get (type.getContext (), bitwidthAttr);
30
+ auto newType = RankedTensorType::get (type.getShape (), type.getElementType (),
25
31
packedAttr);
32
+ assert (mlir::iree_compiler::IREE::Encoding::hasPackedStorageAttr (newType));
33
+ LLVM_DEBUG (llvm::dbgs () << " appending packed tensor type: " << newType << " \n " );
34
+ return newType;
26
35
}
27
36
37
+ struct PackAttributeSignaturePattern : public OpConversionPattern <func::FuncOp> {
38
+ using OpConversionPattern<func::FuncOp>::OpConversionPattern;
39
+
40
+ LogicalResult
41
+ matchAndRewrite (func::FuncOp funcOp, OpAdaptor adaptor,
42
+ ConversionPatternRewriter &rewriter) const override {
43
+
44
+ TypeConverter::SignatureConversion convertedResult (
45
+ funcOp.getNumArguments ());
46
+ if (failed (getTypeConverter ()->convertSignatureArgs (
47
+ funcOp.getArgumentTypes (), convertedResult)))
48
+ return failure ();
49
+ rewriter.modifyOpInPlace (funcOp, [&] {
50
+ rewriter.applySignatureConversion (&funcOp.getFunctionBody ().front (),
51
+ convertedResult);
52
+ });
53
+ return success ();
54
+ }
55
+ };
56
+
57
+
28
58
struct PackStoragePass : impl::PackStoragePassBase<PackStoragePass> {
29
59
30
60
void getDependentDialects (DialectRegistry ®istry) const override {
31
61
registry.insert <tensor::TensorDialect>();
32
62
}
33
63
void runOnOperation () override ;
64
+
65
+ static bool isPackStorageCandidate (RankedTensorType type) {
66
+ auto elementType = type.getElementType ();
67
+ return elementType.isIntOrFloat () &&
68
+ elementType.getIntOrFloatBitWidth () == 1 ;
69
+ }
34
70
};
35
71
36
72
void PackStoragePass::runOnOperation () {
37
73
auto funcOp = getOperation ();
38
74
LLVM_DEBUG (llvm::dbgs () << " == Running PackStoragePass on "
39
75
<< funcOp.getName () << " \n " );
40
- /*
41
- for (auto & arg : funcOp.getArguments()) {
42
- if(auto tensorType = dyn_cast<RankedTensorType>(arg.getType())) {
43
- auto elementType = tensorType.getElementType();
44
- if (elementType.isIntOrFloat() && elementType.getIntOrFloatBitWidth() ==
45
- 1) { arg.setType(appendAttributeToTensor(tensorType));
46
- }
47
- }
76
+ RewritePatternSet conversionPatterns (&getContext ());
77
+ conversionPatterns.add <PackAttributeSignaturePattern>(&getContext ());
78
+
79
+ if (failed (applyPatternsAndFoldGreedily (funcOp,
80
+ std::move (conversionPatterns)))) {
81
+ signalPassFailure ();
48
82
}
49
- */
50
83
}
51
84
52
85
} // namespace mlir::iree_compiler::GlobalOptimization
0 commit comments