{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}

-- | Utilities for working with 'ByteArray' and 'ByteArray#'.
module Agda.Utils.ByteArray
  (
  -- * Constructing byte arrays
    byteArrayOnes#
  -- * Queries
  , byteArrayIsSubsetOf#
  , byteArrayDisjoint#
  -- * Folds
  --
  -- $byteArrayFolds
  , byteArrayFoldrBits#
  , byteArrayFoldlBits#
  -- ** Strict folds
  , byteArrayFoldrBitsStrict#
  , byteArrayFoldlBitsStrict#
  ) where

-- We need the machines word size for some bitwise operations.
#include "MachDeps.h"

import GHC.Base
import GHC.Num.WordArray

import Agda.Utils.Word

--------------------------------------------------------------------------------
-- Constructing byte arrays

-- | Construct a 'ByteArray#' consisting of @n@ 1 bits.
byteArrayOnes# :: Int# -> ByteArray#
byteArrayOnes# :: Int# -> ByteArray#
byteArrayOnes# Int#
n =
  let !(# Int#
q, Int#
r #) = Int# -> Int# -> (# Int#, Int# #)
quotRemInt# Int#
n WORD_SIZE_IN_BITS# in
  if Int# -> Bool
isTrue# (Int#
r Int# -> Int# -> Int#
==# Int#
0#) then
    Int#
-> (MutableWordArray# RealWorld
    -> State# RealWorld -> State# RealWorld)
-> ByteArray#
withNewWordArray# Int#
q \MutableWordArray# RealWorld
mwa State# RealWorld
st ->
      MutableWordArray# RealWorld
-> Int# -> Int# -> State# RealWorld -> State# RealWorld
forall s.
MutableByteArray# s -> Int# -> Int# -> State# s -> State# s
byteArrayFillOnes# MutableWordArray# RealWorld
mwa Int#
0# Int#
q State# RealWorld
st
  else
    Int#
-> (MutableWordArray# RealWorld
    -> State# RealWorld -> State# RealWorld)
-> ByteArray#
withNewWordArray# (Int#
q Int# -> Int# -> Int#
+# Int#
1#) \MutableWordArray# RealWorld
mwa State# RealWorld
st ->
      let st' :: State# RealWorld
st' = MutableWordArray# RealWorld
-> Int# -> Int# -> State# RealWorld -> State# RealWorld
forall s.
MutableByteArray# s -> Int# -> Int# -> State# s -> State# s
byteArrayFillOnes# MutableWordArray# RealWorld
mwa Int#
0# Int#
q State# RealWorld
st
      in MutableWordArray# RealWorld
-> Int# -> Word# -> State# RealWorld -> State# RealWorld
forall s.
MutableWordArray# s -> Int# -> Word# -> State# s -> State# s
mwaWrite# MutableWordArray# RealWorld
mwa Int#
q (Int# -> Word#
uncheckedWordOnes# Int#
r) State# RealWorld
st'
{-# NOINLINE byteArrayOnes# #-}

-- | @byteArrayFillOnes# mwa start end st@ will fill a 'MutableByteArray#' with
-- ones from @start@ to @end - 1@.
byteArrayFillOnes# :: MutableByteArray# s -> Int# -> Int# -> State# s -> State# s
byteArrayFillOnes# :: forall s.
MutableByteArray# s -> Int# -> Int# -> State# s -> State# s
byteArrayFillOnes# MutableByteArray# s
bs Int#
i Int#
len State# s
st =
  if Int# -> Bool
isTrue# (Int#
i Int# -> Int# -> Int#
<# Int#
len) then
    let st' :: State# s
st' = MutableByteArray# s -> Int# -> Word# -> State# s -> State# s
forall s.
MutableWordArray# s -> Int# -> Word# -> State# s -> State# s
mwaWrite# MutableByteArray# s
bs Int#
i (Word# -> Word#
not# Word#
0##) State# s
st
    in MutableByteArray# s -> Int# -> Int# -> State# s -> State# s
forall s.
MutableByteArray# s -> Int# -> Int# -> State# s -> State# s
byteArrayFillOnes# MutableByteArray# s
bs (Int#
i Int# -> Int# -> Int#
+# Int#
1#) Int#
len State# s
st'
  else
    State# s
st

--------------------------------------------------------------------------------
-- Queries

-- | Check that if the set bits of a 'ByteArray#' are a subset
-- of another 'ByteArray#'.
byteArrayIsSubsetOf# :: ByteArray# -> ByteArray# -> Int#
byteArrayIsSubsetOf# :: ByteArray# -> ByteArray# -> Int#
byteArrayIsSubsetOf# ByteArray#
bs1 ByteArray#
bs2 =
  if Int# -> Bool
isTrue# (Int#
len1 Int# -> Int# -> Int#
<=# Int#
len2) then
    Int# -> Int#
loop Int#
0#
  else
    Int#
0#
  where
    len1 :: Int#
len1 = ByteArray# -> Int#
wordArraySize# ByteArray#
bs1
    len2 :: Int#
len2 = ByteArray# -> Int#
wordArraySize# ByteArray#
bs2

    loop :: Int# -> Int#
    loop :: Int# -> Int#
loop Int#
i =
      if Int# -> Bool
isTrue# (Int#
i Int# -> Int# -> Int#
<# Int#
len1) then
        let w1 :: Word#
w1 = ByteArray# -> Int# -> Word#
indexWordArray# ByteArray#
bs1 Int#
i
            w2 :: Word#
w2 = ByteArray# -> Int# -> Word#
indexWordArray# ByteArray#
bs2 Int#
i
        in if Int# -> Bool
isTrue# ((Word#
w1 Word# -> Word# -> Word#
`and#` Word#
w2) Word# -> Word# -> Int#
`eqWord#` Word#
w1) then
          Int# -> Int#
loop (Int#
i Int# -> Int# -> Int#
+# Int#
1#)
        else
          Int#
0#
      else
        Int#
1#
{-# NOINLINE byteArrayIsSubsetOf# #-}

-- | Check if two 'ByteArray#'s are bitwise disjoint.
byteArrayDisjoint# :: ByteArray# -> ByteArray# -> Int#
byteArrayDisjoint# :: ByteArray# -> ByteArray# -> Int#
byteArrayDisjoint# ByteArray#
bs1 ByteArray#
bs2 =
  let len1 :: Int#
len1 = ByteArray# -> Int#
wordArraySize# ByteArray#
bs1
      len2 :: Int#
len2 = ByteArray# -> Int#
wordArraySize# ByteArray#
bs2
  in if Int# -> Bool
isTrue# (Int#
len1 Int# -> Int# -> Int#
<=# Int#
len2) then
    Int# -> Int# -> Int#
loop Int#
0# Int#
len1
  else
    Int# -> Int# -> Int#
loop Int#
0# Int#
len2
  where
    loop :: Int# -> Int# -> Int#
    loop :: Int# -> Int# -> Int#
loop Int#
i Int#
len =
      if Int# -> Bool
isTrue# (Int#
i Int# -> Int# -> Int#
<# Int#
len) then
        let w1 :: Word#
w1 = ByteArray# -> Int# -> Word#
indexWordArray# ByteArray#
bs1 Int#
i
            w2 :: Word#
w2 = ByteArray# -> Int# -> Word#
indexWordArray# ByteArray#
bs2 Int#
i
        in if (Int# -> Bool
isTrue# (Word# -> Word# -> Int#
disjointWord# Word#
w1 Word#
w2)) then
          Int# -> Int# -> Int#
