From 00a9143bb46c112542c13bb032b82cfb6c60a205 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Tue, 4 Apr 2023 17:12:00 -0700 Subject: [PATCH 1/9] [FRONTEND] Expose Autotuner to users (#1473) The Autotuner is a handy utility. By allowing external access to the Autotuner, users can overwrite some functions (e.g., `run`) to load/store best configurations, initialize tensors based on configuration values, and change benchmarking standard (e.g., based on bytes instead of time). --- python/triton/runtime/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/triton/runtime/__init__.py b/python/triton/runtime/__init__.py index 37629c461208..176c9545db32 100644 --- a/python/triton/runtime/__init__.py +++ b/python/triton/runtime/__init__.py @@ -1,5 +1,6 @@ from . import driver -from .autotuner import Config, Heuristics, OutOfResources, autotune, heuristics +from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune, + heuristics) from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret, version_key) @@ -16,4 +17,5 @@ "TensorWrapper", "OutOfResources", "MockTensor", + "Autotuner", ] From 47e73aaddaa5e87e7e824e5f949ae496aad82ecc Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 4 Apr 2023 17:52:26 -0700 Subject: [PATCH 2/9] [BACKEND] Revert inline PTX for conversions supported by LLVM (#1474) No longer needed now that we initialize all registers. Motivation for reverting this workaround now that we can is that it introduced performance regressions --- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 151 ------------------ .../TritonGPUToLLVM/ElementwiseOpToLLVM.h | 4 - .../TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 45 ------ 3 files changed, 200 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 102747d3beea..c4924151bc00 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1122,154 +1122,3 @@ void populateElementwiseOpToLLVMPatterns( // __nv_expf for higher-precision calculation patterns.add(typeConverter, benefit); } - -struct FPExtOpConversion - : ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - static bool isLegalOp(LLVM::FPExtOp op) { - auto retTy = op.getResult().getType(); - auto srcTy = op.getOperand().getType(); - if (retTy.isF32() && srcTy.isF16()) { - return false; - } - return true; - } - - Value createDestOp(LLVM::FPExtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, Type elemTy, - ValueRange operands, Location loc) const { - return FpToFpOpConversion::convertFp16ToFp32(loc, rewriter, operands[0]); - } -}; - -struct FPTruncOpConversion - : ElementwiseOpConversionBase { - using Base = - ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - static bool isLegalOp(LLVM::FPTruncOp op) { - auto retTy = op.getResult().getType(); - auto srcTy = op.getOperand().getType(); - if (retTy.isF16() && srcTy.isF32()) { - return false; - } - return true; - } - - Value createDestOp(LLVM::FPTruncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, Type elemTy, - ValueRange operands, Location loc) const { - return FpToFpOpConversion::convertFp32ToFp16(loc, rewriter, operands[0]); - } -}; - -struct TruncOpConversion - : ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - static bool isLegalOp(LLVM::TruncOp op) { - auto retTy = op.getResult().getType(); - auto srcTy = op.getOperand().getType(); - if (retTy.isInteger(16) && srcTy.isInteger(32)) { - return false; - } - return true; - } - - Value createDestOp(LLVM::TruncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, Type elemTy, - ValueRange operands, Location loc) const { - PTXBuilder builder; - auto &cvt = *builder.create("cvt.u16.u32"); - auto res = builder.newOperand("=h"); - auto operand = builder.newOperand(operands[0], "r"); - cvt(res, operand); - return builder.launch(rewriter, loc, i16_ty, false); - } -}; - -struct SExtOpConversion - : ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - static bool isLegalOp(LLVM::SExtOp op) { - auto retTy = op.getResult().getType(); - auto srcTy = op.getOperand().getType(); - if (retTy.isInteger(32) && srcTy.isInteger(16)) { - return false; - } - return true; - } - - Value createDestOp(LLVM::SExtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, Type elemTy, - ValueRange operands, Location loc) const { - PTXBuilder builder; - auto &cvt = *builder.create("cvt.s32.s16"); - auto res = builder.newOperand("=r"); - auto operand = builder.newOperand(operands[0], "h"); - cvt(res, operand); - return builder.launch(rewriter, loc, i32_ty, false); - } -}; - -struct ZExtOpConversion - : ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - static bool isLegalOp(LLVM::ZExtOp op) { - auto retTy = op.getResult().getType(); - auto srcTy = op.getOperand().getType(); - if (retTy.isInteger(32) && srcTy.isInteger(16)) { - return false; - } - return true; - } - - Value createDestOp(LLVM::ZExtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, Type elemTy, - ValueRange operands, Location loc) const { - PTXBuilder builder; - auto &cvt = *builder.create("cvt.u32.u16"); - auto res = builder.newOperand("=r"); - auto operand = builder.newOperand(operands[0], "h"); - cvt(res, operand); - return builder.launch(rewriter, loc, i32_ty, false); - } -}; - -bool isLegalElementwiseOp(Operation *op) { - if (isa(op)) { - return FPExtOpConversion::isLegalOp(cast(op)); - } else if (isa(op)) { - return FPTruncOpConversion::isLegalOp(cast(op)); - } else if (isa(op)) { - return TruncOpConversion::isLegalOp(cast(op)); - } else if (isa(op)) { - return SExtOpConversion::isLegalOp(cast(op)); - } else if (isa(op)) { - return ZExtOpConversion::isLegalOp(cast(op)); - } - return true; -} - -void populateElementwiseOpToPTXPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); -} diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h index d5b2a094955c..20404f875978 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h @@ -13,8 +13,4 @@ void populateElementwiseOpToLLVMPatterns( bool isLegalElementwiseOp(Operation *op); -void populateElementwiseOpToPTXPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit); - #endif diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index d94bb43c25cd..e46bea9c1cd1 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -56,28 +56,6 @@ class TritonLLVMFunctionConversionTarget : public ConversionTarget { } }; -class TritonPTXConversionTarget : public ConversionTarget { -public: - explicit TritonPTXConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { - addDynamicallyLegalDialect( - [&](Operation *op) { return isLegalElementwiseOp(op); }); - - addLegalDialect(); - addLegalOp(); - } -}; - -class TritonGCNConversionTarget : public ConversionTarget { -public: - explicit TritonGCNConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { - addDynamicallyLegalDialect( - [&](Operation *op) { return isLegalElementwiseOp(op); }); - - addLegalDialect(); - addLegalOp(); - } -}; - struct ReturnOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -236,29 +214,6 @@ class ConvertTritonGPUToLLVM patterns); if (failed(applyPartialConversion(mod, target, std::move(patterns)))) return signalPassFailure(); - - if (isROCM) { - TritonGCNConversionTarget gcnTarget(*context); - RewritePatternSet gcnPatterns(context); - populateElementwiseOpToPTXPatterns(typeConverter, gcnPatterns, - /*benefits=*/10); - if (failed( - applyPartialConversion(mod, gcnTarget, std::move(gcnPatterns)))) - return signalPassFailure(); - } else { - // Use our custom converters to convert some operations to PTX to avoid - // using NVPTX for two reasons: - // 1. NVPTX backend is flaky on data types like float16 and bfloat16 - // 2. In some cases, we may generate faster PTX code than NVPTX backend - TritonPTXConversionTarget ptxTarget(*context); - RewritePatternSet ptxPatterns(context); - // Add patterns to convert LLVM to PTX - populateElementwiseOpToPTXPatterns(typeConverter, ptxPatterns, - /*benefits=*/10); - if (failed( - applyPartialConversion(mod, ptxTarget, std::move(ptxPatterns)))) - return signalPassFailure(); - } } private: From 0e11f1e167624b3a688710f9b4b995b177205bb5 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Tue, 4 Apr 2023 21:53:36 -0700 Subject: [PATCH 3/9] [TESTING] Added `triton.allclose` wrapper around `torch.testing.allclose`. This adds a convenience layer to test linear algebra kernels and their perf. --- python/triton/testing.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/python/triton/testing.py b/python/triton/testing.py index 86b18fc1f00a..5cc0135c740e 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -89,6 +89,35 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, return torch.mean(times).item() +def assert_close(x, y, atol=None, rtol=None, err_msg=''): + import torch + + def default_atol(dtype): + return 1e-2 + + def default_rtol(dtype): + return 0. + if atol is None: + atol = default_atol + if rtol is None: + rtol = default_rtol + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + if not isinstance(y, torch.Tensor): + y = torch.tensor(y) + atol = atol(x.dtype) if callable(atol) else atol + rtol = rtol(x.dtype) if callable(rtol) else rtol + if x.numel() > 1 or y.numel() > 1: + # we could use a fused kernel for fast `isclose` + # if x.numel()*16 > torch.cuda.mem_get_info()[0]: + torch.testing.assert_close(x.cpu(), y.cpu(), atol=atol, rtol=rtol, equal_nan=True) + # else: + # torch.testing.assert_close(x, y, atol=atol, rtol=rtol, equal_nan=True) + return + if not torch.isclose(x, y, atol=atol, rtol=rtol): + raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})') + + class Benchmark: """ This class is used by the :code:`perf_report` function to generate line plots with a concise API. From 577cafff0a48a99dc3dc8d4eed4ca22cb161c501 Mon Sep 17 00:00:00 2001 From: Eta <24918963+Eta0@users.noreply.github.com> Date: Wed, 5 Apr 2023 00:41:08 -0500 Subject: [PATCH 4/9] [BUILD] Add missing subpackages to build (#1475) The `triton/compiler`, `triton/runtime/driver`, and `triton/third_party` subpackages were missing from the distribution built with the old `setup.py` after #1464, causing an immediate error upon importing Triton with a non-editable installation. This change adds the missing Python subpackages and moves `triton/third_party` inclusion to `MANIFEST.in`, where it will automatically be included in wheels due to the existing `include_package_data` setup flag. --- python/MANIFEST.in | 1 + python/setup.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/MANIFEST.in b/python/MANIFEST.in index 11a5eb0d370b..ae2de1fe4289 100644 --- a/python/MANIFEST.in +++ b/python/MANIFEST.in @@ -1 +1,2 @@ graft src +graft triton/third_party diff --git a/python/setup.py b/python/setup.py index 711d86a690fb..8ff1a07b3e08 100644 --- a/python/setup.py +++ b/python/setup.py @@ -221,16 +221,18 @@ def build_extension(self, ext): packages=[ "triton", "triton/_C", - "triton/language", - "triton/tools", "triton/common", + "triton/compiler", + "triton/language", "triton/ops", + "triton/ops/blocksparse", "triton/runtime", - "triton/ops/blocksparse"], + "triton/runtime/driver", + "triton/tools", + ], install_requires=[ "filelock", ], - package_data={"triton": ["third_party/**/*"]}, include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], cmdclass={"build_ext": CMakeBuild}, From 4c1d001ae447c0c9ea6f02b764385f049dea0083 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Tue, 4 Apr 2023 23:56:23 -0700 Subject: [PATCH 5/9] [TESTING] Now using numpy instead of pytorch in `triton.assert_close` More memory-efficient than pytorch --- python/triton/testing.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/python/triton/testing.py b/python/triton/testing.py index 5cc0135c740e..100106582e41 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -90,31 +90,37 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, def assert_close(x, y, atol=None, rtol=None, err_msg=''): + import numpy as np import torch + # absolute tolerance hook def default_atol(dtype): return 1e-2 + if atol is None: + atol = default_atol + atol = atol(x.dtype) if callable(atol) else atol + # relative tolerance hook def default_rtol(dtype): return 0. - if atol is None: - atol = default_atol if rtol is None: rtol = default_rtol - if not isinstance(x, torch.Tensor): - x = torch.tensor(x) - if not isinstance(y, torch.Tensor): - y = torch.tensor(y) - atol = atol(x.dtype) if callable(atol) else atol rtol = rtol(x.dtype) if callable(rtol) else rtol - if x.numel() > 1 or y.numel() > 1: - # we could use a fused kernel for fast `isclose` - # if x.numel()*16 > torch.cuda.mem_get_info()[0]: - torch.testing.assert_close(x.cpu(), y.cpu(), atol=atol, rtol=rtol, equal_nan=True) - # else: - # torch.testing.assert_close(x, y, atol=atol, rtol=rtol, equal_nan=True) + # we use numpy instead of pytorch + # as it seems more memory efficient + # pytorch tends to oom on large tensors + if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() + x = x.cpu().detach().numpy() + if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() + y = y.cpu().detach().numpy() + if x.size > 1 or y.size > 1: + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True) return - if not torch.isclose(x, y, atol=atol, rtol=rtol): + if not np.allclose(x, y, atol=atol, rtol=rtol): raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})') From 8cbf9b40a43966da1a297e2cd2456fd3eb6a77ca Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 6 Apr 2023 00:48:33 -0700 Subject: [PATCH 6/9] [TESTING] Minor fixes (#1479) --- .../test/unit/operators/test_blocksparse.py | 19 ++++++++----------- python/triton/testing.py | 18 ++++++++++-------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index 5953a56a1a6d..5f94cd8b31bf 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -81,17 +81,14 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K= a_tri.retain_grad() b_tri.retain_grad() op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda") - try: - c_tri = op(a_tri, b_tri) - c_tri.backward(dc_tri) - da_tri = a_tri.grad - db_tri = b_tri.grad - # compare - torch.testing.assert_allclose(c_ref, c_tri) - torch.testing.assert_allclose(da_ref, da_tri) - torch.testing.assert_allclose(db_ref, db_tri) - except triton.OutOfResourcesError as e: - pytest.skip(str(e)) + c_tri = op(a_tri, b_tri) + c_tri.backward(dc_tri) + da_tri = a_tri.grad + db_tri = b_tri.grad + # compare + torch.testing.assert_allclose(c_ref, c_tri) + torch.testing.assert_allclose(da_ref, da_tri) + torch.testing.assert_allclose(db_ref, db_tri) configs = [ diff --git a/python/triton/testing.py b/python/triton/testing.py index 100106582e41..6f6520fb3b3f 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -93,18 +93,18 @@ def assert_close(x, y, atol=None, rtol=None, err_msg=''): import numpy as np import torch - # absolute tolerance hook - def default_atol(dtype): - return 1e-2 + # canonicalize arguments to be tensors + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + if not isinstance(y, torch.Tensor): + y = torch.tensor(y) + # absolute tolerance if atol is None: - atol = default_atol + atol = 1e-2 atol = atol(x.dtype) if callable(atol) else atol # relative tolerance hook - - def default_rtol(dtype): - return 0. if rtol is None: - rtol = default_rtol + rtol = 0. rtol = rtol(x.dtype) if callable(rtol) else rtol # we use numpy instead of pytorch # as it seems more memory efficient @@ -117,6 +117,8 @@ def default_rtol(dtype): if y.dtype == torch.bfloat16: y = y.float() y = y.cpu().detach().numpy() + # we handle size==1 case separately as we can + # provide better error message there if x.size > 1 or y.size > 1: np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True) return From 7f3f58f3322d537125c6f6a18d50f070d643994b Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Thu, 6 Apr 2023 10:40:40 -0700 Subject: [PATCH 7/9] [FRONTEND] Fix broadcast semantics (#1480) https://github.com/openai/triton/pull/1183 --------- Co-authored-by: Yen-Chen Lin --- python/test/unit/language/test_core.py | 29 ++++++++++++++++++++++++++ python/triton/language/semantic.py | 4 ++-- python/triton/ops/flash_attention.py | 2 +- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index dfd262f7a994..d60f14a73e8e 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -408,6 +408,35 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device) +# --------------- +# test broadcast +# --------------- +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16) +def test_broadcast(dtype): + @triton.jit + def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :]) + y = tl.load(y_ptr + offset2) + _, y_broadcasted = tl.broadcast(x, y) + tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted) + + M = 32 + N = 64 + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype, rs=rs) + y = numpy_random(N, dtype_str=dtype, rs=rs) + _, y_broadcasted_np = np.broadcast_arrays(x, y) + + x_tri = to_triton(x, device='cuda', dst_type=dtype) + y_tri = to_triton(y, device='cuda', dst_type=dtype) + y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device='cuda', dst_type=dtype) + + broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) + assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() + + # --------------- # test where # --------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 64675fa293fc..3d4d6d624e28 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -599,13 +599,13 @@ def broadcast_impl_value(lhs: tl.tensor, if len(lhs_shape) < len(rhs_shape): # Add new axes to lhs for dim in range(len(lhs_shape), len(rhs_shape)): - lhs = tl.tensor(builder.create_expand_dims(lhs.handle, dim), tl.block_type(lhs_ty.scalar, lhs_shape + [1])) + lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) lhs_ty = lhs.type lhs_shape = lhs_ty.get_block_shapes() elif len(rhs_shape) < len(lhs_shape): # Add new axes to rhs for dim in range(len(rhs_shape), len(lhs_shape)): - rhs = tl.tensor(builder.create_expand_dims(rhs.handle, dim), tl.block_type(rhs_ty.scalar, rhs_shape + [1])) + rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) rhs_ty = rhs.type rhs_shape = rhs_ty.get_block_shapes() assert len(rhs_shape) == len(lhs_shape) diff --git a/python/triton/ops/flash_attention.py b/python/triton/ops/flash_attention.py index 923a1ad6644f..33c0da791fb7 100644 --- a/python/triton/ops/flash_attention.py +++ b/python/triton/ops/flash_attention.py @@ -60,7 +60,7 @@ def _fwd_kernel( l_curr = tl.sum(p, 1) + l_prev # rescale operands of matmuls l_rcp = 1. / l_curr - p *= l_rcp + p *= l_rcp[:, None] acc *= (l_prev * l_rcp)[:, None] # update acc p = p.to(Q.dtype.element_ty) From 6743e42eb5ac980fd351e3a0279d619dfaff795e Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Fri, 7 Apr 2023 10:26:19 -0700 Subject: [PATCH 8/9] [FRONTEND] Data type specification for math functions (#1485) --- python/test/unit/language/test_core.py | 4 +-- python/triton/language/core.py | 3 --- python/triton/language/semantic.py | 34 +++++++++++++++++++++++++- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d60f14a73e8e..bbf10310d882 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -543,7 +543,7 @@ def test_unary_op(dtype_x, expr, device='cuda'): # ---------------- -@pytest.mark.parametrize("dtype_x, expr", [(dtype_x, expr) for dtype_x in float_dtypes for expr in ['exp', 'log', 'cos', 'sin']]) +@pytest.mark.parametrize("dtype_x, expr", [(dtype_x, expr) for dtype_x in ["float32", "float64"] for expr in ['exp', 'log', 'cos', 'sin']]) def test_math_op(dtype_x, expr, device='cuda'): _test_unary(dtype_x, f'tl.{expr}(x)', f'np.{expr}(x) ', device=device) @@ -1374,7 +1374,7 @@ def kernel(X, stride_xm, stride_xk, if DO_SOFTMAX: max = tl.max(z, 1) z = z - max[:, None] - num = tl.exp(z) + num = tl.exp(z.to(tl.float32)).to(max.dtype) den = tl.sum(num, 1) z = num / den[:, None] if CHAIN_DOT: diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 92cdeaae31f7..9a06d951a335 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -12,9 +12,6 @@ TRITON_MAX_TENSOR_NUMEL = 131072 - -T = TypeVar("T") - TRITON_BUILTIN = "__triton_builtin__" diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 3d4d6d624e28..77c5354c181a 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1,12 +1,16 @@ from __future__ import annotations # remove after python 3.11 -from typing import List, Optional, Tuple +from functools import wraps +from typing import List, Optional, Tuple, TypeVar from . import core as tl from triton._C.libtriton.triton import ir +T = TypeVar('T') # Create custom exception that prints message "hello" + + class IncompatibleTypeErrorImpl(Exception): def __init__(self, type_a, type_b): self.type_a = type_a @@ -1315,6 +1319,28 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: # Math # ===----------------------------------------------------------------------=== +def _check_dtype(dtypes: List[str]) -> T: + """ + We following libdevice's convention to check accepted data types for math functions. + It is not a good practice to support all data types as accelerators/GPUs don't support + many float16 and bfloat16 math operations. + We should let the users know that they are using and invoke explicit cast to convert + the data type to the supported one. + """ + def wrapper(fn): + @wraps(fn) + def check(*args, **kwargs): + # concatenate args and kwargs + all_args = list(args) + list(kwargs.values()) + for arg in [a for a in all_args if isinstance(a, tl.tensor)]: + if arg.type.scalar.name not in dtypes: + raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}") + return fn(*args, **kwargs) + return check + + return wrapper + + def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: x, y = binary_op_type_checking_impl(x, y, builder) # FIXME(Keren): not portable, should be fixed @@ -1322,28 +1348,34 @@ def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: return math.mulhi(x, y, _builder=builder) +@_check_dtype(dtypes=["fp32", "fp64"]) def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor: # FIXME(Keren): not portable, should be fixed from . import math return math.floor(x, _builder=builder) +@_check_dtype(dtypes=["fp32", "fp64"]) def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_exp(x.handle), x.type) +@_check_dtype(dtypes=["fp32", "fp64"]) def log(x: tl.tensor, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_log(x.handle), x.type) +@_check_dtype(dtypes=["fp32", "fp64"]) def cos(x: tl.tensor, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_cos(x.handle), x.type) +@_check_dtype(dtypes=["fp32", "fp64"]) def sin(x: tl.tensor, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_sin(x.handle), x.type) +@_check_dtype(dtypes=["fp32", "fp64"]) def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_sqrt(x.handle), x.type) From bc0b007e4b1fef2c8cf3e2aed4b5f231b21e2237 Mon Sep 17 00:00:00 2001 From: Ian O'Connell Date: Fri, 7 Apr 2023 14:38:28 -0600 Subject: [PATCH 9/9] [FRONTEND] Allow cache manager to be overridden, and tweak apis to easier work with remote caches (#1478) The changes here come with a few separate bits: - Allow replacing the cache manager with an ENV variable to make it pluggable - Make the `make_path` api private since its leaking some internal bits of the cache and allowing file access. Use a get operation instead. - For the `compile` operation we have a several files part of a single compile pipeline that are small, this can be not the most performant with remote caches. Also some operations like `_triton.get_shared_memory_size` only work when everything is cached or none(or some key ones aren't). They segfault otherwise. So grouping these as an entity avoids that. --- python/triton/compiler/compiler.py | 60 ++++++++++------- python/triton/compiler/make_launcher.py | 12 ++-- python/triton/runtime/cache.py | 88 ++++++++++++++++++++++++- python/triton/runtime/driver/cuda.py | 11 ++-- python/triton/runtime/driver/hip.py | 11 ++-- 5 files changed, 142 insertions(+), 40 deletions(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 3825ab0ba83a..13002b64d4e5 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -17,7 +17,7 @@ import triton._C.libtriton.triton as _triton # TODO: runtime.errors from ..runtime.autotuner import OutOfResources -from ..runtime.cache import CacheManager +from ..runtime.cache import get_cache_manager from ..runtime.driver import get_cuda_utils, get_hip_utils from ..tools.disasm import extract from .code_generator import ast_to_ttir @@ -410,7 +410,7 @@ def compile(fn, **kwargs): # cache manager so_path = make_stub(name, signature, constants) # create cache manager - fn_cache_manager = CacheManager(make_hash(fn, **kwargs)) + fn_cache_manager = get_cache_manager(make_hash(fn, **kwargs)) # determine name and extension type of provided function if isinstance(fn, triton.runtime.JITFunction): name, ext = fn.__name__, "ast" @@ -419,14 +419,22 @@ def compile(fn, **kwargs): # load metadata if any metadata = None - if fn_cache_manager.has_file(f'{name}.json'): - with open(fn_cache_manager._make_path(f"{name}.json")) as f: + metadata_filename = f"{name}.json" + + # The group is addressed by the metadata + metadata_group = fn_cache_manager.get_group( + metadata_filename + ) or {} + + metadata_path = metadata_group.get(metadata_filename) + + if metadata_path is not None: + with open(metadata_path) as f: metadata = json.load(f) else: metadata = {"num_warps": num_warps, "num_stages": num_stages, "constants": _get_jsonable_constants(constants), - "ctime": dict(), "debug": debug} if ext == "ptx": assert "shared" in kwargs, "ptx compilation must provide shared memory size" @@ -437,25 +445,30 @@ def compile(fn, **kwargs): module = fn # run compilation pipeline and populate metadata for ir, (parse, compile_kernel) in list(stages.items())[first_stage:]: - path = fn_cache_manager._make_path(f"{name}.{ir}") + ir_filename = f"{name}.{ir}" + if ir == ext: next_module = parse(fn) - elif os.path.exists(path) and\ - ir in metadata["ctime"] and\ - os.path.getctime(path) == metadata["ctime"][ir]: - if ir == "amdgcn": - next_module = (parse(path), parse(fn_cache_manager._make_path(f"{name}.hsaco_path"))) - else: - next_module = parse(path) else: - next_module = compile_kernel(module) - if ir == "amdgcn": - fn_cache_manager.put(next_module[0], f"{name}.{ir}") - fn_cache_manager.put(next_module[1], f"{name}.hsaco_path") + path = metadata_group.get(ir_filename) + if path is None: + next_module = compile_kernel(module) + if ir == "amdgcn": + extra_file_name = f"{name}.hsaco_path" + metadata_group[ir_filename] = fn_cache_manager.put(next_module[0], ir_filename) + metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name) + else: + metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + fn_cache_manager.put(next_module, ir_filename) else: - fn_cache_manager.put(next_module, f"{name}.{ir}") - if os.path.exists(path): - metadata["ctime"][ir] = os.path.getctime(path) + if ir == "amdgcn": + extra_file_name = f"{name}.hsaco_path" + hasco_path = metadata_group.get(extra_file_name) + assert hasco_path is not None, "Expected to have hsaco_path in metadata when we have the amdgcn" + next_module = (parse(path), parse(hasco_path)) + else: + next_module = parse(path) + if ir == "cubin": asm[ir] = next_module elif ir == "amdgcn": @@ -470,8 +483,11 @@ def compile(fn, **kwargs): metadata["name"] = get_kernel_name(next_module[0], pattern='.globl') asm["hsaco_path"] = next_module[1] module = next_module - # write-back metadata - fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False) + # write-back metadata, if it didn't come from the cache + if metadata_path is None: + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False) + fn_cache_manager.put_group(metadata_filename, metadata_group) + # return handle to compiled kernel return CompiledKernel(fn, so_path, metadata, asm) diff --git a/python/triton/compiler/make_launcher.py b/python/triton/compiler/make_launcher.py index b3c01d676dc8..3da8ddccf5c5 100644 --- a/python/triton/compiler/make_launcher.py +++ b/python/triton/compiler/make_launcher.py @@ -3,7 +3,7 @@ import tempfile from ..common import _build -from ..runtime.cache import CacheManager +from ..runtime.cache import get_cache_manager from ..runtime.jit import version_key @@ -26,10 +26,11 @@ def make_so_cache_key(version_hash, signature, constants): def make_stub(name, signature, constants): # name of files that are cached so_cache_key = make_so_cache_key(version_key(), signature, constants) - so_cache_manager = CacheManager(so_cache_key) + so_cache_manager = get_cache_manager(so_cache_key) so_name = f"{name}.so" # retrieve stub from cache if it exists - if not so_cache_manager.has_file(so_name): + cache_path = so_cache_manager.get_file(so_name) + if cache_path is None: with tempfile.TemporaryDirectory() as tmpdir: src = generate_launcher(constants, signature) src_path = os.path.join(tmpdir, "main.c") @@ -37,8 +38,9 @@ def make_stub(name, signature, constants): f.write(src) so = _build(name, src_path, tmpdir) with open(so, "rb") as f: - so_cache_manager.put(f.read(), so_name, binary=True) - return so_cache_manager._make_path(so_name) + return so_cache_manager.put(f.read(), so_name, binary=True) + else: + return cache_path # ----- source code generation -------- diff --git a/python/triton/runtime/cache.py b/python/triton/runtime/cache.py index ee8071241ddb..3c7b94e3ac86 100644 --- a/python/triton/runtime/cache.py +++ b/python/triton/runtime/cache.py @@ -1,5 +1,8 @@ +import json import os +from abc import ABC, abstractmethod from pathlib import Path +from typing import Dict, Optional from filelock import FileLock @@ -8,8 +11,32 @@ def default_cache_dir(): return os.path.join(Path.home(), ".triton", "cache") -class CacheManager: +class CacheManager(ABC): + def __init__(self, key): + pass + + @abstractmethod + def get_file(self, filename) -> Optional[str]: + pass + + @abstractmethod + def has_file(self, filename) -> bool: + pass + + @abstractmethod + def put(self, data, filename, binary=True) -> str: + pass + + @abstractmethod + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + pass + @abstractmethod + def put_group(self, filename: str, group: Dict[str, str]): + pass + + +class FileCacheManager(CacheManager): def __init__(self, key): self.key = key self.lock_path = None @@ -20,7 +47,7 @@ def __init__(self, key): self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) - def _make_path(self, filename): + def _make_path(self, filename) -> str: return os.path.join(self.cache_dir, filename) def has_file(self, filename): @@ -28,7 +55,40 @@ def has_file(self, filename): return False return os.path.exists(self._make_path(filename)) - def put(self, data, filename, binary=True): + def get_file(self, filename) -> Optional[str]: + if self.has_file(filename): + return self._make_path(filename) + else: + return None + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + grp_filename = f"__grp__{filename}" + if not self.has_file(grp_filename): + return None + grp_filepath = self._make_path(grp_filename) + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + # Invalid group data. + if child_paths is None: + return None + result = {} + for c in child_paths: + p = self._make_path(c) + if not os.path.exists(p): + raise Exception(f"Group file {p} does not exist from group {grp_filename} ") + result[c] = p + return result + + # Note a group of pushed files as being part of a group + def put_group(self, filename: str, group: Dict[str, str]): + if not self.cache_dir: + return + grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename, binary=False) + + def put(self, data, filename, binary=True) -> str: if not self.cache_dir: return binary = isinstance(data, bytes) @@ -42,3 +102,25 @@ def put(self, data, filename, binary=True): with open(filepath + ".tmp", mode) as f: f.write(data) os.rename(filepath + ".tmp", filepath) + return filepath + + +__cache_cls = FileCacheManager +__cache_cls_nme = "DEFAULT" + + +def get_cache_manager(key) -> CacheManager: + import os + + user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None) + global __cache_cls + global __cache_cls_nme + + if user_cache_manager is not None and user_cache_manager != __cache_cls_nme: + import importlib + module_path, clz_nme = user_cache_manager.split(":") + module = importlib.import_module(module_path) + __cache_cls = getattr(module, clz_nme) + __cache_cls_nme = user_cache_manager + + return __cache_cls(key) diff --git a/python/triton/runtime/driver/cuda.py b/python/triton/runtime/driver/cuda.py index b576ec4c225e..28176868bae5 100644 --- a/python/triton/runtime/driver/cuda.py +++ b/python/triton/runtime/driver/cuda.py @@ -3,7 +3,7 @@ import tempfile from ...common.build import _build -from ..cache import CacheManager +from ..cache import get_cache_manager def get_cuda_utils(): @@ -140,18 +140,19 @@ def _generate_src(): def __init__(self): src = self._generate_src() key = hashlib.md5(src.encode("utf-8")).hexdigest() - cache = CacheManager(key) + cache = get_cache_manager(key) fname = "cuda_utils.so" - if not cache.has_file(fname): + cache_path = cache.get_file(fname) + if cache_path is None: with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, "main.c") with open(src_path, "w") as f: f.write(src) so = _build("cuda_utils", src_path, tmpdir) with open(so, "rb") as f: - cache.put(f.read(), fname, binary=True) + cache_path = cache.put(f.read(), fname, binary=True) import importlib.util - spec = importlib.util.spec_from_file_location("cuda_utils", cache._make_path(fname)) + spec = importlib.util.spec_from_file_location("cuda_utils", cache_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) self.load_binary = mod.load_binary diff --git a/python/triton/runtime/driver/hip.py b/python/triton/runtime/driver/hip.py index 66cbfe9777f0..a0423b6ca057 100644 --- a/python/triton/runtime/driver/hip.py +++ b/python/triton/runtime/driver/hip.py @@ -3,7 +3,7 @@ import tempfile from ...common.build import _build -from ..cache import CacheManager +from ..cache import get_cache_manager def get_hip_utils(): @@ -139,18 +139,19 @@ def _generate_src(self): def __init__(self): src = self._generate_src() key = hashlib.md5(src.encode("utf-8")).hexdigest() - cache = CacheManager(key) + cache = get_cache_manager(key) fname = "hip_utils.so" - if not cache.has_file(fname): + cache_path = cache.get_file(fname) + if cache_path is None: with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, "main.c") with open(src_path, "w") as f: f.write(src) so = _build("hip_utils", src_path, tmpdir) with open(so, "rb") as f: - cache.put(f.read(), fname, binary=True) + cache_path = cache.put(f.read(), fname, binary=True) import importlib.util - spec = importlib.util.spec_from_file_location("hip_utils", cache._make_path(fname)) + spec = importlib.util.spec_from_file_location("hip_utils", cache_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) self.load_binary = mod.load_binary