Skip to content

Commit 1aa968d

Browse files
authored
Improvements to check.expect_almost_eq and FP8 support (#20198)
1. The fuzzy comparison used in `check.expect_almost_eq` is changed from a single absolute `tolerance` to a two-parameter (`atol`, `rtol`) NumPy-compatible fuzzy comparison: `abs(lhs - rhs) <= atol + rtol * abs(rhs)`. The new `rtol` parameter defaults to 0 to preserve existing behavior. The custom parser/printer are dropped, reverting to default-generated. 2. Fixed comparisons of `NaN`'s, which were silently *passing* comparison to anything else! The problem was that the implementation in modules/check/module.cc was using the negated condition: `if (abs(lhs - rhs) > tolerance) return false;`. Since all expressions involving `NaN` evaluate to `false`, that ensured that as long as `lhs` or `rhs` was `NaN`, we wouldn't return `false`. Generally, `NaN` semantics break symmetry under negation, so a good rule to take away from this is to avoid negating boolean expressions that involve floating-point values. 3. Dropped the test `tests/e2e/regression/disable_demote_f64_to_f32.mlir`. It isn't needed anyway as we by now have a number of tests relying on the flag in question - it is necessary in any test exercising f64. * The reason to drop this test now is that it was actually relying on the buggy handling of NaN (above point 2.). It was producing NaN's and was only passing its own checks thanks to that bug. * Note that that is the same test that #20177 fixed undefined behavior in. 5. Added support for bf16 and all the fp8 types both in `check.expect_almost_eq` and in general (`string_util.c`) printing. 6. Printing more helpful diagnostics on failures of fuzzy comparisions, giving details about the first array position that fails the comparision, otherwise on large arrays it is hard or impossible to tell the issue from the debug output. * Example: `Expected near equality of these values. Contents does not match to tolerance parameters atol=0.01, rtol=0. The first failure occurs at index 0 as the lhs value 1 differs from the rhs value 0.98.` --------- Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
1 parent 3719c01 commit 1aa968d

File tree

15 files changed

+334
-163
lines changed

15 files changed

+334
-163
lines changed

compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,12 @@ std::optional<SmallVector<Value>> rewriteAttrToOperands(Location loc,
181181
intAttr.getValue().getSExtValue())));
182182
return {{constValue}};
183183
} else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attrValue)) {
184+
bool lossy = false;
185+
APFloat value = floatAttr.getValue();
186+
value.convert(llvm::cast<FloatType>(inputType).getFloatSemantics(),
187+
llvm::RoundingMode::NearestTiesToEven, &lossy);
184188
auto constValue = builder.create<mlir::arith::ConstantOp>(
185-
loc, inputType, FloatAttr::get(inputType, floatAttr.getValue()));
189+
loc, inputType, FloatAttr::get(inputType, value));
186190
return {{constValue}};
187191
} else if (auto elementsAttr =
188192
llvm::dyn_cast<DenseIntElementsAttr>(attrValue)) {

compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp

+2-33
Original file line numberDiff line numberDiff line change
@@ -34,43 +34,12 @@ struct ExpectAlmostEqConstOpToExpectAlmostEqOp
3434
PatternRewriter &rewriter) const override {
3535
auto rhs = rewriter.create<arith::ConstantOp>(op.getLoc(), op.getValue());
3636
rewriter.replaceOpWithNewOp<ExpectAlmostEqOp>(
37-
op, op.getDevice(), op.getLhs(), rhs, op.getToleranceAttr());
37+
op, op.getDevice(), op.getLhs(), rhs, op.getAtolAttr(),
38+
op.getRtolAttr());
3839
return success();
3940
}
4041
};
4142

