From 5f6e3def14121b93aa8eb63deac355d40c9b0ca9 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Tue, 7 Mar 2023 11:40:46 +0000 Subject: [PATCH 01/15] Add GenericReduceOp to ttir and add `tl.prod` using it --- include/triton/Dialect/Triton/IR/TritonOps.td | 17 +++++ lib/Dialect/Triton/IR/Ops.cpp | 72 +++++++++++++++---- python/src/triton.cc | 22 ++++++ python/triton/language/__init__.py | 2 + python/triton/language/core.py | 7 ++ python/triton/language/semantic.py | 36 ++++++++++ 6 files changed, 142 insertions(+), 14 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index ba88798403ae..d4dad7bf6980 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -369,6 +369,23 @@ def TT_ReduceOp : TT_Op<"reduce", [Pure, }]; } +def TT_GenericReduceOp: TT_Op<"generic_reduce", + [Pure, DeclareOpInterfaceMethods, SingleBlock]> { + let summary = "Reduction using generic combination algorithm"; + let arguments = (ins TT_Tensor:$operand, I32Attr:$axis); + let results = (outs TT_Type:$result); + let regions = (region SizedRegion<1>:$region); + let hasRegionVerifier = 1; +} + +def TT_GenericReduceReturnOp: TT_Op<"generic_reduce.return", + [HasParent<"GenericReduceOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for reduce operator"; + let arguments = (ins AnyType:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + + // // External elementwise op // diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 15d62af3db60..2ddcd5c5461c 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -238,21 +238,10 @@ mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes( } //-- ReduceOp -- -mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - // infer shape - Value arg = operands[0]; - auto argTy = arg.getType().cast(); - auto argEltTy = argTy.getElementType(); - auto i32Ty = IntegerType::get(argEltTy.getContext(), 32); - auto redOp = - attributes.get("redOp").cast().getValue(); - bool withIndex = mlir::triton::ReduceOp::withIndex(redOp); - auto retEltTy = withIndex ? i32Ty : argEltTy; +static mlir::LogicalResult inferReduceReturnShape( + const RankedTensorType &argTy, const Type &retEltTy, + int axis, SmallVectorImpl &inferredReturnTypes) { auto retShape = argTy.getShape().vec(); - int axis = attributes.get("axis").cast().getInt(); retShape.erase(retShape.begin() + axis); if (retShape.empty()) { // 0d-tensor -> scalar @@ -280,6 +269,22 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( return mlir::success(); } +mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + Value arg = operands[0]; + auto argTy = arg.getType().cast(); + auto argEltTy = argTy.getElementType(); + auto i32Ty = IntegerType::get(argEltTy.getContext(), 32); + auto redOp = + attributes.get("redOp").cast().getValue(); + bool withIndex = mlir::triton::ReduceOp::withIndex(redOp); + auto retEltTy = withIndex ? i32Ty : argEltTy; + int axis = attributes.get("axis").cast().getInt(); + return inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); +} + bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) { return redOp == mlir::triton::RedOp::ARGMIN || redOp == mlir::triton::RedOp::ARGMAX || @@ -289,6 +294,45 @@ bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) { redOp == mlir::triton::RedOp::ARGFMAX; } +//-- GenericReduceOp -- +mlir::LogicalResult mlir::triton::GenericReduceOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + Value arg = operands[0]; + auto argTy = arg.getType().cast(); + auto retEltTy = argTy.getElementType(); + int axis = attributes.get("axis").cast().getInt(); + return inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); +} + +mlir::LogicalResult mlir::triton::GenericReduceOp::verifyRegions() { + auto argTy = getOperand().getType().cast(); + auto argElemTy = argTy.getElementType(); + + constexpr unsigned num_args = 2; + auto &block = this->getBody(); + if (block.getNumArguments() != num_args) { + return emitOpError() << "nested block must take " << num_args + << " arguments, but given block with " + << block.getNumArguments() << " arguments"; + } + unsigned i = 0; + for (const auto & blockArgTy: block.getArgumentTypes()) { + if (blockArgTy != argElemTy) { + return this->emitOpError() << "types mismatch on reduction block. Expected argument " << i + << " to have type " << argElemTy << " but got " << blockArgTy; + } + ++i; + } + + if (!mlir::isa(block.getTerminator())) { + return this->emitOpError("the GenericReduceOp region must be terminated " + "with a GenericReduceReturnOp but got") << block.getTerminator(); + } + return mlir::success(); +} + //-- SplatOp -- OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { auto value = adaptor.getSrc(); diff --git a/python/src/triton.cc b/python/src/triton.cc index 1cce609016a7..4841fe85a978 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -440,6 +440,8 @@ void init_triton_ir(py::module &&m) { .def_property_readonly("type", &mlir::func::FuncOp::getFunctionType) .def("reset_type", &mlir::func::FuncOp::setType); + py::class_(m, "GenericReduceOp"); + py::class_(m, "InsertPoint"); py::class_(m, "builder", py::dynamic_attr()) @@ -1326,6 +1328,26 @@ void init_triton_ir(py::module &&m) { return self.create(loc, resType, redOp, operand, axis); }) + .def("create_generic_reduce", + [](mlir::OpBuilder &self, mlir::Value &operand, int axis) -> mlir::triton::GenericReduceOp { + auto loc = self.getUnknownLoc(); + auto inputTensorType = + operand.getType().dyn_cast(); + std::vector shape = inputTensorType.getShape(); + shape.erase(shape.begin() + axis); + mlir::Type resType = inputTensorType.getElementType(); + if (!shape.empty()) { + resType = mlir::RankedTensorType::get(shape, resType); + } + return self.create( + loc, resType, operand, axis); + }) + .def("create_reduce_ret", + [](mlir::OpBuilder &self, mlir::Value &return_value) -> mlir::OpState { + auto loc = self.getUnknownLoc(); + return self.create( + loc, return_value); + }) .def("create_ptr_to_int", [](mlir::OpBuilder &self, mlir::Value &val, mlir::Type &type) -> mlir::Value { diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 3df8c243a0b0..8f26b55448ef 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -57,6 +57,7 @@ num_programs, pi32_t, pointer_type, + prod, program_id, ravel, reshape, @@ -156,6 +157,7 @@ "philox_impl", "pi32_t", "pointer_type", + "prod", "program_id", "rand", "rand4x", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 82a53934f874..af79d74a3a2f 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1168,6 +1168,13 @@ def xor_sum(input, axis, _builder=None): return semantic.xor_sum(input, axis, _builder) +@builtin +@_add_reduction_docstr("prod") +def prod(input, axis, _builder): + axis = _constexpr_to_value(axis) + return semantic.prod(input, axis, _builder) + + # ----------------------- # Internal for debugging # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 79750e82a098..53172a6d53c6 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1,5 +1,6 @@ from __future__ import annotations # remove after python 3.11 +from contextlib import contextmanager from typing import List, Optional, Tuple from . import core as tl @@ -1185,6 +1186,41 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: raise ValueError("xor_sum only supported for integers") return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR) +def reduction(input: tl.tensor, axis: int, region_builder_fn, builder: ir.builder) -> tl.tensor: + scalar_ty = input.type.scalar + + # get result type + shape = input.type.shape + ret_shape = [s for i, s in enumerate(shape) if i != axis] + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = out_scalar_ty + + reduce_op = builder.create_generic_reduce(input.handle, axis) + region_builder_fn(reduce_op) + reduce_op.verify() + + return tl.tensor(reduce_op.get_result(0), res_ty) + +@contextmanager +def insertion_guard(builder): + ip = builder.get_insertion_point() + yield + builder.restore_insertion_point(ip) + +def prod(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + + def make_mul(reduce_op): + ir_scalar_ty = input.type.scalar.to_ir(builder) + region = reduce_op.get_region(0) + with insertion_guard(builder): + block = builder.create_block_with_parent(region, [ir_scalar_ty] * 2) + fmul = builder.create_fmul(block.arg(0), block.arg(1)) + builder.create_reduce_ret(fmul) + + return reduction(input, axis, make_mul, builder) # ===----------------------------------------------------------------------=== # Math From 786433e519336c29cc7660aad27de14dfa1ab9fc Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Wed, 8 Mar 2023 21:55:32 +0000 Subject: [PATCH 02/15] Lower tt.generic_reduce to LLVM IR --- include/triton/Analysis/Utility.h | 23 +- lib/Analysis/Allocation.cpp | 4 + lib/Analysis/Membar.cpp | 3 + lib/Analysis/Utility.cpp | 12 +- lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 1 + .../TritonGPUToLLVM/GenericReduceOpToLLVM.cpp | 340 ++++++++++++++++++ .../TritonGPUToLLVM/GenericReduceOpToLLVM.h | 16 + .../TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 2 + .../TritonToTritonGPUPass.cpp | 33 ++ 9 files changed, 422 insertions(+), 12 deletions(-) create mode 100644 lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp create mode 100644 lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.h diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index ee7fadb59df1..8397bb28932e 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -11,9 +11,24 @@ namespace mlir { class ReduceOpHelper { + ReduceOpHelper(Operation *op, int axis, bool withIndex) + : op(op), axis(axis), withIndex(withIndex) { + srcTy = op->getOperands().front().getType().cast(); + } + public: - explicit ReduceOpHelper(triton::ReduceOp op) : op(op) { - srcTy = op.getOperand().getType().cast(); + explicit ReduceOpHelper(triton::ReduceOp op): + ReduceOpHelper( + op.getOperation(), + op.getAxis(), + triton::ReduceOp::withIndex(op.getRedOp())) { + } + + explicit ReduceOpHelper(triton::GenericReduceOp op): + ReduceOpHelper( + op.getOperation(), + op.getAxis(), + /*withIndex*/false) { } ArrayRef getSrcShape() { return srcTy.getShape(); } @@ -35,8 +50,10 @@ class ReduceOpHelper { unsigned getScratchSizeInBytes(); private: - triton::ReduceOp op; + Operation *op; RankedTensorType srcTy{}; + int axis; + bool withIndex; }; bool isSharedEncoding(Value value); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 3cee8e1d13be..29cdec5b53b6 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -166,6 +166,10 @@ class AllocationAnalysis { ReduceOpHelper helper(reduceOp); unsigned bytes = helper.getScratchSizeInBytes(); allocation->addBuffer(op, bytes); + } else if (auto reduceOp = dyn_cast(op)) { + ReduceOpHelper helper(reduceOp); + unsigned bytes = helper.getScratchSizeInBytes(); + allocation->addBuffer(op, bytes); } else if (auto cvtLayout = dyn_cast(op)) { auto srcTy = cvtLayout.getSrc().getType().cast(); auto dstTy = cvtLayout.getResult().getType().cast(); diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index e9e095445fa8..e23165a3f16c 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -72,6 +72,9 @@ void MembarAnalysis::visitTerminator(Operation *op, } return; } + if (isa(op)) { + return; + } // Otherwise, it could be a return op assert(isa(op) && "Unknown terminator"); } diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 4165dc31ec8d..85915ecad340 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -11,14 +11,12 @@ namespace mlir { bool ReduceOpHelper::isFastReduction() { auto srcLayout = srcTy.getEncoding(); - auto axis = op.getAxis(); return axis == triton::gpu::getOrder(srcLayout)[0]; } unsigned ReduceOpHelper::getInterWarpSize() { auto srcLayout = srcTy.getEncoding(); auto srcShape = srcTy.getShape(); - auto axis = op.getAxis(); auto srcReduceDimSize = static_cast(srcShape[axis]); unsigned sizeIntraWarps = getIntraWarpSize(); return std::min(srcReduceDimSize / sizeIntraWarps, @@ -28,7 +26,6 @@ unsigned ReduceOpHelper::getInterWarpSize() { unsigned ReduceOpHelper::getIntraWarpSize() { auto srcLayout = srcTy.getEncoding(); auto srcShape = srcTy.getShape(); - auto axis = op.getAxis(); auto srcReduceDimSize = static_cast(srcShape[axis]); return std::min(srcReduceDimSize, triton::gpu::getThreadsPerWarp(srcLayout)[axis]); @@ -36,20 +33,17 @@ unsigned ReduceOpHelper::getIntraWarpSize() { unsigned ReduceOpHelper::getThreadsReductionAxis() { auto srcLayout = srcTy.getEncoding(); - auto axis = op.getAxis(); return triton::gpu::getThreadsPerWarp(srcLayout)[axis] * triton::gpu::getWarpsPerCTA(srcLayout)[axis]; } SmallVector ReduceOpHelper::getScratchConfigBasic() { - auto axis = op.getAxis(); auto smemShape = convertType(getSrcShape()); smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis()); return smemShape; } SmallVector> ReduceOpHelper::getScratchConfigsFast() { - auto axis = op.getAxis(); SmallVector> smemShapes(3); auto argLayout = srcTy.getEncoding(); @@ -64,7 +58,7 @@ SmallVector> ReduceOpHelper::getScratchConfigsFast() { /// FIXME(Qingyi): This size is actually larger than required. /// shared memory block1: - auto mod = op.getOperation()->getParentOfType(); + auto mod = op->getParentOfType(); unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); smemShapes[1].push_back(numWarps * 32); @@ -82,10 +76,10 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() { elems = product(smemShape); } - auto tensorType = op.getOperand().getType().cast(); + auto tensorType = op->getOperand(0).getType().cast(); unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8; - if (triton::ReduceOp::withIndex(op.getRedOp())) + if (withIndex) bytes += elems * sizeof(int32_t); return bytes; diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 97e32f4e345c..11e4c8ccdf7b 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_conversion_library(TritonGPUToLLVM TritonGPUToLLVMPass.cpp PTXAsmFormat.cpp ReduceOpToLLVM.cpp + GenericReduceOpToLLVM.cpp Utility.cpp TypeConverter.cpp ViewOpToLLVM.cpp diff --git a/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp new file mode 100644 index 000000000000..7d6731638b1a --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp @@ -0,0 +1,340 @@ +#include "GenericReduceOpToLLVM.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::shflSync; +using ::mlir::LLVM::storeShared; +using ::mlir::triton::gpu::getElemsPerThread; +using ::mlir::triton::gpu::getOrder; + +struct GenericReduceOpConversion + : public ConvertTritonGPUOpToLLVMPattern { +public: + using ConvertTritonGPUOpToLLVMPattern< + triton::GenericReduceOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::GenericReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (ReduceOpHelper(op).isFastReduction()) + return matchAndRewriteFast(op, adaptor, rewriter); + return matchAndRewriteBasic(op, adaptor, rewriter); + } + +private: + + void accumulate(ConversionPatternRewriter &rewriter, + Region &reduceOp, Value &acc, Value cur, bool isFirst) const { + if (isFirst) { + acc = cur; + return; + } + + // Create a new copy of the reduce block, and inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + rewriter.cloneRegionBefore(reduceOp, &parent.front()); + auto &newReduce = parent.front(); + auto returnOp = dyn_cast(newReduce.getTerminator()); + rewriter.mergeBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), {acc, cur}); + acc = returnOp.getResult(); + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + } + + // Use shared memory for reduction within warps and across warps + LogicalResult + matchAndRewriteBasic(triton::GenericReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + unsigned axis = op.getAxis(); + + auto srcTy = op.getOperand().getType().cast(); + auto srcLayout = srcTy.getEncoding().cast(); + auto srcOrd = srcLayout.getOrder(); + auto srcShape = srcTy.getShape(); + + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + auto llvmIndexTy = getTypeConverter()->getIndexType(); + auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); + auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); + Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); + smemBase = bitcast(smemBase, elemPtrTy); + + ReduceOpHelper helper(op); + auto smemShape = helper.getScratchConfigBasic(); + unsigned elems = product(smemShape); + Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(elems)); + indexSmemBase = bitcast(indexSmemBase, indexPtrTy); + + unsigned srcElems = getElemsPerThread(srcTy); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy); + auto srcValues = getTypeConverter()->unpackLLElements( + loc, adaptor.getOperand(), rewriter, srcTy); + + SmallVector> offset = + emitOffsetForLayout(srcLayout, srcTy); + + std::map, Value> accs; + std::map, Value> accIndices; + std::map, SmallVector> indices; + + + Region *reduceOp = &op.getRegion(); + + // reduce within threads + for (unsigned i = 0; i < srcElems; ++i) { + SmallVector key = offset[i]; + key[axis] = 0; + bool isFirst = accs.find(key) == accs.end(); + accumulate(rewriter, *reduceOp, accs[key], srcValues[i], isFirst); + if (isFirst) + indices[key] = srcIndices[i]; + } + + // cached int32 constants + std::map ints; + ints[0] = i32_val(0); + for (int N = smemShape[axis] / 2; N > 0; N >>= 1) + ints[N] = i32_val(N); + Value sizePerThread = i32_val(srcLayout.getSizePerThread()[axis]); + + // reduce across threads + for (auto it : accs) { + const SmallVector &key = it.first; + Value acc = it.second; + SmallVector writeIdx = indices[key]; + + writeIdx[axis] = udiv(writeIdx[axis], sizePerThread); + Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd); + Value writePtr = gep(elemPtrTy, smemBase, writeOffset); + Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset); + store(acc, writePtr); + + SmallVector readIdx(writeIdx.size(), ints[0]); + for (int N = smemShape[axis] / 2; N > 0; N >>= 1) { + readIdx[axis] = ints[N]; + Value readMask = icmp_slt(writeIdx[axis], ints[N]); + Value readOffset = select( + readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd), + ints[0]); + Value readPtr = gep(elemPtrTy, writePtr, readOffset); + barrier(); + Value cur = load(readPtr); + accumulate(rewriter, *reduceOp, acc, cur, false); + barrier(); + store(acc, writePtr); + } + } + + barrier(); + + // set output values + if (auto resultTy = op.getType().dyn_cast()) { + // nd-tensor where n >= 1 + auto resultLayout = resultTy.getEncoding(); + + unsigned resultElems = getElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (unsigned i = 0; i < resultElems; ++i) { + SmallVector readIdx = resultIndices[i]; + readIdx.insert(readIdx.begin() + axis, ints[0]); + Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd); + Value readPtr = gep(elemPtrTy, smemBase, readOffset); + Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); + resultVals[i] = load(readPtr); + } + Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter, + resultTy); + rewriter.replaceOp(op, ret); + } else { + // 0d-tensor -> scalar + Value resultVal = load(smemBase); + rewriter.replaceOp(op, resultVal); + } + + return success(); + } + + // Use warp shuffle for reduction within warps and shared memory for data + // exchange across warps + LogicalResult matchAndRewriteFast(triton::GenericReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + Location loc = op->getLoc(); + unsigned axis = adaptor.getAxis(); + + auto srcTy = op.getOperand().getType().cast(); + auto srcLayout = srcTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto order = getOrder(srcLayout); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout); + + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + auto llvmIndexTy = getTypeConverter()->getIndexType(); + auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); + auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); + Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); + smemBase = bitcast(smemBase, elemPtrTy); + + ReduceOpHelper helper(op); + auto smemShapes = helper.getScratchConfigsFast(); + unsigned elems = product(smemShapes[0]); + unsigned maxElems = std::max(elems, product(smemShapes[1])); + Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems)); + indexSmemBase = bitcast(indexSmemBase, indexPtrTy); + + unsigned sizeIntraWarps = helper.getIntraWarpSize(); + unsigned sizeInterWarps = helper.getInterWarpSize(); + + unsigned srcElems = getElemsPerThread(srcTy); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy); + auto srcValues = getTypeConverter()->unpackLLElements( + loc, adaptor.getOperand(), rewriter, srcTy); + + SmallVector> offset = + emitOffsetForLayout(srcLayout, srcTy); + + std::map, Value> accs; + std::map, Value> accIndices; + std::map, SmallVector> indices; + + auto ¤tBlock = *rewriter.getBlock(); + auto *reduceOp = &op.getRegion(); + + // reduce within threads + for (unsigned i = 0; i < srcElems; ++i) { + SmallVector key = offset[i]; + key[axis] = 0; + bool isFirst = accs.find(key) == accs.end(); + accumulate(rewriter, *reduceOp, accs[key], srcValues[i], isFirst); + if (isFirst) + indices[key] = srcIndices[i]; + } + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(32); + Value warpId = udiv(threadId, warpSize); + Value laneId = urem(threadId, warpSize); + + SmallVector multiDimLaneId = + delinearize(rewriter, loc, laneId, threadsPerWarp, order); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, warpsPerCTA, order); + + Value laneIdAxis = multiDimLaneId[axis]; + Value warpIdAxis = multiDimWarpId[axis]; + + Value zero = i32_val(0); + Value laneZero = icmp_eq(laneIdAxis, zero); + + for (auto it : accs) { + const SmallVector &key = it.first; + Value acc = it.second; + Value accIndex; + + // Reduce within warps + for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) { + Value shfl = shflSync(loc, rewriter, acc, N); + accumulate(rewriter, *reduceOp, acc, shfl, false); + } + + SmallVector writeIdx = indices[key]; + writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis; + Value writeOffset = + linearize(rewriter, loc, writeIdx, smemShapes[0], order); + Value writePtr = gep(elemPtrTy, smemBase, writeOffset); + storeShared(rewriter, loc, writePtr, acc, laneZero); + } + + barrier(); + + // The second round of shuffle reduction + // now the problem size: sizeInterWarps, s1, s2, .. , sn + // where sizeInterWarps is 2^m + // + // Each thread needs to process: + // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads + unsigned numThreads = + product(triton::gpu::getWarpsPerCTA(srcLayout)) * 32; + unsigned elemsPerThread = std::max(elems / numThreads, 1); + Value readOffset = threadId; + for (unsigned round = 0; round < elemsPerThread; ++round) { + Value readPtr = gep(elemPtrTy, smemBase, readOffset); + // FIXME(Qingyi): need predicate icmp_slt(threadId, + // i32_val(sizeInerWarps)) + Value acc = load(readPtr); + Value accIndex; + + for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) { + Value shfl = shflSync(loc, rewriter, acc, N); + accumulate(rewriter, *reduceOp, acc, shfl, false); + } + + // only the first thread in each sizeInterWarps is writing + Value writeOffset = readOffset; + Value writePtr = gep(elemPtrTy, smemBase, writeOffset); + Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); + Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); + Value laneIdModSizeInterWarpsIsZero = + icmp_eq(laneIdModSizeInterWarps, zero); + Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); + storeShared(rewriter, loc, writePtr, acc, pred); + + if (round != elemsPerThread - 1) { + readOffset = add(readOffset, i32_val(numThreads)); + } + } + + // We could avoid this barrier in some of the layouts, however this is not + // the general case. + // TODO: optimize the barrier incase the layouts are accepted. + barrier(); + + // set output values + if (auto resultTy = op.getType().dyn_cast()) { + // nd-tensor where n >= 1 + auto resultLayout = resultTy.getEncoding().cast(); + unsigned resultElems = getElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (size_t i = 0; i < resultElems; ++i) { + SmallVector readIdx = resultIndices[i]; + readIdx.insert(readIdx.begin() + axis, i32_val(0)); + Value readOffset = + linearize(rewriter, loc, readIdx, smemShapes[0], order); + Value readPtr = gep(elemPtrTy, smemBase, readOffset); + Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); + resultVals[i] = load(readPtr); + } + + Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter, + resultTy); + rewriter.replaceOp(op, ret); + } else { + // 0d-tensor -> scalar + Value resultVal = load(smemBase); + rewriter.replaceOp(op, resultVal); + } + + return success(); + } +}; + +void populateGenericReduceOpToLLVMPatterns( + TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, AxisInfoAnalysis &axisInfoAnalysis, + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, allocation, smem, + indexCacheInfo, benefit); +} diff --git a/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.h new file mode 100644 index 000000000000..2280e2c64e89 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.h @@ -0,0 +1,16 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_GENERIC_REDUCE_OP_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_GENERIC_REDUCE_OP_H + +#include "TritonGPUToLLVMBase.h" + +using namespace mlir; +using namespace mlir::triton; + +void populateGenericReduceOpToLLVMPatterns( + TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, AxisInfoAnalysis &axisInfoAnalysis, + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit); + +#endif diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 2bad78e768ac..28142d24fb78 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -23,6 +23,7 @@ #include "ElementwiseOpToLLVM.h" #include "LoadStoreOpToLLVM.h" #include "ReduceOpToLLVM.h" +#include "./GenericReduceOpToLLVM.h" #include "TritonGPUToLLVM.h" #include "TypeConverter.h" #include "ViewOpToLLVM.h" @@ -199,6 +200,7 @@ class ConvertTritonGPUToLLVM populatePatterns2(populateElementwiseOpToLLVMPatterns); populatePatterns1(populateLoadStoreOpToLLVMPatterns); populatePatterns1(populateReduceOpToLLVMPatterns); + populatePatterns1(populateGenericReduceOpToLLVMPatterns); populatePatterns2(populateViewOpToLLVMPatterns); // Native lowering patterns mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 4f347a532bfd..676041fe14e7 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -477,6 +477,38 @@ struct TritonReducePattern : public OpConversionPattern { } }; +struct TritonGenericReducePattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::GenericReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newReduce = rewriter.create( + op.getLoc(), adaptor.getOperand(), adaptor.getAxis()); + addNamedAttrs(newReduce, adaptor.getAttributes()); + + auto &newRegion = newReduce.getRegion(); + rewriter.inlineRegionBefore(op.getRegion(), newRegion, newRegion.end()); + rewriter.replaceOp(op, newReduce.getResult()); + return success(); + } +}; + +struct TritonGenericReduceReturnPattern : + public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::GenericReduceReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + addNamedAttrs( + rewriter.replaceOpWithNewOp( + op, adaptor.getResult()), + adaptor.getAttributes()); + return success(); + } +}; + struct TritonPrintPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -518,6 +550,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonGenericPattern, TritonBroadcastPattern, TritonGenericPattern, TritonCatPattern, TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern, + TritonGenericReducePattern, TritonGenericReduceReturnPattern, TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern, TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context); From ab461e8a4d780b7c300071a8b247ab993a823f89 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Mon, 13 Mar 2023 17:47:03 +0000 Subject: [PATCH 03/15] Support simultaneous reduction of multiple tensors --- include/triton/Analysis/Utility.h | 49 +-- include/triton/Dialect/Triton/IR/TritonOps.td | 17 +- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 10 +- lib/Analysis/Utility.cpp | 27 +- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 2 + .../TritonGPUToLLVM/GenericReduceOpToLLVM.cpp | 316 +++++++++++------- .../TritonGPUToLLVM/TypeConverter.cpp | 24 +- .../TritonToTritonGPUPass.cpp | 22 +- lib/Dialect/Triton/IR/Ops.cpp | 108 ++++-- .../Transforms/TritonGPUConversion.cpp | 2 + python/src/triton.cc | 35 +- python/triton/language/__init__.py | 1 + python/triton/language/core.py | 18 +- python/triton/language/semantic.py | 55 ++- 14 files changed, 459 insertions(+), 227 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 8397bb28932e..2a17c0e6d6d3 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -11,29 +11,39 @@ namespace mlir { class ReduceOpHelper { - ReduceOpHelper(Operation *op, int axis, bool withIndex) - : op(op), axis(axis), withIndex(withIndex) { - srcTy = op->getOperands().front().getType().cast(); - } - public: - explicit ReduceOpHelper(triton::ReduceOp op): - ReduceOpHelper( - op.getOperation(), - op.getAxis(), - triton::ReduceOp::withIndex(op.getRedOp())) { + explicit ReduceOpHelper(triton::ReduceOp rop): + op(rop.getOperation()), axis(rop.getAxis()) { + auto srcTy = rop.getOperand().getType().cast(); + srcShape = srcTy.getShape(); + srcEncoding = srcTy.getEncoding(); + srcElementTypes.push_back(srcTy.getElementType()); + + if (triton::ReduceOp::withIndex(rop.getRedOp())) { + srcElementTypes.push_back(Builder(op).getI32Type()); + } } - explicit ReduceOpHelper(triton::GenericReduceOp op): - ReduceOpHelper( - op.getOperation(), - op.getAxis(), - /*withIndex*/false) { + explicit ReduceOpHelper(triton::GenericReduceOp rop): + op(rop.getOperation()), axis(rop.getAxis()) { + auto firstTy = rop.getOperands()[0].getType().cast(); + srcShape = firstTy.getShape(); + srcEncoding = firstTy.getEncoding(); + srcElementTypes = rop.getElementTypes(); + + for (const auto &t : rop.getInputTypes()) { + if (t.getShape() != srcShape) { + rop.emitError() << "shape mismatch"; + } + if (t.getEncoding() != srcEncoding) { + rop.emitError() << "encoding mismatch"; + } + } } - ArrayRef getSrcShape() { return srcTy.getShape(); } + ArrayRef getSrcShape() { return srcShape; } - Attribute getSrcLayout() { return srcTy.getEncoding(); } + Attribute getSrcLayout() { return srcEncoding; } bool isFastReduction(); @@ -51,9 +61,10 @@ class ReduceOpHelper { private: Operation *op; - RankedTensorType srcTy{}; + ArrayRef srcShape; + Attribute srcEncoding; + SmallVector srcElementTypes; int axis; - bool withIndex; }; bool isSharedEncoding(Value value); diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index d4dad7bf6980..6436660fef73 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -372,16 +372,25 @@ def TT_ReduceOp : TT_Op<"reduce", [Pure, def TT_GenericReduceOp: TT_Op<"generic_reduce", [Pure, DeclareOpInterfaceMethods, SingleBlock]> { let summary = "Reduction using generic combination algorithm"; - let arguments = (ins TT_Tensor:$operand, I32Attr:$axis); - let results = (outs TT_Type:$result); - let regions = (region SizedRegion<1>:$region); + let arguments = (ins Variadic:$operands, I32Attr:$axis); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$operands, "int":$axis)>, + ]; + let hasVerifier = 1; let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + }]; } def TT_GenericReduceReturnOp: TT_Op<"generic_reduce.return", [HasParent<"GenericReduceOp">, Pure, Terminator, ReturnLike]> { let summary = "terminator for reduce operator"; - let arguments = (ins AnyType:$result); + let arguments = (ins Variadic:$result); let assemblyFormat = "$result attr-dict `:` type($result)"; } diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 6cb541b21a86..9d5162373e4b 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -64,7 +64,7 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> { // handle encodings // e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111 def TTG_CmpIOp : TTG_Op<"cmpi", [Pure, Elementwise, - SameOperandsAndResultShape, + SameOperandsAndResultShape, SameOperandsAndResultEncoding]> { let summary = "integer comparison operation"; @@ -78,7 +78,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [Pure, Elementwise, } def TTG_CmpFOp : TTG_Op<"cmpf", [Pure, Elementwise, - SameOperandsAndResultShape, + SameOperandsAndResultShape, SameOperandsAndResultEncoding]> { let summary = "floating-point comparison operation"; @@ -100,10 +100,10 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise, let description = [{}]; let arguments = (ins TT_BoolLike:$condition, - TT_Tensor:$true_value, - TT_Tensor:$false_value); + TT_Type:$true_value, + TT_Type:$false_value); - let results = (outs TT_Tensor:$result); + let results = (outs TT_Type:$result); } diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 85915ecad340..a5e594b40357 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -10,29 +10,24 @@ namespace mlir { bool ReduceOpHelper::isFastReduction() { - auto srcLayout = srcTy.getEncoding(); - return axis == triton::gpu::getOrder(srcLayout)[0]; + return axis == triton::gpu::getOrder(getSrcLayout())[0]; } unsigned ReduceOpHelper::getInterWarpSize() { - auto srcLayout = srcTy.getEncoding(); - auto srcShape = srcTy.getShape(); auto srcReduceDimSize = static_cast(srcShape[axis]); unsigned sizeIntraWarps = getIntraWarpSize(); return std::min(srcReduceDimSize / sizeIntraWarps, - triton::gpu::getWarpsPerCTA(srcLayout)[axis]); + triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]); } unsigned ReduceOpHelper::getIntraWarpSize() { - auto srcLayout = srcTy.getEncoding(); - auto srcShape = srcTy.getShape(); auto srcReduceDimSize = static_cast(srcShape[axis]); return std::min(srcReduceDimSize, - triton::gpu::getThreadsPerWarp(srcLayout)[axis]); + triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]); } unsigned ReduceOpHelper::getThreadsReductionAxis() { - auto srcLayout = srcTy.getEncoding(); + auto srcLayout = getSrcLayout(); return triton::gpu::getThreadsPerWarp(srcLayout)[axis] * triton::gpu::getWarpsPerCTA(srcLayout)[axis]; } @@ -46,7 +41,7 @@ SmallVector ReduceOpHelper::getScratchConfigBasic() { SmallVector> ReduceOpHelper::getScratchConfigsFast() { SmallVector> smemShapes(3); - auto argLayout = srcTy.getEncoding(); + auto argLayout = getSrcLayout(); auto argLayoutMma = argLayout.dyn_cast(); if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 && triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1) @@ -76,13 +71,11 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() { elems = product(smemShape); } - auto tensorType = op->getOperand(0).getType().cast(); - unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8; - - if (withIndex) - bytes += elems * sizeof(int32_t); - - return bytes; + unsigned bytes_per_elem = 0; + for (const auto &ty: srcElementTypes) { + bytes_per_elem += ty.getIntOrFloatBitWidth() / 8; + } + return bytes_per_elem * elems; } bool isSharedEncoding(Value value) { diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 1b4966a1bf8a..483c54f34b35 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -972,6 +972,8 @@ void populateElementwiseOpToLLVMPatterns( POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> + POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin + POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin #undef POPULATE_BINARY_OP #define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ diff --git a/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp index 7d6731638b1a..37fddd329bb2 100644 --- a/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp @@ -24,71 +24,111 @@ struct GenericReduceOpConversion private: - void accumulate(ConversionPatternRewriter &rewriter, - Region &reduceOp, Value &acc, Value cur, bool isFirst) const { + void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, + llvm::SmallVectorImpl &acc, ValueRange cur, bool isFirst) const { if (isFirst) { - acc = cur; + acc.resize(cur.size()); + for (unsigned i = 0; i < cur.size(); ++i) { + acc[i] = cur[i]; + } return; } // Create a new copy of the reduce block, and inline it Block *currentBlock = rewriter.getBlock(); Region &parent = *currentBlock->getParent(); - rewriter.cloneRegionBefore(reduceOp, &parent.front()); + rewriter.cloneRegionBefore(combineOp, &parent.front()); auto &newReduce = parent.front(); auto returnOp = dyn_cast(newReduce.getTerminator()); - rewriter.mergeBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), {acc, cur}); - acc = returnOp.getResult(); + + llvm::SmallVector combineArgs(2*acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + rewriter.mergeBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), combineArgs); + + auto results = returnOp.getResult(); + for (unsigned i = 0; i < acc.size(); ++i) { + acc[i] = results[i]; + } + // Delete the terminator, which is no longer used rewriter.eraseOp(returnOp); } + SmallVector> unpackInputs( + Location loc, triton::GenericReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = getTypeConverter()->unpackLLElements( + loc, operands[i], rewriter, types[i]); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } + } + return srcValues; + } + // Use shared memory for reduction within warps and across warps LogicalResult matchAndRewriteBasic(triton::GenericReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - Location loc = op->getLoc(); + Location loc = op.getLoc(); unsigned axis = op.getAxis(); - auto srcTy = op.getOperand().getType().cast(); - auto srcLayout = srcTy.getEncoding().cast(); + ReduceOpHelper helper(op); + auto srcTys = op.getInputTypes(); + auto srcLayout = helper.getSrcLayout().cast(); auto srcOrd = srcLayout.getOrder(); - auto srcShape = srcTy.getShape(); + auto srcShape = helper.getSrcShape(); - auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + SmallVector elemPtrTys(srcTys.size()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto ty = srcTys[i].getElementType(); + auto llvmElemTy = getTypeConverter()->convertType(ty); + elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3); + } auto llvmIndexTy = getTypeConverter()->getIndexType(); - auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); - smemBase = bitcast(smemBase, elemPtrTy); - ReduceOpHelper helper(op); auto smemShape = helper.getScratchConfigBasic(); unsigned elems = product(smemShape); - Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(elems)); - indexSmemBase = bitcast(indexSmemBase, indexPtrTy); - unsigned srcElems = getElemsPerThread(srcTy); - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy); - auto srcValues = getTypeConverter()->unpackLLElements( - loc, adaptor.getOperand(), rewriter, srcTy); + SmallVector smemBases(op.getNumOperands()); + smemBases[0] = bitcast( + getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + smemBases[i] = gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(elems)); + } + + unsigned srcElems = getElemsPerThread(srcTys[0]); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]); + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + // Assumes offsets don't actually depend on type SmallVector> offset = - emitOffsetForLayout(srcLayout, srcTy); + emitOffsetForLayout(srcLayout, srcTys[0]); - std::map, Value> accs; - std::map, Value> accIndices; + std::map, SmallVector> accs; std::map, SmallVector> indices; - Region *reduceOp = &op.getRegion(); + Region *combineOp = &op.getCombineOp(); // reduce within threads for (unsigned i = 0; i < srcElems; ++i) { SmallVector key = offset[i]; key[axis] = 0; bool isFirst = accs.find(key) == accs.end(); - accumulate(rewriter, *reduceOp, accs[key], srcValues[i], isFirst); + accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); if (isFirst) indices[key] = srcIndices[i]; } @@ -103,14 +143,17 @@ struct GenericReduceOpConversion // reduce across threads for (auto it : accs) { const SmallVector &key = it.first; - Value acc = it.second; + auto &acc = it.second; SmallVector writeIdx = indices[key]; writeIdx[axis] = udiv(writeIdx[axis], sizePerThread); Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd); - Value writePtr = gep(elemPtrTy, smemBase, writeOffset); - Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset); - store(acc, writePtr); + SmallVector writePtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset); + store(acc[i], writePtrs[i]); + } + SmallVector readIdx(writeIdx.size(), ints[0]); for (int N = smemShape[axis] / 2; N > 0; N >>= 1) { @@ -119,44 +162,56 @@ struct GenericReduceOpConversion Value readOffset = select( readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd), ints[0]); - Value readPtr = gep(elemPtrTy, writePtr, readOffset); + SmallVector readPtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + readPtrs[i] = gep(elemPtrTys[i], writePtrs[i], readOffset); + } + barrier(); - Value cur = load(readPtr); - accumulate(rewriter, *reduceOp, acc, cur, false); + SmallVector cur(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + cur[i] = load(gep(elemPtrTys[i], readPtrs[i], readOffset)); + } + accumulate(rewriter, *combineOp, acc, cur, false); barrier(); - store(acc, writePtr); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + store(acc[i], writePtrs[i]); + } } } barrier(); // set output values - if (auto resultTy = op.getType().dyn_cast()) { - // nd-tensor where n >= 1 - auto resultLayout = resultTy.getEncoding(); - - unsigned resultElems = getElemsPerThread(resultTy); - auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); - assert(resultIndices.size() == resultElems); - - SmallVector resultVals(resultElems); - for (unsigned i = 0; i < resultElems; ++i) { - SmallVector readIdx = resultIndices[i]; - readIdx.insert(readIdx.begin() + axis, ints[0]); - Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd); - Value readPtr = gep(elemPtrTy, smemBase, readOffset); - Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); - resultVals[i] = load(readPtr); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = op.getResult()[i].getType().dyn_cast()) { + // nd-tensor where n >= 1 + + auto resultLayout = resultTy.getEncoding(); + + unsigned resultElems = getElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (unsigned j = 0; j < resultElems; ++j) { + SmallVector readIdx = resultIndices[j]; + readIdx.insert(readIdx.begin() + axis, ints[0]); + Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd); + Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); + resultVals[j] = load(readPtr); + } + results[i] = getTypeConverter()->packLLElements( + loc, resultVals, rewriter, resultTy); + } else { + // 0d-tensor -> scalar + results[i] = load(smemBases[i]); } - Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter, - resultTy); - rewriter.replaceOp(op, ret); - } else { - // 0d-tensor -> scalar - Value resultVal = load(smemBase); - rewriter.replaceOp(op, resultVal); } + auto parentBlock = op.getOperation()->getBlock(); + rewriter.replaceOp(op, results); return success(); } @@ -168,52 +223,54 @@ struct GenericReduceOpConversion Location loc = op->getLoc(); unsigned axis = adaptor.getAxis(); - auto srcTy = op.getOperand().getType().cast(); - auto srcLayout = srcTy.getEncoding(); - auto srcShape = srcTy.getShape(); - auto order = getOrder(srcLayout); - - auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); - auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout); + ReduceOpHelper helper(op); + auto srcTys = op.getInputTypes(); + auto srcLayout = helper.getSrcLayout().cast(); + auto srcOrd = srcLayout.getOrder(); + auto srcShape = helper.getSrcShape(); - auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + SmallVector elemPtrTys(srcTys.size()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto ty = srcTys[i].getElementType(); + auto llvmElemTy = getTypeConverter()->convertType(ty); + elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3); + } auto llvmIndexTy = getTypeConverter()->getIndexType(); - auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); - smemBase = bitcast(smemBase, elemPtrTy); - ReduceOpHelper helper(op); auto smemShapes = helper.getScratchConfigsFast(); unsigned elems = product(smemShapes[0]); unsigned maxElems = std::max(elems, product(smemShapes[1])); - Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems)); - indexSmemBase = bitcast(indexSmemBase, indexPtrTy); + + SmallVector smemBases(op.getNumOperands()); + smemBases[0] = bitcast( + getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + smemBases[i] = gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)); + } unsigned sizeIntraWarps = helper.getIntraWarpSize(); unsigned sizeInterWarps = helper.getInterWarpSize(); - unsigned srcElems = getElemsPerThread(srcTy); - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy); - auto srcValues = getTypeConverter()->unpackLLElements( - loc, adaptor.getOperand(), rewriter, srcTy); + unsigned srcElems = getElemsPerThread(srcTys[0]); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]); + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); - SmallVector> offset = - emitOffsetForLayout(srcLayout, srcTy); - - std::map, Value> accs; - std::map, Value> accIndices; + std::map, SmallVector> accs; std::map, SmallVector> indices; - auto ¤tBlock = *rewriter.getBlock(); - auto *reduceOp = &op.getRegion(); + // Assumes offsets don't actually depend on type + SmallVector> offset = + emitOffsetForLayout(srcLayout, srcTys[0]); + + auto *combineOp = &op.getCombineOp(); // reduce within threads for (unsigned i = 0; i < srcElems; ++i) { SmallVector key = offset[i]; key[axis] = 0; bool isFirst = accs.find(key) == accs.end(); - accumulate(rewriter, *reduceOp, accs[key], srcValues[i], isFirst); + accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); if (isFirst) indices[key] = srcIndices[i]; } @@ -223,6 +280,9 @@ struct GenericReduceOpConversion Value warpId = udiv(threadId, warpSize); Value laneId = urem(threadId, warpSize); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout); + auto order = getOrder(srcLayout); SmallVector multiDimLaneId = delinearize(rewriter, loc, laneId, threadsPerWarp, order); SmallVector multiDimWarpId = @@ -236,21 +296,25 @@ struct GenericReduceOpConversion for (auto it : accs) { const SmallVector &key = it.first; - Value acc = it.second; - Value accIndex; + SmallVector acc = it.second; // Reduce within warps for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) { - Value shfl = shflSync(loc, rewriter, acc, N); - accumulate(rewriter, *reduceOp, acc, shfl, false); + SmallVector shfl(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + shfl[i] = shflSync(loc, rewriter, acc[i], N); + } + accumulate(rewriter, *combineOp, acc, shfl, false); } SmallVector writeIdx = indices[key]; writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis; Value writeOffset = linearize(rewriter, loc, writeIdx, smemShapes[0], order); - Value writePtr = gep(elemPtrTy, smemBase, writeOffset); - storeShared(rewriter, loc, writePtr, acc, laneZero); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + Value writePtr = gep(elemPtrTys[i], smemBases[i], writeOffset); + storeShared(rewriter, loc, writePtr, acc[i], laneZero); + } } barrier(); @@ -266,26 +330,37 @@ struct GenericReduceOpConversion unsigned elemsPerThread = std::max(elems / numThreads, 1); Value readOffset = threadId; for (unsigned round = 0; round < elemsPerThread; ++round) { - Value readPtr = gep(elemPtrTy, smemBase, readOffset); // FIXME(Qingyi): need predicate icmp_slt(threadId, // i32_val(sizeInerWarps)) - Value acc = load(readPtr); - Value accIndex; + SmallVector acc(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); + acc[i] = load(readPtr); + } for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) { - Value shfl = shflSync(loc, rewriter, acc, N); - accumulate(rewriter, *reduceOp, acc, shfl, false); + SmallVector shfl(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + shfl[i] = shflSync(loc, rewriter, acc[i], N); + } + accumulate(rewriter, *combineOp, acc, shfl, false); } // only the first thread in each sizeInterWarps is writing Value writeOffset = readOffset; - Value writePtr = gep(elemPtrTy, smemBase, writeOffset); + SmallVector writePtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset); + } Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); Value laneIdModSizeInterWarpsIsZero = icmp_eq(laneIdModSizeInterWarps, zero); Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); - storeShared(rewriter, loc, writePtr, acc, pred); + + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + storeShared(rewriter, loc, writePtrs[i], acc[i], pred); + } if (round != elemsPerThread - 1) { readOffset = add(readOffset, i32_val(numThreads)); @@ -298,32 +373,33 @@ struct GenericReduceOpConversion barrier(); // set output values - if (auto resultTy = op.getType().dyn_cast()) { - // nd-tensor where n >= 1 - auto resultLayout = resultTy.getEncoding().cast(); - unsigned resultElems = getElemsPerThread(resultTy); - auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); - assert(resultIndices.size() == resultElems); - - SmallVector resultVals(resultElems); - for (size_t i = 0; i < resultElems; ++i) { - SmallVector readIdx = resultIndices[i]; - readIdx.insert(readIdx.begin() + axis, i32_val(0)); - Value readOffset = - linearize(rewriter, loc, readIdx, smemShapes[0], order); - Value readPtr = gep(elemPtrTy, smemBase, readOffset); - Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); - resultVals[i] = load(readPtr); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = op.getResult()[i].getType().dyn_cast()) { + // nd-tensor where n >= 1 + auto resultLayout = resultTy.getEncoding().cast(); + unsigned resultElems = getElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (size_t j = 0; j < resultElems; ++j) { + SmallVector readIdx = resultIndices[j]; + readIdx.insert(readIdx.begin() + axis, i32_val(0)); + Value readOffset = + linearize(rewriter, loc, readIdx, smemShapes[0], order); + Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); + resultVals[j] = load(readPtr); + } + + results[i] = getTypeConverter()->packLLElements( + loc, resultVals, rewriter, resultTy); + } else { + // 0d-tensor -> scalar + results[i] = load(smemBases[i]); } - - Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter, - resultTy); - rewriter.replaceOp(op, ret); - } else { - // 0d-tensor -> scalar - Value resultVal = load(smemBase); - rewriter.replaceOp(op, resultVal); } + rewriter.replaceOp(op, results); return success(); } diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index 7bf36c2a6d31..80d80e8852a2 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -50,15 +50,29 @@ Type TritonGPUToLLVMTypeConverter::convertTritonPointerType( Value TritonGPUToLLVMTypeConverter::packLLElements( Location loc, ValueRange resultVals, ConversionPatternRewriter &rewriter, Type type) { - auto structType = this->convertType(type); - if (!structType.isa()) { + auto structType = this->convertType(type).dyn_cast(); + if (!structType) { + assert(resultVals.size() == 1); return *resultVals.begin(); } + auto elementTypes = structType.getBody(); + if (elementTypes.size() != resultVals.size()) { + emitError(loc) << " size mismatch when packing elements for LLVM struct" + << " expected " << elementTypes.size() << " but got " + << resultVals.size(); + } Value llvmStruct = rewriter.create(loc, structType); - // llvm::outs() << structType << "\n"; for (const auto &v : llvm::enumerate(resultVals)) { - assert(v.value() && "can not insert null values"); + if (!v.value()) { + emitError(loc) << "cannot insert null values into struct, but tried to insert" + << v.value(); + } + if (v.value().getType() != elementTypes[v.index()]) { + emitError(loc) << "invalid element type in packLLEElements. Expected " + << elementTypes[v.index()] << " but got " << v.value().getType(); + + } llvmStruct = insert_val(structType, llvmStruct, v.value(), v.index()); } return llvmStruct; @@ -160,4 +174,4 @@ TritonGPUToLLVMTypeConverter::convertTritonTensorType(RankedTensorType type) { } return std::nullopt; -} \ No newline at end of file +} diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 676041fe14e7..8eeba607bc4f 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -70,13 +70,15 @@ class ArithConstantPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); auto value = adaptor.getValue().dyn_cast(); - assert(value); - if (value.getElementType().isInteger(1) && value.isSplat()) - // Workaround until https://reviews.llvm.org/D133743 is included. - value = DenseElementsAttr::get(retType, value.getSplatValue()); - else - // This is a hack. We just want to add encoding - value = value.reshape(retType); + if (dyn_cast(retType)) { + assert(value); + if (value.getElementType().isInteger(1) && value.isSplat()) + // Workaround until https://reviews.llvm.org/D133743 is included. + value = DenseElementsAttr::get(retType, value.getSplatValue()); + else + // This is a hack. We just want to add encoding + value = value.reshape(retType); + } addNamedAttrs( rewriter.replaceOpWithNewOp(op, retType, value), adaptor.getAttributes()); @@ -484,11 +486,11 @@ struct TritonGenericReducePattern : public OpConversionPattern( - op.getLoc(), adaptor.getOperand(), adaptor.getAxis()); + op.getLoc(), adaptor.getOperands(), adaptor.getAxis()); addNamedAttrs(newReduce, adaptor.getAttributes()); - auto &newRegion = newReduce.getRegion(); - rewriter.inlineRegionBefore(op.getRegion(), newRegion, newRegion.end()); + auto &newCombineOp = newReduce.getCombineOp(); + rewriter.inlineRegionBefore(op.getCombineOp(), newCombineOp, newCombineOp.end()); rewriter.replaceOp(op, newReduce.getResult()); return success(); } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 2ddcd5c5461c..26270605a711 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -295,44 +295,112 @@ bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) { } //-- GenericReduceOp -- +void GenericReduceOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::ValueRange operands, int axis) { + SmallVector inferredReturnTypes; + for (unsigned i = 0; i < operands.size(); ++i) { + auto argTy = operands[i].getType().cast(); + auto retEltTy = argTy.getElementType(); + (void)inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); + } + + GenericReduceOp::build(builder, state, inferredReturnTypes, operands, axis); +} + mlir::LogicalResult mlir::triton::GenericReduceOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - Value arg = operands[0]; - auto argTy = arg.getType().cast(); - auto retEltTy = argTy.getElementType(); - int axis = attributes.get("axis").cast().getInt(); - return inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); + for (auto arg : operands) { + auto argTy = arg.getType().cast(); + auto retEltTy = argTy.getElementType(); + int axis = attributes.get("axis").cast().getInt(); + if ( + inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes) + .failed()) { + return failure(); + } + } + return success(); +} + +mlir::LogicalResult mlir::triton::GenericReduceOp::verify() { + if (this->getOperands().size() < 1) { + return this->emitOpError() << "tt.generic_reduce must have at least 1 operand"; + } + for (const auto &operand: this->getOperands()) { + if (!dyn_cast(operand.getType())) { + return this->emitOpError() << "tt.generic_reduce operands must be RankedTensorType"; + } + } + return success(); } mlir::LogicalResult mlir::triton::GenericReduceOp::verifyRegions() { - auto argTy = getOperand().getType().cast(); - auto argElemTy = argTy.getElementType(); - - constexpr unsigned num_args = 2; - auto &block = this->getBody(); - if (block.getNumArguments() != num_args) { - return emitOpError() << "nested block must take " << num_args - << " arguments, but given block with " - << block.getNumArguments() << " arguments"; + auto argElementTypes = this->getElementTypes(); + const auto &operands = this->getOperands(); + const auto numArgs = 2 * operands.size(); + auto &block = *this->getBody(); + if (block.getNumArguments() != numArgs) { + return this->emitOpError() << "nested block must take " << numArgs + << " arguments, but given block with " + << block.getNumArguments() << " arguments"; } unsigned i = 0; - for (const auto & blockArgTy: block.getArgumentTypes()) { + const auto &blockArgTypes = block.getArgumentTypes(); + for (unsigned i = 0; i < numArgs; ++i) { + const auto &blockArgTy = blockArgTypes[i]; + const auto &argElemTy = argElementTypes[i % operands.size()]; if (blockArgTy != argElemTy) { - return this->emitOpError() << "types mismatch on reduction block. Expected argument " << i + return this->emitOpError() << "type mismatch on combine operation. Expected argument " << i << " to have type " << argElemTy << " but got " << blockArgTy; } - ++i; } - if (!mlir::isa(block.getTerminator())) { - return this->emitOpError("the GenericReduceOp region must be terminated " - "with a GenericReduceReturnOp but got") << block.getTerminator(); + auto terminator = dyn_cast(block.getTerminator()); + if (!terminator) { + return this->emitOpError() << "combine operation must be terminated " + << "with a GenericReduceReturnOp but got " + << block.getTerminator(); + } + const auto &combineResults = terminator->getOperands(); + if (combineResults.size() != operands.size()) { + return this->emitOpError() << "expected combine operation to return " << operands.size() + << " values but got " << combineResults.size(); + } + for (unsigned i = 0; i < combineResults.size(); ++i) { + const auto &resultTy = combineResults[i].getType(); + const auto &argElemTy = argElementTypes[i]; + if (resultTy != argElemTy) { + return this->emitOpError() << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << resultTy; + } } return mlir::success(); } +llvm::SmallVector GenericReduceOp::getInputTypes() { + llvm::SmallVector srcTys; + srcTys.reserve(this->getNumOperands()); + for (const auto &ty: this->getOperands().getTypes()) { + srcTys.push_back(ty.cast()); + } + return srcTys; +} + +llvm::SmallVector GenericReduceOp::getElementTypes() { + llvm::SmallVector srcElemTys; + srcElemTys.reserve(this->getNumOperands()); + for (const auto &op: this->getOperands()) { + srcElemTys.push_back(op.getType().cast().getElementType()); + } + return srcElemTys; +} + +unsigned GenericReduceOp::getNumOperands() { + return this->getOperands().size(); +} + //-- SplatOp -- OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { auto value = adaptor.getSrc(); diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 68ecd4caee28..cc1c1d245c44 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -79,6 +79,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( // Some ops from SCF are illegal addIllegalOp(); + // We have custom versions of some arith operators + addIllegalOp(); addDynamicallyLegalDialect(loc, lhs, rhs); }) + .def("create_fmin", + [](mlir::OpBuilder &self, mlir::Value &lhs, + mlir::Value &rhs) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, lhs, rhs); + }) + .def("create_smin", + [](mlir::OpBuilder &self, mlir::Value &lhs, + mlir::Value &rhs) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, lhs, rhs); + }) .def("create_add", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { @@ -1329,24 +1341,21 @@ void init_triton_ir(py::module &&m) { operand, axis); }) .def("create_generic_reduce", - [](mlir::OpBuilder &self, mlir::Value &operand, int axis) -> mlir::triton::GenericReduceOp { + []( + mlir::OpBuilder &self, std::vector operands, int axis + ) -> mlir::triton::GenericReduceOp { auto loc = self.getUnknownLoc(); - auto inputTensorType = - operand.getType().dyn_cast(); - std::vector shape = inputTensorType.getShape(); - shape.erase(shape.begin() + axis); - mlir::Type resType = inputTensorType.getElementType(); - if (!shape.empty()) { - resType = mlir::RankedTensorType::get(shape, resType); - } - return self.create( - loc, resType, operand, axis); + return self.create(loc, operands, axis); }) .def("create_reduce_ret", - [](mlir::OpBuilder &self, mlir::Value &return_value) -> mlir::OpState { + [](mlir::OpBuilder &self, py::args args) -> mlir::OpState { auto loc = self.getUnknownLoc(); + llvm::SmallVector return_values; + for (const auto & arg: args) { + return_values.push_back(py::cast(arg)); + } return self.create( - loc, return_value); + loc, return_values); }) .def("create_ptr_to_int", [](mlir::OpBuilder &self, mlir::Value &val, diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 8f26b55448ef..dd34684da4b0 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -10,6 +10,7 @@ abs, arange, argmin, + argmin2, argmax, atomic_add, atomic_and, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index af79d74a3a2f..1e75616079a4 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1170,11 +1170,27 @@ def xor_sum(input, axis, _builder=None): @builtin @_add_reduction_docstr("prod") -def prod(input, axis, _builder): +def prod(input, axis, _builder=None): axis = _constexpr_to_value(axis) return semantic.prod(input, axis, _builder) +@builtin +@_add_reduction_docstr("argmin2") +def argmin2(input, axis, _builder=None): + + axis = _constexpr_to_value(axis) + n = input.shape[axis] + index = arange(0, n, _builder=_builder) + new_shape = [constexpr(1)] * len(input.shape) + new_shape[axis] = constexpr(n) + index = view(index, new_shape, _builder=_builder) + index = broadcast_to(index, input.shape, _builder=_builder) + + values, indices = semantic.min_with_index(input, index, axis, _builder) + return indices + + # ----------------------- # Internal for debugging # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 53172a6d53c6..226b7e986f2c 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1186,23 +1186,32 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: raise ValueError("xor_sum only supported for integers") return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR) -def reduction(input: tl.tensor, axis: int, region_builder_fn, builder: ir.builder) -> tl.tensor: - scalar_ty = input.type.scalar - - # get result type - shape = input.type.shape +def reduction( + inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder +) -> Tuple[tl.tensor, ...]: + # get result shape + shape = inputs[0].type.shape + print(shape, axis) ret_shape = [s for i, s in enumerate(shape) if i != axis] - if ret_shape: - res_ty = tl.block_type(scalar_ty, ret_shape) - else: - # 0d-tensor -> scalar - res_ty = out_scalar_ty + for t in inputs: + assert t.type.shape == shape + + def wrap_tensor(x, scalar_ty): + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = out_scalar_ty + return tl.tensor(x, res_ty) - reduce_op = builder.create_generic_reduce(input.handle, axis) + reduce_op = builder.create_generic_reduce([t.handle for t in inputs], axis) region_builder_fn(reduce_op) reduce_op.verify() - return tl.tensor(reduce_op.get_result(0), res_ty) + return tuple( + wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar) + for i in range(len(inputs)) + ) @contextmanager def insertion_guard(builder): @@ -1220,7 +1229,27 @@ def make_mul(reduce_op): fmul = builder.create_fmul(block.arg(0), block.arg(1)) builder.create_reduce_ret(fmul) - return reduction(input, axis, make_mul, builder) + return reduction((input,), axis, make_mul, builder)[0] + +def min_with_index(keys: tl.tensor, values: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + + def make_min_with_index_combine(reduce_op): + ir_key_ty = keys.type.scalar.to_ir(builder) + ir_value_ty = values.type.scalar.to_ir(builder) + region = reduce_op.get_region(0) + with insertion_guard(builder): + block = builder.create_block_with_parent(region, [ir_key_ty, ir_value_ty] * 2) + value1, index1, value2, index2 = [block.arg(i) for i in range(4)] + lt = builder.create_fcmpOLT(value1, value2) + gt = builder.create_fcmpOGT(value1, value2) + index_min = builder.create_smin(index1, index2) + index_ret = builder.create_select( + lt, index1, builder.create_select(gt, index2, index_min)) + + value_min = builder.create_fmin(value1, value2) + builder.create_reduce_ret(value_min, index_ret) + + return reduction((keys, values), axis, make_min_with_index_combine, builder) # ===----------------------------------------------------------------------=== # Math From 0aba718980effa1a108dd751256bd57f40677118 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Tue, 14 Mar 2023 01:47:33 +0000 Subject: [PATCH 04/15] Automatically build reduction combine op region from JITFunction --- python/triton/compiler.py | 81 +++++++------ python/triton/language/core.py | 177 ++++++++++++++++++++++------- python/triton/language/semantic.py | 2 +- 3 files changed, 180 insertions(+), 80 deletions(-) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index e53bc6260ebb..6d1e76ff1f6d 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -17,6 +17,7 @@ from collections import namedtuple from pathlib import Path from typing import Any, Callable, Dict, Optional, Tuple, Union +import inspect import setuptools import torch @@ -802,6 +803,43 @@ def visit_Assert(self, node) -> Any: # Convert assert to triton's device_assert which happens on the device return triton.language.core.device_assert(test, msg, _builder=self.builder) + def call_JitFunction(self, fn: torch.runtime.JITFunction, args, kwargs): + args = inspect.getcallargs(fn.fn, *args, **kwargs) + args = [args[name] for name in fn.arg_names] + args = [arg if isinstance(arg, triton.language.tensor) + else triton.language.constexpr(arg) for arg in args] + # generate function def + attributes = dict() + constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in args if arg is not None] + fn_name = mangle_fn(fn.__name__, arg_types, constants) + # generate function def if necessary + if not self.module.has_function(fn_name): + prototype = triton.language.function_type([], arg_types) + gscope = sys.modules[fn.fn.__module__].__dict__ + generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=self.debug) + generator.visit(fn.parse()) + callee_ret_type = generator.last_ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + call_op = self.builder.call(symbol, arg_vals) + if call_op.get_num_results() == 0 or callee_ret_type is None: + return None + elif call_op.get_num_results() == 1: + return triton.language.tensor(call_op.get_result(0), callee_ret_type) + else: + # should return a tuple of tl.tensor + results = [] + for i in range(call_op.get_num_results()): + results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i])) + return tuple(results) + def visit_Call(self, node): fn = self.visit(node.func) if isinstance(fn, triton.language.constexpr): @@ -817,44 +855,13 @@ def visit_Call(self, node): if not self.debug: return if isinstance(fn, triton.runtime.JITFunction): - from inspect import getcallargs - args = getcallargs(fn.fn, *args, **kws) - args = [args[name] for name in fn.arg_names] - args = [arg if isinstance(arg, triton.language.tensor) - else triton.language.constexpr(arg) for arg in args] - # generate function def - attributes = dict() - constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)] - constants = {i: args[i] for i in constexprs} - # generate call - args = [None if i in constexprs else arg for i, arg in enumerate(args)] - arg_vals = [arg.handle for arg in args if arg is not None] - arg_types = [arg.type for arg in args if arg is not None] - fn_name = mangle_fn(fn.__name__, arg_types, constants) - # generate function def if necessary - if not self.module.has_function(fn_name): - prototype = triton.language.function_type([], arg_types) - gscope = sys.modules[fn.fn.__module__].__dict__ - generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=self.debug) - generator.visit(fn.parse()) - callee_ret_type = generator.last_ret_type - self.function_ret_types[fn_name] = callee_ret_type - else: - callee_ret_type = self.function_ret_types[fn_name] - symbol = self.module.get_function(fn_name) - call_op = self.builder.call(symbol, arg_vals) - if call_op.get_num_results() == 0 or callee_ret_type is None: - return None - elif call_op.get_num_results() == 1: - return triton.language.tensor(call_op.get_result(0), callee_ret_type) - else: - # should return a tuple of tl.tensor - results = [] - for i in range(call_op.get_num_results()): - results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i])) - return tuple(results) + return self.call_JitFunction(fn, args, kws) if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) or impl.is_builtin(fn): - return fn(*args, _builder=self.builder, **kws) + extra_kwargs = dict(_builder=self.builder) + sig = inspect.signature(fn) + if '_generator' in sig.parameters: + extra_kwargs['_generator'] = self + return fn(*args, **kws, **extra_kwargs) if fn in self.builtin_namespace.values(): args = [arg.value if isinstance(arg, triton.language.constexpr) else arg for arg in args] return fn(*args, **kws) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 1e75616079a4..0ccbd8e5cb62 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -2,6 +2,7 @@ from enum import Enum from typing import Callable, List, TypeVar +from contextlib import contextmanager import triton from . import builtin, semantic @@ -1125,70 +1126,162 @@ def _decorator(func: T) -> T: return _decorator +@contextmanager +def _insertion_guard(builder): + ip = builder.get_insertion_point() + yield + builder.restore_insertion_point(ip) @builtin -@_add_reduction_docstr("maximum") -def max(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.max(input, axis, _builder) +def reduction(input, axis, combine_fn, _builder=None, _generator=None): + """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis` + :param input: the input tensor, or tuple of tensors + :param axis: the dimension along which the reduction should be done + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + + """ + if isinstance(input, tensor): + return reduction((input,), axis, combine_fn, _builder=_builder, _generator=_generator)[0] + + def make_combine_region(reduce_op): + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + + region = reduce_op.get_region(0) + with _insertion_guard(_builder): + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_reduce_ret(*handles) -@builtin -@_add_reduction_docstr("maximum index") -def argmax(input, axis, _builder=None): axis = _constexpr_to_value(axis) - return semantic.argmax(input, axis, _builder) + return semantic.reduction(input, axis, make_combine_region, _builder) @builtin -@_add_reduction_docstr("minimum") -def min(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.min(input, axis, _builder) +def _promote_reduction_input(t, _builder=None): + scalar_ty = t.type.scalar + # input is extended to 32-bits if necessary + # this increases numerical accuracy and can be done pretty much for free + # on GPUs + if scalar_ty.is_int() and scalar_ty.int_bitwidth < 32: + return t.to(int32, _builder=_builder) + # hardware doesn't support FMAX, FMIN, CMP for bfloat16 + if scalar_ty is bfloat16: + return t.to(float32, _builder=_builder) -@builtin -@_add_reduction_docstr("minimum index") -def argmin(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.argmin(input, axis, _builder) + return t @builtin -@_add_reduction_docstr("sum") -def sum(input, axis, _builder=None): +def _argreduce(input, axis, combine_fn, _builder=None, _generator=None): axis = _constexpr_to_value(axis) - return semantic.sum(input, axis, _builder) + n = input.shape[axis] + index = arange(0, n, _builder=_builder) + if len(input.shape) > 1: + new_shape = [constexpr(1)] * len(input.shape) + new_shape[axis] = constexpr(n) + index = view(index, new_shape, _builder=_builder) + index = broadcast_to(index, input.shape, _builder=_builder) -@builtin -@_add_reduction_docstr("xor sum") -def xor_sum(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.xor_sum(input, axis, _builder) + rvalue, rindices = reduction((input, index), axis, combine_fn, + _builder=_builder, _generator=_generator) + return rindices -@builtin -@_add_reduction_docstr("prod") -def prod(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.prod(input, axis, _builder) +@triton.jit +def _max_combine(a, b): + return maximum(a, b) +@triton.jit +@_add_reduction_docstr("maximum") +def max(input, axis): + input = _promote_reduction_input(input) + return reduction(input, axis, _max_combine) -@builtin -@_add_reduction_docstr("argmin2") -def argmin2(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - n = input.shape[axis] - index = arange(0, n, _builder=_builder) - new_shape = [constexpr(1)] * len(input.shape) - new_shape[axis] = constexpr(n) - index = view(index, new_shape, _builder=_builder) - index = broadcast_to(index, input.shape, _builder=_builder) +@triton.jit +def _argmax_combine(value1, index1, value2, index2): + gt = value1 > value2 + lt = value1 < value2 + index_min = minimum(index1, index2) + index_ret = where(gt, index1, where(lt, index2, index_min)) + value_ret = maximum(value1, value2) + return value_ret, index_ret + - values, indices = semantic.min_with_index(input, index, axis, _builder) - return indices +@triton.jit +@_add_reduction_docstr("maximum index") +def argmax(input, axis): + input = _promote_reduction_input(input) + return _argreduce(input, axis, _argmax_combine) + + +@triton.jit +def _min_combine(a, b): + # TODO: minimum/maximum doesn't get lowered to fmin/fmax... + return minimum(a, b) + + +@triton.jit +@_add_reduction_docstr("minimum") +def min(input, axis): + input = _promote_reduction_input(input) + return reduction(input, axis, _min_combine) + + +@triton.jit +def _argmin_combine(value1, index1, value2, index2): + lt = value1 < value2 + gt = value1 > value2 + index_min = minimum(index1, index2) + index_ret = where(lt, index1, where(gt, index2, index_min)) + value_ret = minimum(value1, value2) + return value_ret, index_ret + + +@triton.jit +@_add_reduction_docstr("minimum index") +def argmin(input, axis): + input = _promote_reduction_input(input) + return _argreduce(input, axis, _argmin_combine) + + +@triton.jit +def _sum_combine(a, b): + return a + b + + +@triton.jit +@_add_reduction_docstr("sum") +def sum(input, axis): + input = _promote_reduction_input(input) + return reduction(input, axis, _sum_combine) + + +@triton.jit +def _xor_combine(a, b): + return a ^ b + + +@builtin +@_add_reduction_docstr("xor sum") +def xor_sum(input, axis, _builder=None, _generator=None): + scalar_ty = input.type.scalar + if not scalar_ty.is_int(): + raise ValueError("xor_sum only supported for integers") + + input = _promote_reduction_input(input, _builder=_builder) + return reduction(input, axis, _xor_combine, + _builder=_builder, _generator=_generator) # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 226b7e986f2c..d53ed3f59727 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1201,7 +1201,7 @@ def wrap_tensor(x, scalar_ty): res_ty = tl.block_type(scalar_ty, ret_shape) else: # 0d-tensor -> scalar - res_ty = out_scalar_ty + res_ty = scalar_ty return tl.tensor(x, res_ty) reduce_op = builder.create_generic_reduce([t.handle for t in inputs], axis) From 08c2e35be4a548b6e69a7dde7ec6dbfdcef172e6 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Tue, 14 Mar 2023 02:50:52 +0000 Subject: [PATCH 05/15] Replace old ReduceOp entirely --- include/triton/Analysis/Utility.h | 12 - .../Dialect/Triton/IR/TritonAttrDefs.td | 24 - include/triton/Dialect/Triton/IR/TritonOps.td | 29 +- lib/Analysis/Allocation.cpp | 4 - lib/Analysis/Membar.cpp | 5 +- lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 1 - .../TritonGPUToLLVM/GenericReduceOpToLLVM.cpp | 416 ---------------- .../TritonGPUToLLVM/GenericReduceOpToLLVM.h | 16 - .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 466 ++++++++---------- .../TritonGPUToLLVM/ReduceOpToLLVM.h | 2 +- .../TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 2 - .../TritonToTritonGPUPass.cpp | 32 +- lib/Dialect/Triton/IR/Ops.cpp | 52 +- .../Transforms/RemoveLayoutConversions.cpp | 72 ++- python/src/triton.cc | 42 +- python/triton/language/__init__.py | 2 - python/triton/language/semantic.py | 124 +---- 17 files changed, 285 insertions(+), 1016 deletions(-) delete mode 100644 lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp delete mode 100644 lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.h diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 2a17c0e6d6d3..bc9a46c5bf9d 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -14,18 +14,6 @@ class ReduceOpHelper { public: explicit ReduceOpHelper(triton::ReduceOp rop): op(rop.getOperation()), axis(rop.getAxis()) { - auto srcTy = rop.getOperand().getType().cast(); - srcShape = srcTy.getShape(); - srcEncoding = srcTy.getEncoding(); - srcElementTypes.push_back(srcTy.getElementType()); - - if (triton::ReduceOp::withIndex(rop.getRedOp())) { - srcElementTypes.push_back(Builder(op).getI32Type()); - } - } - - explicit ReduceOpHelper(triton::GenericReduceOp rop): - op(rop.getOperation()), axis(rop.getAxis()) { auto firstTy = rop.getOperands()[0].getType().cast(); srcShape = firstTy.getShape(); srcEncoding = firstTy.getEncoding(); diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index bb0f4c0676dc..794e02a3377c 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -23,30 +23,6 @@ def TT_EvictionPolicyAttr : I32EnumAttr< let cppNamespace = "::mlir::triton"; } -// reduction -def TT_RedOpAttr : I32EnumAttr< - /*name*/"RedOp", /*summary*/"", - /*case*/ - [ - I32EnumAttrCase, - I32EnumAttrCase<"FADD", 2, "fadd">, - I32EnumAttrCase<"MIN", 3, "min">, - I32EnumAttrCase<"MAX", 4, "max">, - I32EnumAttrCase<"UMIN", 5, "umin">, - I32EnumAttrCase<"UMAX", 6, "umax">, - I32EnumAttrCase<"ARGMIN", 7, "argmin">, - I32EnumAttrCase<"ARGMAX", 8, "argmax">, - I32EnumAttrCase<"ARGUMIN", 9, "argumin">, - I32EnumAttrCase<"ARGUMAX", 10, "argumax">, - I32EnumAttrCase<"FMIN", 11, "fmin">, - I32EnumAttrCase<"FMAX", 12, "fmax">, - I32EnumAttrCase<"ARGFMIN", 13, "argfmin">, - I32EnumAttrCase<"ARGFMAX", 14, "argfmax">, - I32EnumAttrCase<"XOR", 15, "xor"> - ]> { - let cppNamespace = "::mlir::triton"; -} - // atomic def TT_AtomicRMWAttr : I32EnumAttr< "RMWOp", "", diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 6436660fef73..46797cff8bdd 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -348,29 +348,8 @@ def TT_DotOp : TT_Op<"dot", [Pure, // // Reduce Op // -def TT_ReduceOp : TT_Op<"reduce", [Pure, - DeclareOpInterfaceMethods]> { - let summary = "reduce"; - - let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis); - - let results = (outs TT_Type:$result); - - let builders = [ - OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>, - ]; - - let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)"; - - let extraClassDeclaration = [{ - // This member function is marked static because we need to call it before the ReduceOp - // is constructed, see the implementation of create_reduce in triton.cc. - static bool withIndex(mlir::triton::RedOp redOp); - }]; -} - -def TT_GenericReduceOp: TT_Op<"generic_reduce", - [Pure, DeclareOpInterfaceMethods, SingleBlock]> { +def TT_ReduceOp: TT_Op<"reduce", + [Pure, DeclareOpInterfaceMethods, SingleBlock]> { let summary = "Reduction using generic combination algorithm"; let arguments = (ins Variadic:$operands, I32Attr:$axis); let results = (outs Variadic:$result); @@ -387,8 +366,8 @@ def TT_GenericReduceOp: TT_Op<"generic_reduce", }]; } -def TT_GenericReduceReturnOp: TT_Op<"generic_reduce.return", - [HasParent<"GenericReduceOp">, Pure, Terminator, ReturnLike]> { +def TT_ReduceReturnOp: TT_Op<"reduce.return", + [HasParent<"ReduceOp">, Pure, Terminator, ReturnLike]> { let summary = "terminator for reduce operator"; let arguments = (ins Variadic:$result); let assemblyFormat = "$result attr-dict `:` type($result)"; diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 29cdec5b53b6..3cee8e1d13be 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -166,10 +166,6 @@ class AllocationAnalysis { ReduceOpHelper helper(reduceOp); unsigned bytes = helper.getScratchSizeInBytes(); allocation->addBuffer(op, bytes); - } else if (auto reduceOp = dyn_cast(op)) { - ReduceOpHelper helper(reduceOp); - unsigned bytes = helper.getScratchSizeInBytes(); - allocation->addBuffer(op, bytes); } else if (auto cvtLayout = dyn_cast(op)) { auto srcTy = cvtLayout.getSrc().getType().cast(); auto dstTy = cvtLayout.getResult().getType().cast(); diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index e23165a3f16c..9db5dfb23e91 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -72,11 +72,10 @@ void MembarAnalysis::visitTerminator(Operation *op, } return; } - if (isa(op)) { + if (isa(op) || isa(op)) { return; } - // Otherwise, it could be a return op - assert(isa(op) && "Unknown terminator"); + op->emitOpError("Unknown terminator encountered in membar analysis"); } void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 11e4c8ccdf7b..97e32f4e345c 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -7,7 +7,6 @@ add_mlir_conversion_library(TritonGPUToLLVM TritonGPUToLLVMPass.cpp PTXAsmFormat.cpp ReduceOpToLLVM.cpp - GenericReduceOpToLLVM.cpp Utility.cpp TypeConverter.cpp ViewOpToLLVM.cpp diff --git a/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp deleted file mode 100644 index 37fddd329bb2..000000000000 --- a/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp +++ /dev/null @@ -1,416 +0,0 @@ -#include "GenericReduceOpToLLVM.h" - -using namespace mlir; -using namespace mlir::triton; - -using ::mlir::LLVM::shflSync; -using ::mlir::LLVM::storeShared; -using ::mlir::triton::gpu::getElemsPerThread; -using ::mlir::triton::gpu::getOrder; - -struct GenericReduceOpConversion - : public ConvertTritonGPUOpToLLVMPattern { -public: - using ConvertTritonGPUOpToLLVMPattern< - triton::GenericReduceOp>::ConvertTritonGPUOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::GenericReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (ReduceOpHelper(op).isFastReduction()) - return matchAndRewriteFast(op, adaptor, rewriter); - return matchAndRewriteBasic(op, adaptor, rewriter); - } - -private: - - void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, - llvm::SmallVectorImpl &acc, ValueRange cur, bool isFirst) const { - if (isFirst) { - acc.resize(cur.size()); - for (unsigned i = 0; i < cur.size(); ++i) { - acc[i] = cur[i]; - } - return; - } - - // Create a new copy of the reduce block, and inline it - Block *currentBlock = rewriter.getBlock(); - Region &parent = *currentBlock->getParent(); - rewriter.cloneRegionBefore(combineOp, &parent.front()); - auto &newReduce = parent.front(); - auto returnOp = dyn_cast(newReduce.getTerminator()); - - llvm::SmallVector combineArgs(2*acc.size()); - for (unsigned i = 0; i < acc.size(); ++i) { - combineArgs[i] = acc[i]; - combineArgs[acc.size() + i] = cur[i]; - } - - rewriter.mergeBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), combineArgs); - - auto results = returnOp.getResult(); - for (unsigned i = 0; i < acc.size(); ++i) { - acc[i] = results[i]; - } - - // Delete the terminator, which is no longer used - rewriter.eraseOp(returnOp); - } - - SmallVector> unpackInputs( - Location loc, triton::GenericReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto types = op.getInputTypes(); - auto operands = adaptor.getOperands(); - unsigned srcElems = getElemsPerThread(types[0]); - SmallVector> srcValues(srcElems); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - auto values = getTypeConverter()->unpackLLElements( - loc, operands[i], rewriter, types[i]); - - assert(values.size() == srcValues.size()); - for (unsigned j = 0; j < srcValues.size(); ++j) { - srcValues[j].push_back(values[j]); - } - } - return srcValues; - } - - // Use shared memory for reduction within warps and across warps - LogicalResult - matchAndRewriteBasic(triton::GenericReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Location loc = op.getLoc(); - unsigned axis = op.getAxis(); - - ReduceOpHelper helper(op); - auto srcTys = op.getInputTypes(); - auto srcLayout = helper.getSrcLayout().cast(); - auto srcOrd = srcLayout.getOrder(); - auto srcShape = helper.getSrcShape(); - - SmallVector elemPtrTys(srcTys.size()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - auto ty = srcTys[i].getElementType(); - auto llvmElemTy = getTypeConverter()->convertType(ty); - elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3); - } - auto llvmIndexTy = getTypeConverter()->getIndexType(); - auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); - - auto smemShape = helper.getScratchConfigBasic(); - unsigned elems = product(smemShape); - - SmallVector smemBases(op.getNumOperands()); - smemBases[0] = bitcast( - getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); - for (unsigned i = 1; i < op.getNumOperands(); ++i) { - smemBases[i] = gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(elems)); - } - - unsigned srcElems = getElemsPerThread(srcTys[0]); - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]); - auto srcValues = unpackInputs(loc, op, adaptor, rewriter); - - // Assumes offsets don't actually depend on type - SmallVector> offset = - emitOffsetForLayout(srcLayout, srcTys[0]); - - std::map, SmallVector> accs; - std::map, SmallVector> indices; - - - Region *combineOp = &op.getCombineOp(); - - // reduce within threads - for (unsigned i = 0; i < srcElems; ++i) { - SmallVector key = offset[i]; - key[axis] = 0; - bool isFirst = accs.find(key) == accs.end(); - accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); - if (isFirst) - indices[key] = srcIndices[i]; - } - - // cached int32 constants - std::map ints; - ints[0] = i32_val(0); - for (int N = smemShape[axis] / 2; N > 0; N >>= 1) - ints[N] = i32_val(N); - Value sizePerThread = i32_val(srcLayout.getSizePerThread()[axis]); - - // reduce across threads - for (auto it : accs) { - const SmallVector &key = it.first; - auto &acc = it.second; - SmallVector writeIdx = indices[key]; - - writeIdx[axis] = udiv(writeIdx[axis], sizePerThread); - Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd); - SmallVector writePtrs(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset); - store(acc[i], writePtrs[i]); - } - - - SmallVector readIdx(writeIdx.size(), ints[0]); - for (int N = smemShape[axis] / 2; N > 0; N >>= 1) { - readIdx[axis] = ints[N]; - Value readMask = icmp_slt(writeIdx[axis], ints[N]); - Value readOffset = select( - readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd), - ints[0]); - SmallVector readPtrs(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - readPtrs[i] = gep(elemPtrTys[i], writePtrs[i], readOffset); - } - - barrier(); - SmallVector cur(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - cur[i] = load(gep(elemPtrTys[i], readPtrs[i], readOffset)); - } - accumulate(rewriter, *combineOp, acc, cur, false); - barrier(); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - store(acc[i], writePtrs[i]); - } - } - } - - barrier(); - - // set output values - SmallVector results(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - if (auto resultTy = op.getResult()[i].getType().dyn_cast()) { - // nd-tensor where n >= 1 - - auto resultLayout = resultTy.getEncoding(); - - unsigned resultElems = getElemsPerThread(resultTy); - auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); - assert(resultIndices.size() == resultElems); - - SmallVector resultVals(resultElems); - for (unsigned j = 0; j < resultElems; ++j) { - SmallVector readIdx = resultIndices[j]; - readIdx.insert(readIdx.begin() + axis, ints[0]); - Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd); - Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); - resultVals[j] = load(readPtr); - } - results[i] = getTypeConverter()->packLLElements( - loc, resultVals, rewriter, resultTy); - } else { - // 0d-tensor -> scalar - results[i] = load(smemBases[i]); - } - } - - auto parentBlock = op.getOperation()->getBlock(); - rewriter.replaceOp(op, results); - return success(); - } - - // Use warp shuffle for reduction within warps and shared memory for data - // exchange across warps - LogicalResult matchAndRewriteFast(triton::GenericReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - - Location loc = op->getLoc(); - unsigned axis = adaptor.getAxis(); - - ReduceOpHelper helper(op); - auto srcTys = op.getInputTypes(); - auto srcLayout = helper.getSrcLayout().cast(); - auto srcOrd = srcLayout.getOrder(); - auto srcShape = helper.getSrcShape(); - - SmallVector elemPtrTys(srcTys.size()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - auto ty = srcTys[i].getElementType(); - auto llvmElemTy = getTypeConverter()->convertType(ty); - elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3); - } - auto llvmIndexTy = getTypeConverter()->getIndexType(); - auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); - - auto smemShapes = helper.getScratchConfigsFast(); - unsigned elems = product(smemShapes[0]); - unsigned maxElems = std::max(elems, product(smemShapes[1])); - - SmallVector smemBases(op.getNumOperands()); - smemBases[0] = bitcast( - getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); - for (unsigned i = 1; i < op.getNumOperands(); ++i) { - smemBases[i] = gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)); - } - - unsigned sizeIntraWarps = helper.getIntraWarpSize(); - unsigned sizeInterWarps = helper.getInterWarpSize(); - - unsigned srcElems = getElemsPerThread(srcTys[0]); - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]); - auto srcValues = unpackInputs(loc, op, adaptor, rewriter); - - std::map, SmallVector> accs; - std::map, SmallVector> indices; - - // Assumes offsets don't actually depend on type - SmallVector> offset = - emitOffsetForLayout(srcLayout, srcTys[0]); - - auto *combineOp = &op.getCombineOp(); - - // reduce within threads - for (unsigned i = 0; i < srcElems; ++i) { - SmallVector key = offset[i]; - key[axis] = 0; - bool isFirst = accs.find(key) == accs.end(); - accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); - if (isFirst) - indices[key] = srcIndices[i]; - } - - Value threadId = getThreadId(rewriter, loc); - Value warpSize = i32_val(32); - Value warpId = udiv(threadId, warpSize); - Value laneId = urem(threadId, warpSize); - - auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); - auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout); - auto order = getOrder(srcLayout); - SmallVector multiDimLaneId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); - SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, order); - - Value laneIdAxis = multiDimLaneId[axis]; - Value warpIdAxis = multiDimWarpId[axis]; - - Value zero = i32_val(0); - Value laneZero = icmp_eq(laneIdAxis, zero); - - for (auto it : accs) { - const SmallVector &key = it.first; - SmallVector acc = it.second; - - // Reduce within warps - for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) { - SmallVector shfl(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - shfl[i] = shflSync(loc, rewriter, acc[i], N); - } - accumulate(rewriter, *combineOp, acc, shfl, false); - } - - SmallVector writeIdx = indices[key]; - writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis; - Value writeOffset = - linearize(rewriter, loc, writeIdx, smemShapes[0], order); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - Value writePtr = gep(elemPtrTys[i], smemBases[i], writeOffset); - storeShared(rewriter, loc, writePtr, acc[i], laneZero); - } - } - - barrier(); - - // The second round of shuffle reduction - // now the problem size: sizeInterWarps, s1, s2, .. , sn - // where sizeInterWarps is 2^m - // - // Each thread needs to process: - // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads - unsigned numThreads = - product(triton::gpu::getWarpsPerCTA(srcLayout)) * 32; - unsigned elemsPerThread = std::max(elems / numThreads, 1); - Value readOffset = threadId; - for (unsigned round = 0; round < elemsPerThread; ++round) { - // FIXME(Qingyi): need predicate icmp_slt(threadId, - // i32_val(sizeInerWarps)) - SmallVector acc(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); - acc[i] = load(readPtr); - } - - for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) { - SmallVector shfl(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - shfl[i] = shflSync(loc, rewriter, acc[i], N); - } - accumulate(rewriter, *combineOp, acc, shfl, false); - } - - // only the first thread in each sizeInterWarps is writing - Value writeOffset = readOffset; - SmallVector writePtrs(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset); - } - Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); - Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); - Value laneIdModSizeInterWarpsIsZero = - icmp_eq(laneIdModSizeInterWarps, zero); - Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); - - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - storeShared(rewriter, loc, writePtrs[i], acc[i], pred); - } - - if (round != elemsPerThread - 1) { - readOffset = add(readOffset, i32_val(numThreads)); - } - } - - // We could avoid this barrier in some of the layouts, however this is not - // the general case. - // TODO: optimize the barrier incase the layouts are accepted. - barrier(); - - // set output values - SmallVector results(op.getNumOperands()); - for (unsigned i = 0; i < op.getNumOperands(); ++i) { - if (auto resultTy = op.getResult()[i].getType().dyn_cast()) { - // nd-tensor where n >= 1 - auto resultLayout = resultTy.getEncoding().cast(); - unsigned resultElems = getElemsPerThread(resultTy); - auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); - assert(resultIndices.size() == resultElems); - - SmallVector resultVals(resultElems); - for (size_t j = 0; j < resultElems; ++j) { - SmallVector readIdx = resultIndices[j]; - readIdx.insert(readIdx.begin() + axis, i32_val(0)); - Value readOffset = - linearize(rewriter, loc, readIdx, smemShapes[0], order); - Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); - resultVals[j] = load(readPtr); - } - - results[i] = getTypeConverter()->packLLElements( - loc, resultVals, rewriter, resultTy); - } else { - // 0d-tensor -> scalar - results[i] = load(smemBases[i]); - } - } - rewriter.replaceOp(op, results); - - return success(); - } -}; - -void populateGenericReduceOpToLLVMPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, - ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, - PatternBenefit benefit) { - patterns.add(typeConverter, allocation, smem, - indexCacheInfo, benefit); -} diff --git a/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.h deleted file mode 100644 index 2280e2c64e89..000000000000 --- a/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_GENERIC_REDUCE_OP_H -#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_GENERIC_REDUCE_OP_H - -#include "TritonGPUToLLVMBase.h" - -using namespace mlir; -using namespace mlir::triton; - -void populateGenericReduceOpToLLVMPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, - ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, - PatternBenefit benefit); - -#endif diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index b15f99d2a5b3..3fe98aac34f5 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -23,165 +23,113 @@ struct ReduceOpConversion } private: - void accumulate(ConversionPatternRewriter &rewriter, Location loc, - RedOp redOp, Value &acc, Value cur, bool isFirst) const { + + void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, + llvm::SmallVectorImpl &acc, ValueRange cur, bool isFirst) const { if (isFirst) { - acc = cur; + acc.resize(cur.size()); + for (unsigned i = 0; i < cur.size(); ++i) { + acc[i] = cur[i]; + } return; } - switch (redOp) { - case RedOp::ADD: - acc = add(acc, cur); - break; - case RedOp::FADD: - acc = fadd(acc.getType(), acc, cur); - break; - case RedOp::MIN: - acc = smin(acc, cur); - break; - case RedOp::MAX: - acc = smax(acc, cur); - break; - case RedOp::UMIN: - acc = umin(acc, cur); - break; - case RedOp::UMAX: - acc = umax(acc, cur); - break; - case RedOp::FMIN: - acc = fmin(acc, cur); - break; - case RedOp::FMAX: - acc = fmax(acc, cur); - break; - case RedOp::XOR: - acc = xor_(acc, cur); - break; - case RedOp::ARGMIN: - case RedOp::ARGMAX: - case RedOp::ARGUMIN: - case RedOp::ARGUMAX: - case RedOp::ARGFMIN: - case RedOp::ARGFMAX: - llvm::report_fatal_error( - "This accumulate implementation is not for argmin / argmax"); - default: - llvm::report_fatal_error("Unsupported reduce op"); + + // Create a new copy of the reduce block, and inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + rewriter.cloneRegionBefore(combineOp, &parent.front()); + auto &newReduce = parent.front(); + auto returnOp = dyn_cast(newReduce.getTerminator()); + + llvm::SmallVector combineArgs(2*acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; } - } - void accumulateWithIndex(ConversionPatternRewriter &rewriter, Location loc, - RedOp redOp, Value &acc, Value &accIndex, Value cur, - Value curIndex, bool isFirst) const { - if (isFirst) { - acc = cur; - accIndex = curIndex; - return; + rewriter.mergeBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), combineArgs); + + auto results = returnOp.getResult(); + for (unsigned i = 0; i < acc.size(); ++i) { + acc[i] = results[i]; } - switch (redOp) { - case RedOp::ARGMIN: - accIndex = select( - icmp_slt(acc, cur), accIndex, - select(icmp_sgt(acc, cur), curIndex, smin(accIndex, curIndex))); - acc = smin(acc, cur); - break; - case RedOp::ARGMAX: - accIndex = select( - icmp_sgt(acc, cur), accIndex, - select(icmp_slt(acc, cur), curIndex, smin(accIndex, curIndex))); - acc = smax(acc, cur); - break; - case RedOp::ARGUMIN: - accIndex = select( - icmp_ult(acc, cur), accIndex, - select(icmp_ugt(acc, cur), curIndex, smin(accIndex, curIndex))); - acc = umin(acc, cur); - break; - case RedOp::ARGUMAX: - accIndex = select( - icmp_ugt(acc, cur), accIndex, - select(icmp_ult(acc, cur), curIndex, smin(accIndex, curIndex))); - acc = umax(acc, cur); - break; - case RedOp::ARGFMIN: - accIndex = select( - fcmp_olt(acc, cur), accIndex, - select(fcmp_ogt(acc, cur), curIndex, smin(accIndex, curIndex))); - acc = fmin(acc, cur); - break; - case RedOp::ARGFMAX: - accIndex = select( - fcmp_ogt(acc, cur), accIndex, - select(fcmp_olt(acc, cur), curIndex, smin(accIndex, curIndex))); - acc = fmax(acc, cur); - break; - case RedOp::ADD: - case RedOp::FADD: - case RedOp::MIN: - case RedOp::MAX: - case RedOp::UMIN: - case RedOp::UMAX: - case RedOp::FMIN: - case RedOp::FMAX: - case RedOp::XOR: - llvm::report_fatal_error( - "This accumulate implementation is only for argmin / argmax"); - default: - llvm::report_fatal_error("Unsupported reduce op"); + + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + } + + SmallVector> unpackInputs( + Location loc, triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = getTypeConverter()->unpackLLElements( + loc, operands[i], rewriter, types[i]); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } } + return srcValues; } // Use shared memory for reduction within warps and across warps LogicalResult matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - Location loc = op->getLoc(); + Location loc = op.getLoc(); unsigned axis = op.getAxis(); - bool withIndex = triton::ReduceOp::withIndex(op.getRedOp()); - auto srcTy = op.getOperand().getType().cast(); - auto srcLayout = srcTy.getEncoding().cast(); + ReduceOpHelper helper(op); + auto srcTys = op.getInputTypes(); + auto srcLayout = helper.getSrcLayout().cast(); auto srcOrd = srcLayout.getOrder(); - auto srcShape = srcTy.getShape(); + auto srcShape = helper.getSrcShape(); - auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + SmallVector elemPtrTys(srcTys.size()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto ty = srcTys[i].getElementType(); + auto llvmElemTy = getTypeConverter()->convertType(ty); + elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3); + } auto llvmIndexTy = getTypeConverter()->getIndexType(); - auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); - smemBase = bitcast(smemBase, elemPtrTy); - ReduceOpHelper helper(op); auto smemShape = helper.getScratchConfigBasic(); unsigned elems = product(smemShape); - Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(elems)); - indexSmemBase = bitcast(indexSmemBase, indexPtrTy); - unsigned srcElems = getElemsPerThread(srcTy); - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy); - auto srcValues = getTypeConverter()->unpackLLElements( - loc, adaptor.getOperand(), rewriter, srcTy); + SmallVector smemBases(op.getNumOperands()); + smemBases[0] = bitcast( + getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + smemBases[i] = bitcast( + gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(elems)), elemPtrTys[i]); + } + + unsigned srcElems = getElemsPerThread(srcTys[0]); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]); + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + // Assumes offsets don't actually depend on type SmallVector> offset = - emitOffsetForLayout(srcLayout, srcTy); + emitOffsetForLayout(srcLayout, srcTys[0]); - std::map, Value> accs; - std::map, Value> accIndices; + std::map, SmallVector> accs; std::map, SmallVector> indices; + + Region *combineOp = &op.getCombineOp(); + // reduce within threads for (unsigned i = 0; i < srcElems; ++i) { SmallVector key = offset[i]; key[axis] = 0; bool isFirst = accs.find(key) == accs.end(); - if (!withIndex) { - accumulate(rewriter, loc, op.getRedOp(), accs[key], srcValues[i], - isFirst); - } else { - Value curIndex = srcIndices[i][axis]; - accumulateWithIndex(rewriter, loc, op.getRedOp(), accs[key], - accIndices[key], srcValues[i], curIndex, isFirst); - } + accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); if (isFirst) indices[key] = srcIndices[i]; } @@ -196,19 +144,17 @@ struct ReduceOpConversion // reduce across threads for (auto it : accs) { const SmallVector &key = it.first; - Value acc = it.second; - Value accIndex; - if (withIndex) - accIndex = accIndices[key]; + auto &acc = it.second; SmallVector writeIdx = indices[key]; writeIdx[axis] = udiv(writeIdx[axis], sizePerThread); Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd); - Value writePtr = gep(elemPtrTy, smemBase, writeOffset); - Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset); - store(acc, writePtr); - if (withIndex) - store(accIndex, indexWritePtr); + SmallVector writePtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset); + store(acc[i], writePtrs[i]); + } + SmallVector readIdx(writeIdx.size(), ints[0]); for (int N = smemShape[axis] / 2; N > 0; N >>= 1) { @@ -217,22 +163,20 @@ struct ReduceOpConversion Value readOffset = select( readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd), ints[0]); - Value readPtr = gep(elemPtrTy, writePtr, readOffset); + SmallVector readPtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + readPtrs[i] = gep(elemPtrTys[i], writePtrs[i], readOffset); + } + + barrier(); + SmallVector cur(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + cur[i] = load(readPtrs[i]); + } + accumulate(rewriter, *combineOp, acc, cur, false); barrier(); - if (!withIndex) { - Value cur = load(readPtr); - accumulate(rewriter, loc, op.getRedOp(), acc, cur, false); - barrier(); - store(acc, writePtr); - } else { - Value cur = load(readPtr); - Value indexReadPtr = gep(indexPtrTy, indexWritePtr, readOffset); - Value curIndex = load(indexReadPtr); - accumulateWithIndex(rewriter, loc, op.getRedOp(), acc, accIndex, cur, - curIndex, false); - barrier(); - store(acc, writePtr); - store(accIndex, indexWritePtr); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + store(acc[i], writePtrs[i]); } } } @@ -240,33 +184,35 @@ struct ReduceOpConversion barrier(); // set output values - if (auto resultTy = op.getType().dyn_cast()) { - // nd-tensor where n >= 1 - auto resultLayout = resultTy.getEncoding(); - auto resultShape = resultTy.getShape(); - - unsigned resultElems = getElemsPerThread(resultTy); - auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); - assert(resultIndices.size() == resultElems); - - SmallVector resultVals(resultElems); - for (unsigned i = 0; i < resultElems; ++i) { - SmallVector readIdx = resultIndices[i]; - readIdx.insert(readIdx.begin() + axis, ints[0]); - Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd); - Value readPtr = gep(elemPtrTy, smemBase, readOffset); - Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); - resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = op.getResult()[i].getType().dyn_cast()) { + // nd-tensor where n >= 1 + + auto resultLayout = resultTy.getEncoding(); + + unsigned resultElems = getElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (unsigned j = 0; j < resultElems; ++j) { + SmallVector readIdx = resultIndices[j]; + readIdx.insert(readIdx.begin() + axis, ints[0]); + Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd); + Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); + resultVals[j] = load(readPtr); + } + results[i] = getTypeConverter()->packLLElements( + loc, resultVals, rewriter, resultTy); + } else { + // 0d-tensor -> scalar + results[i] = load(smemBases[i]); } - Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter, - resultTy); - rewriter.replaceOp(op, ret); - } else { - // 0d-tensor -> scalar - Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase); - rewriter.replaceOp(op, resultVal); } + auto parentBlock = op.getOperation()->getBlock(); + rewriter.replaceOp(op, results); return success(); } @@ -274,60 +220,59 @@ struct ReduceOpConversion // exchange across warps LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); unsigned axis = adaptor.getAxis(); - bool withIndex = triton::ReduceOp::withIndex(op.getRedOp()); - auto srcTy = op.getOperand().getType().cast(); - auto srcLayout = srcTy.getEncoding(); - auto srcShape = srcTy.getShape(); - auto order = getOrder(srcLayout); - - auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); - auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout); + ReduceOpHelper helper(op); + auto srcTys = op.getInputTypes(); + auto srcLayout = helper.getSrcLayout().cast(); + auto srcOrd = srcLayout.getOrder(); + auto srcShape = helper.getSrcShape(); - auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + SmallVector elemPtrTys(srcTys.size()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto ty = srcTys[i].getElementType(); + auto llvmElemTy = getTypeConverter()->convertType(ty); + elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3); + } auto llvmIndexTy = getTypeConverter()->getIndexType(); - auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); - smemBase = bitcast(smemBase, elemPtrTy); - ReduceOpHelper helper(op); auto smemShapes = helper.getScratchConfigsFast(); unsigned elems = product(smemShapes[0]); unsigned maxElems = std::max(elems, product(smemShapes[1])); - Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems)); - indexSmemBase = bitcast(indexSmemBase, indexPtrTy); + + SmallVector smemBases(op.getNumOperands()); + smemBases[0] = bitcast( + getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + smemBases[i] = bitcast( + gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)), elemPtrTys[i]); + } unsigned sizeIntraWarps = helper.getIntraWarpSize(); unsigned sizeInterWarps = helper.getInterWarpSize(); - unsigned srcElems = getElemsPerThread(srcTy); - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy); - auto srcValues = getTypeConverter()->unpackLLElements( - loc, adaptor.getOperand(), rewriter, srcTy); + unsigned srcElems = getElemsPerThread(srcTys[0]); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]); + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + std::map, SmallVector> accs; + std::map, SmallVector> indices; + + // Assumes offsets don't actually depend on type SmallVector> offset = - emitOffsetForLayout(srcLayout, srcTy); + emitOffsetForLayout(srcLayout, srcTys[0]); - std::map, Value> accs; - std::map, Value> accIndices; - std::map, SmallVector> indices; + auto *combineOp = &op.getCombineOp(); // reduce within threads for (unsigned i = 0; i < srcElems; ++i) { SmallVector key = offset[i]; key[axis] = 0; bool isFirst = accs.find(key) == accs.end(); - if (!withIndex) { - accumulate(rewriter, loc, op.getRedOp(), accs[key], srcValues[i], - isFirst); - } else { - Value curIndex = srcIndices[i][axis]; - accumulateWithIndex(rewriter, loc, op.getRedOp(), accs[key], - accIndices[key], srcValues[i], curIndex, isFirst); - } + accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); if (isFirst) indices[key] = srcIndices[i]; } @@ -337,6 +282,9 @@ struct ReduceOpConversion Value warpId = udiv(threadId, warpSize); Value laneId = urem(threadId, warpSize); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout); + auto order = getOrder(srcLayout); SmallVector multiDimLaneId = delinearize(rewriter, loc, laneId, threadsPerWarp, order); SmallVector multiDimWarpId = @@ -350,32 +298,24 @@ struct ReduceOpConversion for (auto it : accs) { const SmallVector &key = it.first; - Value acc = it.second; - Value accIndex; - if (withIndex) - accIndex = accIndices[key]; + SmallVector acc = it.second; // Reduce within warps for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) { - Value shfl = shflSync(loc, rewriter, acc, N); - if (!withIndex) { - accumulate(rewriter, loc, op.getRedOp(), acc, shfl, false); - } else { - Value shflIndex = shflSync(loc, rewriter, accIndex, N); - accumulateWithIndex(rewriter, loc, op.getRedOp(), acc, accIndex, shfl, - shflIndex, false); + SmallVector shfl(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + shfl[i] = shflSync(loc, rewriter, acc[i], N); } + accumulate(rewriter, *combineOp, acc, shfl, false); } SmallVector writeIdx = indices[key]; writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis; Value writeOffset = linearize(rewriter, loc, writeIdx, smemShapes[0], order); - Value writePtr = gep(elemPtrTy, smemBase, writeOffset); - storeShared(rewriter, loc, writePtr, acc, laneZero); - if (withIndex) { - Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset); - storeShared(rewriter, loc, indexWritePtr, accIndex, laneZero); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + Value writePtr = gep(elemPtrTys[i], smemBases[i], writeOffset); + storeShared(rewriter, loc, writePtr, acc[i], laneZero); } } @@ -392,39 +332,36 @@ struct ReduceOpConversion unsigned elemsPerThread = std::max(elems / numThreads, 1); Value readOffset = threadId; for (unsigned round = 0; round < elemsPerThread; ++round) { - Value readPtr = gep(elemPtrTy, smemBase, readOffset); // FIXME(Qingyi): need predicate icmp_slt(threadId, // i32_val(sizeInerWarps)) - Value acc = load(readPtr); - Value accIndex; - if (withIndex) { - Value readIndexPtr = gep(indexPtrTy, indexSmemBase, readOffset); - accIndex = load(readIndexPtr); + SmallVector acc(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); + acc[i] = load(readPtr); } for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) { - Value shfl = shflSync(loc, rewriter, acc, N); - if (!withIndex) { - accumulate(rewriter, loc, op.getRedOp(), acc, shfl, false); - } else { - Value shflIndex = shflSync(loc, rewriter, accIndex, N); - accumulateWithIndex(rewriter, loc, op.getRedOp(), acc, accIndex, shfl, - shflIndex, false); + SmallVector shfl(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + shfl[i] = shflSync(loc, rewriter, acc[i], N); } + accumulate(rewriter, *combineOp, acc, shfl, false); } // only the first thread in each sizeInterWarps is writing Value writeOffset = readOffset; - Value writePtr = gep(elemPtrTy, smemBase, writeOffset); + SmallVector writePtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset); + } Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); Value laneIdModSizeInterWarpsIsZero = icmp_eq(laneIdModSizeInterWarps, zero); Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); - storeShared(rewriter, loc, writePtr, acc, pred); - if (withIndex) { - Value writeIndexPtr = gep(indexPtrTy, indexSmemBase, writeOffset); - storeShared(rewriter, loc, writeIndexPtr, accIndex, pred); + + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + storeShared(rewriter, loc, writePtrs[i], acc[i], pred); } if (round != elemsPerThread - 1) { @@ -438,32 +375,33 @@ struct ReduceOpConversion barrier(); // set output values - if (auto resultTy = op.getType().dyn_cast()) { - // nd-tensor where n >= 1 - auto resultLayout = resultTy.getEncoding().cast(); - auto resultShape = resultTy.getShape(); - unsigned resultElems = getElemsPerThread(resultTy); - auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); - assert(resultIndices.size() == resultElems); - - SmallVector resultVals(resultElems); - for (size_t i = 0; i < resultElems; ++i) { - SmallVector readIdx = resultIndices[i]; - readIdx.insert(readIdx.begin() + axis, i32_val(0)); - Value readOffset = - linearize(rewriter, loc, readIdx, smemShapes[0], order); - Value readPtr = gep(elemPtrTy, smemBase, readOffset); - Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); - resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = op.getResult()[i].getType().dyn_cast()) { + // nd-tensor where n >= 1 + auto resultLayout = resultTy.getEncoding().cast(); + unsigned resultElems = getElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (size_t j = 0; j < resultElems; ++j) { + SmallVector readIdx = resultIndices[j]; + readIdx.insert(readIdx.begin() + axis, i32_val(0)); + Value readOffset = + linearize(rewriter, loc, readIdx, smemShapes[0], order); + Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); + resultVals[j] = load(readPtr); + } + + results[i] = getTypeConverter()->packLLElements( + loc, resultVals, rewriter, resultTy); + } else { + // 0d-tensor -> scalar + results[i] = load(smemBases[i]); } - Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter, - resultTy); - rewriter.replaceOp(op, ret); - } else { - // 0d-tensor -> scalar - Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase); - rewriter.replaceOp(op, resultVal); } + rewriter.replaceOp(op, results); return success(); } @@ -476,5 +414,5 @@ void populateReduceOpToLLVMPatterns( ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit) { patterns.add(typeConverter, allocation, smem, - indexCacheInfo, benefit); + indexCacheInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h index 0e03e0adda88..a7c1c9912193 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h @@ -13,4 +13,4 @@ void populateReduceOpToLLVMPatterns( ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit); -#endif \ No newline at end of file +#endif diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 28142d24fb78..2bad78e768ac 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -23,7 +23,6 @@ #include "ElementwiseOpToLLVM.h" #include "LoadStoreOpToLLVM.h" #include "ReduceOpToLLVM.h" -#include "./GenericReduceOpToLLVM.h" #include "TritonGPUToLLVM.h" #include "TypeConverter.h" #include "ViewOpToLLVM.h" @@ -200,7 +199,6 @@ class ConvertTritonGPUToLLVM populatePatterns2(populateElementwiseOpToLLVMPatterns); populatePatterns1(populateLoadStoreOpToLLVMPatterns); populatePatterns1(populateReduceOpToLLVMPatterns); - populatePatterns1(populateGenericReduceOpToLLVMPatterns); populatePatterns2(populateViewOpToLLVMPatterns); // Native lowering patterns mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 8eeba607bc4f..55280914fdeb 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -469,23 +469,9 @@ struct TritonReducePattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs( - rewriter.replaceOpWithNewOp( - op, adaptor.getRedOp(), adaptor.getOperand(), adaptor.getAxis()), - adaptor.getAttributes()); - return success(); - } -}; - -struct TritonGenericReducePattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::GenericReduceOp op, OpAdaptor adaptor, + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto newReduce = rewriter.create( + auto newReduce = rewriter.create( op.getLoc(), adaptor.getOperands(), adaptor.getAxis()); addNamedAttrs(newReduce, adaptor.getAttributes()); @@ -496,15 +482,15 @@ struct TritonGenericReducePattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct TritonReduceReturnPattern : + public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(triton::GenericReduceReturnOp op, OpAdaptor adaptor, + matchAndRewrite(triton::ReduceReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { addNamedAttrs( - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, adaptor.getResult()), adaptor.getAttributes()); return success(); @@ -551,8 +537,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonGenericPattern, TritonGenericPattern, TritonBroadcastPattern, TritonGenericPattern, TritonCatPattern, - TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern, - TritonGenericReducePattern, TritonGenericReduceReturnPattern, + TritonReducePattern, TritonReduceReturnPattern, + TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern, TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context); diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 26270605a711..3fb15a7db35a 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -269,34 +269,8 @@ static mlir::LogicalResult inferReduceReturnShape( return mlir::success(); } -mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - Value arg = operands[0]; - auto argTy = arg.getType().cast(); - auto argEltTy = argTy.getElementType(); - auto i32Ty = IntegerType::get(argEltTy.getContext(), 32); - auto redOp = - attributes.get("redOp").cast().getValue(); - bool withIndex = mlir::triton::ReduceOp::withIndex(redOp); - auto retEltTy = withIndex ? i32Ty : argEltTy; - int axis = attributes.get("axis").cast().getInt(); - return inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); -} - -bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) { - return redOp == mlir::triton::RedOp::ARGMIN || - redOp == mlir::triton::RedOp::ARGMAX || - redOp == mlir::triton::RedOp::ARGUMIN || - redOp == mlir::triton::RedOp::ARGUMAX || - redOp == mlir::triton::RedOp::ARGFMIN || - redOp == mlir::triton::RedOp::ARGFMAX; -} - -//-- GenericReduceOp -- -void GenericReduceOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::ValueRange operands, int axis) { +void ReduceOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::ValueRange operands, int axis) { SmallVector inferredReturnTypes; for (unsigned i = 0; i < operands.size(); ++i) { auto argTy = operands[i].getType().cast(); @@ -304,10 +278,10 @@ void GenericReduceOp::build(mlir::OpBuilder &builder, mlir::OperationState &stat (void)inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); } - GenericReduceOp::build(builder, state, inferredReturnTypes, operands, axis); + ReduceOp::build(builder, state, inferredReturnTypes, operands, axis); } -mlir::LogicalResult mlir::triton::GenericReduceOp::inferReturnTypes( +mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { @@ -324,19 +298,19 @@ mlir::LogicalResult mlir::triton::GenericReduceOp::inferReturnTypes( return success(); } -mlir::LogicalResult mlir::triton::GenericReduceOp::verify() { +mlir::LogicalResult mlir::triton::ReduceOp::verify() { if (this->getOperands().size() < 1) { - return this->emitOpError() << "tt.generic_reduce must have at least 1 operand"; + return this->emitOpError() << "must have at least 1 operand"; } for (const auto &operand: this->getOperands()) { if (!dyn_cast(operand.getType())) { - return this->emitOpError() << "tt.generic_reduce operands must be RankedTensorType"; + return this->emitOpError() << "operands must be RankedTensorType"; } } return success(); } -mlir::LogicalResult mlir::triton::GenericReduceOp::verifyRegions() { +mlir::LogicalResult mlir::triton::ReduceOp::verifyRegions() { auto argElementTypes = this->getElementTypes(); const auto &operands = this->getOperands(); const auto numArgs = 2 * operands.size(); @@ -357,10 +331,10 @@ mlir::LogicalResult mlir::triton::GenericReduceOp::verifyRegions() { } } - auto terminator = dyn_cast(block.getTerminator()); + auto terminator = dyn_cast(block.getTerminator()); if (!terminator) { return this->emitOpError() << "combine operation must be terminated " - << "with a GenericReduceReturnOp but got " + << "with a ReduceReturnOp but got " << block.getTerminator(); } const auto &combineResults = terminator->getOperands(); @@ -379,7 +353,7 @@ mlir::LogicalResult mlir::triton::GenericReduceOp::verifyRegions() { return mlir::success(); } -llvm::SmallVector GenericReduceOp::getInputTypes() { +llvm::SmallVector ReduceOp::getInputTypes() { llvm::SmallVector srcTys; srcTys.reserve(this->getNumOperands()); for (const auto &ty: this->getOperands().getTypes()) { @@ -388,7 +362,7 @@ llvm::SmallVector GenericReduceOp::getInputTypes() { return srcTys; } -llvm::SmallVector GenericReduceOp::getElementTypes() { +llvm::SmallVector ReduceOp::getElementTypes() { llvm::SmallVector srcElemTys; srcElemTys.reserve(this->getNumOperands()); for (const auto &op: this->getOperands()) { @@ -397,7 +371,7 @@ llvm::SmallVector GenericReduceOp::getElementTypes() { return srcElemTys; } -unsigned GenericReduceOp::getNumOperands() { +unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 3cbe0a1a8e68..da110c98ce1e 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -94,30 +94,56 @@ class SimplifyReduceCvt : public mlir::RewritePattern { matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { auto reduce = cast(*op); - auto reduceArg = dyn_cast_or_null( - reduce.getOperand().getDefiningOp()); - if (!reduceArg) - return mlir::failure(); - // this may generate unsupported conversions in the LLVM codegen - if (reduceArg.getOperand() - .getType() - .cast() - .getEncoding() - .isa()) - return mlir::failure(); - auto newReduce = rewriter.create( - op->getLoc(), reduce.getRedOp(), reduceArg.getOperand(), - reduce.getAxis()); + SmallVector newOperands = reduce.getOperands(); + + // TODO: This always takes layout from the first argument which + // is fine for argmin/argmax but may not be optimal generally + auto firstArgConversionOp = dyn_cast_or_null( + reduce.getOperands()[0].getDefiningOp()); + if (!firstArgConversionOp) { + return failure(); + } + newOperands[0] = firstArgConversionOp.getOperand(); + auto newEncoding = + newOperands[0] + .getType() + .cast() + .getEncoding(); + + if (!newEncoding.isa()) { + // ReduceOpToLLVM requires block encoding + return failure(); + } + if (isa_and_nonnull( - *reduceArg.getOperand().getDefiningOp())) - return mlir::failure(); - Value newRet = newReduce.getResult(); - // it's still beneficial to move the conversion - // to after the reduce if necessary since it will be - // done on a rank-reduced tensor hence cheaper - if (newRet.getType() != reduce.getResult().getType()) - newRet = rewriter.create( - op->getLoc(), reduce.getResult().getType(), newRet); + *newOperands[0].getDefiningOp())) { + return failure(); + } + + + for (unsigned i = 1; i < newOperands.size(); ++i) { + auto oldTy = newOperands[i].getType().cast(); + RankedTensorType newTy = RankedTensorType::Builder(oldTy).setEncoding(newEncoding); + + newOperands[i] = rewriter.create( + op->getLoc(), newTy, newOperands[i]); + } + + auto newReduce = rewriter.create( + op->getLoc(), newOperands, reduce.getAxis()); + auto &newCombineOp = newReduce.getCombineOp(); + rewriter.inlineRegionBefore(reduce.getCombineOp(), newCombineOp, newCombineOp.end()); + + SmallVector newRet = newReduce.getResult(); + auto oldTypes = reduce.getResult().getType(); + for (unsigned i = 0; i < reduce.getNumOperands(); ++i) { + // it's still beneficial to move the conversion + // to after the reduce if necessary since it will be + // done on a rank-reduced tensor hence cheaper + if (newRet[i].getType() != oldTypes[i]) + newRet[i] = rewriter.create( + op->getLoc(), oldTypes[i], newRet[i]); + } rewriter.replaceOp(op, newRet); return success(); diff --git a/python/src/triton.cc b/python/src/triton.cc index 874b79208076..396def954a2a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -88,23 +88,6 @@ void init_triton_ir(py::module &&m) { .value("EVICT_LAST", mlir::triton::EvictionPolicy::EVICT_LAST) .export_values(); - py::enum_(m, "REDUCE_OP") - .value("ADD", mlir::triton::RedOp::ADD) - .value("FADD", mlir::triton::RedOp::FADD) - .value("MIN", mlir::triton::RedOp::MIN) - .value("MAX", mlir::triton::RedOp::MAX) - .value("UMIN", mlir::triton::RedOp::UMIN) - .value("UMAX", mlir::triton::RedOp::UMAX) - .value("ARGMIN", mlir::triton::RedOp::ARGMIN) - .value("ARGMAX", mlir::triton::RedOp::ARGMAX) - .value("ARGUMIN", mlir::triton::RedOp::ARGUMIN) - .value("ARGUMAX", mlir::triton::RedOp::ARGUMAX) - .value("FMIN", mlir::triton::RedOp::FMIN) - .value("FMAX", mlir::triton::RedOp::FMAX) - .value("ARGFMIN", mlir::triton::RedOp::ARGFMIN) - .value("ARGFMAX", mlir::triton::RedOp::ARGFMAX) - .value("XOR", mlir::triton::RedOp::XOR); - py::enum_(m, "ATOMIC_OP") .value("ADD", mlir::triton::RMWOp::ADD) .value("FADD", mlir::triton::RMWOp::FADD) @@ -440,8 +423,6 @@ void init_triton_ir(py::module &&m) { .def_property_readonly("type", &mlir::func::FuncOp::getFunctionType) .def("reset_type", &mlir::func::FuncOp::setType); - py::class_(m, "GenericReduceOp"); - py::class_(m, "InsertPoint"); py::class_(m, "builder", py::dynamic_attr()) @@ -1324,28 +1305,11 @@ void init_triton_ir(py::module &&m) { return self.create(loc, val); }) .def("create_reduce", - [](mlir::OpBuilder &self, mlir::Value &operand, - mlir::triton::RedOp redOp, int axis) -> mlir::Value { - auto loc = self.getUnknownLoc(); - auto inputTensorType = - operand.getType().dyn_cast(); - std::vector shape = inputTensorType.getShape(); - shape.erase(shape.begin() + axis); - bool withIndex = mlir::triton::ReduceOp::withIndex(redOp); - mlir::Type resType = withIndex ? self.getI32Type() - : inputTensorType.getElementType(); - if (!shape.empty()) { - resType = mlir::RankedTensorType::get(shape, resType); - } - return self.create(loc, resType, redOp, - operand, axis); - }) - .def("create_generic_reduce", []( mlir::OpBuilder &self, std::vector operands, int axis - ) -> mlir::triton::GenericReduceOp { + ) -> mlir::OpState { auto loc = self.getUnknownLoc(); - return self.create(loc, operands, axis); + return self.create(loc, operands, axis); }) .def("create_reduce_ret", [](mlir::OpBuilder &self, py::args args) -> mlir::OpState { @@ -1354,7 +1318,7 @@ void init_triton_ir(py::module &&m) { for (const auto & arg: args) { return_values.push_back(py::cast(arg)); } - return self.create( + return self.create( loc, return_values); }) .def("create_ptr_to_int", diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index dd34684da4b0..10e67268189b 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -10,7 +10,6 @@ abs, arange, argmin, - argmin2, argmax, atomic_add, atomic_and, @@ -58,7 +57,6 @@ num_programs, pi32_t, pointer_type, - prod, program_id, ravel, reshape, diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index d53ed3f59727..47eea698fac6 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1100,92 +1100,9 @@ def where(condition: tl.tensor, return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) # ===----------------------------------------------------------------------===// -# Reductions +# Reduction # ===----------------------------------------------------------------------=== - -def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, - FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor: - scalar_ty = input.type.scalar - out_scalar_ty = scalar_ty - # input is extended to 32-bits if necessary - # this increases numerical accuracy and can be done pretty much for free - # on GPUs - if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32: - input = cast(input, tl.int32, builder) - out_scalar_ty = tl.int32 - - # hardware doesn't support FMAX, FMIN, CMP for bfloat16 - if scalar_ty is tl.bfloat16: - input = cast(input, tl.float32, builder) - out_scalar_ty = tl.float32 - - # choose the right unsigned operation - if scalar_ty.is_int_unsigned(): - int_op_to_unit = { - ir.REDUCE_OP.MIN: ir.REDUCE_OP.UMIN, - ir.REDUCE_OP.MAX: ir.REDUCE_OP.UMAX, - ir.REDUCE_OP.ARGMIN: ir.REDUCE_OP.ARGUMIN, - ir.REDUCE_OP.ARGMAX: ir.REDUCE_OP.ARGUMAX, - } - if INT_OP in int_op_to_unit: - INT_OP = int_op_to_unit[INT_OP] - - # If we are doing an argmin or argmax we want to use an int32 output type - if FLOAT_OP is ir.REDUCE_OP.ARGFMAX or INT_OP is ir.REDUCE_OP.ARGMAX: - out_scalar_ty = tl.int32 - elif FLOAT_OP is ir.REDUCE_OP.ARGFMIN or INT_OP is ir.REDUCE_OP.ARGMIN: - out_scalar_ty = tl.int32 - - # get result type - shape = input.type.shape - - rank = len(shape) - assert 0 <= axis < rank, f"axis (v={axis}) is out of range, should be within [0, {rank})" - - ret_shape = [] - for i, s in enumerate(shape): - if i != axis: - ret_shape.append(s) - if ret_shape: - res_ty = tl.block_type(out_scalar_ty, ret_shape) - else: - # 0d-tensor -> scalar - res_ty = out_scalar_ty - - if scalar_ty.is_floating(): - return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty) - elif scalar_ty.is_int(): - return tl.tensor(builder.create_reduce(input.handle, INT_OP, axis), res_ty) - assert False - - -def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN) - - -def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN) - - -def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX) - - -def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX) - - -def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD) - - -def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - scalar_ty = input.type.scalar - if not scalar_ty.is_int(): - raise ValueError("xor_sum only supported for integers") - return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR) - def reduction( inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder ) -> Tuple[tl.tensor, ...]: @@ -1204,7 +1121,7 @@ def wrap_tensor(x, scalar_ty): res_ty = scalar_ty return tl.tensor(x, res_ty) - reduce_op = builder.create_generic_reduce([t.handle for t in inputs], axis) + reduce_op = builder.create_reduce([t.handle for t in inputs], axis) region_builder_fn(reduce_op) reduce_op.verify() @@ -1213,43 +1130,6 @@ def wrap_tensor(x, scalar_ty): for i in range(len(inputs)) ) -@contextmanager -def insertion_guard(builder): - ip = builder.get_insertion_point() - yield - builder.restore_insertion_point(ip) - -def prod(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - - def make_mul(reduce_op): - ir_scalar_ty = input.type.scalar.to_ir(builder) - region = reduce_op.get_region(0) - with insertion_guard(builder): - block = builder.create_block_with_parent(region, [ir_scalar_ty] * 2) - fmul = builder.create_fmul(block.arg(0), block.arg(1)) - builder.create_reduce_ret(fmul) - - return reduction((input,), axis, make_mul, builder)[0] - -def min_with_index(keys: tl.tensor, values: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - - def make_min_with_index_combine(reduce_op): - ir_key_ty = keys.type.scalar.to_ir(builder) - ir_value_ty = values.type.scalar.to_ir(builder) - region = reduce_op.get_region(0) - with insertion_guard(builder): - block = builder.create_block_with_parent(region, [ir_key_ty, ir_value_ty] * 2) - value1, index1, value2, index2 = [block.arg(i) for i in range(4)] - lt = builder.create_fcmpOLT(value1, value2) - gt = builder.create_fcmpOGT(value1, value2) - index_min = builder.create_smin(index1, index2) - index_ret = builder.create_select( - lt, index1, builder.create_select(gt, index2, index_min)) - - value_min = builder.create_fmin(value1, value2) - builder.create_reduce_ret(value_min, index_ret) - - return reduction((keys, values), axis, make_min_with_index_combine, builder) # ===----------------------------------------------------------------------=== # Math From 4b74ce383b4c2a1f5ecdff3c5d92fc212b2ed11c Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Tue, 14 Mar 2023 21:10:50 +0000 Subject: [PATCH 06/15] Misc cleanup --- python/src/triton.cc | 12 ------------ python/triton/language/__init__.py | 1 - python/triton/language/core.py | 2 ++ 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 396def954a2a..96da7d5e8f4d 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -870,18 +870,6 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); return self.create(loc, lhs, rhs); }) - .def("create_fmin", - [](mlir::OpBuilder &self, mlir::Value &lhs, - mlir::Value &rhs) -> mlir::Value { - auto loc = self.getUnknownLoc(); - return self.create(loc, lhs, rhs); - }) - .def("create_smin", - [](mlir::OpBuilder &self, mlir::Value &lhs, - mlir::Value &rhs) -> mlir::Value { - auto loc = self.getUnknownLoc(); - return self.create(loc, lhs, rhs); - }) .def("create_add", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 10e67268189b..3df8c243a0b0 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -156,7 +156,6 @@ "philox_impl", "pi32_t", "pointer_type", - "prod", "program_id", "rand", "rand4x", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 0ccbd8e5cb62..26873e7e9efb 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1126,12 +1126,14 @@ def _decorator(func: T) -> T: return _decorator + @contextmanager def _insertion_guard(builder): ip = builder.get_insertion_point() yield builder.restore_insertion_point(ip) + @builtin def reduction(input, axis, combine_fn, _builder=None, _generator=None): """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis` From 4b2b16ac47ec997a2548d1d1cee4343e813f7984 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Thu, 16 Mar 2023 16:33:39 +0000 Subject: [PATCH 07/15] Add SameOperandsEncoding --- include/triton/Dialect/Triton/IR/TritonOps.td | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 46797cff8bdd..32d884278b65 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -349,7 +349,10 @@ def TT_DotOp : TT_Op<"dot", [Pure, // Reduce Op // def TT_ReduceOp: TT_Op<"reduce", - [Pure, DeclareOpInterfaceMethods, SingleBlock]> { + [Pure, + SameOperandsEncoding, + SingleBlock, + DeclareOpInterfaceMethods]> { let summary = "Reduction using generic combination algorithm"; let arguments = (ins Variadic:$operands, I32Attr:$axis); let results = (outs Variadic:$result); From b3957a7bb902e883596b35c775c856efbd212170 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Thu, 16 Mar 2023 17:14:37 +0000 Subject: [PATCH 08/15] Run clang-format --- include/triton/Analysis/Utility.h | 4 +- lib/Analysis/Utility.cpp | 2 +- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 14 +++--- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 50 ++++++++++--------- .../TritonGPUToLLVM/TypeConverter.cpp | 9 ++-- .../TritonToTritonGPUPass.cpp | 28 +++++------ lib/Dialect/Triton/IR/Ops.cpp | 48 +++++++++--------- .../Transforms/RemoveLayoutConversions.cpp | 12 ++--- python/src/triton.cc | 11 ++-- 9 files changed, 91 insertions(+), 87 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index bc9a46c5bf9d..4e7275ccb789 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -12,8 +12,8 @@ namespace mlir { class ReduceOpHelper { public: - explicit ReduceOpHelper(triton::ReduceOp rop): - op(rop.getOperation()), axis(rop.getAxis()) { + explicit ReduceOpHelper(triton::ReduceOp rop) + : op(rop.getOperation()), axis(rop.getAxis()) { auto firstTy = rop.getOperands()[0].getType().cast(); srcShape = firstTy.getShape(); srcEncoding = firstTy.getEncoding(); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index a5e594b40357..9b99c84f3eca 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -72,7 +72,7 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() { } unsigned bytes_per_elem = 0; - for (const auto &ty: srcElementTypes) { + for (const auto &ty : srcElementTypes) { bytes_per_elem += ty.getIntOrFloatBitWidth() / 8; } return bytes_per_elem * elems; diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 483c54f34b35..a935d1e00220 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -966,14 +966,14 @@ void populateElementwiseOpToLLVMPatterns( POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // % POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp) POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp) - POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & - POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | - POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ - POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << - POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> - POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> + POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & + POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | + POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ + POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << + POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> + POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin - POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin + POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin #undef POPULATE_BINARY_OP #define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 3fe98aac34f5..a148d215981b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -23,9 +23,9 @@ struct ReduceOpConversion } private: - void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, - llvm::SmallVectorImpl &acc, ValueRange cur, bool isFirst) const { + llvm::SmallVectorImpl &acc, ValueRange cur, + bool isFirst) const { if (isFirst) { acc.resize(cur.size()); for (unsigned i = 0; i < cur.size(); ++i) { @@ -41,13 +41,14 @@ struct ReduceOpConversion auto &newReduce = parent.front(); auto returnOp = dyn_cast(newReduce.getTerminator()); - llvm::SmallVector combineArgs(2*acc.size()); + llvm::SmallVector combineArgs(2 * acc.size()); for (unsigned i = 0; i < acc.size(); ++i) { combineArgs[i] = acc[i]; combineArgs[acc.size() + i] = cur[i]; } - rewriter.mergeBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), combineArgs); + rewriter.mergeBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), + combineArgs); auto results = returnOp.getResult(); for (unsigned i = 0; i < acc.size(); ++i) { @@ -58,16 +59,16 @@ struct ReduceOpConversion rewriter.eraseOp(returnOp); } - SmallVector> unpackInputs( - Location loc, triton::ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + SmallVector> + unpackInputs(Location loc, triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { auto types = op.getInputTypes(); auto operands = adaptor.getOperands(); unsigned srcElems = getElemsPerThread(types[0]); SmallVector> srcValues(srcElems); for (unsigned i = 0; i < op.getNumOperands(); ++i) { - auto values = getTypeConverter()->unpackLLElements( - loc, operands[i], rewriter, types[i]); + auto values = getTypeConverter()->unpackLLElements(loc, operands[i], + rewriter, types[i]); assert(values.size() == srcValues.size()); for (unsigned j = 0; j < srcValues.size(); ++j) { @@ -106,8 +107,9 @@ struct ReduceOpConversion smemBases[0] = bitcast( getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); for (unsigned i = 1; i < op.getNumOperands(); ++i) { - smemBases[i] = bitcast( - gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(elems)), elemPtrTys[i]); + smemBases[i] = + bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(elems)), + elemPtrTys[i]); } unsigned srcElems = getElemsPerThread(srcTys[0]); @@ -121,7 +123,6 @@ struct ReduceOpConversion std::map, SmallVector> accs; std::map, SmallVector> indices; - Region *combineOp = &op.getCombineOp(); // reduce within threads @@ -155,7 +156,6 @@ struct ReduceOpConversion store(acc[i], writePtrs[i]); } - SmallVector readIdx(writeIdx.size(), ints[0]); for (int N = smemShape[axis] / 2; N > 0; N >>= 1) { readIdx[axis] = ints[N]; @@ -186,7 +186,8 @@ struct ReduceOpConversion // set output values SmallVector results(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { - if (auto resultTy = op.getResult()[i].getType().dyn_cast()) { + if (auto resultTy = + op.getResult()[i].getType().dyn_cast()) { // nd-tensor where n >= 1 auto resultLayout = resultTy.getEncoding(); @@ -199,12 +200,13 @@ struct ReduceOpConversion for (unsigned j = 0; j < resultElems; ++j) { SmallVector readIdx = resultIndices[j]; readIdx.insert(readIdx.begin() + axis, ints[0]); - Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd); + Value readOffset = + linearize(rewriter, loc, readIdx, smemShape, srcOrd); Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); resultVals[j] = load(readPtr); } - results[i] = getTypeConverter()->packLLElements( - loc, resultVals, rewriter, resultTy); + results[i] = getTypeConverter()->packLLElements(loc, resultVals, + rewriter, resultTy); } else { // 0d-tensor -> scalar results[i] = load(smemBases[i]); @@ -247,8 +249,9 @@ struct ReduceOpConversion smemBases[0] = bitcast( getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); for (unsigned i = 1; i < op.getNumOperands(); ++i) { - smemBases[i] = bitcast( - gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)), elemPtrTys[i]); + smemBases[i] = + bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)), + elemPtrTys[i]); } unsigned sizeIntraWarps = helper.getIntraWarpSize(); @@ -377,7 +380,8 @@ struct ReduceOpConversion // set output values SmallVector results(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { - if (auto resultTy = op.getResult()[i].getType().dyn_cast()) { + if (auto resultTy = + op.getResult()[i].getType().dyn_cast()) { // nd-tensor where n >= 1 auto resultLayout = resultTy.getEncoding().cast(); unsigned resultElems = getElemsPerThread(resultTy); @@ -394,8 +398,8 @@ struct ReduceOpConversion resultVals[j] = load(readPtr); } - results[i] = getTypeConverter()->packLLElements( - loc, resultVals, rewriter, resultTy); + results[i] = getTypeConverter()->packLLElements(loc, resultVals, + rewriter, resultTy); } else { // 0d-tensor -> scalar results[i] = load(smemBases[i]); @@ -414,5 +418,5 @@ void populateReduceOpToLLVMPatterns( ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit) { patterns.add(typeConverter, allocation, smem, - indexCacheInfo, benefit); + indexCacheInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index 80d80e8852a2..af8d415d81bd 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -65,13 +65,14 @@ Value TritonGPUToLLVMTypeConverter::packLLElements( Value llvmStruct = rewriter.create(loc, structType); for (const auto &v : llvm::enumerate(resultVals)) { if (!v.value()) { - emitError(loc) << "cannot insert null values into struct, but tried to insert" - << v.value(); + emitError(loc) + << "cannot insert null values into struct, but tried to insert" + << v.value(); } if (v.value().getType() != elementTypes[v.index()]) { emitError(loc) << "invalid element type in packLLEElements. Expected " - << elementTypes[v.index()] << " but got " << v.value().getType(); - + << elementTypes[v.index()] << " but got " + << v.value().getType(); } llvmStruct = insert_val(structType, llvmStruct, v.value(), v.index()); } diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 55280914fdeb..1360605483a0 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -469,30 +469,30 @@ struct TritonReducePattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto newReduce = rewriter.create( op.getLoc(), adaptor.getOperands(), adaptor.getAxis()); addNamedAttrs(newReduce, adaptor.getAttributes()); auto &newCombineOp = newReduce.getCombineOp(); - rewriter.inlineRegionBefore(op.getCombineOp(), newCombineOp, newCombineOp.end()); + rewriter.inlineRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); rewriter.replaceOp(op, newReduce.getResult()); return success(); } }; -struct TritonReduceReturnPattern : - public OpConversionPattern { +struct TritonReduceReturnPattern + : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::ReduceReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - addNamedAttrs( - rewriter.replaceOpWithNewOp( - op, adaptor.getResult()), - adaptor.getAttributes()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, adaptor.getResult()), + adaptor.getAttributes()); return success(); } }; @@ -537,11 +537,11 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonGenericPattern, TritonGenericPattern, TritonBroadcastPattern, TritonGenericPattern, TritonCatPattern, - TritonReducePattern, TritonReduceReturnPattern, - TritonTransPattern, TritonExpandDimsPattern, - TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, - TritonStorePattern, TritonExtElemwisePattern, TritonPrintPattern, - TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context); + TritonReducePattern, TritonReduceReturnPattern, TritonTransPattern, + TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, + TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern, + TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern>( + typeConverter, context); } // diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 3fb15a7db35a..3e204b48bd75 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -238,9 +238,9 @@ mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes( } //-- ReduceOp -- -static mlir::LogicalResult inferReduceReturnShape( - const RankedTensorType &argTy, const Type &retEltTy, - int axis, SmallVectorImpl &inferredReturnTypes) { +static mlir::LogicalResult +inferReduceReturnShape(const RankedTensorType &argTy, const Type &retEltTy, + int axis, SmallVectorImpl &inferredReturnTypes) { auto retShape = argTy.getShape().vec(); retShape.erase(retShape.begin() + axis); if (retShape.empty()) { @@ -289,9 +289,8 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( auto argTy = arg.getType().cast(); auto retEltTy = argTy.getElementType(); int axis = attributes.get("axis").cast().getInt(); - if ( - inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes) - .failed()) { + if (inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes) + .failed()) { return failure(); } } @@ -302,7 +301,7 @@ mlir::LogicalResult mlir::triton::ReduceOp::verify() { if (this->getOperands().size() < 1) { return this->emitOpError() << "must have at least 1 operand"; } - for (const auto &operand: this->getOperands()) { + for (const auto &operand : this->getOperands()) { if (!dyn_cast(operand.getType())) { return this->emitOpError() << "operands must be RankedTensorType"; } @@ -326,28 +325,32 @@ mlir::LogicalResult mlir::triton::ReduceOp::verifyRegions() { const auto &blockArgTy = blockArgTypes[i]; const auto &argElemTy = argElementTypes[i % operands.size()]; if (blockArgTy != argElemTy) { - return this->emitOpError() << "type mismatch on combine operation. Expected argument " << i - << " to have type " << argElemTy << " but got " << blockArgTy; + return this->emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << blockArgTy; } } - auto terminator = dyn_cast(block.getTerminator()); + auto terminator = + dyn_cast(block.getTerminator()); if (!terminator) { - return this->emitOpError() << "combine operation must be terminated " - << "with a ReduceReturnOp but got " - << block.getTerminator(); + return this->emitOpError() + << "combine operation must be terminated " + << "with a ReduceReturnOp but got " << block.getTerminator(); } const auto &combineResults = terminator->getOperands(); if (combineResults.size() != operands.size()) { - return this->emitOpError() << "expected combine operation to return " << operands.size() - << " values but got " << combineResults.size(); + return this->emitOpError() + << "expected combine operation to return " << operands.size() + << " values but got " << combineResults.size(); } for (unsigned i = 0; i < combineResults.size(); ++i) { const auto &resultTy = combineResults[i].getType(); const auto &argElemTy = argElementTypes[i]; if (resultTy != argElemTy) { - return this->emitOpError() << "type mismatch on combine operation. Expected argument " << i - << " to have type " << argElemTy << " but got " << resultTy; + return this->emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << resultTy; } } return mlir::success(); @@ -356,7 +359,7 @@ mlir::LogicalResult mlir::triton::ReduceOp::verifyRegions() { llvm::SmallVector ReduceOp::getInputTypes() { llvm::SmallVector srcTys; srcTys.reserve(this->getNumOperands()); - for (const auto &ty: this->getOperands().getTypes()) { + for (const auto &ty : this->getOperands().getTypes()) { srcTys.push_back(ty.cast()); } return srcTys; @@ -365,15 +368,14 @@ llvm::SmallVector ReduceOp::getInputTypes() { llvm::SmallVector ReduceOp::getElementTypes() { llvm::SmallVector srcElemTys; srcElemTys.reserve(this->getNumOperands()); - for (const auto &op: this->getOperands()) { - srcElemTys.push_back(op.getType().cast().getElementType()); + for (const auto &op : this->getOperands()) { + srcElemTys.push_back( + op.getType().cast().getElementType()); } return srcElemTys; } -unsigned ReduceOp::getNumOperands() { - return this->getOperands().size(); -} +unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } //-- SplatOp -- OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index da110c98ce1e..d0254d63810e 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -105,10 +105,7 @@ class SimplifyReduceCvt : public mlir::RewritePattern { } newOperands[0] = firstArgConversionOp.getOperand(); auto newEncoding = - newOperands[0] - .getType() - .cast() - .getEncoding(); + newOperands[0].getType().cast().getEncoding(); if (!newEncoding.isa()) { // ReduceOpToLLVM requires block encoding @@ -120,10 +117,10 @@ class SimplifyReduceCvt : public mlir::RewritePattern { return failure(); } - for (unsigned i = 1; i < newOperands.size(); ++i) { auto oldTy = newOperands[i].getType().cast(); - RankedTensorType newTy = RankedTensorType::Builder(oldTy).setEncoding(newEncoding); + RankedTensorType newTy = + RankedTensorType::Builder(oldTy).setEncoding(newEncoding); newOperands[i] = rewriter.create( op->getLoc(), newTy, newOperands[i]); @@ -132,7 +129,8 @@ class SimplifyReduceCvt : public mlir::RewritePattern { auto newReduce = rewriter.create( op->getLoc(), newOperands, reduce.getAxis()); auto &newCombineOp = newReduce.getCombineOp(); - rewriter.inlineRegionBefore(reduce.getCombineOp(), newCombineOp, newCombineOp.end()); + rewriter.inlineRegionBefore(reduce.getCombineOp(), newCombineOp, + newCombineOp.end()); SmallVector newRet = newReduce.getResult(); auto oldTypes = reduce.getResult().getType(); diff --git a/python/src/triton.cc b/python/src/triton.cc index 96da7d5e8f4d..bc700b50b4f3 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1293,9 +1293,8 @@ void init_triton_ir(py::module &&m) { return self.create(loc, val); }) .def("create_reduce", - []( - mlir::OpBuilder &self, std::vector operands, int axis - ) -> mlir::OpState { + [](mlir::OpBuilder &self, std::vector operands, + int axis) -> mlir::OpState { auto loc = self.getUnknownLoc(); return self.create(loc, operands, axis); }) @@ -1303,11 +1302,11 @@ void init_triton_ir(py::module &&m) { [](mlir::OpBuilder &self, py::args args) -> mlir::OpState { auto loc = self.getUnknownLoc(); llvm::SmallVector return_values; - for (const auto & arg: args) { + for (const auto &arg : args) { return_values.push_back(py::cast(arg)); } - return self.create( - loc, return_values); + return self.create(loc, + return_values); }) .def("create_ptr_to_int", [](mlir::OpBuilder &self, mlir::Value &val, From bbec2fe6eadeb46c825ccf597e646f324e5e9bbc Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Mon, 20 Mar 2023 12:38:25 +0000 Subject: [PATCH 09/15] Fix lit tests --- test/Analysis/test-allocation.mlir | 6 ++- test/Analysis/test-membar.mlir | 12 ++++-- test/Conversion/triton_ops.mlir | 48 ++++++++++++++++++------ test/Conversion/triton_to_tritongpu.mlir | 32 ++++++++++++---- test/TritonGPU/combine.mlir | 18 +++++++-- 5 files changed, 88 insertions(+), 28 deletions(-) diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index c155b1261110..5c3feafdacf1 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -170,7 +170,11 @@ func.func @alloc(%A : !tt.ptr) { func.func @scratch() { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> // CHECK: scratch offset = 0, size = 512 - %b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0> + %b = "tt.reduce" (%cst0) ({ + ^bb0(%arg0: f16, %arg1: f16): + %add = arith.addf %arg0, %arg1 : f16 + tt.reduce.return %add : f16 + }) {axis = 0 : i32} : (tensor<16x16xf16, #AL>) -> tensor<16xf16, #sliceAd0> return // CHECK-NEXT: size = 512 } diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 36bf52813a0b..e8327c0528e1 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -79,7 +79,11 @@ func.func @scratch() { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.convert_layout %1 = triton_gpu.convert_layout %0 : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL> - %2 = tt.reduce %1 {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0> + %2 = "tt.reduce" (%1) ({ + ^bb0(%arg1: f16, %arg2: f16): + %add = arith.addf %arg1, %arg2 : f16 + tt.reduce.return %add : f16 + }) {axis = 0 : i32} : (tensor<32x16xf16, #AL>) -> tensor<16xf16, #sliceAd0> return } @@ -417,7 +421,7 @@ func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK: gpu.barrier - %c_blocked = triton_gpu.convert_layout %c_shared_init : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL> + %c_blocked = triton_gpu.convert_layout %c_shared_init : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL> %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { %c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A_SHARED> { @@ -429,13 +433,13 @@ func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr %c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A_SHARED>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.convert_layout - %c_blocked_next = triton_gpu.convert_layout %c_shared_next : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL> + %c_blocked_next = triton_gpu.convert_layout %c_shared_next : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL> scf.yield %c_shared : tensor<128x32xf16, #A_SHARED> } scf.yield %c_shared_ : tensor<128x32xf16, #A_SHARED> } // CHECK-NOT: gpu.barrier - %b_blocked_next = triton_gpu.convert_layout %b_shared: (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL> + %b_blocked_next = triton_gpu.convert_layout %b_shared: (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL> scf.yield %a_shared, %b_shared, %c_shared_next_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } return diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir index 0e35c634d2a6..c793f0bc0a65 100644 --- a/test/Conversion/triton_ops.mlir +++ b/test/Conversion/triton_ops.mlir @@ -79,18 +79,42 @@ func.func @load_store_ops_scalar(%ptr: !tt.ptr {tt.divisibility = 16 : i32} func.func @reduce_ops_infer(%ptr: !tt.ptr, %v : tensor<1x2x4xf32>) { // Test if reduce ops infer types correctly - // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32> - %a = tt.reduce %v {redOp = 1 : i32, axis = 0 : i32} : tensor<1x2x4xf32> -> tensor<2x4xf32> - // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x4xf32> - %b = tt.reduce %v {redOp = 1 : i32, axis = 1 : i32} : tensor<1x2x4xf32> -> tensor<1x4xf32> - // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x2xf32> - %c = tt.reduce %v {redOp = 1 : i32, axis = 2 : i32} : tensor<1x2x4xf32> -> tensor<1x2xf32> - // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1xf32> - %e = tt.reduce %b {redOp = 1 : i32, axis = 1 : i32} : tensor<1x4xf32> -> tensor<1xf32> - // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<4xf32> - %f = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<2x4xf32> -> tensor<4xf32> - // CHECK: %{{.*}} = tt.reduce %{{.*}} -> f32 - %g = tt.reduce %f {redOp = 1 : i32, axis = 0 : i32} : tensor<4xf32> -> f32 + // CHECK: }) {axis = 0 : i32} : (tensor<1x2x4xf32>) -> tensor<2x4xf32> + %a = "tt.reduce" (%v) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<1x2x4xf32>) -> tensor<2x4xf32> + // CHECK: }) {axis = 1 : i32} : (tensor<1x2x4xf32>) -> tensor<1x4xf32> + %b = "tt.reduce" (%v) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 1 : i32} : (tensor<1x2x4xf32>) -> tensor<1x4xf32> + // CHECK: }) {axis = 2 : i32} : (tensor<1x2x4xf32>) -> tensor<1x2xf32> + %c = "tt.reduce" (%v) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 2 : i32} : (tensor<1x2x4xf32>) -> tensor<1x2xf32> + // CHECK: }) {axis = 1 : i32} : (tensor<1x4xf32>) -> tensor<1xf32> + %e = "tt.reduce" (%b) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 1 : i32} : (tensor<1x4xf32>) -> tensor<1xf32> + // CHECK: }) {axis = 0 : i32} : (tensor<2x4xf32>) -> tensor<4xf32> + %f = "tt.reduce" (%a) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<2x4xf32>) -> tensor<4xf32> + // CHECK: }) {axis = 0 : i32} : (tensor<4xf32>) -> f32 + %g = "tt.reduce" (%f) ({ + ^bb0(%arg0: f32, %arg1: f32): + %add = arith.addf %arg0, %arg1 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<4xf32>) -> f32 // Avoid optimizations for c, e, and g %ptr1x2 = tt.splat %ptr : (!tt.ptr) -> tensor<1x2x!tt.ptr> diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index dd1e4df2d472..e8d05a6f7912 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -40,14 +40,30 @@ func.func @reduce_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { %c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32> %c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32> %c2 = arith.constant dense<3.00e+00> : tensor<16x16xf32> - // CHECK: tensor<4x4xf32, #[[blocked0]]> -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked0]]}>> - %c0_ = tt.reduce %c0 {redOp = 1 : i32, axis = 0 : i32} : tensor<4x4xf32> -> tensor<4xf32> - // CHECK: tensor<8x2xf32, #[[blocked1]]> -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked1]]}> - %c1_ = tt.reduce %c1 {redOp = 1 : i32, axis = 0 : i32} : tensor<8x2xf32> -> tensor<2xf32> - // CHECK: tensor<8x2xf32, #[[blocked1]]> -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #[[blocked1]]}>> - %c2_ = tt.reduce %c1 {redOp = 1 : i32, axis = 1 : i32} : tensor<8x2xf32> -> tensor<8xf32> - // CHECK: tensor<16x16xf32, #[[blocked2]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked2]]}>> - %c3_ = tt.reduce %c2 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf32> -> tensor<16xf32> + // CHECK: (tensor<4x4xf32, #[[blocked0]]>) -> tensor<4xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked0]]}>> + %c0_ = "tt.reduce" (%c0) ({ + ^bb0(%arg1: f32, %arg2: f32): + %add = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<4x4xf32>) -> tensor<4xf32> + // CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<2xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked1]]}> + %c1_ = "tt.reduce" (%c1) ({ + ^bb0(%arg3: f32, %arg4: f32): + %add = arith.addf %arg3, %arg4 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<8x2xf32>) -> tensor<2xf32> + // CHECK: (tensor<8x2xf32, #[[blocked1]]>) -> tensor<8xf32, #triton_gpu.slice<{dim = 1, parent = #[[blocked1]]}>> + %c2_ = "tt.reduce" (%c1) ({ + ^bb0(%arg5: f32, %arg6: f32): + %add = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %add : f32 + }) {axis = 1 : i32} : (tensor<8x2xf32>) -> tensor<8xf32> + // CHECK: (tensor<16x16xf32, #[[blocked2]]>) -> tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #[[blocked2]]}>> + %c3_ = "tt.reduce" (%c2) ({ + ^bb0(%arg7: f32, %arg8: f32): + %add = arith.addf %arg7, %arg8 : f32 + tt.reduce.return %add : f32 + }) {axis = 0 : i32} : (tensor<16x16xf32>) -> tensor<16xf32> return } diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 376c96e2d911..531d80b4702c 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -772,7 +772,11 @@ func.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: %27 = "triton_gpu.cmpf"(%cst_2, %26) {predicate = 4 : i64} : (tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xi1, #blocked2> %28 = arith.andi %22, %27 : tensor<16x16xi1, #blocked2> %29 = "triton_gpu.select"(%28, %26, %cst_2) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2> - %30 = tt.reduce %29 {axis = 1 : i32, redOp = 12 : i32} : tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %30 = "tt.reduce" (%29) ({ + ^bb0(%arg4: f32, %arg5: f32): + %max = arith.maxf %arg4, %arg5 : f32 + tt.reduce.return %max : f32 + }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %31 = triton_gpu.convert_layout %30 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0> %32 = triton_gpu.convert_layout %31 : (tensor<16xf32, #blocked0>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> %33 = tt.expand_dims %32 {axis = 1 : i32} : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xf32, #blocked1> @@ -788,7 +792,11 @@ func.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: %43 = math.exp %42 : tensor<16x16xf32, #blocked2> %44 = arith.addf %36, %43 : tensor<16x16xf32, #blocked2> %45 = "triton_gpu.select"(%22, %44, %36) : (tensor<16x16xi1, #blocked2>, tensor<16x16xf32, #blocked2>, tensor<16x16xf32, #blocked2>) -> tensor<16x16xf32, #blocked2> - %46 = tt.reduce %45 {axis = 1 : i32, redOp = 2 : i32} : tensor<16x16xf32, #blocked2> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %46 = "tt.reduce" (%45) ({ + ^bb0(%arg4: f32, %arg5: f32): + %add = arith.addf %arg4, %arg5 : f32 + tt.reduce.return %add : f32 + }) {axis = 1 : i32} : (tensor<16x16xf32, #blocked2>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %47 = triton_gpu.convert_layout %46 : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16xf32, #blocked0> %48 = triton_gpu.convert_layout %47 : (tensor<16xf32, #blocked0>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> %49 = tt.expand_dims %48 {axis = 1 : i32} : (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xf32, #blocked1> @@ -892,7 +900,11 @@ func.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! %74 = "triton_gpu.select"(%54, %73, %arg7) : (tensor<64x64xi1, #blocked2>, tensor<64x64xf32, #blocked2>, tensor<64x64xf32, #blocked2>) -> tensor<64x64xf32, #blocked2> scf.yield %74 : tensor<64x64xf32, #blocked2> } - %26 = tt.reduce %25 {axis = 1 : i32, redOp = 2 : i32} : tensor<64x64xf32, #blocked2> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %26 = "tt.reduce" (%25) ({ + ^bb0(%arg8: f32, %arg9: f32): + %add = arith.addf %arg8, %arg9 : f32 + tt.reduce.return %add : f32 + }) {axis = 1 : i32} : (tensor<64x64xf32, #blocked2>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %27 = triton_gpu.convert_layout %26 : (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<64xf32, #blocked0> %28 = triton_gpu.convert_layout %27 : (tensor<64xf32, #blocked0>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> %29 = tt.expand_dims %28 {axis = 1 : i32} : (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xf32, #blocked1> From 7e195a6130b5f7e9f41700e36716a293348c2d07 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Mon, 20 Mar 2023 21:05:08 +0000 Subject: [PATCH 10/15] Update to newer LLVM --- lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index a148d215981b..275d5ea1f083 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -47,8 +47,8 @@ struct ReduceOpConversion combineArgs[acc.size() + i] = cur[i]; } - rewriter.mergeBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), - combineArgs); + rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), + combineArgs); auto results = returnOp.getResult(); for (unsigned i = 0; i < acc.size(); ++i) { From 19d490bb43f58d3d548bb62ed4a4a7ee43b755d8 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Thu, 23 Mar 2023 18:40:48 +0000 Subject: [PATCH 11/15] Lint --- python/triton/compiler.py | 2 +- python/triton/language/core.py | 8 +++++--- python/triton/language/semantic.py | 1 + 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 77de93b4fa60..1fa55dac7e00 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -4,6 +4,7 @@ import contextlib import functools import hashlib +import inspect import io import json import os @@ -17,7 +18,6 @@ from collections import namedtuple from pathlib import Path from typing import Any, Callable, Dict, Optional, Tuple, Type, Union -import inspect import setuptools import torch diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 26873e7e9efb..b80ead837366 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1,8 +1,8 @@ from __future__ import annotations +from contextlib import contextmanager from enum import Enum from typing import Callable, List, TypeVar -from contextlib import contextmanager import triton from . import builtin, semantic @@ -1144,7 +1144,8 @@ def reduction(input, axis, combine_fn, _builder=None, _generator=None): """ if isinstance(input, tensor): - return reduction((input,), axis, combine_fn, _builder=_builder, _generator=_generator)[0] + return reduction((input,), axis, combine_fn, + _builder=_builder, _generator=_generator)[0] def make_combine_region(reduce_op): in_scalar_tys = [t.type.scalar for t in input] @@ -1154,7 +1155,8 @@ def make_combine_region(reduce_op): with _insertion_guard(_builder): param_types = [ty.to_ir(_builder) for ty in prototype.param_types] block = _builder.create_block_with_parent(region, param_types) - args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + args = [tensor(block.arg(i), ty) + for i, ty in enumerate(prototype.param_types)] results = _generator.call_JitFunction(combine_fn, args, kwargs={}) if isinstance(results, tensor): handles = [results.handle] diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 47eea698fac6..4ef12c1cd6bf 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1103,6 +1103,7 @@ def where(condition: tl.tensor, # Reduction # ===----------------------------------------------------------------------=== + def reduction( inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder ) -> Tuple[tl.tensor, ...]: From c6f777b4577fbd44b5e307213d2962a1ce4e255c Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Thu, 30 Mar 2023 16:24:33 +0100 Subject: [PATCH 12/15] Fix merge conflicts --- lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp | 1 - python/triton/compiler.py | 2 +- python/triton/language/core.py | 1 + python/triton/language/semantic.py | 3 +-- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 005da419ad79..60566814fa93 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -117,7 +117,6 @@ class SimplifyReduceCvt : public mlir::RewritePattern { if (!reduce) return mlir::failure(); - SmallVector newOperands = reduce.getOperands(); newOperands[0] = convert.getOperand(); diff --git a/python/triton/compiler.py b/python/triton/compiler.py index b9386a4461a5..e373b1f8d43c 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -865,7 +865,7 @@ def visit_Call(self, node): sig = inspect.signature(fn) if '_generator' in sig.parameters: extra_kwargs['_generator'] = self - return fn(*args, _builder=self.builder, **kws) + return fn(*args, **extra_kwargs, **kws) if fn in self.builtin_namespace.values(): args = map(_unwrap_if_constexpr, args) return fn(*args, **kws) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index ac136e8bbece..f9ea41974b68 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1243,6 +1243,7 @@ def _argreduce(input, axis, combine_fn, _builder=None, _generator=None): def _max_combine(a, b): return maximum(a, b) + @triton.jit @_add_reduction_docstr("maximum") def max(input, axis): diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 5b74e29be213..36b96a87719d 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1,7 +1,6 @@ from __future__ import annotations # remove after python 3.11 -from contextlib import contextmanager -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple from . import core as tl from triton._C.libtriton.triton import ir From 440a39d8a85e8c2d25d0ade21b6e83df0ff63183 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Tue, 4 Apr 2023 19:04:22 +0100 Subject: [PATCH 13/15] Respond to some review comments --- include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td | 4 ++-- lib/Analysis/Membar.cpp | 2 +- lib/Analysis/Utility.cpp | 6 +++--- .../TritonGPU/Transforms/RemoveLayoutConversions.cpp | 8 ++++---- lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 9d5162373e4b..b774fd54684d 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -100,8 +100,8 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise, let description = [{}]; let arguments = (ins TT_BoolLike:$condition, - TT_Type:$true_value, - TT_Type:$false_value); + TT_Tensor:$true_value, + TT_Tensor:$false_value); let results = (outs TT_Type:$result); } diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index ec7d2d1ed2f9..86060c2df0d7 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -76,7 +76,7 @@ void MembarAnalysis::visitTerminator(Operation *op, if (isa(op) || isa(op)) { return; } - op->emitOpError("Unknown terminator encountered in membar analysis"); + llvm_unreachable("Unknown terminator encountered in membar analysis"); } void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 9b99c84f3eca..3f2dc7dfa993 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -71,11 +71,11 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() { elems = product(smemShape); } - unsigned bytes_per_elem = 0; + unsigned bytesPerElem = 0; for (const auto &ty : srcElementTypes) { - bytes_per_elem += ty.getIntOrFloatBitWidth() / 8; + bytesPerElem += ty.getIntOrFloatBitWidth() / 8; } - return bytes_per_elem * elems; + return bytesPerElem * elems; } bool isSharedEncoding(Value value) { diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 60566814fa93..fc34862ec7dd 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -106,8 +106,8 @@ class SimplifyReduceCvt : public mlir::RewritePattern { continue; } - // TODO: This always takes layout from the first argument which - // is fine for argmin/argmax but may not be optimal generally + // TODO: This only moves conversions from the first argument which is + // fine for argmin/argmax but may not be optimal generally if (convert.getResult() != owner.getOperands()[0]) { continue; } @@ -123,8 +123,8 @@ class SimplifyReduceCvt : public mlir::RewritePattern { auto newEncoding = newOperands[0].getType().cast().getEncoding(); - if (!newEncoding.isa()) { - // ReduceOpToLLVM requires block encoding + // this may generate unsupported conversions in the LLVM codegen + if (newEncoding.isa()) { return failure(); } diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index cc1c1d245c44..29de81f71549 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -80,7 +80,7 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( addIllegalOp(); // We have custom versions of some arith operators - addIllegalOp(); + addIllegalOp(); addDynamicallyLegalDialect Date: Mon, 10 Apr 2023 18:13:42 +0100 Subject: [PATCH 14/15] Don't rematerialize ReduceOp --- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 2 +- test/TritonGPU/combine.mlir | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index fe019fae40a2..f775936045ab 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -123,7 +123,7 @@ bool expensiveToRemat(Operation *op, Attribute &targetEncoding) { return expensiveLoadOrStore(op, targetEncoding); if (isa(op)) + triton::AtomicCASOp, triton::DotOp, triton::ReduceOp>(op)) return true; if (isa( op)) diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 35c389239508..190e687b7591 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1013,12 +1013,15 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %1 = triton_gpu.convert_layout %0 : (tensor<2xi32, #blocked1>) -> tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> %2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x2xi32, #blocked> %3 = "triton_gpu.cmpi"(%2, %cst_0) {predicate = 2 : i64} : (tensor<1x2xi32, #blocked>, tensor<1x2xi32, #blocked>) -> tensor<1x2xi1, #blocked> + // CHECK-DAG: }) {axis = 1 : i32} %4 = "tt.reduce" (%cst) ({ ^bb0(%arg3: i32, %arg4: i32): %add = arith.addi %arg3, %arg4 : i32 tt.reduce.return %add : i32 }) {axis = 1 : i32} : (tensor<1x2xi32, #blocked>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + // CHECK-NEXT: triton_gpu.convert_layout {{%.*}} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> %5 = triton_gpu.convert_layout %4 : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<1xi32, #blocked1> + // CHECK-NOT: triton_gpu.convert_layout %6 = triton_gpu.convert_layout %5 : (tensor<1xi32, #blocked1>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %7 = tt.expand_dims %6 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<1x1xi32, #blocked2> %8 = triton_gpu.convert_layout %7 : (tensor<1x1xi32, #blocked2>) -> tensor<1x1xi32, #blocked> From c7c8ac12e3172db3c81a7827e983c14141dba4b1 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Wed, 12 Apr 2023 14:14:19 +0100 Subject: [PATCH 15/15] Revert "Don't rematerialize ReduceOp" This reverts commit 19d31c6194c306dc2f0e96012567a2da59abda66. --- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 2 +- test/TritonGPU/combine.mlir | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 59a0ef8a4813..1a517acceb52 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -136,7 +136,7 @@ bool expensiveToRemat(Operation *op, Attribute &targetEncoding) { return expensiveLoadOrStore(op, targetEncoding); if (isa(op)) + triton::AtomicCASOp, triton::DotOp>(op)) return true; if (isa( op)) diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 4a151f698c24..a53eda6f1ecb 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1028,15 +1028,12 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %1 = triton_gpu.convert_layout %0 : (tensor<2xi32, #blocked1>) -> tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> %2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x2xi32, #blocked> %3 = "triton_gpu.cmpi"(%2, %cst_0) {predicate = 2 : i64} : (tensor<1x2xi32, #blocked>, tensor<1x2xi32, #blocked>) -> tensor<1x2xi1, #blocked> - // CHECK-DAG: }) {axis = 1 : i32} %4 = "tt.reduce" (%cst) ({ ^bb0(%arg3: i32, %arg4: i32): %add = arith.addi %arg3, %arg4 : i32 tt.reduce.return %add : i32 }) {axis = 1 : i32} : (tensor<1x2xi32, #blocked>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - // CHECK-NEXT: triton_gpu.convert_layout {{%.*}} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> %5 = triton_gpu.convert_layout %4 : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<1xi32, #blocked1> - // CHECK-NOT: triton_gpu.convert_layout %6 = triton_gpu.convert_layout %5 : (tensor<1xi32, #blocked1>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %7 = tt.expand_dims %6 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<1x1xi32, #blocked2> %8 = triton_gpu.convert_layout %7 : (tensor<1x1xi32, #blocked2>) -> tensor<1x1xi32, #blocked>