{-# OPTIONS_GHC -Wunused-imports #-}

module Agda.TypeChecking.Monad.Imports
  ( addImport
  , locallyAddImport
  , checkForImportCycle
  , dropDecodedModule
  , getDecodedModule
  , getDecodedModules
  , getPrettyVisitedModules
  , getVisitedModule
  , getVisitedModules
  , setDecodedModules
  , setVisitedModules
  , storeDecodedModule
  , visitModule
  ) where

import Control.Monad   ( when )

import Data.Maybe (catMaybes)
import Data.Map qualified as Map
import Data.Set (Set)
import Data.Set qualified as Set

import Agda.Syntax.Common.Pretty
import Agda.Syntax.TopLevelModuleName
import Agda.TypeChecking.Monad.Base

import Agda.Utils.List ( caseListM )
import Agda.Utils.List1 qualified as List1
import Agda.Utils.List2 qualified as List2
import Agda.Utils.Singleton (singleton)
import Agda.Utils.Tuple ( (***) )

import Agda.Utils.Impossible

-- | Register the given module as imported in the current state.
--   Also recursively add its imports to the cumulative imports.
addImport :: TopLevelModuleName -> TCM ()
addImport :: TopLevelModuleName -> TCM ()
addImport TopLevelModuleName
top = do
  vis <- TCMT IO VisitedModules
forall (m :: * -> *). ReadTCState m => m VisitedModules
getVisitedModules
  modifyTCLens' stImportedModulesAndTransitive $ updateImports vis top

-- | Temporarily register the given module as imported.
locallyAddImport :: TopLevelModuleName -> TCM () -> TCM ()
locallyAddImport :: TopLevelModuleName -> TCM () -> TCM ()
locallyAddImport TopLevelModuleName
top TCM ()
cont = do
  vis <- TCMT IO VisitedModules
forall (m :: * -> *). ReadTCState m => m VisitedModules
getVisitedModules
  locallyTCState stImportedModulesAndTransitive (updateImports vis top) cont

updateImports :: VisitedModules -> TopLevelModuleName
              -> (ImportedModules, ImportedModules)
              -> (ImportedModules, ImportedModules)
updateImports :: VisitedModules
-> TopLevelModuleName
-> (Set TopLevelModuleName, Set TopLevelModuleName)
-> (Set TopLevelModuleName, Set TopLevelModuleName)
updateImports VisitedModules
vis TopLevelModuleName
top
  = TopLevelModuleName
-> Set TopLevelModuleName -> Set TopLevelModuleName
forall a. Ord a => a -> Set a -> Set a
Set.insert TopLevelModuleName
top (Set TopLevelModuleName -> Set TopLevelModuleName)
-> (Set TopLevelModuleName -> Set TopLevelModuleName)
-> (Set TopLevelModuleName, Set TopLevelModuleName)
-> (Set TopLevelModuleName, Set TopLevelModuleName)
forall b c b' c'. (b -> c) -> (b' -> c') -> (b, b') -> (c, c')
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** VisitedModules
-> Set TopLevelModuleName
-> Set TopLevelModuleName
-> Set TopLevelModuleName
completeTransitiveImports VisitedModules
vis (TopLevelModuleName -> Set TopLevelModuleName
forall el coll. Singleton el coll => el -> coll
singleton TopLevelModuleName
top)

-- | @completeTransitiveImports ms ms'@.
--   Precondition: @ms@ disjoint from @ms'@.
completeTransitiveImports :: VisitedModules -> Set TopLevelModuleName -> ImportedModules -> ImportedModules
completeTransitiveImports :: VisitedModules
-> Set TopLevelModuleName
-> Set TopLevelModuleName
-> Set TopLevelModuleName
completeTransitiveImports VisitedModules
vis Set TopLevelModuleName
ms Set TopLevelModuleName
old = if Set TopLevelModuleName -> Bool
forall a. Set a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Set TopLevelModuleName
ms then Set TopLevelModuleName
old else do

  -- Add the given imports to the current set.
  let next :: Set TopLevelModuleName
next = Set TopLevelModuleName
old Set TopLevelModuleName
-> Set TopLevelModuleName -> Set TopLevelModuleName
forall a. Ord a => Set a -> Set a -> Set a
`Set.union` Set TopLevelModuleName
ms

  -- The interfaces for the modules we added to the transitive imports.
  let is :: [ModuleInfo]
is = [Maybe ModuleInfo] -> [ModuleInfo]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe ModuleInfo] -> [ModuleInfo])
-> [Maybe ModuleInfo] -> [ModuleInfo]
forall a b. (a -> b) -> a -> b
$ (TopLevelModuleName -> VisitedModules -> Maybe ModuleInfo
forall k a. Ord k => k -> Map k a -> Maybe a
`Map.lookup` VisitedModules
vis) (TopLevelModuleName -> Maybe ModuleInfo)
-> [TopLevelModuleName] -> [Maybe ModuleInfo]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set TopLevelModuleName -> [TopLevelModuleName]
forall a. Set a -> [a]
Set.toList Set TopLevelModuleName
ms

  -- The imports of these modules.
  let imps :: Set TopLevelModuleName
