16
16
#include " llvm/Support/Debug.h"
17
17
#include " llvm/Support/LogicalResult.h"
18
18
#include " mlir/IR/BuiltinTypes.h"
19
+ #include " mlir/IR/SymbolTable.h"
19
20
#include " mlir/Interfaces/FunctionInterfaces.h"
20
21
#include " mlir/Pass/Pass.h"
21
22
#include " mlir/Support/LLVM.h"
@@ -52,6 +53,175 @@ SmallVector<const T *> gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) {
52
53
return results;
53
54
}
54
55
56
+ // / Returns the affinities of the `dispatchOp`'s resource operands. An empty
57
+ // / array attribute indicates that the resource operand affinity is not found.
58
+ // / Usually, it happens when it fails on affinity analysis.
59
+ // / Note that the size of the result might not equal to the number of resource
60
+ // / operands. If a resource operand type is not AffinityType, it is skipped.
61
+ static SmallVector<Attribute>
62
+ getResourceOperandsAffinities (IREE::Stream::AffinityAnalysis &affinityAnalysis,
63
+ IREE::Stream::AsyncDispatchOp dispatchOp) {
64
+ SmallVector<Attribute> result;
65
+ Builder b (dispatchOp.getContext ());
66
+ auto emptyArray = b.getArrayAttr ({});
67
+ for (auto operand : dispatchOp.getResourceOperands ()) {
68
+ // Skip if the operand type is not AffinityType.
69
+ if (!isa<IREE::Stream::AffinityTypeInterface>(operand.getType ())) {
70
+ continue ;
71
+ }
72
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
73
+ if (!affinityAnalysis.tryLookupResourceAffinity (operand, affinities)) {
74
+ result.push_back (emptyArray);
75
+ continue ;
76
+ }
77
+ result.push_back (b.getArrayAttr (llvm::to_vector_of<Attribute>(affinities)));
78
+ }
79
+ return result;
80
+ }
81
+
82
+ // / Duplicates stream.executables based on the affinity analysis of
83
+ // / stream.async.dispatch ops. Some executables can be launched by different
84
+ // / devices. It can produce wrong codegen artifacts when bindings types are
85
+ // / encoded (i.e., the tensor type has an encoding attribute). Because they can
86
+ // / result in different layouts, especially when multi-device is involved. E.g.,
87
+ // / say that device_a and device_b interpret a tensor type with encodings in
88
+ // / different layouts, and there is an executable that can be launch with
89
+ // / resources from either device_a or device_b. It is confusing what the input
90
+ // / layouts for the executable because there are two possibilities. In this
91
+ // / case, we have to duplicate the executable with updated encoding, and modify
92
+ // / the dispatch to launch proper executable based on device analysis.
93
+ static LogicalResult duplicateExecutablesPerAffinityVariant (
94
+ ModuleOp moduleOp, SymbolTable symbolTable, FunctionOpInterface funcOp,
95
+ IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) {
96
+ MLIRContext *ctx = moduleOp.getContext ();
97
+ IRRewriter rewriter (ctx);
98
+
99
+ // 1. Gather per-export [execution affinity -> [resource affinities]] map.
100
+ IREE::Stream::AffinityAnalysis affinityAnalysis (moduleOp);
101
+ if (failed (affinityAnalysis.run ())) {
102
+ return moduleOp.emitError (" failed on running affinity analysis" );
103
+ }
104
+ SmallVector<IREE::Stream::AsyncDispatchOp> candidates;
105
+ funcOp.walk (
106
+ [&](IREE::Stream::AsyncDispatchOp op) { candidates.push_back (op); });
107
+
108
+ // export -> [affinity -> array per resource of affinities PVS].
109
+ DenseMap<IREE::Stream::ExecutableExportOp,
110
+ SetVector<std::pair<IREE::Stream::AffinityAttr, ArrayAttr>>>
111
+ exportToDispatchSites;
112
+
113
+ llvm::MapVector<IREE::Stream::AsyncDispatchOp, SmallVector<Attribute>>
114
+ resourceAffinities;
115
+ for (auto dispatchOp : candidates) {
116
+ SmallVector<IREE::Stream::AffinityAttr> execAffinities;
117
+ if (!affinityAnalysis.tryLookupExecutionAffinity (dispatchOp,
118
+ execAffinities)) {
119
+ return dispatchOp.emitError (" failed on execution affinity lookup" );
120
+ }
121
+ assert (execAffinities.size () == 1 &&
122
+ " We should only have a single execution "
123
+ " affinity when running the pass." );
124
+
125
+ SmallVector<Attribute> operandAffinityAttrs =
126
+ getResourceOperandsAffinities (affinityAnalysis, dispatchOp);
127
+ resourceAffinities[dispatchOp] = operandAffinityAttrs;
128
+
129
+ dispatchOp.forEachEntryPointAttr ([&](SymbolRefAttr entryPoint) {
130
+ auto exportOp = cast<IREE::Stream::ExecutableExportOp>(
131
+ symbolTable.lookupSymbolIn (moduleOp, entryPoint));
132
+ exportToDispatchSites[exportOp].insert (std::make_pair (
133
+ execAffinities[0 ], rewriter.getArrayAttr (operandAffinityAttrs)));
134
+ });
135
+ }
136
+
137
+ LLVM_DEBUG ({
138
+ llvm::dbgs () << " Dump of exportToDispatchSites\n " ;
139
+ for (auto [exportOp, affinities] : exportToDispatchSites) {
140
+ llvm::dbgs () << " ExportOp: " << exportOp.getSymName () << " \n " ;
141
+ for (auto [execAffinity, resourceAffinities] : affinities) {
142
+ llvm::dbgs () << " executaion affinity: " << execAffinity << " \n " ;
143
+ llvm::dbgs () << " resource affinities: " << resourceAffinities
144
+ << " \n " ;
145
+ }
146
+ }
147
+ });
148
+
149
+ // 2. Duplicate executables for each unqiue resource affinities.
150
+
151
+ // Mapping from [execution affinity, resource operands affinities, export] to
152
+ // the executable op.
153
+ using DispatchSiteInfo = std::tuple<IREE::Stream::AffinityAttr, ArrayAttr,
154
+ IREE::Stream::ExecutableExportOp>;
155
+ DenseMap<DispatchSiteInfo, IREE::Stream::ExecutableOp>
156
+ dispatchSiteToExecutableOp;
157
+ for (auto [exportOp, execAndResourceAffinities] : exportToDispatchSites) {
158
+ auto executableOp = exportOp->getParentOfType <IREE::Stream::ExecutableOp>();
159
+ // No need to duplicate the executable if all the uses have the same
160
+ // affinities.
161
+ // TODO(hanchung): Do not duplicate the executables if bindings are not
162
+ // encoded. I.e., all the tensor types do not have encodings.
163
+ if (execAndResourceAffinities.size () == 1 ) {
164
+ auto [execAffinity, resourceAffinities] = execAndResourceAffinities[0 ];
165
+ dispatchSiteToExecutableOp[DispatchSiteInfo (
166
+ execAffinity, resourceAffinities, exportOp)] = executableOp;
167
+ continue ;
168
+ }
169
+
170
+ int64_t dupId = -1 ;
171
+ for (auto [execAffinity, resourceAffinities] : execAndResourceAffinities) {
172
+ rewriter.setInsertionPointAfter (executableOp);
173
+ IREE::Stream::ExecutableOp dupOp = executableOp;
174
+ if (dupId != -1 ) {
175
+ auto symName = std::string (executableOp.getSymName ());
176
+ symName += " _dup" + std::to_string (dupId);
177
+ dupOp = rewriter.cloneWithoutRegions (executableOp);
178
+ rewriter.modifyOpInPlace (dupOp, [&] {
179
+ dupOp.setSymName (symName);
180
+ IRMapping mapping;
181
+ executableOp.getRegion ().cloneInto (&dupOp.getRegion (), mapping);
182
+ });
183
+ }
184
+ dispatchSiteToExecutableOp[DispatchSiteInfo (
185
+ execAffinity, resourceAffinities, exportOp)] = dupOp;
186
+ dupId++;
187
+ }
188
+ }
189
+
190
+ // 3. Update dispatch sites, i.e., point dispatch entry points to
191
+ // corresponding cloned executables.
192
+ for (auto dispatchOp : candidates) {
193
+ SmallVector<Attribute> newEntryPoints;
194
+ SmallVector<IREE::Stream::AffinityAttr> execAffinities;
195
+ // Sanity checks. It should already meet the requirement because they are
196
+ // checked in step 1.
197
+ assert (affinityAnalysis.tryLookupExecutionAffinity (dispatchOp,
198
+ execAffinities));
199
+ assert (execAffinities.size () == 1 );
200
+ SmallVector<Attribute> operandAttrs = resourceAffinities[dispatchOp];
201
+ dispatchOp.forEachEntryPointAttr ([&](SymbolRefAttr entryPoint) {
202
+ auto exportOp = cast<IREE::Stream::ExecutableExportOp>(
203
+ symbolTable.lookupSymbolIn (moduleOp, entryPoint));
204
+ auto info = DispatchSiteInfo (
205
+ execAffinities[0 ], rewriter.getArrayAttr (operandAttrs), exportOp);
206
+ assert (dispatchSiteToExecutableOp.count (info));
207
+
208
+ auto executableOp = dispatchSiteToExecutableOp[info];
209
+ auto newSym = SymbolRefAttr::get (executableOp->getAttrOfType <StringAttr>(
210
+ SymbolTable::getSymbolAttrName ()),
211
+ entryPoint.getNestedReferences ());
212
+ newEntryPoints.push_back (newSym);
213
+ });
214
+
215
+ rewriter.modifyOpInPlace (dispatchOp, [&] {
216
+ dispatchOp.setEntryPointsAttr (rewriter.getArrayAttr (newEntryPoints));
217
+ });
218
+ }
219
+
220
+ // TODO(hanchung): Update encodings in executables.
221
+
222
+ return success ();
223
+ }
224
+
55
225
// TODO(hanchung): Add "cloneWithEncoding" method to RankedTensorType.
56
226
static RankedTensorType cloneWithEncoding (RankedTensorType type,
57
227
Attribute encodingAttr) {
@@ -149,6 +319,7 @@ struct SpecializeEncodingsPass
149
319
return signalPassFailure ();
150
320
}
151
321
322
+ SymbolTable symbolTable (moduleOp);
152
323
llvm::MapVector<StringRef, IREE::Stream::ExecutableOp> executableOps;
153
324
for (auto executableOp : moduleOp.getOps <IREE::Stream::ExecutableOp>()) {
154
325
executableOps[executableOp.getName ()] = executableOp;
@@ -164,7 +335,11 @@ struct SpecializeEncodingsPass
164
335
return signalPassFailure ();
165
336
}
166
337
167
- // TODO(hanchung): Duplicate executables and update dispatch ops.
338
+ if (failed (duplicateExecutablesPerAffinityVariant (
339
+ moduleOp, symbolTable, funcOp, resolveLayoutAttr))) {
340
+ funcOp.emitError (" failed on executable duplication" );
341
+ return signalPassFailure ();
342
+ }
168
343
}
169
344
}
170
345
};
0 commit comments