Skip to content

Commit 9580468

Browse files
9Tempestjoker-eph
authored andcommitted
[mlir][affine] Enforce each result type to match Reduction ops in affine.parallel verifier
This patch updates AffineParallelOp::verify() to check each result type matches its corresponding reduction op (i.e, the result type must be a `FloatType` if the reduction attribute is `addf`) affine.parallel will crash on --lower-affine if the corresponding result type cannot match the reduction attribute. ``` %128 = affine.parallel (%arg2, %arg3) = (0, 0) to (8, 7) reduce ("maxf") -> (memref<8x7xf32>) { %alloc_33 = memref.alloc() : memref<8x7xf32> affine.yield %alloc_33 : memref<8x7xf32> } ``` This will crash and report a type conversion issue when we run `mlir-opt --lower-affine` ``` Assertion failed: (isa<To>(Val) && "cast<Ty>() argument of incompatible type!"), function cast, file Casting.h, line 572. PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace. Stack dump: 0. Program arguments: mlir-opt --lower-affine temp.mlir #0 0x0000000102a18f18 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/workspacebin/mlir-opt+0x1002f8f18) #1 0x0000000102a171b4 llvm::sys::RunSignalHandlers() (/workspacebin/mlir-opt+0x1002f71b4) #2 0x0000000102a195c4 SignalHandler(int) (/workspacebin/mlir-opt+0x1002f95c4) #3 0x00000001be7894c4 (/usr/lib/system/libsystem_platform.dylib+0x1803414c4) #4 0x00000001be771ee0 (/usr/lib/system/libsystem_pthread.dylib+0x180329ee0) #5 0x00000001be6ac340 (/usr/lib/system/libsystem_c.dylib+0x180264340) #6 0x00000001be6ab754 (/usr/lib/system/libsystem_c.dylib+0x180263754) #7 0x0000000106864790 mlir::arith::getIdentityValueAttr(mlir::arith::AtomicRMWKind, mlir::Type, mlir::OpBuilder&, mlir::Location) (.cold.4) (/workspacebin/mlir-opt+0x104144790) #8 0x0000000102ba66ac mlir::arith::getIdentityValueAttr(mlir::arith::AtomicRMWKind, mlir::Type, mlir::OpBuilder&, mlir::Location) (/workspacebin/mlir-opt+0x1004866ac) #9 0x0000000102ba6910 mlir::arith::getIdentityValue(mlir::arith::AtomicRMWKind, mlir::Type, mlir::OpBuilder&, mlir::Location) (/workspacebin/mlir-opt+0x100486910) ... ``` Fixes #64068 Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D157985
1 parent e39de2b commit 9580468

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

+50-2
Original file line numberDiff line numberDiff line change
@@ -3915,6 +3915,49 @@ void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
39153915
setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
39163916
}
39173917

3918+
// check whether resultType match op or not in affine.parallel
3919+
static bool isResultTypeMatchAtomicRMWKind(Type resultType,
3920+
arith::AtomicRMWKind op) {
3921+
switch (op) {
3922+
case arith::AtomicRMWKind::addf:
3923+
return isa<FloatType>(resultType);
3924+
case arith::AtomicRMWKind::addi:
3925+
return isa<IntegerType>(resultType);
3926+
case arith::AtomicRMWKind::assign:
3927+
return true;
3928+
case arith::AtomicRMWKind::mulf:
3929+
return isa<FloatType>(resultType);
3930+
case arith::AtomicRMWKind::muli:
3931+
return isa<IntegerType>(resultType);
3932+
case arith::AtomicRMWKind::maximumf:
3933+
return isa<FloatType>(resultType);
3934+
case arith::AtomicRMWKind::minimumf:
3935+
return isa<FloatType>(resultType);
3936+
case arith::AtomicRMWKind::maxs: {
3937+
auto intType = llvm::dyn_cast<IntegerType>(resultType);
3938+
return intType && intType.isSigned();
3939+
}
3940+
case arith::AtomicRMWKind::mins: {
3941+
auto intType = llvm::dyn_cast<IntegerType>(resultType);
3942+
return intType && intType.isSigned();
3943+
}
3944+
case arith::AtomicRMWKind::maxu: {
3945+
auto intType = llvm::dyn_cast<IntegerType>(resultType);
3946+
return intType && intType.isUnsigned();
3947+
}
3948+
case arith::AtomicRMWKind::minu: {
3949+
auto intType = llvm::dyn_cast<IntegerType>(resultType);
3950+
return intType && intType.isUnsigned();
3951+
}
3952+
case arith::AtomicRMWKind::ori:
3953+
return isa<IntegerType>(resultType);
3954+
case arith::AtomicRMWKind::andi:
3955+
return isa<IntegerType>(resultType);
3956+
default:
3957+
return false;
3958+
}
3959+
}
3960+
39183961
LogicalResult AffineParallelOp::verify() {
39193962
auto numDims = getNumDims();
39203963
if (getLowerBoundsGroups().getNumElements() != numDims ||
@@ -3946,11 +3989,16 @@ LogicalResult AffineParallelOp::verify() {
39463989
if (getReductions().size() != getNumResults())
39473990
return emitOpError("a reduction must be specified for each output");
39483991

3949-
// Verify reduction ops are all valid
3950-
for (Attribute attr : getReductions()) {
3992+
// Verify reduction ops are all valid and each result type matches reduction
3993+
// ops
3994+
for (auto it : llvm::enumerate((getReductions()))) {
3995+
Attribute attr = it.value();
39513996
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
39523997
if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
39533998
return emitOpError("invalid reduction attribute");
3999+
auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4000+
if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
4001+
return emitOpError("result type cannot match reduction attribute");
39544002
}
39554003

39564004
// Verify that the bound operands are valid dimension/symbols.

mlir/test/Dialect/Affine/invalid.mlir

+12
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,18 @@ func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
297297

298298
// -----
299299

300+
func.func @affine_parallel(%arg0 : index, %arg1 : index, %arg2 : index) {
301+
%0 = memref.alloc() : memref<100x100xi32>
302+
// expected-error@+1 {{result type cannot match reduction attribute}}
303+
%1 = affine.parallel (%i, %j) = (0, 0) to (100, 100) step (10, 10) reduce ("minimumf") -> (i32) {
304+
%2 = affine.load %0[%i, %j] : memref<100x100xi32>
305+
affine.yield %2 : i32
306+
}
307+
return
308+
}
309+
310+
// -----
311+
300312
func.func @vector_load_invalid_vector_type() {
301313
%0 = memref.alloc() : memref<100xf32>
302314
affine.for %i0 = 0 to 16 step 8 {

0 commit comments

Comments
 (0)