Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into HEAD
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbell10 committed Apr 7, 2023
2 parents 3db5241 + bc0b007 commit c404afe
Show file tree
Hide file tree
Showing 17 changed files with 264 additions and 265 deletions.
151 changes: 0 additions & 151 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1124,154 +1124,3 @@ void populateElementwiseOpToLLVMPatterns(
// __nv_expf for higher-precision calculation
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
}

struct FPExtOpConversion
: ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

static bool isLegalOp(LLVM::FPExtOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isF32() && srcTy.isF16()) {
return false;
}
return true;
}

Value createDestOp(LLVM::FPExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
return FpToFpOpConversion::convertFp16ToFp32(loc, rewriter, operands[0]);
}
};

struct FPTruncOpConversion
: ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion> {
using Base =
ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

static bool isLegalOp(LLVM::FPTruncOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isF16() && srcTy.isF32()) {
return false;
}
return true;
}

Value createDestOp(LLVM::FPTruncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
return FpToFpOpConversion::convertFp32ToFp16(loc, rewriter, operands[0]);
}
};

struct TruncOpConversion
: ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

static bool isLegalOp(LLVM::TruncOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isInteger(16) && srcTy.isInteger(32)) {
return false;
}
return true;
}

Value createDestOp(LLVM::TruncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.u16.u32");
auto res = builder.newOperand("=h");
auto operand = builder.newOperand(operands[0], "r");
cvt(res, operand);
return builder.launch(rewriter, loc, i16_ty, false);
}
};

struct SExtOpConversion
: ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

static bool isLegalOp(LLVM::SExtOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
return false;
}
return true;
}

Value createDestOp(LLVM::SExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.s32.s16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(operands[0], "h");
cvt(res, operand);
return builder.launch(rewriter, loc, i32_ty, false);
}
};

struct ZExtOpConversion
: ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion> {
using Base = ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

static bool isLegalOp(LLVM::ZExtOp op) {
auto retTy = op.getResult().getType();
auto srcTy = op.getOperand().getType();
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
return false;
}
return true;
}

Value createDestOp(LLVM::ZExtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.u32.u16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(operands[0], "h");
cvt(res, operand);
return builder.launch(rewriter, loc, i32_ty, false);
}
};

bool isLegalElementwiseOp(Operation *op) {
if (isa<LLVM::FPExtOp>(op)) {
return FPExtOpConversion::isLegalOp(cast<LLVM::FPExtOp>(op));
} else if (isa<LLVM::FPTruncOp>(op)) {
return FPTruncOpConversion::isLegalOp(cast<LLVM::FPTruncOp>(op));
} else if (isa<LLVM::TruncOp>(op)) {
return TruncOpConversion::isLegalOp(cast<LLVM::TruncOp>(op));
} else if (isa<LLVM::SExtOp>(op)) {
return SExtOpConversion::isLegalOp(cast<LLVM::SExtOp>(op));
} else if (isa<LLVM::ZExtOp>(op)) {
return ZExtOpConversion::isLegalOp(cast<LLVM::ZExtOp>(op));
}
return true;
}

void populateElementwiseOpToPTXPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<FPExtOpConversion>(typeConverter, benefit);
patterns.add<FPTruncOpConversion>(typeConverter, benefit);
patterns.add<TruncOpConversion>(typeConverter, benefit);
patterns.add<SExtOpConversion>(typeConverter, benefit);
patterns.add<ZExtOpConversion>(typeConverter, benefit);
}
4 changes: 0 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,4 @@ void populateElementwiseOpToLLVMPatterns(

bool isLegalElementwiseOp(Operation *op);

void populateElementwiseOpToPTXPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit);

#endif
45 changes: 0 additions & 45 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,28 +56,6 @@ class TritonLLVMFunctionConversionTarget : public ConversionTarget {
}
};

class TritonPTXConversionTarget : public ConversionTarget {
public:
explicit TritonPTXConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
addDynamicallyLegalDialect<LLVM::LLVMDialect>(
[&](Operation *op) { return isLegalElementwiseOp(op); });

addLegalDialect<NVVM::NVVMDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};

class TritonGCNConversionTarget : public ConversionTarget {
public:
explicit TritonGCNConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
addDynamicallyLegalDialect<LLVM::LLVMDialect>(
[&](Operation *op) { return isLegalElementwiseOp(op); });

addLegalDialect<ROCDL::ROCDLDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};

struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;

Expand Down Expand Up @@ -236,29 +214,6 @@ class ConvertTritonGPUToLLVM
patterns);
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();

