Skip to content

Commit

Permalink
[APFloat] Add APFloat support for FP4 data type (#95392)
Browse files Browse the repository at this point in the history
This patch adds APFloat type support for the E2M1
FP4 datatype. The definitions for this format are
detailed in section 5.3.3 of the OCP specification,
which can be accessed here:

https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
  • Loading branch information
durga4github authored Jun 14, 2024
1 parent e83adfe commit 880d370
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 7 deletions.
1 change: 1 addition & 0 deletions clang/lib/AST/MicrosoftMangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,7 @@ void MicrosoftCXXNameMangler::mangleFloat(llvm::APFloat Number) {
case APFloat::S_FloatTF32:
case APFloat::S_Float6E3M2FN:
case APFloat::S_Float6E2M3FN:
case APFloat::S_Float4E2M1FN:
llvm_unreachable("Tried to mangle unexpected APFloat semantics");
}

Expand Down
8 changes: 8 additions & 0 deletions llvm/include/llvm/ADT/APFloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ struct APFloatBase {
// types, there are no infinity or NaN values. The format is detailed in
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
S_Float6E2M3FN,
// 4-bit floating point number with bit layout S1E2M1. Unlike IEEE-754
// types, there are no infinity or NaN values. The format is detailed in
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
S_Float4E2M1FN,

S_x87DoubleExtended,
S_MaxSemantics = S_x87DoubleExtended,
Expand All @@ -219,6 +223,7 @@ struct APFloatBase {
static const fltSemantics &FloatTF32() LLVM_READNONE;
static const fltSemantics &Float6E3M2FN() LLVM_READNONE;
static const fltSemantics &Float6E2M3FN() LLVM_READNONE;
static const fltSemantics &Float4E2M1FN() LLVM_READNONE;
static const fltSemantics &x87DoubleExtended() LLVM_READNONE;

/// A Pseudo fltsemantic used to construct APFloats that cannot conflict with
Expand Down Expand Up @@ -639,6 +644,7 @@ class IEEEFloat final : public APFloatBase {
APInt convertFloatTF32APFloatToAPInt() const;
APInt convertFloat6E3M2FNAPFloatToAPInt() const;
APInt convertFloat6E2M3FNAPFloatToAPInt() const;
APInt convertFloat4E2M1FNAPFloatToAPInt() const;
void initFromAPInt(const fltSemantics *Sem, const APInt &api);
template <const fltSemantics &S> void initFromIEEEAPInt(const APInt &api);
void initFromHalfAPInt(const APInt &api);
Expand All @@ -656,6 +662,7 @@ class IEEEFloat final : public APFloatBase {
void initFromFloatTF32APInt(const APInt &api);
void initFromFloat6E3M2FNAPInt(const APInt &api);
void initFromFloat6E2M3FNAPInt(const APInt &api);
void initFromFloat4E2M1FNAPInt(const APInt &api);

void assign(const IEEEFloat &);
void copySignificand(const IEEEFloat &);
Expand Down Expand Up @@ -1067,6 +1074,7 @@ class APFloat : public APFloatBase {
// Below Semantics do not support {NaN or Inf}
case APFloat::S_Float6E3M2FN:
case APFloat::S_Float6E2M3FN:
case APFloat::S_Float4E2M1FN:
return false;
}
}
Expand Down
25 changes: 23 additions & 2 deletions llvm/lib/Support/APFloat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ enum class fltNonfiniteBehavior {
// encodings do not distinguish between signalling and quiet NaN.
NanOnly,

// This behavior is present in Float6E3M2FN and Float6E2M3FN types,
// which do not support Inf or NaN values.
// This behavior is present in Float6E3M2FN, Float6E2M3FN, and
// Float4E2M1FN types, which do not support Inf or NaN values.
FiniteOnly,
};

Expand Down Expand Up @@ -147,6 +147,8 @@ static constexpr fltSemantics semFloat6E3M2FN = {
4, -2, 3, 6, fltNonfiniteBehavior::FiniteOnly};
static constexpr fltSemantics semFloat6E2M3FN = {
2, 0, 4, 6, fltNonfiniteBehavior::FiniteOnly};
static constexpr fltSemantics semFloat4E2M1FN = {
2, 0, 2, 4, fltNonfiniteBehavior::FiniteOnly};
static constexpr fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
static constexpr fltSemantics semBogus = {0, 0, 0, 0};

Expand Down Expand Up @@ -218,6 +220,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
return Float6E3M2FN();
case S_Float6E2M3FN:
return Float6E2M3FN();
case S_Float4E2M1FN:
return Float4E2M1FN();
case S_x87DoubleExtended:
return x87DoubleExtended();
}
Expand Down Expand Up @@ -254,6 +258,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
return S_Float6E3M2FN;
else if (&Sem == &llvm::APFloat::Float6E2M3FN())
return S_Float6E2M3FN;
else if (&Sem == &llvm::APFloat::Float4E2M1FN())
return S_Float4E2M1FN;
else if (&Sem == &llvm::APFloat::x87DoubleExtended())
return S_x87DoubleExtended;
else
Expand All @@ -278,6 +284,7 @@ const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; }
const fltSemantics &APFloatBase::Float6E3M2FN() { return semFloat6E3M2FN; }
const fltSemantics &APFloatBase::Float6E2M3FN() { return semFloat6E2M3FN; }
const fltSemantics &APFloatBase::Float4E2M1FN() { return semFloat4E2M1FN; }
const fltSemantics &APFloatBase::x87DoubleExtended() {
return semX87DoubleExtended;
}
Expand Down Expand Up @@ -3640,6 +3647,11 @@ APInt IEEEFloat::convertFloat6E2M3FNAPFloatToAPInt() const {
return convertIEEEFloatToAPInt<semFloat6E2M3FN>();
}

APInt IEEEFloat::convertFloat4E2M1FNAPFloatToAPInt() const {
assert(partCount() == 1);
return convertIEEEFloatToAPInt<semFloat4E2M1FN>();
}

// This function creates an APInt that is just a bit map of the floating
// point constant as it would appear in memory. It is not a conversion,
// and treating the result as a normal integer is unlikely to be useful.
Expand Down Expand Up @@ -3687,6 +3699,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
if (semantics == (const llvm::fltSemantics *)&semFloat6E2M3FN)
return convertFloat6E2M3FNAPFloatToAPInt();

if (semantics == (const llvm::fltSemantics *)&semFloat4E2M1FN)
return convertFloat4E2M1FNAPFloatToAPInt();

assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
"unknown format!");
return convertF80LongDoubleAPFloatToAPInt();
Expand Down Expand Up @@ -3911,6 +3926,10 @@ void IEEEFloat::initFromFloat6E2M3FNAPInt(const APInt &api) {
initFromIEEEAPInt<semFloat6E2M3FN>(api);
}

void IEEEFloat::initFromFloat4E2M1FNAPInt(const APInt &api) {
initFromIEEEAPInt<semFloat4E2M1FN>(api);
}

/// Treat api as containing the bits of a floating point number.
void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
assert(api.getBitWidth() == Sem->sizeInBits);
Expand Down Expand Up @@ -3944,6 +3963,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
return initFromFloat6E3M2FNAPInt(api);
if (Sem == &semFloat6E2M3FN)
return initFromFloat6E2M3FNAPInt(api);
if (Sem == &semFloat4E2M1FN)
return initFromFloat4E2M1FNAPInt(api);

llvm_unreachable(nullptr);
}
Expand Down
Loading

0 comments on commit 880d370

Please sign in to comment.