Skip to content

Commit fca80ac

Browse files
committed
rustc: Fix (again) simd vectors by-val in ABI
The issue of passing around SIMD types as values between functions has seen [quite a lot] of [discussion], and although we thought [we fixed it][quite a lot] it [wasn't]! This PR is a change to rustc to, again, try to fix this issue. The fundamental problem here remains the same, if a SIMD vector argument is passed by-value in LLVM's function type, then if the caller and callee disagree on target features a miscompile happens. We solve this by never passing SIMD vectors by-value, but LLVM will still thwart us with its argument promotion pass to promote by-ref SIMD arguments to by-val SIMD arguments. This commit is an attempt to thwart LLVM thwarting us. We, just before codegen, will take yet another look at the LLVM module and demote any by-value SIMD arguments we see. This is a very manual attempt by us to ensure the codegen for a module keeps working, and it unfortunately is likely producing suboptimal code, even in release mode. The saving grace for this, in theory, is that if SIMD types are passed by-value across a boundary in release mode it's pretty unlikely to be performance sensitive (as it's already doing a load/store, and otherwise perf-sensitive bits should be inlined). The implementation here is basically a big wad of C++. It was largely copied from LLVM's own argument promotion pass, only doing the reverse. In local testing this... Closes rust-lang#50154 Closes rust-lang#52636 Closes rust-lang#54583 Closes rust-lang#55059 [quite a lot]: rust-lang#47743 [discussion]: rust-lang#44367 [wasn't]: rust-lang#50154
1 parent b1bdf04 commit fca80ac

File tree

10 files changed

+330
-10
lines changed

10 files changed

+330
-10
lines changed

src/librustc_codegen_llvm/back/lto.rs

