From 83fe74c3567c2c99d4e9e3802a5f59281a082df9 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Thu, 2 Jan 2025 18:06:34 -1000 Subject: [PATCH] fix bug, improve scalachecks --- core/src/main/scala/cats/data/Chain.scala | 27 ++++++++++++------- .../test/scala/cats/tests/ChainSuite.scala | 19 ++++++++++--- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/cats/data/Chain.scala b/core/src/main/scala/cats/data/Chain.scala index bb803c9a92..262beaaf87 100644 --- a/core/src/main/scala/cats/data/Chain.scala +++ b/core/src/main/scala/cats/data/Chain.scala @@ -266,13 +266,15 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { arg match { case Wrap(seq) => if (count == 1) { - lhs.append(seq(0)) + lhs.append(seq.head) } else { // count > 1 val taken = seq.take(count) // we may have not takeped all of count val newCount = count - taken.length - val newLhs = lhs.concat(Wrap(taken)) + val wrapped = Wrap(taken) + // this is more efficient than using concat + val newLhs = if (lhs.isEmpty) wrapped else Append(lhs, wrapped) if (newCount > 0) { // we have to keep takeping on the rhs go(newLhs, newCount, rhs, Chain.nil) @@ -282,7 +284,7 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { } } case Append(l, r) => - go(lhs, count, l, r.concat(rhs)) + go(lhs, count, l, if (rhs.isEmpty) r else Append(r, rhs)) case s @ Singleton(_) => // due to the invariant count >= 1 val newLhs = if (lhs.isEmpty) s else Append(lhs, s) @@ -308,13 +310,14 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { arg match { case Wrap(seq) => if (count == 1) { - lhs.append(seq.last) + seq.last +: rhs } else { // count > 1 val taken = seq.takeRight(count) // we may have not takeped all of count val newCount = count - taken.length - val newRhs = Wrap(taken).concat(rhs) + val wrapped = Wrap(taken) + val newRhs = if (rhs.isEmpty) wrapped else Append(wrapped, rhs) if (newCount > 0) { // we have to keep takeping on the rhs go(Chain.nil, newCount, lhs, newRhs) @@ -324,7 +327,7 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { } } case Append(l, r) => - go(lhs.concat(l), count, r, rhs) + go(if (lhs.isEmpty) l else Append(lhs, l), count, r, rhs) case s @ Singleton(_) => // due to the invariant count >= 1 val newRhs = if (rhs.isEmpty) s else Append(s, rhs) @@ -381,11 +384,13 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { rhs } } else { + // dropped is not empty + val wrapped = Wrap(dropped) // we must be done - Chain.fromSeq(dropped).concat(rhs) + if (rhs.isEmpty) wrapped else Append(wrapped, rhs) } case Append(l, r) => - go(count, l, r.concat(rhs)) + go(count, l, if (rhs.isEmpty) r else Append(r, rhs)) case Singleton(_) => // due to the invariant count >= 1 if (count > 1) go(count - 1, rhs, Chain.nil) @@ -422,10 +427,12 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { } } else { // we must be done - lhs.concat(Chain.fromSeq(dropped)) + // note: dropped.nonEmpty + val wrapped = Wrap(dropped) + if (lhs.isEmpty) wrapped else Append(lhs, wrapped) } case Append(l, r) => - go(lhs.concat(l), count, r) + go(if (lhs.isEmpty) l else Append(lhs, l), count, r) case Singleton(_) => // due to the invariant count >= 1 if (count > 1) go(Chain.nil, count - 1, lhs) diff --git a/tests/shared/src/test/scala/cats/tests/ChainSuite.scala b/tests/shared/src/test/scala/cats/tests/ChainSuite.scala index 7794823bfd..154ec56304 100644 --- a/tests/shared/src/test/scala/cats/tests/ChainSuite.scala +++ b/tests/shared/src/test/scala/cats/tests/ChainSuite.scala @@ -449,25 +449,36 @@ class ChainSuite extends CatsSuite { } } + private val genChainDropTakeArgs = + Arbitrary.arbitrary[Chain[Int]].flatMap { chain => + // Bias to values close to the length + Gen + .oneOf( + Gen.choose(Int.MinValue, Int.MaxValue), + Gen.choose(-1, chain.length.toInt + 1) + ) + .map((chain, _)) + } + test("drop(cnt).toList == toList.drop(cnt)") { - forAll { (chain: Chain[Int], count: Int) => + forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => assert(chain.drop(count).toList == chain.toList.drop(count)) } } test("dropRight(cnt).toList == toList.dropRight(cnt)") { - forAll { (chain: Chain[Int], count: Int) => + forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => assert(chain.dropRight(count).toList == chain.toList.dropRight(count)) } } test("take(cnt).toList == toList.take(cnt)") { - forAll { (chain: Chain[Int], count: Int) => + forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => assert(chain.take(count).toList == chain.toList.take(count)) } } test("takeRight(cnt).toList == toList.takeRight(cnt)") { - forAll { (chain: Chain[Int], count: Int) => + forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => assert(chain.takeRight(count).toList == chain.toList.takeRight(count)) } }