This repository was archived by the owner on Mar 5, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
fix: Conversion operations having poison results #131
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
118e779
fix: Conversion operations haveing poison results.
doug-q 176d98f
test: add test for roundtripping signed ints.
doug-q d239a64
test: Add tests for exact float <-> int roundtrips
doug-q e17b7a5
tidy
doug-q 8e285f8
fmt
doug-q 3c20cd3
Merge branch 'main' into doug/fix-poison-conversions
doug-q File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,17 @@ | ||
use anyhow::{anyhow, Result}; | ||
use anyhow::{anyhow, bail, ensure, Result}; | ||
|
||
use hugr::{ | ||
extension::{ | ||
prelude::{sum_with_error, ConstError, BOOL_T}, | ||
simple_op::MakeExtensionOp, | ||
}, | ||
ops::{constant::Value, custom::ExtensionOp}, | ||
ops::{constant::Value, custom::ExtensionOp, DataflowOpTrait as _}, | ||
std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}, | ||
types::{TypeArg, TypeEnum}, | ||
types::{TypeArg, TypeEnum, TypeRow}, | ||
HugrView, | ||
}; | ||
|
||
use inkwell::{values::BasicValue, FloatPredicate, IntPredicate}; | ||
use inkwell::{types::IntType, values::BasicValue, FloatPredicate, IntPredicate}; | ||
|
||
use crate::{ | ||
custom::{CodegenExtension, CodegenExtsBuilder}, | ||
|
@@ -21,6 +21,7 @@ use crate::{ | |
EmitOpArgs, | ||
}, | ||
sum::LLVMSumValue, | ||
types::HugrType, | ||
}; | ||
|
||
fn build_trunc_op<'c, H: HugrView>( | ||
|
@@ -29,53 +30,58 @@ fn build_trunc_op<'c, H: HugrView>( | |
log_width: u64, | ||
args: EmitOpArgs<'c, '_, ExtensionOp, H>, | ||
) -> Result<()> { | ||
// Note: This logic is copied from `llvm_type` in the IntTypes | ||
// extension. We need to have a common source of truth for this. | ||
let (width, (int_min_value_s, int_max_value_s), int_max_value_u) = match log_width { | ||
0..=3 => (8, (i8::MIN as i64, i8::MAX as i64), u8::MAX as u64), | ||
4 => (16, (i16::MIN as i64, i16::MAX as i64), u16::MAX as u64), | ||
5 => (32, (i32::MIN as i64, i32::MAX as i64), u32::MAX as u64), | ||
6 => (64, (i64::MIN, i64::MAX), u64::MAX), | ||
m => return Err(anyhow!("ConversionEmitter: unsupported log_width: {}", m)), | ||
}; | ||
|
||
let hugr_int_ty = INT_TYPES[log_width as usize].clone(); | ||
let int_ty = context | ||
.typing_session() | ||
.llvm_type(&hugr_int_ty)? | ||
.into_int_type(); | ||
let hugr_sum_ty = sum_with_error(vec![hugr_int_ty.clone()]); | ||
// TODO: it would be nice to get this info out of `ops.node()`, this would | ||
// require adding appropriate methods to `ConvertOpDef`. In the meantime, we | ||
// assert that the output types are as we expect. | ||
debug_assert_eq!( | ||
TypeRow::from(vec![HugrType::from(hugr_sum_ty.clone())]), | ||
args.node().signature().output | ||
); | ||
|
||
let Some(int_ty) = IntType::try_from(context.llvm_type(&hugr_int_ty)?).ok() else { | ||
bail!("Expected `arithmetic.int` to lower to an llvm integer") | ||
}; | ||
|
||
let hugr_sum_ty = sum_with_error(vec![hugr_int_ty]); | ||
let sum_ty = context.typing_session().llvm_sum_type(hugr_sum_ty)?; | ||
let sum_ty = context.llvm_sum_type(hugr_sum_ty)?; | ||
|
||
let (width, int_min_value_s, int_max_value_s, int_max_value_u) = { | ||
ensure!( | ||
log_width <= 6, | ||
"Expected log_width of output to be <= 6, found: {log_width}" | ||
); | ||
let width = 1 << log_width; | ||
( | ||
width, | ||
i64::MIN >> (64 - width), | ||
i64::MAX >> (64 - width), | ||
u64::MAX >> (64 - width), | ||
) | ||
}; | ||
|
||
emit_custom_unary_op(context, args, |ctx, arg, _| { | ||
// We have to check if the conversion will work, so we | ||
// make the maximum int and convert to a float, then compare | ||
// with the function input. | ||
let flt_max = if signed { | ||
ctx.iw_context() | ||
.f64_type() | ||
.const_float(int_max_value_s as f64) | ||
let flt_max = ctx.iw_context().f64_type().const_float(if signed { | ||
int_max_value_s as f64 | ||
} else { | ||
ctx.iw_context() | ||
.f64_type() | ||
.const_float(int_max_value_u as f64) | ||
}; | ||
int_max_value_u as f64 | ||
}); | ||
|
||
let within_upper_bound = ctx.builder().build_float_compare( | ||
FloatPredicate::OLE, | ||
FloatPredicate::OLT, | ||
arg.into_float_value(), | ||
flt_max, | ||
"within_upper_bound", | ||
)?; | ||
|
||
let flt_min = if signed { | ||
ctx.iw_context() | ||
.f64_type() | ||
.const_float(int_min_value_s as f64) | ||
let flt_min = ctx.iw_context().f64_type().const_float(if signed { | ||
int_min_value_s as f64 | ||
} else { | ||
ctx.iw_context().f64_type().const_float(0.0) | ||
}; | ||
0.0 | ||
}); | ||
|
||
let within_lower_bound = ctx.builder().build_float_compare( | ||
FloatPredicate::OLE, | ||
|
@@ -401,7 +407,7 @@ mod test { | |
assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main")); | ||
} | ||
|
||
fn roundtrip_hugr(val: u64) -> Hugr { | ||
fn roundtrip_hugr(val: u64, signed: bool) -> Hugr { | ||
let int64 = INT_TYPES[6].clone(); | ||
SimpleHugrConfig::new() | ||
.with_outs(USIZE_T) | ||
|
@@ -412,14 +418,23 @@ mod test { | |
.add_dataflow_op(ConvertOpDef::ifromusize.without_log_width(), [k]) | ||
.unwrap() | ||
.outputs_arr(); | ||
let [flt] = builder | ||
.add_dataflow_op(ConvertOpDef::convert_u.with_log_width(6), [int]) | ||
.unwrap() | ||
.outputs_arr(); | ||
let [int_or_err] = builder | ||
.add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(6), [flt]) | ||
.unwrap() | ||
.outputs_arr(); | ||
let [flt] = { | ||
let op = if signed { | ||
ConvertOpDef::convert_s.with_log_width(6) | ||
} else { | ||
ConvertOpDef::convert_u.with_log_width(6) | ||
}; | ||
builder.add_dataflow_op(op, [int]).unwrap().outputs_arr() | ||
}; | ||
|
||
let [int_or_err] = { | ||
let op = if signed { | ||
ConvertOpDef::trunc_s.with_log_width(6) | ||
} else { | ||
ConvertOpDef::trunc_u.with_log_width(6) | ||
}; | ||
builder.add_dataflow_op(op, [flt]).unwrap().outputs_arr() | ||
}; | ||
let sum_ty = sum_with_error(int64.clone()); | ||
let variants = (0..sum_ty.num_variants()) | ||
.map(|i| sum_ty.get_variant(i).unwrap().clone().try_into().unwrap()); | ||
|
@@ -467,25 +482,89 @@ mod test { | |
#[case(4294967295)] | ||
#[case(42)] | ||
#[case(18_000_000_000_000_000_000)] | ||
fn roundtrip(mut exec_ctx: TestContext, #[case] val: u64) { | ||
fn roundtrip_unsigned(mut exec_ctx: TestContext, #[case] val: u64) { | ||
add_extensions(&mut exec_ctx); | ||
let hugr = roundtrip_hugr(val); | ||
let hugr = roundtrip_hugr(val, false); | ||
assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main")); | ||
} | ||
|
||
// N.B.: There's some strange behaviour at the upper end of the ints - the | ||
// first case gets converted to something that's off by 1,000, but the second | ||
// (which is (2 ^ 64) - 1) gets converted to (2 ^ 32) - off by 9 million! | ||
// The fact that the first case works as expected tells me this isn't to do | ||
// with int widths - maybe a floating point expert could explain that this | ||
// is standard behaviour... | ||
#[rstest] | ||
#[case(18_446_744_073_709_550_000, 18_446_744_073_709_549_568)] | ||
#[case(18_446_744_073_709_551_615, 9_223_372_036_854_775_808)] // 2 ^ 63 | ||
fn approx_roundtrip(mut exec_ctx: TestContext, #[case] val: u64, #[case] expected: u64) { | ||
// Exact roundtrip conversion is defined on values up to 2**53 for f64. | ||
#[case(0)] | ||
#[case(3)] | ||
#[case(255)] | ||
#[case(4294967295)] | ||
#[case(42)] | ||
#[case(-9_000_000_000_000_000_000)] | ||
fn roundtrip_signed(mut exec_ctx: TestContext, #[case] val: i64) { | ||
add_extensions(&mut exec_ctx); | ||
let hugr = roundtrip_hugr(val as u64, true); | ||
assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main") as i64); | ||
} | ||
|
||
// For unisgined ints larger than (1 << 54) - 1, f64s do not have enough | ||
// precision to exactly roundtrip the int. | ||
// The exact behaviour of the round-trip is is platform-dependent. | ||
#[rstest] | ||
#[case(u64::MAX)] | ||
#[case(u64::MAX - 1)] | ||
#[case(u64::MAX - (1 << 1))] | ||
#[case(u64::MAX - (1 << 2))] | ||
#[case(u64::MAX - (1 << 3))] | ||
#[case(u64::MAX - (1 << 4))] | ||
#[case(u64::MAX - (1 << 5))] | ||
#[case(u64::MAX - (1 << 6))] | ||
#[case(u64::MAX - (1 << 7))] | ||
#[case(u64::MAX - (1 << 8))] | ||
#[case(u64::MAX - (1 << 9))] | ||
#[case(u64::MAX - (1 << 10))] | ||
#[case(u64::MAX - (1 << 11))] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not happy with these cases and those for the signed variant below. Any ideas how to improve this? |
||
fn approx_roundtrip_unsigned(mut exec_ctx: TestContext, #[case] val: u64) { | ||
add_extensions(&mut exec_ctx); | ||
|
||
let hugr = roundtrip_hugr(val, false); | ||
let result = exec_ctx.exec_hugr_u64(hugr, "main"); | ||
let (v_r_max, v_r_min) = (val.max(result), val.min(result)); | ||
// If val is too large the `trunc_u` op in `hugr` will return None. | ||
// In this case the hugr returns the magic number `999`. | ||
assert!(result == 999 || (v_r_max - v_r_min) < 1 << 10); | ||
} | ||
|
||
#[rstest] | ||
#[case(i64::MAX)] | ||
#[case(i64::MAX - 1)] | ||
#[case(i64::MAX - (1 << 1))] | ||
#[case(i64::MAX - (1 << 2))] | ||
#[case(i64::MAX - (1 << 3))] | ||
#[case(i64::MAX - (1 << 4))] | ||
#[case(i64::MAX - (1 << 5))] | ||
#[case(i64::MAX - (1 << 6))] | ||
#[case(i64::MAX - (1 << 7))] | ||
#[case(i64::MAX - (1 << 8))] | ||
#[case(i64::MAX - (1 << 9))] | ||
#[case(i64::MAX - (1 << 10))] | ||
#[case(i64::MAX - (1 << 11))] | ||
#[case(i64::MIN)] | ||
#[case(i64::MIN + 1)] | ||
#[case(i64::MIN + (1 << 1))] | ||
#[case(i64::MIN + (1 << 2))] | ||
#[case(i64::MIN + (1 << 3))] | ||
#[case(i64::MIN + (1 << 4))] | ||
#[case(i64::MIN + (1 << 5))] | ||
#[case(i64::MIN + (1 << 6))] | ||
#[case(i64::MIN + (1 << 7))] | ||
#[case(i64::MIN + (1 << 8))] | ||
#[case(i64::MIN + (1 << 9))] | ||
#[case(i64::MIN + (1 << 10))] | ||
#[case(i64::MIN + (1 << 11))] | ||
fn approx_roundtrip_signed(mut exec_ctx: TestContext, #[case] val: i64) { | ||
add_extensions(&mut exec_ctx); | ||
let hugr = roundtrip_hugr(val); | ||
assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main")); | ||
|
||
let hugr = roundtrip_hugr(val as u64, true); | ||
let result = exec_ctx.exec_hugr_u64(hugr, "main") as i64; | ||
// If val.abs() is too large the `trunc_s` op in `hugr` will return None. | ||
// In this case the hugr returns the magic number `999`. | ||
assert!(result == 999 || (val - result).abs() < 1 << 10); | ||
} | ||
|
||
#[rstest] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These lines scare me a bit (I think in C shift operations on signed integers are "implementation dependent"?)
According to https://doc.rust-lang.org/reference/expressions/operator-expr.html#arithmetic-and-logical-binary-operators these are arithmetic right shift, but which way does rounding go? I couldn't find chapter and verse on this.
It's probably fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think rounding is not the right way to think about it. Shifting is a bit operation, which happens to mean "divide by 2 rounding down, repeat". Instead, the bit pattern of
i64::MIN
is1000...
. Arithmetic right shifting(i.e. sign extending so that new most-significant-bits are 1) the 1 to the new most-significant-bit is the correct thing.Similarly the bit pattern of
i64::MAX
is0111111
. Arithmetic right shifting (i.e. new most-significant-bits are 0) the 0 to the new most-significant-bit is the correct thing.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this guaranteed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I believe so. https://doc.rust-lang.org/book/ch03-02-data-types.html#integer-types says "Signed numbers are stored using two’s complement representation."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! All good then.