module Agda.Interaction.Command
  ( CommandM, localStateCommandM, liftLocalState, revLift, revLiftTC
  ) where

import Control.Monad.State ( MonadState(..), execStateT, lift )

import Agda.TypeChecking.Monad.Base ( TCM,MonadTCState, TCState, getTC, putTC )
import Agda.TypeChecking.Monad.State ( localTCState )

import Agda.Interaction.Base ( CommandM' )

------------------------------------------------------------------------
-- The CommandM monad

type CommandM  = CommandM' TCM

-- | Restore both 'TCState' and 'CommandState'.

localStateCommandM :: CommandM a -> CommandM a
localStateCommandM :: forall a. CommandM a -> CommandM a
localStateCommandM CommandM a
m = do
  cSt <- StateT CommandState TCM CommandState
forall s (m :: * -> *). MonadState s m => m s
get
  tcSt <- getTC
  x <- m
  putTC tcSt
  put cSt
  return x

-- | Restore 'TCState', do not touch 'CommandState'.

liftLocalState :: TCM a -> CommandM a
liftLocalState :: forall a. TCM a -> CommandM a
liftLocalState = TCM a -> StateT CommandState TCM a
forall (m :: * -> *) a. Monad m => m a -> StateT CommandState m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (TCM a -> StateT CommandState TCM a)
-> (TCM a -> TCM a) -> TCM a -> StateT CommandState TCM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TCM a -> TCM a
forall a. TCM a -> TCM a
localTCState

-- | Build an opposite action to 'lift' for state monads.

revLift
    :: MonadState st m
    => (forall c . m c -> st -> k (c, st))      -- ^ run
    -> (forall b . k b -> m b)                  -- ^ lift
    -> (forall x . (m a -> k x) -> k x) -> m a  -- ^ reverse lift in double negative position
revLift :: forall st (m :: * -> *) (k :: * -> *) a.
MonadState st m =>
(forall c. m c -> st -> k (c, st))
-> (forall b. k b -> m b) -> (forall x. (m a -> k x) -> k x) -> m a
revLift forall c. m c -> st -> k (c, st)
run forall b. k b -> m b
lift' forall x. (m a -> k x) -> k x
f = do
    st <- m st
forall s (m :: * -> *). MonadState s m => m s
get
    (a, st') <- lift' $ f (`run` st)
    put st'
    return a

revLiftTC
    :: MonadTCState m
    => (forall c . m c -> TCState -> k (c, TCState))  -- ^ run
    -> (forall b . k b -> m b)                        -- ^ lift
    -> (forall x . (m a -> k x) -> k x) -> m a        -- ^ reverse lift in double negative position
revLiftTC :: forall (m :: * -> *) (k :: * -> *) a.
MonadTCState m =>
(forall c. m c -> TCState -> k (c, TCState))
-> (forall b. k b -> m b) -> (forall x. (m a -> k x) -> k x) -> m a
revLiftTC forall c. m c -> TCState -> k (c, TCState)
run forall b. k b -> m b
lift' forall x. (m a -> k x) -> k x
f = do
    st <- m TCState
forall (m :: * -> *). MonadTCState m => m TCState
getTC
    (a, st') <- lift' $ f (`run` st)
    putTC st'
    return a