imps = [Set TopLevelModuleName] -> Set TopLevelModuleName
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions ([Set TopLevelModuleName] -> Set TopLevelModuleName)
-> [Set TopLevelModuleName] -> Set TopLevelModuleName
forall a b. (a -> b) -> a -> b
$ (ModuleInfo -> Set TopLevelModuleName)
-> [ModuleInfo] -> [Set TopLevelModuleName]
forall a b. (a -> b) -> [a] -> [b]
map ([TopLevelModuleName] -> Set TopLevelModuleName
forall a. Ord a => [a] -> Set a
Set.fromList ([TopLevelModuleName] -> Set TopLevelModuleName)
-> (ModuleInfo -> [TopLevelModuleName])
-> ModuleInfo
-> Set TopLevelModuleName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((TopLevelModuleName, Hash) -> TopLevelModuleName)
-> [(TopLevelModuleName, Hash)] -> [TopLevelModuleName]
forall a b. (a -> b) -> [a] -> [b]
map (TopLevelModuleName, Hash) -> TopLevelModuleName
forall a b. (a, b) -> a
fst ([(TopLevelModuleName, Hash)] -> [TopLevelModuleName])
-> (ModuleInfo -> [(TopLevelModuleName, Hash)])
-> ModuleInfo
-> [TopLevelModuleName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Interface -> [(TopLevelModuleName, Hash)]
iImportedModules (Interface -> [(TopLevelModuleName, Hash)])
-> (ModuleInfo -> Interface)
-> ModuleInfo
-> [(TopLevelModuleName, Hash)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModuleInfo -> Interface
miInterface) [ModuleInfo]
is

  -- Recurse on the new imports.
  VisitedModules
-> Set TopLevelModuleName
-> Set TopLevelModuleName
-> Set TopLevelModuleName
completeTransitiveImports VisitedModules
vis (Set TopLevelModuleName
imps Set TopLevelModuleName
-> Set TopLevelModuleName -> Set TopLevelModuleName
forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` Set TopLevelModuleName
next) Set TopLevelModuleName
next

visitModule :: ModuleInfo -> TCM ()
visitModule :: ModuleInfo -> TCM ()
visitModule ModuleInfo
mi =
  Lens' TCState VisitedModules
-> (VisitedModules -> VisitedModules) -> TCM ()
forall (m :: * -> *) a.
MonadTCState m =>
Lens' TCState a -> (a -> a) -> m ()
modifyTCLens (VisitedModules -> f VisitedModules) -> TCState -> f TCState
Lens' TCState VisitedModules
stVisitedModules ((VisitedModules -> VisitedModules) -> TCM ())
-> (VisitedModules -> VisitedModules) -> TCM ()
forall a b. (a -> b) -> a -> b
$
    TopLevelModuleName
-> ModuleInfo -> VisitedModules -> VisitedModules
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (Interface -> TopLevelModuleName
iTopLevelModuleName (Interface -> TopLevelModuleName)
-> Interface -> TopLevelModuleName
forall a b. (a -> b) -> a -> b
$ ModuleInfo -> Interface
miInterface ModuleInfo
mi) ModuleInfo
mi

setVisitedModules :: VisitedModules -> TCM ()
setVisitedModules :: VisitedModules -> TCM ()
setVisitedModules VisitedModules
ms = Lens' TCState VisitedModules -> VisitedModules -> TCM ()
forall (m :: * -> *) a.
MonadTCState m =>
Lens' TCState a -> a -> m ()
setTCLens (VisitedModules -> f VisitedModules) -> TCState -> f TCState
Lens' TCState VisitedModules
stVisitedModules VisitedModules
ms

getVisitedModules :: ReadTCState m => m VisitedModules
getVisitedModules :: forall (m :: * -> *). ReadTCState m => m VisitedModules
getVisitedModules = Lens' TCState VisitedModules -> m VisitedModules
forall (m :: * -> *) a. ReadTCState m => Lens' TCState a -> m a
useTC (VisitedModules -> f VisitedModules) -> TCState -> f TCState
Lens' TCState VisitedModules
stVisitedModules

getPrettyVisitedModules :: ReadTCState m => m Doc
getPrettyVisitedModules :: forall (m :: * -> *). ReadTCState m => m Doc
getPrettyVisitedModules = do
  visited <-  ((TopLevelModuleName, ModuleInfo) -> Doc)
-> [(TopLevelModuleName, ModuleInfo)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Doc -> Doc -> Doc) -> (Doc, Doc) -> Doc
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
(<>) ((Doc, Doc) -> Doc)
-> ((TopLevelModuleName, ModuleInfo) -> (Doc, Doc))
-> (TopLevelModuleName, ModuleInfo)
-> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TopLevelModuleName -> Doc
forall a. Pretty a => a -> Doc
pretty (TopLevelModuleName -> Doc)
-> (ModuleInfo -> Doc)
-> (TopLevelModuleName, ModuleInfo)
-> (Doc, Doc)
forall b c b' c'. (b -> c) -> (b' -> c') -> (b, b') -> (c, c')
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** (ModuleCheckMode -> Doc
prettyCheckMode (ModuleCheckMode -> Doc)
-> (ModuleInfo -> ModuleCheckMode) -> ModuleInfo -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModuleInfo -> ModuleCheckMode
miMode))) ([(TopLevelModuleName, ModuleInfo)] -> [Doc])
-> (VisitedModules -> [(TopLevelModuleName, ModuleInfo)])
-> VisitedModules
-> [Doc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VisitedModules -> [(TopLevelModuleName, ModuleInfo)]
forall k a. Map k a -> [(k, a)]
Map.toList
          (VisitedModules -> [Doc]) -> m VisitedModules -> m [Doc]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m VisitedModules
forall (m :: * -> *). ReadTCState m => m VisitedModules
getVisitedModules
  return $ hcat $ punctuate ", " visited
  where
  prettyCheckMode :: ModuleCheckMode -> Doc
  prettyCheckMode :: ModuleCheckMode -> Doc
prettyCheckMode ModuleCheckMode
ModuleTypeChecked                  = Doc
""
  prettyCheckMode ModuleCheckMode
ModuleScopeChecked                 = Doc
" (scope only)"

getVisitedModule :: ReadTCState m
                 => TopLevelModuleName
                 -> m (Maybe ModuleInfo)
getVisitedModule :: forall (m :: * -> *).
ReadTCState m =>
TopLevelModuleName -> m (Maybe ModuleInfo)
getVisitedModule TopLevelModuleName
x = TopLevelModuleName -> VisitedModules -> Maybe ModuleInfo
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup TopLevelModuleName
x (VisitedModules -> Maybe ModuleInfo)
-> m VisitedModules -> m (Maybe ModuleInfo)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lens' TCState VisitedModules -> m VisitedModules
forall (m :: * -> *) a. ReadTCState m => Lens' TCState a -> m a
useTC (VisitedModules -> f VisitedModules) -> TCState -> f TCState
Lens' TCState VisitedModules
stVisitedModules

getDecodedModules :: TCM DecodedModules
getDecodedModules :: TCMT IO VisitedModules
getDecodedModules = PersistentTCState -> VisitedModules
stDecodedModules (PersistentTCState -> VisitedModules)
-> (TCState -> PersistentTCState) -> TCState -> VisitedModules
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TCState -> PersistentTCState
stPersistentState (TCState -> VisitedModules)
-> TCMT IO TCState -> TCMT IO VisitedModules
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TCMT IO TCState
forall (m :: * -> *). MonadTCState m => m TCState
getTC

setDecodedModules :: DecodedModules -> TCM ()
setDecodedModules :: VisitedModules -> TCM ()
setDecodedModules VisitedModules
ms = (TCState -> TCState) -> TCM ()
forall (m :: * -> *).
MonadTCState m =>
(TCState -> TCState) -> m ()
modifyTC' ((TCState -> TCState) -> TCM ()) -> (TCState -> TCState) -> TCM ()
forall a b. (a -> b) -> a -> b
$ \TCState
s ->
  TCState
s { stPersistentState = (stPersistentState s) { stDecodedModules = ms } }

getDecodedModule :: TopLevelModuleName -> TCM (Maybe ModuleInfo)
getDecodedModule :: TopLevelModuleName -> TCM (Maybe ModuleInfo)
getDecodedModule TopLevelModuleName
x = TopLevelModuleName -> VisitedModules -> Maybe ModuleInfo
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup TopLevelModuleName
x (VisitedModules -> Maybe ModuleInfo)
-> (TCState -> VisitedModules) -> TCState -> Maybe ModuleInfo
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PersistentTCState -> VisitedModules
stDecodedModules (PersistentTCState -> VisitedModules)
-> (TCState -> PersistentTCState) -> TCState -> VisitedModules
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TCState -> PersistentTCState
stPersistentState (TCState -> Maybe ModuleInfo)
-> TCMT IO TCState -> TCM (Maybe ModuleInfo)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TCMT IO TCState
forall (m :: * -> *). MonadTCState m => m TCState
getTC

storeDecodedModule :: ModuleInfo -> TCM ()
storeDecodedModule :: ModuleInfo -> TCM ()
storeDecodedModule ModuleInfo
mi = (TCState -> TCState) -> TCM ()
forall (m :: * -> *).
MonadTCState m =>
(TCState -> TCState) -> m ()
modifyTC ((TCState -> TCState) -> TCM ()) -> (TCState -> TCState) -> TCM ()
forall a b. (a -> b) -> a -> b
$ \TCState
s ->
  TCState
s { stPersistentState =
        (stPersistentState s) { stDecodedModules =
          Map.insert (iTopLevelModuleName $ miInterface mi) mi $
            stDecodedModules (stPersistentState s)
        }
  }

dropDecodedModule :: TopLevelModuleName -> TCM ()
dropDecodedModule :: TopLevelModuleName -> TCM ()
dropDecodedModule TopLevelModuleName
x = (TCState -> TCState) -> TCM ()
forall (m :: * -> *).
MonadTCState m =>
(TCState -> TCState) -> m ()
modifyTC ((TCState -> TCState) -> TCM ()) -> (TCState -> TCState) -> TCM ()
forall a b. (a -> b) -> a -> b
$ \TCState
s ->
  TCState
s { stPersistentState =
        (stPersistentState s) { stDecodedModules =
                                  Map.delete x $ stDecodedModules $ stPersistentState s
                              }
  }


-- | Assumes that the first module in the import path is the module we are
--   worried about.
checkForImportCycle :: TCM ()
checkForImportCycle :: TCM ()
checkForImportCycle = do
  TCMT IO [TopLevelModuleName]
-> TCM ()
-> (TopLevelModuleName -> [TopLevelModuleName] -> TCM ())
-> TCM ()
forall (m :: * -> *) a b.
Monad m =>
m [a] -> m b -> (a -> [a] -> m b) -> m b
caseListM ((TCEnv -> [TopLevelModuleName]) -> TCMT IO [TopLevelModuleName]
forall (m :: * -> *) a. MonadTCEnv m => (TCEnv -> a) -> m a
asksTC TCEnv -> [TopLevelModuleName]
envImportStack) TCM ()
forall a. HasCallStack => a
__IMPOSSIBLE__ ((TopLevelModuleName -> [TopLevelModuleName] -> TCM ()) -> TCM ())
-> (TopLevelModuleName -> [TopLevelModuleName] -> TCM ()) -> TCM ()
forall a b. (a -> b) -> a -> b
$ \ TopLevelModuleName
m [TopLevelModuleName]
ms -> do
    Bool -> TCM () -> TCM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TopLevelModuleName
m TopLevelModuleName -> [TopLevelModuleName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [TopLevelModuleName]
ms) (TCM () -> TCM ()) -> TCM () -> TCM ()
forall a b. (a -> b) -> a -> b
$ TypeError -> TCM ()
forall (m :: * -> *) a.
(HasCallStack, MonadTCError m) =>
TypeError -> m a
typeError (TypeError -> TCM ()) -> TypeError -> TCM ()
forall a b. (a -> b) -> a -> b
$ List2 TopLevelModuleName -> TypeError
CyclicModuleDependency (List2 TopLevelModuleName -> TypeError)
-> List2 TopLevelModuleName -> TypeError
forall a b. (a -> b) -> a -> b
$
      List1 TopLevelModuleName
-> TopLevelModuleName -> List2 TopLevelModuleName
forall a. List1 a -> a -> List2 a
List2.snoc (List1 TopLevelModuleName
-> [TopLevelModuleName] -> List1 TopLevelModuleName
forall a. List1 a -> [a] -> List1 a
List1.fromListSafe List1 TopLevelModuleName
forall a. HasCallStack => a
__IMPOSSIBLE__ ([TopLevelModuleName] -> List1 TopLevelModuleName)
-> [TopLevelModuleName] -> List1 TopLevelModuleName
forall a b. (a -> b) -> a -> b
$ (TopLevelModuleName -> Bool)
-> [TopLevelModuleName] -> [TopLevelModuleName]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (TopLevelModuleName -> TopLevelModuleName -> Bool
forall a. Eq a => a -> a -> Bool
/= TopLevelModuleName
m) ([TopLevelModuleName] -> [TopLevelModuleName])
-> [TopLevelModuleName] -> [TopLevelModuleName]
forall a b. (a -> b) -> a -> b
$ [TopLevelModuleName] -> [TopLevelModuleName]
forall a. [a] -> [a]
reverse [TopLevelModuleName]
ms) TopLevelModuleName
m
        -- NB: we know that ms contains m, so even after dropWhile the list is not empty.