Skip to content

Commit 09fb0e6

Browse files
authored
Remove the >3D restriction on reshaping for dot_general (#13521)
Previously, the collapsing dims together (in case of multiple dims of the same type) would only happen if the input rank is larger than 3. This made the incorrect assumption that a 3d tensor is of a standard form (BxCxP or BxPxC), where in reality this could be the result of multiples of one dimension but none of another (BxBxC, for example). In those cases we still need to perform a reshape.
1 parent e2aa9f2 commit 09fb0e6

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

compiler/src/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp

+9-7
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,13 @@ class TransposeReshapeGenericDotGeneral
292292
b.getI64TensorAttr(targetOrder));
293293
}
294294

295-
Value ReshapeIfMorethan3D(OpBuilder &b, Location loc, Value src,
296-
size_t dimsBorder0, size_t dimsBorder1) const {
295+
Value ReshapeIfNonStandard(OpBuilder &b, Location loc, Value src,
296+
size_t dimsBorder0, size_t dimsBorder1) const {
297297
auto type = src.getType().cast<RankedTensorType>();
298-
if (type.getRank() <= 3) return src;
299298
auto shape = type.getShape();
299+
if (dimsBorder0 <= 1 && dimsBorder1 - dimsBorder0 <= 1 &&
300+
shape.size() - dimsBorder1 <= 1)
301+
return src;
300302
SmallVector<int64_t, 4> result_shape = {
301303
std::accumulate(shape.begin(), shape.begin() + dimsBorder0, 1,
302304
std::multiplies<int64_t>()),
@@ -387,10 +389,10 @@ class TransposeReshapeGenericDotGeneral
387389
int64_t numRhsContractionDims =
388390
rhsContractionBase + rhsContractingDims.size();
389391

390-
lhs = ReshapeIfMorethan3D(rewriter, op.getLoc(), lhs,
391-
lhsBatchingDims.size(), lhsContractionBase);
392-
rhs = ReshapeIfMorethan3D(rewriter, op.getLoc(), rhs,
393-
rhsBatchingDims.size(), numRhsContractionDims);
392+
lhs = ReshapeIfNonStandard(rewriter, op.getLoc(), lhs,
393+
lhsBatchingDims.size(), lhsContractionBase);
394+
rhs = ReshapeIfNonStandard(rewriter, op.getLoc(), rhs,
395+
rhsBatchingDims.size(), numRhsContractionDims);
394396

395397
if (lhs == op.getLhs() && rhs == op.getRhs())
396398
return rewriter.notifyMatchFailure(op, "already in canonical form");

compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canonicalize_dot_general.mlir

+18
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,21 @@ func.func public @dot_general_2d(%arg0: tensor<4x3xf32> {mhlo.sharding = ""}, %a
1515
// CHECK-SAME: precision_config = [#mhlo<precision HIGHEST>, #mhlo<precision HIGHEST>]
1616
return %0 : tensor<3xf32>
1717
}
18+
19+
// CHECK-LABEL: @dot_general_4d
20+
func.func public @dot_general_4d(%arg0: tensor<1x2x3xf32> {mhlo.sharding = ""}, %arg1: tensor<1x4x2x3xf32> {mhlo.sharding = ""}) -> tensor<1x2x4xf32> {
21+
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], rhs_batching_dimensions = [0, 2], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [3]>, precision_config = [#mhlo<precision HIGHEST>, #mhlo<precision HIGHEST>]} : (tensor<1x2x3xf32>, tensor<1x4x2x3xf32>) -> tensor<1x2x4xf32>
22+
23+
// CHECK: %[[RHS_T:.+]] = "mhlo.transpose"(%arg1) {permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>} : (tensor<1x4x2x3xf32>) -> tensor<1x2x3x4xf32>
24+
// CHECK: %[[LHS_R:.+]] = mhlo.reshape %arg0 : (tensor<1x2x3xf32>) -> tensor<2x1x3xf32>
25+
// CHECK: %[[RHS_R:.+]] = mhlo.reshape %[[RHS_T]] : (tensor<1x2x3x4xf32>) -> tensor<2x3x4xf32>
26+
// CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS_R]], %[[RHS_R]])
27+
// CHECK-SAME: dot_dimension_numbers = #mhlo.dot<
28+
// CHECK-SAME: lhs_batching_dimensions = [0]
29+
// CHECK-SAME: rhs_batching_dimensions = [0]
30+
// CHECK-SAME: lhs_contracting_dimensions = [2]
31+
// CHECK-SAME: rhs_contracting_dimensions = [1]>
32+
// CHECK-SAME: precision_config = [#mhlo<precision HIGHEST>, #mhlo<precision HIGHEST>]
33+
// CHECK: mhlo.reshape %[[DOT]] : (tensor<2x1x4xf32>) -> tensor<1x2x4xf32>
34+
return %0 : tensor<1x2x4xf32>
35+
}

0 commit comments

Comments
 (0)