Skip to content

Commit 14e0ad8

Browse files
committed
Fix #[project] on non-statement expressions
1 parent 7d46ff1 commit 14e0ad8

File tree

4 files changed

+101
-42
lines changed

4 files changed

+101
-42
lines changed

pin-project-internal/src/project.rs

+41-16
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ pub(crate) fn attribute(args: &TokenStream, input: Stmt, mutability: Mutability)
1313
.unwrap_or_else(|e| e.to_compile_error())
1414
}
1515

16-
fn replace_stmt(stmt: &mut Stmt, mutability: Mutability) -> Result<()> {
17-
match stmt {
18-
Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => {
16+
fn replace_expr(expr: &mut Expr, mutability: Mutability) {
17+
match expr {
18+
Expr::Match(expr) => {
1919
Context::new(mutability).replace_expr_match(expr);
2020
}
21-
Stmt::Expr(Expr::If(expr_if)) => {
21+
Expr::If(expr_if) => {
2222
let mut expr_if = expr_if;
2323
while let Expr::Let(ref mut expr) = &mut *expr_if.cond {
2424
Context::new(mutability).replace_expr_let(expr);
@@ -31,15 +31,18 @@ fn replace_stmt(stmt: &mut Stmt, mutability: Mutability) -> Result<()> {
3131
break;
3232
}
3333
}
34-
Stmt::Local(local) => Context::new(mutability).replace_local(local)?,
3534
_ => {}
3635
}
37-
Ok(())
36+
}
37+
38+
fn replace_local(local: &mut Local, mutability: Mutability) -> Result<()> {
39+
Context::new(mutability).replace_local(local)
3840
}
3941

4042
fn parse(mut stmt: Stmt, mutability: Mutability) -> Result<TokenStream> {
41-
replace_stmt(&mut stmt, mutability)?;
4243
match &mut stmt {
44+
Stmt::Expr(expr) | Stmt::Semi(expr, _) => replace_expr(expr, mutability),
45+
Stmt::Local(local) => replace_local(local, mutability)?,
4346
Stmt::Item(Item::Fn(item)) => replace_item_fn(item, mutability)?,
4447
Stmt::Item(Item::Impl(item)) => replace_item_impl(item, mutability),
4548
Stmt::Item(Item::Use(item)) => replace_item_use(item, mutability)?,
@@ -219,12 +222,28 @@ impl FnVisitor {
219222
}
220223

221224
fn visit_stmt(&mut self, node: &mut Stmt) -> Result<()> {
222-
let attr = match node {
223-
Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => {
224-
expr.attrs.find_remove(self.name())?
225+
match node {
226+
Stmt::Expr(expr) | Stmt::Semi(expr, _) => {
227+
visit_mut::visit_expr_mut(self, expr);
228+
self.visit_expr(expr)
229+
}
230+
Stmt::Local(local) => {
231+
visit_mut::visit_local_mut(self, local);
232+
if let Some(attr) = local.attrs.find_remove(self.name())? {
233+
parse_as_empty(&attr.tokens)?;
234+
replace_local(local, self.mutability)?;
235+
}
236+
Ok(())
225237
}
226-
Stmt::Local(local) => local.attrs.find_remove(self.name())?,
227-
Stmt::Expr(Expr::If(expr_if)) => {
238+
// Do not recurse into nested items.
239+
Stmt::Item(_) => Ok(()),
240+
}
241+
}
242+
243+
fn visit_expr(&mut self, node: &mut Expr) -> Result<()> {
244+
let attr = match node {
245+
Expr::Match(expr) => expr.attrs.find_remove(self.name())?,
246+
Expr::If(expr_if) => {
228247
if let Expr::Let(_) = &*expr_if.cond {
229248
expr_if.attrs.find_remove(self.name())?
230249
} else {
@@ -235,7 +254,7 @@ impl FnVisitor {
235254
};
236255
if let Some(attr) = attr {
237256
parse_as_empty(&attr.tokens)?;
238-
replace_stmt(node, self.mutability)?;
257+
replace_expr(node, self.mutability);
239258
}
240259
Ok(())
241260
}
@@ -246,14 +265,20 @@ impl VisitMut for FnVisitor {
246265
if self.res.is_err() {
247266
return;
248267
}
249-
250-
visit_mut::visit_stmt_mut(self, node);
251-
252268
if let Err(e) = self.visit_stmt(node) {
253269
self.res = Err(e)
254270
}
255271
}
256272

273+
fn visit_expr_mut(&mut self, node: &mut Expr) {
274+
if self.res.is_err() {
275+
return;
276+
}
277+
if let Err(e) = self.visit_expr(node) {
278+
self.res = Err(e)
279+
}
280+
}
281+
257282
fn visit_item_mut(&mut self, _: &mut Item) {
258283
// Do not recurse into nested items.
259284
}

tests/pin_project.rs

+24
Original file line numberDiff line numberDiff line change
@@ -556,3 +556,27 @@ fn self_in_where_clause() {
556556
type Foo = Struct1<T>;
557557
}
558558
}
559+
560+
#[test]
561+
fn where_clause() {
562+
#[pin_project]
563+
struct StructWhereClause<T>
564+
where
565+
T: Copy,
566+
{
567+
field: T,
568+
}
569+
570+
#[pin_project]
571+
struct TupleStructWhereClause<T>(T)
572+
where
573+
T: Copy;
574+
575+
#[pin_project]
576+
enum EnumWhereClause<T>
577+
where
578+
T: Copy,
579+
{
580+
Variant(T),
581+
}
582+
}

tests/project.rs

+19-23
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
#![warn(rust_2018_idioms, single_use_lifetimes)]
22
#![allow(dead_code)]
33

4-
// This hack is needed until https://github.com/rust-lang/rust/pull/69201
5-
// makes it way into stable.
6-
// Ceurrently, `#[attr] if true {}` doesn't even *parse* on stable,
7-
// which means that it will error even behind a `#[rustversion::nightly]`
4+
// Ceurrently, `#[attr] if true {}` doesn't even *parse* on MSRV,
5+
// which means that it will error even behind a `#[rustversion::since(..)]`
86
//
97
// This trick makes sure that we don't even attempt to parse
10-
// the `#[project] if let _` test on stable.
11-
#[rustversion::nightly]
8+
// the `#[project] if let _` test on MSRV.
9+
#[rustversion::since(1.43)]
1210
include!("project_if_attr.rs.in");
1311

1412
use pin_project::{pin_project, project};
@@ -194,23 +192,21 @@ mod project_use_2 {
194192
}
195193
}
196194

197-
#[pin_project]
198-
struct StructWhereClause<T>
199-
where
200-
T: Copy,
201-
{
202-
field: T,
203-
}
195+
#[test]
196+
#[project]
197+
fn non_stmt_expr_match() {
198+
#[pin_project]
199+
enum Enum<A> {
200+
Variant(#[pin] A),
201+
}
204202

205-
#[pin_project]
206-
struct TupleStructWhereClause<T>(T)
207-
where
208-
T: Copy;
203+
let mut x = Enum::Variant(1);
204+
let x = Pin::new(&mut x).project();
209205

210-
#[pin_project]
211-
enum EnumWhereClause<T>
212-
where
213-
T: Copy,
214-
{
215-
Variant(T),
206+
Some(
207+
#[project]
208+
match x {
209+
Enum::Variant(_x) => {}
210+
},
211+
);
216212
}

tests/project_if_attr.rs.in

+17-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
// FIXME: Once https://github.com/rust-lang/rust/pull/69201 makes its
2-
// way into stable, move this back into `project.rs
3-
41
#[test]
52
#[project]
63
fn project_if_let() {
@@ -27,3 +24,20 @@ fn project_if_let() {
2724
}
2825
}
2926

27+
#[test]
28+
#[project]
29+
fn non_stmt_expr_if_let() {
30+
#[pin_project]
31+
enum Enum<A> {
32+
Variant(#[pin] A),
33+
}
34+
35+
let mut x = Enum::Variant(1);
36+
let x = Pin::new(&mut x).project();
37+
38+
#[allow(irrefutable_let_patterns)]
39+
Some(
40+
#[project]
41+
if let Enum::Variant(_x) = x {},
42+
);
43+
}

0 commit comments

Comments
 (0)