42-
static constexpr char kToleranceKeyword[] = "tolerance";
43-
static constexpr float kToleranceDefaultValue = 1e-4f;
44-
45-
static ParseResult parseOptionalFloatTolerance(OpAsmParser &parser,
46-
FloatAttr &tolerance) {
47-
float toleranceValue = kToleranceDefaultValue;
48-
if (succeeded(parser.parseOptionalComma())) {
49-
if (failed(parser.parseKeyword(kToleranceKeyword))) {
50-
return parser.emitError(parser.getCurrentLocation(),
51-
llvm::Twine("Expected keyword: ") +
52-
kToleranceKeyword);
53-
}
54-
llvm::APFloat parsedTolerance(APFloat::IEEEsingle());
55-
if (failed(parser.parseFloat(parsedTolerance.getSemantics(),
56-
parsedTolerance))) {
57-
return parser.emitError(parser.getCurrentLocation(),
58-
"Failed to parse optional float tolerance.");
59-
}
60-
toleranceValue = parsedTolerance.convertToFloat();
61-
}
62-
tolerance = parser.getBuilder().getF32FloatAttr(toleranceValue);
63-
return success();
64-
}
65-
66-
static void printOptionalFloatTolerance(OpAsmPrinter &p, Operation *op,
67-
FloatAttr tolerance) {
68-
float toleranceValue = tolerance.getValue().convertToFloat();
69-
if (toleranceValue != kToleranceDefaultValue) {
70-
p << ", " << kToleranceKeyword << " " << toleranceValue;
71-
}
72-
}
73-
7443
} // namespace
7544

7645
void ExpectEqConstOp::getCanonicalizationPatterns(RewritePatternSet &results,

compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td

+28-12
Original file line numberDiff line numberDiff line change
@@ -137,26 +137,40 @@ def CHECK_ExpectAlmostEqOp :
137137
Op<CHECK_Dialect, "expect_almost_eq", [AllTypesMatch<["lhs", "rhs"]>]> {
138138
let summary = [{Checks that the operands are almost equal}];
139139
let description = [{
140-
Verifies that the buffer view or tensor operands with float elements are
141-
almost equal to within an implementation-defined "reasonable" tolerance.
140+
Verifies that the buffer view or tensor operands with float elements satisfy
141+
the Numpy-style fuzzy-comparision condition with pararameters `atol`,
142+
`rtol`, which is the following element-wise on array elements `lhs`, `rhs`:
143+
```
144+
abs(lhs - rhs) <= atol + rtol * abs(rhs).
145+
```
142146

143147
Issues a non-fatal failure if the verification fails.
144148

149+
The `atol`, `rtol` parameters may be omitted, in which case some default
150+
value is used. The default `atol` is nonzero, while the default `rtol` is
151+
zero, which makes these comparision behave closer to exact comparisons as
152+
the values being compared get large.
153+
154+
This default behavior is supported for legacy compatibility and to support
155+
some use cases that legitimately don't care, but the majority of use cases
156+
should care and so should provide explicit `atol`, `rtol` values.
157+
145158
```mlir
146-
check.expect_almost_eq(%arg0, %arg1) : tensor<5xf32>
159+
check.expect_almost_eq(%arg0, %arg1, atol 1.0e-2, rtol 1.0e-3) : tensor<5xf32>
147160
```
148161
}];
149162

150163
let arguments = (ins
151164
Optional<HAL_Device>:$device,
152165
AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$lhs,
153166
AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$rhs,
154-
F32Attr:$tolerance
167+
DefaultValuedAttr<F32Attr, "1.e-4f">:$atol,
168+
DefaultValuedAttr<F32Attr, "0.f">:$rtol
155169
);
156170

157171
let assemblyFormat = [{
158172
(`` `<` $device^ `>`)?
159-
`` `(` $lhs `,` $rhs `` custom<OptionalFloatTolerance>($tolerance) `)`
173+
`` `(` $lhs `,` $rhs (`` `,` `atol` $atol^)? (`` `,` `rtol` $rtol^)? `)`
160174
attr-dict `:` type($lhs)
161175
}];
162176
}
@@ -166,30 +180,32 @@ def CHECK_ExpectAlmostEqConstOp :
166180
"expect_almost_eq_const", [AllTypesMatch<["lhs", "value"]>]> {
167181
let summary = [{Checks that the tensor operand is almost equal to some constant}];
168182
let description = [{
169-
Verifies that the tensor operand with float elements is almost equal to the
170-
constant attribute within an implementation-defined "reasonable" tolerance.
183+
This op is just a convenience wrapper around the expect_almost_eq op.
171184

172-
Issues a non-fatal failure if the verification fails.
185+
Verifies that the buffer view or tensor operands with float elements satisfy
186+
the Numpy-style fuzzy-comparision condition with pararameters `atol`,
187+
`rtol`. More details in the description of `expect_almost_eq`.
173188

174-
This op is just a convenience wrapper around the expect_almost_eq op.
189+
Issues a non-fatal failure if the verification fails.
175190

176191
```mlir
177-
check.expect_almost_eq_const(%const0, dense<[0.999999, 2.0]> : tensor<5xf32>) : tensor<5xf32>
192+
check.expect_almost_eq_const(%const0, dense<[0.999999, 2.0]> : tensor<5xf32>, atol 1.0e-2, rtol 1.0e-3) : tensor<5xf32>
178193
```
179194
}];
180195

181196
let arguments = (ins
182197
Optional<HAL_Device>:$device,
183198
TensorOf<[AnyFloat]>:$lhs,
184199
ElementsAttr:$value,
185-
F32Attr:$tolerance
200+
DefaultValuedAttr<F32Attr, "1.e-4f">:$atol,
201+
DefaultValuedAttr<F32Attr, "0.f">:$rtol
186202
);
187203

188204
let hasCanonicalizer = 1;
189205

190206
let assemblyFormat = [{
191207
(`` `<` $device^ `>`)?
192-
`` `(` $lhs `,` $value `` custom<OptionalFloatTolerance>($tolerance) `)`
208+
`` `(` $lhs `,` $value (`` `,` `atol` $atol^)? (`` `,` `rtol` $rtol^)? `)`
193209
attr-dict `:` type($lhs)
194210
}];
195211
}

compiler/src/iree/compiler/Modules/Check/check.imports.mlir

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ vm.import private optional @expect_almost_eq(
2929
%device : !vm.ref<!hal.device>,
3030
%lhs : !vm.ref<!hal.buffer_view>,
3131
%rhs : !vm.ref<!hal.buffer_view>,
32-
%tolerance : f32
32+
%atol : f32,
33+
%rtol : f32,
3334
)
3435

3536
} // vm.module

