Skip to content

Commit 553828c

Browse files
committed
Mark more LLVM FFI as safe
1 parent 3565603 commit 553828c

File tree

5 files changed

+68
-58
lines changed

5 files changed

+68
-58
lines changed

compiler/rustc_codegen_llvm/src/asm.rs

+47-46
Original file line numberDiff line numberDiff line change
@@ -482,12 +482,13 @@ pub(crate) fn inline_asm_call<'ll>(
482482

483483
debug!("Asm Output Type: {:?}", output);
484484
let fty = bx.cx.type_func(&argtys, output);
485-
unsafe {
486-
// Ask LLVM to verify that the constraints are well-formed.
487-
let constraints_ok = llvm::LLVMRustInlineAsmVerify(fty, cons.as_c_char_ptr(), cons.len());
488-
debug!("constraint verification result: {:?}", constraints_ok);
489-
if constraints_ok {
490-
let v = llvm::LLVMRustInlineAsm(
485+
// Ask LLVM to verify that the constraints are well-formed.
486+
let constraints_ok =
487+
unsafe { llvm::LLVMRustInlineAsmVerify(fty, cons.as_c_char_ptr(), cons.len()) };
488+
debug!("constraint verification result: {:?}", constraints_ok);
489+
if constraints_ok {
490+
let v = unsafe {
491+
llvm::LLVMRustInlineAsm(
491492
fty,
492493
asm.as_c_char_ptr(),
493494
asm.len(),
@@ -497,50 +498,50 @@ pub(crate) fn inline_asm_call<'ll>(
497498
alignstack,
498499
dia,
499500
can_throw,
500-
);
501-
502-
let call = if !labels.is_empty() {
503-
assert!(catch_funclet.is_none());
504-
bx.callbr(fty, None, None, v, inputs, dest.unwrap(), labels, None, None)
505-
} else if let Some((catch, funclet)) = catch_funclet {
506-
bx.invoke(fty, None, None, v, inputs, dest.unwrap(), catch, funclet, None)
507-
} else {
508-
bx.call(fty, None, None, v, inputs, None, None)
509-
};
501+
)
502+
};
510503

511-
// Store mark in a metadata node so we can map LLVM errors
512-
// back to source locations. See #17552.
513-
let key = "srcloc";
514-
let kind = bx.get_md_kind_id(key);
504+
let call = if !labels.is_empty() {
505+
assert!(catch_funclet.is_none());
506+
bx.callbr(fty, None, None, v, inputs, dest.unwrap(), labels, None, None)
507+
} else if let Some((catch, funclet)) = catch_funclet {
508+
bx.invoke(fty, None, None, v, inputs, dest.unwrap(), catch, funclet, None)
509+
} else {
510+
bx.call(fty, None, None, v, inputs, None, None)
511+
};
515512

516-
// `srcloc` contains one 64-bit integer for each line of assembly code,
517-
// where the lower 32 bits hold the lo byte position and the upper 32 bits
518-
// hold the hi byte position.
519-
let mut srcloc = vec![];
520-
if dia == llvm::AsmDialect::Intel && line_spans.len() > 1 {
521-
// LLVM inserts an extra line to add the ".intel_syntax", so add
522-
// a dummy srcloc entry for it.
523-
//
524-
// Don't do this if we only have 1 line span since that may be
525-
// due to the asm template string coming from a macro. LLVM will
526-
// default to the first srcloc for lines that don't have an
527-
// associated srcloc.
528-
srcloc.push(llvm::LLVMValueAsMetadata(bx.const_u64(0)));
529-
}
530-
srcloc.extend(line_spans.iter().map(|span| {
531-
llvm::LLVMValueAsMetadata(bx.const_u64(
532-
u64::from(span.lo().to_u32()) | (u64::from(span.hi().to_u32()) << 32),
533-
))
534-
}));
535-
let md = llvm::LLVMMDNodeInContext2(bx.llcx, srcloc.as_ptr(), srcloc.len());
536-
let md = bx.get_metadata_value(md);
537-
llvm::LLVMSetMetadata(call, kind, md);
513+
// Store mark in a metadata node so we can map LLVM errors
514+
// back to source locations. See #17552.
515+
let key = "srcloc";
516+
let kind = bx.get_md_kind_id(key);
538517

539-
Some(call)
540-
} else {
541-
// LLVM has detected an issue with our constraints, bail out
542-
None
518+
// `srcloc` contains one 64-bit integer for each line of assembly code,
519+
// where the lower 32 bits hold the lo byte position and the upper 32 bits
520+
// hold the hi byte position.
521+
let mut srcloc = vec![];
522+
if dia == llvm::AsmDialect::Intel && line_spans.len() > 1 {
523+
// LLVM inserts an extra line to add the ".intel_syntax", so add
524+
// a dummy srcloc entry for it.
525+
//
526+
// Don't do this if we only have 1 line span since that may be
527+
// due to the asm template string coming from a macro. LLVM will
528+
// default to the first srcloc for lines that don't have an
529+
// associated srcloc.
530+
srcloc.push(llvm::LLVMValueAsMetadata(bx.const_u64(0)));
543531
}
532+
srcloc.extend(line_spans.iter().map(|span| {
533+
llvm::LLVMValueAsMetadata(
534+
bx.const_u64(u64::from(span.lo().to_u32()) | (u64::from(span.hi().to_u32()) << 32)),
535+
)
536+
}));
537+
let md = unsafe { llvm::LLVMMDNodeInContext2(bx.llcx, srcloc.as_ptr(), srcloc.len()) };
538+
let md = bx.get_metadata_value(md);
539+
llvm::LLVMSetMetadata(call, kind, md);
540+
541+
Some(call)
542+
} else {
543+
// LLVM has detected an issue with our constraints, bail out
544+
None
544545
}
545546
}
546547

compiler/rustc_codegen_llvm/src/builder.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,8 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
311311
// This function handles switch instructions with more than 2 targets and it needs to
312312
// emit branch weights metadata instead of using the intrinsic.
313313
// The values 1 and 2000 are the same as the values used by the `llvm.expect` intrinsic.
314-
let cold_weight = unsafe { llvm::LLVMValueAsMetadata(self.cx.const_u32(1)) };
315-
let hot_weight = unsafe { llvm::LLVMValueAsMetadata(self.cx.const_u32(2000)) };
314+
let cold_weight = llvm::LLVMValueAsMetadata(self.cx.const_u32(1));
315+
let hot_weight = llvm::LLVMValueAsMetadata(self.cx.const_u32(2000));
316316
let weight =
317317
|is_cold: bool| -> &Metadata { if is_cold { cold_weight } else { hot_weight } };
318318

compiler/rustc_codegen_llvm/src/context.rs

+4-6
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use crate::back::write::to_llvm_code_model;
3434
use crate::callee::get_fn;
3535
use crate::common::AsCCharPtr;
3636
use crate::debuginfo::metadata::apply_vcall_visibility_metadata;
37-
use crate::llvm::{Metadata, MetadataType};
37+
use crate::llvm::Metadata;
3838
use crate::type_::Type;
3939
use crate::value::Value;
4040
use crate::{attributes, coverageinfo, debuginfo, llvm, llvm_util};
@@ -664,7 +664,7 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
664664
unsafe { llvm::LLVMGetNamedFunction((**self).borrow().llmod, name.as_ptr()) }
665665
}
666666

