{-# LANGUAGE Trustworthy       #-}
{-# LANGUAGE BangPatterns      #-}
{-# LANGUAGE CPP               #-}
{-# LANGUAGE DeriveFoldable    #-}
{-# LANGUAGE DeriveFunctor     #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE MagicHash         #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE UnboxedTuples     #-}
module GHC.Event.PSQ
    (
    
      Elem(..)
    , Key
    , Prio
    
    , PSQ
    
    , size
    , null
    , lookup
    
    , empty
    , singleton
    
    , unsafeInsertNew
    
    , delete
    , adjust
    
    , toList
    
    , findMin
    , deleteMin
    , minView
    , atMost
    ) where
import GHC.Base hiding (Nat, empty)
import GHC.Event.Unique
import GHC.Word (Word64)
import GHC.Num (Num(..))
import GHC.Real (fromIntegral)
import GHC.Types (Int)
#include "MachDeps.h"
type Prio = Word64
type Nat = Word
type Key = Unique
type Mask = Int
type PSQ a = IntPSQ a
data Elem a = E
    { key   :: {-# UNPACK #-} !Key
    , prio  :: {-# UNPACK #-} !Prio
    , value :: a
    }
data IntPSQ v
    = Bin {-# UNPACK #-} !Key {-# UNPACK #-} !Prio !v {-# UNPACK #-} !Mask !(IntPSQ v) !(IntPSQ v)
    | Tip {-# UNPACK #-} !Key {-# UNPACK #-} !Prio !v
    | Nil
(.&.) :: Nat -> Nat -> Nat
(.&.) (W# w1) (W# w2) = W# (w1 `and#` w2)
{-# INLINE (.&.) #-}
xor :: Nat -> Nat -> Nat
xor (W# w1) (W# w2) = W# (w1 `xor#` w2)
{-# INLINE xor #-}
complement :: Nat -> Nat
complement (W# w) = W# (w `xor#` mb)
  where
#if WORD_SIZE_IN_BITS == 32
    mb = 0xFFFFFFFF##
#elif WORD_SIZE_IN_BITS == 64
    mb = 0xFFFFFFFFFFFFFFFF##
#else
#error Unhandled value for WORD_SIZE_IN_BITS
#endif
{-# INLINE complement #-}
{-# INLINE natFromInt #-}
natFromInt :: Int -> Nat
natFromInt = fromIntegral
{-# INLINE intFromNat #-}
intFromNat :: Nat -> Int
intFromNat = fromIntegral
{-# INLINE zero #-}
zero :: Key -> Mask -> Bool
zero i m
  = (natFromInt (asInt i)) .&. (natFromInt m) == 0
{-# INLINE nomatch #-}
nomatch :: Key -> Key -> Mask -> Bool
nomatch k1 k2 m =
    natFromInt (asInt k1) .&. m' /= natFromInt (asInt k2) .&. m'
  where
    m' = maskW (natFromInt m)
{-# INLINE maskW #-}
maskW :: Nat -> Nat
maskW m = complement (m-1) `xor` m
{-# INLINE branchMask #-}
branchMask :: Key -> Key -> Mask
branchMask k1' k2' =
    intFromNat (highestBitMask (natFromInt k1 `xor` natFromInt k2))
  where
    k1 = asInt k1'
    k2 = asInt k2'
highestBitMask :: Nat -> Nat
highestBitMask (W# x) =
    W# (uncheckedShiftL# 1## (word2Int# (WORD_SIZE_IN_BITS## `minusWord#` 1## `minusWord#` clz# x)))
{-# INLINE highestBitMask #-}
null :: IntPSQ v -> Bool
null Nil = True
null _   = False
size :: IntPSQ v -> Int
size Nil               = 0
size (Tip _ _ _)       = 1
size (Bin _ _ _ _ l r) = 1 + size l + size r
lookup :: Key -> IntPSQ v -> Maybe (Prio, v)
lookup k = go
  where
    go t = case t of
        Nil                -> Nothing
        Tip k' p' x'
          | k == k'        -> Just (p', x')
          | otherwise      -> Nothing
        Bin k' p' x' m l r
          | nomatch k k' m -> Nothing
          | k == k'        -> Just (p', x')
          | zero k m       -> go l
          | otherwise      -> go r
findMin :: IntPSQ v -> Maybe (Elem v)
findMin t = case t of
    Nil             -> Nothing
    Tip k p x       -> Just (E k p x)
    Bin k p x _ _ _ -> Just (E k p x)
empty :: IntPSQ v
empty = Nil
singleton :: Key -> Prio -> v -> IntPSQ v
singleton = Tip
{-# INLINABLE unsafeInsertNew #-}
unsafeInsertNew :: Key -> Prio -> v -> IntPSQ v -> IntPSQ v
unsafeInsertNew k p x = go
  where
    go t = case t of
      Nil       -> Tip k p x
      Tip k' p' x'
        | (p, k) < (p', k') -> link k  p  x  k' t           Nil
        | otherwise         -> link k' p' x' k  (Tip k p x) Nil
      Bin k' p' x' m l r
        | nomatch k k' m ->
            if (p, k) < (p', k')
              then link k  p  x  k' t           Nil
              else link k' p' x' k  (Tip k p x) (merge m l r)
        | otherwise ->
            if (p, k) < (p', k')
              then
                if zero k' m
                  then Bin k  p  x  m (unsafeInsertNew k' p' x' l) r
                  else Bin k  p  x  m l (unsafeInsertNew k' p' x' r)
              else
                if zero k m
                  then Bin k' p' x' m (unsafeInsertNew k  p  x  l) r
                  else Bin k' p' x' m l (unsafeInsertNew k  p  x  r)
link :: Key -> Prio -> v -> Key -> IntPSQ v -> IntPSQ v -> IntPSQ v
link k p x k' k't otherTree
  | zero (Unique m) (asInt k') = Bin k p x m k't otherTree
  | otherwise                  = Bin k p x m otherTree k't
  where
    m = branchMask k k'
{-# INLINABLE delete #-}
delete :: Key -> IntPSQ v -> IntPSQ v
delete k = go
  where
    go t = case t of
        Nil           -> Nil
        Tip k' _ _
          | k == k'   -> Nil
          | otherwise -> t
        Bin k' p' x' m l r
          | nomatch k k' m -> t
          | k == k'        -> merge m l r
          | zero k m       -> binShrinkL k' p' x' m (go l) r
          | otherwise      -> binShrinkR k' p' x' m l      (go r)
{-# INLINE deleteMin #-}
deleteMin :: IntPSQ v -> IntPSQ v
deleteMin t = case minView t of
    Nothing      -> t
    Just (_, t') -> t'
adjust
    :: (Prio -> Prio)
    -> Key
    -> PSQ a
    -> PSQ a
adjust f k q = case alter g k q of (_, q') -> q'
  where g (Just (p, v)) = ((), Just ((f p), v))
        g Nothing       = ((), Nothing)
{-# INLINE adjust #-}
{-# INLINE alter #-}
alter
    :: (Maybe (Prio, v) -> (b, Maybe (Prio, v)))
    -> Key
    -> IntPSQ v
    -> (b, IntPSQ v)
alter f = \k t0 ->
    let (t, mbX) = case deleteView k t0 of
                            Nothing          -> (t0, Nothing)
                            Just (p, v, t0') -> (t0', Just (p, v))
    in case f mbX of
          (b, mbX') ->
            (b, maybe t (\(p, v) -> unsafeInsertNew k p v t) mbX')
    where
        maybe _ g (Just x)  = g x
        maybe def _ Nothing = def
{-# INLINE binShrinkL #-}
binShrinkL :: Key -> Prio -> v -> Mask -> IntPSQ v -> IntPSQ v -> IntPSQ v
binShrinkL k p x m Nil r = case r of Nil -> Tip k p x; _ -> Bin k p x m Nil r
binShrinkL k p x m l   r = Bin k p x m l r
{-# INLINE binShrinkR #-}
binShrinkR :: Key -> Prio -> v -> Mask -> IntPSQ v -> IntPSQ v -> IntPSQ v
binShrinkR k p x m l Nil = case l of Nil -> Tip k p x; _ -> Bin k p x m l Nil
binShrinkR k p x m l r   = Bin k p x m l r
toList :: IntPSQ v -> [Elem v]
toList =
    go []
  where
    go acc Nil                   = acc
    go acc (Tip k' p' x')        = (E k' p' x') : acc
    go acc (Bin k' p' x' _m l r) = (E k' p' x') : go (go acc r) l
{-# INLINABLE deleteView #-}
deleteView :: Key -> IntPSQ v -> Maybe (Prio, v, IntPSQ v)
deleteView k t0 =
    case delFrom t0 of
      (# _, Nothing     #) -> Nothing
      (# t, Just (p, x) #) -> Just (p, x, t)
  where
    delFrom t = case t of
      Nil -> (# Nil, Nothing #)
      Tip k' p' x'
        | k == k'   -> (# Nil, Just (p', x') #)
        | otherwise -> (# t,   Nothing       #)
      Bin k' p' x' m l r
        | nomatch k k' m -> (# t, Nothing #)
        | k == k'   -> let t' = merge m l r
                       in  t' `seq` (# t', Just (p', x') #)
        | zero k m  -> case delFrom l of
                         (# l', mbPX #) -> let t' = binShrinkL k' p' x' m l' r
                                           in  t' `seq` (# t', mbPX #)
        | otherwise -> case delFrom r of
                         (# r', mbPX #) -> let t' = binShrinkR k' p' x' m l  r'
                                           in  t' `seq` (# t', mbPX #)
{-# INLINE minView #-}
minView :: IntPSQ v -> Maybe (Elem v, IntPSQ v)
minView t = case t of
    Nil             -> Nothing
    Tip k p x       -> Just (E k p x, Nil)
    Bin k p x m l r -> Just (E k p x, merge m l r)
{-# INLINABLE atMost #-}
atMost :: Prio -> IntPSQ v -> ([Elem v], IntPSQ v)
atMost pt t0 = go [] t0
  where
    go acc t = case t of
        Nil             -> (acc, t)
        Tip k p x
            | p > pt    -> (acc, t)
            | otherwise -> ((E k p x) : acc, Nil)
        Bin k p x m l r
            | p > pt    -> (acc, t)
            | otherwise ->
                let (acc',  l') = go acc  l
                    (acc'', r') = go acc' r
                in  ((E k p x) : acc'', merge m l' r')
{-# INLINABLE merge #-}
merge :: Mask -> IntPSQ v -> IntPSQ v -> IntPSQ v
merge m l r = case l of
    Nil -> r
    Tip lk lp lx ->
      case r of
        Nil                     -> l
        Tip rk rp rx
          | (lp, lk) < (rp, rk) -> Bin lk lp lx m Nil r
          | otherwise           -> Bin rk rp rx m l   Nil
        Bin rk rp rx rm rl rr
          | (lp, lk) < (rp, rk) -> Bin lk lp lx m Nil r
          | otherwise           -> Bin rk rp rx m l   (merge rm rl rr)
    Bin lk lp lx lm ll lr ->
      case r of
        Nil                     -> l
        Tip rk rp rx
          | (lp, lk) < (rp, rk) -> Bin lk lp lx m (merge lm ll lr) r
          | otherwise           -> Bin rk rp rx m l                Nil
        Bin rk rp rx rm rl rr
          | (lp, lk) < (rp, rk) -> Bin lk lp lx m (merge lm ll lr) r
          | otherwise           -> Bin rk rp rx m l                (merge rm rl rr)