Skip to content

Commit 4c00a22

Browse files
Enable scatter fusion with index operand. (#19198)
This drops a pessimistic check during analysis of the indexing maps of the fused `OpOperand` in the producer and consumer that was preventing fusion of the scatter operation with its index operand producer. Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
1 parent cbdcdd0 commit 4c00a22

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -267,14 +267,14 @@ matchIteratorTypes(const llvm::SmallBitVector &rootOuterParallelLoop,
267267

268268
// If the candidate is all parallel, then it should be at least as parallel as
269269
// the root.
270-
for (int pos : llvm::seq<int>(0, rootOuterParallelLoop.size())) {
270+
for (int pos : llvm::seq<int>(0, std::min(candidateOuterParallelLoop.size(),
271+
rootOuterParallelLoop.size()))) {
271272
// If we reach the end of the outer loops of the root, break out of the
272273
// loop.
273274
if (!rootOuterParallelLoop.test(pos))
274275
break;
275276
// If the root loop is parallel, the candidate loop should also be parallel.
276-
if (pos >= candidateOuterParallelLoop.size() ||
277-
!candidateOuterParallelLoop.test(pos))
277+
if (!candidateOuterParallelLoop.test(pos))
278278
return false;
279279
}
280280
return true;

compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir

+32
Original file line numberDiff line numberDiff line change
@@ -922,3 +922,35 @@ util.func @custom_op_no_producer_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<
922922
// CHECK-SAME: ins(%[[DISPATCH1]],
923923
// CHECK: flow.return %[[CUSTOM_OP]]
924924
// CHECK: util.return %[[DISPATCH2]]
925+
926+
// -----
927+
928+
util.func @scatter_index_producer_fusion(%arg0 : tensor<?x1xi64>,
929+
%arg1 : index, %arg2 : tensor<?x1x32x8x128xf16>,
930+
%arg3 : tensor<?x32x8x128xf16>) -> tensor<?x32x8x128xf16> {
931+
%empty = tensor.empty(%arg1) : tensor<?x1xi32>
932+
%0 = linalg.generic {
933+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
934+
affine_map<(d0, d1) -> (d0, d1)>],
935+
iterator_types = ["parallel", "parallel"]}
936+
ins(%arg0 : tensor<?x1xi64>) outs(%empty : tensor<?x1xi32>) {
937+
^bb0(%in: i64, %out: i32):
938+
%1 = arith.trunci %in : i64 to i32
939+
linalg.yield %1 : i32
940+
} -> tensor<?x1xi32>
941+
%1 = iree_linalg_ext.scatter
942+
dimension_map = [0] unique_indices(true)
943+
ins(%arg2, %0 : tensor<?x1x32x8x128xf16>, tensor<?x1xi32>)
944+
outs(%arg3 : tensor<?x32x8x128xf16>) {
945+
^bb0(%arg6: f16, %arg7: f16):
946+
iree_linalg_ext.yield %arg6 : f16
947+
} -> tensor<?x32x8x128xf16>
948+
util.return %1 : tensor<?x32x8x128xf16>
949+
}
950+
// CHECK-LABEL: func public @scatter_index_producer_fusion
951+
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
952+
// CHECK: %[[GENERIC:.+]] = linalg.generic
953+
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
954+
// CHECK-SAME: ins(%{{.+}}, %[[GENERIC]] :
955+
// CHECK: flow.return %[[SCATTER]]
956+
// CHECK: util.return %[[DISPATCH]]

0 commit comments

Comments
 (0)