667-
pub(crate) fn get_md_kind_id(&self, name: &str) -> u32 {
667+
pub(crate) fn get_md_kind_id(&self, name: &str) -> llvm::MetadataKindId {
668668
unsafe {
669669
llvm::LLVMGetMDKindIDInContext(
670670
self.llcx(),
@@ -1228,13 +1228,11 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
12281228
pub(crate) fn set_metadata<'a>(
12291229
&self,
12301230
val: &'a Value,
1231-
kind_id: MetadataType,
1231+
kind_id: impl Into<llvm::MetadataKindId>,
12321232
md: &'ll Metadata,
12331233
) {
12341234
let node = self.get_metadata_value(md);
1235-
unsafe {
1236-
llvm::LLVMSetMetadata(val, kind_id as c_uint, node);
1237-
}
1235+
llvm::LLVMSetMetadata(val, kind_id.into(), node);
12381236
}
12391237
}
12401238

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33

44
use libc::{c_char, c_uint};
55

6+
use super::MetadataKindId;
67
use super::ffi::{BasicBlock, Metadata, Module, Type, Value};
78
use crate::llvm::Bool;
89

910
#[link(name = "llvm-wrapper", kind = "static")]
1011
unsafe extern "C" {
1112
// Enzyme
12-
pub(crate) fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool;
13+
pub(crate) safe fn LLVMRustHasMetadata(I: &Value, KindID: MetadataKindId) -> bool;
1314
pub(crate) fn LLVMRustEraseInstUntilInclusive(BB: &BasicBlock, I: &Value);
1415
pub(crate) fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>;
1516
pub(crate) fn LLVMRustDIGetInstMetadata(I: &Value) -> Option<&Metadata>;

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

+13-3
Original file line numberDiff line numberDiff line change
@@ -976,14 +976,24 @@ pub type SelfProfileAfterPassCallback = unsafe extern "C" fn(*mut c_void);
976976
pub type GetSymbolsCallback = unsafe extern "C" fn(*mut c_void, *const c_char) -> *mut c_void;
977977
pub type GetSymbolsErrorCallback = unsafe extern "C" fn(*const c_char) -> *mut c_void;
978978

979+
#[derive(Copy, Clone)]
980+
#[repr(transparent)]
981+
pub struct MetadataKindId(c_uint);
982+
983+
impl From<MetadataType> for MetadataKindId {
984+
fn from(value: MetadataType) -> Self {
985+
Self(value as c_uint)
986+
}
987+
}
988+
979989
unsafe extern "C" {
980990
// Create and destroy contexts.
981991
pub(crate) fn LLVMContextDispose(C: &'static mut Context);
982992
pub(crate) fn LLVMGetMDKindIDInContext(
983993
C: &Context,
984994
Name: *const c_char,
985995
SLen: c_uint,
986-
) -> c_uint;
996+
) -> MetadataKindId;
987997

988998
// Create modules.
989999
pub(crate) fn LLVMModuleCreateWithNameInContext(
@@ -1051,9 +1061,9 @@ unsafe extern "C" {
10511061
pub(crate) fn LLVMGetValueName2(Val: &Value, Length: *mut size_t) -> *const c_char;
10521062
pub(crate) fn LLVMSetValueName2(Val: &Value, Name: *const c_char, NameLen: size_t);
10531063
pub(crate) fn LLVMReplaceAllUsesWith<'a>(OldVal: &'a Value, NewVal: &'a Value);
1054-
pub(crate) fn LLVMSetMetadata<'a>(Val: &'a Value, KindID: c_uint, Node: &'a Value);
1064+
pub(crate) safe fn LLVMSetMetadata<'a>(Val: &'a Value, KindID: MetadataKindId, Node: &'a Value);
10551065
pub(crate) fn LLVMGlobalSetMetadata<'a>(Val: &'a Value, KindID: c_uint, Metadata: &'a Metadata);
1056-
pub(crate) fn LLVMValueAsMetadata(Node: &Value) -> &Metadata;
1066+
pub(crate) safe fn LLVMValueAsMetadata(Node: &Value) -> &Metadata;
10571067

10581068
// Operations on constants of any type
10591069
pub(crate) fn LLVMConstNull(Ty: &Type) -> &Value;

0 commit comments

Comments
 (0)