{-# LANGUAGE BangPatterns #-}

module Data.CanonicalMaps (
  CanonicalZero (..),
  canonicalInsert,
  canonicalMapUnion,
  canonicalMap,
  pointWise,
  Map.Map,
)
where

import Data.Map.Internal (
  Map (..),
  link,
  link2,
 )
import qualified Data.Map.Strict as Map

-- =====================================================================================
-- Operations on Map from keys to values that are specialised to `CanonicalZero` values.
-- A (Map k v) is (CanonicalZero v), if it never stores a zero at type v.
-- In order to do this we need to know what 'zeroC' is, and 'joinC' has to know how to
-- joining together two maps where one of its arguments might be 'zeroC'
-- This class is strictly used in the implementation, and is not observable by the user.
-- ======================================================================================

class Eq t => CanonicalZero t where
  zeroC :: t
  joinC :: t -> t -> t

instance CanonicalZero Integer where
  zeroC :: Integer
zeroC = Integer
0
  joinC :: Integer -> Integer -> Integer
joinC = forall a. Num a => a -> a -> a
(+)

instance (Ord k, CanonicalZero v) => CanonicalZero (Map k v) where
  zeroC :: Map k v
zeroC = forall k a. Map k a
Map.empty
  joinC :: Map k v -> Map k v -> Map k v
joinC = forall k a.
(Ord k, CanonicalZero a) =>
(a -> a -> a) -> Map k a -> Map k a -> Map k a
canonicalMapUnion forall t. CanonicalZero t => t -> t -> t
joinC

-- Note that the class CanonicalZero and the function canonicalMapUnion are mutually recursive.

canonicalMapUnion ::
  (Ord k, CanonicalZero a) =>
  (a -> a -> a) -> -- (\ left right -> ??) which side do you prefer?
  Map k a ->
  Map k a ->
  Map k a
canonicalMapUnion :: forall k a.
(Ord k, CanonicalZero a) =>
(a -> a -> a) -> Map k a -> Map k a -> Map k a
canonicalMapUnion a -> a -> a
_f Map k a
t1 Map k a
Tip = Map k a
t1
canonicalMapUnion a -> a -> a
f Map k a
t1 (Bin Size
_ k
k a
x Map k a
Tip Map k a
Tip) = forall k a.
(Ord k, CanonicalZero a) =>
(a -> a -> a) -> k -> a -> Map k a -> Map k a
canonicalInsert a -> a -> a
f k
k a
x Map k a
t1
canonicalMapUnion a -> a -> a
f (Bin Size
_ k
k a
x Map k a
Tip Map k a
Tip) Map k a
t2 = forall k a.
(Ord k, CanonicalZero a) =>
(a -> a -> a) -> k -> a -> Map k a -> Map k a
canonicalInsert a -> a -> a
f k
k a
x Map k a
t2
canonicalMapUnion a -> a -> a
_f Map k a
Tip Map k a
t2 = Map k a
t2
canonicalMapUnion a -> a -> a
f (Bin Size
_ k
k1 a
x1 Map k a
l1 Map k a
r1) Map k a
t2 = case forall k a. Ord k => k -> Map k a -> (Map k a, Maybe a, Map k a)
Map.splitLookup k
k1 Map k a
t2 of
  (Map k a
l2, Maybe a
mb, Map k a
r2) -> case Maybe a
mb of
    Maybe a
Nothing ->
      if a
x1 forall a. Eq a => a -> a -> Bool
== forall t. CanonicalZero t => t
zeroC
        then forall k a. Map k a -> Map k a -> Map k a
link2 Map k a
l1l2 Map k a
r1r2
        else forall k a. k -> a -> Map k a -> Map k a -> Map k a
link k
k1 a
x1 Map k a
l1l2 Map k a
r1r2
    Just a
x2 ->
      if a
new forall a. Eq a => a -> a -> Bool
== forall t. CanonicalZero t => t
zeroC
        then forall k a. Map k a -> Map k a -> Map k a
link2 Map k a
l1l2 Map k a
r1r2
        else forall k a. k -> a -> Map k a -> Map k a -> Map k a
link k
k1 a
new Map k a
l1l2 Map k a
r1r2
      where
        new :: a
new = a -> a -> a
f a
x1 a
x2
    where
      !l1l2 :: Map k a
l1l2 = forall k a.
(Ord k, CanonicalZero a) =>
(a -> a -> a) -> Map k a -> Map k a -> Map k a
canonicalMapUnion a -> a -> a
f Map k a
l1 Map k a
l2
      !r1r2 :: Map k a
r1r2 = forall k a.
(Ord k, CanonicalZero a) =>
(a -> a -> a) -> Map k a -> Map k a -> Map k a
canonicalMapUnion a -> a -> a
f Map k a
r1 Map k a
r2
{-# INLINEABLE canonicalMapUnion #-}

canonicalInsert ::
  (Ord k, CanonicalZero a) =>
  (a -> a -> a) ->
  k ->
  a ->
  Map k a ->
  Map k a
canonicalInsert :: forall k a.
(Ord k, CanonicalZero a) =>
(a -> a -> a) -> k -> a -> Map k a -> Map k a
canonicalInsert a -> a -> a
f !k
kx a
x = Map k a -> Map k a
go
  where
    go :: Map k a -> Map k a
go Map k a
Tip = if a
x forall a. Eq a => a -> a -> Bool
== forall t. CanonicalZero t => t
zeroC then forall k a. Map k a
Tip else forall k a. k -> a -> Map k a
Map.singleton k
kx a
x
    go (Bin Size
sy k
ky a
y Map k a
l Map k a
r) =
      case forall a. Ord a => a -> a -> Ordering
compare k
kx k
ky of
        Ordering
LT -> forall k a. k -> a -> Map k a -> Map k a -> Map k a
link k
ky a
y (Map k a -> Map k a
go Map k a
l) Map k a
r
        Ordering
GT -> forall k a. k -> a -> Map k a -> Map k a -> Map k a
link k
ky a
y Map k a
l (Map k a -> Map k a
go Map k a
r)
        Ordering
EQ -> if a
new forall a. Eq a => a -> a -> Bool
== forall t. CanonicalZero t => t
zeroC then forall k a. Map k a -> Map k a -> Map k a
link2 Map k a
l Map k a
r else forall k a. Size -> k -> a -> Map k a -> Map k a -> Map k a
Bin Size
sy k
kx a
new Map k a
l Map k a
r
          where
            new :: a
new = a -> a -> a
f a
y a
x -- Apply to value in the tree, then the new value
{-# INLINEABLE canonicalInsert #-}

canonicalMap :: (Ord k, CanonicalZero a) => (a -> a) -> Map k a -> Map k a
canonicalMap :: forall k a.
(Ord k, CanonicalZero a) =>
(a -> a) -> Map k a -> Map k a
canonicalMap a -> a
f = forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
Map.foldrWithKey forall {k}. Ord k => k -> a -> Map k a -> Map k a
accum forall k a. Map k a
Map.empty
  where
    accum :: k -> a -> Map k a -> Map k a
accum k
k a
v Map k a
ans = if a
new forall a. Eq a => a -> a -> Bool
== forall t. CanonicalZero t => t
zeroC then Map k a
ans else forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert k
k a
new Map k a
ans
      where
        new :: a
new = a -> a
f a
v
{-# INLINEABLE canonicalMap #-}

-- Pointwise comparison assuming the map is CanonicalZero, and we assume semantically that
-- the value for keys not appearing in the map is 'zeroC'

pointWise ::
  (Ord k, CanonicalZero v) =>
  (v -> v -> Bool) ->
  Map k v ->
  Map k v ->
  Bool
pointWise :: forall k v.
(Ord k, CanonicalZero v) =>
(v -> v -> Bool) -> Map k v -> Map k v -> Bool
pointWise v -> v -> Bool
_ Map k v
Tip Map k v
Tip = Bool
True
pointWise v -> v -> Bool
p Map k v
Tip m :: Map k v
m@Bin {} = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall t. CanonicalZero t => t
zeroC v -> v -> Bool
`p`) Map k v
m
pointWise v -> v -> Bool
p m :: Map k v
m@Bin {} Map k v
Tip = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (v -> v -> Bool
`p` forall t. CanonicalZero t => t
zeroC) Map k v
m
pointWise v -> v -> Bool
p Map k v
m (Bin Size
_ k
k v
v2 Map k v
ls Map k v
rs) =
  case forall k a. Ord k => k -> Map k a -> (Map k a, Maybe a, Map k a)
Map.splitLookup k
k Map k v
m of
    (Map k v
lm, Just v
v1, Map k v
rm) -> v -> v -> Bool
p v
v1 v
v2 Bool -> Bool -> Bool
&& forall k v.
(Ord k, CanonicalZero v) =>
(v -> v -> Bool) -> Map k v -> Map k v -> Bool
pointWise v -> v -> Bool
p Map k v
ls Map k v
lm Bool -> Bool -> Bool
&& forall k v.
(Ord k, CanonicalZero v) =>
(v -> v -> Bool) -> Map k v -> Map k v -> Bool
pointWise v -> v -> Bool
p Map k v
rs Map k v
rm
    (Map k v, Maybe v, Map k v)
_ -> Bool
False
{-# INLINEABLE pointWise #-}