Skip to content
This repository was archived by the owner on Mar 5, 2025. It is now read-only.

fix: Conversion operations having poison results #131

Merged
merged 6 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 136 additions & 57 deletions src/extension/conversions.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, bail, ensure, Result};

use hugr::{
extension::{
prelude::{sum_with_error, ConstError, BOOL_T},
simple_op::MakeExtensionOp,
},
ops::{constant::Value, custom::ExtensionOp},
ops::{constant::Value, custom::ExtensionOp, DataflowOpTrait as _},
std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES},
types::{TypeArg, TypeEnum},
types::{TypeArg, TypeEnum, TypeRow},
HugrView,
};

use inkwell::{values::BasicValue, FloatPredicate, IntPredicate};
use inkwell::{types::IntType, values::BasicValue, FloatPredicate, IntPredicate};

use crate::{
custom::{CodegenExtension, CodegenExtsBuilder},
Expand All @@ -21,6 +21,7 @@ use crate::{
EmitOpArgs,
},
sum::LLVMSumValue,
types::HugrType,
};

fn build_trunc_op<'c, H: HugrView>(
Expand All @@ -29,53 +30,58 @@ fn build_trunc_op<'c, H: HugrView>(
log_width: u64,
args: EmitOpArgs<'c, '_, ExtensionOp, H>,
) -> Result<()> {
// Note: This logic is copied from `llvm_type` in the IntTypes
// extension. We need to have a common source of truth for this.
let (width, (int_min_value_s, int_max_value_s), int_max_value_u) = match log_width {
0..=3 => (8, (i8::MIN as i64, i8::MAX as i64), u8::MAX as u64),
4 => (16, (i16::MIN as i64, i16::MAX as i64), u16::MAX as u64),
5 => (32, (i32::MIN as i64, i32::MAX as i64), u32::MAX as u64),
6 => (64, (i64::MIN, i64::MAX), u64::MAX),
m => return Err(anyhow!("ConversionEmitter: unsupported log_width: {}", m)),
};

let hugr_int_ty = INT_TYPES[log_width as usize].clone();
let int_ty = context
.typing_session()
.llvm_type(&hugr_int_ty)?
.into_int_type();
let hugr_sum_ty = sum_with_error(vec![hugr_int_ty.clone()]);
// TODO: it would be nice to get this info out of `ops.node()`, this would
// require adding appropriate methods to `ConvertOpDef`. In the meantime, we
// assert that the output types are as we expect.
debug_assert_eq!(
TypeRow::from(vec![HugrType::from(hugr_sum_ty.clone())]),
args.node().signature().output
);

let Some(int_ty) = IntType::try_from(context.llvm_type(&hugr_int_ty)?).ok() else {
bail!("Expected `arithmetic.int` to lower to an llvm integer")
};

let hugr_sum_ty = sum_with_error(vec![hugr_int_ty]);
let sum_ty = context.typing_session().llvm_sum_type(hugr_sum_ty)?;
let sum_ty = context.llvm_sum_type(hugr_sum_ty)?;

let (width, int_min_value_s, int_max_value_s, int_max_value_u) = {
ensure!(
log_width <= 6,
"Expected log_width of output to be <= 6, found: {log_width}"
);
let width = 1 << log_width;
(
width,
i64::MIN >> (64 - width),
i64::MAX >> (64 - width),
Comment on lines +57 to +58
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines scare me a bit (I think in C shift operations on signed integers are "implementation dependent"?)

According to https://doc.rust-lang.org/reference/expressions/operator-expr.html#arithmetic-and-logical-binary-operators these are arithmetic right shift, but which way does rounding go? I couldn't find chapter and verse on this.

It's probably fine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think rounding is not the right way to think about it. Shifting is a bit operation, which happens to mean "divide by 2 rounding down, repeat". Instead, the bit pattern of i64::MIN is 1000.... Arithmetic right shifting(i.e. sign extending so that new most-significant-bits are 1) the 1 to the new most-significant-bit is the correct thing.
Similarly the bit pattern of i64::MAX is 0111111. Arithmetic right shifting (i.e. new most-significant-bits are 0) the 0 to the new most-significant-bit is the correct thing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the bit pattern of i64::MIN is 1000...

Is this guaranteed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I believe so. https://doc.rust-lang.org/book/ch03-02-data-types.html#integer-types says "Signed numbers are stored using two’s complement representation."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! All good then.

u64::MAX >> (64 - width),
)
};

emit_custom_unary_op(context, args, |ctx, arg, _| {
// We have to check if the conversion will work, so we
// make the maximum int and convert to a float, then compare
// with the function input.
let flt_max = if signed {
ctx.iw_context()
.f64_type()
.const_float(int_max_value_s as f64)
let flt_max = ctx.iw_context().f64_type().const_float(if signed {
int_max_value_s as f64
} else {
ctx.iw_context()
.f64_type()
.const_float(int_max_value_u as f64)
};
int_max_value_u as f64
});

let within_upper_bound = ctx.builder().build_float_compare(
FloatPredicate::OLE,
FloatPredicate::OLT,
arg.into_float_value(),
flt_max,
"within_upper_bound",
)?;

let flt_min = if signed {
ctx.iw_context()
.f64_type()
.const_float(int_min_value_s as f64)
let flt_min = ctx.iw_context().f64_type().const_float(if signed {
int_min_value_s as f64
} else {
ctx.iw_context().f64_type().const_float(0.0)
};
0.0
});

let within_lower_bound = ctx.builder().build_float_compare(
FloatPredicate::OLE,
Expand Down Expand Up @@ -401,7 +407,7 @@ mod test {
assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main"));
}

fn roundtrip_hugr(val: u64) -> Hugr {
fn roundtrip_hugr(val: u64, signed: bool) -> Hugr {
let int64 = INT_TYPES[6].clone();
SimpleHugrConfig::new()
.with_outs(USIZE_T)
Expand All @@ -412,14 +418,23 @@ mod test {
.add_dataflow_op(ConvertOpDef::ifromusize.without_log_width(), [k])
.unwrap()
.outputs_arr();
let [flt] = builder
.add_dataflow_op(ConvertOpDef::convert_u.with_log_width(6), [int])
.unwrap()
.outputs_arr();
let [int_or_err] = builder
.add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(6), [flt])
.unwrap()
.outputs_arr();
let [flt] = {
let op = if signed {
ConvertOpDef::convert_s.with_log_width(6)
} else {
ConvertOpDef::convert_u.with_log_width(6)
};
builder.add_dataflow_op(op, [int]).unwrap().outputs_arr()
};

let [int_or_err] = {
let op = if signed {
ConvertOpDef::trunc_s.with_log_width(6)
} else {
ConvertOpDef::trunc_u.with_log_width(6)
};
builder.add_dataflow_op(op, [flt]).unwrap().outputs_arr()
};
let sum_ty = sum_with_error(int64.clone());
let variants = (0..sum_ty.num_variants())
.map(|i| sum_ty.get_variant(i).unwrap().clone().try_into().unwrap());
Expand Down Expand Up @@ -467,25 +482,89 @@ mod test {
#[case(4294967295)]
#[case(42)]
#[case(18_000_000_000_000_000_000)]
fn roundtrip(mut exec_ctx: TestContext, #[case] val: u64) {
fn roundtrip_unsigned(mut exec_ctx: TestContext, #[case] val: u64) {
add_extensions(&mut exec_ctx);
let hugr = roundtrip_hugr(val);
let hugr = roundtrip_hugr(val, false);
assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main"));
}

// N.B.: There's some strange behaviour at the upper end of the ints - the
// first case gets converted to something that's off by 1,000, but the second
// (which is (2 ^ 64) - 1) gets converted to (2 ^ 32) - off by 9 million!
// The fact that the first case works as expected tells me this isn't to do
// with int widths - maybe a floating point expert could explain that this
// is standard behaviour...
#[rstest]
#[case(18_446_744_073_709_550_000, 18_446_744_073_709_549_568)]
#[case(18_446_744_073_709_551_615, 9_223_372_036_854_775_808)] // 2 ^ 63
fn approx_roundtrip(mut exec_ctx: TestContext, #[case] val: u64, #[case] expected: u64) {
// Exact roundtrip conversion is defined on values up to 2**53 for f64.
#[case(0)]
#[case(3)]
#[case(255)]
#[case(4294967295)]
#[case(42)]
#[case(-9_000_000_000_000_000_000)]
fn roundtrip_signed(mut exec_ctx: TestContext, #[case] val: i64) {
add_extensions(&mut exec_ctx);
let hugr = roundtrip_hugr(val as u64, true);
assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main") as i64);
}

// For unisgined ints larger than (1 << 54) - 1, f64s do not have enough
// precision to exactly roundtrip the int.
// The exact behaviour of the round-trip is is platform-dependent.
#[rstest]
#[case(u64::MAX)]
#[case(u64::MAX - 1)]
#[case(u64::MAX - (1 << 1))]
#[case(u64::MAX - (1 << 2))]
#[case(u64::MAX - (1 << 3))]
#[case(u64::MAX - (1 << 4))]
#[case(u64::MAX - (1 << 5))]
#[case(u64::MAX - (1 << 6))]
#[case(u64::MAX - (1 << 7))]
#[case(u64::MAX - (1 << 8))]
#[case(u64::MAX - (1 << 9))]
#[case(u64::MAX - (1 << 10))]
#[case(u64::MAX - (1 << 11))]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not happy with these cases and those for the signed variant below. Any ideas how to improve this?

fn approx_roundtrip_unsigned(mut exec_ctx: TestContext, #[case] val: u64) {
add_extensions(&mut exec_ctx);

let hugr = roundtrip_hugr(val, false);
let result = exec_ctx.exec_hugr_u64(hugr, "main");
let (v_r_max, v_r_min) = (val.max(result), val.min(result));
// If val is too large the `trunc_u` op in `hugr` will return None.
// In this case the hugr returns the magic number `999`.
assert!(result == 999 || (v_r_max - v_r_min) < 1 << 10);
}

#[rstest]
#[case(i64::MAX)]
#[case(i64::MAX - 1)]
#[case(i64::MAX - (1 << 1))]
#[case(i64::MAX - (1 << 2))]
#[case(i64::MAX - (1 << 3))]
#[case(i64::MAX - (1 << 4))]
#[case(i64::MAX - (1 << 5))]
#[case(i64::MAX - (1 << 6))]
#[case(i64::MAX - (1 << 7))]
#[case(i64::MAX - (1 << 8))]
#[case(i64::MAX - (1 << 9))]
#[case(i64::MAX - (1 << 10))]
#[case(i64::MAX - (1 << 11))]
#[case(i64::MIN)]
#[case(i64::MIN + 1)]
#[case(i64::MIN + (1 << 1))]
#[case(i64::MIN + (1 << 2))]
#[case(i64::MIN + (1 << 3))]
#[case(i64::MIN + (1 << 4))]
#[case(i64::MIN + (1 << 5))]
#[case(i64::MIN + (1 << 6))]
#[case(i64::MIN + (1 << 7))]
#[case(i64::MIN + (1 << 8))]
#[case(i64::MIN + (1 << 9))]
#[case(i64::MIN + (1 << 10))]
#[case(i64::MIN + (1 << 11))]
fn approx_roundtrip_signed(mut exec_ctx: TestContext, #[case] val: i64) {
add_extensions(&mut exec_ctx);
let hugr = roundtrip_hugr(val);
assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));

let hugr = roundtrip_hugr(val as u64, true);
let result = exec_ctx.exec_hugr_u64(hugr, "main") as i64;
// If val.abs() is too large the `trunc_s` op in `hugr` will return None.
// In this case the hugr returns the magic number `999`.
assert!(result == 999 || (val - result).abs() < 1 << 10);
}

#[rstest]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%within_upper_bound = fcmp ole double %0, 0x41DFFFFFFFC00000
%within_upper_bound = fcmp olt double %0, 0x41DFFFFFFFC00000
%within_lower_bound = fcmp ole double 0xC1E0000000000000, %0
%success = and i1 %within_upper_bound, %within_lower_bound
%trunc_result = fptosi double %0 to i32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ alloca_block:
entry_block: ; preds = %alloca_block
store double %0, double* %"2_0", align 8
%"2_01" = load double, double* %"2_0", align 8
%within_upper_bound = fcmp ole double %"2_01", 0x41DFFFFFFFC00000
%within_upper_bound = fcmp olt double %"2_01", 0x41DFFFFFFFC00000
%within_lower_bound = fcmp ole double 0xC1E0000000000000, %"2_01"
%success = and i1 %within_upper_bound, %within_lower_bound
%trunc_result = fptosi double %"2_01" to i32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%within_upper_bound = fcmp ole double %0, 0x43F0000000000000
%within_upper_bound = fcmp olt double %0, 0x43F0000000000000
%within_lower_bound = fcmp ole double 0.000000e+00, %0
%success = and i1 %within_upper_bound, %within_lower_bound
%trunc_result = fptoui double %0 to i64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ alloca_block:
entry_block: ; preds = %alloca_block
store double %0, double* %"2_0", align 8
%"2_01" = load double, double* %"2_0", align 8
%within_upper_bound = fcmp ole double %"2_01", 0x43F0000000000000
%within_upper_bound = fcmp olt double %"2_01", 0x43F0000000000000
%within_lower_bound = fcmp ole double 0.000000e+00, %"2_01"
%success = and i1 %within_upper_bound, %within_lower_bound
%trunc_result = fptoui double %"2_01" to i64
Expand Down