loop (Int#
i Int# -> Int# -> Int#
+# Int#
1#) Int#
len
        else
          Int#
0#
      else
        Int#
1#
{-# NOINLINE byteArrayDisjoint# #-}

--------------------------------------------------------------------------------
-- Folds

-- $byteArrayFolds
-- As usual, there is an ambiguity in left/right folds for folding over the bits of a
-- 'ByteArray#'. We opt to use the convention where we treat the 0th bit as the "head" of
-- the 'ByteArray#', so a right fold like @byteArrayFoldrBits# f x 0b1011@ would give @f 0 (f 1 (f 3 x))@.

-- | Perform a lazy right fold over the bit indicies of a 'ByteArray#'.
byteArrayFoldrBits# :: (Int -> a -> a) -> a -> ByteArray# -> a
byteArrayFoldrBits# :: forall a. (Int -> a -> a) -> a -> ByteArray# -> a
byteArrayFoldrBits# Int -> a -> a
f a
a ByteArray#
bs = Int# -> a
loop Int#
0#
  where
    len :: Int#
len = ByteArray# -> Int#
wordArraySize# ByteArray#
bs

    -- Not tail recursive.
    loop :: Int# -> a
loop Int#
i =
      if Int# -> Bool
isTrue# (Int#
i Int# -> Int# -> Int#
<# Int#
len) then
        Int# -> (Int -> a -> a) -> a -> Word# -> a
forall a. Int# -> (Int -> a -> a) -> a -> Word# -> a
wordFoldrBitsOffset# (WORD_SIZE_IN_BITS# *# i) f (loop (i +# 1#)) (indexWordArray# bs i)
      else
        a
a

-- | Perform a lazy left fold over the bit indicies of a 'ByteArray#'.
byteArrayFoldlBits# :: (a -> Int -> a) -> a -> ByteArray# -> a
byteArrayFoldlBits# :: forall a. (a -> Int -> a) -> a -> ByteArray# -> a
byteArrayFoldlBits# a -> Int -> a
f a
a ByteArray#
bs = Int# -> a
loop (ByteArray# -> Int#
wordArraySize# ByteArray#
bs Int# -> Int# -> Int#
-# Int#
1#)
  where
    -- Not tail recursive.
    loop :: Int# -> a
loop Int#
i =
      if Int# -> Bool
isTrue# (Int#
i Int# -> Int# -> Int#
>=# Int#
0#) then
        Int# -> (a -> Int -> a) -> a -> Word# -> a
forall a. Int# -> (a -> Int -> a) -> a -> Word# -> a
wordFoldlBitsOffset# (WORD_SIZE_IN_BITS# *# i) f (loop (i -# 1#)) (indexWordArray# bs i)
      else
        a
a

-- | Perform a strict right fold over the bit indicies of a 'ByteArray#'.
byteArrayFoldrBitsStrict# :: (Int -> a -> a) -> a -> ByteArray# -> a
byteArrayFoldrBitsStrict# :: forall a. (Int -> a -> a) -> a -> ByteArray# -> a
byteArrayFoldrBitsStrict# Int -> a -> a
f a
a ByteArray#
bs = Int# -> a -> a
loop (ByteArray# -> Int#
wordArraySize# ByteArray#
bs Int# -> Int# -> Int#
-# Int#
1#) a
a
  where
    -- Tail recursive.
    loop :: Int# -> a -> a
loop Int#
i !a
acc =
      if Int# -> Bool
isTrue# (Int#
i Int# -> Int# -> Int#
>=# Int#
0#) then
        Int# -> a -> a
loop (Int#
i Int# -> Int# -> Int#
-# Int#
1#) (Int# -> (Int -> a -> a) -> a -> Word# -> a
forall a. Int# -> (Int -> a -> a) -> a -> Word# -> a
wordFoldrBitsOffsetStrict# (WORD_SIZE_IN_BITS# *# i) f acc (indexWordArray# bs i))
      else
        a
acc

-- | Perform a strict left fold over the bit indicies of a 'ByteArray#'.
byteArrayFoldlBitsStrict# :: (a -> Int -> a) -> a -> ByteArray# -> a
byteArrayFoldlBitsStrict# :: forall a. (a -> Int -> a) -> a -> ByteArray# -> a
byteArrayFoldlBitsStrict# a -> Int -> a
f a
a ByteArray#
bs = Int# -> a -> a
loop Int#
0# a
a
  where
    len :: Int#
len = ByteArray# -> Int#
wordArraySize# ByteArray#
bs

    -- Tail recursive.
    loop :: Int# -> a -> a
loop Int#
i !a
acc =
      if Int# -> Bool
isTrue# (Int#
i Int# -> Int# -> Int#
<# Int#
len) then
        Int# -> a -> a
loop (Int#
i Int# -> Int# -> Int#
+# Int#
1#) (Int# -> (a -> Int -> a) -> a -> Word# -> a
forall a. Int# -> (a -> Int -> a) -> a -> Word# -> a
wordFoldlBitsOffsetStrict# (WORD_SIZE_IN_BITS# *# i) f acc (indexWordArray# bs i))
      else
        a
acc