From ab3115990d333b8b782f8ea062c881fda377a353 Mon Sep 17 00:00:00 2001 From: Oli Scherer Date: Wed, 29 Jan 2025 09:20:53 +0000 Subject: [PATCH] Pretty print pattern type values with `transmute` if they don't satisfy their pattern --- compiler/rustc_const_eval/src/lib.rs | 2 + .../src/util/check_validity_requirement.rs | 54 ++++++++++++++----- compiler/rustc_const_eval/src/util/mod.rs | 1 + compiler/rustc_middle/src/hooks/mod.rs | 4 ++ compiler/rustc_middle/src/ty/print/pretty.rs | 2 +- .../pattern_types.main.PreCodegen.after.mir | 2 +- tests/mir-opt/pattern_types.rs | 2 +- 7 files changed, 51 insertions(+), 16 deletions(-) diff --git a/compiler/rustc_const_eval/src/lib.rs b/compiler/rustc_const_eval/src/lib.rs index b44d2a01d57c0..ed5489652fba6 100644 --- a/compiler/rustc_const_eval/src/lib.rs +++ b/compiler/rustc_const_eval/src/lib.rs @@ -51,6 +51,8 @@ pub fn provide(providers: &mut Providers) { providers.check_validity_requirement = |tcx, (init_kind, param_env_and_ty)| { util::check_validity_requirement(tcx, init_kind, param_env_and_ty) }; + providers.hooks.validate_scalar_in_layout = + |tcx, scalar, layout| util::validate_scalar_in_layout(tcx, scalar, layout); } /// `rustc_driver::main` installs a handler that will set this to `true` if diff --git a/compiler/rustc_const_eval/src/util/check_validity_requirement.rs b/compiler/rustc_const_eval/src/util/check_validity_requirement.rs index a729d9325c84a..79baf91c3ce64 100644 --- a/compiler/rustc_const_eval/src/util/check_validity_requirement.rs +++ b/compiler/rustc_const_eval/src/util/check_validity_requirement.rs @@ -1,9 +1,10 @@ use rustc_abi::{BackendRepr, FieldsShape, Scalar, Variants}; -use rustc_middle::bug; use rustc_middle::ty::layout::{ HasTyCtxt, LayoutCx, LayoutError, LayoutOf, TyAndLayout, ValidityRequirement, }; -use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt}; +use rustc_middle::ty::{PseudoCanonicalInput, ScalarInt, Ty, TyCtxt}; +use rustc_middle::{bug, ty}; +use rustc_span::DUMMY_SP; use crate::const_eval::{CanAccessMutGlobal, CheckAlignment, CompileTimeMachine}; use crate::interpret::{InterpCx, MemoryKind}; @@ -34,7 +35,7 @@ pub fn check_validity_requirement<'tcx>( let layout_cx = LayoutCx::new(tcx, input.typing_env); if kind == ValidityRequirement::Uninit || tcx.sess.opts.unstable_opts.strict_init_checks { - check_validity_requirement_strict(layout, &layout_cx, kind) + Ok(check_validity_requirement_strict(layout, &layout_cx, kind)) } else { check_validity_requirement_lax(layout, &layout_cx, kind) } @@ -46,10 +47,10 @@ fn check_validity_requirement_strict<'tcx>( ty: TyAndLayout<'tcx>, cx: &LayoutCx<'tcx>, kind: ValidityRequirement, -) -> Result> { +) -> bool { let machine = CompileTimeMachine::new(CanAccessMutGlobal::No, CheckAlignment::Error); - let mut cx = InterpCx::new(cx.tcx(), rustc_span::DUMMY_SP, cx.typing_env, machine); + let mut cx = InterpCx::new(cx.tcx(), DUMMY_SP, cx.typing_env, machine); let allocated = cx .allocate(ty, MemoryKind::Machine(crate::const_eval::MemoryKind::Heap)) @@ -69,14 +70,13 @@ fn check_validity_requirement_strict<'tcx>( // due to this. // The value we are validating is temporary and discarded at the end of this function, so // there is no point in reseting provenance and padding. - Ok(cx - .validate_operand( - &allocated.into(), - /*recursive*/ false, - /*reset_provenance_and_padding*/ false, - ) - .discard_err() - .is_some()) + cx.validate_operand( + &allocated.into(), + /*recursive*/ false, + /*reset_provenance_and_padding*/ false, + ) + .discard_err() + .is_some() } /// Implements the 'lax' (default) version of the [`check_validity_requirement`] checks; see that @@ -168,3 +168,31 @@ fn check_validity_requirement_lax<'tcx>( Ok(true) } + +pub(crate) fn validate_scalar_in_layout<'tcx>( + tcx: TyCtxt<'tcx>, + scalar: ScalarInt, + ty: Ty<'tcx>, +) -> bool { + let machine = CompileTimeMachine::new(CanAccessMutGlobal::No, CheckAlignment::Error); + + let typing_env = ty::TypingEnv::fully_monomorphized(); + let mut cx = InterpCx::new(tcx, DUMMY_SP, typing_env, machine); + + let Ok(layout) = cx.layout_of(ty) else { + bug!("could not compute layout of {scalar:?}:{ty:?}") + }; + let allocated = cx + .allocate(layout, MemoryKind::Machine(crate::const_eval::MemoryKind::Heap)) + .expect("OOM: failed to allocate for uninit check"); + + cx.write_scalar(scalar, &allocated).unwrap(); + + cx.validate_operand( + &allocated.into(), + /*recursive*/ false, + /*reset_provenance_and_padding*/ false, + ) + .discard_err() + .is_some() +} diff --git a/compiler/rustc_const_eval/src/util/mod.rs b/compiler/rustc_const_eval/src/util/mod.rs index 25a9dbb2c1117..5be5ee8d1ae97 100644 --- a/compiler/rustc_const_eval/src/util/mod.rs +++ b/compiler/rustc_const_eval/src/util/mod.rs @@ -8,6 +8,7 @@ mod type_name; pub use self::alignment::{is_disaligned, is_within_packed}; pub use self::check_validity_requirement::check_validity_requirement; +pub(crate) use self::check_validity_requirement::validate_scalar_in_layout; pub use self::compare_types::{relate_types, sub_types}; pub use self::type_name::type_name; diff --git a/compiler/rustc_middle/src/hooks/mod.rs b/compiler/rustc_middle/src/hooks/mod.rs index 276a02b4e0feb..c5ce6efcb817e 100644 --- a/compiler/rustc_middle/src/hooks/mod.rs +++ b/compiler/rustc_middle/src/hooks/mod.rs @@ -98,6 +98,10 @@ declare_hooks! { hook save_dep_graph() -> (); hook query_key_hash_verify_all() -> (); + + /// Ensure the given scalar is valid for the given type. + /// This checks non-recursive runtime validity. + hook validate_scalar_in_layout(scalar: crate::ty::ScalarInt, ty: Ty<'tcx>) -> bool; } #[cold] diff --git a/compiler/rustc_middle/src/ty/print/pretty.rs b/compiler/rustc_middle/src/ty/print/pretty.rs index 3431081b189e4..feae8ea312eaa 100644 --- a/compiler/rustc_middle/src/ty/print/pretty.rs +++ b/compiler/rustc_middle/src/ty/print/pretty.rs @@ -1741,7 +1741,7 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write { " as ", )?; } - ty::Pat(base_ty, pat) => { + ty::Pat(base_ty, pat) if self.tcx().validate_scalar_in_layout(int, ty) => { self.pretty_print_const_scalar_int(int, *base_ty, print_ty)?; p!(write(" is {pat:?}")); } diff --git a/tests/mir-opt/pattern_types.main.PreCodegen.after.mir b/tests/mir-opt/pattern_types.main.PreCodegen.after.mir index 8c99902f9b8f0..5ff90de961506 100644 --- a/tests/mir-opt/pattern_types.main.PreCodegen.after.mir +++ b/tests/mir-opt/pattern_types.main.PreCodegen.after.mir @@ -5,7 +5,7 @@ fn main() -> () { scope 1 { debug x => const 2_u32 is 1..=; scope 2 { - debug y => const 0_u32 is 1..=; + debug y => const {transmute(0x00000000): (u32) is 1..=}; } } diff --git a/tests/mir-opt/pattern_types.rs b/tests/mir-opt/pattern_types.rs index 217c64b90cbbc..0369ccf9a9d56 100644 --- a/tests/mir-opt/pattern_types.rs +++ b/tests/mir-opt/pattern_types.rs @@ -7,6 +7,6 @@ use std::pat::pattern_type; fn main() { // CHECK: debug x => const 2_u32 is 1..= let x: pattern_type!(u32 is 1..) = unsafe { std::mem::transmute(2) }; - // CHECK: debug y => const 0_u32 is 1..= + // CHECK: debug y => const {transmute(0x00000000): (u32) is 1..=} let y: pattern_type!(u32 is 1..) = unsafe { std::mem::transmute(0) }; }