Skip to content

Commit

Permalink
fix bug, improve scalachecks
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Jan 3, 2025
1 parent 3b2c4bd commit 83fe74c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
27 changes: 17 additions & 10 deletions core/src/main/scala/cats/data/Chain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,15 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
arg match {
case Wrap(seq) =>
if (count == 1) {
lhs.append(seq(0))
lhs.append(seq.head)
} else {
// count > 1
val taken = seq.take(count)
// we may have not takeped all of count
val newCount = count - taken.length
val newLhs = lhs.concat(Wrap(taken))
val wrapped = Wrap(taken)
// this is more efficient than using concat
val newLhs = if (lhs.isEmpty) wrapped else Append(lhs, wrapped)
if (newCount > 0) {
// we have to keep takeping on the rhs
go(newLhs, newCount, rhs, Chain.nil)
Expand All @@ -282,7 +284,7 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
}
}
case Append(l, r) =>
go(lhs, count, l, r.concat(rhs))
go(lhs, count, l, if (rhs.isEmpty) r else Append(r, rhs))
case s @ Singleton(_) =>
// due to the invariant count >= 1
val newLhs = if (lhs.isEmpty) s else Append(lhs, s)
Expand All @@ -308,13 +310,14 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
arg match {
case Wrap(seq) =>
if (count == 1) {
lhs.append(seq.last)
seq.last +: rhs
} else {
// count > 1
val taken = seq.takeRight(count)
// we may have not takeped all of count
val newCount = count - taken.length
val newRhs = Wrap(taken).concat(rhs)
val wrapped = Wrap(taken)
val newRhs = if (rhs.isEmpty) wrapped else Append(wrapped, rhs)
if (newCount > 0) {
// we have to keep takeping on the rhs
go(Chain.nil, newCount, lhs, newRhs)
Expand All @@ -324,7 +327,7 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
}
}
case Append(l, r) =>
go(lhs.concat(l), count, r, rhs)
go(if (lhs.isEmpty) l else Append(lhs, l), count, r, rhs)
case s @ Singleton(_) =>
// due to the invariant count >= 1
val newRhs = if (rhs.isEmpty) s else Append(s, rhs)
Expand Down Expand Up @@ -381,11 +384,13 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
rhs
}
} else {
// dropped is not empty
val wrapped = Wrap(dropped)
// we must be done
Chain.fromSeq(dropped).concat(rhs)
if (rhs.isEmpty) wrapped else Append(wrapped, rhs)
}
case Append(l, r) =>
go(count, l, r.concat(rhs))
go(count, l, if (rhs.isEmpty) r else Append(r, rhs))
case Singleton(_) =>
// due to the invariant count >= 1
if (count > 1) go(count - 1, rhs, Chain.nil)
Expand Down Expand Up @@ -422,10 +427,12 @@ sealed abstract class Chain[+A] extends ChainCompat[A] {
}
} else {
// we must be done
lhs.concat(Chain.fromSeq(dropped))
// note: dropped.nonEmpty
val wrapped = Wrap(dropped)
if (lhs.isEmpty) wrapped else Append(lhs, wrapped)
}
case Append(l, r) =>
go(lhs.concat(l), count, r)
go(if (lhs.isEmpty) l else Append(lhs, l), count, r)
case Singleton(_) =>
// due to the invariant count >= 1
if (count > 1) go(Chain.nil, count - 1, lhs)
Expand Down
19 changes: 15 additions & 4 deletions tests/shared/src/test/scala/cats/tests/ChainSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -449,25 +449,36 @@ class ChainSuite extends CatsSuite {
}
}

private val genChainDropTakeArgs =
Arbitrary.arbitrary[Chain[Int]].flatMap { chain =>
// Bias to values close to the length
Gen
.oneOf(
Gen.choose(Int.MinValue, Int.MaxValue),
Gen.choose(-1, chain.length.toInt + 1)
)
.map((chain, _))
}

test("drop(cnt).toList == toList.drop(cnt)") {
forAll { (chain: Chain[Int], count: Int) =>
forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) =>
assert(chain.drop(count).toList == chain.toList.drop(count))
}
}

test("dropRight(cnt).toList == toList.dropRight(cnt)") {
forAll { (chain: Chain[Int], count: Int) =>
forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) =>
assert(chain.dropRight(count).toList == chain.toList.dropRight(count))
}
}
test("take(cnt).toList == toList.take(cnt)") {
forAll { (chain: Chain[Int], count: Int) =>
forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) =>
assert(chain.take(count).toList == chain.toList.take(count))
}
}

test("takeRight(cnt).toList == toList.takeRight(cnt)") {
forAll { (chain: Chain[Int], count: Int) =>
forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) =>
assert(chain.takeRight(count).toList == chain.toList.takeRight(count))
}
}
Expand Down

0 comments on commit 83fe74c

Please sign in to comment.