Skip to content

Commit

Permalink
Add F4E2M1FN type: conversion codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-kozub committed Dec 18, 2024
1 parent 70ca820 commit c479f09
Show file tree
Hide file tree
Showing 7 changed files with 383 additions and 67 deletions.
105 changes: 55 additions & 50 deletions xla/backends/gpu/codegen/transforms/expand_float_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ int GetExponentBias(mlir::FloatType ty) {
return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics());
}

bool IsFNUZ(mlir::FloatType ty) {
return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() ||
ty.isFloat8E5M2FNUZ();
}

Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) {
auto ty = mlir::cast<mlir::FloatType>(value.getType());
if (mlir::LLVM::isCompatibleOuterType(ty)) {
Expand All @@ -175,7 +180,7 @@ Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) {
return b.create<ma::CmpFOp>(ma::CmpFPredicate::OEQ, value, inf);
}

assert(ty.getIntOrFloatBitWidth() == 8);
assert(ty.getIntOrFloatBitWidth() <= 8);
// F8E5M2, F8E4M3, F8E3M4 are the only 8 bit float with infinities.
if (ty.isFloat8E5M2()) {
Val bits{b.create<ma::BitcastOp>(b.getI8Type(), value), &b};
Expand All @@ -196,6 +201,9 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) {
if (mlir::LLVM::isCompatibleOuterType(ty)) {
return b.create<ma::CmpFOp>(ma::CmpFPredicate::UNO, value, value);
}
if (ty.isFloat4E2M1FN()) {
return b.create<ma::ConstantIntOp>(false, b.getI1Type());
}

assert(ty.getIntOrFloatBitWidth() == 8);
Val bits{b.create<ma::BitcastOp>(b.getI8Type(), value), &b};
Expand Down Expand Up @@ -281,7 +289,7 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
auto to_int_ty = b.getIntegerType(to_ty.getIntOrFloatBitWidth());

mlir::IntegerType wide_int_ty;
if (from_ty.getWidth() == 8 && to_ty.getWidth() == 8) {
if (from_ty.getWidth() <= 8 && to_ty.getWidth() <= 8) {
wide_int_ty = b.getI16Type();
} else {
wide_int_ty = b.getIntegerType(
Expand All @@ -300,21 +308,20 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
int64_t exp_offset = to_bias - from_bias;
int digit_shift = to_mantissa - from_mantissa;

Val from_bits{
b.create<ma::BitcastOp>(
b.getIntegerType(value.getType().getIntOrFloatBitWidth()), value),
&b};
int from_width = value.getType().getIntOrFloatBitWidth();
Val from_bits{b.create<ma::BitcastOp>(b.getIntegerType(from_width), value),
&b};
if (from_width < 8) {
from_bits = convert_int(b.getIntegerType(8), from_bits);
}

auto cst = [&](mlir::Type ty, int64_t n) -> Val {
return {b.create<ma::ConstantIntOp>(n, ty), &b};
};

// Shift bits to destination type, without sign bit.
Val from_sign_bit =
from_bits.shrui(value.getType().getIntOrFloatBitWidth() - 1) != 0;

from_bits =
from_bits & ((1ULL << (value.getType().getIntOrFloatBitWidth() - 1)) - 1);
Val from_sign_bit = from_bits.shrui(from_width - 1) != 0;
from_bits = from_bits & ((1ULL << (from_width - 1)) - 1);

Value result_is_inf = IsInf(value, b);
Value input_is_nan = IsNaN(value, b);
Expand All @@ -327,6 +334,12 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
Value to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics()));
Val to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics()));

// MX float types have neither infinities nor NaNs.
if (to_ty.isFloat4E2M1FN()) {
to_inf = cst_bits(llvm::APFloat::getLargest(to_ty.getFloatSemantics()));
to_nan = to_zero | 0x8;
}

auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) {
assert(bits.value.getType() == roundoff.value.getType());
// Round to nearest even by adding a bias term.
Expand Down Expand Up @@ -394,10 +407,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
Val bits = convert_int(wide_int_ty, from_bits);

// Determine exponent in target type.
Value normalization_factor =
convert_int(i32_ty,
b.create<mlir::math::CountLeadingZerosOp>(from_bits)) -
(from_int_ty.getWidth() - from_mantissa - 1);
Value clz = convert_int(
i32_ty, b.create<mlir::math::CountLeadingZerosOp>(from_bits));
Value msb = cst(i32_ty, std::max(from_width, 8) - 1) - clz;
Value normalization_factor = cst(i32_ty, from_mantissa) - msb;

Val biased_exponent = cst(i32_ty, exp_offset + 1) - normalization_factor;
// If the result is subnormal, adjust the subnormal bits to account for
Expand Down Expand Up @@ -469,18 +482,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
result);
}