compiler/src/iree/compiler/Modules/Check/test/ops.mlir

+14-4
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ util.func public @expect_almost_eq(%lhs : !hal.buffer_view, %rhs : !hal.buffer_v
8989
// CHECK-SAME: %[[LHS:[a-zA-Z0-9$._-]+]]
9090
// CHECK-SAME: %[[RHS:[a-zA-Z0-9$._-]+]]
9191
util.func public @expect_almost_eq(%lhs : !hal.buffer_view, %rhs : !hal.buffer_view) {
92-
// CHECK: check.expect_almost_eq(%[[LHS]], %[[RHS]], tolerance 1.23{{0*}}e+02) : !hal.buffer_view
93-
check.expect_almost_eq(%lhs, %rhs, tolerance 123.0) : !hal.buffer_view
92+
// CHECK: check.expect_almost_eq(%[[LHS]], %[[RHS]], atol 1.23{{0*}}e+02) : !hal.buffer_view
93+
check.expect_almost_eq(%lhs, %rhs, atol 123.0) : !hal.buffer_view
9494
util.return
9595
}
9696

@@ -120,7 +120,17 @@ util.func public @expect_almost_eq_const(%lhs : tensor<2x2xf32>) {
120120
// CHECK-LABEL: @expect_almost_eq_const
121121
// CHECK-SAME: %[[LHS:[a-zA-Z0-9$._-]+]]
122122
util.func public @expect_almost_eq_const(%lhs : tensor<2x2xf32>) {
123-
// CHECK: check.expect_almost_eq_const(%[[LHS]], dense<1.000000e+00> : tensor<2x2xf32>, tolerance 1.23{{0*}}e+02) : tensor<2x2xf32>
124-
check.expect_almost_eq_const(%lhs, dense<1.0> : tensor<2x2xf32>, tolerance 123.0) : tensor<2x2xf32>
123+
// CHECK: check.expect_almost_eq_const(%[[LHS]], dense<1.000000e+00> : tensor<2x2xf32>, atol 1.23{{0*}}e+02) : tensor<2x2xf32>
124+
check.expect_almost_eq_const(%lhs, dense<1.0> : tensor<2x2xf32>, atol 123.0) : tensor<2x2xf32>
125+
util.return
126+
}
127+
128+
// -----
129+
130+
// CHECK-LABEL: @expect_almost_eq_const
131+
// CHECK-SAME: %[[LHS:[a-zA-Z0-9$._-]+]]
132+
util.func public @expect_almost_eq_const(%lhs : tensor<2x2xf32>) {
133+
// CHECK: check.expect_almost_eq_const(%[[LHS]], dense<1.000000e+00> : tensor<2x2xf32>, atol 2.5{{0*}}e-01, rtol 1.25{{0*}}e-01) : tensor<2x2xf32>
134+
check.expect_almost_eq_const(%lhs, dense<1.0> : tensor<2x2xf32>, atol 0.25, rtol 0.125) : tensor<2x2xf32>
125135
util.return
126136
}

runtime/src/iree/hal/string_util.c

+16
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,22 @@ IREE_API_EXPORT iree_status_t iree_hal_format_element(
572572
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu64,
573573
*(const uint64_t*)data.data);
574574
break;
575+
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FN:
576+
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%G",
577+
iree_math_f8e4m3fn_to_f32(*(const uint8_t*)data.data));
578+
break;
579+
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ:
580+
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%G",
581+
iree_math_f8e4m3fnuz_to_f32(*(const uint8_t*)data.data));
582+
break;
583+
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2:
584+
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%G",
585+
iree_math_f8e5m2_to_f32(*(const uint8_t*)data.data));
586+
break;
587+
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ:
588+
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%G",
589+
iree_math_f8e5m2fnuz_to_f32(*(const uint8_t*)data.data));
590+
break;
575591
case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
576592
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%G",
577593
iree_math_bf16_to_f32(*(const uint16_t*)data.data));

