-- | Wrappers around 'MVar's which provides an atomic modification operation.
module Agda.Utils.Atomic
  ( Atomic
  , newAtomic
  , readAtomic
  , modifyAtomic
  , withAtomic
  )
  where

import Control.Monad.IO.Class
import Control.Monad.Catch
import Control.DeepSeq

import Control.Concurrent
import Control.Exception (evaluate)

-- | A mutable variable which can be read from and *modified*.
-- This provides a 'modifyAtomic' combinator with atomic semantics for
-- modification, unlike 'modifyMVar'.
newtype Atomic a = Atomic { forall a. Atomic a -> MVar a
unAtomic :: MVar a }
  deriving (Atomic a -> Atomic a -> Bool
(Atomic a -> Atomic a -> Bool)
-> (Atomic a -> Atomic a -> Bool) -> Eq (Atomic a)
forall a. Atomic a -> Atomic a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Atomic a -> Atomic a -> Bool
== :: Atomic a -> Atomic a -> Bool
$c/= :: forall a. Atomic a -> Atomic a -> Bool
/= :: Atomic a -> Atomic a -> Bool
Eq, Atomic a -> ()
(Atomic a -> ()) -> NFData (Atomic a)
forall a. Atomic a -> ()
forall a. (a -> ()) -> NFData a
$crnf :: forall a. Atomic a -> ()
rnf :: Atomic a -> ()
NFData)

-- Implementation note: the caveat to modifyMVar's atomicity is that it
-- is possible for another thread to have a *different* value of 'a',
-- which they can 'putMVar', while the continuation to 'modifyMVar' is
-- executing (since the variable is left in the empty state for its
-- duration). This would block the call to 'modifyMVar' after the end of
-- the continuation.
--
-- By contrast, we simply do not export a "put" variant of writing to an
-- atomic. This should, hopefully, be enough to ensure atomicity.

-- | Create a new atomic variable with the given initial value.
newAtomic :: MonadIO m => a -> m (Atomic a)
newAtomic :: forall (m :: * -> *) a. MonadIO m => a -> m (Atomic a)
newAtomic = IO (Atomic a) -> m (Atomic a)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Atomic a) -> m (Atomic a))
-> (a -> IO (Atomic a)) -> a -> m (Atomic a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MVar a -> Atomic a) -> IO (MVar a) -> IO (Atomic a)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MVar a -> Atomic a
forall a. MVar a -> Atomic a
Atomic (IO (MVar a) -> IO (Atomic a))
-> (a -> IO (MVar a)) -> a -> IO (Atomic a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> IO (MVar a)
forall a. a -> IO (MVar a)
newMVar

-- | Read the current state of the atomic variable, waiting if any other
-- thread is modifying it.
--
-- Like 'readMVar', 'readAtomic' is multiple-wakeup, which means that
-- all threads waiting on a modification will be woken up when it
-- finishes.
readAtomic :: MonadIO m => Atomic a -> m a
readAtomic :: forall (m :: * -> *) a. MonadIO m => Atomic a -> m a
readAtomic = IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a) -> (Atomic a -> IO a) -> Atomic a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MVar a -> IO a
forall a. MVar a -> IO a
readMVar (MVar a -> IO a) -> (Atomic a -> MVar a) -> Atomic a -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Atomic a -> MVar a
forall a. Atomic a -> MVar a
unAtomic
{-# INLINE readAtomic #-}

-- | Modify the contents of an atomic variable. If the continuation
-- fails, or receives an asynchronous exception, the variable is
-- returned to its old state.
--
-- No thread, *including the calling thread*, can access the contents of
-- the same 'Atomic' while this function is executing, for reading *or*
-- writing. This means that *any* nested use of the variable, as in
--
-- @
-- f var = modifyAtomic var \old ->
--   -- ...
--   modifyAtomic var \old' -> ... -- (!)
--   readAtomic var                -- (!)
-- @
--
-- will result in a deadlock. The new state of the variable is evaluated
-- (to WHNF).
modifyAtomic :: (MonadIO m, MonadMask m) => Atomic a -> (a -> m (a, b)) -> m b
modifyAtomic :: forall (m :: * -> *) a b.
(MonadIO m, MonadMask m) =>
Atomic a -> (a -> m (a, b)) -> m b
modifyAtomic (Atomic MVar a
var) a -> m (a, b)
k = ((forall a. m a -> m a) -> m b) -> m b
forall b. HasCallStack => ((forall a. m a -> m a) -> m b) -> m b
forall (m :: * -> *) b.
(MonadMask m, HasCallStack) =>
((forall a. m a -> m a) -> m b) -> m b
mask \forall a. m a -> m a
restore -> do
  old <- IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a) -> IO a -> m a
forall a b. (a -> b) -> a -> b
$ MVar a -> IO a
forall a. MVar a -> IO a
takeMVar MVar a
var
  (new, ret) <- restore (k old >>= liftIO . evaluate)
    `onException` liftIO (putMVar var old)
  ret <$ liftIO (putMVar var $! new)

{-# SPECIALISE modifyAtomic :: Atomic a -> (a -> IO (a, b)) -> IO b #-}

withAtomic :: (MonadIO m, MonadMask m) => Atomic a -> (a -> m ()) -> m ()
withAtomic :: forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
Atomic a -> (a -> m ()) -> m ()
withAtomic Atomic a
var a -> m ()
k = Atomic a -> (a -> m (a, ())) -> m ()
forall (m :: * -> *) a b.
(MonadIO m, MonadMask m) =>
Atomic a -> (a -> m (a, b)) -> m b
modifyAtomic Atomic a
var \a
val -> (a
val,) (() -> (a, ())) -> m () -> m (a, ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> m ()
k a
val