From 1e145f92f262b4eec86bc2fb4fbbbf09eb4ec2fb Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Tue, 3 Dec 2024 03:25:57 -0800 Subject: [PATCH] PR #19096: Add F4E2M1FN and F8E8M0FNU types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/19096 This PR adds F4E2M1FN primitive type (4-bit float with 2 bits exponent and 1 bit mantissa), F8E8M0FNU primitive type (8-bit float with 8 bits exponent, no mantissa and no sign) and enables loads/stores in the same way S4/U4 type is implemented. This will enable using microscaling (MX) formats ([RFC](https://github.com/openxla/xla/discussions/18085)), such as MXFP4. ```c F4E2M1FN - Exponent bias: 1 - Maximum stored exponent value: 3 (binary 11) - Maximum unbiased exponent value: 3 - 1 = 2 - Minimum stored exponent value: 1 (binary 01) - Minimum unbiased exponent value: 1 − 1 = 0 - Has Positive and Negative zero - Doesn't have infinity - Doesn't have NaNs Additional details: - Zeros (+/-): S.00.0 - Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0 - Min normal number: S.01.0 = ±2^(0) = ±1.0 - Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5 F8E8M0FNU - Exponent bias: 127 - Maximum stored exponent value: 254 (binary 1111'1110) - Maximum unbiased exponent value: 254 - 127 = 127 - Minimum stored exponent value: 0 (binary 0000'0000) - Minimum unbiased exponent value: 0 − 127 = -127 - Doesn't have zero - Doesn't have infinity - NaN is encoded as binary 1111'1111 Additional details: - Zeros cannot be represented - Negative values cannot be represented - Mantissa is always 1 ``` Related PRs: - https://github.com/openxla/stablehlo/pull/2582 - https://github.com/jax-ml/ml_dtypes/pull/181 - https://github.com/llvm/llvm-project/pull/95392 - https://github.com/llvm/llvm-project/pull/108877 - https://github.com/jax-ml/ml_dtypes/pull/166 - https://github.com/llvm/llvm-project/pull/107127 - https://github.com/llvm/llvm-project/pull/111028 The PR is split into multiple commits just to make the review easier, it is possible that some tests could fail if only some (i.e. not all) of these commits are applied. Copybara import of the project: -- fa539fbde987ff6421fd2937fade495baf633630 by Sergey Kozub : Add F4E2M1FN type: import mxfloat.h -- 2c014035923e0394b2cfcb81eaf090a96621b0aa by Sergey Kozub : Add F4E2M1FN type: primitive type -- e919ed54e825f2e905aaf0cc279dd21cd80f1ce9 by Sergey Kozub : Add F4E2M1FN type: literal support -- ca16839096feb93e0454ec380c5c707c30199346 by Sergey Kozub : Add F4E2M1FN type: conversion codegen -- eedc079ca9a4db9e611d84877a25b3da21386f16 by Sergey Kozub : Add F4E2M1FN type: python interface -- 8e0305cd47002f0c1f8668a3cbcbce5428f2a4c6 by Sergey Kozub : Add F4E2M1FN type: FFI -- aabe9c68d964609f78f29e17ee0680798ad0c6ac by Sergey Kozub : Add F4E2M1FN type: HLO evaluator -- 87da2ebfab388f113482e852009401a9e416974a by Sergey Kozub : Add F4E2M1FN type: add tests -- e0ee48c3a37018ba985c850931592d62eadf7c2e by Sergey Kozub : Add F8E8M0FNU type -- be2e457922e2cddeaf5aca13dd022f3ac2a1393b by Sergey Kozub : Addressing PR#19096 review comments Merging this change closes #19096 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/19096 from openxla:skozub/e2m1 be2e457922e2cddeaf5aca13dd022f3ac2a1393b PiperOrigin-RevId: 702273510 --- third_party/tsl/tsl/platform/BUILD | 1 + third_party/tsl/tsl/platform/ml_dtypes.h | 3 + xla/array2d_test.cc | 28 ++ xla/comparison_util.h | 9 +- xla/ffi/api/api.h | 4 + xla/ffi/api/c_api.h | 2 + xla/ffi/api/ffi.h | 6 + xla/ffi/api/ffi_test.cc | 6 + xla/ffi/call_frame.cc | 2 + xla/fp_util_test.cc | 70 +++++ xla/hlo/builder/lib/math.cc | 11 +- xla/hlo/builder/lib/math_test.cc | 34 ++- xla/hlo/evaluator/BUILD | 1 + xla/hlo/evaluator/hlo_evaluator.cc | 2 +- .../evaluator/hlo_evaluator_typed_visitor.h | 2 + .../hlo_evaluator_typed_visitor_mxfloat.cc | 23 ++ .../expanders/comparison_expander.cc | 59 ++-- .../simplifiers/float_normalization.cc | 3 + .../simplifiers/float_normalization_test.cc | 4 +- .../translate/hlo_to_mhlo/tests/import.hlo | 20 +- .../translate/mhlo_to_hlo/literal_exporter.cc | 6 + .../translate/mhlo_to_hlo/tests/export.mlir | 18 +- xla/literal.cc | 36 ++- xla/literal.h | 29 +- xla/literal_comparison.cc | 3 + xla/literal_comparison_test.cc | 52 ++-- xla/literal_test.cc | 75 +++-- xla/mlir/utils/type_util.cc | 10 +- xla/mlir/utils/type_util_test.cc | 2 + xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir | 14 + xla/pjrt/c/CHANGELOG.md | 3 + xla/pjrt/c/pjrt_c_api.h | 6 +- xla/pjrt/c/pjrt_c_api_helpers.cc | 8 + xla/primitive_util.cc | 12 + xla/primitive_util.h | 80 ++++- xla/primitive_util_test.cc | 116 ++++++- xla/python/ifrt/dtype.cc | 8 + xla/python/ifrt/dtype.h | 6 +- xla/python/ifrt/dtype.proto | 6 + xla/python/ifrt/dtype_test.cc | 86 ++---- xla/python/pjrt_ifrt/pjrt_dtype.cc | 4 + xla/python/py_values.cc | 16 + xla/python/types.cc | 42 +++ xla/python/types.h | 2 + xla/python/xla.cc | 2 + xla/python/xla_client.py | 4 + xla/python/xla_client.pyi | 2 + xla/python/xla_client_test.py | 4 +- xla/python/xla_extension/__init__.pyi | 2 + xla/service/cpu/cpu_compiler.cc | 4 + xla/service/cpu/onednn_memory_util.h | 2 +- xla/service/elemental_ir_emitter.cc | 275 ++++++++++++++++- xla/service/elemental_ir_emitter_test.cc | 15 +- xla/service/float8_fnuz_ir_emitter.cc | 17 +- .../fusions/transforms/expand_float_ops.cc | 179 ++++++----- .../gpu/fusions/transforms/lower_tensors.cc | 53 ++-- .../transforms/tests/expand_float_ops.mlir | 50 ++++ .../transforms/tests/lower_tensors.mlir | 52 ++++ xla/service/gpu/gpu_compiler.cc | 4 + .../gpu/tests/float_conversions_test.cc | 7 +- xla/service/llvm_ir/llvm_util.cc | 3 + xla/stream_executor/data_type.h | 8 + xla/stream_executor/dnn.cc | 2 + xla/stream_executor/gpu/gpu_blas_lt.cc | 10 + xla/stream_executor/rocm/hip_blas_utils.cc | 6 +- xla/tests/BUILD | 1 + xla/tests/array_elementwise_ops_test.cc | 48 +-- xla/tests/constants_test.cc | 9 +- xla/tests/convert_test.cc | 282 +++++++++++++++++- xla/tools/driver.cc | 27 +- xla/tsl/framework/type_traits.h | 4 +- xla/tsl/protobuf/dnn.proto | 2 + xla/tsl/python/lib/core/ml_dtypes.cc | 6 + xla/tsl/python/lib/core/ml_dtypes.h | 2 + xla/types.h | 16 + xla/util.cc | 10 + xla/util.h | 25 +- xla/util_test.cc | 28 +- xla/xla_data.proto | 27 +- 79 files changed, 1767 insertions(+), 351 deletions(-) create mode 100644 xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc diff --git a/third_party/tsl/tsl/platform/BUILD b/third_party/tsl/tsl/platform/BUILD index fa4e4d0ce09120..3b51d30a9769d8 100644 --- a/third_party/tsl/tsl/platform/BUILD +++ b/third_party/tsl/tsl/platform/BUILD @@ -1066,6 +1066,7 @@ cc_library( deps = [ "@ml_dtypes//:float8", "@ml_dtypes//:intn", + "@ml_dtypes//:mxfloat", ], ) diff --git a/third_party/tsl/tsl/platform/ml_dtypes.h b/third_party/tsl/tsl/platform/ml_dtypes.h index 89a40bd891e106..f31726869b508d 100644 --- a/third_party/tsl/tsl/platform/ml_dtypes.h +++ b/third_party/tsl/tsl/platform/ml_dtypes.h @@ -18,8 +18,10 @@ limitations under the License. #include "ml_dtypes/include/float8.h" // from @ml_dtypes #include "ml_dtypes/include/intn.h" // from @ml_dtypes +#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes namespace tsl { +using float4_e2m1fn = ::ml_dtypes::float4_e2m1fn; using float8_e3m4 = ::ml_dtypes::float8_e3m4; using float8_e4m3 = ::ml_dtypes::float8_e4m3; using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn; @@ -27,6 +29,7 @@ using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz; using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz; using float8_e5m2 = ::ml_dtypes::float8_e5m2; using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz; +using float8_e8m0fnu = ::ml_dtypes::float8_e8m0fnu; using int2 = ::ml_dtypes::int2; using uint2 = ::ml_dtypes::uint2; diff --git a/xla/array2d_test.cc b/xla/array2d_test.cc index 921da30256fa3d..c62f6e882713e5 100644 --- a/xla/array2d_test.cc +++ b/xla/array2d_test.cc @@ -219,6 +219,34 @@ TEST(Array2dTest, LinspaceF8E3M4) { EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); } +TEST(Array2dTest, LinspaceF4E2M1FN) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 1.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.0); // 2.5 rounded down + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 3.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 4.0); // 3.5 rounded up +} + +TEST(Array2dTest, LinspaceF8E8M0FNU) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 2.0); // 1.5 rounded up + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.0); // 2.5 rounded down + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 4.0); // 3.0 rounded up + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 4.0); // 3.5 rounded up +} + TEST(Array2dTest, Stringification) { auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); const std::string expected = R"([[1, 1.5], diff --git a/xla/comparison_util.h b/xla/comparison_util.h index 5a21595da4d741..44f0dd48640bb1 100644 --- a/xla/comparison_util.h +++ b/xla/comparison_util.h @@ -193,8 +193,13 @@ class Comparison { // -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN // Reference: // https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations - using R = SignedIntegerTypeForSizeType; - return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); + if constexpr (std::numeric_limits::is_signed) { + using R = SignedIntegerTypeForSizeType; + return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); + } else { + using R = UnsignedIntegerTypeForSizeType; + return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); + } } } // Applies the comparison from this Comparison's direction and ordering. diff --git a/xla/ffi/api/api.h b/xla/ffi/api/api.h index cf98210af1b717..013b343b249f20 100644 --- a/xla/ffi/api/api.h +++ b/xla/ffi/api/api.h @@ -131,6 +131,8 @@ inline std::ostream& operator<<(std::ostream& os, return os << "C128"; case XLA_FFI_DataType_TOKEN: return os << "TOKEN"; + case XLA_FFI_DataType_F4E2M1FN: + return os << "F4E2M1FN"; case XLA_FFI_DataType_F8E5M2: return os << "F8E5M2"; case XLA_FFI_DataType_F8E3M4: @@ -145,6 +147,8 @@ inline std::ostream& operator<<(std::ostream& os, return os << "F8E5M2FNUZ"; case XLA_FFI_DataType_F8E4M3FNUZ: return os << "F8E4M3FNUZ"; + case XLA_FFI_DataType_F8E8M0FNU: + return os << "F8E8M0FNU"; } } diff --git a/xla/ffi/api/c_api.h b/xla/ffi/api/c_api.h index f0c4f40e78ea7a..5dc6bee08902d4 100644 --- a/xla/ffi/api/c_api.h +++ b/xla/ffi/api/c_api.h @@ -201,6 +201,8 @@ typedef enum { XLA_FFI_DataType_F8E4M3B11FNUZ = 23, XLA_FFI_DataType_F8E5M2FNUZ = 24, XLA_FFI_DataType_F8E4M3FNUZ = 25, + XLA_FFI_DataType_F4E2M1FN = 30, + XLA_FFI_DataType_F8E8M0FNU = 31, } XLA_FFI_DataType; // LINT.ThenChange(ffi_test.cc) diff --git a/xla/ffi/api/ffi.h b/xla/ffi/api/ffi.h index 91124533e7cca1..946b4284ebda78 100644 --- a/xla/ffi/api/ffi.h +++ b/xla/ffi/api/ffi.h @@ -79,6 +79,8 @@ enum class DataType : uint8_t { F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ, F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ, F8E3M4 = XLA_FFI_DataType_F8E3M4, + F4E2M1FN = XLA_FFI_DataType_F4E2M1FN, + F8E8M0FNU = XLA_FFI_DataType_F8E8M0FNU, }; // Create aliases in ::xla::ffi namespace for all DataTypes, for consistency @@ -106,6 +108,8 @@ inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ; inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ; inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ; inline constexpr DataType F8E3M4 = DataType::F8E3M4; +inline constexpr DataType F4E2M1FN = DataType::F4E2M1FN; +inline constexpr DataType F8E8M0FNU = DataType::F8E8M0FNU; inline std::ostream& operator<<(std::ostream& os, const DataType dtype) { return os << static_cast(dtype); @@ -127,6 +131,8 @@ constexpr size_t ByteWidth(DataType dtype) { case DataType::F8E5M2FNUZ: case DataType::F8E4M3FNUZ: case DataType::F8E3M4: + case DataType::F4E2M1FN: + case DataType::F8E8M0FNU: return 1; case DataType::S16: case DataType::U16: diff --git a/xla/ffi/api/ffi_test.cc b/xla/ffi/api/ffi_test.cc index 339dffac81172d..c40c03834057ee 100644 --- a/xla/ffi/api/ffi_test.cc +++ b/xla/ffi/api/ffi_test.cc @@ -129,6 +129,7 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN)); + EXPECT_EQ(encoded(PrimitiveType::F4E2M1FN), encoded(DataType::F4E2M1FN)); EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN)); @@ -137,6 +138,7 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4)); + EXPECT_EQ(encoded(PrimitiveType::F8E8M0FNU), encoded(DataType::F8E8M0FNU)); } TEST(FfiTest, DataTypeByteWidth) { @@ -179,6 +181,8 @@ TEST(FfiTest, DataTypeByteWidth) { EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::C128), ByteWidth(DataType::C128)); + EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F4E2M1FN), + ByteWidth(DataType::F4E2M1FN)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2), ByteWidth(DataType::F8E5M2)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3), @@ -193,6 +197,8 @@ TEST(FfiTest, DataTypeByteWidth) { ByteWidth(DataType::F8E4M3FNUZ)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4), ByteWidth(DataType::F8E3M4)); + EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E8M0FNU), + ByteWidth(DataType::F8E8M0FNU)); } TEST(FfiTest, ErrorEnumValue) { diff --git a/xla/ffi/call_frame.cc b/xla/ffi/call_frame.cc index 3fb2ac3c7786fa..7bcb14da445e8c 100644 --- a/xla/ffi/call_frame.cc +++ b/xla/ffi/call_frame.cc @@ -264,6 +264,7 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { case PrimitiveType::C64: case PrimitiveType::C128: case PrimitiveType::TOKEN: + case PrimitiveType::F4E2M1FN: case PrimitiveType::F8E5M2: case PrimitiveType::F8E4M3: case PrimitiveType::F8E4M3FN: @@ -271,6 +272,7 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { case PrimitiveType::F8E5M2FNUZ: case PrimitiveType::F8E4M3FNUZ: case PrimitiveType::F8E3M4: + case PrimitiveType::F8E8M0FNU: return static_cast(primitive_type); default: DCHECK(false) << "Unsupported primitive type " diff --git a/xla/fp_util_test.cc b/xla/fp_util_test.cc index 3eb7c54f919b0a..8ea22d9d1602bf 100644 --- a/xla/fp_util_test.cc +++ b/xla/fp_util_test.cc @@ -119,6 +119,76 @@ class FP8E4M3DistanceTest : public ::testing::Test {}; using F8E4M3Types = ::testing::Types; TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types); +TEST(FPDistanceTest, F4E2M1FNDistance) { + // a & b are equal + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(4.0)), + 0); + + // a & b have the same exponents + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(6.0)), + 1); + + // a & b have different exponents + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float4_e2m1fn(2.0), tsl::float4_e2m1fn(4.0)), + 2); + + // 1 from 0 in the positive direction + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::denorm_min(), + tsl::float4_e2m1fn(0)), + 1); + + // 1 from 0 in the negative direction + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + tsl::float4_e2m1fn(0)), + 1); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + std::numeric_limits::denorm_min()), + 2); + + // 1 non denorm from 0 in the positive direction + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::min(), + tsl::float4_e2m1fn(0)), + 2); + + // 1 non denorm from 0 in the negative direction + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), + tsl::float4_e2m1fn(0)), + 2); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), + std::numeric_limits::min()), + 4); +} + +TEST(FPDistanceTest, F8E8M0FNUDistance) { + // a & b are equal + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(1.0)), + 0); + + // one step apart + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(2.0)), + 1); + + // two steps apart + EXPECT_EQ(CalculateDistanceInFloats( + tsl::float8_e8m0fnu(0.5), tsl::float8_e8m0fnu(2.0)), + 2); +} + TEST(FPDistanceTest, F8E3M4Distance) { // a & b are equal EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e3m4(8.0), diff --git a/xla/hlo/builder/lib/math.cc b/xla/hlo/builder/lib/math.cc index f2a77df3d7ddaa..620e907f8cf112 100644 --- a/xla/hlo/builder/lib/math.cc +++ b/xla/hlo/builder/lib/math.cc @@ -184,6 +184,7 @@ XlaOp IsNegZero(XlaOp operand) { case F32: return Eq(BitcastConvertType(operand, U32), ConstantR0WithType(&b, U32, uint32_t{1} << 31)); + case F4E2M1FN: case F8E3M4: case F8E4M3: case F8E5M2: @@ -971,8 +972,9 @@ XlaOp Igamma(XlaOp a, XlaOp x) { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a)); PrimitiveType a_x_type = a_shape.element_type(); bool needs_upcast = false; - for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, - F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : + {BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN, + F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; @@ -1024,8 +1026,9 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) { } TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a)); bool needs_upcast = false; - for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, - F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : + {BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN, + F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; diff --git a/xla/hlo/builder/lib/math_test.cc b/xla/hlo/builder/lib/math_test.cc index 30eaf4b503de62..1cdf6648b52f9e 100644 --- a/xla/hlo/builder/lib/math_test.cc +++ b/xla/hlo/builder/lib/math_test.cc @@ -95,9 +95,13 @@ class MathTypedTest : public MathTest { Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)}); bool has_inf = std::numeric_limits::has_infinity; + bool has_nan = std::numeric_limits::has_quiet_NaN; + bool has_finite = !has_inf && !has_nan; + bool has_nan_only = !has_inf && has_nan; + auto expected = LiteralUtil::MakeTupleOwned( - LiteralUtil::CreateR1( - {true, true, true, true, true, false, false, false, false}), + LiteralUtil::CreateR1({true, true, true, true, true, has_finite, + has_finite, has_finite, has_finite}), LiteralUtil::CreateR1({false, false, false, false, false, has_inf, has_inf, false, false}), LiteralUtil::CreateR1( @@ -105,7 +109,8 @@ class MathTypedTest : public MathTest { LiteralUtil::CreateR1( {false, false, false, false, false, false, has_inf, false, false}), LiteralUtil::CreateR1({false, false, false, false, false, - !has_inf, !has_inf, true, true})); + has_nan_only, has_nan_only, has_nan, + has_nan})); ComputeAndCompareLiteral(&b, expected, {}); } @@ -118,10 +123,11 @@ class MathTypedTest : public MathTest { LiteralUtil::CreateR1({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}), &b)); + bool is_mx = std::is_same_v; ComputeAndCompareLiteral( &b, LiteralUtil::CreateR1( - {has_negative_zero_v, false, false, false, false, false, false}), + {has_negative_zero_v, false, false, false, false, false, is_mx}), {}, error_spec_); } @@ -136,6 +142,9 @@ class MathTypedTest : public MathTest { // For good measure, we also check pow with an exponent other than 0.5. void TestSqrtPowInequivalence() { SetFastMathDisabled(true); + if (std::is_same_v) { + GTEST_SKIP() << "Skipping due to low precision"; + } // Tests disable constant folding by default, but this test needs it // enabled, otherwise we don't tickle the bug we're trying to catch. @@ -181,9 +190,14 @@ class MathTypedTest : public MathTest { &b); Erf(x); - bool has_inf = std::numeric_limits::has_infinity; - std::vector expected = { - has_inf ? T(-1) : nan, has_inf ? T(1) : nan, T(-0), T(0), T(-1), T(1)}; + bool inf_as_nan = !std::numeric_limits::has_infinity && + std::numeric_limits::has_quiet_NaN; + std::vector expected = {inf_as_nan ? nan : T(-1), + inf_as_nan ? nan : T(1), + T(-0), + T(0), + T(-1), + T(1)}; ComputeAndCompareR1(&b, expected, {}, error_spec_); } @@ -191,9 +205,9 @@ class MathTypedTest : public MathTest { // TODO(b/123355973): Add bfloat16 to TestTypes once it's working. using TestTypes = - ::testing::Types StochasticConvertOp(const Literal& operand_literal, const Shape& result_shape) { std::function stochastic_convert_op = [](Fp operand, Uint random) -> ResultT { - bool is_negative = static_cast(Eigen::numext::signbit(operand)); + bool is_negative = static_cast(SignAndMagnitude(operand).first); if (Eigen::numext::isinf(operand)) { return is_negative ? std::numeric_limits::min() : std::numeric_limits::max(); diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index 41cd753d987201..7f0925f1a3179b 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -1734,6 +1734,7 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; @@ -1741,6 +1742,7 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc new file mode 100644 index 00000000000000..6bc96c1a1f7cda --- /dev/null +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_mxfloat.cc @@ -0,0 +1,23 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/evaluator/hlo_evaluator.h" +#include "xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" +#include "tsl/platform/ml_dtypes.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/xla/hlo/transforms/expanders/comparison_expander.cc b/xla/hlo/transforms/expanders/comparison_expander.cc index 0f09ecced1ebaf..86d1eeafcd5931 100644 --- a/xla/hlo/transforms/expanders/comparison_expander.cc +++ b/xla/hlo/transforms/expanders/comparison_expander.cc @@ -115,34 +115,41 @@ absl::StatusOr ComparisonExpander::ExpandInstruction( ShapeUtil::ChangeElementType(rhs->shape(), compare_type), rhs)); } - int64_t bit_width = primitive_util::BitWidth(lhs->shape().element_type()); - PrimitiveType signed_type = - primitive_util::SignedIntegralTypeForBitWidth(bit_width); - auto signed_shape = ShapeUtil::ChangeElementType(lhs->shape(), signed_type); - - auto zero_value = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type))); - zero_value = computation->AddInstruction( - HloInstruction::CreateBroadcast(signed_shape, zero_value, {})); - - auto min_value = computation->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MinValue(signed_shape.element_type()))); - min_value = computation->AddInstruction( - HloInstruction::CreateBroadcast(signed_shape, min_value, {})); - - auto max_value = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type))); - max_value = computation->AddInstruction( - HloInstruction::CreateBroadcast(signed_shape, max_value, {})); - - lhs = BitcastConvertFloatingPointToIntegral(computation, lhs, zero_value, - min_value, max_value); - rhs = BitcastConvertFloatingPointToIntegral(computation, rhs, zero_value, - min_value, max_value); + if (compare_type != F8E8M0FNU) { + int64_t bit_width = primitive_util::BitWidth(lhs->shape().element_type()); + PrimitiveType signed_type = + primitive_util::SignedIntegralTypeForBitWidth(bit_width); + auto signed_shape = ShapeUtil::ChangeElementType(lhs->shape(), signed_type); + + auto zero_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type))); + zero_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, zero_value, {})); + + auto min_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MinValue(signed_type))); + min_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, min_value, {})); + + auto max_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type))); + max_value = computation->AddInstruction( + HloInstruction::CreateBroadcast(signed_shape, max_value, {})); + + lhs = BitcastConvertFloatingPointToIntegral(computation, lhs, zero_value, + min_value, max_value); + rhs = BitcastConvertFloatingPointToIntegral(computation, rhs, zero_value, + min_value, max_value); + } else { + auto int8_shape = ShapeUtil::ChangeElementType(lhs->shape(), U8); + lhs = computation->AddInstruction( + HloInstruction::CreateBitcastConvert(int8_shape, lhs)); + rhs = computation->AddInstruction( + HloInstruction::CreateBitcastConvert(int8_shape, rhs)); + } auto new_compare = computation->AddInstruction(HloInstruction::CreateCompare( - instruction->shape(), lhs, rhs, compare->direction(), - Comparison::Type::kSigned)); + instruction->shape(), lhs, rhs, compare->direction())); VLOG(2) << "New comparison instruction for total order:" << new_compare->ToString(); diff --git a/xla/hlo/transforms/simplifiers/float_normalization.cc b/xla/hlo/transforms/simplifiers/float_normalization.cc index b6d8a532054502..2a29a7f0ce8f4b 100644 --- a/xla/hlo/transforms/simplifiers/float_normalization.cc +++ b/xla/hlo/transforms/simplifiers/float_normalization.cc @@ -217,6 +217,9 @@ absl::Status FloatNormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( hlo->mutable_shape(), [&](Shape* subshape, const xla::ShapeIndex& index) { if (subshape->element_type() == from) { subshape->set_element_type(to); + if (subshape->has_layout() && from == F4E2M1FN) { + subshape->mutable_layout()->set_element_size_in_bits(0); + } } }); float_normalization_->UpdateLayout(hlo->mutable_shape()); diff --git a/xla/hlo/transforms/simplifiers/float_normalization_test.cc b/xla/hlo/transforms/simplifiers/float_normalization_test.cc index 86ec889abc6527..b614f74229c0e5 100644 --- a/xla/hlo/transforms/simplifiers/float_normalization_test.cc +++ b/xla/hlo/transforms/simplifiers/float_normalization_test.cc @@ -150,7 +150,9 @@ class FloatNormalizationF8Test public ::testing::WithParamInterface {}; INSTANTIATE_TEST_SUITE_P(FloatNormalizationF8Suite, FloatNormalizationF8Test, - ::testing::Values(F8E3M4, F8E4M3, F8E5M2)); + ::testing::Values(F4E2M1FN, F8E3M4, F8E4M3, + F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, + F8E5M2, F8E5M2FNUZ, F8E8M0FNU)); TEST_F(FloatNormalizationTest, NoopIfSupported) { auto builder = HloComputation::Builder(TestName()); diff --git a/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo b/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo index 3a1e7ceabb160f..577e4ad61f89e2 100644 --- a/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo +++ b/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo @@ -421,6 +421,12 @@ add { // CHECK: %[[VAL_13:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4> %constant.13 = f8e3m4[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN> + %constant.14 = f4e2m1fn[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_15:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00]> : tensor<4xf8E8M0FNU> + %constant.15 = f8e8m0fnu[4] constant({1, 2, 4, 8}) } // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual @@ -542,7 +548,19 @@ add { %convert.15 = f8e3m4[4] convert(f32[4] %convert.14) // CHECK-NEXT: %13 = mhlo.convert %12 : (tensor<4xf8E3M4>) -> tensor<4xf32> - ROOT %convert.16 = f32[4] convert(f8e3m4[4] %convert.15) + %convert.16 = f32[4] convert(f8e3m4[4] %convert.15) + + // CHECK-NEXT: %14 = mhlo.convert %13 : (tensor<4xf32>) -> tensor<4xf4E2M1FN> + %convert.17 = f4e2m1fn[4] convert(f32[4] %convert.16) + + // CHECK-NEXT: %15 = mhlo.convert %14 : (tensor<4xf4E2M1FN>) -> tensor<4xf32> + %convert.18 = f32[4] convert(f4e2m1fn[4] %convert.17) + + // CHECK-NEXT: %16 = mhlo.convert %15 : (tensor<4xf32>) -> tensor<4xf8E8M0FNU> + %convert.19 = f8e8m0fnu[4] convert(f32[4] %convert.18) + + // CHECK-NEXT: %17 = mhlo.convert %16 : (tensor<4xf8E8M0FNU>) -> tensor<4xf32> + ROOT %convert.20 = f32[4] convert(f8e8m0fnu[4] %convert.19) } // CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8> diff --git a/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc b/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc index 821f1487cf88c1..f50e2a097a3277 100644 --- a/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc +++ b/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc @@ -41,6 +41,12 @@ xla::Array ArrayFromDenseElementsAttr(mlir::DenseElementsAttr dense_attr) { xla::Array array(shape.dimensions()); if constexpr (!xla::primitive_util::IsSubByteNonPredType(type)) { array.SetValues(dense_attr.getValues()); + } else if constexpr (xla::primitive_util::IsMXType(type)) { + // Bitcast MX floating point types from APFloat. + auto values = dense_attr.getValues(); + for (int i = 0; i < values.size(); i++) { + array.data()[i] = T::FromRep(values[i].bitcastToAPInt().getZExtValue()); + } } else { // The only way to get subbyte integers from getValues() is to get them as // APInts. diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir index 17b686cc2f5ebe..20b7c7e2642808 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir @@ -606,6 +606,12 @@ func.func @main() { // CHECK: f8e3m4[4] constant({1, 2, 3, 4}) %cst_17 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4> + // CHECK: f4e2m1fn[4] constant({1, 2, 3, 4}) + %cst_18 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf4E2M1FN> + + // CHECK: f8e8m0fnu[4] constant({1, 2, 4, 8}) + %cst_19 = arith.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00]> : tensor<4xf8E8M0FNU> + func.return } @@ -739,7 +745,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { %9 = "mhlo.convert"(%8) : (tensor<2xf8E4M3>) -> tensor<2xf32> %10 = "mhlo.convert"(%9) : (tensor<2xf32>) -> tensor<2xf8E3M4> %11 = "mhlo.convert"(%10) : (tensor<2xf8E3M4>) -> tensor<2xf32> - func.return %11 : tensor<2xf32> + %12 = "mhlo.convert"(%11) : (tensor<2xf32>) -> tensor<2xf4E2M1FN> + %13 = "mhlo.convert"(%12) : (tensor<2xf4E2M1FN>) -> tensor<2xf32> + %14 = "mhlo.convert"(%13) : (tensor<2xf32>) -> tensor<2xf8E8M0FNU> + %15 = "mhlo.convert"(%14) : (tensor<2xf8E8M0FNU>) -> tensor<2xf32> + func.return %15 : tensor<2xf32> } // CHECK: ENTRY @@ -755,7 +765,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: %[[E4M3_VAL:.*]] = f8e4m3[2] convert(f32[2] %[[F32_VAL4]]) // CHECK: %[[F32_VAL5:.*]] = f32[2] convert(f8e4m3[2] %[[E4M3_VAL]]) // CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(f32[2] %[[F32_VAL5]]) -// CHECK: ROOT %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]]) +// CHECK: %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]]) +// CHECK: %[[E2M1FN_VAL:.*]] = f4e2m1fn[2] convert(f32[2] %[[F32_VAL6]]) +// CHECK: %[[F32_VAL7:.*]] = f32[2] convert(f4e2m1fn[2] %[[E2M1FN_VAL]]) +// CHECK: %[[E8M0FNU_VAL:.*]] = f8e8m0fnu[2] convert(f32[2] %[[F32_VAL7]]) +// CHECK: ROOT %[[F32_VAL8:.*]] = f32[2] convert(f8e8m0fnu[2] %[[E8M0FNU_VAL]]) // ----- diff --git a/xla/literal.cc b/xla/literal.cc index 997f44a4dd0f62..866bc1838a9190 100644 --- a/xla/literal.cc +++ b/xla/literal.cc @@ -91,10 +91,11 @@ bool LiteralProtoHasValues(const LiteralProto& proto) { !proto.s16s().empty() || proto.s32s_size() || proto.s64s_size() || !proto.u2s().empty() || !proto.u4s().empty() || !proto.u8s().empty() || !proto.u16s().empty() || proto.u32s_size() || proto.u64s_size() || - !proto.f8e5m2s().empty() || !proto.f8e4m3s().empty() || - !proto.f8e4m3fns().empty() || !proto.f8e4m3b11fnuzs().empty() || - !proto.f8e5m2fnuzs().empty() || !proto.f8e4m3fnuzs().empty() || - !proto.f8e3m4s().empty() || !proto.f16s().empty() || + !proto.f4e2m1fns().empty() || !proto.f8e3m4s().empty() || + !proto.f8e4m3b11fnuzs().empty() || !proto.f8e4m3fns().empty() || + !proto.f8e4m3fnuzs().empty() || !proto.f8e4m3s().empty() || + !proto.f8e5m2fnuzs().empty() || !proto.f8e5m2s().empty() || + !proto.f8e8m0fnus().empty() || !proto.f16s().empty() || !proto.bf16s().empty() || proto.f32s_size() || proto.f64s_size() || proto.c64s_size() || proto.c128s_size() || proto.preds_size() || proto.tuple_literals_size(); @@ -1874,7 +1875,6 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { << __func__ << " is only supported for dense arrays: " << subshape(); CHECK_EQ(size_bytes_dense(), other.size_bytes_dense()); if (primitive_util::IsSubByteNonPredType(subshape().element_type())) { - CHECK(!primitive_util::IsFloatingPointType(subshape().element_type())); auto one_array = buffer(); auto two_array = other.buffer(); const int bits_per_element = @@ -2259,6 +2259,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case S64: CopyToRepeatedField(proto->mutable_s64s(), data()); break; + case F4E2M1FN: + *proto->mutable_f4e2m1fns() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F8E5M2: *proto->mutable_f8e5m2s() = std::string( reinterpret_cast(data().data()), @@ -2294,6 +2299,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { reinterpret_cast(data().data()), size_bytes_dense()); break; + case F8E8M0FNU: + *proto->mutable_f8e8m0fnus() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F16: *proto->mutable_f16s() = std::string(reinterpret_cast(data().data()), @@ -2445,6 +2455,14 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { case U64: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); break; + case F4E2M1FN: { + const std::string& s(proto.f4e2m1fns()); + TF_RET_CHECK(data().size() * + sizeof(tsl::float4_e2m1fn) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + break; + } case F8E5M2: { const std::string& s(proto.f8e5m2s()); TF_RET_CHECK(data().size() * sizeof(tsl::float8_e5m2) == @@ -2498,6 +2516,14 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { memcpy(untyped_data(), s.data(), s.size()); break; } + case F8E8M0FNU: { + const std::string& s(proto.f8e8m0fnus()); + TF_RET_CHECK(data().size() * + sizeof(tsl::float8_e8m0fnu) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + break; + } case F16: { const std::string& s(proto.f16s()); TF_RET_CHECK(data().size() * sizeof(half) == s.size()); diff --git a/xla/literal.h b/xla/literal.h index 3233126a5efb05..adee914c092108 100644 --- a/xla/literal.h +++ b/xla/literal.h @@ -589,18 +589,17 @@ class LiteralBase { primitive_util::NativeToPrimitiveType(); constexpr int bits_per_element = primitive_util::BitWidth(primitive_type); if constexpr (bits_per_element < 8) { - static_assert(!primitive_util::IsFloatingPointType(primitive_type)); static_assert(!primitive_util::IsComplexType(primitive_type)); static_assert(8 % bits_per_element == 0); - constexpr int elements_per_byte = 8 / bits_per_element; + constexpr int elements_per_byte = 8 / bits_per_element; int64_t bytes = elements.size() / elements_per_byte; for (int64_t i = 0; i < bytes; ++i) { uint8_t byte = 0; for (int b = 0; b < elements_per_byte; ++b) { - uint8_t src = - static_cast(elements[i * elements_per_byte + b]) & - LsbMask(bits_per_element); + uint8_t src = Eigen::numext::bit_cast( + elements[i * elements_per_byte + b]) & + LsbMask(bits_per_element); byte |= src << (b * bits_per_element); } WriteElement(byte); @@ -609,9 +608,9 @@ class LiteralBase { if (rest != 0) { uint8_t byte = 0; for (int64_t b = 0; b < rest; ++b) { - uint8_t src = - static_cast(elements[bytes * elements_per_byte + b]) & - LsbMask(bits_per_element); + uint8_t src = Eigen::numext::bit_cast( + elements[bytes * elements_per_byte + b]) & + LsbMask(bits_per_element); byte |= src << (b * bits_per_element); } WriteElement(byte); @@ -701,11 +700,17 @@ class LiteralBase { primitive_util::NativeToPrimitiveType(); constexpr int bits_per_element = primitive_util::BitWidth(primitive_type); if constexpr (bits_per_element < 8) { - static_assert(!primitive_util::IsFloatingPointType(primitive_type)); static_assert(!primitive_util::IsComplexType(primitive_type)); static_assert(8 % bits_per_element == 0); - constexpr int elements_per_byte = 8 / bits_per_element; + constexpr auto cast = [](uint8_t x) -> NativeT { + if constexpr (primitive_util::IsFloatingPointType(primitive_type)) { + return Eigen::numext::bit_cast(x); + } + return static_cast(x); + }; + + constexpr int elements_per_byte = 8 / bits_per_element; int64_t bytes = elements.size() / elements_per_byte; for (int64_t i = 0; i < bytes; ++i) { uint8_t byte; @@ -714,7 +719,7 @@ class LiteralBase { } for (int b = 0; b < elements_per_byte; ++b) { elements[i * elements_per_byte + b] = - static_cast(byte & LsbMask(bits_per_element)); + cast(byte & LsbMask(bits_per_element)); byte >>= bits_per_element; } } @@ -726,7 +731,7 @@ class LiteralBase { } for (int64_t b = 0; b < rest; ++b) { elements[bytes * elements_per_byte + b] = - static_cast(byte & LsbMask(bits_per_element)); + cast(byte & LsbMask(bits_per_element)); byte >>= bits_per_element; } } diff --git a/xla/literal_comparison.cc b/xla/literal_comparison.cc index c97629594122bb..fee817978ec3e4 100644 --- a/xla/literal_comparison.cc +++ b/xla/literal_comparison.cc @@ -418,6 +418,9 @@ class NearComparator { } else { float_distance = CalculateFloatDistance(expected, actual); abs_error = FpAbsoluteValue(actual - expected); + if (!std::numeric_limits::is_signed && IsNaN(abs_error)) { + abs_error = FpAbsoluteValue(expected - actual); + } // Avoid division by 0 even though it's well-defined because ubsan can be // configured to treat this as a fatal error. diff --git a/xla/literal_comparison_test.cc b/xla/literal_comparison_test.cc index 7713aceaaa3bc5..29c12eb7c75e4a 100644 --- a/xla/literal_comparison_test.cc +++ b/xla/literal_comparison_test.cc @@ -30,13 +30,15 @@ template class LiteralComparisonTest : public ::testing::Test {}; using TestedTypes = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(LiteralComparisonTest, TestedTypes); TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { - auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); - auto expected = LiteralUtil::CreateR0(TypeParam(8.0)); + auto actual = LiteralUtil::CreateR0(TypeParam(1.0)); + auto expected = LiteralUtil::CreateR0(TypeParam(1.0)); TF_EXPECT_OK(literal_comparison::Near(expected, actual, ErrorSpec(0.0, 0.0), /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); @@ -44,12 +46,16 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); - auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); - float expV = 9.0; // F8E4M3* - if (type == F8E5M2) - expV = 10.0; + auto actual = LiteralUtil::CreateR0(TypeParam(1.0)); + float expV = 1.125; // F8E4M3* + if (type == F8E5M2 || type == F8E5M2FNUZ) + expV = 1.25; else if (type == F8E3M4) - expV = 8.5; + expV = 1.0625; + else if (type == F4E2M1FN) + expV = 1.5; + else if (type == F8E8M0FNU) + expV = 2.0; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec, @@ -64,12 +70,16 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); - auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); - float expV = 12.0; // F8E4M3* - if (type == F8E5M2) - expV = 14.0; + auto actual = LiteralUtil::CreateR0(TypeParam(1.0)); + float expV = 1.5; // F8E4M3* + if (type == F8E5M2 || type == F8E5M2FNUZ) + expV = 2.0; else if (type == F8E3M4) - expV = 10.0; + expV = 1.25; + else if (type == F4E2M1FN) + expV = 4.0; + else if (type == F8E8M0FNU) + expV = 16.0; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; @@ -86,12 +96,16 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { TYPED_TEST(LiteralComparisonTest, FloatUsingCompareNear_NotEqual_4ulps) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); - auto actual = LiteralUtil::CreateR0(8.0); - float expV = 12.1; // F8E4M3* - if (type == F8E5M2) - expV = 13.0; + auto actual = LiteralUtil::CreateR0(1.0); + float expV = 1.51; // F8E4M3* + if (type == F8E5M2 || type == F8E5M2FNUZ) + expV = 2.01; else if (type == F8E3M4) - expV = 10.125; + expV = 1.26; + else if (type == F4E2M1FN) + expV = 4.1; + else if (type == F8E8M0FNU) + expV = 16.5; auto expected = LiteralUtil::CreateR0(expV); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; diff --git a/xla/literal_test.cc b/xla/literal_test.cc index 44e4acd6a5cef7..7aa9f2dc040dcd 100644 --- a/xla/literal_test.cc +++ b/xla/literal_test.cc @@ -124,11 +124,11 @@ class LiteralUtilTest : public ::testing::Test { template class LiteralUtilFloatTest : public LiteralUtilTest {}; -using FloatTypes = - ::testing::Types; +using FloatTypes = ::testing::Types; TYPED_TEST_SUITE(LiteralUtilFloatTest, FloatTypes); @@ -175,6 +175,10 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { LiteralUtil::CreateR0(static_cast(9.001f)); EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString()); + auto f4e2m1fn_lit = + LiteralUtil::CreateR0(tsl::float4_e2m1fn(0.5)); + EXPECT_EQ("f4e2m1fn[] 0.5", f4e2m1fn_lit.ToString()); + auto f8e5m2_lit = LiteralUtil::CreateR0(tsl::float8_e5m2(0.5)); EXPECT_EQ("f8e5m2[] 0.5", f8e5m2_lit.ToString()); @@ -207,6 +211,10 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f8e3m4_lit = LiteralUtil::CreateR0(tsl::float8_e3m4(0.5)); EXPECT_EQ("f8e3m4[] 0.5", f8e3m4_lit.ToString()); + + auto f8e8m0fnu_lit = + LiteralUtil::CreateR0(tsl::float8_e8m0fnu(0.5)); + EXPECT_EQ("f8e8m0fnu[] 0.5", f8e8m0fnu_lit.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -659,6 +667,11 @@ TEST_F(LiteralUtilTest, IsAll) { bfloat16 b90(9.00f); EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}}).IsAll(9.0)); + tsl::float4_e2m1fn m16(4); + EXPECT_TRUE(LiteralUtil::CreateR1({m16}).IsAll(4)); + // 5 rounds to 4 in E2M1FN but is not equal to 4, so this should be false + EXPECT_FALSE(LiteralUtil::CreateR1({m16}).IsAll(5)); + tsl::float8_e5m2 p16(8); EXPECT_TRUE(LiteralUtil::CreateR1({p16}).IsAll(8)); // 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false @@ -689,6 +702,11 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE(LiteralUtil::CreateR1({v16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({v16}).IsAll(9)); + tsl::float8_e8m0fnu w16(8); + EXPECT_TRUE(LiteralUtil::CreateR1({w16}).IsAll(8)); + // 9 rounds to 8 in E8M0FNU but is not equal to 8, so this should be false + EXPECT_FALSE(LiteralUtil::CreateR1({w16}).IsAll(9)); + complex64 c8_9 = {8, 9}; EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAll(8)); @@ -2214,6 +2232,9 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); auto vector_half = LiteralUtil::CreateR1({half{10.0}, half{20.0}, half{-30.0}}); + using e2m1 = tsl::float4_e2m1fn; + auto vector_f4e2m1fn = + LiteralUtil::CreateR1({e2m1{1.0}, e2m1{2.0}, e2m1{-3.0}}); using e5 = tsl::float8_e5m2; auto vector_f8e5m2 = LiteralUtil::CreateR1({e5{10.0}, e5{20.0}, e5{-32.0}}); @@ -2234,6 +2255,9 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { LiteralUtil::CreateR1({e4f{10.0}, e4f{20.0}, e4f{-30.0}}); using e3 = tsl::float8_e3m4; auto vector_f8e3m4 = LiteralUtil::CreateR1({e3{2.5}, e3{5.0}, e3{-8.0}}); + using e8m0 = tsl::float8_e8m0fnu; + auto vector_f8e8m0fnu = + LiteralUtil::CreateR1({e8m0{1.0}, e8m0{2.0}, e8m0{4.0}}); auto matrix_pred = LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); auto vector_s4 = LiteralUtil::CreateR1({s4{-1}, s4{3}, s4{7}}); @@ -2254,13 +2278,15 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { EXPECT_EQ(vector_c64, to_from_proto(vector_c64)); EXPECT_EQ(vector_c128, to_from_proto(vector_c128)); EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); - EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); + EXPECT_EQ(vector_f4e2m1fn, to_from_proto(vector_f4e2m1fn)); + EXPECT_EQ(vector_f8e3m4, to_from_proto(vector_f8e3m4)); EXPECT_EQ(vector_f8e4m3, to_from_proto(vector_f8e4m3)); - EXPECT_EQ(vector_f8e4m3fn, to_from_proto(vector_f8e4m3fn)); EXPECT_EQ(vector_f8e4m3b11, to_from_proto(vector_f8e4m3b11)); - EXPECT_EQ(vector_f8e5m2fnuz, to_from_proto(vector_f8e5m2fnuz)); + EXPECT_EQ(vector_f8e4m3fn, to_from_proto(vector_f8e4m3fn)); EXPECT_EQ(vector_f8e4m3fnuz, to_from_proto(vector_f8e4m3fnuz)); - EXPECT_EQ(vector_f8e3m4, to_from_proto(vector_f8e3m4)); + EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); + EXPECT_EQ(vector_f8e5m2fnuz, to_from_proto(vector_f8e5m2fnuz)); + EXPECT_EQ(vector_f8e8m0fnu, to_from_proto(vector_f8e8m0fnu)); EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); EXPECT_EQ(vector_s4, to_from_proto(vector_s4)); EXPECT_EQ(vector_u4, to_from_proto(vector_u4)); @@ -2511,19 +2537,19 @@ TEST_F(LiteralUtilTest, SliceOnBool) { } TEST_F(LiteralUtilTest, IsEqualAt) { - double val_double = 10.0; - int val_integral = 10; - Literal c1 = LiteralUtil::CreateR0(10); + double val_double = 4.0; + int val_integral = 4; + Literal c1 = LiteralUtil::CreateR0(val_integral); EXPECT_TRUE(c1.IsEqualAt({}, val_double)); EXPECT_TRUE(c1.IsEqualAt({}, val_integral)); - Literal c2 = LiteralUtil::CreateR0(10); + Literal c2 = LiteralUtil::CreateR0(val_double); EXPECT_TRUE(c2.IsEqualAt({}, val_double)); EXPECT_TRUE(c2.IsEqualAt({}, val_integral)); Literal c3 = LiteralUtil::CreateR0(tsl::float8_e5m2{val_double}); EXPECT_TRUE(c3.IsEqualAt({}, val_double)); EXPECT_TRUE(c3.IsEqualAt({}, val_integral)); - complex128 val_complex = {10, 0}; + complex128 val_complex = {val_double, 0}; EXPECT_TRUE(c1.IsEqualAt({}, val_complex)); EXPECT_TRUE(c2.IsEqualAt({}, val_complex)); EXPECT_TRUE(c3.IsEqualAt({}, val_complex)); @@ -2532,8 +2558,8 @@ TEST_F(LiteralUtilTest, IsEqualAt) { EXPECT_TRUE(c4.IsEqualAt({}, val_integral)); EXPECT_TRUE(c4.IsEqualAt({}, val_complex)); EXPECT_FALSE(c4.IsEqualAt({}, std::numeric_limits::infinity())); - complex128 val_true_complex = {10, 3}; - complex64 val_smaller_complex = {10, 3}; + complex128 val_true_complex = {val_double, 3}; + complex64 val_smaller_complex = {static_cast(val_double), 3}; Literal c5 = LiteralUtil::CreateR0(val_true_complex); EXPECT_TRUE(c5.IsEqualAt({}, val_true_complex)); EXPECT_TRUE(c5.IsEqualAt({}, val_smaller_complex)); @@ -2557,6 +2583,14 @@ TEST_F(LiteralUtilTest, IsEqualAt) { LiteralUtil::CreateR0(tsl::float8_e3m4{val_double}); EXPECT_TRUE(c10.IsEqualAt({}, val_double)); EXPECT_TRUE(c10.IsEqualAt({}, val_integral)); + Literal c11 = + LiteralUtil::CreateR0(tsl::float4_e2m1fn{val_double}); + EXPECT_TRUE(c11.IsEqualAt({}, val_double)); + EXPECT_TRUE(c11.IsEqualAt({}, val_integral)); + Literal c12 = LiteralUtil::CreateR0( + tsl::float8_e8m0fnu{val_double}); + EXPECT_TRUE(c12.IsEqualAt({}, val_double)); + EXPECT_TRUE(c12.IsEqualAt({}, val_integral)); } TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) { @@ -2882,10 +2916,11 @@ class LiteralSerializationTest : public ::testing::Test, static std::vector GenerateSimpleParams() { std::vector params; for (PrimitiveType element_type : - {PRED, S4, U4, S8, U8, S16, - U16, S32, U32, S64, U64, F16, - F32, F64, BF16, F8E5M2, F8E4M3, F8E4M3FN, - F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ, F8E3M4, C64, C128}) { + {PRED, S4, U4, S8, U8, S16, + U16, S32, U32, S64, U64, F16, + F32, F64, BF16, F4E2M1FN, F8E3M4, F8E4M3, + F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F8E8M0FNU, + C64, C128}) { for (const DimensionVector& dimensions : { DimensionVector{}, DimensionVector{0}, diff --git a/xla/mlir/utils/type_util.cc b/xla/mlir/utils/type_util.cc index 2581390a1e13d7..ea8da4d4990d9d 100644 --- a/xla/mlir/utils/type_util.cc +++ b/xla/mlir/utils/type_util.cc @@ -32,6 +32,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( switch (type) { case xla::PrimitiveType::PRED: return b.getI1Type(); + case xla::PrimitiveType::F4E2M1FN: + return b.getFloat4E2M1FNType(); case xla::PrimitiveType::F8E5M2: return b.getFloat8E5M2Type(); case xla::PrimitiveType::F8E4M3: @@ -46,6 +48,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( return b.getFloat8E4M3FNUZType(); case xla::PrimitiveType::F8E3M4: return b.getFloat8E3M4Type(); + case xla::PrimitiveType::F8E8M0FNU: + return b.getFloat8E8M0FNUType(); case xla::PrimitiveType::F16: return b.getF16Type(); case xla::PrimitiveType::BF16: @@ -78,7 +82,9 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( } xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { - if (type.isFloat8E5M2()) { + if (type.isFloat4E2M1FN()) { + return xla::PrimitiveType::F4E2M1FN; + } else if (type.isFloat8E5M2()) { return xla::PrimitiveType::F8E5M2; } else if (type.isFloat8E4M3()) { return xla::PrimitiveType::F8E4M3; @@ -92,6 +98,8 @@ xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { return xla::PrimitiveType::F8E5M2FNUZ; } else if (type.isFloat8E3M4()) { return xla::PrimitiveType::F8E3M4; + } else if (type.isFloat8E8M0FNU()) { + return xla::PrimitiveType::F8E8M0FNU; } else if (type.isBF16()) { return xla::PrimitiveType::BF16; } else if (type.isF16()) { diff --git a/xla/mlir/utils/type_util_test.cc b/xla/mlir/utils/type_util_test.cc index a8043ab0b5f140..2239943d906b7b 100644 --- a/xla/mlir/utils/type_util_test.cc +++ b/xla/mlir/utils/type_util_test.cc @@ -101,6 +101,7 @@ INSTANTIATE_TEST_SUITE_P( Execute, TypeUtilTest, ::testing::ValuesIn(std::vector( {{PRED, [](mlir::Builder b) { return b.getI1Type(); }}, + {F4E2M1FN, [](mlir::Builder b) { return b.getFloat4E2M1FNType(); }}, {F8E5M2, [](mlir::Builder b) { return b.getFloat8E5M2Type(); }}, {F8E4M3, [](mlir::Builder b) { return b.getFloat8E4M3Type(); }}, {F8E4M3FN, [](mlir::Builder b) { return b.getFloat8E4M3FNType(); }}, @@ -111,6 +112,7 @@ INSTANTIATE_TEST_SUITE_P( {F8E4M3FNUZ, [](mlir::Builder b) { return b.getFloat8E4M3FNUZType(); }}, {F8E3M4, [](mlir::Builder b) { return b.getFloat8E3M4Type(); }}, + {F8E8M0FNU, [](mlir::Builder b) { return b.getFloat8E8M0FNUType(); }}, {F16, [](mlir::Builder b) { return b.getF16Type(); }}, {BF16, [](mlir::Builder b) { return b.getBF16Type(); }}, {F32, [](mlir::Builder b) { return b.getF32Type(); }}, diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index 72a8248f2526cb..5dcd78efbfe3cb 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -6832,6 +6832,13 @@ func.func @invalid_dimension_attr(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @f8e3m4(%arg0: tensor) -> tensor { %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor func.return %0 : tensor @@ -6860,6 +6867,13 @@ func.func @f8e5m2(%arg0: tensor) -> tensor { // ----- +func.func @f8e8m0fnu(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @top_k_1d(%arg0 : tensor<16xf32>) { %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) return diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index 6034d631634e02..34e5f422677d49 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -1,5 +1,8 @@ # PJRT C API changelog +## 0.58 +* Added types F4E2M1FN and F8E8M0FNU. + ## 0.57 * Rearranged fields in the PJRT_Api * Update outdated struct sizes from previous changes to diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index 75b83dfa0012ea..004c6ea7bb06e5 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -79,7 +79,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 57 +#define PJRT_API_MINOR 58 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -648,6 +648,10 @@ typedef enum { // More truncated 8 bit floating-point formats. PJRT_Buffer_Type_F8E4M3, PJRT_Buffer_Type_F8E3M4, + PJRT_Buffer_Type_F8E8M0FNU, + + // 4-bit MX floating-point format. + PJRT_Buffer_Type_F4E2M1FN, } PJRT_Buffer_Type; typedef enum { diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index 70ecc239439506..4ce2a66cdc31a4 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -295,6 +295,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_BF16; case xla::PrimitiveType::F64: return PJRT_Buffer_Type::PJRT_Buffer_Type_F64; + case xla::PrimitiveType::F4E2M1FN: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F4E2M1FN; case xla::PrimitiveType::F8E5M2: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2; case xla::PrimitiveType::F8E4M3: @@ -309,6 +311,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FNUZ; case xla::PrimitiveType::F8E3M4: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4; + case xla::PrimitiveType::F8E8M0FNU: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E8M0FNU; case xla::PrimitiveType::C64: return PJRT_Buffer_Type::PJRT_Buffer_Type_C64; case xla::PrimitiveType::C128: @@ -362,6 +366,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::C64; case PJRT_Buffer_Type::PJRT_Buffer_Type_C128: return xla::PrimitiveType::C128; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F4E2M1FN: + return xla::PrimitiveType::F4E2M1FN; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2: return xla::PrimitiveType::F8E5M2; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3: @@ -376,6 +382,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::F8E4M3FNUZ; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4: return xla::PrimitiveType::F8E3M4; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E8M0FNU: + return xla::PrimitiveType::F8E8M0FNU; case PJRT_Buffer_Type::PJRT_Buffer_Type_INVALID: CHECK(false) << "Buffer type is not supported in C API layer."; } diff --git a/xla/primitive_util.cc b/xla/primitive_util.cc index b70ba275a1f47f..5006406ea99779 100644 --- a/xla/primitive_util.cc +++ b/xla/primitive_util.cc @@ -93,6 +93,18 @@ bool HasInfinity(PrimitiveType type) { return false; } +bool HasNaN(PrimitiveType type) { + if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { + return FloatingPointTypeSwitch( + [&](auto constant_type) -> bool { + return std::numeric_limits< + NativeTypeOf>::has_quiet_NaN; + }, + type); + } + return false; +} + bool HasNegativeZero(PrimitiveType type) { if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { return FloatingPointTypeSwitch( diff --git a/xla/primitive_util.h b/xla/primitive_util.h index de5ee4fde11d7b..2a831d2df1044e 100644 --- a/xla/primitive_util.h +++ b/xla/primitive_util.h @@ -69,6 +69,9 @@ int ExponentBias(PrimitiveType type); // Returns whether the type has a value for infinity. bool HasInfinity(PrimitiveType type); +// Returns whether the type has a value for NaN. +bool HasNaN(PrimitiveType type); + // Returns whether the type has a value for negative zero. bool HasNegativeZero(PrimitiveType type); @@ -175,6 +178,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { return BF16; } +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return F4E2M1FN; +} + template <> constexpr PrimitiveType NativeToPrimitiveType() { return F8E5M2; @@ -210,6 +218,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { return F8E3M4; } +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return F8E8M0FNU; +} + // Complex template <> constexpr PrimitiveType NativeToPrimitiveType() { @@ -314,6 +327,11 @@ struct PrimitiveTypeToNative { using type = bfloat16; }; +template <> +struct PrimitiveTypeToNative { + using type = tsl::float4_e2m1fn; +}; + template <> struct PrimitiveTypeToNative { using type = tsl::float8_e5m2; @@ -349,6 +367,11 @@ struct PrimitiveTypeToNative { using type = tsl::float8_e3m4; }; +template <> +struct PrimitiveTypeToNative { + using type = tsl::float8_e8m0fnu; +}; + // Complex template <> struct PrimitiveTypeToNative { @@ -381,6 +404,10 @@ inline constexpr bool IsArrayType(PrimitiveType primitive_type) { primitive_type < PrimitiveType_ARRAYSIZE; } +constexpr bool IsMXType(PrimitiveType type) { + return type == F4E2M1FN || type == F8E8M0FNU; +} + constexpr bool IsF8Type(PrimitiveType type) { return type == F8E5M2 || type == F8E4M3 || type == F8E4M3FN || type == F8E4M3B11FNUZ || type == F8E5M2FNUZ || type == F8E4M3FNUZ || @@ -389,7 +416,7 @@ constexpr bool IsF8Type(PrimitiveType type) { constexpr bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64 || type == BF16 || - IsF8Type(type); + IsF8Type(type) || IsMXType(type); } constexpr bool IsComplexType(PrimitiveType type) { @@ -449,6 +476,9 @@ template constexpr R FloatingPointTypeSwitch(F&& f, PrimitiveType type) { if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { switch (type) { + case F4E2M1FN: + return std::forward(f)( + PrimitiveTypeConstant()); case F8E3M4: return std::forward(f)( PrimitiveTypeConstant()); @@ -470,6 +500,9 @@ constexpr R FloatingPointTypeSwitch(F&& f, PrimitiveType type) { case F8E5M2FNUZ: return std::forward(f)( PrimitiveTypeConstant()); + case F8E8M0FNU: + return std::forward(f)( + PrimitiveTypeConstant()); case F16: return std::forward(f)(PrimitiveTypeConstant()); case BF16: @@ -553,6 +586,9 @@ inline constexpr int PrimitiveTypeBitWidth() { if constexpr (primitive_type == PRED) { return std::numeric_limits::digits; } + if constexpr (IsMXType(primitive_type)) { + return NativeT::kBits; + } if constexpr (IsFloatingPointType(primitive_type)) { return sizeof(NativeT) * std::numeric_limits::digits; } @@ -689,6 +725,10 @@ inline bool CastPreservesValues(PrimitiveType from_type, if (from_type == to_type) { return true; } + // * -> F8E8M0FNU is not possible because zero cannot be represented. + if (to_type == F8E8M0FNU) { + return false; + } // PRED -> * if (from_type == PRED) { return true; @@ -711,21 +751,33 @@ inline bool CastPreservesValues(PrimitiveType from_type, return false; } // F -> F is safe if the exponent/significand are preserved and `to_type` - // preserves infinities in `from_type. + // preserves infinities/nans/unsigned zero in `from_type`. if (primitive_util::IsFloatingPointType(from_type) && primitive_util::IsFloatingPointType(to_type)) { - return (!primitive_util::HasInfinity(from_type) || - primitive_util::HasInfinity(to_type)) && - primitive_util::SignificandWidth(from_type) <= - primitive_util::SignificandWidth(to_type) && - primitive_util::ExponentWidth(from_type) <= - primitive_util::ExponentWidth(to_type) && - (primitive_util::UnderflowExponent(from_type) - - primitive_util::SignificandWidth(from_type)) >= - (primitive_util::UnderflowExponent(to_type) - - primitive_util::SignificandWidth(to_type)) && - primitive_util::OverflowExponent(from_type) <= - primitive_util::OverflowExponent(to_type); + return + // Target mantissa should be large enough. + primitive_util::SignificandWidth(from_type) <= + primitive_util::SignificandWidth(to_type) && + // Target exponent should be large enough. + primitive_util::ExponentWidth(from_type) <= + primitive_util::ExponentWidth(to_type) && + // HasInfinity check. + (!primitive_util::HasInfinity(from_type) || + primitive_util::HasInfinity(to_type)) && + // HasNaN check. + (!primitive_util::HasNaN(from_type) || + primitive_util::HasNaN(to_type)) && + // HasNegativeZero check. + (!primitive_util::HasNegativeZero(from_type) || + primitive_util::HasNegativeZero(to_type)) && + // Minimum denormal should be representable by target type. + (primitive_util::UnderflowExponent(from_type) - + primitive_util::SignificandWidth(from_type)) >= + (primitive_util::UnderflowExponent(to_type) - + primitive_util::SignificandWidth(to_type)) && + // Maximum exponent may be larger with custom bias (e.g. F8E4M3B11FNUZ). + primitive_util::OverflowExponent(from_type) <= + primitive_util::OverflowExponent(to_type); } // F -> I is not safe because it drops fractional numbers. if (!primitive_util::IsIntegralType(from_type)) { diff --git a/xla/primitive_util_test.cc b/xla/primitive_util_test.cc index 850203f17379a4..740eb4ac693080 100644 --- a/xla/primitive_util_test.cc +++ b/xla/primitive_util_test.cc @@ -75,6 +75,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[PRED][C64] = true; expecteds[PRED][BF16] = true; expecteds[PRED][C128] = true; + expecteds[PRED][F4E2M1FN] = true; expecteds[PRED][F8E5M2] = true; expecteds[PRED][F8E4M3] = true; expecteds[PRED][F8E4M3FN] = true; @@ -82,6 +83,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[PRED][F8E5M2FNUZ] = true; expecteds[PRED][F8E4M3FNUZ] = true; expecteds[PRED][F8E3M4] = true; + expecteds[PRED][F8E8M0FNU] = false; expecteds[S2][PRED] = false; expecteds[S2][S2] = true; expecteds[S2][S4] = true; @@ -101,6 +103,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S2][C64] = true; expecteds[S2][BF16] = true; expecteds[S2][C128] = true; + expecteds[S2][F4E2M1FN] = true; expecteds[S2][F8E5M2] = true; expecteds[S2][F8E4M3] = true; expecteds[S2][F8E4M3FN] = true; @@ -108,6 +111,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S2][F8E5M2FNUZ] = true; expecteds[S2][F8E4M3FNUZ] = true; expecteds[S2][F8E3M4] = true; + expecteds[S2][F8E8M0FNU] = false; expecteds[S4][PRED] = false; expecteds[S4][S2] = false; expecteds[S4][S4] = true; @@ -127,6 +131,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S4][C64] = true; expecteds[S4][BF16] = true; expecteds[S4][C128] = true; + expecteds[S4][F4E2M1FN] = false; expecteds[S4][F8E5M2] = true; expecteds[S4][F8E4M3] = true; expecteds[S4][F8E4M3FN] = true; @@ -134,6 +139,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S4][F8E5M2FNUZ] = true; expecteds[S4][F8E4M3FNUZ] = true; expecteds[S4][F8E3M4] = true; + expecteds[S4][F8E8M0FNU] = false; expecteds[S8][PRED] = false; expecteds[S8][S2] = false; expecteds[S8][S4] = false; @@ -153,6 +159,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S8][C64] = true; expecteds[S8][BF16] = true; expecteds[S8][C128] = true; + expecteds[S8][F4E2M1FN] = false; expecteds[S8][F8E5M2] = false; expecteds[S8][F8E4M3] = false; expecteds[S8][F8E4M3FN] = false; @@ -160,6 +167,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S8][F8E5M2FNUZ] = false; expecteds[S8][F8E4M3FNUZ] = false; expecteds[S8][F8E3M4] = false; + expecteds[S8][F8E8M0FNU] = false; expecteds[S16][PRED] = false; expecteds[S16][S2] = false; expecteds[S16][S4] = false; @@ -179,6 +187,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S16][C64] = true; expecteds[S16][BF16] = false; expecteds[S16][C128] = true; + expecteds[S16][F4E2M1FN] = false; expecteds[S16][F8E5M2] = false; expecteds[S16][F8E4M3] = false; expecteds[S16][F8E4M3FN] = false; @@ -186,6 +195,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S16][F8E5M2FNUZ] = false; expecteds[S16][F8E4M3FNUZ] = false; expecteds[S16][F8E3M4] = false; + expecteds[S16][F8E8M0FNU] = false; expecteds[S32][PRED] = false; expecteds[S32][S2] = false; expecteds[S32][S4] = false; @@ -205,6 +215,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S32][C64] = false; expecteds[S32][BF16] = false; expecteds[S32][C128] = true; + expecteds[S32][F4E2M1FN] = false; expecteds[S32][F8E5M2] = false; expecteds[S32][F8E4M3] = false; expecteds[S32][F8E4M3FN] = false; @@ -212,6 +223,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S32][F8E5M2FNUZ] = false; expecteds[S32][F8E4M3FNUZ] = false; expecteds[S32][F8E3M4] = false; + expecteds[S32][F8E8M0FNU] = false; expecteds[S64][PRED] = false; expecteds[S64][S2] = false; expecteds[S64][S4] = false; @@ -231,6 +243,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S64][C64] = false; expecteds[S64][BF16] = false; expecteds[S64][C128] = false; + expecteds[S64][F4E2M1FN] = false; expecteds[S64][F8E5M2] = false; expecteds[S64][F8E4M3] = false; expecteds[S64][F8E4M3FN] = false; @@ -238,6 +251,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S64][F8E5M2FNUZ] = false; expecteds[S64][F8E4M3FNUZ] = false; expecteds[S64][F8E3M4] = false; + expecteds[S64][F8E8M0FNU] = false; expecteds[U2][PRED] = false; expecteds[U2][S2] = false; expecteds[U2][S4] = true; @@ -257,8 +271,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U2][C64] = true; expecteds[U2][BF16] = true; expecteds[U2][C128] = true; - expecteds[U2][BF16] = true; - expecteds[U2][C128] = true; + expecteds[U2][F4E2M1FN] = true; expecteds[U2][F8E5M2] = true; expecteds[U2][F8E4M3] = true; expecteds[U2][F8E4M3FN] = true; @@ -266,6 +279,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U2][F8E5M2FNUZ] = true; expecteds[U2][F8E4M3FNUZ] = true; expecteds[U2][F8E3M4] = true; + expecteds[U2][F8E8M0FNU] = false; expecteds[U4][PRED] = false; expecteds[U4][S2] = false; expecteds[U4][S4] = false; @@ -285,8 +299,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U4][C64] = true; expecteds[U4][BF16] = true; expecteds[U4][C128] = true; - expecteds[U4][BF16] = true; - expecteds[U4][C128] = true; + expecteds[U4][F4E2M1FN] = false; expecteds[U4][F8E5M2] = false; expecteds[U4][F8E4M3] = true; expecteds[U4][F8E4M3FN] = true; @@ -294,6 +307,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U4][F8E5M2FNUZ] = false; expecteds[U4][F8E4M3FNUZ] = true; expecteds[U4][F8E3M4] = true; + expecteds[U4][F8E8M0FNU] = false; expecteds[U8][PRED] = false; expecteds[U8][S2] = false; expecteds[U8][S4] = false; @@ -313,8 +327,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U8][C64] = true; expecteds[U8][BF16] = true; expecteds[U8][C128] = true; - expecteds[U8][BF16] = true; - expecteds[U8][C128] = true; + expecteds[U8][F4E2M1FN] = false; expecteds[U8][F8E5M2] = false; expecteds[U8][F8E4M3] = false; expecteds[U8][F8E4M3FN] = false; @@ -322,6 +335,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U8][F8E5M2FNUZ] = false; expecteds[U8][F8E4M3FNUZ] = false; expecteds[U8][F8E3M4] = false; + expecteds[U8][F8E8M0FNU] = false; expecteds[U16][PRED] = false; expecteds[U16][S2] = false; expecteds[U16][S4] = false; @@ -341,6 +355,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U16][C64] = true; expecteds[U16][BF16] = false; expecteds[U16][C128] = true; + expecteds[U16][F4E2M1FN] = false; expecteds[U16][F8E5M2] = false; expecteds[U16][F8E4M3] = false; expecteds[U16][F8E4M3FN] = false; @@ -348,6 +363,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U16][F8E5M2FNUZ] = false; expecteds[U16][F8E4M3FNUZ] = false; expecteds[U16][F8E3M4] = false; + expecteds[U16][F8E8M0FNU] = false; expecteds[U32][PRED] = false; expecteds[U32][S2] = false; expecteds[U32][S4] = false; @@ -367,6 +383,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U32][C64] = false; expecteds[U32][BF16] = false; expecteds[U32][C128] = true; + expecteds[U32][F4E2M1FN] = false; expecteds[U32][F8E5M2] = false; expecteds[U32][F8E4M3] = false; expecteds[U32][F8E4M3FN] = false; @@ -374,6 +391,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U32][F8E5M2FNUZ] = false; expecteds[U32][F8E4M3FNUZ] = false; expecteds[U32][F8E3M4] = false; + expecteds[U32][F8E8M0FNU] = false; expecteds[U64][PRED] = false; expecteds[U64][S2] = false; expecteds[U64][S4] = false; @@ -393,6 +411,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U64][C64] = false; expecteds[U64][BF16] = false; expecteds[U64][C128] = false; + expecteds[U64][F4E2M1FN] = false; expecteds[U64][F8E5M2] = false; expecteds[U64][F8E4M3] = false; expecteds[U64][F8E4M3FN] = false; @@ -400,6 +419,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U64][F8E5M2FNUZ] = false; expecteds[U64][F8E4M3FNUZ] = false; expecteds[U64][F8E3M4] = false; + expecteds[U64][F8E8M0FNU] = false; expecteds[F16][PRED] = false; expecteds[F16][S2] = false; expecteds[F16][S4] = false; @@ -419,6 +439,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F16][C64] = true; expecteds[F16][BF16] = false; expecteds[F16][C128] = true; + expecteds[F16][F4E2M1FN] = false; expecteds[F16][F8E5M2] = false; expecteds[F16][F8E4M3] = false; expecteds[F16][F8E4M3FN] = false; @@ -426,6 +447,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F16][F8E5M2FNUZ] = false; expecteds[F16][F8E4M3FNUZ] = false; expecteds[F16][F8E3M4] = false; + expecteds[F16][F8E8M0FNU] = false; expecteds[F32][PRED] = false; expecteds[F32][S2] = false; expecteds[F32][S4] = false; @@ -445,6 +467,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F32][C64] = true; expecteds[F32][BF16] = false; expecteds[F32][C128] = true; + expecteds[F32][F4E2M1FN] = false; expecteds[F32][F8E5M2] = false; expecteds[F32][F8E4M3] = false; expecteds[F32][F8E4M3FN] = false; @@ -452,6 +475,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F32][F8E5M2FNUZ] = false; expecteds[F32][F8E4M3FNUZ] = false; expecteds[F32][F8E3M4] = false; + expecteds[F32][F8E8M0FNU] = false; expecteds[F64][PRED] = false; expecteds[F64][S2] = false; expecteds[F64][S4] = false; @@ -471,6 +495,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F64][C64] = false; expecteds[F64][BF16] = false; expecteds[F64][C128] = true; + expecteds[F64][F4E2M1FN] = false; expecteds[F64][F8E5M2] = false; expecteds[F64][F8E4M3] = false; expecteds[F64][F8E4M3FN] = false; @@ -478,6 +503,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F64][F8E5M2FNUZ] = false; expecteds[F64][F8E4M3FNUZ] = false; expecteds[F64][F8E3M4] = false; + expecteds[F64][F8E8M0FNU] = false; expecteds[C64][PRED] = false; expecteds[C64][S2] = false; expecteds[C64][S4] = false; @@ -497,6 +523,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C64][C64] = true; expecteds[C64][BF16] = false; expecteds[C64][C128] = true; + expecteds[C64][F4E2M1FN] = false; expecteds[C64][F8E5M2] = false; expecteds[C64][F8E4M3] = false; expecteds[C64][F8E4M3FN] = false; @@ -504,6 +531,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C64][F8E5M2FNUZ] = false; expecteds[C64][F8E4M3FNUZ] = false; expecteds[C64][F8E3M4] = false; + expecteds[C64][F8E8M0FNU] = false; expecteds[BF16][PRED] = false; expecteds[BF16][S2] = false; expecteds[BF16][S4] = false; @@ -523,6 +551,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[BF16][C64] = true; expecteds[BF16][BF16] = true; expecteds[BF16][C128] = true; + expecteds[BF16][F4E2M1FN] = false; expecteds[BF16][F8E5M2] = false; expecteds[BF16][F8E4M3] = false; expecteds[BF16][F8E4M3FN] = false; @@ -530,6 +559,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[BF16][F8E5M2FNUZ] = false; expecteds[BF16][F8E4M3FNUZ] = false; expecteds[BF16][F8E3M4] = false; + expecteds[BF16][F8E8M0FNU] = false; expecteds[C128][PRED] = false; expecteds[C128][S2] = false; expecteds[C128][S4] = false; @@ -549,6 +579,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C128][C64] = false; expecteds[C128][BF16] = false; expecteds[C128][C128] = true; + expecteds[C128][F4E2M1FN] = false; expecteds[C128][F8E5M2] = false; expecteds[C128][F8E4M3] = false; expecteds[C128][F8E4M3FN] = false; @@ -556,6 +587,35 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C128][F8E5M2FNUZ] = false; expecteds[C128][F8E4M3FNUZ] = false; expecteds[C128][F8E3M4] = false; + expecteds[C128][F8E8M0FNU] = false; + expecteds[F4E2M1FN][PRED] = false; + expecteds[F4E2M1FN][S2] = false; + expecteds[F4E2M1FN][S4] = false; + expecteds[F4E2M1FN][S8] = false; + expecteds[F4E2M1FN][S16] = false; + expecteds[F4E2M1FN][S32] = false; + expecteds[F4E2M1FN][S64] = false; + expecteds[F4E2M1FN][U2] = false; + expecteds[F4E2M1FN][U4] = false; + expecteds[F4E2M1FN][U8] = false; + expecteds[F4E2M1FN][U16] = false; + expecteds[F4E2M1FN][U32] = false; + expecteds[F4E2M1FN][U64] = false; + expecteds[F4E2M1FN][F16] = true; + expecteds[F4E2M1FN][F32] = true; + expecteds[F4E2M1FN][F64] = true; + expecteds[F4E2M1FN][C64] = true; + expecteds[F4E2M1FN][BF16] = true; + expecteds[F4E2M1FN][C128] = true; + expecteds[F4E2M1FN][F4E2M1FN] = true; + expecteds[F4E2M1FN][F8E5M2] = true; + expecteds[F4E2M1FN][F8E4M3] = true; + expecteds[F4E2M1FN][F8E4M3FN] = true; + expecteds[F4E2M1FN][F8E4M3B11FNUZ] = false; + expecteds[F4E2M1FN][F8E4M3FNUZ] = false; + expecteds[F4E2M1FN][F8E5M2FNUZ] = false; + expecteds[F4E2M1FN][F8E3M4] = true; + expecteds[F4E2M1FN][F8E8M0FNU] = false; expecteds[F8E5M2][PRED] = false; expecteds[F8E5M2][S2] = false; expecteds[F8E5M2][S4] = false; @@ -575,6 +635,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2][C64] = true; expecteds[F8E5M2][BF16] = true; expecteds[F8E5M2][C128] = true; + expecteds[F8E5M2][F4E2M1FN] = false; expecteds[F8E5M2][F8E5M2] = true; expecteds[F8E5M2][F8E4M3] = false; expecteds[F8E5M2][F8E4M3FN] = false; @@ -582,6 +643,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2][F8E5M2FNUZ] = false; expecteds[F8E5M2][F8E4M3FNUZ] = false; expecteds[F8E5M2][F8E3M4] = false; + expecteds[F8E5M2][F8E8M0FNU] = false; expecteds[F8E4M3][PRED] = false; expecteds[F8E4M3][S2] = false; expecteds[F8E4M3][S4] = false; @@ -601,6 +663,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3][C64] = true; expecteds[F8E4M3][BF16] = true; expecteds[F8E4M3][C128] = true; + expecteds[F8E4M3][F4E2M1FN] = false; expecteds[F8E4M3][F8E5M2] = false; expecteds[F8E4M3][F8E5M2FNUZ] = false; expecteds[F8E4M3][F8E4M3] = true; @@ -608,6 +671,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3][F8E4M3FNUZ] = false; expecteds[F8E4M3][F8E4M3B11FNUZ] = false; expecteds[F8E4M3][F8E3M4] = false; + expecteds[F8E4M3][F8E8M0FNU] = false; expecteds[F8E4M3FN][PRED] = false; expecteds[F8E4M3FN][S2] = false; expecteds[F8E4M3FN][S4] = false; @@ -627,6 +691,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FN][C64] = true; expecteds[F8E4M3FN][BF16] = true; expecteds[F8E4M3FN][C128] = true; + expecteds[F8E4M3FN][F4E2M1FN] = false; expecteds[F8E4M3FN][F8E5M2] = false; expecteds[F8E4M3FN][F8E5M2FNUZ] = false; expecteds[F8E4M3FN][F8E4M3] = false; @@ -634,6 +699,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FN][F8E4M3FNUZ] = false; expecteds[F8E4M3FN][F8E4M3B11FNUZ] = false; expecteds[F8E4M3FN][F8E3M4] = false; + expecteds[F8E4M3FN][F8E8M0FNU] = false; expecteds[F8E4M3B11FNUZ][PRED] = false; expecteds[F8E4M3B11FNUZ][S2] = false; expecteds[F8E4M3B11FNUZ][S4] = false; @@ -653,6 +719,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3B11FNUZ][C64] = true; expecteds[F8E4M3B11FNUZ][BF16] = true; expecteds[F8E4M3B11FNUZ][C128] = true; + expecteds[F8E4M3B11FNUZ][F4E2M1FN] = false; expecteds[F8E4M3B11FNUZ][F8E5M2] = false; expecteds[F8E4M3B11FNUZ][F8E4M3] = false; expecteds[F8E4M3B11FNUZ][F8E4M3FN] = false; @@ -660,6 +727,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3B11FNUZ][F8E4M3FNUZ] = false; expecteds[F8E4M3B11FNUZ][F8E5M2FNUZ] = false; expecteds[F8E4M3B11FNUZ][F8E3M4] = false; + expecteds[F8E4M3B11FNUZ][F8E8M0FNU] = false; expecteds[F8E5M2FNUZ][PRED] = false; expecteds[F8E5M2FNUZ][S2] = false; expecteds[F8E5M2FNUZ][S4] = false; @@ -679,6 +747,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2FNUZ][C64] = true; expecteds[F8E5M2FNUZ][BF16] = true; expecteds[F8E5M2FNUZ][C128] = true; + expecteds[F8E5M2FNUZ][F4E2M1FN] = false; expecteds[F8E5M2FNUZ][F8E5M2] = false; expecteds[F8E5M2FNUZ][F8E4M3] = false; expecteds[F8E5M2FNUZ][F8E4M3FN] = false; @@ -686,6 +755,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2FNUZ][F8E5M2FNUZ] = true; expecteds[F8E5M2FNUZ][F8E4M3FNUZ] = false; expecteds[F8E5M2FNUZ][F8E3M4] = false; + expecteds[F8E5M2FNUZ][F8E8M0FNU] = false; expecteds[F8E4M3FNUZ][PRED] = false; expecteds[F8E4M3FNUZ][S2] = false; expecteds[F8E4M3FNUZ][S4] = false; @@ -705,6 +775,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FNUZ][C64] = true; expecteds[F8E4M3FNUZ][BF16] = true; expecteds[F8E4M3FNUZ][C128] = true; + expecteds[F8E4M3FNUZ][F4E2M1FN] = false; expecteds[F8E4M3FNUZ][F8E5M2] = false; expecteds[F8E4M3FNUZ][F8E4M3] = false; expecteds[F8E4M3FNUZ][F8E4M3FN] = false; @@ -712,6 +783,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FNUZ][F8E5M2FNUZ] = false; expecteds[F8E4M3FNUZ][F8E4M3FNUZ] = true; expecteds[F8E4M3FNUZ][F8E3M4] = false; + expecteds[F8E4M3FNUZ][F8E8M0FNU] = false; expecteds[F8E3M4][PRED] = false; expecteds[F8E3M4][S2] = false; expecteds[F8E3M4][S4] = false; @@ -731,6 +803,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E3M4][C64] = true; expecteds[F8E3M4][BF16] = true; expecteds[F8E3M4][C128] = true; + expecteds[F8E3M4][F4E2M1FN] = false; expecteds[F8E3M4][F8E5M2] = false; expecteds[F8E3M4][F8E5M2FNUZ] = false; expecteds[F8E3M4][F8E4M3] = false; @@ -738,6 +811,35 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E3M4][F8E4M3FNUZ] = false; expecteds[F8E3M4][F8E4M3B11FNUZ] = false; expecteds[F8E3M4][F8E3M4] = true; + expecteds[F8E3M4][F8E8M0FNU] = false; + expecteds[F8E8M0FNU][PRED] = false; + expecteds[F8E8M0FNU][S2] = false; + expecteds[F8E8M0FNU][S4] = false; + expecteds[F8E8M0FNU][S8] = false; + expecteds[F8E8M0FNU][S16] = false; + expecteds[F8E8M0FNU][S32] = false; + expecteds[F8E8M0FNU][S64] = false; + expecteds[F8E8M0FNU][U2] = false; + expecteds[F8E8M0FNU][U4] = false; + expecteds[F8E8M0FNU][U8] = false; + expecteds[F8E8M0FNU][U16] = false; + expecteds[F8E8M0FNU][U32] = false; + expecteds[F8E8M0FNU][U64] = false; + expecteds[F8E8M0FNU][F16] = false; + expecteds[F8E8M0FNU][F32] = true; + expecteds[F8E8M0FNU][F64] = true; + expecteds[F8E8M0FNU][C64] = true; + expecteds[F8E8M0FNU][BF16] = true; + expecteds[F8E8M0FNU][C128] = true; + expecteds[F8E8M0FNU][F4E2M1FN] = false; + expecteds[F8E8M0FNU][F8E5M2] = false; + expecteds[F8E8M0FNU][F8E4M3] = false; + expecteds[F8E8M0FNU][F8E4M3FN] = false; + expecteds[F8E8M0FNU][F8E4M3B11FNUZ] = false; + expecteds[F8E8M0FNU][F8E4M3FNUZ] = false; + expecteds[F8E8M0FNU][F8E5M2FNUZ] = false; + expecteds[F8E8M0FNU][F8E3M4] = false; + expecteds[F8E8M0FNU][F8E8M0FNU] = true; for (int from_type_int = PrimitiveType_MIN; from_type_int < PrimitiveType_ARRAYSIZE; ++from_type_int) { @@ -758,7 +860,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { << primitive_util::LowercasePrimitiveTypeName(to_type); } } -} +} // NOLINT(readability/fn_size) } // namespace } // namespace xla diff --git a/xla/python/ifrt/dtype.cc b/xla/python/ifrt/dtype.cc index a79240f51a7e23..4ba15201bcf868 100644 --- a/xla/python/ifrt/dtype.cc +++ b/xla/python/ifrt/dtype.cc @@ -32,6 +32,7 @@ std::optional DType::byte_size() const { case kU2: case kS4: case kU4: + case kF4E2M1FN: // Smaller than a byte. return std::nullopt; case kPred: @@ -39,6 +40,7 @@ std::optional DType::byte_size() const { case kU8: case kF8E3M4: case kF8E4M3: + case kF8E8M0FNU: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -77,12 +79,14 @@ std::optional DType::bit_size() const { return 2; case kS4: case kU4: + case kF4E2M1FN: return 4; case kPred: case kS8: case kU8: case kF8E3M4: case kF8E4M3: + case kF8E8M0FNU: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -142,8 +146,10 @@ absl::StatusOr DType::FromProto(const DTypeProto& dtype_proto) { CASE(C64); CASE(C128); // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // CASE(F4E2M1FN); // CASE(F8E3M4); // CASE(F8E4M3); + // CASE(F8E8M0FNU); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); @@ -190,8 +196,10 @@ DTypeProto DType::ToProto() const { CASE(C64); CASE(C128); // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // CASE(F4E2M1FN); // CASE(F8E3M4); // CASE(F8E4M3); + // CASE(F8E8M0FNU); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); diff --git a/xla/python/ifrt/dtype.h b/xla/python/ifrt/dtype.h index d23efc55a1aa12..ff724df90a2308 100644 --- a/xla/python/ifrt/dtype.h +++ b/xla/python/ifrt/dtype.h @@ -88,8 +88,12 @@ class DType { kF8E4M3FNUZ = 25, kF8E5M2 = 19, kF8E5M2FNUZ = 24, + kF8E8M0FNU = 31, - // Next = 30 + // MX floating point types. + kF4E2M1FN = 30, + + // Next = 32 // Variable-length string represented as raw bytes, as in `bytes` in Python, // i.e., no encoding enforcement. String is not support in XLA. DType.Kind diff --git a/xla/python/ifrt/dtype.proto b/xla/python/ifrt/dtype.proto index 3a2b0df7976d6e..2cf453f26c291d 100644 --- a/xla/python/ifrt/dtype.proto +++ b/xla/python/ifrt/dtype.proto @@ -70,12 +70,18 @@ message DTypeProto { KIND_F8E4M3FNUZ = 25; KIND_F8E5M2 = 19; KIND_F8E5M2FNUZ = 24; + KIND_F8E8M0FNU = 31; + + // MX floating point types. + KIND_F4E2M1FN = 30; // Variable-length string represented as raw bytes, as in `bytes` in Python, // i.e., no encoding enforcement. String is not support in XLA. DType.Kind // needs to match xla.PrimitiveType enum, so choose a large enum to avoid // collision. KIND_STRING = 99; + + // Next: 32 } // LINT.ThenChange() Kind kind = 1; diff --git a/xla/python/ifrt/dtype_test.cc b/xla/python/ifrt/dtype_test.cc index 57fec6702d277d..9d3d3105f54e54 100644 --- a/xla/python/ifrt/dtype_test.cc +++ b/xla/python/ifrt/dtype_test.cc @@ -42,34 +42,21 @@ TEST(DTypeTest, FromToFromProto) { TEST(DTypeTest, ByteSize) { for (const auto& [kind, byte_size] : std::vector>({ - {DType::kS2, -1}, - {DType::kU2, -1}, - {DType::kS4, -1}, - {DType::kU4, -1}, - {DType::kPred, 1}, - {DType::kS8, 1}, - {DType::kU8, 1}, - {DType::kF8E3M4, 1}, - {DType::kF8E4M3, 1}, - {DType::kF8E4M3FN, 1}, - {DType::kF8E4M3B11FNUZ, 1}, - {DType::kF8E4M3FNUZ, 1}, - {DType::kF8E5M2, 1}, - {DType::kF8E5M2FNUZ, 1}, - {DType::kS16, 2}, - {DType::kU16, 2}, - {DType::kF16, 2}, - {DType::kBF16, 2}, - {DType::kS32, 4}, - {DType::kU32, 4}, - {DType::kF32, 4}, - {DType::kS64, 8}, - {DType::kU64, 8}, - {DType::kF64, 8}, - {DType::kC64, 8}, - {DType::kC128, 16}, - {DType::kToken, -1}, - {DType::kInvalid, -1}, + {DType::kS2, -1}, {DType::kU2, -1}, + {DType::kS4, -1}, {DType::kU4, -1}, + {DType::kPred, 1}, {DType::kS8, 1}, + {DType::kU8, 1}, {DType::kF4E2M1FN, -1}, + {DType::kF8E3M4, 1}, {DType::kF8E4M3, 1}, + {DType::kF8E4M3FN, 1}, {DType::kF8E4M3B11FNUZ, 1}, + {DType::kF8E4M3FNUZ, 1}, {DType::kF8E5M2, 1}, + {DType::kF8E5M2FNUZ, 1}, {DType::kF8E8M0FNU, 1}, + {DType::kS16, 2}, {DType::kU16, 2}, + {DType::kF16, 2}, {DType::kBF16, 2}, + {DType::kS32, 4}, {DType::kU32, 4}, + {DType::kF32, 4}, {DType::kS64, 8}, + {DType::kU64, 8}, {DType::kF64, 8}, + {DType::kC64, 8}, {DType::kC128, 16}, + {DType::kToken, -1}, {DType::kInvalid, -1}, {DType::kString, -1}, })) { EXPECT_EQ(DType(kind).byte_size(), @@ -80,34 +67,21 @@ TEST(DTypeTest, ByteSize) { TEST(DTypeTest, BitSize) { for (const auto& [kind, bit_size] : std::vector>({ - {DType::kS2, 2}, - {DType::kU2, 2}, - {DType::kS4, 4}, - {DType::kU4, 4}, - {DType::kPred, 8}, - {DType::kS8, 8}, - {DType::kU8, 8}, - {DType::kF8E3M4, 8}, - {DType::kF8E4M3, 8}, - {DType::kF8E4M3FN, 8}, - {DType::kF8E4M3B11FNUZ, 8}, - {DType::kF8E4M3FNUZ, 8}, - {DType::kF8E5M2, 8}, - {DType::kF8E5M2FNUZ, 8}, - {DType::kS16, 16}, - {DType::kU16, 16}, - {DType::kF16, 16}, - {DType::kBF16, 16}, - {DType::kS32, 32}, - {DType::kU32, 32}, - {DType::kF32, 32}, - {DType::kS64, 64}, - {DType::kU64, 64}, - {DType::kF64, 64}, - {DType::kC64, 64}, - {DType::kC128, 128}, - {DType::kToken, -1}, - {DType::kInvalid, -1}, + {DType::kS2, 2}, {DType::kU2, 2}, + {DType::kS4, 4}, {DType::kU4, 4}, + {DType::kPred, 8}, {DType::kS8, 8}, + {DType::kU8, 8}, {DType::kF4E2M1FN, 4}, + {DType::kF8E3M4, 8}, {DType::kF8E4M3, 8}, + {DType::kF8E4M3FN, 8}, {DType::kF8E4M3B11FNUZ, 8}, + {DType::kF8E4M3FNUZ, 8}, {DType::kF8E5M2, 8}, + {DType::kF8E5M2FNUZ, 8}, {DType::kF8E8M0FNU, 8}, + {DType::kS16, 16}, {DType::kU16, 16}, + {DType::kF16, 16}, {DType::kBF16, 16}, + {DType::kS32, 32}, {DType::kU32, 32}, + {DType::kF32, 32}, {DType::kS64, 64}, + {DType::kU64, 64}, {DType::kF64, 64}, + {DType::kC64, 64}, {DType::kC128, 128}, + {DType::kToken, -1}, {DType::kInvalid, -1}, {DType::kString, -1}, })) { EXPECT_EQ(DType(kind).bit_size(), diff --git a/xla/python/pjrt_ifrt/pjrt_dtype.cc b/xla/python/pjrt_ifrt/pjrt_dtype.cc index 9c581ec6227cae..2af3281a588cce 100644 --- a/xla/python/pjrt_ifrt/pjrt_dtype.cc +++ b/xla/python/pjrt_ifrt/pjrt_dtype.cc @@ -44,6 +44,7 @@ absl::StatusOr ToPrimitiveType(DType dtype) { CASE(DType::kU16, xla::PrimitiveType::U16); CASE(DType::kU32, xla::PrimitiveType::U32); CASE(DType::kU64, xla::PrimitiveType::U64); + CASE(DType::kF4E2M1FN, xla::PrimitiveType::F4E2M1FN); CASE(DType::kF8E3M4, xla::PrimitiveType::F8E3M4); CASE(DType::kF8E4M3, xla::PrimitiveType::F8E4M3); CASE(DType::kF8E4M3FN, xla::PrimitiveType::F8E4M3FN); @@ -51,6 +52,7 @@ absl::StatusOr ToPrimitiveType(DType dtype) { CASE(DType::kF8E4M3FNUZ, xla::PrimitiveType::F8E4M3FNUZ); CASE(DType::kF8E5M2, xla::PrimitiveType::F8E5M2); CASE(DType::kF8E5M2FNUZ, xla::PrimitiveType::F8E5M2FNUZ); + CASE(DType::kF8E8M0FNU, xla::PrimitiveType::F8E8M0FNU); CASE(DType::kF16, xla::PrimitiveType::F16); CASE(DType::kF32, xla::PrimitiveType::F32); CASE(DType::kBF16, xla::PrimitiveType::BF16); @@ -83,6 +85,7 @@ absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { case xla::PrimitiveType::U16: case xla::PrimitiveType::U32: case xla::PrimitiveType::U64: + case xla::PrimitiveType::F4E2M1FN: case xla::PrimitiveType::F8E3M4: case xla::PrimitiveType::F8E4M3: case xla::PrimitiveType::F8E4M3FN: @@ -90,6 +93,7 @@ absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { case xla::PrimitiveType::F8E4M3FNUZ: case xla::PrimitiveType::F8E5M2: case xla::PrimitiveType::F8E5M2FNUZ: + case xla::PrimitiveType::F8E8M0FNU: case xla::PrimitiveType::F16: case xla::PrimitiveType::F32: case xla::PrimitiveType::BF16: diff --git a/xla/python/py_values.cc b/xla/python/py_values.cc index 9a9c63a922e90d..bc896ea2205aca 100644 --- a/xla/python/py_values.cc +++ b/xla/python/py_values.cc @@ -185,6 +185,9 @@ absl::StatusOr HandleNumpyScalar( } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = BF16; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F4E2M1FN; } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = F8E3M4; @@ -206,6 +209,9 @@ absl::StatusOr HandleNumpyScalar( } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = F8E5M2FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E8M0FNU; } else if (std::is_same() || !options.squash_64bit_types) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<0>()); type = primitive_util::NativeToPrimitiveType(); @@ -400,6 +406,10 @@ absl::StatusOr DevicePut(nb::handle arg, (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; + if (dtypes.np_float4_e2m1fn.has_value()) { + (*p)[dtypes.np_float4_e2m1fn->ptr()] = + HandleNumpyScalar; + } if (dtypes.np_float8_e3m4.has_value()) { (*p)[dtypes.np_float8_e3m4->ptr()] = HandleNumpyScalar; @@ -417,6 +427,10 @@ absl::StatusOr DevicePut(nb::handle arg, HandleNumpyScalar; (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = HandleNumpyScalar; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = + HandleNumpyScalar; + } (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar; @@ -598,8 +612,10 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // (*p)[dtypes.np_float4_e2m1fn.ptr()] = numpy_array_handler; // (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler; // (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler; + // (*p)[dtypes.np_float8_e8m0fnu.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; diff --git a/xla/python/types.cc b/xla/python/types.cc index 125f96a75fdf25..af828232bdb5fd 100644 --- a/xla/python/types.cc +++ b/xla/python/types.cc @@ -59,6 +59,7 @@ namespace { struct CustomDtypes { nb_dtype bfloat16; + std::optional float4_e2m1fn; std::optional float8_e3m4; std::optional float8_e4m3; nb_dtype float8_e4m3fn; @@ -66,6 +67,7 @@ struct CustomDtypes { nb_dtype float8_e4m3fnuz; nb_dtype float8_e5m2; nb_dtype float8_e5m2fnuz; + std::optional float8_e8m0fnu; std::optional int2; nb_dtype int4; std::optional uint2; @@ -77,6 +79,10 @@ const CustomDtypes& GetCustomDtypes() { nb::module_ ml_dtypes = nb::module_::import_("ml_dtypes"); auto* dtypes = new CustomDtypes; dtypes->bfloat16 = nb_dtype::from_args(ml_dtypes.attr("bfloat16")); + if (nb::hasattr(ml_dtypes, "float4_e2m1fn")) { + dtypes->float4_e2m1fn = + nb_dtype::from_args(ml_dtypes.attr("float4_e2m1fn")); + } if (nb::hasattr(ml_dtypes, "float8_e3m4")) { dtypes->float8_e3m4 = nb_dtype::from_args(ml_dtypes.attr("float8_e3m4")); } @@ -92,6 +98,10 @@ const CustomDtypes& GetCustomDtypes() { nb_dtype::from_args(ml_dtypes.attr("float8_e4m3fnuz")); dtypes->float8_e5m2fnuz = nb_dtype::from_args(ml_dtypes.attr("float8_e5m2fnuz")); + if (nb::hasattr(ml_dtypes, "float8_e8m0fnu")) { + dtypes->float8_e8m0fnu = + nb_dtype::from_args(ml_dtypes.attr("float8_e8m0fnu")); + } dtypes->int4 = nb_dtype::from_args(ml_dtypes.attr("int4")); dtypes->uint4 = nb_dtype::from_args(ml_dtypes.attr("uint4")); if (nb::hasattr(ml_dtypes, "int2")) { @@ -148,6 +158,9 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { auto* map = new absl::flat_hash_map(); map->emplace(custom_dtypes.bfloat16, BF16); + if (custom_dtypes.float4_e2m1fn.has_value()) { + map->emplace(*custom_dtypes.float4_e2m1fn, F4E2M1FN); + } if (custom_dtypes.float8_e3m4.has_value()) { map->emplace(*custom_dtypes.float8_e3m4, F8E3M4); } @@ -159,6 +172,9 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { map->emplace(custom_dtypes.float8_e4m3fnuz, F8E4M3FNUZ); map->emplace(custom_dtypes.float8_e5m2, F8E5M2); map->emplace(custom_dtypes.float8_e5m2fnuz, F8E5M2FNUZ); + if (custom_dtypes.float8_e8m0fnu.has_value()) { + map->emplace(*custom_dtypes.float8_e8m0fnu, F8E8M0FNU); + } if (custom_dtypes.int2.has_value()) { map->emplace(*custom_dtypes.int2, S2); } @@ -218,6 +234,11 @@ absl::StatusOr PrimitiveTypeToNbDtype(PrimitiveType type) { return to_nb_dtype(NPY_UINT32); case U64: return to_nb_dtype(NPY_UINT64); + case F4E2M1FN: + if (custom_dtypes.float4_e2m1fn.has_value()) { + return *custom_dtypes.float4_e2m1fn; + } + break; case F8E3M4: if (custom_dtypes.float8_e3m4.has_value()) { return *custom_dtypes.float8_e3m4; @@ -238,6 +259,11 @@ absl::StatusOr PrimitiveTypeToNbDtype(PrimitiveType type) { return custom_dtypes.float8_e5m2; case F8E5M2FNUZ: return custom_dtypes.float8_e5m2fnuz; + case F8E8M0FNU: + if (custom_dtypes.float8_e8m0fnu.has_value()) { + return *custom_dtypes.float8_e8m0fnu; + } + break; case BF16: return custom_dtypes.bfloat16; case F16: @@ -308,6 +334,11 @@ absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { return to_nb_dtype(NPY_COMPLEX64); case ifrt::DType::kC128: return to_nb_dtype(NPY_COMPLEX128); + case ifrt::DType::kF4E2M1FN: + if (custom_dtypes.float4_e2m1fn.has_value()) { + return *custom_dtypes.float4_e2m1fn; + } + break; case ifrt::DType::kF8E3M4: if (custom_dtypes.float8_e3m4.has_value()) { return *custom_dtypes.float8_e3m4; @@ -328,6 +359,11 @@ absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { return custom_dtypes.float8_e5m2; case ifrt::DType::kF8E5M2FNUZ: return custom_dtypes.float8_e5m2fnuz; + case ifrt::DType::kF8E8M0FNU: + if (custom_dtypes.float8_e8m0fnu.has_value()) { + return *custom_dtypes.float8_e8m0fnu; + } + break; case ifrt::DType::kString: // PEP 3118 code for "pointer to Python Object". We use Python objects // instead of 'U' (Unicode string) or 'V' (raw data) because the latter @@ -381,6 +417,9 @@ const NumpyScalarTypes& GetNumpyScalarTypes() { dtypes->np_uint32 = nb::object(numpy.attr("uint32")); dtypes->np_uint64 = nb::object(numpy.attr("uint64")); dtypes->np_bfloat16 = nb::object(ml_dtypes.attr("bfloat16")); + if (nb::hasattr(ml_dtypes, "float4_e2m1fn")) { + dtypes->np_float4_e2m1fn = nb::object(ml_dtypes.attr("float4_e2m1fn")); + } if (nb::hasattr(ml_dtypes, "float8_e3m4")) { dtypes->np_float8_e3m4 = nb::object(ml_dtypes.attr("float8_e3m4")); } @@ -393,6 +432,9 @@ const NumpyScalarTypes& GetNumpyScalarTypes() { dtypes->np_float8_e5m2 = nb::object(ml_dtypes.attr("float8_e5m2")); dtypes->np_float8_e4m3fnuz = nb::object(ml_dtypes.attr("float8_e4m3fnuz")); dtypes->np_float8_e5m2fnuz = nb::object(ml_dtypes.attr("float8_e5m2fnuz")); + if (nb::hasattr(ml_dtypes, "float8_e8m0fnu")) { + dtypes->np_float8_e8m0fnu = nb::object(ml_dtypes.attr("float8_e8m0fnu")); + } dtypes->np_float16 = nb::object(numpy.attr("float16")); dtypes->np_float32 = nb::object(numpy.attr("float32")); dtypes->np_float64 = nb::object(numpy.attr("float64")); diff --git a/xla/python/types.h b/xla/python/types.h index 59c27d99184e5c..31ae84d9320845 100644 --- a/xla/python/types.h +++ b/xla/python/types.h @@ -81,6 +81,7 @@ struct NumpyScalarTypes { nanobind::object np_uint64; nanobind::object np_bfloat16; // Remove std::optional once the minimum ml_dtypes in JAX is >= 0.5.0. + std::optional np_float4_e2m1fn; std::optional np_float8_e3m4; std::optional np_float8_e4m3; nanobind::object np_float8_e4m3fn; @@ -88,6 +89,7 @@ struct NumpyScalarTypes { nanobind::object np_float8_e4m3fnuz; nanobind::object np_float8_e5m2; nanobind::object np_float8_e5m2fnuz; + std::optional np_float8_e8m0fnu; nanobind::object np_float16; nanobind::object np_float32; nanobind::object np_float64; diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 0fe3da546b9526..cb1d6f6c5258fc 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -205,8 +205,10 @@ NB_MODULE(xla_extension, m_nb) { .value("U64", U64) .value("F16", F16) // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // .value("F4E2M1FN", F4E2M1FN) // .value("F8E3M4", F8E3M4) // .value("F8E4M3", F8E4M3) + // .value("F8E8M0FNU", F8E8M0FNU) .value("F8E4M3FN", F8E4M3FN) .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) .value("F8E4M3FNUZ", F8E4M3FNUZ) diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index d15ff8201d4b37..770a24e8bc0a81 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -277,8 +277,10 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): bfloat16 = ml_dtypes.bfloat16 # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float4_e2m1fn = ml_dtypes.float4_e2m1fn # float8_e3m4 = ml_dtypes.float8_e3m4 # float8_e4m3 = ml_dtypes.float8_e4m3 +# float8_e8m0fnu = ml_dtypes.float8_e8m0fnu float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz @@ -298,8 +300,10 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType.U32: np.dtype('uint32'), PrimitiveType.U64: np.dtype('uint64'), # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + # PrimitiveType.F4E2M1FN: np.dtype(float4_e2m1fn), # PrimitiveType.F8E3M4: np.dtype(float8_e3m4), # PrimitiveType.F8E4M3: np.dtype(float8_e4m3), + # PrimitiveType.F8E8M0FNU: np.dtype(float8_e8m0fnu), PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn), PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz), PrimitiveType.F8E5M2: np.dtype(float8_e5m2), diff --git a/xla/python/xla_client.pyi b/xla/python/xla_client.pyi index cc7dd65927ee43..e808bbf7c99628 100644 --- a/xla/python/xla_client.pyi +++ b/xla/python/xla_client.pyi @@ -62,8 +62,10 @@ mlir_api_version: int bfloat16: type[numpy.generic] # TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float4_e2m1fn: type[numpy.generic] # float8_e3m4: type[numpy.generic] # float8_e4m3: type[numpy.generic] +# float8_e8m0fnu: type[numpy.generic] float8_e4m3fn: type[numpy.generic] float8_e4m3b11fnuz: type[numpy.generic] float8_e4m3fnuz: type[numpy.generic] diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index 75d76a683deb07..848a6280981b03 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -56,8 +56,10 @@ bfloat16 = xla_client.bfloat16 # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float4_e2m1fn = xla_client.float4_e2m1fn # float8_e3m4 = xla_client.float8_e3m4 # float8_e4m3 = xla_client.float8_e4m3 +# float8_e8m0fnu = xla_client.float8_e8m0fnu float8_e4m3fn = xla_client.float8_e4m3fn float8_e4m3fnuz = xla_client.float8_e4m3fnuz float8_e4m3b11fnuz = xla_client.float8_e4m3b11fnuz @@ -190,7 +192,7 @@ def TestFactory(xla_backend, fp8_dtypes = [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] standard_dtypes += fp8_dtypes # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. - # standard_dtypes += [float8_e3m4, float8_e4m3] + # standard_dtypes += [float4_e2m1fn, float8_e3m4, float8_e4m3, float8_e8m0fnu] dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes class ComputationTest(parameterized.TestCase): diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index cd6311ad06fa84..a11403910630c5 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -74,6 +74,7 @@ class PrimitiveType(enum.IntEnum): U16: PrimitiveType U32: PrimitiveType U64: PrimitiveType + F4E2M1FN: PrimitiveType F8E3M4: PrimitiveType F8E4M3: PrimitiveType F8E4M3FN: PrimitiveType @@ -81,6 +82,7 @@ class PrimitiveType(enum.IntEnum): F8E4M3FNUZ: PrimitiveType F8E5M2: PrimitiveType F8E5M2FNUZ: PrimitiveType + F8E8M0FNU: PrimitiveType BF16: PrimitiveType F16: PrimitiveType F32: PrimitiveType diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index bcd30e2aae8872..ef68a6dbbe95a0 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -607,6 +607,10 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(&f8e4m3fnuz_support); FloatSupport f8e3m4_support(F8E3M4, F16); pipeline.AddPass(&f8e3m4_support); + FloatSupport f4e2m1fn_support(F4E2M1FN, F16); + pipeline.AddPass(&f4e2m1fn_support); + FloatSupport f8e8m0fnu_support(F8E8M0FNU, F32); + pipeline.AddPass(&f8e8m0fnu_support); // After canonicalization, there may be more batch dots that can be // simplified. pipeline.AddPass(); diff --git a/xla/service/cpu/onednn_memory_util.h b/xla/service/cpu/onednn_memory_util.h index 18841d2712dcbc..90c4f6c82e4082 100644 --- a/xla/service/cpu/onednn_memory_util.h +++ b/xla/service/cpu/onednn_memory_util.h @@ -73,7 +73,7 @@ inline dnnl::memory::data_type ToOneDnnDataType(PrimitiveType ptype) { // TODO(intel-tf): properly handle not supported types: // S16, S64, U16, U32, U64, C64, C128, F8E5M2, F8E4M3FN, S4, U4, - // F8E4M3B11FNUZ, F8E4M3, F8E3M4 + // F8E4M3B11FNUZ, F8E4M3, F8E3M4, F4E2M1FN, F8E8M0FNU default: return dt::undef; } diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index d1276e1717bab1..f8dc9f3dfbc5cd 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -809,6 +809,223 @@ llvm::Value* EmitF8e4m3b11fnuzToF16(llvm::Value* f8_value, return f16_value; } +absl::StatusOr EmitF16ToF4e2m1fn(llvm::Value* f16_value, + llvm::IRBuilderBase* b) { + auto i8_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt8Ty(), val); + }; + auto i16_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt16Ty(), val); + }; + constexpr int mantissa_diff = 9; // 10 for F16, 1 for F4 + constexpr int bias_diff = 14; // 15 for F16, 1 for F4 + + // Cast the input value to an integer for bitwise manipulation. + // Get the absolute value of the input (discard the sign). + // f16_bits = bitcast(f16_value, int) + // f16_abs_bits = f16_bits & 0x7FFF + llvm::Value* f16_bits = b->CreateBitCast(f16_value, b->getInt16Ty()); + llvm::Value* f16_abs_bits = b->CreateAnd(f16_bits, i16_const(0x7FFF)); + + // If the input absolute value is >= 7.0 or an infinity, the result saturates + // to max value (6.0). If (0.75 <= input < 1), the result is rounded to 1.0. + // If (0 <= input <= 0.25), the result is rounded to 0.0. + // If the input is NaN, the result is undefined (implemented as minus zero). + // The rest of the cases are handled by the "happy path". + // is_overflow = f16_abs_bits >= 0x1.Cp2 + // is_one = f16_abs_bits >= 0x1.8p-1 (used only if exponent underflows) + // is_zero = f16_abs_bits <= 0x1p-2 (used only if exponent underflows) + // is_nan = f16_abs_bits > 0x7C00 (F16 NaN threshold) + llvm::Value* is_overflow = + b->CreateICmpUGE(f16_abs_bits, i16_const(0x4700)); // 7.0 + llvm::Value* is_one = + b->CreateICmpUGE(f16_abs_bits, i16_const(0x3A00)); // 0.75 + llvm::Value* is_zero = + b->CreateICmpULE(f16_abs_bits, i16_const(0x3400)); // 0.25 + llvm::Value* is_nan = + b->CreateICmpUGT(f16_abs_bits, i16_const(0x7C00)); // inf + + // Truncate the mantissa to 1 bit and the exponent to 3 bits (not 2 bits, as + // the type doesn't have Inf/NaN and can represent unbiased exponent 2). + // This case, as well as the denormal, is handled below. + TF_ASSIGN_OR_RETURN( + llvm::Value * reduced_precision, + EmitReducePrecisionIR( + /*src_ty=*/F16, f16_value, + /*dest_exponent_bits=*/primitive_util::ExponentWidth(F4E2M1FN) + 1, + /*dest_mantissa_bits=*/primitive_util::SignificandWidth(F4E2M1FN) - 1, + /*quiet_nans=*/false, b)); + + // Cast the reduced precision value to an integer for bitwise manipulation. + // Discard the least significant (9) mantissa bits leaving 1 bit. + // Truncate to + // as_int16 = bitcast(reduced_precision, int) + // as_int8 = as_int16 >> (f16_mantissa - f4_mantissa) + llvm::Value* as_int16 = b->CreateBitCast(reduced_precision, b->getInt16Ty()); + llvm::Value* as_int8 = + b->CreateTrunc(b->CreateLShr(as_int16, mantissa_diff), b->getInt8Ty()); + + // Get the sign (0 or 1). + // f4_sign = as_int8 >> 6 + llvm::Value* f4_sign = b->CreateLShr(as_int8, 6); + + // Get exponent and mantissa bits without the sign. + // Important: the mask is 0x3F (not 0x7F), discard bit #6. + // f4_bits = as_int8 & 0x3F + llvm::Value* f4_bits = b->CreateAnd(as_int8, i8_const(0x3F)); + + // Convert F16 exponent to F4 exponent by readjusting the exponent bias. + // This produces the "normal" result, i.e. not Inf or NaN or denormal. + // f4_normal = f4_bits - ((f16_bias - f4_bias) << f4_mantissa) + constexpr int f4_exponent_offset = bias_diff << 1; + llvm::Value* f4_normal = b->CreateSub(f4_bits, i8_const(f4_exponent_offset)); + + // If the rounding resulted in zero exponent, the value is incorrect. + // This happens when the input is < 1.0 + // is_underflow = f4_normal <= 1 + llvm::Value* is_underflow = b->CreateICmpSLE(f4_normal, i8_const(1)); + + // Chain of selects that handles the special cases. + // f4_result = + // is_underflow ? (is_one ? 1.0 : (is_zero ? 0.0 : 0.5)) : + // is_overflow ? (is_nan ? -0.0 : 6.0) : + // f4_normal + llvm::Value* f4_result = b->CreateSelect( + is_underflow, + // If underflow, the input is < 1.0; the result is either 0.0, 0.5 or 1.0 + b->CreateSelect(is_one, i8_const(0x2), + b->CreateSelect(is_zero, i8_const(0x0), i8_const(0x1))), + // If overflow, the input is >= 7.0 or infinity or NaN. + b->CreateSelect(is_overflow, + b->CreateSelect(is_nan, i8_const(0x8), i8_const(0x7)), + f4_normal)); + + // Add sign to the resulting value. + // f4_signed_result = (f4_sign << 3) | f4_result + return b->CreateOr(f4_result, b->CreateShl(f4_sign, 3)); +} + +llvm::Value* EmitF4e2m1fnToF16(llvm::Value* f8_value, llvm::IRBuilderBase* b) { + auto i16_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt16Ty(), val); + }; + constexpr int mantissa_diff = 9; // 10 for F16, 1 for F4 + constexpr int bias_diff = 14; // 15 for F16, 1 for F4 + + // The input value is a 8-bit integer, extend it to 16-bit integer. + // as_int16 = bitcast(f8_value, int) + llvm::Value* as_int16 = b->CreateZExt(f8_value, b->getInt16Ty()); + + // Get the sign and shift it to F16 position. + // f4_sign = as_int16 >> 3 + // f16_sign_bit = f4_sign << 15 + llvm::Value* f4_sign = b->CreateLShr(as_int16, 3); + llvm::Value* f16_sign_bit = b->CreateShl(f4_sign, 15); + + // Get exponent and mantissa bits without the sign. + // f4_bits = as_int16 & 0x7 + // f16_bits = f4_bits << (f16_mantissa - f4_mantissa) + llvm::Value* f4_bits = b->CreateAnd(as_int16, i16_const(0x7)); + llvm::Value* f16_bits = b->CreateShl(f4_bits, mantissa_diff); + + // Convert F16 exponent to F4 exponent by readjusting the exponent bias. + // f4_normal = f4_bits - ((f16_bias - f4_bias) << f4_mantissa) + constexpr int f16_exponent_offset = bias_diff << 10; + llvm::Value* f16_normal = + b->CreateAdd(f16_bits, i16_const(f16_exponent_offset)); + + // For denormal and zero, the exponent is different. Handle these cases + // separately below. + // is_denorm_or_zero = f4_bits <= 1 + // is_zero = f4_bits == 0 + llvm::Value* is_denorm_or_zero = b->CreateICmpULE(f4_bits, i16_const(1)); + llvm::Value* is_zero = b->CreateICmpEQ(f4_bits, i16_const(0)); + + // Chain of selects that handles the special cases. + // f16_result = is_denorm_or_zero ? (is_zero ? 0.0 : 0.5) : f16_normal + llvm::Value* f16_result = b->CreateSelect( + is_denorm_or_zero, + b->CreateSelect(is_zero, i16_const(0x0000), i16_const(0x3800)), + f16_normal); + + // Add sign to the resulting value. + // f16_signed_result = f16_sign_bit | f16_result + llvm::Value* f16_signed_result = b->CreateOr(f16_result, f16_sign_bit); + return b->CreateBitCast(f16_signed_result, b->getHalfTy()); +} + +llvm::Value* EmitF32ToF8e8m0fnu(llvm::Value* f32_value, + llvm::IRBuilderBase* b) { + auto i32_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt32Ty(), val); + }; + + // Cast the input value to an integer for bitwise manipulation. + // as_int32 = bitcast(f32_value, int) + llvm::Value* as_int32 = b->CreateBitCast(f32_value, b->getInt32Ty()); + + // Check if the input is zero, negative, overflow, infinity or NaN. + // All of these cases cannot be represented in the E8M0 format. + // is_zero_or_negative = as_int32 <= 0 + // is_overflow_or_nan = as_int32 >= 0x1.8p127 + // is_nan = is_zero_or_negative | is_overflow_or_nan + llvm::Value* is_zero_or_negative = b->CreateICmpSLE(as_int32, i32_const(0)); + llvm::Value* is_overflow_or_nan = + b->CreateICmpSGE(as_int32, i32_const(0x7F400000)); // 1.5 * 2^127 + llvm::Value* is_nan = b->CreateOr(is_zero_or_negative, is_overflow_or_nan); + + // Check if the input is a denormal which should round to the minimum value + // (2^-127), as there is no zero value. + // is_denorm = as_int32 <= 0x1p-127 + llvm::Value* is_denorm = + b->CreateICmpULE(as_int32, i32_const(0x400000)); // 1.0 * 2^-127 + + // Round the value (always up) and discard the mantissa. + // rounded = as_int32 + 0x1p-127 + // f8_normal = as_int32 >> f32_mantissa + llvm::Value* rounded = + b->CreateAdd(as_int32, i32_const(0x400000)); // 1.0 * 2^-127 + llvm::Value* f8_normal = b->CreateAShr(rounded, 23); + + // Chain of selects that handles the special cases. + // f8_result = is_nan ? 0xFF : (is_denorm ? 0x00 : f8_normal) + llvm::Value* f8_result = + b->CreateSelect(is_nan, i32_const(0xFF), + b->CreateSelect(is_denorm, i32_const(0x00), f8_normal)); + + // Truncate to the result type. + return b->CreateTrunc(f8_result, b->getInt8Ty()); +} + +llvm::Value* EmitF8e8m0fnuToF32(llvm::Value* f8_value, llvm::IRBuilderBase* b) { + auto i32_const = [&](int val) { + return llvm::ConstantInt::get(b->getInt32Ty(), val); + }; + + // The input value is a 8-bit integer, extend it to 32-bit integer. + // as_int32 = bitcast(f8_value, int) + llvm::Value* as_int32 = b->CreateZExt(f8_value, b->getInt32Ty()); + + // Check if the input is a denormal or NaN. + // is_zero = as_int32 == 0x00 + // is_nan = as_int32 == 0xFF + llvm::Value* is_zero = b->CreateICmpEQ(as_int32, i32_const(0)); + llvm::Value* is_nan = b->CreateICmpEQ(as_int32, i32_const(0xFF)); + + // Shift exponent to the left for the normal case. + // f32_normal = as_int32 << mantissa_diff + llvm::Value* f32_normal = b->CreateShl(as_int32, 23); + + // Chain of selects that handles the special cases. + // f32_result = is_nan ? 0x7FC00000 : (is_zero ? 0x1p-127 : f32_normal) + llvm::Value* f32_result = b->CreateSelect( + is_nan, i32_const(0x7FC00000), + b->CreateSelect(is_zero, i32_const(0x400000), f32_normal)); + + // Bitcast integer bits to the result type. + return b->CreateBitCast(f32_result, b->getFloatTy()); +} + llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, PrimitiveType from_type, PrimitiveType to_type, llvm::Module* module, @@ -902,6 +1119,18 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( b_), b_); } + if (to_type == F4E2M1FN) { + return EmitF16ToF4e2m1fn( + EmitIntegralToFloating(operand_value, from_type, F16, module_, + b_), + b_); + } + if (to_type == F8E8M0FNU) { + return EmitF32ToF8e8m0fnu( + EmitIntegralToFloating(operand_value, from_type, F32, module_, + b_), + b_); + } if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { return EmitFloatingToF8fnuz( F16, @@ -1105,10 +1334,29 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return operand_value; } } + if (from_type == F4E2M1FN) { + TF_RET_CHECK(to_type != F4E2M1FN); + operand_value = EmitF4e2m1fnToF16(operand_value, b_); + from_type = F16; + if (from_type == to_type) { + return operand_value; + } + } + if (from_type == F8E8M0FNU) { + TF_RET_CHECK(to_type != F8E8M0FNU); + operand_value = EmitF8e8m0fnuToF32(operand_value, b_); + from_type = F32; + if (from_type == to_type) { + return operand_value; + } + } if (from_type == F8E5M2FNUZ || from_type == F8E4M3FNUZ) { TF_RET_CHECK(to_type != from_type); PrimitiveType cast_type = primitive_util::IsFloatingPointType(to_type) ? to_type : F16; + if (to_type == F8E8M0FNU) { + cast_type = F32; + } TF_ASSIGN_OR_RETURN(operand_value, EmitF8fnuzToFloating(from_type, operand_value, cast_type, b_, module_)); @@ -1176,6 +1424,22 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitF16ToF8e4m3b11fnuz(operand_value, b_); } + if (to_type == F4E2M1FN) { + // Cast to F16 first. Casts to F4E2M1FN must be from F16. + if (from_type != F16) { + operand_value = b_->CreateFPCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_)); + } + return EmitF16ToF4e2m1fn(operand_value, b_); + } + if (to_type == F8E8M0FNU) { + // Cast to F32 first. Casts to F8E8M0FNU must be from F32. + if (from_type != F32) { + operand_value = b_->CreateFPCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(F32, module_)); + } + return EmitF32ToF8e8m0fnu(operand_value, b_); + } if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { return EmitFloatingToF8fnuz(from_type, operand_value, to_type, b_); } @@ -1721,6 +1985,12 @@ absl::StatusOr ElementalIrEmitter::EmitFloatBinaryOp( } else if (operand_type == F8E4M3FN) { lhs_value = EmitF8e4m3fnToF16(lhs_value, b_); rhs_value = EmitF8e4m3fnToF16(rhs_value, b_); + } else if (operand_type == F4E2M1FN) { + lhs_value = EmitF4e2m1fnToF16(lhs_value, b_); + rhs_value = EmitF4e2m1fnToF16(rhs_value, b_); + } else if (operand_type == F8E8M0FNU) { + lhs_value = EmitF8e8m0fnuToF32(lhs_value, b_); + rhs_value = EmitF8e8m0fnuToF32(rhs_value, b_); } else if (operand_type == F8E5M2FNUZ || operand_type == F8E4M3FNUZ) { TF_ASSIGN_OR_RETURN( lhs_value, @@ -3569,9 +3839,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( primitive_util::IsFloatingPointType(component_element_type)) << component_element_type; llvm::Type* float_ir_type; - if (component_element_type == F8E4M3FNUZ) { - float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_); - } else if (component_element_type == F8E5M2FNUZ) { + if (component_element_type == F8E4M3FNUZ || + component_element_type == F8E5M2FNUZ) { float_ir_type = llvm_ir::PrimitiveTypeToIrType(F16, module_); } else { float_ir_type = diff --git a/xla/service/elemental_ir_emitter_test.cc b/xla/service/elemental_ir_emitter_test.cc index c1edb9a4b856d7..050f6c386c9fc2 100644 --- a/xla/service/elemental_ir_emitter_test.cc +++ b/xla/service/elemental_ir_emitter_test.cc @@ -99,9 +99,10 @@ class ElementalIrEmitterExecutionTypedTest }; using FloatTypes = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ElementalIrEmitterExecutionTypedTest, FloatTypes); @@ -614,7 +615,9 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { std::is_same() || std::is_same() || std::is_same() || - std::is_same()) { + std::is_same() || + std::is_same() || + std::is_same()) { GTEST_SKIP() << "Skipping test for type " << tname; } const auto hlo_text = absl::StrReplaceAll(R"( @@ -629,6 +632,10 @@ TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { TYPED_TEST(ElementalIrEmitterExecutionTypedTest, BatchDotFloat) { auto tname = this->TypeName(); + if (std::is_same() || + std::is_same()) { + GTEST_SKIP() << "Skipping test for type " << tname; + } const auto hlo_text = absl::StrReplaceAll(R"( HloModule matmul diff --git a/xla/service/float8_fnuz_ir_emitter.cc b/xla/service/float8_fnuz_ir_emitter.cc index 4afb96362cf86e..e0be95da5f6680 100644 --- a/xla/service/float8_fnuz_ir_emitter.cc +++ b/xla/service/float8_fnuz_ir_emitter.cc @@ -40,6 +40,8 @@ namespace { absl::StatusOr PrimitiveTypeToAPFloatSemantics( PrimitiveType type) { switch (type) { + case F4E2M1FN: + return &llvm::APFloat::Float4E2M1FN(); case F8E3M4: return &llvm::APFloat::Float8E3M4(); case F8E4M3: @@ -54,6 +56,8 @@ absl::StatusOr PrimitiveTypeToAPFloatSemantics( return &llvm::APFloat::Float8E5M2(); case F8E5M2FNUZ: return &llvm::APFloat::Float8E5M2FNUZ(); + case F8E8M0FNU: + return &llvm::APFloat::Float8E8M0FNU(); case BF16: return &llvm::APFloat::BFloat(); case F16: @@ -72,6 +76,8 @@ absl::StatusOr PrimitiveTypeToAPFloatSemantics( absl::StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilderBase* b, PrimitiveType type) { switch (type) { + case F4E2M1FN: + return b->getIntNTy(4); case F8E3M4: case F8E4M3: case F8E4M3B11FNUZ: @@ -79,6 +85,7 @@ absl::StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilderBase* b, case F8E4M3FNUZ: case F8E5M2: case F8E5M2FNUZ: + case F8E8M0FNU: return b->getInt8Ty(); case BF16: return b->getBFloatTy(); @@ -649,8 +656,14 @@ absl::StatusOr EmitF8fnuzToFloating(PrimitiveType input_type, llvm::ConstantInt::get(b->getInt8Ty(), 0x0u), sign); // Bitwise or the sign bit back in. - sign = b->CreateZExt(sign, output_int_type); - sign = b->CreateShl(sign, output_type_bit_width - BitWidth(input_type)); + int shift = output_type_bit_width - BitWidth(input_type); + if (shift >= 0) { + sign = b->CreateZExt(sign, output_int_type); + sign = b->CreateShl(sign, shift); + } else { + sign = b->CreateLShr(sign, -shift); + sign = b->CreateTrunc(sign, output_int_type); + } llvm::Value* result = b->CreateOr(sign, result_abs); // Bitcast to the output type. diff --git a/xla/service/gpu/fusions/transforms/expand_float_ops.cc b/xla/service/gpu/fusions/transforms/expand_float_ops.cc index 6fea3a97527f9b..d72b8eace3c09b 100644 --- a/xla/service/gpu/fusions/transforms/expand_float_ops.cc +++ b/xla/service/gpu/fusions/transforms/expand_float_ops.cc @@ -163,7 +163,13 @@ int GetSignificandBits(mlir::FloatType ty) { } int GetExponentBias(mlir::FloatType ty) { - return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics()); + return 1 - llvm::APFloat::semanticsMinExponent(ty.getFloatSemantics()) - + ty.isFloat8E8M0FNU(); // No zero exponent for E8M0. +} + +bool IsFNUZ(mlir::FloatType ty) { + return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() || + ty.isFloat8E5M2FNUZ(); } Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) { @@ -175,7 +181,7 @@ Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) { return b.create(ma::CmpFPredicate::OEQ, value, inf); } - assert(ty.getIntOrFloatBitWidth() == 8); + assert(ty.getIntOrFloatBitWidth() <= 8); // F8E5M2, F8E4M3, F8E3M4 are the only 8 bit float with infinities. if (ty.isFloat8E5M2()) { Val bits{b.create(b.getI8Type(), value), &b}; @@ -196,6 +202,9 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { if (mlir::LLVM::isCompatibleOuterType(ty)) { return b.create(ma::CmpFPredicate::UNO, value, value); } + if (ty.isFloat4E2M1FN()) { + return b.create(false, b.getI1Type()); + } assert(ty.getIntOrFloatBitWidth() == 8); Val bits{b.create(b.getI8Type(), value), &b}; @@ -207,6 +216,8 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { return (bits & 0b0111'1111) == 0b0111'1111; } else if (ty.isFloat8E3M4()) { return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'0000); + } else if (ty.isFloat8E8M0FNU()) { + return bits == 0xFF; } return bits == 0x80; } @@ -281,11 +292,18 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, auto to_int_ty = b.getIntegerType(to_ty.getIntOrFloatBitWidth()); mlir::IntegerType wide_int_ty; - if (from_ty.getWidth() == 8 && to_ty.getWidth() == 8) { + if (from_ty.getWidth() <= 8 && to_ty.getWidth() <= 8) { wide_int_ty = b.getI16Type(); } else { wide_int_ty = b.getIntegerType( std::max(from_int_ty.getWidth(), to_int_ty.getWidth())); + // Avoid overflow for bit shifts. + auto may_overflow = [&](mlir::Type a, mlir::Type b) { + return a.isFloat8E8M0FNU() && b.isF16(); + }; + if (may_overflow(from_ty, to_ty) || may_overflow(to_ty, from_ty)) { + wide_int_ty = b.getI32Type(); + } } auto convert_int = [&](mlir::Type ty, Value v) -> Val { if (v.getType() == ty) { @@ -300,24 +318,23 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, int64_t exp_offset = to_bias - from_bias; int digit_shift = to_mantissa - from_mantissa; - Val from_bits{ - b.create( - b.getIntegerType(value.getType().getIntOrFloatBitWidth()), value), - &b}; + int from_width = value.getType().getIntOrFloatBitWidth(); + Val from_bits{b.create(b.getIntegerType(from_width), value), + &b}; + if (from_width < 8) { + from_bits = convert_int(b.getIntegerType(8), from_bits); + } auto cst = [&](mlir::Type ty, int64_t n) -> Val { return {b.create(n, ty), &b}; }; // Shift bits to destination type, without sign bit. - Val from_sign_bit = - from_bits.shrui(value.getType().getIntOrFloatBitWidth() - 1) != 0; - - from_bits = - from_bits & ((1ULL << (value.getType().getIntOrFloatBitWidth() - 1)) - 1); - - Value result_is_inf = IsInf(value, b); - Value input_is_nan = IsNaN(value, b); + Val from_sign_bit; + if (!from_ty.isFloat8E8M0FNU()) { + from_sign_bit = from_bits.shrui(from_width - 1) != 0; + from_bits = from_bits & ((1ULL << (from_width - 1)) - 1); + } auto cst_bits = [&](llvm::APFloat f) { return cst(b.getIntegerType(llvm::APFloat::getSizeInBits(f.getSemantics())), @@ -327,7 +344,17 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, Value to_nan = cst_bits(llvm::APFloat::getNaN(to_ty.getFloatSemantics())); Val to_zero = cst_bits(llvm::APFloat::getZero(to_ty.getFloatSemantics())); - auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) { + // MX float types have neither infinities nor NaNs. + if (to_ty.isFloat4E2M1FN()) { + to_inf = cst_bits(llvm::APFloat::getLargest(to_ty.getFloatSemantics())); + to_nan = to_zero | 0x8; + } else if (to_ty.isFloat8E8M0FNU()) { + to_inf = to_nan; + to_zero = Val{to_nan, &b}; + } + + auto round_bits_to_nearest_even = [&](Val bits, Val roundoff, + bool use_implicit_bit = false) { assert(bits.value.getType() == roundoff.value.getType()); // Round to nearest even by adding a bias term. // Consider a bit pattern @@ -337,9 +364,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, // - L is 1, R is 1, OR // - L is 0, R is 1, any T is one. // We do this by adding L to a bit pattern consisting of all T = 1. - Val rounded = (bits.shrui(roundoff) & 1) + - (bits.MakeConstant(1).shl(roundoff - 1) - 1); - Val bias{b.create(roundoff == 0, roundoff, rounded), &b}; + Val bias = !use_implicit_bit + ? (bits.shrui(roundoff) & 1) + + (bits.MakeConstant(1).shl(roundoff - 1) - 1) + : bits.MakeConstant(1).shl(roundoff - 1); return bits + bias; }; @@ -349,9 +377,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, // Round the mantissa if it is shrinking. Val rounded_from_bits = convert_int(wide_int_ty, from_bits); if (digit_shift < 0) { - rounded_from_bits = round_bits_to_nearest_even( - from_bits, from_bits.MakeConstant(-digit_shift)) & - ~((1ll << (-digit_shift)) - 1); + rounded_from_bits = + round_bits_to_nearest_even( + rounded_from_bits, rounded_from_bits.MakeConstant(-digit_shift), + /*use_implicit_bit=*/to_mantissa == 0) & + ~((1ll << (-digit_shift)) - 1); } // Re-bias the exponent. @@ -394,10 +424,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, Val bits = convert_int(wide_int_ty, from_bits); // Determine exponent in target type. - Value normalization_factor = - convert_int(i32_ty, - b.create(from_bits)) - - (from_int_ty.getWidth() - from_mantissa - 1); + Value clz = convert_int( + i32_ty, b.create(from_bits)); + Value msb = cst(i32_ty, std::max(from_width, 8) - 1) - clz; + Value normalization_factor = cst(i32_ty, from_mantissa) - msb; Val biased_exponent = cst(i32_ty, exp_offset + 1) - normalization_factor; // If the result is subnormal, adjust the subnormal bits to account for @@ -418,10 +448,12 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, Value biased_exp_sle_zero = biased_exponent.cmp(CmpIPredicate::sle, 0); bits.value = b.create(biased_exp_sle_zero, subnormal_bits, normal_bits); - if (digit_shift > 0) { + if (digit_shift >= 0) { bits = bits.shl(digit_shift); } else { - bits = round_bits_to_nearest_even(bits, bits.MakeConstant(-digit_shift)); + bits = round_bits_to_nearest_even( + bits, bits.MakeConstant(-digit_shift), + /*use_implicit_bit=*/to_mantissa == 0 && exp_offset != 0); bits = bits.shrui(-digit_shift); } bits = convert_int(to_int_ty, bits); @@ -430,11 +462,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, } else if (to_min_exp > from_min_exp) { // `To` supports fewer exponents near zero which means that some values in // `From` may become subnormal. - Val unbiased_exp = biased_from_exp - from_bias; - Val biased_to_exp = unbiased_exp + to_bias; + Val biased_to_exp = biased_from_exp + (to_bias - from_bias); // Subnormals and zero. // Round and shift mantissa down. - Val from_has_leading_one = biased_from_exp != 0; + Val from_has_leading_one = + !from_ty.isFloat8E8M0FNU() ? biased_from_exp != 0 : cst(i32_ty, 1); Val from_has_leading_one_i32 = convert_int(i32_ty, from_has_leading_one); from_has_leading_one = convert_int(from_int_ty, from_has_leading_one); Val exponent_shift_i32 = @@ -469,31 +501,35 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty, result); } - // Handle types with no unsigned zero. - auto is_nuz = [](mlir::FloatType ty) { - return ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E4M3FNUZ() || - ty.isFloat8E5M2FNUZ(); - }; + Value result_is_inf = IsInf(value, b); + Value input_is_nan = IsNaN(value, b); - if (is_nuz(to_ty)) { + if (to_ty.isFloat8E8M0FNU()) { + // Converting a negative number to E8M0 results in NaN. + input_is_nan = from_sign_bit | input_is_nan; + } else if (IsFNUZ(to_ty)) { // Clear the sign bit if the result is zero (the output has no negative - // zero). - Val result_is_non_zero = Val{result, &b} != 0; + // zero). Handle the edge case when the input is zero and the result is not. + Val result_is_non_zero = + (digit_shift > 0 ? from_bits : Val{result, &b}) != 0; from_sign_bit = from_sign_bit & result_is_non_zero; - } else if (is_nuz(from_ty)) { + } else if (IsFNUZ(from_ty)) { // Clear the sign bit if the input is NaN (it's positive but encoded as // negative 0). from_sign_bit = from_sign_bit ^ input_is_nan; } + if (!from_ty.isFloat8E8M0FNU()) { + result = b.create(from_bits == 0, to_zero, result); + } result = b.create(result_is_inf, to_inf, result); - result = b.create(from_bits == 0, to_zero, result); result = b.create(input_is_nan, to_nan, result); - Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1)); - // Insert sign bit. - result = b.create(from_sign_bit, neg_result, result); + if (!from_ty.isFloat8E8M0FNU()) { + Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth() - 1)); + result = b.create(from_sign_bit, neg_result, result); + } result = b.create(to_ty, result); return result; } @@ -506,8 +542,8 @@ struct RewriteTruncFPattern : public mlir::OpRewritePattern { using FloatValue = mlir::TypedValue; auto src = mlir::cast(op.getOperand()); auto dst_ty = mlir::cast(op.getType()); - if (dst_ty.getWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit truncf"); + if (dst_ty.getWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) truncf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -524,8 +560,8 @@ struct RewriteExtFPattern : public mlir::OpRewritePattern { using FloatValue = mlir::TypedValue; auto src = mlir::cast(op.getOperand()); auto dst_ty = mlir::cast(op.getType()); - if (src.getType().getWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit extf"); + if (src.getType().getWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) extf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -544,8 +580,8 @@ struct RewriteF8Cst : public mlir::OpRewritePattern { auto lhs = mlir::cast(op.getLhs()); auto rhs = mlir::cast(op.getRhs()); - if (lhs.getType().getWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an 8 bit cmpf"); + if (lhs.getType().getWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an 8 bit (or less) cmpf"); } mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -553,16 +589,16 @@ struct RewriteF8Cst : public mlir::OpRewritePattern { llvm::APFloat rhs_cst(rhs.getType().getFloatSemantics()); if (op.getPredicate() == ma::CmpFPredicate::UNE && mlir::matchPattern(rhs, mlir::m_ConstantFloat(&rhs_cst))) { - Val int_value{b.create(rewriter.getI8Type(), lhs), &b}; + mlir::Type int_ty = rewriter.getIntegerType(lhs.getType().getWidth()); + Val int_value{b.create(int_ty, lhs), &b}; int64_t constant = rhs_cst.bitcastToAPInt().getZExtValue(); // If we're comparing to +-0, compare the absolute values. - if (rhs_cst.isZero() && - (lhs.getType().isFloat8E3M4() || lhs.getType().isFloat8E4M3() || - lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) { - int_value = int_value & 0x7f; - constant &= 0x7f; + if (rhs_cst.isZero() && !IsFNUZ(lhs.getType())) { + int64_t mask = (1 << (lhs.getType().getWidth() - 1)) - 1; + int_value = int_value & mask; + constant &= mask; } - auto cst = b.create(constant, rewriter.getI8Type()); + auto cst = b.create(constant, int_ty); rewriter.replaceOpWithNewOp(op, ma::CmpIPredicate::ne, int_value, cst); return mlir::success(); @@ -586,18 +622,23 @@ struct RewriteAbsFPattern : public mlir::OpRewritePattern { auto src = mlir::cast(op.getOperand()); // LowerGpuOpsToNVVMOps has a lowering for abs that doesn't work with bf16. // Once that's removed, remove the code for BF16 here. - if (src.getType().getWidth() != 8 && !src.getType().isBF16()) { - return rewriter.notifyMatchFailure(op, "not an f8 or bf16 absf"); + if (src.getType().getWidth() > 8 && !src.getType().isBF16()) { + return rewriter.notifyMatchFailure(op, + "not an f8 (or less) or bf16 absf"); } + + // If type is unsigned (E8M0), the operation is no-op. + if (!llvm::APFloat::semanticsHasSignedRepr( + src.getType().getFloatSemantics())) { + rewriter.replaceAllOpUsesWith(op, op.getOperand()); + return mlir::success(); + } + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); mlir::Type i_ty = rewriter.getIntegerType(src.getType().getWidth()); Val value{b.create(i_ty, src), &b}; - if (src.getType().getWidth() == 8) { - value = value & 0x7f; - } else { - CHECK(src.getType().isBF16()); - value = value & 0x7fff; - } + int64_t mask = (1ull << (src.getType().getWidth() - 1)) - 1; + value = value & mask; rewriter.replaceOpWithNewOp(op, src.getType(), value); return mlir::success(); } @@ -609,8 +650,8 @@ struct RewriteIToFpPattern : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite( Op op, mlir::PatternRewriter& rewriter) const override { - if (op.getType().getIntOrFloatBitWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an f8 itofp"); + if (op.getType().getIntOrFloatBitWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an f8 (or less) itofp"); } Value to_float = rewriter.create(op.getLoc(), rewriter.getF32Type(), op.getIn()); @@ -625,8 +666,8 @@ struct RewriteFpToIPattern : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite( Op op, mlir::PatternRewriter& rewriter) const override { - if (op.getIn().getType().getIntOrFloatBitWidth() != 8) { - return rewriter.notifyMatchFailure(op, "not an f8 fptoi"); + if (op.getIn().getType().getIntOrFloatBitWidth() > 8) { + return rewriter.notifyMatchFailure(op, "not an f8 (or less) fptoi"); } Value to_f32 = rewriter.create( op.getLoc(), rewriter.getF32Type(), op.getIn()); diff --git a/xla/service/gpu/fusions/transforms/lower_tensors.cc b/xla/service/gpu/fusions/transforms/lower_tensors.cc index 87238edacec1d3..b5266dd4e278e2 100644 --- a/xla/service/gpu/fusions/transforms/lower_tensors.cc +++ b/xla/service/gpu/fusions/transforms/lower_tensors.cc @@ -194,7 +194,7 @@ std::tuple GetI4IndexAndNibble(Value linear_index, mlir::LLVM::GEPOp CreateGep(TypedValue tensor, Value linear_index, mlir::ImplicitLocOpBuilder& b) { Type element_type = tensor.getType().getElementType(); - if (element_type == b.getI4Type()) { + if (element_type.getIntOrFloatBitWidth() == 4) { element_type = b.getI8Type(); } auto ptr = mlir::LLVM::LLVMPointerType::get(b.getContext()); @@ -222,8 +222,9 @@ struct RewriteTensorExtract : mlir::OpRewritePattern { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto linear_index = GetLinearIndex(op.getIndices(), b); Type element_type = op.getTensor().getType().getElementType(); + Value is_low_nibble = nullptr; - if (element_type == rewriter.getI4Type()) { + if (element_type.getIntOrFloatBitWidth() == 4) { std::tie(linear_index, is_low_nibble) = GetI4IndexAndNibble(linear_index, b); } @@ -238,7 +239,7 @@ struct RewriteTensorExtract : mlir::OpRewritePattern { auto high_value = b.create( load, b.create(4, load.getType())); load = b.create( - op.getType(), + rewriter.getI4Type(), b.create(is_low_nibble, load, high_value)); } @@ -275,6 +276,7 @@ struct RewriteTransferRead auto source = mlir::dyn_cast>( op.getSource()); + mlir::Type source_element_type = source.getType().getElementType(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto linear_index = GetLinearIndex(op.getIndices(), b); @@ -283,7 +285,8 @@ struct RewriteTransferRead if (vector_type.getElementType().isInteger(1)) { vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type()); } - if (op.getVectorType().getElementType().isInteger(4)) { + mlir::Type gep_element_type = vector_type.getElementType(); + if (gep_element_type.getIntOrFloatBitWidth() == 4) { linear_index = b.create( linear_index, b.create(1, linear_index.getType())); @@ -295,11 +298,11 @@ struct RewriteTransferRead auto loaded = b.create(llvm_vector_type, gep).getResult(); - if (source.getType().getElementType().isInteger(1)) { + if (source_element_type.isInteger(1)) { Value zero = b.create( mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0))); loaded = b.create(arith::CmpIPredicate::ne, loaded, zero); - } else if (source.getType().getElementType().isInteger(4)) { + } else if (source_element_type.getIntOrFloatBitWidth() == 4) { // LLVM and XLA pack i4s in opposite order, so we have to reshuffle the // elements. loaded = PermutePairsInVector(loaded, b); @@ -328,7 +331,7 @@ struct RewriteTensorInsert : mlir::OpRewritePattern { auto scalar_value = op.getScalar(); // For i4 we store 2 values into one byte. This needs special handling here. - if (tensor_dest.getType().getElementType() == rewriter.getI4Type()) { + if (tensor_dest.getType().getElementType().getIntOrFloatBitWidth() == 4) { // We need to use directly op.getDest() as input, otherwise the following // rewrite might remove the only user of it. tensor_dest = op.getDest(); @@ -346,6 +349,10 @@ struct RewriteTensorInsert : mlir::OpRewritePattern { auto tensor_dest_i8 = b.create(tensor_ty, tensor_dest) .getResult(0); + if (scalar_value.getType() != rewriter.getI4Type()) { + scalar_value = + b.create(rewriter.getI4Type(), scalar_value); + } scalar_value = b.create(ty, scalar_value); // We need AtomicRMWOp because it can happen that different threads try to @@ -406,12 +413,13 @@ struct RewriteTransferWrite auto linear_index = GetLinearIndex(op.getIndices(), b); mlir::Value vector_value = op.getVector(); - if (op.getVectorType().getElementType().isInteger(1)) { + mlir::Type vector_element_type = op.getVectorType().getElementType(); + if (vector_element_type.isInteger(1)) { vector_value = b.create( op.getVectorType().cloneWith(std::nullopt, b.getI8Type()), vector_value); } - if (op.getVectorType().getElementType().isInteger(4)) { + if (vector_element_type.getIntOrFloatBitWidth() == 4) { linear_index = b.create( linear_index, b.create(1, linear_index.getType())); @@ -477,21 +485,18 @@ mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value, // Needed to support complex element type. mlir::LLVMTypeConverter converter(b.getContext()); auto llvm_element_type = converter.convertType(element_type); - if (mlir::isa(element_type)) { - int bit_width = mlir::cast(element_type).getWidth(); - if (bit_width == 4) { - num_elements = CeilOfRatio(num_elements, 2); - llvm_element_type = b.getI8Type(); - auto unpacked_data = - mlir::cast(value).getRawData(); - std::vector packed_data(num_elements); - absl::Span packed_data_span = - absl::MakeSpan(packed_data.data(), packed_data.size()); - PackIntN(4, unpacked_data, packed_data_span); - value = mlir::DenseElementsAttr::getFromRawBuffer( - mlir::RankedTensorType::get({num_elements}, llvm_element_type), - packed_data); - } + if (element_type.getIntOrFloatBitWidth() == 4) { + num_elements = CeilOfRatio(num_elements, 2); + llvm_element_type = b.getI8Type(); + auto unpacked_data = + mlir::cast(value).getRawData(); + std::vector packed_data(num_elements); + absl::Span packed_data_span = + absl::MakeSpan(packed_data.data(), packed_data.size()); + PackIntN(4, unpacked_data, packed_data_span); + value = mlir::DenseElementsAttr::getFromRawBuffer( + mlir::RankedTensorType::get({num_elements}, llvm_element_type), + packed_data); } auto array_ty = mlir::LLVM::LLVMArrayType::get(llvm_element_type, num_elements); diff --git a/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir b/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir index ff97c1b46f708d..15a8d27d74ad94 100644 --- a/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir +++ b/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir @@ -115,3 +115,53 @@ module { // CHECK: %[[EXT:.*]] = arith.extf {{.*}} : bf16 to f32 // CHECK: arith.truncf %[[EXT]] : f32 to f16 // CHECK-NOT: arith.truncf + +// ----- + +module { + func.func @f4_to_f16(%arg0: f4E2M1FN) -> f16 { + %ret = arith.extf %arg0 : f4E2M1FN to f16 + return %ret : f16 + } +} + +// CHECK-LABEL: @f4_to_f16 +// CHECK-NOT: arith.extf + +// ----- + +module { + func.func @f16_to_f4(%arg0: f16) -> f4E2M1FN { + %ret = arith.truncf %arg0 : f16 to f4E2M1FN + return %ret : f4E2M1FN + } +} + +// CHECK-LABEL: @f16_to_f4 +// CHECK-NOT: arith.truncf + +// ----- + +module { + func.func @f4_abs(%arg0: f4E2M1FN) -> f4E2M1FN { + %ret = math.absf %arg0 : f4E2M1FN + return %ret : f4E2M1FN + } +} + +// CHECK-LABEL: @f4_abs +// CHECK-NOT: math.absf +// CHECK: arith.constant 7 : i4 + +// ----- + +module { + func.func @e8m0_abs(%arg0: f8E8M0FNU) -> f8E8M0FNU { + %ret = math.absf %arg0 : f8E8M0FNU + return %ret : f8E8M0FNU + } +} + +// CHECK-LABEL: @e8m0_abs +// CHECK-NOT: math.absf +// CHECK: return %arg0 diff --git a/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir b/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir index a894c13dce1293..3074483a77b361 100644 --- a/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir +++ b/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir @@ -732,3 +732,55 @@ func.func @int4_constant(%arg0: tensor<3xi4>, %arg1: index) -> i4 { // CHECK: llvm.mlir.global private constant // CHECK-SAME: dense<[18, 48]> // CHECK-LABEL: @int4_constant + +// ----- + +func.func @complex_expm1_approx(%arg0: tensor<3xcomplex>, %i: index) + -> complex { + %extracted = tensor.extract %arg0[%i] : tensor<3xcomplex> + %expm1 = complex.expm1 %extracted : complex + return %expm1 : complex +} +// CHECK-LABEL: @complex_expm1_approx +// CHECK: math.expm1 +// CHECK-COUNT-6: math.fma + +// ----- + +func.func @f4_constant(%arg0: tensor<3xf4E2M1FN>, %arg1: index) -> f4E2M1FN { + %cst = arith.constant dense<[0.5, -0.5, 2.5]> : tensor<3xf4E2M1FN> + %extracted = tensor.extract %arg0[%arg1] : tensor<3xf4E2M1FN> + %extracted_0 = tensor.extract %cst[%arg1] : tensor<3xf4E2M1FN> + %0 = arith.addf %extracted, %extracted_0 : f4E2M1FN + return %0 : f4E2M1FN +} +// CHECK: llvm.mlir.global private constant +// CHECK-SAME: dense<[25, 64]> +// CHECK-LABEL: @f4_constant + +// ----- + +func.func @transfer_read_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1}) -> vector<2xf4E2M1FN> { + %c16 = arith.constant 16 : index + %c0 = arith.constant 0.0 : f4E2M1FN + %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf4E2M1FN>, vector<2xf4E2M1FN> + func.return %out : vector<2xf4E2M1FN> +} +// CHECK-LABEL: @transfer_read_f4 +// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[8] +// CHECK: llvm.load %[[PTR]] : !llvm.ptr -> vector<2xi4> +// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xi4> to vector<2xf4E2M1FN> +// CHECK: return %[[OUT]] : vector<2xf4E2M1FN> + +// ----- + +func.func @transfer_write_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1}, + %arg1: vector<2xf4E2M1FN>) -> tensor<43xf4E2M1FN> { + %c10 = arith.constant 10 : index + %out = vector.transfer_write %arg1, %arg0[%c10] : vector<2xf4E2M1FN>, tensor<43xf4E2M1FN> + func.return %out : tensor<43xf4E2M1FN> +} +// CHECK-LABEL: @transfer_write_f4 +// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %arg0[5] : (!llvm.ptr) -> !llvm.ptr, i8 +// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xf4E2M1FN> to vector<2xi4> +// CHECK: llvm.store %[[OUT]], %[[PTR]] : vector<2xi4>, !llvm.ptr diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 0f2d12aeda1811..23f8cc74c91bc0 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1476,6 +1476,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( const GpuFloatSupport f8e5m2fnuz_support(gpu_version, F8E5M2FNUZ, F16); const GpuFloatSupport f8e4m3fnuz_support(gpu_version, F8E4M3FNUZ, F16); const GpuFloatSupport f8e3m4_support(gpu_version, F8E3M4, F16); + const GpuFloatSupport f4e2m1fn_support(gpu_version, F4E2M1FN, F16); + const GpuFloatSupport f8e8m0fnu_support(gpu_version, F8E8M0FNU, F32); auto add_float_normalization = [&](HloPassPipeline& pipeline) { auto& sub_pipeline = pipeline.AddPass("float_normalization"); @@ -1487,6 +1489,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( sub_pipeline.AddPass(&f8e5m2fnuz_support); sub_pipeline.AddPass(&f8e4m3fnuz_support); sub_pipeline.AddPass(&f8e3m4_support); + sub_pipeline.AddPass(&f4e2m1fn_support); + sub_pipeline.AddPass(&f8e8m0fnu_support); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. if (debug_options.xla_allow_excess_precision()) { sub_pipeline.AddPass(); diff --git a/xla/service/gpu/tests/float_conversions_test.cc b/xla/service/gpu/tests/float_conversions_test.cc index 16383324dfb016..6e0e14e320a7f9 100644 --- a/xla/service/gpu/tests/float_conversions_test.cc +++ b/xla/service/gpu/tests/float_conversions_test.cc @@ -29,9 +29,10 @@ class FloatConversionParamTest INSTANTIATE_TEST_SUITE_P(FloatConversionParamSuite, FloatConversionParamTest, ::testing::Values("f64", "f32", "f16", "bf16", - "f8e5m2", "f8e5m2fnuz", "f8e4m3", - "f8e4m3fn", "f8e4m3fnuz", - "f8e4m3b11fnuz", "f8e3m4")); + "f4e2m1fn", "f8e5m2", "f8e5m2fnuz", + "f8e4m3", "f8e4m3fn", "f8e4m3fnuz", + "f8e4m3b11fnuz", "f8e3m4", + "f8e8m0fnu")); TEST_P(FloatConversionParamTest, FloatToF16) { auto type_name = GetParam(); diff --git a/xla/service/llvm_ir/llvm_util.cc b/xla/service/llvm_ir/llvm_util.cc index 229b7f87b7d2c1..74cc0532516654 100644 --- a/xla/service/llvm_ir/llvm_util.cc +++ b/xla/service/llvm_ir/llvm_util.cc @@ -198,6 +198,8 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, case S16: case U16: return llvm::Type::getInt16Ty(module->getContext()); + case F4E2M1FN: + return llvm::Type::getIntNTy(module->getContext(), 4); case F8E5M2: case F8E5M2FNUZ: case F8E4M3: @@ -205,6 +207,7 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, case F8E4M3B11FNUZ: case F8E4M3FNUZ: case F8E3M4: + case F8E8M0FNU: // We represent F8 as an int since there is no LLVM F8 dtype. return llvm::Type::getInt8Ty(module->getContext()); case BF16: diff --git a/xla/stream_executor/data_type.h b/xla/stream_executor/data_type.h index f5246389e485c3..e3e7d1f17e312f 100644 --- a/xla/stream_executor/data_type.h +++ b/xla/stream_executor/data_type.h @@ -37,6 +37,10 @@ struct ToDataType; // Note: If you add a new specialization below, make sure to add the // corresponding definition in stream_executor/dnn.cc. template <> +struct ToDataType { + static constexpr DataType value = DataType::kF4E2M1FN; +}; +template <> struct ToDataType { static constexpr DataType value = DataType::kF8E3M4; }; @@ -61,6 +65,10 @@ struct ToDataType { static constexpr DataType value = DataType::kF8E5M2FNUZ; }; template <> +struct ToDataType { + static constexpr DataType value = DataType::kF8E8M0FNU; +}; +template <> struct ToDataType { static constexpr DataType value = DataType::kFloat; }; diff --git a/xla/stream_executor/dnn.cc b/xla/stream_executor/dnn.cc index 6b7a87d80b3aec..24851e56d75eda 100644 --- a/xla/stream_executor/dnn.cc +++ b/xla/stream_executor/dnn.cc @@ -69,12 +69,14 @@ bool ProtoMapsEqual(const google::protobuf::Map& x, } // namespace +constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; +constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; diff --git a/xla/stream_executor/gpu/gpu_blas_lt.cc b/xla/stream_executor/gpu/gpu_blas_lt.cc index 6aee86bf2cbc19..182af599af9e5c 100644 --- a/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -56,6 +56,10 @@ absl::StatusOr AsBlasDataType(PrimitiveType dtype) { return DataType::kF8E4M3FNUZ; case PrimitiveType::F8E3M4: return DataType::kF8E3M4; + case PrimitiveType::F4E2M1FN: + return DataType::kF4E2M1FN; + case PrimitiveType::F8E8M0FNU: + return DataType::kF8E8M0FNU; case PrimitiveType::S8: return DataType::kInt8; case PrimitiveType::F16: @@ -93,6 +97,10 @@ absl::StatusOr AsXlaPrimitiveType(DataType dtype) { return PrimitiveType::F8E4M3FNUZ; case DataType::kF8E3M4: return PrimitiveType::F8E3M4; + case DataType::kF4E2M1FN: + return PrimitiveType::F4E2M1FN; + case DataType::kF8E8M0FNU: + return PrimitiveType::F8E8M0FNU; case DataType::kInt8: return PrimitiveType::S8; case DataType::kHalf: @@ -154,6 +162,8 @@ absl::StatusOr GetBlasComputationType( case PrimitiveType::F8E5M2FNUZ: // fall-through case PrimitiveType::F8E4M3FNUZ: // fall-through case PrimitiveType::F8E3M4: // fall-through + case PrimitiveType::F4E2M1FN: // fall-through + case PrimitiveType::F8E8M0FNU: // fall-through case PrimitiveType::F16: // fall-through case PrimitiveType::BF16: // Accumulate in f32 precision. diff --git a/xla/stream_executor/rocm/hip_blas_utils.cc b/xla/stream_executor/rocm/hip_blas_utils.cc index e5730121addd8d..8864476bf0d825 100644 --- a/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/xla/stream_executor/rocm/hip_blas_utils.cc @@ -39,8 +39,10 @@ hipDataType AsHipblasDataType(blas::DataType type) { case blas::DataType::kF8E4M3: case blas::DataType::kF8E4M3FN: case blas::DataType::kF8E3M4: - LOG(FATAL) - << "hipblaslt does not support F8E5M2, F8E4M3, F8E4M3FN and F8E3M4"; + case blas::DataType::kF4E2M1FN: + case blas::DataType::kF8E8M0FNU: + LOG(FATAL) << "hipblaslt does not support F8E5M2, F8E4M3, F8E4M3FN, " + "F8E3M4, F4E2M1FN and F8E8M0FNU"; #if TF_ROCM_VERSION >= 60000 case blas::DataType::kF8E5M2FNUZ: return HIP_R_8F_E5M2_FNUZ; diff --git a/xla/tests/BUILD b/xla/tests/BUILD index ab1c01190fadae..857ed141da14b6 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -844,6 +844,7 @@ xla_test( "//xla:shape_util", "//xla:test", "//xla:types", + "//xla:util", "//xla/client:global_data", "//xla/client:local_client", "//xla/hlo/builder:xla_builder", diff --git a/xla/tests/array_elementwise_ops_test.cc b/xla/tests/array_elementwise_ops_test.cc index c12ce79a06e8fa..b4fc4932a6a11c 100644 --- a/xla/tests/array_elementwise_ops_test.cc +++ b/xla/tests/array_elementwise_ops_test.cc @@ -47,6 +47,7 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/types.h" +#include "xla/util.h" #include "tsl/platform/ml_dtypes.h" #if TENSORFLOW_USE_ROCM @@ -93,6 +94,20 @@ std::pair, std::vector> AllSignedPairs( return {xs, ys}; } +template +void AddNegativeValuesMaybeRemoveZero(std::vector& values) { + values.reserve(values.size() * 2); + if (!has_zero_v) { + values.erase(values.begin()); + } + for (size_t i = 0, n = values.size(); i < n; ++i) { + auto neg = -values[i]; + if (SignAndMagnitude(neg).first) { + values.push_back(neg); + } + } +} + class ArrayElementwiseOpTest : public ClientLibraryTestBase { public: static constexpr float kEpsF32 = std::numeric_limits::epsilon(); @@ -1371,14 +1386,7 @@ class TotalOrderTest : public ClientLibraryTestBase { values.push_back(Eigen::numext::abs(std::numeric_limits::quiet_NaN())); } #endif - values.reserve(values.size() * 2); - for (size_t i = 0, n = values.size(); i < n; ++i) { - auto value = values[i]; - auto neg = -value; - if (Eigen::numext::signbit(neg) != Eigen::numext::signbit(value)) { - values.push_back(neg); - } - } + AddNegativeValuesMaybeRemoveZero(values); std::vector lhs_data; std::vector rhs_data; lhs_data.reserve(values.size() * values.size()); @@ -1423,19 +1431,21 @@ class TotalOrderTest : public ClientLibraryTestBase { } }; -using Types = ::testing::Types; + float>; TYPED_TEST_SUITE(TotalOrderTest, Types); @@ -1462,13 +1472,7 @@ TYPED_TEST(TotalOrderTest, LargeMagnitudeVsNaN) { if constexpr (std::numeric_limits::has_infinity) { values.push_back(std::numeric_limits::infinity()); } - for (size_t i = 0, n = values.size(); i < n; ++i) { - auto value = values[i]; - auto neg = -value; - if (Eigen::numext::signbit(neg) != Eigen::numext::signbit(value)) { - values.push_back(neg); - } - } + AddNegativeValuesMaybeRemoveZero(values); auto lhs = ConstantR1(&builder, values); auto rhs = ConstantR1( &builder, diff --git a/xla/tests/constants_test.cc b/xla/tests/constants_test.cc index 9650077ed57b28..2018f88079f867 100644 --- a/xla/tests/constants_test.cc +++ b/xla/tests/constants_test.cc @@ -48,11 +48,10 @@ class ConstantsTest : public ClientLibraryTestBase { template class ConstantsFloatTest : public ConstantsTest {}; -using FloatTypes = - ::testing::Types; +using FloatTypes = ::testing::Types< + float, half, tsl::float4_e2m1fn, tsl::float8_e3m4, tsl::float8_e4m3, + tsl::float8_e4m3fn, tsl::float8_e4m3b11fnuz, tsl::float8_e4m3fnuz, + tsl::float8_e5m2, tsl::float8_e5m2fnuz, tsl::float8_e8m0fnu>; TYPED_TEST_SUITE(ConstantsFloatTest, FloatTypes); diff --git a/xla/tests/convert_test.cc b/xla/tests/convert_test.cc index 4f06ea0cc290c7..95d6525ab300be 100644 --- a/xla/tests/convert_test.cc +++ b/xla/tests/convert_test.cc @@ -54,9 +54,11 @@ class ConvertTestT : public ConvertTest { using ConvertTest::ConvertTest; }; using FloatingPointTypeList = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ConvertTestT, FloatingPointTypeList); template @@ -741,10 +743,11 @@ XLA_TYPED_TEST(ConvertTestT, ConvertFPToPred) { XlaBuilder builder(this->TestName()); using FP = TypeParam; - auto a = ConstantR1(&builder, {FP{0.0}, FP{0.25}, FP{2.0}, FP{-0.0}}); + auto a = ConstantR1(&builder, {FP{0.0}, FP{0.5}, FP{2.0}, FP{-0.0}}); ConvertElementType(a, PRED); - std::array expected = {false, true, true, false}; + bool zero_pred = !has_zero_v; + std::array expected = {zero_pred, true, true, zero_pred}; this->template ComputeAndCompareR1(&builder, expected, {}); } @@ -1925,5 +1928,274 @@ XLA_TYPED_TEST(ConvertTestF16, ConvertF8e3m4F16RoundtripExhaustive4) { this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +// ----- F4E2M1FN + +XLA_TEST_F(ConvertTest, ConvertF16F4e2m1fnRoundtrip) { + // Convert from FP16 to FP4, then back to FP16. + XlaBuilder builder(TestName()); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {inf, 0x1.8p2}, + // clang-format on + {0x1.4p0, 0x1p0}, // Round-to-even down + {0x1.Cp0, 0x1p1}, // Round-to-even up + {0x1.8p2, 0x1.8p2}, // Max value + {0x1.BFCp2, 0x1.8p2}, // Largest number that doesn't overflow + {0x1.Cp2, 0x1.8p2}, // Smallest number that overflows + {0x1p3, 0x1.8p2}, // Overflow + {0x1p0, 0x1p0}, // Smallest F8 normal + {0x1.8p-1, 0x1p0}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-1, 0x1.0p-1}, // Denormal without rounding + {0x1.8p-1, 0x1.0p0}, // Round-to-even up + {0x1.6p-1, 0x1.0p-1}, // Round-to-nearest down + {0x1.Ep-1, 0x1.0p0}, // Round-to-nearest up + {0x1p-2, 0}, // Largest number that underflows + {0x1.004p-2, 0x1p-1}, // Smallest number that doesn't underflow + {0x1.7FCp-1, 0x1p-1}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(Eigen::half{test_case.input}); + expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); + } + + auto f4 = + ConvertElementType(ConstantR1(&builder, inputs), F4E2M1FN); + ConvertElementType(f4, F16); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, + ErrorSpec(0.)); +} + +XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F4e2m1fnRoundtrip)) { + // Convert from FP32 to FP4, then back to FP32. + XlaBuilder builder(TestName()); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {inf, 0x1.8p2}, + // clang-format on + {0x1.4p0, 0x1p0}, // Round-to-even down + {0x1.Cp0, 0x1p1}, // Round-to-even up + {0x1.8p2, 0x1.8p2}, // Max value + {0x1.BFFFFEp2, 0x1.8p2}, // Largest number that doesn't overflow + {0x1.Cp2, 0x1.8p2}, // Smallest number that overflows + {0x1p3, 0x1.8p2}, // Overflow + {0x1p0, 0x1p0}, // Smallest F8 normal + {0x1.8p-1, 0x1p0}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-1, 0x1.0p-1}, // Denormal without rounding + {0x1.8p-1, 0x1.0p0}, // Round-to-even up + {0x1.6p-1, 0x1.0p-1}, // Round-to-nearest down + {0x1.Ep-1, 0x1.0p0}, // Round-to-nearest up + {0x1p-2, 0}, // Largest number that underflows + {0x1.000002p-2, 0x1p-1}, // Smallest number that doesn't underflow + {0x1.7FFFFEp-1, 0x1p-1}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f4 = ConvertElementType(ConstantR1(&builder, inputs), F4E2M1FN); + ConvertElementType(f4, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF4e2m1fnRoundtripExhaustive) { + // Convert from FP4 to supported floating point type, then back to FP4. + XlaBuilder builder(this->TestName()); + + using From = tsl::float4_e2m1fn; + std::vector all_f4; + for (int i = 0; i < 16; i++) { + all_f4.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f4_as_fp = + ConvertElementType(ConstantR1(&builder, all_f4), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f4_as_fp, F4E2M1FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF4e2m1fnRoundtripExhaustive2) { + // Convert from supported floating point type to FP4. + XlaBuilder builder(this->TestName()); + + std::vector all_f4; + for (int i = 0; i < 16; i++) { + all_f4.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); + } + + ConvertElementType(ConstantR1(&builder, all_f4), F4E2M1FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF4e2m1fnRoundtripExhaustive3) { + // Convert from FP4 to supported floating point type. + XlaBuilder builder(this->TestName()); + + using From = tsl::float4_e2m1fn; + std::vector all_f4; + for (int i = 0; i < 16; i++) { + all_f4.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f4), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, ConvertF4e2m1fnF16RoundtripExhaustive4) { + // Convert from (B)F16 to FP4. + XlaBuilder builder(this->TestName()); + + std::vector all_f16; + for (int i = 0; i < 65536; i++) { + all_f16.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f16), F4E2M1FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +// ----- F8E8M0FNU + +XLA_TEST_F(ConvertTest, ConvertF32F8e8m0fnuRoundtrip) { + // Convert from FP32 to FP8, then back to FP32. + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, nan}, // No zero values + {-0.0, nan}, + {1.0, 1.0}, + {-1.0, nan}, // No negative values + {nan, nan}, + {inf, nan}, + // clang-format on + {0x1.8p1, 0x1p2}, // Round-to-even up + {0x1.8p2, 0x1p3}, // Round-to-even up (always rounds up) + {0x1p127, 0x1p127}, // Max value + {0x1.7FFFFEp127, 0x1p127}, // Largest number that doesn't overflow + {0x1.8p127, nan}, // Smallest number that overflows + {0x1.FFFFFEp127, nan}, // Overflow + {0x1p-126, 0x1p-126}, // Smallest F8 normal + {0x0.800002p-126, 0x1p-126}, // Smallest number rounding up to normal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E8M0FNU); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e8m0fnuRoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e8m0fnu; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E8M0FNU); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e8m0fnuRoundtripExhaustive2) { +#ifdef XLA_TEST_BACKEND_CPU + // This test is disabled on CPU, as converting 0x1p-127 from double to float + // using CVTSD2SS on x64 results in an underflow (even though the result is + // representable as denormalized float32). + if (std::is_same_v) { + GTEST_SKIP() << "Skipping test for double precision floating point that " + "loses denormal value during conversion"; + } +#endif + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); + + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), F8E8M0FNU); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e8m0fnuRoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e8m0fnu; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e8m0fnuF16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector all_f16; + for (int i = 0; i < 65536; i++) { + all_f16.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f16), F8E8M0FNU); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + } // namespace } // namespace xla diff --git a/xla/tools/driver.cc b/xla/tools/driver.cc index 4f4895b57123ae..5e5119a726a236 100644 --- a/xla/tools/driver.cc +++ b/xla/tools/driver.cc @@ -119,6 +119,7 @@ enum PrimitiveType { F64, C64, C128, + F4E2M1FN, F8E5M2, F8E4M3, F8E4M3FN, @@ -126,22 +127,18 @@ enum PrimitiveType { F8E5M2FNUZ, F8E4M3FNUZ, F8E3M4, + F8E8M0FNU, }; const std::vector& primitive_strings() { - static auto vec = new std::vector({"s2", "s4", - "s8", "s16", - "s32", "s64", - "u2", "u4", - "u8", "u16", - "u32", "u64", - "f16", "bf16", - "f32", "f64", - "c64", "c128", - "f8e5m2", "f8e4m3", - "f8e4m3fn", "f8e4m3b11fnuz", - "f8e5m2fnuz", "f8e4m3fnuz", - "f8e3m4"}); + static auto vec = new std::vector( + {"s2", "s4", "s8", "s16", + "s32", "s64", "u2", "u4", + "u8", "u16", "u32", "u64", + "f16", "bf16", "f32", "f64", + "c64", "c128", "f4e2m1fn", "f8e3m4", + "f8e4m3", "f8e4m3fn", "f8e4m3fnuz", "f8e4m3b11fnuz", + "f8e5m2", "f8e5m2fnuz", "f8e8m0fnu"}); return *vec; } @@ -418,6 +415,7 @@ void Fill(void* buffer, const ArrayShape& shape) { case F64: return FillFloatT(buffer, num_elements); + case F4E2M1FN: case F8E5M2: case F8E4M3: case F8E4M3FN: @@ -425,6 +423,7 @@ void Fill(void* buffer, const ArrayShape& shape) { case F8E5M2FNUZ: case F8E4M3FNUZ: case F8E3M4: + case F8E8M0FNU: case F16: case BF16: case C64: @@ -476,6 +475,7 @@ void Display(const void* buffer, const ArrayShape& shape) { case F64: return DisplayT(buffer, num_elements); + case F4E2M1FN: case F8E5M2: case F8E4M3: case F8E4M3FN: @@ -483,6 +483,7 @@ void Display(const void* buffer, const ArrayShape& shape) { case F8E5M2FNUZ: case F8E4M3FNUZ: case F8E3M4: + case F8E8M0FNU: case F16: case BF16: case C64: diff --git a/xla/tsl/framework/type_traits.h b/xla/tsl/framework/type_traits.h index f7a9bd7a54bc91..bda9244327b4fd 100644 --- a/xla/tsl/framework/type_traits.h +++ b/xla/tsl/framework/type_traits.h @@ -70,13 +70,15 @@ struct is_simple_type { std::is_trivial::value || std::is_same::value || std::is_same::value || std::is_same::value || is_quantized::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value; }; diff --git a/xla/tsl/protobuf/dnn.proto b/xla/tsl/protobuf/dnn.proto index 2ac31005c16629..4a6d8fff6f72cd 100644 --- a/xla/tsl/protobuf/dnn.proto +++ b/xla/tsl/protobuf/dnn.proto @@ -24,6 +24,8 @@ enum DataType { kInt64 = 12; kF8E4M3 = 13; kF8E3M4 = 14; + kF4E2M1FN = 15; + kF8E8M0FNU = 16; } // Describes how a convolution input or output layer's data is formatted. diff --git a/xla/tsl/python/lib/core/ml_dtypes.cc b/xla/tsl/python/lib/core/ml_dtypes.cc index e2c5eb295c6b12..a986efb7cca963 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.cc +++ b/xla/tsl/python/lib/core/ml_dtypes.cc @@ -61,6 +61,8 @@ struct MlDtypesInitInfo { numpy_dtypes.bfloat16 = py::dtype::from_args(ml_dtypes.attr("bfloat16")).num(); + numpy_dtypes.float4_e2m1fn = + py::dtype::from_args(ml_dtypes.attr("float4_e2m1fn")).num(); numpy_dtypes.float8_e3m4 = py::dtype::from_args(ml_dtypes.attr("float8_e3m4")).num(); numpy_dtypes.float8_e4m3 = @@ -75,6 +77,8 @@ struct MlDtypesInitInfo { py::dtype::from_args(ml_dtypes.attr("float8_e4m3fnuz")).num(); numpy_dtypes.float8_e5m2fnuz = py::dtype::from_args(ml_dtypes.attr("float8_e5m2fnuz")).num(); + numpy_dtypes.float8_e8m0fnu = + py::dtype::from_args(ml_dtypes.attr("float8_e8m0fnu")).num(); numpy_dtypes.int4 = py::dtype::from_args(ml_dtypes.attr("int4")).num(); numpy_dtypes.uint4 = py::dtype::from_args(ml_dtypes.attr("uint4")).num(); } catch (const std::exception& e) { @@ -85,6 +89,7 @@ struct MlDtypesInitInfo { // Verify all types were successfully loaded. if (numpy_dtypes.bfloat16 == NPY_NOTYPE || + numpy_dtypes.float4_e2m1fn == NPY_NOTYPE || numpy_dtypes.float8_e3m4 == NPY_NOTYPE || numpy_dtypes.float8_e4m3 == NPY_NOTYPE || numpy_dtypes.float8_e4m3fn == NPY_NOTYPE || @@ -92,6 +97,7 @@ struct MlDtypesInitInfo { numpy_dtypes.float8_e4m3b11fnuz == NPY_NOTYPE || numpy_dtypes.float8_e5m2 == NPY_NOTYPE || numpy_dtypes.float8_e5m2fnuz == NPY_NOTYPE || + numpy_dtypes.float8_e8m0fnu == NPY_NOTYPE || numpy_dtypes.int4 == NPY_NOTYPE || numpy_dtypes.uint4 == NPY_NOTYPE) { init_valid = false; } diff --git a/xla/tsl/python/lib/core/ml_dtypes.h b/xla/tsl/python/lib/core/ml_dtypes.h index b3aa94e430239a..725d844c27bb4e 100644 --- a/xla/tsl/python/lib/core/ml_dtypes.h +++ b/xla/tsl/python/lib/core/ml_dtypes.h @@ -24,6 +24,7 @@ namespace ml_dtypes { struct NumpyDtypes { int bfloat16; + int float4_e2m1fn; int float8_e3m4; int float8_e4m3; int float8_e4m3fn; @@ -31,6 +32,7 @@ struct NumpyDtypes { int float8_e4m3fnuz; int float8_e5m2; int float8_e5m2fnuz; + int float8_e8m0fnu; int int4; int uint4; }; diff --git a/xla/types.h b/xla/types.h index 8d30a2b2500131..6ebb1caba93c34 100644 --- a/xla/types.h +++ b/xla/types.h @@ -129,16 +129,32 @@ struct make_specialized_signed>> { template using make_specialized_signed_t = typename make_specialized_signed::type; +// has_negative_zero[_v] + template struct has_negative_zero : std::bool_constant::is_iec559> {}; +template <> +struct has_negative_zero : std::bool_constant {}; + template <> struct has_negative_zero : std::bool_constant {}; template inline constexpr bool has_negative_zero_v = has_negative_zero::value; +// has_zero[_v] + +template +struct has_zero : std::bool_constant {}; + +template <> +struct has_zero : std::bool_constant {}; + +template +inline constexpr bool has_zero_v = has_zero::value; + } // namespace xla #endif // XLA_TYPES_H_ diff --git a/xla/util.cc b/xla/util.cc index 8378920df5921f..71d340db0347fd 100644 --- a/xla/util.cc +++ b/xla/util.cc @@ -145,6 +145,7 @@ std::string Reindent(absl::string_view original, template static void RoundTripNanPayload(FloatT value, std::string* result) { + static_assert(std::numeric_limits::has_quiet_NaN); static_assert(!std::is_same::value, "RoundTripNanPayload does not support E4M3FN"); static_assert(!std::is_same::value, @@ -171,6 +172,10 @@ static std::string GenericRoundTripFpToString(FloatT value) { static_cast(value)); } +std::string RoundTripFpToString(tsl::float4_e2m1fn value) { + return GenericRoundTripFpToString(value); +} + std::string RoundTripFpToString(tsl::float8_e5m2 value) { std::string result = GenericRoundTripFpToString(value); RoundTripNanPayload(value, &result); @@ -209,6 +214,11 @@ std::string RoundTripFpToString(tsl::float8_e3m4 value) { return result; } +std::string RoundTripFpToString(tsl::float8_e8m0fnu value) { + std::string result = GenericRoundTripFpToString(value); + return result; +} + std::string RoundTripFpToString(bfloat16 value) { std::string result = GenericRoundTripFpToString(value); RoundTripNanPayload(value, &result); diff --git a/xla/util.h b/xla/util.h index 959009073e96f9..a5fea9e787b7f3 100644 --- a/xla/util.h +++ b/xla/util.h @@ -416,6 +416,9 @@ std::string VectorString(const std::initializer_list& c) { return VectorString>(c); } +// Returns a string which can losslessly round trip to a float4 E2M1FN. +std::string RoundTripFpToString(tsl::float4_e2m1fn value); + // Returns a string which can losslessly round trip to a float8 E5M2. std::string RoundTripFpToString(tsl::float8_e5m2 value); @@ -437,6 +440,9 @@ std::string RoundTripFpToString(tsl::float8_e4m3fnuz value); // Returns a string which can losslessly round trip to a float8 E3M4. std::string RoundTripFpToString(tsl::float8_e3m4 value); +// Returns a string which can losslessly round trip to a float8 E8M0FNU. +std::string RoundTripFpToString(tsl::float8_e8m0fnu value); + // Returns a string which can losslessly round trip to a bfloat. std::string RoundTripFpToString(tsl::bfloat16 value); @@ -652,8 +658,9 @@ template auto SignAndMagnitude(T x) { using BitType = UnsignedIntegerTypeForSizeType; BitType x_abs_bits = Eigen::numext::bit_cast(Eigen::numext::abs(x)); - const BitType x_bits = Eigen::numext::bit_cast(x); - const BitType x_sign = x_bits ^ x_abs_bits; + // Eigen implements the sign value to be either all-zeros (for positive input) + // or all-ones (for negative input). + BitType x_sign = Eigen::numext::bit_cast(Eigen::numext::signbit(x)); if constexpr (!has_negative_zero_v) { // f8e4m3b11, f8e4m3fnuz, and f8e5m2fnuz don't support -0, adjust negative // numbers to fill in the gap. @@ -664,12 +671,17 @@ auto SignAndMagnitude(T x) { return std::make_pair(x_sign, x_abs_bits); } +template <> +inline auto SignAndMagnitude(tsl::float8_e8m0fnu x) { + uint8_t x_bits = Eigen::numext::bit_cast(x); + return std::make_pair(static_cast(0), x_bits); +} + template auto SignAndMagnitudeToTwosComplement(T sign, T magnitude) { static_assert(!std::numeric_limits::is_signed); using SignedType = std::make_signed_t; - return static_cast(magnitude) ^ - (static_cast(sign) < 0 ? SignedType{-1} : SignedType{0}); + return static_cast(magnitude) ^ static_cast(sign); } // Returns the signed magnitude of T. @@ -679,6 +691,11 @@ auto ToSignMagnitude(T input) { return SignAndMagnitudeToTwosComplement(sign, magnitude); } +template <> +inline auto ToSignMagnitude(tsl::float8_e8m0fnu x) { + return Eigen::numext::bit_cast(x); +} + template constexpr int NanPayloadBits() { // Floating point types with signaling NaNs have payloads. diff --git a/xla/util_test.cc b/xla/util_test.cc index 2fe6317bfbb8ea..74dc0820c4f869 100644 --- a/xla/util_test.cc +++ b/xla/util_test.cc @@ -207,9 +207,9 @@ namespace { template void TotalOrderHelper(T x, T y) { auto x_sm = ToSignMagnitude(x); - bool x_sign = static_cast(Eigen::numext::signbit(x)); - bool y_sign = static_cast(Eigen::numext::signbit(y)); auto y_sm = ToSignMagnitude(y); + bool x_sign = static_cast(SignAndMagnitude(x).first); + bool y_sign = static_cast(SignAndMagnitude(y).first); if (x_sign && !y_sign) { EXPECT_LT(x_sm, y_sm) << x << " " << y; } @@ -240,6 +240,18 @@ void TotalOrderHelper(T x, T y) { } } // namespace +TEST(UtilTest, TotalOrder_F4E2M1FN) { + for (int a = 0; a < 16; ++a) { + tsl::float4_e2m1fn x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 16; ++b) { + tsl::float4_e2m1fn y = + Eigen::numext::bit_cast(static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + TEST(UtilTest, TotalOrder_F8E5M2) { for (int a = 0; a < 256; ++a) { tsl::float8_e5m2 x = @@ -326,6 +338,18 @@ TEST(UtilTest, TotalOrder_F8E3M4) { } } +TEST(UtilTest, TotalOrder_F8E8M0FNU) { + for (int a = 0; a < 256; ++a) { + tsl::float8_e8m0fnu x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 256; ++b) { + tsl::float8_e8m0fnu y = + Eigen::numext::bit_cast(static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + void PackInt4(absl::Span input, absl::Span output) { CHECK_EQ(output.size(), CeilOfRatio(input.size(), size_t{2})); for (size_t i = 0; i < input.size(); ++i) { diff --git a/xla/xla_data.proto b/xla/xla_data.proto index 7d9563b11ab795..89c62b52659cde 100644 --- a/xla/xla_data.proto +++ b/xla/xla_data.proto @@ -109,6 +109,17 @@ enum PrimitiveType { F8E5M2FNUZ = 24; F8E4M3FNUZ = 25; + // MX float dtypes, as described in: + // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + // + // F4E2M1FN has 2 exponent bits and 1 mantissa bit. + // F8E8M0FNU has 8 exponent bits, no mantissa and no sign. + // + // Only finite values are supported (hence "FN" suffix). Unlike IEEE types, + // infinities and NaNs are not supported. + F4E2M1FN = 30; + F8E8M0FNU = 31; + // Complex values of fixed width. C64 = 15; // Paired F32 (real, imag), as in std::complex. C128 = 18; // Paired F64 (real, imag), as in std::complex. @@ -134,7 +145,7 @@ enum PrimitiveType { // primitive type will have empty dimensions and tuple_shapes fields. TOKEN = 17; - // Next = 30 + // Next = 32 } // LINT.ThenChange( // https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc @@ -579,15 +590,17 @@ message LiteralProto { bytes bf16s = 13; bytes u16s = 16; bytes s16s = 17; - bytes f8e5m2s = 19; - bytes f8e4m3s = 28; - bytes f8e4m3fns = 20; + bytes f4e2m1fns = 30; + bytes f8e3m4s = 29; bytes f8e4m3b11fnuzs = 23; - bytes f8e5m2fnuzs = 24; + bytes f8e4m3fns = 20; bytes f8e4m3fnuzs = 25; - bytes f8e3m4s = 29; + bytes f8e4m3s = 28; + bytes f8e5m2fnuzs = 24; + bytes f8e5m2s = 19; + bytes f8e8m0fnus = 31; repeated int64 sparse_indices = 14; - // Next = 30 + // Next = 32 } message WindowDimension {