runtime/src/iree/modules/check/check_test.cc

+37-24
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,8 @@ TEST_F(CheckTest, ExpectAlmostEqSameBufferSuccess) {
441441
CreateFloat32BufferView(contents, shape, &input_buffer_view));
442442
IREE_ASSERT_OK(Invoke("expect_almost_eq",
443443
{input_buffer_view, input_buffer_view},
444-
{/*tolerance=*/iree_vm_value_make_f32(0.f)}));
444+
{/*atol=*/iree_vm_value_make_f32(0.f),
445+
/*rtol=*/iree_vm_value_make_f32(0.f)}));
445446
}
446447

447448
TEST_F(CheckTest, ExpectAlmostEqIdenticalBufferSuccess) {
@@ -452,19 +453,21 @@ TEST_F(CheckTest, ExpectAlmostEqIdenticalBufferSuccess) {
452453
ASSERT_NO_FATAL_FAILURE(CreateFloat32BufferView(contents, shape, &lhs));
453454
ASSERT_NO_FATAL_FAILURE(CreateFloat32BufferView(contents, shape, &rhs));
454455
IREE_ASSERT_OK(Invoke("expect_almost_eq", {lhs, rhs},
455-
{/*tolerance=*/iree_vm_value_make_f32(0.f)}));
456+
{/*atol=*/iree_vm_value_make_f32(0.f),
457+
/*rtol=*/iree_vm_value_make_f32(0.f)}));
456458
}
457459

458460
TEST_F(CheckTest, ExpectAlmostEqNearIdenticalBufferSuccess) {
459461
vm::ref<iree_hal_buffer_view_t> lhs;
460462
vm::ref<iree_hal_buffer_view_t> rhs;
461-
float lhs_contents[] = {1.0f, 1.99999f, 0.00001f, 4.0f};
462-
float rhs_contents[] = {1.00001f, 2.0f, 0.0f, 4.0f};
463+
float lhs_contents[] = {1.0f, 1.99999f, 0.00001f, 10000.0f};
464+
float rhs_contents[] = {1.00001f, 2.0f, 0.0f, 10000.1f};
463465
iree_hal_dim_t shape[] = {4};
464466
ASSERT_NO_FATAL_FAILURE(CreateFloat32BufferView(lhs_contents, shape, &lhs));
465467
ASSERT_NO_FATAL_FAILURE(CreateFloat32BufferView(rhs_contents, shape, &rhs));
466468
IREE_ASSERT_OK(Invoke("expect_almost_eq", {lhs, rhs},
467-
{/*tolerance=*/iree_vm_value_make_f32(1.e-4f)}));
469+
{/*atol=*/iree_vm_value_make_f32(1.e-4f),
470+
/*rtol=*/iree_vm_value_make_f32(1.e-4f)}));
468471
}
469472

470473
TEST_F(CheckTest, ExpectAlmostEqIdentical3DBufferSuccess) {
@@ -475,7 +478,8 @@ TEST_F(CheckTest, ExpectAlmostEqIdentical3DBufferSuccess) {
475478
ASSERT_NO_FATAL_FAILURE(CreateFloat32BufferView(contents, shape, &lhs));
476479
ASSERT_NO_FATAL_FAILURE(CreateFloat32BufferView(contents, shape, &rhs));
477480
IREE_ASSERT_OK(Invoke("expect_almost_eq", {lhs, rhs},
478-
{/*tolerance=*/iree_vm_value_make_f32(0.f)}));
481+
{/*atol=*/iree_vm_value_make_f32(0.f),
482+
/*rtol=*/iree_vm_value_make_f32(0.f)}));
479483
}
480484

