diff --git a/containers-tests/benchmarks/Map.hs b/containers-tests/benchmarks/Map.hs index 0e324e556..392c15bb2 100644 --- a/containers-tests/benchmarks/Map.hs +++ b/containers-tests/benchmarks/Map.hs @@ -20,8 +20,10 @@ main = do let m = M.fromAscList elems :: M.Map Int Int m_even = M.fromAscList elems_even :: M.Map Int Int m_odd = M.fromAscList elems_odd :: M.Map Int Int + m_odd_keys = M.keysSet m_odd evaluate $ rnf [m, m_even, m_odd] evaluate $ rnf elems_rev + evaluate $ rnf m_odd_keys defaultMain [ bench "lookup absent" $ whnf (lookup evens) m_odd , bench "lookup present" $ whnf (lookup evens) m_even @@ -95,8 +97,13 @@ main = do , bench "fromDistinctDescList" $ whnf M.fromDistinctDescList elems_rev , bench "fromDistinctDescList:fusion" $ whnf (\n -> M.fromDistinctDescList [(i,i) | i <- [n,n-1..1]]) bound , bench "minView" $ whnf (\m' -> case M.minViewWithKey m' of {Nothing -> 0; Just ((k,v),m'') -> k+v+M.size m''}) (M.fromAscList $ zip [1..10::Int] [100..110::Int]) + , bench "eq" $ whnf (\m' -> m' == m') m -- worst case, compares everything , bench "compare" $ whnf (\m' -> compare m' m') m -- worst case, compares everything + + , bench "restrictKeys" $ whnf (M.restrictKeys m) m_odd_keys + , bench "withoutKeys" $ whnf (M.withoutKeys m) m_odd_keys + , bench "partitionKeys" $ whnf (M.partitionKeys m) m_odd_keys ] where bound = 2^12 diff --git a/containers-tests/containers-tests.cabal b/containers-tests/containers-tests.cabal index da7c76a0f..289f31dc2 100644 --- a/containers-tests/containers-tests.cabal +++ b/containers-tests/containers-tests.cabal @@ -128,6 +128,7 @@ library Utils.Containers.Internal.State Utils.Containers.Internal.StrictMaybe Utils.Containers.Internal.EqOrdUtil + Utils.Containers.Internal.StrictTriple if impl(ghc) other-modules: diff --git a/containers-tests/tests/map-properties.hs b/containers-tests/tests/map-properties.hs index 6b7a45ad5..d6eea949e 100644 --- a/containers-tests/tests/map-properties.hs +++ b/containers-tests/tests/map-properties.hs @@ -173,6 +173,7 @@ main = defaultMain $ testGroup "map-properties" , testProperty "withoutKeys" prop_withoutKeys , testProperty "intersection" prop_intersection , testProperty "restrictKeys" prop_restrictKeys + , testProperty "partitionKeys" prop_partitionKeys , testProperty "intersection model" prop_intersectionModel , testProperty "intersectionWith" prop_intersectionWith , testProperty "intersectionWithModel" prop_intersectionWithModel @@ -1140,6 +1141,12 @@ prop_withoutKeys m s0 = valid reduced .&&. (m `withoutKeys` s === filterWithKey s = keysSet s0 reduced = withoutKeys m s +prop_partitionKeys :: IMap -> IMap -> Property +prop_partitionKeys m s0 = valid with .&&. valid without .&&. (m `partitionKeys` s === (m `restrictKeys` s, m `withoutKeys` s)) + where + s = keysSet s0 + (with, without) = partitionKeys m s + prop_intersection :: IMap -> IMap -> Bool prop_intersection t1 t2 = valid (intersection t1 t2) diff --git a/containers/containers.cabal b/containers/containers.cabal index 12185a9f5..eeb72dae6 100644 --- a/containers/containers.cabal +++ b/containers/containers.cabal @@ -83,6 +83,7 @@ Library Utils.Containers.Internal.PtrEquality Utils.Containers.Internal.Coercions Utils.Containers.Internal.EqOrdUtil + Utils.Containers.Internal.StrictTriple if impl(ghc) other-modules: Utils.Containers.Internal.TypeError diff --git a/containers/src/Data/Map/Internal.hs b/containers/src/Data/Map/Internal.hs index b230a574e..8cd484536 100644 --- a/containers/src/Data/Map/Internal.hs +++ b/containers/src/Data/Map/Internal.hs @@ -299,6 +299,7 @@ module Data.Map.Internal ( , restrictKeys , withoutKeys + , partitionKeys , partition , partitionWithKey @@ -398,6 +399,7 @@ import qualified Data.Set.Internal as Set import Data.Set.Internal (Set) import Utils.Containers.Internal.PtrEquality (ptrEq) import Utils.Containers.Internal.StrictPair +import Utils.Containers.Internal.StrictTriple import Utils.Containers.Internal.StrictMaybe import Utils.Containers.Internal.BitQueue import Utils.Containers.Internal.EqOrdUtil (EqM(..), OrdM(..)) @@ -1966,6 +1968,51 @@ withoutKeys m (Set.Bin _ k ls rs) = case splitMember k m of {-# INLINABLE withoutKeys #-} #endif +-- | \(O\bigl(m \log\bigl(\frac{n}{m}+1\bigr)\bigr), \; 0 < m \leq n\). Partition the map according to a set. +-- The first map contains the input 'Map' restricted to those keys found in the 'Set', +-- the second map contains the input 'Map' without all keys in the 'Set'. +-- This is more efficient than using 'restrictKeys' and 'withoutKeys' together. +-- +-- @ +-- m \`partitionKeys\` s = (m ``restrictKeys`` s, m ``withoutKeys`` s) +-- @ +partitionKeys :: Ord k => Map k a -> Set k -> (Map k a, Map k a) +partitionKeys xs ys = + case partitionKeysWorker xs ys of + xs' :*: ys' -> (xs', ys') +#if __GLASGOW_HASKELL__ +{-# INLINABLE partitionKeys #-} +#endif + +partitionKeysWorker :: Ord k => Map k a -> Set k -> StrictPair (Map k a) (Map k a) +partitionKeysWorker Tip _ = Tip :*: Tip +partitionKeysWorker m Set.Tip = Tip :*: m +partitionKeysWorker m@(Bin _ k x lm rm) s@Set.Bin{} = + case b of + True -> with :*: without + where + with = + if lmWith `ptrEq` lm && rmWith `ptrEq` rm + then m + else link k x lmWith rmWith + without = + link2 lmWithout rmWithout + False -> with :*: without + where + with = link2 lmWith rmWith + without = + if lmWithout `ptrEq` lm && rmWithout `ptrEq` rm + then m + else link k x lmWithout rmWithout + where + !(lmWith :*: lmWithout) = partitionKeysWorker lm ls' + !(rmWith :*: rmWithout) = partitionKeysWorker rm rs' + + !(!ls', b, !rs') = Set.splitMember k s +#if __GLASGOW_HASKELL__ +{-# INLINABLE partitionKeysWorker #-} +#endif + -- | \(O(n+m)\). Difference with a combining function. -- When two equal keys are -- encountered, the combining function is applied to the values of these keys. @@ -4004,8 +4051,6 @@ splitMember k0 m = case go k0 m of {-# INLINABLE splitMember #-} #endif -data StrictTriple a b c = StrictTriple !a !b !c - {-------------------------------------------------------------------- Utility functions that maintain the balance properties of the tree. All constructors assume that all values in [l] < [k] and all values diff --git a/containers/src/Data/Map/Lazy.hs b/containers/src/Data/Map/Lazy.hs index 2fca4c91d..d69e943f5 100644 --- a/containers/src/Data/Map/Lazy.hs +++ b/containers/src/Data/Map/Lazy.hs @@ -231,6 +231,7 @@ module Data.Map.Lazy ( , filterWithKey , restrictKeys , withoutKeys + , partitionKeys , partition , partitionWithKey , takeWhileAntitone diff --git a/containers/src/Data/Map/Strict.hs b/containers/src/Data/Map/Strict.hs index 649898850..632a1bf70 100644 --- a/containers/src/Data/Map/Strict.hs +++ b/containers/src/Data/Map/Strict.hs @@ -246,6 +246,7 @@ module Data.Map.Strict , filterWithKey , restrictKeys , withoutKeys + , partitionKeys , partition , partitionWithKey diff --git a/containers/src/Data/Map/Strict/Internal.hs b/containers/src/Data/Map/Strict/Internal.hs index 21afe2d91..9e6be1235 100644 --- a/containers/src/Data/Map/Strict/Internal.hs +++ b/containers/src/Data/Map/Strict/Internal.hs @@ -256,6 +256,7 @@ module Data.Map.Strict.Internal , filterWithKey , restrictKeys , withoutKeys + , partitionKeys , partition , partitionWithKey , takeWhileAntitone @@ -418,7 +419,8 @@ import Data.Map.Internal , toDescList , union , unions - , withoutKeys ) + , withoutKeys + , partitionKeys ) import Data.Map.Internal.Debug (valid) diff --git a/containers/src/Data/Set/Internal.hs b/containers/src/Data/Set/Internal.hs index 6bc472f3d..432b674f6 100644 --- a/containers/src/Data/Set/Internal.hs +++ b/containers/src/Data/Set/Internal.hs @@ -256,6 +256,7 @@ import Data.List.NonEmpty (NonEmpty(..)) #endif import Utils.Containers.Internal.StrictPair +import Utils.Containers.Internal.StrictTriple import Utils.Containers.Internal.PtrEquality import Utils.Containers.Internal.EqOrdUtil (EqM(..), OrdM(..)) @@ -1430,19 +1431,25 @@ splitS x (Bin _ y l r) EQ -> (l :*: r) {-# INLINABLE splitS #-} +splitMemberS :: Ord a => a -> Set a -> StrictTriple (Set a) Bool (Set a) +splitMemberS x = go + where + go Tip = StrictTriple Tip False Tip + go (Bin _ y l r) = case compare x y of + LT -> let StrictTriple lt found gt = splitMemberS x l + in StrictTriple lt found (link y gt r) + GT -> let StrictTriple lt found gt = splitMemberS x r + in StrictTriple (link y l lt) found gt + EQ -> StrictTriple l True r +#if __GLASGOW_HASKELL__ +{-# INLINABLE splitMemberS #-} +#endif + -- | \(O(\log n)\). Performs a 'split' but also returns whether the pivot -- element was found in the original set. splitMember :: Ord a => a -> Set a -> (Set a,Bool,Set a) -splitMember _ Tip = (Tip, False, Tip) -splitMember x (Bin _ y l r) - = case compare x y of - LT -> let (lt, found, gt) = splitMember x l - !gt' = link y gt r - in (lt, found, gt') - GT -> let (lt, found, gt) = splitMember x r - !lt' = link y l lt - in (lt', found, gt) - EQ -> (l, True, r) +splitMember k0 s = case splitMemberS k0 s of + StrictTriple l b r -> (l, b, r) #if __GLASGOW_HASKELL__ {-# INLINABLE splitMember #-} #endif diff --git a/containers/src/Utils/Containers/Internal/StrictTriple.hs b/containers/src/Utils/Containers/Internal/StrictTriple.hs new file mode 100644 index 000000000..45523f81f --- /dev/null +++ b/containers/src/Utils/Containers/Internal/StrictTriple.hs @@ -0,0 +1,15 @@ +{-# LANGUAGE CPP #-} +#if !defined(TESTING) && defined(__GLASGOW_HASKELL__) +{-# LANGUAGE Safe #-} +#endif + +-- | A strict triple + +module Utils.Containers.Internal.StrictTriple (StrictTriple(..)) where + +-- | The same as a regular Haskell tuple, but +-- +-- @ +-- StrictTriple x y _|_ = StrictTriple x _|_ z = StrictTriple _|_ y z = _|_ +-- @ +data StrictTriple a b c = StrictTriple !a !b !c