{-# LANGUAGE Strict #-}
{-# OPTIONS_GHC -Wunused-imports -Wno-redundant-bang-patterns #-}

------------------------------------------------------------------------
-- | Hash tables.
------------------------------------------------------------------------

module Agda.Utils.HashTable
  ( HashTable
  , HashTableLU
  , HashTableLL
  , Agda.Utils.HashTable.empty
  , Agda.Utils.HashTable.insert
  , Agda.Utils.HashTable.lookup
  , Agda.Utils.HashTable.toList
  , forAssocs
  , Agda.Utils.HashTable.size
  , insertingIfAbsent
  ) where

import Prelude hiding (lookup)

import Data.Bits
import Data.Hashable
import Data.Primitive.MutVar
import Data.Vector.Hashtables
import Data.Vector.Hashtables.Internal
import Data.Vector.Hashtables.Internal.Mask

import qualified Data.Primitive.PrimArray as A

import Data.Vector.Generic.Mutable (MVector)
import qualified Data.Vector.Mutable as VM
import qualified Data.Vector.Unboxed.Mutable as VUM

import Agda.Utils.Monad


-- | Hash tables.
-- A very limited amount of (possibly outdated) testing indicates
-- that, for the use in Agda's serialiser/deserialiser,
-- Data.HashTable.IO.CuckooHashTable is somewhat slower than
-- Data.HashTable.IO.BasicHashTable, and that
-- Data.HashTable.IO.LinearHashTable and the hashtables from
-- Data.Hashtable are much slower. However, other (also possibly
-- outdated) testing suggests that Data.HashTable.IO.CuckooHashTable
-- is quite a bit faster than Data.HashTable.IO.BasicHashTable for
-- 64-bit Windows. Some more recent, also limited, testing suggests
-- that the following hash table implementation from
-- Data.Vector.Hashtables is quite a bit faster than
-- Data.HashTable.IO.BasicHashTable (see issue #5966).

newtype HashTable ks k vs v =
  HashTable (Dictionary (PrimState IO) ks k vs v)

-- | Hashtable with lifted keys and unboxed values.
type HashTableLU k v = HashTable VM.MVector k VUM.MVector v

-- | Hashtable with lifted keys and lifted values.
type HashTableLL k v = HashTable VM.MVector k VM.MVector v

-- | An empty hash table.

empty :: (MVector ks k, MVector vs v) => IO (HashTable ks k vs v)
empty :: forall (ks :: * -> * -> *) k (vs :: * -> * -> *) v.
(MVector ks k, MVector vs v) =>
IO (HashTable ks k vs v)
empty = Dictionary RealWorld ks k vs v -> HashTable ks k vs v
Dictionary (PrimState IO) ks k vs v -> HashTable ks k vs v
forall (ks :: * -> * -> *) k (vs :: * -> * -> *) v.
Dictionary (PrimState IO) ks k vs v -> HashTable ks k vs v
HashTable (Dictionary RealWorld ks k vs v -> HashTable ks k vs v)
-> IO (Dictionary RealWorld ks k vs v) -> IO (HashTable ks k vs v)
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> Int -> IO (Dictionary (PrimState IO) ks k vs v)
forall (ks :: * -> * -> *) k (vs :: * -> * -> *) v (m :: * -> *).
(MVector ks k, MVector vs v, PrimMonad m) =>
Int -> m (Dictionary (PrimState m) ks k vs v)
initialize Int
0

-- | Inserts the key and the corresponding value into the hash table.

insert :: (Hashable k, MVector vs v, MVector ks k) =>
          HashTable ks k vs v -> k -> v -> IO ()
insert :: forall k (vs :: * -> * -> *) v (ks :: * -> * -> *).
(Hashable k, MVector vs v, MVector ks k) =>
HashTable ks k vs v -> k -> v -> IO ()
insert (HashTable Dictionary (PrimState IO) ks k vs v
h) = Dictionary (PrimState IO) ks k vs v -> k -> v -> IO ()
forall (ks :: * -> * -> *) k (vs :: * -> * -> *) v (m :: * -> *).
(MVector ks k, MVector vs v, PrimMonad m, Hashable k, Eq k) =>
Dictionary (PrimState m) ks k vs v -> k -> v -> m ()
Data.Vector.Hashtables.insert Dictionary (PrimState IO) ks k vs v
h
{-# INLINABLE insert #-}

-- | Tries to find a value corresponding to the key in the hash table.

lookup :: (Hashable k, MVector ks k, MVector vs v)
       => HashTable ks k vs v -> k -> IO (Maybe v)
lookup :: forall k (ks :: * -> * -> *) (vs :: * -> * -> *) v.
(Hashable k, MVector ks k, MVector vs v) =>
HashTable ks k vs v -> k -> IO (Maybe v)
lookup (HashTable Dictionary (PrimState IO) ks k vs v
h) = Dictionary (PrimState IO) ks k vs v -> k -> IO (Maybe v)
forall (ks :: * -> * -> *) k (vs :: * -> * -> *) v (m :: * -> *).
(MVector ks k, MVector vs v, PrimMonad m, Hashable k, Eq k) =>
Dictionary (PrimState m) ks k vs v -> k -> m (Maybe v)
Data.Vector.Hashtables.lookup Dictionary (PrimState IO) ks k vs v
h
{-# INLINABLE lookup #-}

-- | Converts the hash table to a list.
--
-- The order of the elements in the list is unspecified.

toList :: (Hashable k, MVector ks k, MVector vs v) => HashTable ks k vs v -> IO [(k, v)]
toList :: forall k (ks :: * -> * -> *) (vs :: * -> * -> *) v.
(Hashable k, MVector ks k, MVector vs v) =>
HashTable ks k vs v -> IO [(k, v)]
toList (HashTable Dictionary (PrimState IO) ks k vs v
h) = Dictionary (PrimState IO) ks k vs v -> IO [(k, v)]
forall (ks :: * -> * -> *) k (vs :: * -> * -> *) v (m :: * -> *).
(MVector ks k, MVector vs v, PrimMonad m, Hashable k, Eq k) =>
Dictionary (PrimState m) ks k vs v -> m [(k, v)]
Data.Vector.Hashtables.toList Dictionary (PrimState IO) ks k vs v
h
{-# INLINABLE toList #-}

-- | Iterate over key-value pairs in IO.
forAssocs :: (MVector ks k, MVector vs v)
          => HashTable ks k vs v -> (k -> v -> IO ()) -> IO ()
forAssocs :: forall (ks :: * -> * -> *) k (vs :: * -> * -> *) v.
(MVector ks k, MVector vs v) =>
HashTable ks k vs v -> (k -> v -> IO ()) -> IO ()
forAssocs (HashTable Dictionary (PrimState IO) ks k vs v
h) k -> v -> IO ()
f = do
  Dictionary{..} <- MutVar (PrimState IO) (Dictionary_ RealWorld ks k vs v)
-> IO (Dictionary_ RealWorld ks k vs v)
forall (m :: * -> *) a.
PrimMonad m =>
MutVar (PrimState m) a -> m a
readMutVar (Dictionary RealWorld ks k vs v
-> MutVar RealWorld (Dictionary_ RealWorld ks k vs v)
forall s (ks :: * -> * -> *) k (vs :: * -> * -> *) v.
Dictionary s ks k vs v -> MutVar s (Dictionary_ s ks k vs v)
getDRef Dictionary RealWorld ks k vs v
Dictionary (PrimState IO) ks k vs v
h)
  count <- refs ! getCount
  let go :: Int -> IO ()
      go Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      go Int
i = do
        h <- IntArray RealWorld
MutablePrimArray (PrimState IO) Int
hashCode MutablePrimArray (PrimState IO) Int -> Int -> IO Int
forall (m :: * -> *).
PrimMonad m =>
MutablePrimArray (PrimState m) Int -> Int -> m Int
! Int
i
        if h < 0 then
          go (i - 1)
        else do
          k <- key !~ i
          v <- value !~ i
          _ <- f k v
          go (i - 1)
  go (count - 1)
{-# INLINE forAssocs #-}

size :: MVector ks k => HashTable ks k vs v -> IO Int
size :: forall (ks :: * -> * -> *) k (vs :: * -> * -> *) v.
MVector ks k =>
HashTable ks k vs v -> IO Int
size (HashTable Dictionary (PrimState IO) ks k vs v
h) = Dictionary (PrimState IO) ks k vs v -> IO Int
forall (ks :: * -> * -> *) k (m :: * -> *) (vs :: * -> * -> *) v.
(MVector ks k, PrimMonad m) =>
Dictionary (PrimState m) ks k vs v -> m Int
Data.Vector.Hashtables.size Dictionary (PrimState IO) ks k vs v
h
{-# INLINE size #-}

-- | Look up a key in the table. If it's already there, proceed with
--   the first @(v -> m a) argument. Otherwise run the @(m v) argument,
--   insert the result value in the table and pass it to the second @(v -> m a)
--   argument.
insertingIfAbsent :: forall ks k vs v a.
          (Hashable k, MVector ks k, MVector vs v)
       => HashTable ks k vs v
       -> k
       -> (v -> IO a)
       -> IO v
       -> (v -> IO a)
       -> IO a
insertingIfAbsent :: forall (ks :: * -> * -> *) k (vs :: * -> * -> *) v a.
(Hashable k, MVector ks k, MVector vs v) =>
HashTable ks k vs v
-> k -> (v -> IO a) -> IO v -> (v -> IO a) -> IO a
insertingIfAbsent (HashTable DRef{MutVar (PrimState IO) (Dictionary_ (PrimState IO) ks k vs v)
getDRef :: forall s (ks :: * -> * -> *) k (vs :: * -> * -> *) v.
Dictionary s ks k vs v -> MutVar s (Dictionary_ s ks k vs v)
getDRef :: MutVar (PrimState IO) (Dictionary_ (PrimState IO) ks k vs v)
..}) k
key' v -> IO a
found IO v
getValue' v -> IO a
notfound = do
    d@Dictionary{..} <- MutVar (PrimState IO) (Dictionary_ RealWorld ks k vs v)
-> IO (Dictionary_ RealWorld ks k vs v)
forall (m :: * -> *) a.
PrimMonad m =>
MutVar (PrimState m) a -> m a
readMutVar MutVar (PrimState IO) (Dictionary_ RealWorld ks k vs v)
MutVar (PrimState IO) (Dictionary_ (PrimState IO) ks k vs v)
getDRef
    let
        hashCode' = k -> Int
forall a. Hashable a => a -> Int
hash k
key' Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
mask
        !targetBucket = Int
hashCode' Int -> FastRem -> Int
`fastRem` FastRem
remSize

        go :: Int -> IO a
        go Int
i    | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = do
                    hc <- IntArray RealWorld
MutablePrimArray (PrimState IO) Int
hashCode MutablePrimArray (PrimState IO) Int -> Int -> IO Int
forall (m :: * -> *).
PrimMonad m =>
MutablePrimArray (PrimState m) Int -> Int -> m Int
! Int
i
                    if hc == hashCode'
                        then do
                            k  <- key !~ i
                            if k == key'
                                then do
                                  v <- value !~ i
                                  found v
                                else go =<< next ! i
                        else go =<< next ! i
                | Bool
otherwise = IO a
addOrResize

        addOrResize :: IO a
        addOrResize = do
            freeCount <- IntArray RealWorld
MutablePrimArray (PrimState IO) Int
refs MutablePrimArray (PrimState IO) Int -> Int -> IO Int
forall (m :: * -> *).
PrimMonad m =>
MutablePrimArray (PrimState m) Int -> Int -> m Int
! Int
getFreeCount
            value' <- getValue'
            if freeCount > 0
                then do
                    index <- refs ! getFreeList
                    nxt <- next ! index
                    refs <~ getFreeList $ nxt
                    refs <~ getFreeCount $ freeCount - 1
                    add index targetBucket value'
                else do
                    count <- refs ! getCount
                    refs <~ getCount $ count + 1
                    nextLen <- A.getSizeofMutablePrimArray next
                    if count == nextLen
                        then do
                            nd <- resize d count hashCode' key' value'
                            writeMutVar getDRef nd
                            notfound value'
                        else add count targetBucket value'

        add :: Int -> Int -> v -> IO a
        add !Int
index !Int
targetBucket !v
value' = do
            IntArray RealWorld
MutablePrimArray (PrimState IO) Int
hashCode MutablePrimArray (PrimState IO) Int -> Int -> Int -> IO ()
forall (m :: * -> *).
PrimMonad m =>
MutablePrimArray (PrimState m) Int -> Int -> Int -> m ()
<~ Int
index (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Int
hashCode'
            b <- IntArray RealWorld
MutablePrimArray (PrimState IO) Int
buckets MutablePrimArray (PrimState IO) Int -> Int -> IO Int
forall (m :: * -> *).
PrimMonad m =>
MutablePrimArray (PrimState m) Int -> Int -> m Int
! Int
targetBucket
            next <~ index $ b
            key <~~ index $ key'
            value <~~ index $ value'
            buckets <~ targetBucket $ index
            notfound value'

    go =<< buckets ! targetBucket
{-# INLINE insertingIfAbsent #-}