Skip to content

Commit 931032c

Browse files
committed
Add comments to the tests.
Signed-off-by: hanhanW <hanhan0912@gmail.com>
1 parent 92a9908 commit 931032c

File tree

2 files changed

+229
-4
lines changed

2 files changed

+229
-4
lines changed

compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp

+176-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "llvm/Support/Debug.h"
1717
#include "llvm/Support/LogicalResult.h"
1818
#include "mlir/IR/BuiltinTypes.h"
19+
#include "mlir/IR/SymbolTable.h"
1920
#include "mlir/Interfaces/FunctionInterfaces.h"
2021
#include "mlir/Pass/Pass.h"
2122
#include "mlir/Support/LLVM.h"
@@ -52,6 +53,175 @@ SmallVector<const T *> gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) {
5253
return results;
5354
}
5455

56+
/// Returns the affinities of the `dispatchOp`'s resource operands. An empty
57+
/// array attribute indicates that the resource operand affinity is not found.
58+
/// Usually, it happens when it fails on affinity analysis.
59+
/// Note that the size of the result might not equal to the number of resource
60+
/// operands. If a resource operand type is not AffinityType, it is skipped.
61+
static SmallVector<Attribute>
62+
getResourceOperandsAffinities(IREE::Stream::AffinityAnalysis &affinityAnalysis,
63+
IREE::Stream::AsyncDispatchOp dispatchOp) {
64+
SmallVector<Attribute> result;
65+
Builder b(dispatchOp.getContext());
66+
auto emptyArray = b.getArrayAttr({});
67+
for (auto operand : dispatchOp.getResourceOperands()) {
68+
// Skip if the operand type is not AffinityType.
69+
if (!isa<IREE::Stream::AffinityTypeInterface>(operand.getType())) {
70+
continue;
71+
}
72+
SmallVector<IREE::Stream::AffinityAttr> affinities;
73+
if (!affinityAnalysis.tryLookupResourceAffinity(operand, affinities)) {
74+
result.push_back(emptyArray);
75+
continue;
76+
}
77+
result.push_back(b.getArrayAttr(llvm::to_vector_of<Attribute>(affinities)));
78+
}
79+
return result;
80+
}
81+
82+
/// Duplicates stream.executables based on the affinity analysis of
83+
/// stream.async.dispatch ops. Some executables can be launched by different
84+
/// devices. It can produce wrong codegen artifacts when bindings types are
85+
/// encoded (i.e., the tensor type has an encoding attribute). Because they can
86+
/// result in different layouts, especially when multi-device is involved. E.g.,
87+
/// say that device_a and device_b interpret a tensor type with encodings in
88+
/// different layouts, and there is an executable that can be launch with
89+
/// resources from either device_a or device_b. It is confusing what the input
90+
/// layouts for the executable because there are two possibilities. In this
91+
/// case, we have to duplicate the executable with updated encoding, and modify
92+
/// the dispatch to launch proper executable based on device analysis.
93+
static LogicalResult duplicateExecutablesPerAffinityVariant(
94+
ModuleOp moduleOp, SymbolTable symbolTable, FunctionOpInterface funcOp,
95+
IREE::Stream::ResolveLayoutAttrFn resolveLayoutAttr) {
96+
MLIRContext *ctx = moduleOp.getContext();
97+
IRRewriter rewriter(ctx);
98+
99+
// 1. Gather per-export [execution affinity -> [resource affinities]] map.
100+
IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp);
101+
if (failed(affinityAnalysis.run())) {
102+
return moduleOp.emitError("failed on running affinity analysis");
103+
}
104+
SmallVector<IREE::Stream::AsyncDispatchOp> candidates;
105+
funcOp.walk(
106+
[&](IREE::Stream::AsyncDispatchOp op) { candidates.push_back(op); });
107+
108+
// export -> [affinity -> array per resource of affinities PVS].
109+
DenseMap<IREE::Stream::ExecutableExportOp,
110+
SetVector<std::pair<IREE::Stream::AffinityAttr, ArrayAttr>>>
111+
exportToDispatchSites;
112+
113+
llvm::MapVector<IREE::Stream::AsyncDispatchOp, SmallVector<Attribute>>
114+
resourceAffinities;
115+
for (auto dispatchOp : candidates) {
116+
SmallVector<IREE::Stream::AffinityAttr> execAffinities;
117+
if (!affinityAnalysis.tryLookupExecutionAffinity(dispatchOp,
118+
execAffinities)) {
119+
return dispatchOp.emitError("failed on execution affinity lookup");
120+
}
121+
assert(execAffinities.size() == 1 &&
122+
"We should only have a single execution "
123+
"affinity when running the pass.");
124+
125+
SmallVector<Attribute> operandAffinityAttrs =
126+
getResourceOperandsAffinities(affinityAnalysis, dispatchOp);
127+
resourceAffinities[dispatchOp] = operandAffinityAttrs;
128+
129+
dispatchOp.forEachEntryPointAttr([&](SymbolRefAttr entryPoint) {
130+
auto exportOp = cast<IREE::Stream::ExecutableExportOp>(
131+
symbolTable.lookupSymbolIn(moduleOp, entryPoint));
132+
exportToDispatchSites[exportOp].insert(std::make_pair(
133+
execAffinities[0], rewriter.getArrayAttr(operandAffinityAttrs)));
134+
});
135+
}
136+
137+
LLVM_DEBUG({
138+
llvm::dbgs() << "Dump of exportToDispatchSites\n";
139+
for (auto [exportOp, affinities] : exportToDispatchSites) {
140+
llvm::dbgs() << " ExportOp: " << exportOp.getSymName() << "\n";
141+
for (auto [execAffinity, resourceAffinities] : affinities) {
142+
llvm::dbgs() << " executaion affinity: " << execAffinity << "\n";
143+
llvm::dbgs() << " resource affinities: " << resourceAffinities
144+
<< "\n";
145+
}
146+
}
147+
});
148+
149+
// 2. Duplicate executables for each unqiue resource affinities.
150+
151+
// Mapping from [execution affinity, resource operands affinities, export] to
152+
// the executable op.
153+
using DispatchSiteInfo = std::tuple<IREE::Stream::AffinityAttr, ArrayAttr,
154+
IREE::Stream::ExecutableExportOp>;
155+
DenseMap<DispatchSiteInfo, IREE::Stream::ExecutableOp>
156+
dispatchSiteToExecutableOp;
157+
for (auto [exportOp, execAndResourceAffinities] : exportToDispatchSites) {
158+
auto executableOp = exportOp->getParentOfType<IREE::Stream::ExecutableOp>();
159+
// No need to duplicate the executable if all the uses have the same
160+
// affinities.
161+
// TODO(hanchung): Do not duplicate the executables if bindings are not
162+
// encoded. I.e., all the tensor types do not have encodings.
163+
if (execAndResourceAffinities.size() == 1) {
164+
auto [execAffinity, resourceAffinities] = execAndResourceAffinities[0];
165+
dispatchSiteToExecutableOp[DispatchSiteInfo(
166+
execAffinity, resourceAffinities, exportOp)] = executableOp;
167+
continue;
168+
}
169+
170+
int64_t dupId = -1;
171+
for (auto [execAffinity, resourceAffinities] : execAndResourceAffinities) {
172+
rewriter.setInsertionPointAfter(executableOp);
173+
IREE::Stream::ExecutableOp dupOp = executableOp;
174+
if (dupId != -1) {
175+
auto symName = std::string(executableOp.getSymName());
176+
symName += "_dup" + std::to_string(dupId);
177+
dupOp = rewriter.cloneWithoutRegions(executableOp);
178+
rewriter.modifyOpInPlace(dupOp, [&] {
179+
dupOp.setSymName(symName);
180+
IRMapping mapping;
181+
executableOp.getRegion().cloneInto(&dupOp.getRegion(), mapping);
182+
});
183+
}
184+
dispatchSiteToExecutableOp[DispatchSiteInfo(
185+
execAffinity, resourceAffinities, exportOp)] = dupOp;
186+
dupId++;
187+
}
188+
}
189+
190+
// 3. Update dispatch sites, i.e., point dispatch entry points to
191+
// corresponding cloned executables.
192+
for (auto dispatchOp : candidates) {
193+
SmallVector<Attribute> newEntryPoints;
194+
SmallVector<IREE::Stream::AffinityAttr> execAffinities;
195+
// Sanity checks. It should already meet the requirement because they are
196+
// checked in step 1.
197+
assert(affinityAnalysis.tryLookupExecutionAffinity(dispatchOp,
198+
execAffinities));
199+
assert(execAffinities.size() == 1);
200+
SmallVector<Attribute> operandAttrs = resourceAffinities[dispatchOp];
201+
dispatchOp.forEachEntryPointAttr([&](SymbolRefAttr entryPoint) {
202+
auto exportOp = cast<IREE::Stream::ExecutableExportOp>(
203+
symbolTable.lookupSymbolIn(moduleOp, entryPoint));
204+
auto info = DispatchSiteInfo(
205+
execAffinities[0], rewriter.getArrayAttr(operandAttrs), exportOp);
206+
assert(dispatchSiteToExecutableOp.count(info));
207+
208+
auto executableOp = dispatchSiteToExecutableOp[info];
209+
auto newSym = SymbolRefAttr::get(executableOp->getAttrOfType<StringAttr>(
210+
SymbolTable::getSymbolAttrName()),
211+
entryPoint.getNestedReferences());
212+
newEntryPoints.push_back(newSym);
213+
});
214+
215+
rewriter.modifyOpInPlace(dispatchOp, [&] {
216+
dispatchOp.setEntryPointsAttr(rewriter.getArrayAttr(newEntryPoints));
217+
});
218+
}
219+
220+
// TODO(hanchung): Update encodings in executables.
221+
222+
return success();
223+
}
224+
55225
// TODO(hanchung): Add "cloneWithEncoding" method to RankedTensorType.
56226
static RankedTensorType cloneWithEncoding(RankedTensorType type,
57227
Attribute encodingAttr) {
@@ -149,6 +319,7 @@ struct SpecializeEncodingsPass
149319
return signalPassFailure();
150320
}
151321