if (isROCM) {
TritonGCNConversionTarget gcnTarget(*context);
RewritePatternSet gcnPatterns(context);
populateElementwiseOpToPTXPatterns(typeConverter, gcnPatterns,
/*benefits=*/10);
if (failed(
applyPartialConversion(mod, gcnTarget, std::move(gcnPatterns))))
return signalPassFailure();
} else {
// Use our custom converters to convert some operations to PTX to avoid
// using NVPTX for two reasons:
// 1. NVPTX backend is flaky on data types like float16 and bfloat16
// 2. In some cases, we may generate faster PTX code than NVPTX backend
TritonPTXConversionTarget ptxTarget(*context);
RewritePatternSet ptxPatterns(context);
// Add patterns to convert LLVM to PTX
populateElementwiseOpToPTXPatterns(typeConverter, ptxPatterns,
/*benefits=*/10);
if (failed(
applyPartialConversion(mod, ptxTarget, std::move(ptxPatterns))))
return signalPassFailure();
}
}

private:
Expand Down
1 change: 1 addition & 0 deletions python/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
graft src
graft triton/third_party
10 changes: 6 additions & 4 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,18 @@ def build_extension(self, ext):
packages=[
"triton",
"triton/_C",
"triton/language",
"triton/tools",
"triton/common",
"triton/compiler",
"triton/language",
"triton/ops",
"triton/ops/blocksparse",
"triton/runtime",
"triton/ops/blocksparse"],
"triton/runtime/driver",
"triton/tools",
],
install_requires=[
"filelock",
],
package_data={"triton": ["third_party/**/*"]},
include_package_data=True,
ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={"build_ext": CMakeBuild},
Expand Down
33 changes: 31 additions & 2 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,35 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)


# ---------------
# test broadcast
# ---------------
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16)
def test_broadcast(dtype):
@triton.jit
def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr):
offset1 = tl.arange(0, M)
offset2 = tl.arange(0, N)
x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :])
y = tl.load(y_ptr + offset2)
_, y_broadcasted = tl.broadcast(x, y)
tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted)

M = 32
N = 64
rs = RandomState(17)
x = numpy_random((M, N), dtype_str=dtype, rs=rs)
y = numpy_random(N, dtype_str=dtype, rs=rs)
_, y_broadcasted_np = np.broadcast_arrays(x, y)

x_tri = to_triton(x, device='cuda', dst_type=dtype)
y_tri = to_triton(y, device='cuda', dst_type=dtype)
y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device='cuda', dst_type=dtype)

broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()


# ---------------
# test where
# ---------------
Expand Down Expand Up @@ -514,7 +543,7 @@ def test_unary_op(dtype_x, expr, device='cuda'):
# ----------------


@pytest.mark.parametrize("dtype_x, expr", [(dtype_x, expr) for dtype_x in float_dtypes for expr in ['exp', 'log', 'cos', 'sin']])
@pytest.mark.parametrize("dtype_x, expr", [(dtype_x, expr) for dtype_x in ["float32", "float64"] for expr in ['exp', 'log', 'cos', 'sin']])
def test_math_op(dtype_x, expr, device='cuda'):
_test_unary(dtype_x, f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)

Expand Down Expand Up @@ -1345,7 +1374,7 @@ def kernel(X, stride_xm, stride_xk,
if DO_SOFTMAX:
max = tl.max(z, 1)
z = z - max[:, None]
num = tl.exp(z)
num = tl.exp(z.to(tl.float32)).to(max.dtype)
den = tl.sum(num, 1)
z = num / den[:, None]
if CHAIN_DOT:
Expand Down
19 changes: 8 additions & 11 deletions python/test/unit/operators/test_blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,14 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
a_tri.retain_grad()
b_tri.retain_grad()
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
try:
c_tri = op(a_tri, b_tri)
c_tri.backward(dc_tri)
da_tri = a_tri.grad
db_tri = b_tri.grad
# compare
torch.testing.assert_allclose(c_ref, c_tri)
torch.testing.assert_allclose(da_ref, da_tri)
torch.testing.assert_allclose(db_ref, db_tri)
except triton.OutOfResourcesError as e:
pytest.skip(str(e))
c_tri = op(a_tri, b_tri)
c_tri.backward(dc_tri)
da_tri = a_tri.grad
db_tri = b_tri.grad
# compare
torch.testing.assert_allclose(c_ref, c_tri)
torch.testing.assert_allclose(da_ref, da_tri)
torch.testing.assert_allclose(db_ref, db_tri)


configs = [
Expand Down
Loading

0 comments on commit c404afe

Please sign in to comment.