Skip to content

Commit b0e8f3c

Browse files
authored
[VectorLayoutAnalysis] Fix bug in scf.for transfer functions (#15989)
Prior to this patch, the scf.for transfer functions were not propagating change on resolution of scf.for operands/results.
1 parent 4592b8f commit b0e8f3c

File tree

3 files changed

+68
-26
lines changed

3 files changed

+68
-26
lines changed

compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp

+20-23
Original file line numberDiff line numberDiff line change
@@ -924,31 +924,28 @@ void transform_dialect::WorkgroupSwizzleOp::getEffects(
924924

925925
static void setAnchorOpsFromAttributes(VectorLayoutAnalysis &analysis,
926926
func::FuncOp funcOp) {
927-
for (Block &block : funcOp) {
928-
for (Operation &op : block) {
929-
for (NamedAttribute attr : op.getAttrs()) {
930-
StringRef name = attr.getName().strref();
931-
if (name.find("__vector_layout_test_anchor_operand_") !=
932-
std::string::npos) {
933-
int operandNum;
934-
name.substr(name.find_last_of("_") + 1)
935-
.getAsInteger(/*Radix=*/10, operandNum);
936-
assert(operandNum < op.getNumOperands() &&
937-
"operand number out of range");
938-
analysis.setAnchor(op.getOperand(operandNum), attr.getValue());
939-
}
940-
if (name.find("__vector_layout_test_anchor_result_") !=
941-
std::string::npos) {
942-
int resultNum;
943-
name.substr(name.find_last_of("_") + 1)
944-
.getAsInteger(/*Radix=*/10, resultNum);
945-
assert(resultNum < op.getNumResults() &&
946-
"result number out of range");
947-
analysis.setAnchor(op.getResult(resultNum), attr.getValue());
948-
}
927+
funcOp.walk([&](Operation *op) {
928+
for (NamedAttribute attr : op->getAttrs()) {
929+
StringRef name = attr.getName().strref();
930+
if (name.find("__vector_layout_test_anchor_operand_") !=
931+
std::string::npos) {
932+
int operandNum;
933+
name.substr(name.find_last_of("_") + 1)
934+
.getAsInteger(/*Radix=*/10, operandNum);
935+
assert(operandNum < op->getNumOperands() &&
936+
"operand number out of range");
937+
analysis.setAnchor(op->getOperand(operandNum), attr.getValue());
938+
}
939+
if (name.find("__vector_layout_test_anchor_result_") !=
940+
std::string::npos) {
941+
int resultNum;
942+
name.substr(name.find_last_of("_") + 1)
943+
.getAsInteger(/*Radix=*/10, resultNum);
944+
assert(resultNum < op->getNumResults() && "result number out of range");
945+
analysis.setAnchor(op->getResult(resultNum), attr.getValue());
949946
}
950947
}
951-
}
948+
});
952949
}
953950

954951
static void emitLayoutRemarks(VectorLayoutAnalysis &analysis,

compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,8 @@ void PropagateLayout::visitRegionSuccessors(RegionBranchOpInterface branch,
763763
// Propagate the layouts.
764764
for (auto [forwardedLattice, inputLattice] :
765765
llvm::zip(forwardedLattices, inputLattices)) {
766-
inputLattice->resolve(forwardedLattice);
766+
ChangeResult changed = inputLattice->resolve(forwardedLattice);
767+
propagateIfChanged(inputLattice, changed);
767768
}
768769
}
769770
}
@@ -887,8 +888,9 @@ void EnforceLayout::visitRegionSuccessors(RegionBranchOpInterface branch,
887888
int64_t curr = 0;
888889
for (auto [forwardedLattice, inputLattice] :
889890
llvm::zip(forwardedLattices, inputLattices)) {
890-
forwardedLattice->resolveWithPossibleConflict(inputLattice,
891-
*forwardedOperands[curr]);
891+
ChangeResult changed = forwardedLattice->resolveWithPossibleConflict(
892+
inputLattice, *forwardedOperands[curr]);
893+
propagateIfChanged(forwardedLattice, changed);
892894
curr++;
893895
}
894896
}

compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir

+43
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,46 @@ builtin.module attributes { transform.with_named_sequence } {
145145
transform.yield
146146
}
147147
}
148+
149+
// -----
150+
151+
#layout = #iree_vector_ext.layout<<[VECTORY], [16]>, <[BATCHY, VECTORX], [2, 8]>>
152+
153+
// Propagate and enforce through scf.for
154+
builtin.module attributes { transform.with_named_sequence } {
155+
func.func @scffor(%arr: memref<16x16xf16>, %arr2: memref<16xf16>, %a: vector<16xf16>, %b: vector<16xf16>) -> vector<16xf16> {
156+
%c0 = arith.constant 0 : index
157+
%c1 = arith.constant 1 : index
158+
%c1024 = arith.constant 1024 : index
159+
%cst_0 = arith.constant 0.0 : f16
160+
%cst0_1 = arith.constant dense<0.0> : vector<16xf16>
161+
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}}
162+
163+
%out = scf.for %iv = %c0 to %c1024 step %c1 iter_args(%arg1 = %cst0_1) -> (vector<16xf16>) {
164+
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}}
165+
%root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true], "__vector_layout_test_anchor_result_0" = #layout} : memref<16x16xf16>, vector<16x16xf16>
166+
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>, <[ BATCHY, VECTORX], [2, 8]>>}}
167+
%root2 = vector.transfer_read %arr2[%c0], %cst_0 {in_bounds = [true]} : memref<16xf16>, vector<16xf16>
168+
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}}
169+
%root_transpose = vector.transpose %root, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
170+
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHY, VECTORX], [2, 8]>, <[ VECTORY], [16]>>}}
171+
%root_red = vector.multi_reduction<add>, %root_transpose, %arg1 [0] : vector<16x16xf16> to vector<16xf16>
172+
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}}
173+
%c = arith.mulf %root_red, %b : vector<16xf16>
174+
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}}
175+
%d = arith.addf %c, %a : vector<16xf16>
176+
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}}
177+
%e = arith.divf %d, %root2 : vector<16xf16>
178+
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>>}}
179+
scf.yield %e : vector<16xf16>
180+
}
181+
182+
func.return %out : vector<16xf16>
183+
}
184+
185+
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
186+
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
187+
transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op
188+
transform.yield
189+
}
190+
}

0 commit comments

Comments
 (0)