322+
SymbolTable symbolTable(moduleOp);
152323
llvm::MapVector<StringRef, IREE::Stream::ExecutableOp> executableOps;
153324
for (auto executableOp : moduleOp.getOps<IREE::Stream::ExecutableOp>()) {
154325
executableOps[executableOp.getName()] = executableOp;
@@ -164,7 +335,11 @@ struct SpecializeEncodingsPass
164335
return signalPassFailure();
165336
}
166337

167-
// TODO(hanchung): Duplicate executables and update dispatch ops.
338+
if (failed(duplicateExecutablesPerAffinityVariant(
339+
moduleOp, symbolTable, funcOp, resolveLayoutAttr))) {
340+
funcOp.emitError("failed on executable duplication");
341+
return signalPassFailure();
342+
}
168343
}
169344
}
170345
};
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,69 @@
11
// RUN: iree-opt --split-input-file --iree-stream-specialize-encodings %s | FileCheck %s
22

3-
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding_layout = #iree_cpu.vmvx_encoding_layout<>, ukernels = "all"}>
3+
//------------------------------------------------------------------------------
4+
// Stream ops that have TensorPhaseOp trait. This test suite tests that the
5+
// encoding is updated that carries resolved layouts.
6+
//------------------------------------------------------------------------------
7+
8+
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding_layout = #iree_cpu.vmvx_encoding_layout<>}>
49
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
510
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32]>
611
module {
712
util.global private @device_a = #device_target_local_0_
813

9-
util.func public @main(%d0: index, %d1: index) -> index {
14+
util.func public @tensor_sizeof(%d0: index, %d1: index) -> index {
1015
%size = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<?x?xf32, #encoding>{%d0, %d1} : index
1116
util.return %size : index
1217
}
1318
}
1419
// CHECK: #[[EXECUTABLE:.+]] = #hal.executable.target<"vmvx",
1520
// CHECK: #[[$ENCODING:.+]] = #iree_encoding.encoding
1621
// CHECK-SAME: layouts = [#[[EXECUTABLE]]]
17-
// CHECK-LABEL: util.func public @main
22+
// CHECK-LABEL: util.func public @tensor_sizeof
1823
// CHECK: %[[RES:.+]] = stream.tensor.sizeof {{.+}} tensor<?x?xf32, #[[$ENCODING]]>
1924
// CHECK: return %[[RES]]
25+
26+
// -----
27+
28+
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {ukernels = "none"}>
29+
#map = affine_map<(d0) -> (d0)>
30+
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
31+
#device_target_local_1_ = #hal.device.target<"local", {ordinal = 1 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
32+
module attributes {stream.affinity.default = #hal.device.affinity<@device_a>} {
33+
util.global private @device_a = #device_target_local_0_
34+
util.global private @device_b = #device_target_local_1_
35+
stream.executable private @ex {
36+
stream.executable.export public @dispatch
37+
}
38+
util.func public @multi_device(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view {
39+
%c16 = arith.constant 16 : index
40+
%c0 = arith.constant 0 : index
41+
%c4 = arith.constant 4 : index
42+
%element_type_f32 = hal.element_type<f32> : i32
43+
%dense_row_major = hal.encoding_type<dense_row_major> : i32
44+
hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("input0") shape([%c4]) type(%element_type_f32) encoding(%dense_row_major)
45+
%0 = stream.tensor.import on(#hal.device.affinity<@device_a>) %arg0 : !hal.buffer_view -> tensor<4xf32> in !stream.resource<external>{%c16}
46+
%1 = stream.timepoint.import on(#hal.device.affinity<@device_a>) %arg1 : (!hal.fence) => !stream.timepoint
47+
%2 = stream.timepoint.await %1 => %0 : !stream.resource<external>{%c16}
48+
%3 = stream.async.transfer %2 : !stream.resource<external>{%c16} from(#hal.device.affinity<@device_a>) -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%c16}
49+
%4 = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@dispatch(%3[%c0 to %c16 for %c16]) : (!stream.resource<*>{%c16}) -> !stream.resource<*>{%c16}
50+
%5 = stream.async.transfer %4 : !stream.resource<*>{%c16} from(#hal.device.affinity<@device_a>) -> to(#hal.device.affinity<@device_b>) !stream.resource<*>{%c16}
51+
%6 = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@dispatch(%5[%c0 to %c16 for %c16]) : (!stream.resource<*>{%c16}) -> !stream.resource<*>{%c16}
52+
%7 = stream.async.transfer %6 : !stream.resource<*>{%c16} from(#hal.device.affinity<@device_b>) -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%c16}
53+
%result, %result_timepoint = stream.timepoint.barrier on(#hal.device.affinity<@device_a>) %7 : !stream.resource<*>{%c16} => !stream.timepoint
54+
stream.timepoint.chain_external on(#hal.device.affinity<@device_a>) %result_timepoint => (%arg2 : !hal.fence)
55+
%8 = stream.async.transfer %result : !stream.resource<*>{%c16} from(#hal.device.affinity<@device_a>) -> to(#hal.device.affinity<@device_a>) !stream.resource<external>{%c16}
56+
%9 = stream.tensor.export on(#hal.device.affinity<@device_a>) %8 : tensor<4xf32> in !stream.resource<external>{%c16} -> !hal.buffer_view
57+
util.return %9 : !hal.buffer_view
58+
}
59+
}
60+
61+
// CHECK: #[[DEVICE_LOCAL_0:.+]] = #hal.device.target
62+
// CHECK: #[[DEVICE_LOCAL_1:.+]] = #hal.device.target
63+
// CHECK: util.global private @[[$DEVICE_A:.+]] = #[[DEVICE_LOCAL_0]]
64+
// CHECK: util.global private @[[$DEVICE_B:.+]] = #[[DEVICE_LOCAL_1]]
65+
// CHECK: stream.executable private @[[$EX0:.+]] {
66+
// CHECK: stream.executable private @[[$EX1:.+]] {
67+
// CHECK-LABEL: util.func public @multi_device
68+
// CHECK: stream.async.dispatch on(#hal.device.affinity<@[[$DEVICE_A]]>) @[[$EX0]]::@dispatch
69+
// CHECK: stream.async.dispatch on(#hal.device.affinity<@[[$DEVICE_B]]>) @[[$EX1]]::@dispatch

0 commit comments

Comments
 (0)