Skip to content

Commit

Permalink
PR #21380: Add F4E2M1FN and F8E8M0FNU types
Browse files Browse the repository at this point in the history
Imported from GitHub PR #21380

Previous PR #19096 was rolled back, re-trying.

This PR adds F4E2M1FN primitive type (4-bit float with 2 bits exponent and 1 bit mantissa), F8E8M0FNU primitive type (8-bit float with 8 bits exponent, no mantissa and no sign) and enables loads/stores in the same way S4/U4 type is implemented.

This will enable using microscaling (MX) formats ([RFC](#18085)), such as MXFP4.

```c
F4E2M1FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.0
- Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0
- Min normal number: S.01.0 = ±2^(0) = ±1.0
- Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5

F8E8M0FNU
- Exponent bias: 127
- Maximum stored exponent value: 254 (binary 1111'1110)
- Maximum unbiased exponent value: 254 - 127 = 127
- Minimum stored exponent value: 0 (binary 0000'0000)
- Minimum unbiased exponent value: 0 − 127 = -127
- Doesn't have zero
- Doesn't have infinity
- NaN is encoded as binary 1111'1111

Additional details:
- Zeros cannot be represented
- Negative values cannot be represented
- Mantissa is always 1
```

Related PRs:
- openxla/stablehlo#2582
- jax-ml/ml_dtypes#181
- llvm/llvm-project#95392
- llvm/llvm-project#108877
- jax-ml/ml_dtypes#166
- llvm/llvm-project#107127
- llvm/llvm-project#111028
Copybara import of the project:

--
d7e00c4 by Sergey Kozub <skozub@nvidia.com>:

Add F4E2M1FN and F8E8M0FNU types

Merging this change closes #21380

COPYBARA_INTEGRATE_REVIEW=#21380 from openxla:skozub/e2m1_e8m0 d7e00c4
PiperOrigin-RevId: 715434229
  • Loading branch information
sergey-kozub authored and Google-ML-Automation committed Jan 14, 2025
1 parent 88a2497 commit b912750
Show file tree
Hide file tree
Showing 79 changed files with 1,851 additions and 376 deletions.
28 changes: 28 additions & 0 deletions xla/array2d_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,34 @@ TEST(Array2dTest, LinspaceF8E3M4) {
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF4E2M1FN) {
auto arr = MakeLinspaceArray2D<tsl::float4_e2m1fn>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
}

TEST(Array2dTest, LinspaceF8E8M0FNU) {
auto arr = MakeLinspaceArray2D<tsl::float8_e8m0fnu>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 2.0); // 1.5 rounded up
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 4.0); // 3.0 rounded up
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
}

