Skip to content

Commit e261fa2

Browse files
authored
SIP-62 - For comprehension improvements (#20522)
Implementation for SIP-62. ### Summary of the changes For more details read the committed markdown file here: scala/improvement-proposals#79 This introduces improvements to `for` comprehensions in Scala to improve usability and simplify desugaring. The changes are hidden behind a language import `scala.language.experimental.betterFors`. The main changes are: 1. **Starting `for` comprehensions with aliases**: - **Current Syntax**: ```scala val a = 1 for { b <- Some(2) c <- doSth(a) } yield b + c ``` - **New Syntax**: ```scala for { a = 1 b <- Some(2) c <- doSth(a) } yield b + c ``` 2. **Simpler Desugaring for Pure Aliases**: - **Current Desugaring**: ```scala for { a <- doSth(arg) b = a } yield a + b ``` Desugars to: ```scala doSth(arg).map { a => val b = a (a, b) }.map { case (a, b) => a + b } ``` - **New Desugaring**: (where possible) ```scala doSth(arg).map { a => val b = a a + b } ``` 3. **Avoiding Redundant `map` Calls**: - **Current Desugaring**: ```scala for { a <- List(1, 2, 3) } yield a ``` Desugars to: ```scala List(1, 2, 3).map(a => a) ``` - **New Desugaring**: ```scala List(1, 2, 3) ```
2 parents 1b644f6 + 4bc0a4a commit e261fa2

File tree

10 files changed

+257
-25
lines changed

10 files changed

+257
-25
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+88-21
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import NameKinds.{UniqueName, ContextBoundParamName, ContextFunctionParamName, D
1111
import typer.{Namer, Checking}
1212
import util.{Property, SourceFile, SourcePosition, SrcPos, Chars}
1313
import config.{Feature, Config}
14+
import config.Feature.{sourceVersion, migrateTo3, enabled, betterForsEnabled}
1415
import config.SourceVersion.*
1516
import collection.mutable
1617
import reporting.*
@@ -1815,46 +1816,81 @@ object desugar {
18151816
/** Create tree for for-comprehension `<for (enums) do body>` or
18161817
* `<for (enums) yield body>` where mapName and flatMapName are chosen
18171818
* corresponding to whether this is a for-do or a for-yield.
1818-
* The creation performs the following rewrite rules:
1819+
* If betterFors are enabled, the creation performs the following rewrite rules:
18191820
*
1820-
* 1.
1821+
* 1. if betterFors is enabled:
18211822
*
1822-
* for (P <- G) E ==> G.foreach (P => E)
1823+
* for () do E ==> E
1824+
* or
1825+
* for () yield E ==> E
18231826
*
1824-
* Here and in the following (P => E) is interpreted as the function (P => E)
1825-
* if P is a variable pattern and as the partial function { case P => E } otherwise.
1827+
* (Where empty for-comprehensions are excluded by the parser)
18261828
*
18271829
* 2.
18281830
*
1829-
* for (P <- G) yield E ==> G.map (P => E)
1831+
* for (P <- G) do E ==> G.foreach (P => E)
1832+
*
1833+
* Here and in the following (P => E) is interpreted as the function (P => E)
1834+
* if P is a variable pattern and as the partial function { case P => E } otherwise.
18301835
*
18311836
* 3.
18321837
*
1838+
* for (P <- G) yield P ==> G
1839+
*
1840+
* If betterFors is enabled, P is a variable or a tuple of variables and G is not a withFilter.
1841+
*
1842+
* for (P <- G) yield E ==> G.map (P => E)
1843+
*
1844+
* Otherwise
1845+
*
1846+
* 4.
1847+
*
18331848
* for (P_1 <- G_1; P_2 <- G_2; ...) ...
18341849
* ==>
18351850
* G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...)
18361851
*
1837-
* 4.
1852+
* 5.
18381853
*
1839-
* for (P <- G; E; ...) ...
1840-
* =>
1841-
* for (P <- G.filter (P => E); ...) ...
1854+
* for (P <- G; if E; ...) ...
1855+
* ==>
1856+
* for (P <- G.withFilter (P => E); ...) ...
18421857
*
1843-
* 5. For any N:
1858+
* 6. For any N, if betterFors is enabled:
18441859
*
1845-
* for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...)
1860+
* for (P <- G; P_1 = E_1; ... P_N = E_N; P1 <- G1; ...) ...
18461861
* ==>
1847-
* for (TupleN(P_1, P_2, ... P_N) <-
1848-
* for (x_1 @ P_1 <- G) yield {
1849-
* val x_2 @ P_2 = E_2
1862+
* G.flatMap (P => for (P_1 = E_1; ... P_N = E_N; ...))
1863+
*
1864+
* 7. For any N, if betterFors is enabled:
1865+
*
1866+
* for (P <- G; P_1 = E_1; ... P_N = E_N) ...
1867+
* ==>
1868+
* G.map (P => for (P_1 = E_1; ... P_N = E_N) ...)
1869+
*
1870+
* 8. For any N:
1871+
*
1872+
* for (P <- G; P_1 = E_1; ... P_N = E_N; ...)
1873+
* ==>
1874+
* for (TupleN(P, P_1, ... P_N) <-
1875+
* for (x @ P <- G) yield {
1876+
* val x_1 @ P_1 = E_2
18501877
* ...
1851-
* val x_N & P_N = E_N
1852-
* TupleN(x_1, ..., x_N)
1853-
* } ...)
1878+
* val x_N @ P_N = E_N
1879+
* TupleN(x, x_1, ..., x_N)
1880+
* }; if E; ...)
18541881
*
18551882
* If any of the P_i are variable patterns, the corresponding `x_i @ P_i` is not generated
18561883
* and the variable constituting P_i is used instead of x_i
18571884
*
1885+
* 9. For any N, if betterFors is enabled:
1886+
*
1887+
* for (P_1 = E_1; ... P_N = E_N; ...)
1888+
* ==>
1889+
* {
1890+
* val x_N @ P_N = E_N
1891+
* for (...)
1892+
* }
1893+
*
18581894
* @param mapName The name to be used for maps (either map or foreach)
18591895
* @param flatMapName The name to be used for flatMaps (either flatMap or foreach)
18601896
* @param enums The enumerators in the for expression
@@ -1963,7 +1999,7 @@ object desugar {
19631999
case GenCheckMode.FilterAlways => false // pattern was prefixed by `case`
19642000
case GenCheckMode.FilterNow | GenCheckMode.CheckAndFilter => isVarBinding(gen.pat) || isIrrefutable(gen.pat, gen.expr)
19652001
case GenCheckMode.Check => true
1966-
case GenCheckMode.Ignore => true
2002+
case GenCheckMode.Ignore | GenCheckMode.Filtered => true
19672003

19682004
/** rhs.name with a pattern filter on rhs unless `pat` is irrefutable when
19692005
* matched against `rhs`.
@@ -1973,12 +2009,31 @@ object desugar {
19732009
Select(rhs, name)
19742010
}
19752011

2012+
def deepEquals(t1: Tree, t2: Tree): Boolean =
2013+
(unsplice(t1), unsplice(t2)) match
2014+
case (Ident(n1), Ident(n2)) => n1 == n2
2015+
case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals)
2016+
case _ => false
2017+
19762018
enums match {
2019+
case Nil if betterForsEnabled => body
19772020
case (gen: GenFrom) :: Nil =>
1978-
Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
2021+
if betterForsEnabled
2022+
&& gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
2023+
&& deepEquals(gen.pat, body)
2024+
then gen.expr // avoid a redundant map with identity
2025+
else Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
19792026
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
19802027
val cont = makeFor(mapName, flatMapName, rest, body)
19812028
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
2029+
case (gen: GenFrom) :: rest
2030+
if betterForsEnabled
2031+
&& rest.dropWhile(_.isInstanceOf[GenAlias]).headOption.forall(e => e.isInstanceOf[GenFrom]) => // possible aliases followed by a generator or end of for
2032+
val cont = makeFor(mapName, flatMapName, rest, body)
2033+
val selectName =
2034+
if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName
2035+
else mapName
2036+
Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
19822037
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
19832038
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
19842039
val pats = valeqs map { case GenAlias(pat, _) => pat }
@@ -1997,8 +2052,20 @@ object desugar {
19972052
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
19982053
case (gen: GenFrom) :: test :: rest =>
19992054
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
2000-
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore)
2055+
val genFrom = GenFrom(gen.pat, filtered, if betterForsEnabled then GenCheckMode.Filtered else GenCheckMode.Ignore)
20012056
makeFor(mapName, flatMapName, genFrom :: rest, body)
2057+
case GenAlias(_, _) :: _ if betterForsEnabled =>
2058+
val (valeqs, rest) = enums.span(_.isInstanceOf[GenAlias])
2059+
val pats = valeqs.map { case GenAlias(pat, _) => pat }
2060+
val rhss = valeqs.map { case GenAlias(_, rhs) => rhs }
2061+
val (defpats, ids) = pats.map(makeIdPat).unzip
2062+
val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) =>
2063+
val mods = defpat match
2064+
case defTree: DefTree => defTree.mods
2065+
case _ => Modifiers()
2066+
makePatDef(valeq, mods, defpat, rhs)
2067+
}
2068+
Block(pdefs, makeFor(mapName, flatMapName, rest, body))
20022069
case _ =>
20032070
EmptyTree //may happen for erroneous input
20042071
}

compiler/src/dotty/tools/dotc/ast/untpd.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
183183

184184
/** An enum to control checking or filtering of patterns in GenFrom trees */
185185
enum GenCheckMode {
186-
case Ignore // neither filter nor check since filtering was done before
186+
case Ignore // neither filter nor check since pattern is trivially irrefutable
187+
case Filtered // neither filter nor check since filtering was done before
187188
case Check // check that pattern is irrefutable
188189
case CheckAndFilter // both check and filter (transitional period starting with 3.2)
189190
case FilterNow // filter out non-matching elements if we are not in 3.2 or later

compiler/src/dotty/tools/dotc/config/Feature.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ object Feature:
3838
val modularity = experimental("modularity")
3939
val betterMatchTypeExtractors = experimental("betterMatchTypeExtractors")
4040
val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions")
41+
val betterFors = experimental("betterFors")
4142

4243
def experimentalAutoEnableFeatures(using Context): List[TermName] =
4344
defn.languageExperimentalFeatures
@@ -67,7 +68,8 @@ object Feature:
6768
(into, "Allow into modifier on parameter types"),
6869
(namedTuples, "Allow named tuples"),
6970
(modularity, "Enable experimental modularity features"),
70-
(betterMatchTypeExtractors, "Enable better match type extractors")
71+
(betterMatchTypeExtractors, "Enable better match type extractors"),
72+
(betterFors, "Enable improvements in `for` comprehensions")
7173
)
7274

7375
// legacy language features from Scala 2 that are no longer supported.
@@ -125,6 +127,8 @@ object Feature:
125127
def clauseInterleavingEnabled(using Context) =
126128
sourceVersion.isAtLeast(`3.6`) || enabled(clauseInterleaving)
127129

130+
def betterForsEnabled(using Context) = enabled(betterFors)
131+
128132
def genericNumberLiteralsEnabled(using Context) = enabled(genericNumberLiterals)
129133

130134
def scala2ExperimentalMacroEnabled(using Context) = enabled(scala2macros)

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+17-1
Original file line numberDiff line numberDiff line change
@@ -2894,7 +2894,11 @@ object Parsers {
28942894

28952895
/** Enumerators ::= Generator {semi Enumerator | Guard}
28962896
*/
2897-
def enumerators(): List[Tree] = generator() :: enumeratorsRest()
2897+
def enumerators(): List[Tree] =
2898+
if in.featureEnabled(Feature.betterFors) then
2899+
aliasesUntilGenerator() ++ enumeratorsRest()
2900+
else
2901+
generator() :: enumeratorsRest()
28982902

28992903
def enumeratorsRest(): List[Tree] =
29002904
if (isStatSep) {
@@ -2936,6 +2940,18 @@ object Parsers {
29362940
GenFrom(pat, subExpr(), checkMode)
29372941
}
29382942

2943+
def aliasesUntilGenerator(): List[Tree] =
2944+
if in.token == CASE then generator() :: Nil
2945+
else {
2946+
val pat = pattern1()
2947+
if in.token == EQUALS then
2948+
atSpan(startOffset(pat), in.skipToken()) { GenAlias(pat, subExpr()) } :: {
2949+
if (isStatSep) in.nextToken()
2950+
aliasesUntilGenerator()
2951+
}
2952+
else generatorRest(pat, casePat = false) :: Nil
2953+
}
2954+
29392955
/** ForExpr ::= ‘for’ ‘(’ Enumerators ‘)’ {nl} [‘do‘ | ‘yield’] Expr
29402956
* | ‘for’ ‘{’ Enumerators ‘}’ {nl} [‘do‘ | ‘yield’] Expr
29412957
* | ‘for’ Enumerators (‘do‘ | ‘yield’) Expr

library/src/scala/runtime/stdLibPatches/language.scala

+6
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,12 @@ object language:
133133
@compileTimeOnly("`quotedPatternsWithPolymorphicFunctions` can only be used at compile time in import statements")
134134
object quotedPatternsWithPolymorphicFunctions
135135

136+
/** Experimental support for improvements in `for` comprehensions
137+
*
138+
* @see [[https://github.com/scala/improvement-proposals/pull/79]]
139+
*/
140+
@compileTimeOnly("`betterFors` can only be used at compile time in import statements")
141+
object betterFors
136142
end experimental
137143

138144
/** The deprecated object contains features that are no longer officially suypported in Scala.

project/MiMaFilters.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ object MiMaFilters {
88
val ForwardsBreakingChanges: Map[String, Seq[ProblemFilter]] = Map(
99
// Additions that require a new minor version of the library
1010
Build.mimaPreviousDottyVersion -> Seq(
11-
11+
ProblemFilters.exclude[MissingFieldProblem]("scala.runtime.stdLibPatches.language#experimental.betterFors"),
12+
ProblemFilters.exclude[MissingClassProblem]("scala.runtime.stdLibPatches.language$experimental$betterFors$"),
1213
),
1314

1415
// Additions since last LTS

tests/run/better-fors.check

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
List((1,3), (1,4), (2,3), (2,4))
2+
List((1,2,3), (1,2,4))
3+
List((1,3), (1,4), (2,3), (2,4))
4+
List((2,3), (2,4))
5+
List((2,3), (2,4))
6+
List((1,2), (2,4))
7+
List(1, 2, 3)
8+
List((2,3,6))
9+
List(6)
10+
List(3, 6)
11+
List(6)
12+
List(2)

tests/run/better-fors.scala

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import scala.language.experimental.betterFors
2+
3+
def for1 =
4+
for {
5+
a = 1
6+
b <- List(a, 2)
7+
c <- List(3, 4)
8+
} yield (b, c)
9+
10+
def for2 =
11+
for
12+
a = 1
13+
b = 2
14+
c <- List(3, 4)
15+
yield (a, b, c)
16+
17+
def for3 =
18+
for {
19+
a = 1
20+
b <- List(a, 2)
21+
c = 3
22+
d <- List(c, 4)
23+
} yield (b, d)
24+
25+
def for4 =
26+
for {
27+
a = 1
28+
b <- List(a, 2)
29+
if b > 1
30+
c <- List(3, 4)
31+
} yield (b, c)
32+
33+
def for5 =
34+
for {
35+
a = 1
36+
b <- List(a, 2)
37+
c = 3
38+
if b > 1
39+
d <- List(c, 4)
40+
} yield (b, d)
41+
42+
def for6 =
43+
for {
44+
a = 1
45+
b = 2
46+
c <- for {
47+
x <- List(a, b)
48+
y = x * 2
49+
} yield (x, y)
50+
} yield c
51+
52+
def for7 =
53+
for {
54+
a <- List(1, 2, 3)
55+
} yield a
56+
57+
def for8 =
58+
for {
59+
a <- List(1, 2)
60+
b = a + 1
61+
if b > 2
62+
c = b * 2
63+
if c < 8
64+
} yield (a, b, c)
65+
66+
def for9 =
67+
for {
68+
a <- List(1, 2)
69+
b = a * 2
70+
if b > 2
71+
} yield a + b
72+
73+
def for10 =
74+
for {
75+
a <- List(1, 2)
76+
b = a * 2
77+
} yield a + b
78+
79+
def for11 =
80+
for {
81+
a <- List(1, 2)
82+
b = a * 2
83+
if b > 2 && b % 2 == 0
84+
} yield a + b
85+
86+
def for12 =
87+
for {
88+
a <- List(1, 2)
89+
if a > 1
90+
} yield a
91+
92+
object Test extends App {
93+
println(for1)
94+
println(for2)
95+
println(for3)
96+
println(for4)
97+
println(for5)
98+
println(for6)
99+
println(for7)
100+
println(for8)
101+
println(for9)
102+
println(for10)
103+
println(for11)
104+
println(for12)
105+
}

tests/run/fors.check

+3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ hello world
4545
hello/1~2 hello/3~4 /1~2 /3~4 world/1~2 world/3~4
4646
(2,1) (4,3)
4747

48+
testTailrec
49+
List((4,Symbol(a)), (5,Symbol(b)), (6,Symbol(c)))
50+
4851
testGivens
4952
123
5053
456

0 commit comments

Comments
 (0)