481485
TEST_F(CheckTest, ExpectAlmostEqDifferentShapeFailure) {
@@ -488,7 +492,8 @@ TEST_F(CheckTest, ExpectAlmostEqDifferentShapeFailure) {
488492
ASSERT_NO_FATAL_FAILURE(CreateFloat32BufferView(contents, rhs_shape, &rhs));
489493
EXPECT_NONFATAL_FAILURE(
490494
IREE_ASSERT_OK(Invoke("expect_almost_eq", {lhs, rhs},
491-
{/*tolerance=*/iree_vm_value_make_f32(0.f)})),
495+
{/*atol=*/iree_vm_value_make_f32(0.f),
496+
/*rtol=*/iree_vm_value_make_f32(0.f)})),
492497
"Shapes do not match");
493498
}
494499

@@ -505,7 +510,8 @@ TEST_F(CheckTest, ExpectAlmostEqSmallerLhsElementCountFailure) {
505510
CreateFloat32BufferView(bigger_contents, bigger_shape, &bigger));
506511
EXPECT_NONFATAL_FAILURE(
507512
IREE_ASSERT_OK(Invoke("expect_almost_eq", {smaller, bigger},
508-
{/*tolerance=*/iree_vm_value_make_f32(0.f)})),
513+
{/*atol=*/iree_vm_value_make_f32(0.f),
514+
/*rtol=*/iree_vm_value_make_f32(0.f)})),
509515
"Shapes do not match");
510516
}
511517

@@ -522,7 +528,8 @@ TEST_F(CheckTest, ExpectAlmostEqSmallerRhsElementCountFailure) {
522528
CreateFloat32BufferView(bigger_contents, bigger_shape, &bigger));
523529
EXPECT_NONFATAL_FAILURE(
524530
IREE_ASSERT_OK(Invoke("expect_almost_eq", {bigger, smaller},
525-
{/*tolerance=*/iree_vm_value_make_f32(0.f)})),
531+
{/*atol=*/iree_vm_value_make_f32(0.f),
532+
/*rtol=*/iree_vm_value_make_f32(0.f)})),
526533
"Shapes do not match");
527534
}
528535

@@ -536,7 +543,8 @@ TEST_F(CheckTest, ExpectAlmostEqDifferentElementTypeFailure) {
536543
ASSERT_NO_FATAL_FAILURE(CreateFloat32BufferView(rhs_contents, shape, &rhs));
537544
EXPECT_NONFATAL_FAILURE(
538545
IREE_ASSERT_OK(Invoke("expect_almost_eq", {lhs, rhs},
539-
{/*tolerance=*/iree_vm_value_make_f32(0.f)})),
546+
{/*atol=*/iree_vm_value_make_f32(0.f),
547+
/*rtol=*/iree_vm_value_make_f32(0.f)})),
540548
"Element types do not match");
541549
}
542550

@@ -550,7 +558,8 @@ TEST_F(CheckTest, ExpectAlmostEqDifferentContentsFailure) {
550558
ASSERT_NO_FATAL_FAILURE(CreateFloat32BufferView(rhs_contents, shape, &rhs));
551559
EXPECT_NONFATAL_FAILURE(
552560
IREE_ASSERT_OK(Invoke("expect_almost_eq", {lhs, rhs},
553-
{/*tolerance=*/iree_vm_value_make_f32(0.1f)})),
561+
{/*atol=*/iree_vm_value_make_f32(0.1f),
562+
/*rtol=*/iree_vm_value_make_f32(0.f)})),
554563
"Contents does not match");
555564
}
556565