TEST(Array2dTest, Stringification) {
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
const std::string expected = R"([[1, 1.5],
Expand Down
191 changes: 119 additions & 72 deletions xla/backends/gpu/codegen/transforms/expand_float_ops.cc

Large diffs are not rendered by default.

59 changes: 35 additions & 24 deletions xla/backends/gpu/codegen/transforms/lower_tensors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ std::tuple<Value, Value> GetI4IndexAndNibble(Value linear_index,
ml::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
Value linear_index, mlir::ImplicitLocOpBuilder& b) {
Type element_type = tensor.getType().getElementType();
if (element_type == b.getI4Type()) {
if (element_type.isIntOrFloat() &&
element_type.getIntOrFloatBitWidth() == 4) {
element_type = b.getI8Type();
}
auto ptr = ml::LLVMPointerType::get(b.getContext());
Expand Down Expand Up @@ -328,7 +329,8 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
auto linear_index = GetLinearIndex(op.getIndices(), b);
Type element_type = op.getTensor().getType().getElementType();
Value is_low_nibble = nullptr;
if (element_type == rewriter.getI4Type()) {
if (element_type.isIntOrFloat() &&
element_type.getIntOrFloatBitWidth() == 4) {
std::tie(linear_index, is_low_nibble) =
GetI4IndexAndNibble(linear_index, b);
}
Expand All @@ -342,7 +344,7 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
auto high_value = b.create<mlir::arith::ShRUIOp>(
load, b.create<mlir::arith::ConstantIntOp>(4, load.getType()));
load = b.create<mlir::arith::TruncIOp>(
op.getType(),
rewriter.getI4Type(),
b.create<mlir::arith::SelectOp>(is_low_nibble, load, high_value));
}

Expand Down Expand Up @@ -378,6 +380,7 @@ struct RewriteTransferRead : OpRewritePattern<vector::TransferReadOp> {

auto source = mlir::dyn_cast<mlir::TypedValue<mlir::RankedTensorType>>(
op.getSource());
mlir::Type source_element_type = source.getType().getElementType();

mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto linear_index = GetLinearIndex(op.getIndices(), b);
Expand All @@ -386,7 +389,9 @@ struct RewriteTransferRead : OpRewritePattern<vector::TransferReadOp> {
if (vector_type.getElementType().isInteger(1)) {
vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type());
}
if (op.getVectorType().getElementType().isInteger(4)) {
mlir::Type gep_element_type = vector_type.getElementType();
if (gep_element_type.isIntOrFloat() &&
gep_element_type.getIntOrFloatBitWidth() == 4) {
linear_index = b.create<arith::ShRUIOp>(
linear_index,
b.create<arith::ConstantIntOp>(1, linear_index.getType()));
Expand All @@ -397,11 +402,12 @@ struct RewriteTransferRead : OpRewritePattern<vector::TransferReadOp> {
auto llvm_vector_type = converter.convertType(vector_type);
auto loaded = b.create<ml::LoadOp>(llvm_vector_type, gep).getResult();

if (source.getType().getElementType().isInteger(1)) {
if (source_element_type.isInteger(1)) {
Value zero = b.create<mlir::arith::ConstantOp>(
mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0)));
loaded = b.create<arith::CmpIOp>(arith::CmpIPredicate::ne, loaded, zero);
} else if (source.getType().getElementType().isInteger(4)) {
} else if (source_element_type.isIntOrFloat() &&
source_element_type.getIntOrFloatBitWidth() == 4) {
// LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
// elements.
loaded = PermutePairsInVector(loaded, b);
Expand Down Expand Up @@ -430,7 +436,8 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
auto scalar_value = op.getScalar();

// For i4 we store 2 values into one byte. This needs special handling here.
if (tensor_dest.getType().getElementType() == rewriter.getI4Type()) {
if (tensor_dest.getType().getElementType().isIntOrFloat() &&
tensor_dest.getType().getElementType().getIntOrFloatBitWidth() == 4) {
// We need to use directly op.getDest() as input, otherwise the following
// rewrite might remove the only user of it.
tensor_dest = op.getDest();
Expand All @@ -448,6 +455,10 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
auto tensor_dest_i8 =
b.create<UnrealizedConversionCastOp>(tensor_ty, tensor_dest)
.getResult(0);
if (scalar_value.getType() != rewriter.getI4Type()) {
scalar_value =
b.create<arith::BitcastOp>(rewriter.getI4Type(), scalar_value);
}
scalar_value = b.create<mlir::arith::ExtUIOp>(ty, scalar_value);

// We need AtomicRMWOp because it can happen that different threads try to
Expand Down Expand Up @@ -507,12 +518,14 @@ struct RewriteTransferWrite : OpRewritePattern<vector::TransferWriteOp> {
auto linear_index = GetLinearIndex(op.getIndices(), b);

mlir::Value vector_value = op.getVector();
if (op.getVectorType().getElementType().isInteger(1)) {
mlir::Type vector_element_type = op.getVectorType().getElementType();
if (vector_element_type.isInteger(1)) {
vector_value = b.create<arith::ExtUIOp>(
op.getVectorType().cloneWith(std::nullopt, b.getI8Type()),
vector_value);
}
if (op.getVectorType().getElementType().isInteger(4)) {
if (vector_element_type.isIntOrFloat() &&
vector_element_type.getIntOrFloatBitWidth() == 4) {
linear_index = b.create<arith::ShRUIOp>(
linear_index,
b.create<arith::ConstantIntOp>(1, linear_index.getType()));
Expand Down Expand Up @@ -575,21 +588,19 @@ ml::GlobalOp CreateGlobalOp(mlir::Attribute value,
// Needed to support complex element type.
mlir::LLVMTypeConverter converter(b.getContext());
auto llvm_element_type = converter.convertType(element_type);
if (mlir::isa<mlir::IntegerType>(element_type)) {
int bit_width = mlir::cast<mlir::IntegerType>(element_type).getWidth();
if (bit_width == 4) {
num_elements = CeilOfRatio<int64_t>(num_elements, 2);
llvm_element_type = b.getI8Type();
auto unpacked_data =
mlir::cast<mlir::DenseElementsAttr>(value).getRawData();
std::vector<char> packed_data(num_elements);
absl::Span<char> packed_data_span =
absl::MakeSpan(packed_data.data(), packed_data.size());
PackIntN(4, unpacked_data, packed_data_span);
value = mlir::DenseElementsAttr::getFromRawBuffer(
mlir::RankedTensorType::get({num_elements}, llvm_element_type),
packed_data);
}
if (element_type.isIntOrFloat() &&
element_type.getIntOrFloatBitWidth() == 4) {
num_elements = CeilOfRatio<int64_t>(num_elements, 2);
llvm_element_type = b.getI8Type();
auto unpacked_data =
mlir::cast<mlir::DenseElementsAttr>(value).getRawData();
std::vector<char> packed_data(num_elements);
absl::Span<char> packed_data_span =
absl::MakeSpan(packed_data.data(), packed_data.size());
PackIntN(4, unpacked_data, packed_data_span);
value = mlir::DenseElementsAttr::getFromRawBuffer(
mlir::RankedTensorType::get({num_elements}, llvm_element_type),
packed_data);
}
auto array_ty = ml::LLVMArrayType::get(llvm_element_type, num_elements);
std::string name;
Expand Down
50 changes: 50 additions & 0 deletions xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,53 @@ module {
// CHECK: %[[EXT:.*]] = arith.extf {{.*}} : bf16 to f32
// CHECK: arith.truncf %[[EXT]] : f32 to f16
// CHECK-NOT: arith.truncf

// -----

module {
func.func @f4_to_f16(%arg0: f4E2M1FN) -> f16 {
%ret = arith.extf %arg0 : f4E2M1FN to f16
return %ret : f16
}
}

// CHECK-LABEL: @f4_to_f16
// CHECK-NOT: arith.extf

// -----

module {
func.func @f16_to_f4(%arg0: f16) -> f4E2M1FN {
%ret = arith.truncf %arg0 : f16 to f4E2M1FN
return %ret : f4E2M1FN
}
}

// CHECK-LABEL: @f16_to_f4
// CHECK-NOT: arith.truncf

// -----

module {
func.func @f4_abs(%arg0: f4E2M1FN) -> f4E2M1FN {
%ret = math.absf %arg0 : f4E2M1FN
return %ret : f4E2M1FN
}
}

// CHECK-LABEL: @f4_abs
// CHECK-NOT: math.absf
// CHECK: arith.constant 7 : i4

// -----

module {
func.func @e8m0_abs(%arg0: f8E8M0FNU) -> f8E8M0FNU {
%ret = math.absf %arg0 : f8E8M0FNU
return %ret : f8E8M0FNU
}
}

// CHECK-LABEL: @e8m0_abs
// CHECK-NOT: math.absf
// CHECK: return %arg0
39 changes: 39 additions & 0 deletions xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -785,3 +785,42 @@ func.func @vector_atomic_rmw(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-HOPPER: llvm.atomicrmw fadd {{.*}} !llvm.ptr, f32
// CHECK-HOPPER: llvm.atomicrmw fadd {{.*}} !llvm.ptr, f32
// CHECK-HOPPER: llvm.atomicrmw fadd {{.*}} !llvm.ptr, f32

// -----

func.func @f4_constant(%arg0: tensor<3xf4E2M1FN>, %arg1: index) -> f4E2M1FN {
%cst = arith.constant dense<[0.5, -0.5, 2.5]> : tensor<3xf4E2M1FN>
%extracted = tensor.extract %arg0[%arg1] : tensor<3xf4E2M1FN>
%extracted_0 = tensor.extract %cst[%arg1] : tensor<3xf4E2M1FN>
%0 = arith.addf %extracted, %extracted_0 : f4E2M1FN
return %0 : f4E2M1FN
}
// CHECK: llvm.mlir.global private constant
// CHECK-SAME: dense<[25, 64]>
// CHECK-LABEL: @f4_constant

// -----

func.func @transfer_read_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1}) -> vector<2xf4E2M1FN> {
%c16 = arith.constant 16 : index
%c0 = arith.constant 0.0 : f4E2M1FN
%out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf4E2M1FN>, vector<2xf4E2M1FN>
func.return %out : vector<2xf4E2M1FN>
}
// CHECK-LABEL: @transfer_read_f4
// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[8]
// CHECK: llvm.load %[[PTR]] : !llvm.ptr -> vector<2xi4>
// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xi4> to vector<2xf4E2M1FN>
// CHECK: return %[[OUT]] : vector<2xf4E2M1FN>

// -----

func.func @transfer_write_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1},
%arg1: vector<2xf4E2M1FN>) -> tensor<43xf4E2M1FN> {
%c10 = arith.constant 10 : index
%out = vector.transfer_write %arg1, %arg0[%c10] : vector<2xf4E2M1FN>, tensor<43xf4E2M1FN>
func.return %out : tensor<43xf4E2M1FN>
}
// CHECK-LABEL: @transfer_write_f4
// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %arg0[5] : (!llvm.ptr) -> !llvm.ptr, i8
// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xf4E2M1FN> to vector<2xi4>
Loading

0 comments on commit b912750

Please sign in to comment.