Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #21380: Add F4E2M1FN and F8E8M0FNU types #21392

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions xla/array2d_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,34 @@ TEST(Array2dTest, LinspaceF8E3M4) {
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF4E2M1FN) {
auto arr = MakeLinspaceArray2D<tsl::float4_e2m1fn>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
}

TEST(Array2dTest, LinspaceF8E8M0FNU) {
auto arr = MakeLinspaceArray2D<tsl::float8_e8m0fnu>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 2.0); // 1.5 rounded up
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 4.0); // 3.0 rounded up
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
}

TEST(Array2dTest, Stringification) {
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
const std::string expected = R"([[1, 1.5],
Expand Down
191 changes: 119 additions & 72 deletions xla/backends/gpu/codegen/transforms/expand_float_ops.cc

Large diffs are not rendered by default.

59 changes: 35 additions & 24 deletions xla/backends/gpu/codegen/transforms/lower_tensors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ std::tuple<Value, Value> GetI4IndexAndNibble(Value linear_index,
ml::GEPOp CreateGep(TypedValue<mlir::RankedTensorType> tensor,
Value linear_index, mlir::ImplicitLocOpBuilder& b) {
Type element_type = tensor.getType().getElementType();
if (element_type == b.getI4Type()) {
if (element_type.isIntOrFloat() &&
element_type.getIntOrFloatBitWidth() == 4) {
element_type = b.getI8Type();
}
auto ptr = ml::LLVMPointerType::get(b.getContext());
Expand Down Expand Up @@ -328,7 +329,8 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
auto linear_index = GetLinearIndex(op.getIndices(), b);
Type element_type = op.getTensor().getType().getElementType();
Value is_low_nibble = nullptr;
if (element_type == rewriter.getI4Type()) {
if (element_type.isIntOrFloat() &&
element_type.getIntOrFloatBitWidth() == 4) {
std::tie(linear_index, is_low_nibble) =
GetI4IndexAndNibble(linear_index, b);
}
Expand All @@ -342,7 +344,7 @@ struct RewriteTensorExtract : OpRewritePattern<mlir::tensor::ExtractOp> {
auto high_value = b.create<mlir::arith::ShRUIOp>(
load, b.create<mlir::arith::ConstantIntOp>(4, load.getType()));
load = b.create<mlir::arith::TruncIOp>(
op.getType(),
rewriter.getI4Type(),
b.create<mlir::arith::SelectOp>(is_low_nibble, load, high_value));
}

Expand Down Expand Up @@ -378,6 +380,7 @@ struct RewriteTransferRead : OpRewritePattern<vector::TransferReadOp> {

auto source = mlir::dyn_cast<mlir::TypedValue<mlir::RankedTensorType>>(
op.getSource());
mlir::Type source_element_type = source.getType().getElementType();

mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto linear_index = GetLinearIndex(op.getIndices(), b);
Expand All @@ -386,7 +389,9 @@ struct RewriteTransferRead : OpRewritePattern<vector::TransferReadOp> {
if (vector_type.getElementType().isInteger(1)) {
vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type());
}
if (op.getVectorType().getElementType().isInteger(4)) {
mlir::Type gep_element_type = vector_type.getElementType();
if (gep_element_type.isIntOrFloat() &&
gep_element_type.getIntOrFloatBitWidth() == 4) {
linear_index = b.create<arith::ShRUIOp>(
linear_index,
b.create<arith::ConstantIntOp>(1, linear_index.getType()));
Expand All @@ -397,11 +402,12 @@ struct RewriteTransferRead : OpRewritePattern<vector::TransferReadOp> {
auto llvm_vector_type = converter.convertType(vector_type);
auto loaded = b.create<ml::LoadOp>(llvm_vector_type, gep).getResult();

if (source.getType().getElementType().isInteger(1)) {
if (source_element_type.isInteger(1)) {
Value zero = b.create<mlir::arith::ConstantOp>(
mlir::DenseElementsAttr::get(vector_type, b.getI8IntegerAttr(0)));
loaded = b.create<arith::CmpIOp>(arith::CmpIPredicate::ne, loaded, zero);
} else if (source.getType().getElementType().isInteger(4)) {
} else if (source_element_type.isIntOrFloat() &&
source_element_type.getIntOrFloatBitWidth() == 4) {
// LLVM and XLA pack i4s in opposite order, so we have to reshuffle the
// elements.
loaded = PermutePairsInVector(loaded, b);
Expand Down Expand Up @@ -430,7 +436,8 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
auto scalar_value = op.getScalar();

// For i4 we store 2 values into one byte. This needs special handling here.
if (tensor_dest.getType().getElementType() == rewriter.getI4Type()) {
if (tensor_dest.getType().getElementType().isIntOrFloat() &&
tensor_dest.getType().getElementType().getIntOrFloatBitWidth() == 4) {
// We need to use directly op.getDest() as input, otherwise the following
// rewrite might remove the only user of it.
tensor_dest = op.getDest();
Expand All @@ -448,6 +455,10 @@ struct RewriteTensorInsert : OpRewritePattern<mlir::tensor::InsertOp> {
auto tensor_dest_i8 =
b.create<UnrealizedConversionCastOp>(tensor_ty, tensor_dest)
.getResult(0);
if (scalar_value.getType() != rewriter.getI4Type()) {
scalar_value =
b.create<arith::BitcastOp>(rewriter.getI4Type(), scalar_value);
}
scalar_value = b.create<mlir::arith::ExtUIOp>(ty, scalar_value);

// We need AtomicRMWOp because it can happen that different threads try to
Expand Down Expand Up @@ -507,12 +518,14 @@ struct RewriteTransferWrite : OpRewritePattern<vector::TransferWriteOp> {
auto linear_index = GetLinearIndex(op.getIndices(), b);

mlir::Value vector_value = op.getVector();
if (op.getVectorType().getElementType().isInteger(1)) {
mlir::Type vector_element_type = op.getVectorType().getElementType();
if (vector_element_type.isInteger(1)) {
vector_value = b.create<arith::ExtUIOp>(
op.getVectorType().cloneWith(std::nullopt, b.getI8Type()),
vector_value);
}
if (op.getVectorType().getElementType().isInteger(4)) {
if (vector_element_type.isIntOrFloat() &&
vector_element_type.getIntOrFloatBitWidth() == 4) {
linear_index = b.create<arith::ShRUIOp>(
linear_index,
b.create<arith::ConstantIntOp>(1, linear_index.getType()));
Expand Down Expand Up @@ -575,21 +588,19 @@ ml::GlobalOp CreateGlobalOp(mlir::Attribute value,
// Needed to support complex element type.
mlir::LLVMTypeConverter converter(b.getContext());
auto llvm_element_type = converter.convertType(element_type);
if (mlir::isa<mlir::IntegerType>(element_type)) {
int bit_width = mlir::cast<mlir::IntegerType>(element_type).getWidth();
if (bit_width == 4) {
num_elements = CeilOfRatio<int64_t>(num_elements, 2);
llvm_element_type = b.getI8Type();
auto unpacked_data =
mlir::cast<mlir::DenseElementsAttr>(value).getRawData();
std::vector<char> packed_data(num_elements);
absl::Span<char> packed_data_span =
absl::MakeSpan(packed_data.data(), packed_data.size());
PackIntN(4, unpacked_data, packed_data_span);
value = mlir::DenseElementsAttr::getFromRawBuffer(
mlir::RankedTensorType::get({num_elements}, llvm_element_type),
packed_data);
}
if (element_type.isIntOrFloat() &&
element_type.getIntOrFloatBitWidth() == 4) {
num_elements = CeilOfRatio<int64_t>(num_elements, 2);
llvm_element_type = b.getI8Type();
auto unpacked_data =
mlir::cast<mlir::DenseElementsAttr>(value).getRawData();
std::vector<char> packed_data(num_elements);
absl::Span<char> packed_data_span =
absl::MakeSpan(packed_data.data(), packed_data.size());
PackIntN(4, unpacked_data, packed_data_span);
value = mlir::DenseElementsAttr::getFromRawBuffer(
mlir::RankedTensorType::get({num_elements}, llvm_element_type),
packed_data);
}
auto array_ty = ml::LLVMArrayType::get(llvm_element_type, num_elements);
std::string name;
Expand Down
50 changes: 50 additions & 0 deletions xla/backends/gpu/codegen/transforms/tests/expand_float_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,53 @@ module {
// CHECK: %[[EXT:.*]] = arith.extf {{.*}} : bf16 to f32
// CHECK: arith.truncf %[[EXT]] : f32 to f16
// CHECK-NOT: arith.truncf

// -----

module {
func.func @f4_to_f16(%arg0: f4E2M1FN) -> f16 {
%ret = arith.extf %arg0 : f4E2M1FN to f16
return %ret : f16
}
}

// CHECK-LABEL: @f4_to_f16
// CHECK-NOT: arith.extf

// -----

module {
func.func @f16_to_f4(%arg0: f16) -> f4E2M1FN {
%ret = arith.truncf %arg0 : f16 to f4E2M1FN
return %ret : f4E2M1FN
}
}

// CHECK-LABEL: @f16_to_f4
// CHECK-NOT: arith.truncf

// -----

module {
func.func @f4_abs(%arg0: f4E2M1FN) -> f4E2M1FN {
%ret = math.absf %arg0 : f4E2M1FN
return %ret : f4E2M1FN
}
}

// CHECK-LABEL: @f4_abs
// CHECK-NOT: math.absf
// CHECK: arith.constant 7 : i4

// -----

module {
func.func @e8m0_abs(%arg0: f8E8M0FNU) -> f8E8M0FNU {
%ret = math.absf %arg0 : f8E8M0FNU
return %ret : f8E8M0FNU
}
}

// CHECK-LABEL: @e8m0_abs
// CHECK-NOT: math.absf
// CHECK: return %arg0
39 changes: 39 additions & 0 deletions xla/backends/gpu/codegen/transforms/tests/lower_tensors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -785,3 +785,42 @@ func.func @vector_atomic_rmw(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-HOPPER: llvm.atomicrmw fadd {{.*}} !llvm.ptr, f32
// CHECK-HOPPER: llvm.atomicrmw fadd {{.*}} !llvm.ptr, f32
// CHECK-HOPPER: llvm.atomicrmw fadd {{.*}} !llvm.ptr, f32

// -----

func.func @f4_constant(%arg0: tensor<3xf4E2M1FN>, %arg1: index) -> f4E2M1FN {
%cst = arith.constant dense<[0.5, -0.5, 2.5]> : tensor<3xf4E2M1FN>
%extracted = tensor.extract %arg0[%arg1] : tensor<3xf4E2M1FN>
%extracted_0 = tensor.extract %cst[%arg1] : tensor<3xf4E2M1FN>
%0 = arith.addf %extracted, %extracted_0 : f4E2M1FN
return %0 : f4E2M1FN
}
// CHECK: llvm.mlir.global private constant
// CHECK-SAME: dense<[25, 64]>
// CHECK-LABEL: @f4_constant

// -----

func.func @transfer_read_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1}) -> vector<2xf4E2M1FN> {
%c16 = arith.constant 16 : index
%c0 = arith.constant 0.0 : f4E2M1FN
%out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf4E2M1FN>, vector<2xf4E2M1FN>
func.return %out : vector<2xf4E2M1FN>
}
// CHECK-LABEL: @transfer_read_f4
// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[8]
// CHECK: llvm.load %[[PTR]] : !llvm.ptr -> vector<2xi4>
// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xi4> to vector<2xf4E2M1FN>
// CHECK: return %[[OUT]] : vector<2xf4E2M1FN>

// -----

func.func @transfer_write_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1},
%arg1: vector<2xf4E2M1FN>) -> tensor<43xf4E2M1FN> {
%c10 = arith.constant 10 : index
%out = vector.transfer_write %arg1, %arg0[%c10] : vector<2xf4E2M1FN>, tensor<43xf4E2M1FN>
func.return %out : tensor<43xf4E2M1FN>
}
// CHECK-LABEL: @transfer_write_f4
// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %arg0[5] : (!llvm.ptr) -> !llvm.ptr, i8
// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xf4E2M1FN> to vector<2xi4>
Loading
Loading