+5-7
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ impl LtoModuleCodegen {
8080
let module = module.take().unwrap();
8181
{
8282
let config = cgcx.config(module.kind);
83-
let llmod = module.module_llvm.llmod();
84-
let tm = &*module.module_llvm.tm;
85-
run_pass_manager(cgcx, tm, llmod, config, false);
83+
run_pass_manager(cgcx, &module, config, false);
8684
timeline.record("fat-done");
8785
}
8886
Ok(module)
@@ -557,8 +555,7 @@ fn thin_lto(cgcx: &CodegenContext,
557555
}
558556

559557
fn run_pass_manager(cgcx: &CodegenContext,
560-
tm: &llvm::TargetMachine,
561-
llmod: &llvm::Module,
558+
module: &ModuleCodegen,
562559
config: &ModuleConfig,
563560
thin: bool) {
564561
// Now we have one massive module inside of llmod. Time to run the
@@ -569,7 +566,8 @@ fn run_pass_manager(cgcx: &CodegenContext,
569566
debug!("running the pass manager");
570567
unsafe {
571568
let pm = llvm::LLVMCreatePassManager();
572-
llvm::LLVMRustAddAnalysisPasses(tm, pm, llmod);
569+
let llmod = module.module_llvm.llmod();
570+
llvm::LLVMRustAddAnalysisPasses(module.module_llvm.tm, pm, llmod);
573571

574572
if config.verify_llvm_ir {
575573
let pass = llvm::LLVMRustFindAndCreatePass("verify\0".as_ptr() as *const _);
@@ -864,7 +862,7 @@ impl ThinModule {
864862
// little differently.
865863
info!("running thin lto passes over {}", module.name);
866864
let config = cgcx.config(module.kind);
867-
run_pass_manager(cgcx, module.module_llvm.tm, llmod, config, true);
865+
run_pass_manager(cgcx, &module, config, true);
868866
cgcx.save_temp_bitcode(&module, "thin-lto-after-pm");
869867
timeline.record("thin-done");
870868
}

src/librustc_codegen_llvm/back/write.rs

+33-1
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ unsafe fn optimize(cgcx: &CodegenContext,
633633
None,
634634
&format!("llvm module passes [{}]", module_name.unwrap()),
635635
|| {
636-
llvm::LLVMRunPassManager(mpm, llmod)
636+
llvm::LLVMRunPassManager(mpm, llmod);
637637
});
638638

639639
// Deallocate managers that we're now done with
@@ -691,6 +691,38 @@ unsafe fn codegen(cgcx: &CodegenContext,
691691
create_msvc_imps(cgcx, llcx, llmod);
692692
}
693693

694+
// Ok now this one's a super interesting invocations. SIMD in rustc is
695+
// difficult where we want some parts of the program to be able to use
696+
// some SIMD features while other parts of the program don't. The real
697+
// tough part is that we want this to actually work correctly!
698+
//
699+
// We go to great lengths to make sure this works, and one crucial
700+
// aspect is that vector arguments (simd types) are never passed by
701+
// value in the ABI of functions. It turns out, however, that LLVM will
702+
// undo our "clever work" of passing vector types by reference. Its
703+
// argument promotion pass will promote these by-ref arguments to
704+
// by-val. That, however, introduces codegen errors!
705+
//
706+
// The upstream LLVM bug [1] has unfortunatey not really seen a lot of
707+
// activity. The Rust bug [2], however, has seen quite a lot of reports
708+
// of this in the wild. As a result, this is worked around locally here.
709+
// We have a custom transformation, `LLVMRustDemoteSimdArguments`, which
710+
// does the opposite of argument promotion by demoting any by-value SIMD
711+
// arguments in function signatures to pointers intead of being
712+
// by-value.
713+
//
714+
// This operates at the LLVM IR layer because LLVM is thwarting our
715+
// codegen and this is the only chance we get to make sure it's correct
716+
// before we hit codegen.
717+
//
718+
// Hopefully one day the upstream LLVM bug will be fixed and we'll no
719+
// longer need this!
720+
//
721+
// [1]: https://bugs.llvm.org/show_bug.cgi?id=37358
722+
// [2]: https://github.com/rust-lang/rust/issues/50154
723+
llvm::LLVMRustDemoteSimdArguments(llmod);
724+
cgcx.save_temp_bitcode(&module, "simd-demoted");
725+
694726
// A codegen-specific pass manager is used to generate object
695727
// files for an LLVM module.
696728
//

src/librustc_codegen_llvm/llvm/ffi.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,8 @@ extern "C" {
11361136
/// Runs a pass manager on a module.
11371137
pub fn LLVMRunPassManager(PM: &PassManager<'a>, M: &'a Module) -> Bool;
11381138

1139+
pub fn LLVMRustDemoteSimdArguments(M: &'a Module);
1140+
11391141
pub fn LLVMInitializePasses();
11401142

11411143
pub fn LLVMPassManagerBuilderCreate() -> &'static mut PassManagerBuilder;

src/librustc_llvm/build.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ fn main() {
162162
}
163163

164164
build_helper::rerun_if_changed_anything_in_dir(Path::new("../rustllvm"));
165-
cfg.file("../rustllvm/PassWrapper.cpp")
165+
cfg
166+
.file("../rustllvm/DemoteSimd.cpp")
167+
.file("../rustllvm/PassWrapper.cpp")
166168
.file("../rustllvm/RustWrapper.cpp")
167169
.file("../rustllvm/ArchiveWrapper.cpp")
168170
.file("../rustllvm/Linker.cpp")

src/rustllvm/DemoteSimd.cpp

+186
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
// Copyright 2018 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// http://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
#include <vector>
12+
#include <set>
13+
14+
#include "rustllvm.h"
15+
16+
#include "llvm/IR/CallSite.h"
17+
#include "llvm/IR/Module.h"
18+
19+
using namespace llvm;
20+
21+
static std::vector<Function*>
22+
GetFunctionsWithSimdArgs(Module *M) {
23+
std::vector<Function*> Ret;
24+
25+
for (auto &F : M->functions()) {
26+
// Skip all intrinsic calls as these are always tightly controlled to "work
27+
// correctly", so no need to fixup any of these.
28+
if (F.isIntrinsic())
29+
continue;
30+
31+
// We know that we started out with a `Module` that has what we want, so
32+
// we're just trying to undo specifically the work of the
33+
// `ArgumentPromotion` pass. That only runs in a select few circumstances,
34+
// so make sure that we don't get anything surprising. For example, make
35+
// sure we don't actually return a vector type because rustc shouldn't ever
36+
// generate this and nor should passes make this happen.
37+
assert(!F->getReturnType()->isVectorTy());
38+
39+
// If any argument to this function is a by-value vector type, then that's
40+
// bad! The compiler didn't generate any functions that looked like this,
41+
// and we try to rely on LLVM to not do this! Argument promotion may,
42+
// however, promote arguments from behind references. In any case, figure
43+
// out if we're interested in demoting this argument.
44+
bool anyVector = false;
45+
for (auto &Arg : F.args())
46+
anyVector = anyVector || Arg.getType()->isVectorTy();
47+
if (anyVector)
48+
Ret.push_back(&F);
49+
}
50+
51+
return Ret;
52+
}
53+
54+
extern "C" void
55+
LLVMRustDemoteSimdArguments(LLVMModuleRef Mod) {
56+
Module *M = unwrap(Mod);
57+
58+
auto Functions = GetFunctionsWithSimdArgs(M);
59+
60+
for (auto F : Functions) {
61+
// The argument promotion pass in LLVM should only run on functions that
62+
// have local linkage. We're modifying function signatures here, so make
63+
// sure such a desctructive change doesn't affect the public ABI.
64+
assert(F->hasLocalLinkage());
65+
66+
// Build up our list of new parameters and new argument attributes.
67+
// We're only changing those arguments which are vector types.
68+
SmallVector<Type*, 8> Params;
69+
SmallVector<AttributeSet, 8> ArgAttrVec;
70+
auto PAL = F->getAttributes();
71+
for (auto &Arg : F->args()) {
72+
auto *Ty = Arg.getType();
73+
if (Ty->isVectorTy()) {
74+
Params.push_back(PointerType::get(Ty, 0));
75+
ArgAttrVec.push_back(AttributeSet());
76+
} else {
77+
Params.push_back(Ty);
78+
ArgAttrVec.push_back(PAL.getParamAttributes(Arg.getArgNo()));
79+
}
80+
}
81+
82+
// Replace `F` with a new function with our new signature. I'm... not really
83+
// sure how this works, but this is all the steps `ArgumentPromotion` does
84+
// to replace a signature as well.
85+
assert(!F->isVarArg()); // ArgumentPromotion should skip these fns
86+
FunctionType *NFTy = FunctionType::get(F->getReturnType(), Params, false);
87+
Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName());
88+
NF->copyAttributesFrom(F);
89+
NF->setSubprogram(F->getSubprogram());
90+
F->setSubprogram(nullptr);
91+
NF->setAttributes(AttributeList::get(F->getContext(),
92+
PAL.getFnAttributes(),
93+
PAL.getRetAttributes(),
94+
ArgAttrVec));
95+
ArgAttrVec.clear();
96+
F->getParent()->getFunctionList().insert(F->getIterator(), NF);
97+
NF->takeName(F);
98+
99+
// Iterate over all invocations of `F`, updating all `call` instructions to
100+
// store immediate vector types in a local `alloc` instead of a by-value
101+
// vector.
102+
//
103+
// Like before, much of this is copied from the `ArgumentPromotion` pass in
104+
// LLVM.
105+
SmallVector<Value*, 16> Args;
106+
while (!F->use_empty()) {
107+
CallSite CS(F->user_back());
108+
assert(CS.getCalledFunction() == F);
109+
Instruction *Call = CS.getInstruction();
110+
const AttributeList &CallPAL = CS.getAttributes();
111+
112+
// Loop over the operands, inserting an `alloca` and a store for any
113+
// argument we're demoting to be by reference
114+
//
115+
// FIXME: we probably want to figure out an LLVM pass to run and clean up
116+
// this function and instructions we're generating, we should in theory
117+
// only generate a maximum number of `alloca` instructions rather than
118+
// one-per-variable unconditionally.
119+
CallSite::arg_iterator AI = CS.arg_begin();
120+
size_t ArgNo = 0;
121+
for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
122+
++I, ++AI, ++ArgNo) {
123+
if (I->getType()->isVectorTy()) {
124+
AllocaInst *AllocA = new AllocaInst(I->getType(), 0, nullptr, "", Call);
125+
new StoreInst(*AI, AllocA, Call);
126+
Args.push_back(AllocA);
127+
ArgAttrVec.push_back(AttributeSet());
128+
} else {
129+
Args.push_back(*AI);
130+
ArgAttrVec.push_back(CallPAL.getParamAttributes(ArgNo));
131+
}
132+
}
133+
assert(AI == CS.arg_end());
134+
135+
// Create a new call instructions which we'll use to replace the old call
136+
// instruction, copying over as many attributes and such as possible.
137+
SmallVector<OperandBundleDef, 1> OpBundles;
138+
CS.getOperandBundlesAsDefs(OpBundles);
139+
140+
CallSite NewCS;
141+
if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) {
142+
InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
143+
Args, OpBundles, "", Call);
144+
} else {
145+
auto *NewCall = CallInst::Create(NF, Args, OpBundles, "", Call);
146+
NewCall->setTailCallKind(cast<CallInst>(Call)->getTailCallKind());
147+
NewCS = NewCall;
148+
}
149+
NewCS.setCallingConv(CS.getCallingConv());
150+
NewCS.setAttributes(
151+
AttributeList::get(F->getContext(), CallPAL.getFnAttributes(),
152+
CallPAL.getRetAttributes(), ArgAttrVec));
153+
NewCS->setDebugLoc(Call->getDebugLoc());
154+
Args.clear();
155+
ArgAttrVec.clear();
156+
Call->replaceAllUsesWith(NewCS.getInstruction());
157+
NewCS->takeName(Call);
158+
Call->eraseFromParent();
159+
}
160+
161+
// Splice the body of the old function right into the new function.
162+
NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList());
163+
164+
// Update our new function to replace all uses of the by-value argument with
165+
// loads of the pointer argument we've generated.
166+
//
167+
// FIXME: we probably want to only generate one load instruction per
168+
// function? Or maybe run an LLVM pass to clean up this function?
169+
for (Function::arg_iterator I = F->arg_begin(),
170+
E = F->arg_end(),
171+
I2 = NF->arg_begin();
172+
I != E;
173+
++I, ++I2) {
174+
if (I->getType()->isVectorTy()) {
175+
I->replaceAllUsesWith(new LoadInst(&*I2, "", &NF->begin()->front()));
176+
} else {
177+
I->replaceAllUsesWith(&*I2);
178+
}
179+
I2->takeName(&*I);
180+
}
181+
182+
// Delete all references to the old function, it should be entirely dead
183+
// now.
184+
M->getFunctionList().remove(F);
185+
}
186+
}

src/test/codegen/repr-transparent.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ struct f32x4(f32, f32, f32, f32);
108108
#[repr(transparent)]
109109
pub struct Vector(f32x4);
110110

111-
// CHECK: define <4 x float> @test_Vector(<4 x float> %arg0)
111+
// CHECK: define <4 x float> @test_Vector(<4 x float>* %arg0)
112112
#[no_mangle]
113113
pub extern fn test_Vector(_: Vector) -> Vector { loop {} }
114114

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
-include ../../run-make-fulldeps/tools.mk
2+
3+
ifeq ($(TARGET),x86_64-unknown-linux-gnu)
4+
all:
5+
$(RUSTC) t1.rs -C opt-level=3
6+
$(TMPDIR)/t1
7+
$(RUSTC) t2.rs -C opt-level=3
8+
$(TMPDIR)/t2
9+
$(RUSTC) t3.rs -C opt-level=3
10+
$(TMPDIR)/t3
11+
else
12+
all:
13+
endif
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
use std::arch::x86_64;
2+
3+
fn main() {
4+
if !is_x86_feature_detected!("avx2") {
5+
return println!("AVX2 is not supported on this machine/build.");
6+
}
7+
let load_bytes: [u8; 32] = [0x0f; 32];
8+
let lb_ptr = load_bytes.as_ptr();
9+
let reg_load = unsafe {
10+
x86_64::_mm256_loadu_si256(
11+
lb_ptr as *const x86_64::__m256i
12+
)
13+
};
14+
println!("{:?}", reg_load);
15+
let mut store_bytes: [u8; 32] = [0; 32];
16+
let sb_ptr = store_bytes.as_mut_ptr();
17+
unsafe {
18+
x86_64::_mm256_storeu_si256(sb_ptr as *mut x86_64::__m256i, reg_load);
19+
}
20+
assert_eq!(load_bytes, store_bytes);
21+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
use std::arch::x86_64::*;
2+
3+
fn main() {
4+
if !is_x86_feature_detected!("avx") {
5+
return println!("AVX is not supported on this machine/build.");
6+
}
7+
unsafe {
8+
let f = _mm256_set_pd(2.0, 2.0, 2.0, 2.0);
9+
let r = _mm256_mul_pd(f, f);
10+
11+
union A { a: __m256d, b: [f64; 4] }
12+
assert_eq!(A { a: r }.b, [4.0, 4.0, 4.0, 4.0]);
13+
}
14+
}

0 commit comments

Comments
 (0)