From e506322599ef0520c2e73e3eb602c2915036ee4b Mon Sep 17 00:00:00 2001 From: Bas Zalmstra <zalmstra.bas@gmail.com> Date: Sat, 23 Nov 2019 12:41:14 +0100 Subject: [PATCH 1/3] feat: parsing of break expressions --- crates/mun_hir/src/expr.rs | 15 ++++ crates/mun_syntax/src/ast/generated.rs | 40 ++++++++++- crates/mun_syntax/src/grammar.ron | 3 + .../src/parsing/grammar/expressions.rs | 68 +++++++++++++++---- crates/mun_syntax/src/parsing/token_set.rs | 4 +- .../mun_syntax/src/syntax_kind/generated.rs | 2 + crates/mun_syntax/src/tests/parser.rs | 13 ++++ .../tests/snapshots/parser__break_expr.snap | 62 +++++++++++++++++ 8 files changed, 191 insertions(+), 16 deletions(-) create mode 100644 crates/mun_syntax/src/tests/snapshots/parser__break_expr.snap diff --git a/crates/mun_hir/src/expr.rs b/crates/mun_hir/src/expr.rs index 08810a782..00b837455 100644 --- a/crates/mun_hir/src/expr.rs +++ b/crates/mun_hir/src/expr.rs @@ -201,6 +201,9 @@ pub enum Expr { Return { expr: Option<ExprId>, }, + Break { + expr: Option<ExprId>, + }, Loop { body: ExprId, }, @@ -293,6 +296,11 @@ impl Expr { f(*expr); } } + Expr::Break { expr } => { + if let Some(expr) = expr { + f(*expr); + } + } Expr::Loop { body } => { f(*body); } @@ -461,6 +469,7 @@ where match expr.kind() { ast::ExprKind::LoopExpr(expr) => self.collect_loop(expr), ast::ExprKind::ReturnExpr(r) => self.collect_return(r), + ast::ExprKind::BreakExpr(r) => self.collect_break(r), ast::ExprKind::BlockExpr(b) => self.collect_block(b), ast::ExprKind::Literal(e) => { let lit = match e.kind() { @@ -634,6 +643,12 @@ where self.alloc_expr(Expr::Return { expr }, syntax_node_ptr) } + fn collect_break(&mut self, expr: ast::BreakExpr) -> ExprId { + let syntax_node_ptr = AstPtr::new(&expr.clone().into()); + let expr = expr.expr().map(|e| self.collect_expr(e)); + self.alloc_expr(Expr::Break { expr }, syntax_node_ptr) + } + fn collect_loop(&mut self, expr: ast::LoopExpr) -> ExprId { let syntax_node_ptr = AstPtr::new(&expr.clone().into()); let body = self.collect_block_opt(expr.loop_body()); diff --git a/crates/mun_syntax/src/ast/generated.rs b/crates/mun_syntax/src/ast/generated.rs index 1ec906755..112c644e1 100644 --- a/crates/mun_syntax/src/ast/generated.rs +++ b/crates/mun_syntax/src/ast/generated.rs @@ -140,6 +140,37 @@ impl BlockExpr { } } +// BreakExpr + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct BreakExpr { + pub(crate) syntax: SyntaxNode, +} + +impl AstNode for BreakExpr { + fn can_cast(kind: SyntaxKind) -> bool { + match kind { + BREAK_EXPR => true, + _ => false, + } + } + fn cast(syntax: SyntaxNode) -> Option<Self> { + if Self::can_cast(syntax.kind()) { + Some(BreakExpr { syntax }) + } else { + None + } + } + fn syntax(&self) -> &SyntaxNode { + &self.syntax + } +} +impl BreakExpr { + pub fn expr(&self) -> Option<Expr> { + super::child_opt(self) + } +} + // CallExpr #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -218,7 +249,7 @@ impl AstNode for Expr { fn can_cast(kind: SyntaxKind) -> bool { match kind { LITERAL | PREFIX_EXPR | PATH_EXPR | BIN_EXPR | PAREN_EXPR | CALL_EXPR | IF_EXPR - | LOOP_EXPR | RETURN_EXPR | BLOCK_EXPR => true, + | LOOP_EXPR | RETURN_EXPR | BREAK_EXPR | BLOCK_EXPR => true, _ => false, } } @@ -244,6 +275,7 @@ pub enum ExprKind { IfExpr(IfExpr), LoopExpr(LoopExpr), ReturnExpr(ReturnExpr), + BreakExpr(BreakExpr), BlockExpr(BlockExpr), } impl From<Literal> for Expr { @@ -291,6 +323,11 @@ impl From<ReturnExpr> for Expr { Expr { syntax: n.syntax } } } +impl From<BreakExpr> for Expr { + fn from(n: BreakExpr) -> Expr { + Expr { syntax: n.syntax } + } +} impl From<BlockExpr> for Expr { fn from(n: BlockExpr) -> Expr { Expr { syntax: n.syntax } @@ -309,6 +346,7 @@ impl Expr { IF_EXPR => ExprKind::IfExpr(IfExpr::cast(self.syntax.clone()).unwrap()), LOOP_EXPR => ExprKind::LoopExpr(LoopExpr::cast(self.syntax.clone()).unwrap()), RETURN_EXPR => ExprKind::ReturnExpr(ReturnExpr::cast(self.syntax.clone()).unwrap()), + BREAK_EXPR => ExprKind::BreakExpr(BreakExpr::cast(self.syntax.clone()).unwrap()), BLOCK_EXPR => ExprKind::BlockExpr(BlockExpr::cast(self.syntax.clone()).unwrap()), _ => unreachable!(), } diff --git a/crates/mun_syntax/src/grammar.ron b/crates/mun_syntax/src/grammar.ron index 62507012f..787c9c57c 100644 --- a/crates/mun_syntax/src/grammar.ron +++ b/crates/mun_syntax/src/grammar.ron @@ -123,6 +123,7 @@ Grammar( "BLOCK_EXPR", "RETURN_EXPR", "LOOP_EXPR", + "BREAK_EXPR", "CONDITION", "BIND_PAT", @@ -200,6 +201,7 @@ Grammar( "IfExpr": ( options: [ "Condition" ] ), + "BreakExpr": (options: ["Expr"]), "ArgList": ( collections: [ ["args", "Expr"] @@ -217,6 +219,7 @@ Grammar( "IfExpr", "LoopExpr", "ReturnExpr", + "BreakExpr", "BlockExpr", ] ), diff --git a/crates/mun_syntax/src/parsing/grammar/expressions.rs b/crates/mun_syntax/src/parsing/grammar/expressions.rs index b6e4d77c5..bcbce1b2e 100644 --- a/crates/mun_syntax/src/parsing/grammar/expressions.rs +++ b/crates/mun_syntax/src/parsing/grammar/expressions.rs @@ -6,14 +6,32 @@ pub(crate) const LITERAL_FIRST: TokenSet = const EXPR_RECOVERY_SET: TokenSet = token_set![LET_KW]; -const ATOM_EXPR_FIRST: TokenSet = LITERAL_FIRST - .union(PATH_FIRST) - .union(token_set![IDENT, L_PAREN, L_CURLY, IF_KW, RETURN_KW,]); +const ATOM_EXPR_FIRST: TokenSet = LITERAL_FIRST.union(PATH_FIRST).union(token_set![ + IDENT, + T!['('], + T!['{'], + T![if], + T![loop], + T![return], + T![break], +]); const LHS_FIRST: TokenSet = ATOM_EXPR_FIRST.union(token_set![EXCLAMATION, MINUS]); const EXPR_FIRST: TokenSet = LHS_FIRST; +#[derive(Clone, Copy)] +struct Restrictions { + /// Indicates that parsing of structs is not valid in the current context. For instance: + /// ```mun + /// if break { 3 } + /// if break 4 { 3 } + /// ``` + /// In the first if expression we do not want the `break` expression to capture the block as an + /// expression. However, in the second statement we do want the break to capture the 4. + forbid_structs: bool, +} + pub(crate) fn expr_block_contents(p: &mut Parser) { while !p.at(EOF) && !p.at(T!['}']) { if p.eat(T![;]) { @@ -84,16 +102,29 @@ fn let_stmt(p: &mut Parser, m: Marker) { } pub(super) fn expr(p: &mut Parser) { - expr_bp(p, 1); + let r = Restrictions { + forbid_structs: false, + }; + expr_bp(p, r, 1); +} + +fn expr_no_struct(p: &mut Parser) { + let r = Restrictions { + forbid_structs: true, + }; + expr_bp(p, r, 1); } fn expr_stmt(p: &mut Parser) -> Option<CompletedMarker> { - expr_bp(p, 1) + let r = Restrictions { + forbid_structs: false, + }; + expr_bp(p, r, 1) } -fn expr_bp(p: &mut Parser, bp: u8) -> Option<CompletedMarker> { +fn expr_bp(p: &mut Parser, r: Restrictions, bp: u8) -> Option<CompletedMarker> { // Parse left hand side of the expression - let mut lhs = match lhs(p) { + let mut lhs = match lhs(p, r) { Some(lhs) => lhs, None => return None, }; @@ -107,7 +138,7 @@ fn expr_bp(p: &mut Parser, bp: u8) -> Option<CompletedMarker> { let m = lhs.precede(p); p.bump(op); - expr_bp(p, op_bp + 1); + expr_bp(p, r, op_bp + 1); lhs = m.complete(p, BIN_EXPR); } @@ -135,7 +166,7 @@ fn current_op(p: &Parser) -> (u8, SyntaxKind) { } } -fn lhs(p: &mut Parser) -> Option<CompletedMarker> { +fn lhs(p: &mut Parser, r: Restrictions) -> Option<CompletedMarker> { let m; let kind = match p.current() { T![-] | T![!] => { @@ -144,11 +175,11 @@ fn lhs(p: &mut Parser) -> Option<CompletedMarker> { PREFIX_EXPR } _ => { - let lhs = atom_expr(p)?; + let lhs = atom_expr(p, r)?; return Some(postfix_expr(p, lhs)); } }; - expr_bp(p, 255); + expr_bp(p, r, 255); Some(m.complete(p, kind)) } @@ -188,7 +219,7 @@ fn arg_list(p: &mut Parser) { m.complete(p, ARG_LIST); } -fn atom_expr(p: &mut Parser) -> Option<CompletedMarker> { +fn atom_expr(p: &mut Parser, r: Restrictions) -> Option<CompletedMarker> { if let Some(m) = literal(p) { return Some(m); } @@ -203,6 +234,7 @@ fn atom_expr(p: &mut Parser) -> Option<CompletedMarker> { T![if] => if_expr(p), T![loop] => loop_expr(p), T![return] => ret_expr(p), + T![break] => break_expr(p, r), _ => { p.error_recover("expected expression", EXPR_RECOVERY_SET); return None; @@ -262,7 +294,7 @@ fn loop_expr(p: &mut Parser) -> CompletedMarker { fn cond(p: &mut Parser) { let m = p.start(); - expr(p); + expr_no_struct(p); m.complete(p, CONDITION); } @@ -275,3 +307,13 @@ fn ret_expr(p: &mut Parser) -> CompletedMarker { } m.complete(p, RETURN_EXPR) } + +fn break_expr(p: &mut Parser, r: Restrictions) -> CompletedMarker { + assert!(p.at(T![break])); + let m = p.start(); + p.bump(T![break]); + if p.at_ts(EXPR_FIRST) && !(r.forbid_structs && p.at(T!['{'])) { + expr(p); + } + m.complete(p, BREAK_EXPR) +} diff --git a/crates/mun_syntax/src/parsing/token_set.rs b/crates/mun_syntax/src/parsing/token_set.rs index 69d8b0d15..314064591 100644 --- a/crates/mun_syntax/src/parsing/token_set.rs +++ b/crates/mun_syntax/src/parsing/token_set.rs @@ -28,8 +28,8 @@ const fn mask(kind: SyntaxKind) -> u128 { #[macro_export] macro_rules! token_set { - ($($t:ident),*) => { TokenSet::empty()$(.union(TokenSet::singleton($t)))* }; - ($($t:ident),* ,) => { token_set!($($t),*) }; + ($($t:expr),*) => { TokenSet::empty()$(.union(TokenSet::singleton($t)))* }; + ($($t:expr),* ,) => { token_set!($($t),*) }; } #[test] diff --git a/crates/mun_syntax/src/syntax_kind/generated.rs b/crates/mun_syntax/src/syntax_kind/generated.rs index 88f03d0cf..0bf4058b2 100644 --- a/crates/mun_syntax/src/syntax_kind/generated.rs +++ b/crates/mun_syntax/src/syntax_kind/generated.rs @@ -104,6 +104,7 @@ pub enum SyntaxKind { BLOCK_EXPR, RETURN_EXPR, LOOP_EXPR, + BREAK_EXPR, CONDITION, BIND_PAT, PLACEHOLDER_PAT, @@ -369,6 +370,7 @@ impl SyntaxKind { BLOCK_EXPR => &SyntaxInfo { name: "BLOCK_EXPR" }, RETURN_EXPR => &SyntaxInfo { name: "RETURN_EXPR" }, LOOP_EXPR => &SyntaxInfo { name: "LOOP_EXPR" }, + BREAK_EXPR => &SyntaxInfo { name: "BREAK_EXPR" }, CONDITION => &SyntaxInfo { name: "CONDITION" }, BIND_PAT => &SyntaxInfo { name: "BIND_PAT" }, PLACEHOLDER_PAT => &SyntaxInfo { name: "PLACEHOLDER_PAT" }, diff --git a/crates/mun_syntax/src/tests/parser.rs b/crates/mun_syntax/src/tests/parser.rs index da0e780ec..97624773f 100644 --- a/crates/mun_syntax/src/tests/parser.rs +++ b/crates/mun_syntax/src/tests/parser.rs @@ -186,3 +186,16 @@ fn loop_expr() { }"#, ) } + +#[test] +fn break_expr() { + ok_snapshot_test( + r#" + fn foo() { + break; + if break { 3; } + if break 4 { 3; } + } + "#, + ) +} diff --git a/crates/mun_syntax/src/tests/snapshots/parser__break_expr.snap b/crates/mun_syntax/src/tests/snapshots/parser__break_expr.snap new file mode 100644 index 000000000..ea8474ff3 --- /dev/null +++ b/crates/mun_syntax/src/tests/snapshots/parser__break_expr.snap @@ -0,0 +1,62 @@ +--- +source: crates/mun_syntax/src/tests/parser.rs +expression: "fn foo() {\n break;\n if break { 3; }\n if break 4 { 3; }\n}" +--- +SOURCE_FILE@[0; 65) + FUNCTION_DEF@[0; 65) + FN_KW@[0; 2) "fn" + WHITESPACE@[2; 3) " " + NAME@[3; 6) + IDENT@[3; 6) "foo" + PARAM_LIST@[6; 8) + L_PAREN@[6; 7) "(" + R_PAREN@[7; 8) ")" + WHITESPACE@[8; 9) " " + BLOCK_EXPR@[9; 65) + L_CURLY@[9; 10) "{" + WHITESPACE@[10; 15) "\n " + EXPR_STMT@[15; 21) + BREAK_EXPR@[15; 20) + BREAK_KW@[15; 20) "break" + SEMI@[20; 21) ";" + WHITESPACE@[21; 26) "\n " + EXPR_STMT@[26; 41) + IF_EXPR@[26; 41) + IF_KW@[26; 28) "if" + WHITESPACE@[28; 29) " " + CONDITION@[29; 34) + BREAK_EXPR@[29; 34) + BREAK_KW@[29; 34) "break" + WHITESPACE@[34; 35) " " + BLOCK_EXPR@[35; 41) + L_CURLY@[35; 36) "{" + WHITESPACE@[36; 37) " " + EXPR_STMT@[37; 39) + LITERAL@[37; 38) + INT_NUMBER@[37; 38) "3" + SEMI@[38; 39) ";" + WHITESPACE@[39; 40) " " + R_CURLY@[40; 41) "}" + WHITESPACE@[41; 46) "\n " + IF_EXPR@[46; 63) + IF_KW@[46; 48) "if" + WHITESPACE@[48; 49) " " + CONDITION@[49; 56) + BREAK_EXPR@[49; 56) + BREAK_KW@[49; 54) "break" + WHITESPACE@[54; 55) " " + LITERAL@[55; 56) + INT_NUMBER@[55; 56) "4" + WHITESPACE@[56; 57) " " + BLOCK_EXPR@[57; 63) + L_CURLY@[57; 58) "{" + WHITESPACE@[58; 59) " " + EXPR_STMT@[59; 61) + LITERAL@[59; 60) + INT_NUMBER@[59; 60) "3" + SEMI@[60; 61) ";" + WHITESPACE@[61; 62) " " + R_CURLY@[62; 63) "}" + WHITESPACE@[63; 64) "\n" + R_CURLY@[64; 65) "}" + From 78c6181e528da237429986b4b0ba956a28f4538d Mon Sep 17 00:00:00 2001 From: Bas Zalmstra <zalmstra.bas@gmail.com> Date: Sat, 23 Nov 2019 14:41:54 +0100 Subject: [PATCH 2/3] feat: breaks in loop type checking --- crates/mun/test/main.mun | 33 ++------ crates/mun_hir/src/diagnostics.rs | 24 ++++++ crates/mun_hir/src/ty/infer.rs | 82 +++++++++++++++++-- .../src/ty/snapshots/tests__infer_break.snap | 40 +++++++++ crates/mun_hir/src/ty/tests.rs | 16 ++++ 5 files changed, 164 insertions(+), 31 deletions(-) create mode 100644 crates/mun_hir/src/ty/snapshots/tests__infer_break.snap diff --git a/crates/mun/test/main.mun b/crates/mun/test/main.mun index 9697d054a..dd9c9eb89 100644 --- a/crates/mun/test/main.mun +++ b/crates/mun/test/main.mun @@ -1,25 +1,8 @@ -// function to subtract two floats -fn subtract(a:float, b:float):float { - a-b -} - -// function to subtract two floats -fn multiply(a:float, b:float):float { - a*b -} - -fn main():int { - add(5, 3) -} - -fn add_impl(a:int, b:int):int { - a+b -} - -fn add(a:int, b:int):int { - add_impl(a,b) -} - -fn test():int { - add(4,5) -} +fn foo():int { + break; // error: not in a loop + loop { break 3; break 3.0; } // error: mismatched type + let a:int = loop { break 3.0; } // error: mismatched type + loop { break 3; } + let a:int = loop { break loop { break 3; } } + loop { break loop { break 3.0; } } // error: mismatched type + } diff --git a/crates/mun_hir/src/diagnostics.rs b/crates/mun_hir/src/diagnostics.rs index 4ba5e94a4..ab2896cd0 100644 --- a/crates/mun_hir/src/diagnostics.rs +++ b/crates/mun_hir/src/diagnostics.rs @@ -360,3 +360,27 @@ impl Diagnostic for ReturnMissingExpression { self } } + +#[derive(Debug)] +pub struct BreakOutsideLoop { + pub file: FileId, + pub break_expr: SyntaxNodePtr, +} + +impl Diagnostic for BreakOutsideLoop { + fn message(&self) -> String { + "`break` outside of a loop".to_owned() + } + + fn file(&self) -> FileId { + self.file + } + + fn syntax_node_ptr(&self) -> SyntaxNodePtr { + self.break_expr + } + + fn as_any(&self) -> &(dyn Any + Send + 'static) { + self + } +} diff --git a/crates/mun_hir/src/ty/infer.rs b/crates/mun_hir/src/ty/infer.rs index 72606710f..e4686f89f 100644 --- a/crates/mun_hir/src/ty/infer.rs +++ b/crates/mun_hir/src/ty/infer.rs @@ -104,6 +104,12 @@ struct InferenceResultBuilder<'a, D: HirDatabase> { type_variables: TypeVariableTable, + /// Information on the current loop that we're processing (or None if we're not in a loop) the + /// entry contains the current type of the loop statement (initially `never`) and the expected + /// type of the loop expression. Both these values are updated when a break statement is + /// encountered. + active_loop: Option<(Ty, Expectation)>, + /// The return type of the function being inferred. return_ty: Ty, } @@ -115,6 +121,7 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { type_of_expr: ArenaMap::default(), type_of_pat: ArenaMap::default(), diagnostics: Vec::default(), + active_loop: None, type_variables: TypeVariableTable::default(), db, body, @@ -306,10 +313,8 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { Ty::simple(TypeCtor::Never) } - Expr::Loop { body } => { - self.infer_expr(*body, &Expectation::has_type(Ty::Empty)); - Ty::simple(TypeCtor::Never) - } + Expr::Break { expr } => self.infer_break(tgt_expr, *expr), + Expr::Loop { body } => self.infer_loop_expr(tgt_expr, *body, expected), _ => Ty::Unknown, // Expr::UnaryOp { expr: _, op: _ } => {} // Expr::Block { statements: _, tail: _ } => {} @@ -513,6 +518,61 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { } } + fn infer_break(&mut self, tgt_expr: ExprId, expr: Option<ExprId>) -> Ty { + // Fetch the expected type + let expected = if let Some((_, info)) = &self.active_loop { + info.clone() + } else { + self.diagnostics + .push(InferenceDiagnostic::BreakOutsideLoop { id: tgt_expr }); + return Ty::simple(TypeCtor::Never); + }; + + // Infer the type of the break expression + let ty = if let Some(expr) = expr { + self.infer_expr_inner(expr, &expected) + } else { + Ty::Empty + }; + + // Verify that it matches what we expected + let ty = if !expected.is_none() && ty != expected.ty { + self.diagnostics.push(InferenceDiagnostic::MismatchedTypes { + expected: expected.ty.clone(), + found: ty.clone(), + id: tgt_expr, + }); + expected.ty + } else { + ty + }; + + // Update the expected type for the rest of the loop + self.active_loop = Some((ty.clone(), Expectation::has_type(ty))); + + Ty::simple(TypeCtor::Never) + } + + fn infer_loop_expr(&mut self, _tgt_expr: ExprId, body: ExprId, expected: &Expectation) -> Ty { + self.infer_loop_block(body, expected) + } + + /// Infers the type of a loop body, taking into account breaks. + fn infer_loop_block(&mut self, body: ExprId, expected: &Expectation) -> Ty { + // Take the previous loop information and replace it with a new entry + let top_level_loop = std::mem::replace( + &mut self.active_loop, + Some((Ty::simple(TypeCtor::Never), expected.clone())), + ); + + // Infer the body of the loop + self.infer_expr_coerce(body, &Expectation::has_type(Ty::Empty)); + + // Take the result of the loop information and replace with top level loop + let (ty, _) = std::mem::replace(&mut self.active_loop, top_level_loop).unwrap(); + ty + } + pub fn report_pat_inference_failure(&mut self, _pat: PatId) { // self.diagnostics.push(InferenceDiagnostic::PatInferenceFailed { // pat @@ -573,8 +633,8 @@ impl From<PatId> for ExprOrPatId { mod diagnostics { use crate::diagnostics::{ - CannotApplyBinaryOp, ExpectedFunction, IncompatibleBranch, InvalidLHS, MismatchedType, - MissingElseBranch, ParameterCountMismatch, ReturnMissingExpression, + BreakOutsideLoop, CannotApplyBinaryOp, ExpectedFunction, IncompatibleBranch, InvalidLHS, + MismatchedType, MissingElseBranch, ParameterCountMismatch, ReturnMissingExpression, }; use crate::{ code_model::src::HasSource, @@ -627,6 +687,9 @@ mod diagnostics { ReturnMissingExpression { id: ExprId, }, + BreakOutsideLoop { + id: ExprId, + }, } impl InferenceDiagnostic { @@ -736,6 +799,13 @@ mod diagnostics { return_expr: id, }); } + InferenceDiagnostic::BreakOutsideLoop { id } => { + let id = body.expr_syntax(*id).unwrap().ast.syntax_node_ptr(); + sink.push(BreakOutsideLoop { + file, + break_expr: id, + }); + } } } } diff --git a/crates/mun_hir/src/ty/snapshots/tests__infer_break.snap b/crates/mun_hir/src/ty/snapshots/tests__infer_break.snap new file mode 100644 index 000000000..be793ed35 --- /dev/null +++ b/crates/mun_hir/src/ty/snapshots/tests__infer_break.snap @@ -0,0 +1,40 @@ +--- +source: crates/mun_hir/src/ty/tests.rs +expression: "fn foo():int {\n break; // error: not in a loop\n loop { break 3; break 3.0; } // error: mismatched type\n let a:int = loop { break 3.0; } // error: mismatched type\n loop { break 3; }\n let a:int = loop { break loop { break 3; } }\n loop { break loop { break 3.0; } } // error: mismatched type\n}" +--- +[19; 24): `break` outside of a loop +[70; 79): mismatched type +[132; 141): mismatched type +[266; 275): mismatched type +[13; 308) '{ ...type }': never +[19; 24) 'break': never +[54; 82) 'loop {...3.0; }': int +[59; 82) '{ brea...3.0; }': never +[61; 68) 'break 3': never +[67; 68) '3': int +[70; 79) 'break 3.0': never +[76; 79) '3.0': float +[117; 118) 'a': int +[125; 144) 'loop {...3.0; }': int +[130; 144) '{ break 3.0; }': never +[132; 141) 'break 3.0': never +[138; 141) '3.0': float +[175; 192) 'loop {...k 3; }': int +[180; 192) '{ break 3; }': never +[182; 189) 'break 3': never +[188; 189) '3': int +[201; 202) 'a': int +[209; 241) 'loop {...3; } }': int +[214; 241) '{ brea...3; } }': never +[216; 239) 'break ...k 3; }': never +[222; 239) 'loop {...k 3; }': int +[227; 239) '{ break 3; }': never +[229; 236) 'break 3': never +[235; 236) '3': int +[246; 280) 'loop {...0; } }': int +[251; 280) '{ brea...0; } }': never +[253; 278) 'break ...3.0; }': never +[259; 278) 'loop {...3.0; }': int +[264; 278) '{ break 3.0; }': never +[266; 275) 'break 3.0': never +[272; 275) '3.0': float diff --git a/crates/mun_hir/src/ty/tests.rs b/crates/mun_hir/src/ty/tests.rs index bb077989f..f65ad8214 100644 --- a/crates/mun_hir/src/ty/tests.rs +++ b/crates/mun_hir/src/ty/tests.rs @@ -128,6 +128,22 @@ fn infer_loop() { ) } +#[test] +fn infer_break() { + infer_snapshot( + r#" + fn foo():int { + break; // error: not in a loop + loop { break 3; break 3.0; } // error: mismatched type + let a:int = loop { break 3.0; } // error: mismatched type + loop { break 3; } + let a:int = loop { break loop { break 3; } } + loop { break loop { break 3.0; } } // error: mismatched type + } + "#, + ) +} + fn infer_snapshot(text: &str) { let text = text.trim().replace("\n ", "\n"); insta::assert_snapshot!(insta::_macro_support::AutoName, infer(&text), &text); From 048a362c29b32afa2554beedaffb8eeeaaad2723 Mon Sep 17 00:00:00 2001 From: Bas Zalmstra <zalmstra.bas@gmail.com> Date: Sat, 23 Nov 2019 15:40:36 +0100 Subject: [PATCH 3/3] feat: ir generation for breaks --- crates/mun_codegen/src/ir/body.rs | 53 ++++++++++++++++++- .../src/snapshots/test__loop_break_expr.snap | 29 ++++++++++ crates/mun_codegen/src/test.rs | 19 +++++++ crates/mun_runtime/src/test.rs | 27 ++++++++++ 4 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 crates/mun_codegen/src/snapshots/test__loop_break_expr.snap diff --git a/crates/mun_codegen/src/ir/body.rs b/crates/mun_codegen/src/ir/body.rs index ae77b23c6..20bd7c1fe 100644 --- a/crates/mun_codegen/src/ir/body.rs +++ b/crates/mun_codegen/src/ir/body.rs @@ -11,8 +11,17 @@ use mun_hir::{ }; use std::{collections::HashMap, mem, sync::Arc}; +use inkwell::basic_block::BasicBlock; use inkwell::values::PointerValue; +struct LoopInfo { + break_values: Vec<( + inkwell::values::BasicValueEnum, + inkwell::basic_block::BasicBlock, + )>, + exit_block: BasicBlock, +} + pub(crate) struct BodyIrGenerator<'a, 'b, D: IrDatabase> { db: &'a D, module: &'a Module, @@ -25,6 +34,7 @@ pub(crate) struct BodyIrGenerator<'a, 'b, D: IrDatabase> { pat_to_name: HashMap<PatId, String>, function_map: &'a HashMap<mun_hir::Function, FunctionValue>, dispatch_table: &'b DispatchTable, + active_loop: Option<LoopInfo>, } impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { @@ -58,6 +68,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { pat_to_name: HashMap::default(), function_map, dispatch_table, + active_loop: None, } } @@ -132,6 +143,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { } => self.gen_if(expr, *condition, *then_branch, *else_branch), Expr::Return { expr: ret_expr } => self.gen_return(expr, *ret_expr), Expr::Loop { body } => self.gen_loop(expr, *body), + Expr::Break { expr: break_expr } => self.gen_break(expr, *break_expr), _ => unimplemented!("unimplemented expr type {:?}", &body[expr]), } } @@ -575,9 +587,31 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { None } + fn gen_break(&mut self, _expr: ExprId, break_expr: Option<ExprId>) -> Option<BasicValueEnum> { + let break_value = break_expr.and_then(|expr| self.gen_expr(expr)); + let loop_info = self.active_loop.as_mut().unwrap(); + if let Some(break_value) = break_value { + loop_info + .break_values + .push((break_value, self.builder.get_insert_block().unwrap())); + } + self.builder + .build_unconditional_branch(&loop_info.exit_block); + None + } + fn gen_loop(&mut self, _expr: ExprId, body_expr: ExprId) -> Option<BasicValueEnum> { let context = self.module.get_context(); let loop_block = context.append_basic_block(&self.fn_value, "loop"); + let exit_block = context.append_basic_block(&self.fn_value, "exit"); + + // Build a new loop info struct + let loop_info = LoopInfo { + exit_block, + break_values: Vec::new(), + }; + + let prev_loop = std::mem::replace(&mut self.active_loop, Some(loop_info)); // Insert an explicit fall through from the current block to the loop self.builder.build_unconditional_branch(&loop_block); @@ -589,6 +623,23 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { // Jump to the start of the loop self.builder.build_unconditional_branch(&loop_block); - None + let LoopInfo { + exit_block, + break_values, + } = std::mem::replace(&mut self.active_loop, prev_loop).unwrap(); + + // Move the builder to the exit block + self.builder.position_at_end(&exit_block); + + if !break_values.is_empty() { + let (value, _) = break_values.first().unwrap(); + let phi = self.builder.build_phi(value.get_type(), "exit"); + for (ref value, ref block) in break_values { + phi.add_incoming(&[(value, block)]) + } + Some(phi.as_basic_value()) + } else { + None + } } } diff --git a/crates/mun_codegen/src/snapshots/test__loop_break_expr.snap b/crates/mun_codegen/src/snapshots/test__loop_break_expr.snap new file mode 100644 index 000000000..d6e3e9909 --- /dev/null +++ b/crates/mun_codegen/src/snapshots/test__loop_break_expr.snap @@ -0,0 +1,29 @@ +--- +source: crates/mun_codegen/src/test.rs +expression: "fn foo(n:int):int {\n loop {\n if n > 5 {\n break n;\n }\n if n > 10 {\n break 10;\n }\n n += 1;\n }\n}" +--- +; ModuleID = 'main.mun' +source_filename = "main.mun" + +define i64 @foo(i64) { +body: + br label %loop + +loop: ; preds = %if_merge6, %body + %n.0 = phi i64 [ %0, %body ], [ %add, %if_merge6 ] + %greater = icmp sgt i64 %n.0, 5 + br i1 %greater, label %exit, label %if_merge + +exit: ; preds = %if_merge, %loop + %exit8 = phi i64 [ %n.0, %loop ], [ 10, %if_merge ] + ret i64 %exit8 + +if_merge: ; preds = %loop + %greater4 = icmp sgt i64 %n.0, 10 + br i1 %greater4, label %exit, label %if_merge6 + +if_merge6: ; preds = %if_merge + %add = add i64 %n.0, 1 + br label %loop +} + diff --git a/crates/mun_codegen/src/test.rs b/crates/mun_codegen/src/test.rs index 29285116b..eaaed1e2c 100644 --- a/crates/mun_codegen/src/test.rs +++ b/crates/mun_codegen/src/test.rs @@ -327,6 +327,25 @@ fn loop_expr() { ) } +#[test] +fn loop_break_expr() { + test_snapshot( + r#" + fn foo(n:int):int { + loop { + if n > 5 { + break n; + } + if n > 10 { + break 10; + } + n += 1; + } + } + "#, + ) +} + fn test_snapshot(text: &str) { let text = text.trim().replace("\n ", "\n"); diff --git a/crates/mun_runtime/src/test.rs b/crates/mun_runtime/src/test.rs index 9facd2de6..c2573cb76 100644 --- a/crates/mun_runtime/src/test.rs +++ b/crates/mun_runtime/src/test.rs @@ -214,6 +214,33 @@ fn fibonacci_loop() { assert_invoke_eq!(i64, 46368, driver, "fibonacci", 24i64); } +#[test] +fn fibonacci_loop_break() { + let mut driver = TestDriver::new( + r#" + fn fibonacci(n:int):int { + let a = 0; + let b = 1; + let i = 1; + loop { + if i > n { + break a; + } + let sum = a + b; + a = b; + b = sum; + i += 1; + } + } + "#, + ); + + assert_invoke_eq!(i64, 5, driver, "fibonacci", 5i64); + assert_invoke_eq!(i64, 89, driver, "fibonacci", 11i64); + assert_invoke_eq!(i64, 987, driver, "fibonacci", 16i64); + assert_invoke_eq!(i64, 46368, driver, "fibonacci", 24i64); +} + #[test] fn true_is_true() { let mut driver = TestDriver::new(