Skip to content

Commit e8032e3

Browse files
authored
Adding flags to most HAL methods and extending existing ones to i64. (#20368)
**This is a breaking HAL change and will require recompilation of VMFBs.** There _should_ be a useful error indicating that emitted if attempting to load an existing VMFB. There's been TODOs in the HAL imports for awhile waiting for a breakage and #20240 does that. To prevent more PRs like that this adds flag arguments to most HAL methods so that we can add new flags in the future. Some existing flags were i32 and have been changed to i64 for consistency. Of note: * `hal.device.queue.execute` split to `hal.device.queue.barrier` instead of relying on empty command buffer list * `hal.command_buffer.advise_buffer` added, but not yet in IR * `hal.command_buffer.fill_buffer` max pattern length changed to i64 to match `hal.device.queue.fill`
1 parent f301696 commit e8032e3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+890
-469
lines changed

compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,9 @@ void createCoarseFencesSyncWrapper(StringRef syncFunctionName,
541541
.getResults();
542542

543543
// Wait forever for signal.
544-
rewriter.create<IREE::HAL::FenceAwaitOp>(loc, rewriter.getI32Type(),
545-
timeoutMillis, signalFence);
544+
rewriter.create<IREE::HAL::FenceAwaitOp>(
545+
loc, rewriter.getI32Type(), timeoutMillis,
546+
IREE::HAL::WaitFlagBitfield::None, signalFence);
546547

547548
rewriter.create<IREE::Util::ReturnOp>(loc, callResults);
548549
}

compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,9 @@ createImportWrapperFunc(IREE::ABI::InvocationModel invocationModel,
259259
if (hasSideEffects && signalFence) {
260260
auto timeoutMillis =
261261
entryBuilder.create<arith::ConstantIntOp>(importOp.getLoc(), -1, 32);
262-
entryBuilder.create<IREE::HAL::FenceAwaitOp>(importOp.getLoc(),
263-
entryBuilder.getI32Type(),
264-
timeoutMillis, signalFence);
262+
entryBuilder.create<IREE::HAL::FenceAwaitOp>(
263+
importOp.getLoc(), entryBuilder.getI32Type(), timeoutMillis,
264+
IREE::HAL::WaitFlagBitfield::None, signalFence);
265265
}
266266

267267
// Marshal results.

compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp

+12-19
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
#include "iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h"
78
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
89
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
910
#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
@@ -137,13 +138,16 @@ class CommandBufferFillBufferOpConversion
137138
auto patternLengthConst = rewriter.createOrFold<mlir::arith::ConstantIntOp>(
138139
op.getLoc(), patternLengthBytes, 32);
139140
Value pattern = op.getPattern();
140-
if (patternBitWidth < 32) {
141+
if (patternBitWidth < 64) {
141142
pattern = rewriter.createOrFold<arith::ExtUIOp>(
142-
op.getLoc(), rewriter.getIntegerType(32), pattern);
143+
op.getLoc(), rewriter.getIntegerType(64), pattern);
143144
}
144145
callOperands.push_back(pattern);
145146
callOperands.push_back(patternLengthConst);
146147

148+
callOperands.push_back(
149+
getFlagsI64(op.getLoc(), adaptor.getFlagsAttr(), rewriter));
150+
147151
auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
148152
op, SymbolRefAttr::get(importOp), importType.getResults(),
149153
callOperands);
@@ -183,7 +187,9 @@ class CommandBufferUpdateBufferOpConversion
183187
castToImportType(adaptor.getTargetOffset(), rewriter.getI64Type(),
184188
rewriter),
185189
castToImportType(adaptor.getLength(), rewriter.getI64Type(), rewriter),
186-
targetBufferSlot};
190+
targetBufferSlot,
191+
getFlagsI64(op.getLoc(), adaptor.getFlagsAttr(), rewriter),
192+
};
187193
auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
188194
op, SymbolRefAttr::get(importOp), importType.getResults(),
189195
callOperands);
@@ -226,6 +232,7 @@ class CommandBufferCopyBufferOpConversion
226232
castToImportType(adaptor.getTargetOffset(), rewriter.getI64Type(),
227233
rewriter),
228234
castToImportType(adaptor.getLength(), rewriter.getI64Type(), rewriter),
235+
getFlagsI64(op.getLoc(), adaptor.getFlagsAttr(), rewriter),
229236
};
230237
auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
231238
op, SymbolRefAttr::get(importOp), importType.getResults(),
@@ -341,21 +348,14 @@ class CommandBufferDispatchOpConversion
341348
auto i64Type = rewriter.getI64Type();
342349
Value zeroI32 = rewriter.create<IREE::VM::ConstI32ZeroOp>(op.getLoc());
343350

