Skip to content

Commit 85ab5be

Browse files
authored
Rollup merge of rust-lang#133122 - compiler-errors:afidt, r=oli-obk
Add unpolished, experimental support for AFIDT (async fn in dyn trait) This allows us to begin messing around `async fn` in `dyn Trait`. Calling an async fn from a trait object always returns a `dyn* Future<Output = ...>`. To make it work, Implementations are currently required to return something that can be coerced to a `dyn* Future` (see the example in `tests/ui/async-await/dyn/works.rs`). If it's not the right size, then it'll raise an error at the coercion site (see the example in `tests/ui/async-await/dyn/wrong-size.rs`). Currently the only practical way of doing this is wrapping the body in `Box::pin(async move { .. })`. This PR does not implement a helper type like a "`Boxing`"[^boxing] adapter, and I'll probably follow-up with another PR to improve the error message for the `PointerLike` trait (something that explains in just normal prose what is happening here, rather than a trait error). [^boxing]: https://rust-lang.github.io/async-fundamentals-initiative/explainer/user_guide_future.html#the-boxing-adapter This PR also does not implement new trait solver support for AFIDT; I'll need to think how best to integrate it into candidate assembly, and that's a bit of a matter of taste, but I don't think it will be difficult to do. This could also be generalized: * To work on functions that are `-> impl Future` (soon). * To work on functions that are `-> impl Iterator` and other "dyn rpitit safe" traits. We still need to nail down exactly what is needed for this to be okay (not soon). Tracking: * rust-lang#133119
2 parents 2779de7 + 57e8a1c commit 85ab5be

23 files changed

+621
-33
lines changed

compiler/rustc_feature/src/unstable.rs