// Handle types with no unsigned zero.
auto is_nuz = [](mlir::FloatType ty) {
return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() ||
ty.isFloat8E5M2FNUZ();
};

if (is_nuz(to_ty)) {
if (IsFNUZ(to_ty)) {
// Clear the sign bit if the result is zero (the output has no negative
// zero).
Val result_is_non_zero = Val{result, &b} != 0;
// zero). Handle the edge case when the input is zero and the result is not.
Val result_is_non_zero =
(digit_shift > 0 ? from_bits : Val{result, &b}) != 0;
from_sign_bit = from_sign_bit & result_is_non_zero;
} else if (is_nuz(from_ty)) {
} else if (IsFNUZ(from_ty)) {
// Clear the sign bit if the input is NaN (it's positive but encoded as
// negative 0).
from_sign_bit = from_sign_bit ^ input_is_nan;
Expand All @@ -506,8 +514,8 @@ struct RewriteTruncFPattern : public mlir::OpRewritePattern<ma::TruncFOp> {
using FloatValue = mlir::TypedValue<mlir::FloatType>;
auto src = mlir::cast<FloatValue>(op.getOperand());
auto dst_ty = mlir::cast<mlir::FloatType>(op.getType());
if (dst_ty.getWidth() != 8) {
return rewriter.notifyMatchFailure(op, "not an 8 bit truncf");
if (dst_ty.getWidth() > 8) {
return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) truncf");
}

mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Expand All @@ -524,8 +532,8 @@ struct RewriteExtFPattern : public mlir::OpRewritePattern<ma::ExtFOp> {
using FloatValue = mlir::TypedValue<mlir::FloatType>;
auto src = mlir::cast<FloatValue>(op.getOperand());
auto dst_ty = mlir::cast<mlir::FloatType>(op.getType());
if (src.getType().getWidth() != 8) {
return rewriter.notifyMatchFailure(op, "not an 8 bit extf");
if (src.getType().getWidth() > 8) {
return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) extf");
}

mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Expand All @@ -544,25 +552,25 @@ struct RewriteF8Cst : public mlir::OpRewritePattern<ma::CmpFOp> {
auto lhs = mlir::cast<FloatValue>(op.getLhs());
auto rhs = mlir::cast<FloatValue>(op.getRhs());

if (lhs.getType().getWidth() != 8) {
return rewriter.notifyMatchFailure(op, "not an 8 bit cmpf");
if (lhs.getType().getWidth() > 8) {
return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) cmpf");
}

mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
// Skip the f32 conversion if we're comparing UNE.cst.
llvm::APFloat rhs_cst(rhs.getType().getFloatSemantics());
if (op.getPredicate() == ma::CmpFPredicate::UNE &&
mlir::matchPattern(rhs, mlir::m_ConstantFloat(&rhs_cst))) {
Val int_value{b.create<ma::BitcastOp>(rewriter.getI8Type(), lhs), &b};
mlir::Type int_ty = rewriter.getIntegerType(lhs.getType().getWidth());
Val int_value{b.create<ma::BitcastOp>(int_ty, lhs), &b};
int64_t constant = rhs_cst.bitcastToAPInt().getZExtValue();
// If we're comparing to +-0, compare the absolute values.
if (rhs_cst.isZero() &&
(lhs.getType().isFloat8E3M4() || lhs.getType().isFloat8E4M3() ||
lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) {
int_value = int_value & 0x7f;
constant &= 0x7f;
if (rhs_cst.isZero() && !IsFNUZ(lhs.getType())) {
int64_t mask = (1 << (lhs.getType().getWidth() - 1)) - 1;
int_value = int_value & mask;
constant &= mask;
}
auto cst = b.create<ma::ConstantIntOp>(constant, rewriter.getI8Type());
auto cst = b.create<ma::ConstantIntOp>(constant, int_ty);
rewriter.replaceOpWithNewOp<ma::CmpIOp>(op, ma::CmpIPredicate::ne,
int_value, cst);
return mlir::success();
Expand All @@ -586,18 +594,15 @@ struct RewriteAbsFPattern : public mlir::OpRewritePattern<mlir::math::AbsFOp> {
auto src = mlir::cast<FloatValue>(op.getOperand());
// LowerGpuOpsToNVVMOps has a lowering for abs that doesn't work with bf16.
// Once that's removed, remove the code for BF16 here.
if (src.getType().getWidth() != 8 && !src.getType().isBF16()) {
return rewriter.notifyMatchFailure(op, "not an f8 or bf16 absf");
if (src.getType().getWidth() > 8 && !src.getType().isBF16()) {
return rewriter.notifyMatchFailure(op,
"not an f8 (or less) or bf16 absf");
}
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
mlir::Type i_ty = rewriter.getIntegerType(src.getType().getWidth());
Val value{b.create<ma::BitcastOp>(i_ty, src), &b};
if (src.getType().getWidth() == 8) {
value = value & 0x7f;
} else {
CHECK(src.getType().isBF16());
value = value & 0x7fff;
}
int64_t mask = (1ull << (src.getType().getWidth() - 1)) - 1;
value = value & mask;
rewriter.replaceOpWithNewOp<ma::BitcastOp>(op, src.getType(), value);
return mlir::success();
}
Expand All @@ -609,8 +614,8 @@ struct RewriteIToFpPattern : public mlir::OpRewritePattern<Op> {

mlir::LogicalResult matchAndRewrite(
Op op, mlir::PatternRewriter& rewriter) const override {
if (op.getType().getIntOrFloatBitWidth() != 8) {
return rewriter.notifyMatchFailure(op, "not an f8 itofp");
if (op.getType().getIntOrFloatBitWidth() > 8) {
return rewriter.notifyMatchFailure(op, "not an f8 (or less) itofp");
}
Value to_float =
rewriter.create<Op>(op.getLoc(), rewriter.getF32Type(), op.getIn());
Expand All @@ -625,8 +630,8 @@ struct RewriteFpToIPattern : public mlir::OpRewritePattern<Op> {

mlir::LogicalResult matchAndRewrite(
Op op, mlir::PatternRewriter& rewriter) const override {
if (op.getIn().getType().getIntOrFloatBitWidth() != 8) {
return rewriter.notifyMatchFailure(op, "not an f8 fptoi");
if (op.getIn().getType().getIntOrFloatBitWidth() > 8) {
return rewriter.notifyMatchFailure(op, "not an f8 (or less) fptoi");
}
Value to_f32 = rewriter.create<ma::ExtFOp>(
op.getLoc(), rewriter.getF32Type(), op.getIn());
Expand Down
37 changes: 37 additions & 0 deletions xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,40 @@ module {
// CHECK: %[[EXT:.*]] = arith.extf {{.*}} : bf16 to f32
// CHECK: arith.truncf %[[EXT]] : f32 to f16
// CHECK-NOT: arith.truncf

// -----

module {
func.func @f4_to_f16(%arg0: f4E2M1FN) -> f16 {
%ret = arith.extf %arg0 : f4E2M1FN to f16
return %ret : f16
}
}

// CHECK-LABEL: @f4_to_f16
// CHECK-NOT: arith.extf

// -----

module {
func.func @f16_to_f4(%arg0: f16) -> f4E2M1FN {
%ret = arith.truncf %arg0 : f16 to f4E2M1FN
return %ret : f4E2M1FN
}
}

// CHECK-LABEL: @f16_to_f4
// CHECK-NOT: arith.truncf

// -----

module {
func.func @f4_abs(%arg0: f4E2M1FN) -> f4E2M1FN {
%ret = math.absf %arg0 : f4E2M1FN
return %ret : f4E2M1FN
}
}

// CHECK-LABEL: @f4_abs
// CHECK-NOT: math.absf
// CHECK: arith.constant 7 : i4
Loading

0 comments on commit c479f09

Please sign in to comment.