Skip to content

Commit de4031d

Browse files
authored
[Runtime] Preserve denormals in floating point conversions (#20252)
These helper functions are used notably for printing and parsing, so flushing denormals to zero there was not a good idea, even if actual codegen would flush to zero. For example, flusing to zero when printing could hide denormals, and parsing to zero could prevent reproducing issues with denormals. A couple of unrelated changes are lumped in this PR: dropping some redundant branches to handle `nan_as_neg_zero` (which is used in FP8 types that have FN in their name). I just spotted that and it's redundant because it's already handled at the end of the function, and the tests still passing confirm that. As rightly called out by @krzysz00 in #20242 (review). Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
1 parent 1c2bfc1 commit de4031d

File tree

2 files changed

+42
-29
lines changed

2 files changed

+42
-29
lines changed

runtime/src/iree/base/internal/math.h

+33-21
Original file line numberDiff line numberDiff line change
@@ -304,12 +304,8 @@ static inline float iree_math_make_f32_from_bits(uint32_t src, int exp_bits,
304304
const uint32_t f32_sign = src_sign << (f32_sign_shift - src_sign_shift);
305305
const uint32_t src_exp = src & src_exp_mask;
306306
const uint32_t src_mantissa = src & src_mantissa_mask;
307-
// Initializing f32_exp and f32_mantissa for the case of normal finite values.
308-
// Below we will overload that in other cases.
309-
uint32_t f32_exp = ((src_exp >> src_exp_shift) + f32_exp_bias - src_exp_bias)
310-
<< f32_exp_shift;
311-
uint32_t f32_mantissa = src_mantissa
312-
<< (f32_mantissa_bits - src_mantissa_bits);
307+
uint32_t f32_exp = 0;
308+
uint32_t f32_mantissa = 0;
313309
if (src_exp == src_exp_mask) {
314310
// Top exponent value normally means infinity or NaN.
315311
if (have_infinity) {
@@ -333,16 +329,16 @@ static inline float iree_math_make_f32_from_bits(uint32_t src, int exp_bits,
333329
f32_mantissa = f32_mantissa_mask;
334330
}
335331
}
336-
} else if (src_exp == 0) {
337-
// Zero or subnormal. Generate zero, except in one case: if the source type
338-
// encodes NaN as signed zero, we handle that now.
339-
if (nan_as_neg_zero && src == src_sign_mask) {
340-
f32_exp = f32_exp_mask;
341-
f32_mantissa = f32_mantissa_mask;
342-
} else {
343-
f32_exp = 0;
344-
f32_mantissa = 0;
345-
}
332+
} else if (nan_as_neg_zero && src == src_sign_mask) {
333+
// Source is NaN encoded as negative zero. Generate NaN.
334+
f32_exp = f32_exp_mask;
335+
f32_mantissa = f32_mantissa_mask;
336+
} else if (src_exp == 0 && src_mantissa == 0) {
337+
// Zero. Leave f32_exp and f32_mantissa as zero.
338+
} else {
339+
f32_exp = ((src_exp >> src_exp_shift) + f32_exp_bias - src_exp_bias)
340+
<< f32_exp_shift;
341+
f32_mantissa = src_mantissa << (f32_mantissa_bits - src_mantissa_bits);
346342
}
347343
const uint32_t u32_value = f32_sign | f32_exp | f32_mantissa;
348344
float f32_value;
@@ -378,11 +374,16 @@ static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even(
378374
// Inf. Leave zero mantissa.
379375
}
380376
} else if (f32_exp == 0) {
381-
// Zero or subnormal. Generate zero. Leave zero mantissa.
382-
if (nan_as_neg_zero) {
383-
// The destination has no signed zero. Avoid accidentally generating NaN.
384-
dst_sign = 0;
377+
// Zero or subnormal.
378+
if (dst_exp_bits == f32_exp_bits) {
379+
// When the destination type still has as many exponent bits, denormals
380+
// can remain nonzero. This happens only with the bf16 type.
381+
// Just truncate the mantissa. Not worth bothering with round-to-nearest
382+
// for denormals for bf16 only.
383+
dst_mantissa = f32_mantissa >> (f32_mantissa_bits - dst_mantissa_bits);
385384
}
385+
// The destination type has fewer exponent bits, so f32 subnormal values
386+
// become exactly zero. Leave the mantissa zero.
386387
} else {
387388
// Normal finite value.
388389
int arithmetic_exp = (f32_exp >> f32_exp_shift) - f32_exp_bias;
@@ -397,8 +398,19 @@ static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even(
397398
generate_nan = true;
398399
}
399400
} else if (arithmetic_exp + dst_exp_bias <= 0) {
400-
// Underflow. Generate zero. Leave zero mantissa.
401+
// Underflow. Generate a subnormal or zero.
401402
dst_exp = 0;
403+
// The exponent has to be clamped to 0 when the value
404+
// (arithmetic_exp + dst_exp_bias) is negative. This has to be compensated
405+
// by right-shifting the subnormal mantissa.
406+
int exp_to_encode_as_bitshift = -(arithmetic_exp + dst_exp_bias);
407+
int shift_amount =
408+
f32_mantissa_bits - dst_mantissa_bits + exp_to_encode_as_bitshift;
409+
if (shift_amount >= f32_mantissa_bits) {
410+
dst_mantissa = 0;
411+
} else {
412+
dst_mantissa = f32_mantissa >> shift_amount;
413+
}
402414
} else {
403415
// Normal case.
404416
// Implement round-to-nearest-even, by adding a bias before truncating.

runtime/src/iree/base/internal/math_test.cc

+9-8
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,10 @@ TEST(F16ConversionTest, F32ToF16) {
192192
// Underflow
193193
EXPECT_EQ(0, iree_math_f32_to_f16(FLT_MIN));
194194
EXPECT_EQ(0x8000, iree_math_f32_to_f16(-FLT_MIN));
195-
EXPECT_EQ(0, iree_math_f32_to_f16(1.0e-05));
196-
EXPECT_EQ(0x8000, iree_math_f32_to_f16(-1.0e-05));
197-
EXPECT_EQ(0, iree_math_f32_to_f16(6.1e-05)); // Near largest denormal
198-
EXPECT_EQ(0x8000, iree_math_f32_to_f16(-6.1e-05));
195+
EXPECT_EQ(0x004F, iree_math_f32_to_f16(1.0e-05));
196+
EXPECT_EQ(0x804F, iree_math_f32_to_f16(-1.0e-05));
197+
EXPECT_EQ(0x03FE, iree_math_f32_to_f16(6.1e-05)); // Near largest denormal
198+
EXPECT_EQ(0x83FE, iree_math_f32_to_f16(-6.1e-05));
199199

200200
// Denormals may or may not get flushed to zero. Accept both ways.
201201
uint16_t positive_denormal = iree_math_f32_to_f16(kF16Min / 2);
@@ -319,7 +319,8 @@ TEST(BF16ConversionTest, F32ToBF16ToF32) {
319319
EXPECT_EQ(FLT_MIN, iree_math_bf16_to_f32(iree_math_f32_to_bf16(FLT_MIN)));
320320
EXPECT_EQ(-FLT_MIN, iree_math_bf16_to_f32(iree_math_f32_to_bf16(-FLT_MIN)));
321321
// Denormals
322-
EXPECT_EQ(0.0f, iree_math_bf16_to_f32(iree_math_f32_to_bf16(2.0e-40f)));
322+
EXPECT_EQ(1.83670992e-40f,
323+
iree_math_bf16_to_f32(iree_math_f32_to_bf16(2.0e-40f)));
323324
// Inf and Nan
324325
EXPECT_EQ(INFINITY, iree_math_bf16_to_f32(iree_math_f32_to_bf16(INFINITY)));
325326
EXPECT_EQ(-INFINITY, iree_math_bf16_to_f32(iree_math_f32_to_bf16(-INFINITY)));
@@ -362,10 +363,10 @@ TEST(F8E5M2ConversionTest, F32ToF8E5M2) {
362363
// Underflow
363364
EXPECT_EQ(0, iree_math_f32_to_f8e5m2(FLT_MIN));
364365
EXPECT_EQ(0x80, iree_math_f32_to_f8e5m2(-FLT_MIN));
365-
EXPECT_EQ(0, iree_math_f32_to_f8e5m2(kF8E5M2Min * 0.5f));
366+
EXPECT_EQ(0x00, iree_math_f32_to_f8e5m2(kF8E5M2Min * 0.5f));
366367
EXPECT_EQ(0x80, iree_math_f32_to_f8e5m2(-kF8E5M2Min * 0.5f));
367-
EXPECT_EQ(0, iree_math_f32_to_f8e5m2(kF8E5M2Min * 0.75f));
368-
EXPECT_EQ(0x80, iree_math_f32_to_f8e5m2(-kF8E5M2Min * 0.75f));
368+
EXPECT_EQ(0x02, iree_math_f32_to_f8e5m2(kF8E5M2Min * 0.75f));
369+
EXPECT_EQ(0x82, iree_math_f32_to_f8e5m2(-kF8E5M2Min * 0.75f));
369370

370371
// Denormals may or may not get flushed to zero. Accept both ways.
371372
uint16_t positive_denormal = iree_math_f32_to_f8e5m2(kF8E5M2Min / 2);

0 commit comments

Comments
 (0)