diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala index 7b2d3e9b82..59116636ea 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala @@ -374,8 +374,8 @@ class FormatOps( if (afterInfix ne Newlines.AfterInfix.keep) if (isBeforeOp) Seq(Split(spaceMod, 0)) else { - val (fullInfix, isEnclosed) = InfixSplits.findMaybeEnclosingInfix(app) - val ok = isEnclosed || fullInfix.parent.forall { + val (fullInfix, enclosedIn) = InfixSplits.findMaybeEnclosingInfix(app) + val ok = enclosedIn.isDefined || fullInfix.parent.forall { case t: Defn.Val => t.rhs eq fullInfix case t: Defn.Var => t.body eq fullInfix case _ => true @@ -387,9 +387,10 @@ class FormatOps( else { // we don't modify line breaks generally around infix expressions // TODO: if that ever changes, modify how rewrite rules handle infix - val (fullInfix, isFullInfixEnclosed) = InfixSplits + val (fullInfix, fullInfixEnclosedIn) = InfixSplits .findMaybeEnclosingInfix(app) - def okToBreak: Boolean = !isBeforeOp || isFullInfixEnclosed || + val fullInfixEnclosedInParens = fullInfixEnclosedIn.exists(_.isRight) + def okToBreak: Boolean = !isBeforeOp || fullInfixEnclosedInParens || initStyle.dialect.allowInfixOperatorAfterNL || (fullInfix.parent match { case Some(p: Case) => p.cond.contains(fullInfix) @@ -401,7 +402,7 @@ class FormatOps( } val mod = if (ft.noBreak || !okToBreak) spaceMod - else Newline2x(isFullInfixEnclosed && ft.hasBlankLine) + else Newline2x(fullInfixEnclosedInParens && ft.hasBlankLine) def split(implicit fl: FileLine) = Split(mod, 0) if (isBeforeOp && isFewerBracesRhs(app.arg)) Seq(split) else Seq(InfixSplits.withNLIndent(split, app, fullInfix)) @@ -452,8 +453,9 @@ class FormatOps( private def findMaybeEnclosingInfix( child: Member.Infix, childTree: Tree, - ): (Member.Infix, Boolean) = - if (isEnclosedWithinParensOrBraces(childTree)) (child, true) + ): (Member.Infix, Option[Either[FT, FT]]) = { + val inParensOrBraces = getClosingIfWithinParensOrBraces(childTree) + if (inParensOrBraces.isDefined) (child, inParensOrBraces) else childTree.parent match { case Some(p: Member.Infix) if !p.isAssignment => findMaybeEnclosingInfix(p, p) @@ -461,12 +463,13 @@ class FormatOps( findMaybeEnclosingInfix(child, p) case Some(p @ Tree.Block(`childTree` :: Nil)) => findMaybeEnclosingInfix(child, p) - case _ => (child, false) + case _ => (child, None) } + } private[FormatOps] def findMaybeEnclosingInfix( app: Member.Infix, - ): (Member.Infix, Boolean) = findMaybeEnclosingInfix(app, app) + ): (Member.Infix, Option[Either[FT, FT]]) = findMaybeEnclosingInfix(app, app) private[FormatOps] def findEnclosingInfix(app: Member.Infix): Member.Infix = findMaybeEnclosingInfix(app)._1 @@ -1745,7 +1748,7 @@ class FormatOps( val op = t.op.value op == "->" || op == "→" } => None - case _ => getClosingIfWithinParens(body).toOption + case _ => getClosingIfWithinParens(body) } def isBodyEnclosedAsBlock(body: Tree): Boolean = @@ -1887,7 +1890,7 @@ class FormatOps( case _: Member.Tuple => None case Term.Block((_: Member.Tuple) :: Nil) if !head.left.is[T.LeftBrace] => None - case _ => tokens.getClosingIfWithinParens(nonTrivialEnd)(head).toOption + case _ => tokens.getClosingIfWithinParens(nonTrivialEnd)(head) .map(prevNonCommentSameLine) }).getOrElse(nextNonCommentSameLine(end)) def nlPolicy(implicit fileLine: FileLine) = Policy ? danglingKeyword && { @@ -2504,7 +2507,7 @@ class FormatOps( val beg = getOnOrBeforeOwned(hft, tree) val end = getLastNonTrivial(treeTokens, tree) - val kw = next(getClosingIfWithinParens(end)(beg).toOption.fold(end)(next)) + val kw = next(getClosingIfWithinParens(end)(beg).fold(end)(next)) if (!kw.left.is[A]) return None val indent = style.indent.ctrlSite.getOrElse(style.indent.getSignificant) diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala index 63e4a87e10..823d8cf587 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala @@ -133,28 +133,37 @@ class FormatTokens(leftTok2tok: Map[TokenHash, Int])(val arr: Array[FT]) .exists(_._1.left.is[T.LeftBrace]) def isEnclosedWithinParens(tree: Tree): Boolean = - getClosingIfWithinParens(tree).isRight + getClosingIfWithinParens(tree).isDefined + + def getClosingIfWithinParens(tree: Tree): Option[FT] = + getClosingIfWithinParensOrBraces(tree).flatMap(_.toOption) + + def getClosingIfWithinParens(last: FT)(head: FT): Option[FT] = + getClosingIfWithinParensOrBraces(last)(head).flatMap(_.toOption) def isEnclosedWithinParensOrBraces(tree: Tree): Boolean = - getClosingIfWithinParens(tree) != Left(false) + getClosingIfWithinParensOrBraces(tree).isDefined - def getClosingIfWithinParens(last: FT)(head: FT): Either[Boolean, FT] = { + def getClosingIfWithinParensOrBraces( + last: FT, + )(head: FT): Option[Either[FT, FT]] = { val innerMatched = matchingOptLeft(last).contains(head) - if (innerMatched && last.left.is[T.RightParen]) Right(prev(last)) + if (innerMatched && last.left.is[T.RightParen]) Some(Right(prev(last))) else { val afterLast = nextNonComment(last) - if (matchingOptRight(afterLast).exists(_ eq prevNonCommentBefore(head))) - if (afterLast.right.is[T.RightParen]) Right(afterLast) else Left(true) - else Left(innerMatched) + if (!matchingOptRight(afterLast).exists(_ eq prevNonCommentBefore(head))) + if (innerMatched) Some(Left(prev(last))) else None + else + Some(Either.cond(afterLast.right.is[T.RightParen], afterLast, afterLast)) } } - def getClosingIfWithinParens(tree: Tree): Either[Boolean, FT] = { + def getClosingIfWithinParensOrBraces(tree: Tree): Option[Either[FT, FT]] = { val tokens = tree.tokens getHeadOpt(tokens, tree) match { case Some(head) => - getClosingIfWithinParens(getLastNonTrivial(tokens, tree))(head) - case None => Left(false) + getClosingIfWithinParensOrBraces(getLastNonTrivial(tokens, tree))(head) + case None => None } }