|
4 | 4 | // See https://llvm.org/LICENSE.txt for license information.
|
5 | 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
6 | 6 |
|
| 7 | +#include "iree/compiler/Dialect/HAL/Conversion/HALToVM/Patterns.h" |
7 | 8 | #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
|
8 | 9 | #include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
|
9 | 10 | #include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
|
@@ -137,13 +138,16 @@ class CommandBufferFillBufferOpConversion
|
137 | 138 | auto patternLengthConst = rewriter.createOrFold<mlir::arith::ConstantIntOp>(
|
138 | 139 | op.getLoc(), patternLengthBytes, 32);
|
139 | 140 | Value pattern = op.getPattern();
|
140 |
| - if (patternBitWidth < 32) { |
| 141 | + if (patternBitWidth < 64) { |
141 | 142 | pattern = rewriter.createOrFold<arith::ExtUIOp>(
|
142 |
| - op.getLoc(), rewriter.getIntegerType(32), pattern); |
| 143 | + op.getLoc(), rewriter.getIntegerType(64), pattern); |
143 | 144 | }
|
144 | 145 | callOperands.push_back(pattern);
|
145 | 146 | callOperands.push_back(patternLengthConst);
|
146 | 147 |
|
| 148 | + callOperands.push_back( |
| 149 | + getFlagsI64(op.getLoc(), adaptor.getFlagsAttr(), rewriter)); |
| 150 | + |
147 | 151 | auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
|
148 | 152 | op, SymbolRefAttr::get(importOp), importType.getResults(),
|
149 | 153 | callOperands);
|
@@ -183,7 +187,9 @@ class CommandBufferUpdateBufferOpConversion
|
183 | 187 | castToImportType(adaptor.getTargetOffset(), rewriter.getI64Type(),
|
184 | 188 | rewriter),
|
185 | 189 | castToImportType(adaptor.getLength(), rewriter.getI64Type(), rewriter),
|
186 |
| - targetBufferSlot}; |
| 190 | + targetBufferSlot, |
| 191 | + getFlagsI64(op.getLoc(), adaptor.getFlagsAttr(), rewriter), |
| 192 | + }; |
187 | 193 | auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
|
188 | 194 | op, SymbolRefAttr::get(importOp), importType.getResults(),
|
189 | 195 | callOperands);
|
@@ -226,6 +232,7 @@ class CommandBufferCopyBufferOpConversion
|
226 | 232 | castToImportType(adaptor.getTargetOffset(), rewriter.getI64Type(),
|
227 | 233 | rewriter),
|
228 | 234 | castToImportType(adaptor.getLength(), rewriter.getI64Type(), rewriter),
|
| 235 | + getFlagsI64(op.getLoc(), adaptor.getFlagsAttr(), rewriter), |
229 | 236 | };
|
230 | 237 | auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
|
231 | 238 | op, SymbolRefAttr::get(importOp), importType.getResults(),
|
@@ -341,21 +348,14 @@ class CommandBufferDispatchOpConversion
|
341 | 348 | auto i64Type = rewriter.getI64Type();
|
342 | 349 | Value zeroI32 = rewriter.create<IREE::VM::ConstI32ZeroOp>(op.getLoc());
|
343 | 350 |
|
344 |
| - auto flags = adaptor.getFlagsAttr() |
345 |
| - ? rewriter |
346 |
| - .create<IREE::VM::ConstI64Op>( |
347 |
| - op.getLoc(), adaptor.getFlagsAttr().getInt()) |
348 |
| - .getResult() |
349 |
| - : rewriter.create<IREE::VM::ConstI64ZeroOp>(op.getLoc()) |
350 |
| - .getResult(); |
351 | 351 | SmallVector<Value, 8> callOperands = {
|
352 | 352 | adaptor.getCommandBuffer(),
|
353 | 353 | adaptor.getExecutable(),
|
354 | 354 | castToImportType(adaptor.getEntryPoint(), i32Type, rewriter),
|
355 | 355 | castToImportType(adaptor.getWorkgroupX(), i32Type, rewriter),
|
356 | 356 | castToImportType(adaptor.getWorkgroupY(), i32Type, rewriter),
|
357 | 357 | castToImportType(adaptor.getWorkgroupZ(), i32Type, rewriter),
|
358 |
| - flags, |
| 358 | + getFlagsI64(op.getLoc(), adaptor.getFlagsAttr(), rewriter), |
359 | 359 | };
|
360 | 360 | SmallVector<int16_t, 5> segmentSizes = {
|
361 | 361 | /*command_buffer=*/-1,
|
@@ -421,21 +421,14 @@ class CommandBufferDispatchIndirectOpConversion
|
421 | 421 |
|
422 | 422 | auto [workgroupsBufferSlot, workgroupsBuffer] =
|
423 | 423 | splitBufferSlot(op.getLoc(), adaptor.getWorkgroupsBuffer(), rewriter);
|
424 |
| - auto flags = adaptor.getFlagsAttr() |
425 |
| - ? rewriter |
426 |
| - .create<IREE::VM::ConstI64Op>( |
427 |
| - op.getLoc(), adaptor.getFlagsAttr().getInt()) |
428 |
| - .getResult() |
429 |
| - : rewriter.create<IREE::VM::ConstI64ZeroOp>(op.getLoc()) |
430 |
| - .getResult(); |
431 | 424 | SmallVector<Value, 8> callOperands = {
|
432 | 425 | adaptor.getCommandBuffer(),
|
433 | 426 | adaptor.getExecutable(),
|
434 | 427 | castToImportType(adaptor.getEntryPoint(), i32Type, rewriter),
|
435 | 428 | workgroupsBufferSlot,
|
436 | 429 | workgroupsBuffer,
|
437 | 430 | castToImportType(adaptor.getWorkgroupsOffset(), i64Type, rewriter),
|
438 |
| - flags, |
| 431 | + getFlagsI64(op.getLoc(), adaptor.getFlagsAttr(), rewriter), |
439 | 432 | };
|
440 | 433 | SmallVector<int16_t, 5> segmentSizes = {
|
441 | 434 | /*command_buffer=*/-1,
|
|
0 commit comments