+2
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ declare_features! (
390390
(unstable, associated_type_defaults, "1.2.0", Some(29661)),
391391
/// Allows `async || body` closures.
392392
(unstable, async_closure, "1.37.0", Some(62290)),
393+
/// Allows async functions to be called from `dyn Trait`.
394+
(incomplete, async_fn_in_dyn_trait, "CURRENT_RUSTC_VERSION", Some(133119)),
393395
/// Allows `#[track_caller]` on async functions.
394396
(unstable, async_fn_track_caller, "1.73.0", Some(110011)),
395397
/// Allows `for await` loops.

compiler/rustc_middle/src/ty/instance.rs

+20-17
Original file line numberDiff line numberDiff line change
@@ -677,23 +677,26 @@ impl<'tcx> Instance<'tcx> {
677677
//
678678
// 1) The underlying method expects a caller location parameter
679679
// in the ABI
680-
if resolved.def.requires_caller_location(tcx)
681-
// 2) The caller location parameter comes from having `#[track_caller]`
682-
// on the implementation, and *not* on the trait method.
683-
&& !tcx.should_inherit_track_caller(def)
684-
// If the method implementation comes from the trait definition itself
685-
// (e.g. `trait Foo { #[track_caller] my_fn() { /* impl */ } }`),
686-
// then we don't need to generate a shim. This check is needed because
687-
// `should_inherit_track_caller` returns `false` if our method
688-
// implementation comes from the trait block, and not an impl block
689-
&& !matches!(
690-
tcx.opt_associated_item(def),
691-
Some(ty::AssocItem {
692-
container: ty::AssocItemContainer::Trait,
693-
..
694-
})
695-
)
696-
{
680+
let needs_track_caller_shim = resolved.def.requires_caller_location(tcx)
681+
// 2) The caller location parameter comes from having `#[track_caller]`
682+
// on the implementation, and *not* on the trait method.
683+
&& !tcx.should_inherit_track_caller(def)
684+
// If the method implementation comes from the trait definition itself
685+
// (e.g. `trait Foo { #[track_caller] my_fn() { /* impl */ } }`),
686+
// then we don't need to generate a shim. This check is needed because
687+
// `should_inherit_track_caller` returns `false` if our method
688+
// implementation comes from the trait block, and not an impl block
689+
&& !matches!(
690+
tcx.opt_associated_item(def),
691+
Some(ty::AssocItem {
692+
container: ty::AssocItemContainer::Trait,
693+
..
694+
})
695+
);
696+
// We also need to generate a shim if this is an AFIT.
697+
let needs_rpitit_shim =
698+
tcx.return_position_impl_trait_in_trait_shim_data(def).is_some();
699+
if needs_track_caller_shim || needs_rpitit_shim {
697700
if tcx.is_closure_like(def) {
698701
debug!(
699702
" => vtable fn pointer created for closure with #[track_caller]: {:?} for method {:?} {:?}",

compiler/rustc_middle/src/ty/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ mod opaque_types;
146146
mod parameterized;
147147
mod predicate;
148148
mod region;
149+
mod return_position_impl_trait_in_trait;
149150
mod rvalue_scopes;
150151
mod structural_impls;
151152
#[allow(hidden_glob_reexports)]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
use rustc_hir::def_id::DefId;
2+
3+
use crate::ty::{self, ExistentialPredicateStableCmpExt, TyCtxt};
4+
5+
impl<'tcx> TyCtxt<'tcx> {
6+
/// Given a `def_id` of a trait or impl method, compute whether that method needs to
7+
/// have an RPITIT shim applied to it for it to be object safe. If so, return the
8+
/// `def_id` of the RPITIT, and also the args of trait method that returns the RPITIT.
9+
///
10+
/// NOTE that these args are not, in general, the same as than the RPITIT's args. They
11+
/// are a subset of those args, since they do not include the late-bound lifetimes of
12+
/// the RPITIT. Depending on the context, these will need to be dealt with in different
13+
/// ways -- in codegen, it's okay to fill them with ReErased.
14+
pub fn return_position_impl_trait_in_trait_shim_data(
15+
self,
16+
def_id: DefId,
17+
) -> Option<(DefId, ty::EarlyBinder<'tcx, ty::GenericArgsRef<'tcx>>)> {
18+
let assoc_item = self.opt_associated_item(def_id)?;
19+
20+
let (trait_item_def_id, opt_impl_def_id) = match assoc_item.container {
21+
ty::AssocItemContainer::Impl => {
22+
(assoc_item.trait_item_def_id?, Some(self.parent(def_id)))
23+
}
24+
ty::AssocItemContainer::Trait => (def_id, None),
25+
};
26+
27+
let sig = self.fn_sig(trait_item_def_id);
28+
29+
// Check if the trait returns an RPITIT.
30+
let ty::Alias(ty::Projection, ty::AliasTy { def_id, .. }) =
31+
*sig.skip_binder().skip_binder().output().kind()
32+
else {
33+
return None;
34+
};
35+
if !self.is_impl_trait_in_trait(def_id) {
36+
return None;
37+
}
38+
39+
let args = if let Some(impl_def_id) = opt_impl_def_id {
40+
// Rebase the args from the RPITIT onto the impl trait ref, so we can later
41+
// substitute them with the method args of the *impl* method, since that's
42+
// the instance we're building a vtable shim for.
43+
ty::GenericArgs::identity_for_item(self, trait_item_def_id).rebase_onto(
44+
self,
45+
self.parent(trait_item_def_id),
46+
self.impl_trait_ref(impl_def_id)
47+
.expect("expected impl trait ref from parent of impl item")
48+
.instantiate_identity()
49+
.args,
50+
)
51+
} else {
52+
// This is when we have a default trait implementation.
53+
ty::GenericArgs::identity_for_item(self, trait_item_def_id)
54+
};
55+
56+
Some((def_id, ty::EarlyBinder::bind(args)))
57+
}
58+
59+
/// Given a `DefId` of an RPITIT and its args, return the existential predicates
60+
/// that corresponds to the RPITIT's bounds with the self type erased.
61+
pub fn item_bounds_to_existential_predicates(
62+
self,
63+
def_id: DefId,
64+
args: ty::GenericArgsRef<'tcx>,
65+
) -> &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> {
66+
let mut bounds: Vec<_> = self
67+
.item_super_predicates(def_id)
68+
.iter_instantiated(self, args)
69+
.filter_map(|clause| {
70+
clause
71+
.kind()
72+
.map_bound(|clause| match clause {
73+
ty::ClauseKind::Trait(trait_pred) => Some(ty::ExistentialPredicate::Trait(
74+
ty::ExistentialTraitRef::erase_self_ty(self, trait_pred.trait_ref),
75+
)),
76+
ty::ClauseKind::Projection(projection_pred) => {
77+
Some(ty::ExistentialPredicate::Projection(
78+
ty::ExistentialProjection::erase_self_ty(self, projection_pred),
79+
))
80+
}
81+
ty::ClauseKind::TypeOutlives(_) => {
82+
// Type outlives bounds don't really turn into anything,
83+
// since we must use an intersection region for the `dyn*`'s
84+
// region anyways.
85+
None
86+
}
87+
_ => unreachable!("unexpected clause in item bounds: {clause:?}"),
88+
})
89+
.transpose()
90+
})
91+
.collect();
92+
bounds.sort_by(|a, b| a.skip_binder().stable_cmp(self, &b.skip_binder()));
93+
self.mk_poly_existential_predicates(&bounds)
94+
}
95+
}

compiler/rustc_mir_transform/src/shim.rs

+53-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use rustc_index::{Idx, IndexVec};
99
use rustc_middle::mir::patch::MirPatch;
1010
use rustc_middle::mir::*;
1111
use rustc_middle::query::Providers;
12+
use rustc_middle::ty::adjustment::PointerCoercion;
1213
use rustc_middle::ty::{
1314
self, CoroutineArgs, CoroutineArgsExt, EarlyBinder, GenericArgs, Ty, TyCtxt,
1415
};
@@ -710,6 +711,13 @@ fn build_call_shim<'tcx>(
710711
};
711712

712713
let def_id = instance.def_id();
714+
715+
let rpitit_shim = if let ty::InstanceKind::ReifyShim(..) = instance {
716+
tcx.return_position_impl_trait_in_trait_shim_data(def_id)
717+
} else {
718+
None
719+
};
720+
713721
let sig = tcx.fn_sig(def_id);
714722
let sig = sig.map_bound(|sig| tcx.instantiate_bound_regions_with_erased(sig));
715723

@@ -765,9 +773,34 @@ fn build_call_shim<'tcx>(
765773
let mut local_decls = local_decls_for_sig(&sig, span);
766774
let source_info = SourceInfo::outermost(span);
767775

776+
let mut destination = Place::return_place();
777+
if let Some((rpitit_def_id, fn_args)) = rpitit_shim {
778+
let rpitit_args =
779+
fn_args.instantiate_identity().extend_to(tcx, rpitit_def_id, |param, _| {
780+
match param.kind {
781+
ty::GenericParamDefKind::Lifetime => tcx.lifetimes.re_erased.into(),
782+
ty::GenericParamDefKind::Type { .. }
783+
| ty::GenericParamDefKind::Const { .. } => {
784+
unreachable!("rpitit should have no addition ty/ct")
785+
}
786+
}
787+
});
788+
let dyn_star_ty = Ty::new_dynamic(
789+
tcx,
790+
tcx.item_bounds_to_existential_predicates(rpitit_def_id, rpitit_args),
791+
tcx.lifetimes.re_erased,
792+
ty::DynStar,
793+
);
794+
destination = local_decls.push(local_decls[RETURN_PLACE].clone()).into();
795+
local_decls[RETURN_PLACE].ty = dyn_star_ty;
796+
let mut inputs_and_output = sig.inputs_and_output.to_vec();
797+
*inputs_and_output.last_mut().unwrap() = dyn_star_ty;
798+
sig.inputs_and_output = tcx.mk_type_list(&inputs_and_output);
799+
}
800+
768801
let rcvr_place = || {
769802
assert!(rcvr_adjustment.is_some());
770-
Place::from(Local::new(1 + 0))
803+
Place::from(Local::new(1))
771804
};
772805
let mut statements = vec![];
773806

@@ -854,7 +887,7 @@ fn build_call_shim<'tcx>(
854887
TerminatorKind::Call {
855888
func: callee,
856889
args,
857-
destination: Place::return_place(),
890+
destination,
858891
target: Some(BasicBlock::new(1)),
859892
unwind: if let Some(Adjustment::RefMut) = rcvr_adjustment {
860893
UnwindAction::Cleanup(BasicBlock::new(3))
@@ -882,7 +915,24 @@ fn build_call_shim<'tcx>(
882915
);
883916
}
884917
// BB #1/#2 - return
885-
block(&mut blocks, vec![], TerminatorKind::Return, false);
918+
// NOTE: If this is an RPITIT in dyn, we also want to coerce
919+
// the return type of the function into a `dyn*`.
920+
let stmts = if rpitit_shim.is_some() {
921+
vec![Statement {
922+
source_info,
923+
kind: StatementKind::Assign(Box::new((
924+
Place::return_place(),
925+
Rvalue::Cast(
926+
CastKind::PointerCoercion(PointerCoercion::DynStar, CoercionSource::Implicit),
927+
Operand::Move(destination),
928+
sig.output(),
929+
),
930+
))),
931+
}]
932+
} else {
933+
vec![]
934+
};
935+
block(&mut blocks, stmts, TerminatorKind::Return, false);
886936
if let Some(Adjustment::RefMut) = rcvr_adjustment {
887937
// BB #3 - drop if closure panics
888938
block(

compiler/rustc_monomorphize/src/lib.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ fn custom_coerce_unsize_info<'tcx>(
4242
..
4343
})) => Ok(tcx.coerce_unsized_info(impl_def_id)?.custom_kind.unwrap()),
4444
impl_source => {
45-
bug!("invalid `CoerceUnsized` impl_source: {:?}", impl_source);
45+
bug!(
46+
"invalid `CoerceUnsized` from {source_ty} to {target_ty}: impl_source: {:?}",
47+
impl_source
48+
);
4649
}
4750
}
4851
}

compiler/rustc_span/src/symbol.rs

+1
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ symbols! {
461461
async_drop_slice,
462462
async_drop_surface_drop_in_place,
463463
async_fn,
464+
async_fn_in_dyn_trait,
464465
async_fn_in_trait,
465466
async_fn_kind_helper,
466467
async_fn_kind_upvars,

compiler/rustc_trait_selection/src/traits/dyn_compatibility.rs

+48-11
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use rustc_abi::BackendRepr;
1111
use rustc_errors::FatalError;
1212
use rustc_hir as hir;
1313
use rustc_hir::def_id::DefId;
14+
use rustc_middle::bug;
1415
use rustc_middle::query::Providers;
1516
use rustc_middle::ty::{
1617
self, EarlyBinder, ExistentialPredicateStableCmpExt as _, GenericArgs, Ty, TyCtxt,
@@ -901,23 +902,59 @@ fn contains_illegal_impl_trait_in_trait<'tcx>(
901902
fn_def_id: DefId,
902903
ty: ty::Binder<'tcx, Ty<'tcx>>,
903904
) -> Option<MethodViolationCode> {
904-
// This would be caught below, but rendering the error as a separate
905-
// `async-specific` message is better.
905+
let ty = tcx.liberate_late_bound_regions(fn_def_id, ty);
906+
906907
if tcx.asyncness(fn_def_id).is_async() {
907-
return Some(MethodViolationCode::AsyncFn);
908+
// FIXME(async_fn_in_dyn_trait): Think of a better way to unify these code paths
909+
// to issue an appropriate feature suggestion when users try to use AFIDT.
910+
// Obviously we must only do this once AFIDT is finished enough to actually be usable.
911+
if tcx.features().async_fn_in_dyn_trait() {
912+
let ty::Alias(ty::Projection, proj) = *ty.kind() else {
913+
bug!("expected async fn in trait to return an RPITIT");
914+
};
915+
assert!(tcx.is_impl_trait_in_trait(proj.def_id));
916+
917+
// FIXME(async_fn_in_dyn_trait): We should check that this bound is legal too,
918+
// and stop relying on `async fn` in the definition.
919+
for bound in tcx.item_bounds(proj.def_id).instantiate(tcx, proj.args) {
920+
if let Some(violation) = bound
921+
.visit_with(&mut IllegalRpititVisitor { tcx, allowed: Some(proj) })
922+
.break_value()
923+
{
924+
return Some(violation);
925+
}
926+
}
927+
928+
None
929+
} else {
930+
// Rendering the error as a separate `async-specific` message is better.
931+
Some(MethodViolationCode::AsyncFn)
932+
}
933+
} else {
934+
ty.visit_with(&mut IllegalRpititVisitor { tcx, allowed: None }).break_value()
908935
}
936+
}
937+
938+
struct IllegalRpititVisitor<'tcx> {
939+
tcx: TyCtxt<'tcx>,
940+
allowed: Option<ty::AliasTy<'tcx>>,
941+
}
942+
943+
impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for IllegalRpititVisitor<'tcx> {
944+
type Result = ControlFlow<MethodViolationCode>;
909945

910-
// FIXME(RPITIT): Perhaps we should use a visitor here?
911-
ty.skip_binder().walk().find_map(|arg| {
912-
if let ty::GenericArgKind::Type(ty) = arg.unpack()
913-
&& let ty::Alias(ty::Projection, proj) = ty.kind()
914-
&& tcx.is_impl_trait_in_trait(proj.def_id)
946+
fn visit_ty(&mut self, ty: Ty<'tcx>) -> Self::Result {
947+
if let ty::Alias(ty::Projection, proj) = *ty.kind()
948+
&& Some(proj) != self.allowed
949+
&& self.tcx.is_impl_trait_in_trait(proj.def_id)
915950
{
916-
Some(MethodViolationCode::ReferencesImplTraitInTrait(tcx.def_span(proj.def_id)))
951+
ControlFlow::Break(MethodViolationCode::ReferencesImplTraitInTrait(
952+
self.tcx.def_span(proj.def_id),
953+
))
917954
} else {
918-
None
955+
ty.super_visit_with(self)
919956
}
920-
})
957+
}
921958
}
922959

923960
pub(crate) fn provide(providers: &mut Providers) {

0 commit comments

Comments
 (0)