344-
auto flags = adaptor.getFlagsAttr()
345-
? rewriter
346-
.create<IREE::VM::ConstI64Op>(
347-
op.getLoc(), adaptor.getFlagsAttr().getInt())
348-
.getResult()
349-
: rewriter.create<IREE::VM::ConstI64ZeroOp>(op.getLoc())
350-
.getResult();
351351
SmallVector<Value, 8> callOperands = {
352352
adaptor.getCommandBuffer(),
353353
adaptor.getExecutable(),
354354
castToImportType(adaptor.getEntryPoint(), i32Type, rewriter),
355355
castToImportType(adaptor.getWorkgroupX(), i32Type, rewriter),
356356
castToImportType(adaptor.getWorkgroupY(), i32Type, rewriter),
357357
castToImportType(adaptor.getWorkgroupZ(), i32Type, rewriter),
358-
flags,
358+
getFlagsI64(op.getLoc(), adaptor.getFlagsAttr(), rewriter),
359359
};
360360
SmallVector<int16_t, 5> segmentSizes = {
361361
/*command_buffer=*/-1,
@@ -421,21 +421,14 @@ class CommandBufferDispatchIndirectOpConversion
421421

422422
auto [workgroupsBufferSlot, workgroupsBuffer] =
423423
splitBufferSlot(op.getLoc(), adaptor.getWorkgroupsBuffer(), rewriter);
424-
auto flags = adaptor.getFlagsAttr()
425-
? rewriter
426-
.create<IREE::VM::ConstI64Op>(
427-
op.getLoc(), adaptor.getFlagsAttr().getInt())
428-
.getResult()
429-
: rewriter.create<IREE::VM::ConstI64ZeroOp>(op.getLoc())
430-
.getResult();
431424
SmallVector<Value, 8> callOperands = {
432425
adaptor.getCommandBuffer(),
433426
adaptor.getExecutable(),
434427
castToImportType(adaptor.getEntryPoint(), i32Type, rewriter),
435428
workgroupsBufferSlot,
436429
workgroupsBuffer,
437430
castToImportType(adaptor.getWorkgroupsOffset(), i64Type, rewriter),
438-
flags,
431+
getFlagsI64(op.getLoc(), adaptor.getFlagsAttr(), rewriter),
439432
};
440433
SmallVector<int16_t, 5> segmentSizes = {
441434
/*command_buffer=*/-1,

compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
#include "iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h"
78
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
89
#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
910
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -134,8 +135,6 @@ class DeviceQueueFillOpConversion
134135
auto patternLength = rewriter.create<IREE::VM::ConstI32Op>(
135136
op.getLoc(),
136137
llvm::divideCeil(op.getPattern().getType().getIntOrFloatBitWidth(), 8));
137-
auto flags =
138-
rewriter.create<IREE::VM::ConstI64Op>(op.getLoc(), op.getFlags());
139138
std::array<Value, 10> callOperands = {
140139
adaptor.getDevice(),
141140
castToImportType(adaptor.getQueueAffinity(), i64Type, rewriter),
@@ -146,7 +145,7 @@ class DeviceQueueFillOpConversion
146145
castToImportType(adaptor.getLength(), i64Type, rewriter),
147146
castToImportType(adaptor.getPattern(), i64Type, rewriter),
148147
patternLength,
149-
flags,
148+
getFlagsI64(op.getLoc(), adaptor.getFlagsAttr(), rewriter),
150149
};
151150
auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
152151
op, SymbolRefAttr::get(importOp), importType.getResults(),
@@ -177,19 +176,23 @@ class DeviceQueueExecuteIndirectOpConversion
177176
auto importType = importOp.getFunctionType();
178177
auto i64Type = rewriter.getI64Type();
179178

179+
Value queueAffinity =
180+
castToImportType(adaptor.getQueueAffinity(), i64Type, rewriter);
180181
SmallVector<Value, 8> callOperands = {
181182
adaptor.getDevice(),
182-
castToImportType(adaptor.getQueueAffinity(), i64Type, rewriter),
183+
queueAffinity,
183184
adaptor.getWaitFence(),
184185
adaptor.getSignalFence(),
185186
adaptor.getCommandBuffer(),
187+
getFlagsI64(op.getLoc(), adaptor.getFlagsAttr(), rewriter),
186188
};
187189
SmallVector<int16_t, 5> segmentSizes = {
188190
/*device=*/-1,
189191
/*queue_affinity=*/-1,
190192
/*wait_fence=*/-1,
191193
/*signal_fence=*/-1,
192194
/*command_buffer=*/-1,
195+
/*flags=*/-1,
193196
/*bindings=*/
194197
static_cast<int16_t>(adaptor.getBindingBuffers().size()),
195198
};
@@ -239,6 +242,8 @@ void populateHALDeviceToVMPatterns(MLIRContext *context,
239242
context, importSymbols, typeConverter, "hal.device.queue.read");
240243
patterns.insert<VMImportOpConversion<IREE::HAL::DeviceQueueWriteOp>>(
241244
context, importSymbols, typeConverter, "hal.device.queue.write");
245+
patterns.insert<VMImportOpConversion<IREE::HAL::DeviceQueueBarrierOp>>(
246+
context, importSymbols, typeConverter, "hal.device.queue.barrier");
242247
patterns.insert<VMImportOpConversion<IREE::HAL::DeviceQueueExecuteOp>>(
243248
context, importSymbols, typeConverter, "hal.device.queue.execute");
244249
patterns.insert<DeviceQueueExecuteIndirectOpConversion>(

compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,8 @@ class ExecutableCreateOpConversion
121121
createOp.getLoc(), adaptor.getConstants(), rewriter);
122122

123123
SmallVector<Value, 8> callOperands = {
124-
adaptor.getDevice(),
125-
executableFormatStr,
126-
rodataOp,
124+
adaptor.getDevice(), adaptor.getQueueAffinity(),
125+
executableFormatStr, rodataOp,
127126
constantBuffer,
128127
};
129128
auto importType = importOp.getFunctionType();

compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010

1111
namespace mlir::iree_compiler {
1212

13+
Value getFlagsI64(Location loc, IntegerAttr flagsAttr, OpBuilder &builder) {
14+
return flagsAttr
15+
? builder.create<IREE::VM::ConstI64Op>(loc, flagsAttr.getInt())
16+
.getResult()
17+
: builder.create<IREE::VM::ConstI64ZeroOp>(loc).getResult();
18+
}
19+
1320
extern void populateHALAllocatorToVMPatterns(MLIRContext *context,
1421
SymbolTable &importSymbols,
1522
TypeConverter &typeConverter,

compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ void populateHALToVMPatterns(MLIRContext *context, SymbolTable &importSymbols,
2323
Value createPackedConstantBuffer(Location loc, ValueRange constantValues,
2424
OpBuilder &builder);
2525

26+
// Returns an i64 value initialized with the bits of |flagsAttr| or 0.
27+
Value getFlagsI64(Location loc, IntegerAttr flagsAttr, OpBuilder &builder);
28+
2629
} // namespace mlir::iree_compiler
2730

2831
#endif // IREE_COMPILER_DIALECT_HAL_CONVERSION_HALTOVM_PATTERNS_H_

compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/channel_ops.mlir

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
// CHECK-LABEL: @channel_create
44
// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[AFFINITY:.+]]: i64, %[[ID:.+]]: !vm.buffer, %[[GROUP:.+]]: !vm.buffer, %[[RANK:.+]]: i32, %[[COUNT:.+]]: i32) -> !vm.ref<!hal.channel>
55
util.func public @channel_create(%device: !hal.device, %affinity: i64, %id: !util.buffer, %group: !util.buffer, %rank: i32, %count: i32) -> !hal.channel {
6-
// CHECK: %[[FLAGS:.+]] = vm.const.i32.zero
6+
// CHECK: %[[FLAGS:.+]] = vm.const.i64.zero
77
// CHECK: %[[CHANNEL:.+]] = vm.call @hal.channel.create(%[[DEVICE]], %[[AFFINITY]], %[[FLAGS]], %[[ID]], %[[GROUP]], %[[RANK]], %[[COUNT]])
88
%channel = hal.channel.create device(%device : !hal.device)
99
affinity(%affinity)
10-
flags(0)
10+
flags("None")
1111
id(%id)
1212
group(%group)
1313
rank(%rank)
@@ -21,12 +21,12 @@ util.func public @channel_create(%device: !hal.device, %affinity: i64, %id: !uti
2121
// CHECK-LABEL: @channel_split
2222
// CHECK-SAME: (%[[BASE_CHANNEL:.+]]: !vm.ref<!hal.channel>, %[[COLOR:.+]]: i32, %[[KEY:.+]]: i32)
2323
util.func public @channel_split(%base_channel: !hal.channel, %color: i32, %key: i32) -> !hal.channel {
24-
// CHECK: %[[FLAGS:.+]] = vm.const.i32.zero
24+
// CHECK: %[[FLAGS:.+]] = vm.const.i64.zero
2525
// CHECK: %[[SPLIT_CHANNEL:.+]] = vm.call @hal.channel.split(%[[BASE_CHANNEL]], %[[COLOR]], %[[KEY]], %[[FLAGS]])
2626
%split_channel = hal.channel.split<%base_channel : !hal.channel>
2727
color(%color)
2828
key(%key)
29-
flags(0) : !hal.channel
29+
flags("None") : !hal.channel
3030
// CHECK: vm.return %[[SPLIT_CHANNEL]]
3131
util.return %split_channel : !hal.channel
3232
}

0 commit comments

Comments
 (0)