@@ -569,7 +578,8 @@ TEST_F(CheckTest, ExpectAlmostEqDifferentEverythingFullMessageFailure) {
569578
// types.
570579
EXPECT_NONFATAL_FAILURE(
571580
IREE_ASSERT_OK(Invoke("expect_almost_eq", {lhs, rhs},
572-
{/*tolerance=*/iree_vm_value_make_f32(0.f)})),
581+
{/*atol=*/iree_vm_value_make_f32(0.f),
582+
/*rtol=*/iree_vm_value_make_f32(0.f)})),
573583
"Expected near equality of these values. Element types do not match."
574584
" Shapes do not match.\n"
575585
" lhs:\n"
@@ -588,9 +598,11 @@ TEST_F(CheckTest, ExpectAlmostEqDifferentContents3DFullMessageFailure) {
588598
ASSERT_NO_FATAL_FAILURE(CreateFloat32BufferView(rhs_contents, shape, &rhs));
589599
EXPECT_NONFATAL_FAILURE(
590600
IREE_ASSERT_OK(Invoke("expect_almost_eq", {lhs, rhs},
591-
{/*tolerance=*/iree_vm_value_make_f32(0.1f)})),
601+
{/*atol=*/iree_vm_value_make_f32(0.1f),
602+
/*rtol=*/iree_vm_value_make_f32(0.f)})),
592603
"Expected near equality of these values. Contents does not match to "
593-
"tolerance=0.1.\n"
604+
"tolerance parameters atol=0.1, rtol=0. The first failure occurs at "
605+
"index 3 as the lhs value 4 differs from the rhs value 42.\n"
594606
" lhs:\n"
595607
" 2x2x2xf32=[[1 2][3 4]][[5 6][7 8]]\n"
596608
" rhs:\n"
@@ -605,23 +617,23 @@ TEST_F(CheckTest, ExpectAlmostEqIdenticalBufferF16Success) {
605617
ASSERT_NO_FATAL_FAILURE(CreateFloat16BufferView(contents, shape, &lhs));
606618
ASSERT_NO_FATAL_FAILURE(CreateFloat16BufferView(contents, shape, &rhs));
607619
IREE_ASSERT_OK(Invoke("expect_almost_eq", {lhs, rhs},
608-
{/*tolerance=*/iree_vm_value_make_f32(0.f)}));
620+
{/*atol=*/iree_vm_value_make_f32(0.f),
621+
/*rtol=*/iree_vm_value_make_f32(0.f)}));
609622
}
610623

611624
TEST_F(CheckTest, ExpectAlmostEqNearIdenticalBufferF16Success) {
612625
vm::ref<iree_hal_buffer_view_t> lhs;
613626
vm::ref<iree_hal_buffer_view_t> rhs;
614-
uint16_t lhs_contents[] = {
615-
iree_math_f32_to_f16(1.0f), iree_math_f32_to_f16(1.999f),
616-
iree_math_f32_to_f16(0.001f), iree_math_f32_to_f16(4.0f)};
617-
uint16_t rhs_contents[] = {
618-
iree_math_f32_to_f16(1.001f), iree_math_f32_to_f16(2.0f),
619-
iree_math_f32_to_f16(0.0f), iree_math_f32_to_f16(4.0f)};
620-
iree_hal_dim_t shape[] = {4};
627+
uint16_t lhs_contents[] = {iree_math_f32_to_f16(10000.0f),
628+
iree_math_f32_to_f16(10000.1f)};
629+
uint16_t rhs_contents[] = {iree_math_f32_to_f16(10000.1f),
630+
iree_math_f32_to_f16(10000.0f)};
631+
iree_hal_dim_t shape[] = {2};
621632
ASSERT_NO_FATAL_FAILURE(CreateFloat16BufferView(lhs_contents, shape, &lhs));
622633
ASSERT_NO_FATAL_FAILURE(CreateFloat16BufferView(rhs_contents, shape, &rhs));
623634
IREE_ASSERT_OK(Invoke("expect_almost_eq", {lhs, rhs},
624-
{/*tolerance=*/iree_vm_value_make_f32(0.01f)}));
635+
{/*atol=*/iree_vm_value_make_f32(0.f),
636+
/*rtol=*/iree_vm_value_make_f32(1e-4f)}));
625637
}
626638

627639
TEST_F(CheckTest, ExpectAlmostEqDifferentContentsF16Failure) {
@@ -634,7 +646,8 @@ TEST_F(CheckTest, ExpectAlmostEqDifferentContentsF16Failure) {
634646
ASSERT_NO_FATAL_FAILURE(CreateFloat16BufferView(rhs_contents, shape, &rhs));
635647
EXPECT_NONFATAL_FAILURE(
636648
IREE_ASSERT_OK(Invoke("expect_almost_eq", {lhs, rhs},
637-
{/*tolerance=*/iree_vm_value_make_f32(0.1f)})),
649+
{/*atol=*/iree_vm_value_make_f32(0.1f),
650+
/*rtol=*/iree_vm_value_make_f32(0.1f)})),
638651
"Contents does not match");
639652
}
640653
} // namespace

0 commit comments

Comments
 (0)