Source code on Github
{-# OPTIONS --safe --without-K #-}
-- {-# OPTIONS -v tactic.inline:100 #-}

module Tactic.Inline where

import Data.Nat as ; import Data.Nat.Properties as 
import Data.List as L
open import Relation.Unary using () renaming (Decidable to Decidable¹)
open import Relation.Nullary.Decidable using () renaming (map′ to mapDec)
open import Algebra.Core using (Op₁)

open import Class.Functor
open import Class.Monad
open import Class.Semigroup
open import Class.DecEq
open import Class.Show
open import Class.MonadTC.Instances
open import Class.MonadError.Instances

open import Meta.Prelude
open import Reflection.Syntax
open import Reflection.Utils using (apply∗)
open import Reflection.Utils.Debug; open Debug ("tactic.inline" , 100)
-- open import Meta.Init
open import Reflection using (TC)
instance
  iTC  = MonadTC-TC
  iTCE = MonadError-TC

private
  pattern `case_of_ x y = quote case_of_ ∙⟦ x  y 

  $inline : Bool  Name  Term  TC 
  $inline genType n' `e = do
    e@(def n xs)  return `e
      where _  _IMPOSSIBLE_
    printLn $ "** Inlining "  show n  "("  show xs  ")"
    if genType
      then (declareDef (vArg n') =<< inferType e)
      else return tt
    function cs  getDefinition n
      where _  _IMPOSSIBLE_
    print $ show n  "'s clauses: "
    void $ forM cs λ c  print $ " - "  show c
    print ""
    let cs' = goᶜ n n' xs <$> cs
    print $ "\n"  show n'  "'s clauses: "
    void $ forM cs' λ c  print $ " - "  show c
    print ""
    defineFun n' cs'

   where module _ (n n' : Name) (xs : Args Term) (let ∣xs∣ = length xs) where

    lookupVar :     Maybe Term
    lookupVar lvl x
      with x ℕ.<? lvl
    ... | yes _  = nothing
    ... | no  x≮ =
      let record { quotient = k } = ℕ.≤⇒≤″ (ℕ.≮⇒≥ x≮)
      in unArg <$> xs  (∣xs∣  suc k)

    -- (B) recursively substitute free variables for the values in given `xs`
    mutual
      go :    Op₁ Term
      go lvl = λ where
        -- * (B1) substitute free variables
        (var x as)  let as′ = go∗ lvl as in case lookupVar lvl x of λ where
          nothing   var x as′
          (just t)  apply∗ t as′
        -- * (B2) rename (& instantiate) recursive calls
        (def 𝕟 as)  let as′ = go∗ lvl as in
          if 𝕟 == n then
            def n' (drop ∣xs∣ as′)
          else
            def 𝕟 as′
        (con c as)  con c (go∗ lvl as)
        (pi (arg i ty) (abs x t))  pi (arg i $ go lvl ty) (abs x $ go (suc lvl) t)
        (lam v (abs x t))  lam v (abs x $ go (suc lvl) t)
        (pat-lam cs (vArg a  []))  `case go lvl a of pat-lam (goCls lvl cs) []
        -- ^ use case_of_ for single-argument pattern lambdas (c.f. example 7)
        (pat-lam cs as)  pat-lam (goCls lvl cs) (go∗ lvl as)
        (agda-sort s)  agda-sort (goSort lvl s)
        (meta x as)  meta x (go∗ lvl as)
        t  t

      go∗ :   Op₁ (Args Term)
      go∗ lvl = λ where
        []  []
        (arg i x  as)  arg i (go lvl x)  go∗ lvl as

      goSort :   Op₁ Sort
      goSort lvl = λ where
        (set t)   set  $ go lvl t
        (prop t)  prop $ go lvl t
        s  s

      goC :   Op₁ Clause
      goC lvl = λ where
        (clause tel ps t) 
          let lvl' = lvl ℕ.+ length tel
          in clause (goTel lvl tel) (goPs lvl' ps) (go lvl' t)
        (absurd-clause tel ps) 
          let lvl' = lvl ℕ.+ length tel
          in absurd-clause (goTel lvl tel) (goPs lvl' ps)

      goCls :   Op₁ (List Clause)
      goCls lvl = λ where
        []  []
        (c  cs)  goC lvl c  goCls lvl cs

      goP :   Op₁ Pattern
      goP lvl = λ where
        (con c ps)  con c (goPs lvl ps)
        (dot t)  dot (go lvl t)
        (var x)  case lookupVar lvl x of λ where
          nothing  var x
          (just t)  dot t
        p  p

      goPs :   Op₁ (Args Pattern)
      goPs lvl = λ where
        []  []
        (arg i p  ps)  arg i (goP lvl p)  goPs lvl ps

      goTel :   Op₁ Telescope
      goTel lvl = λ where
        []  []
        ((x , arg i t)  tel)  (x , arg i (go lvl t))  goTel (suc lvl) tel

    -- ** Entrypoint (A): instantiating the clauses of a definition
    goᶜ : Clause  Clause
    goᶜ = λ where
      (clause tel ps t)  let n = length tel  ∣xs∣ in
        clause (instTel tel) (instPs n ps) (go n t)
      (absurd-clause tel ps)  let n = length tel  ∣xs∣ in
        absurd-clause (instTel tel) (instPs n ps)
     where
      -- (A1) instantiating a clause's telescope
      instTel : Op₁ Telescope
      instTel = goTel 0  drop ∣xs∣

      -- (A2) instantiating a clause's parameters
      instPs :   Op₁ (Args Pattern)
      instPs n = goPs n  drop ∣xs∣

inline inlineDecl : Name  Term  TC 
inline     = $inline false -- for use with `unquoteDef`
inlineDecl = $inline true  -- for use with `unquoteDecl`

-- ** Tests

private
  -- (1) specializing the function to be applied by `map`
  unquoteDecl sucs = inlineDecl sucs (quoteTerm (L.map suc))
  {-
  sucs : List ℕ → List ℕ
  sucs [] = [] {_} {_}
  sucs (x ∷ xs) = _∷_ {_} {_} (suc x) (sucs xs)
  -}
  _ = sucs (0  1  2  3  [])  (1  2  3  4  [])
     refl

  -- (2) specializing the predicate to be checked by `all?`)
  data Even :   Set where
    zero : Even 0
    suc  :  {n}  Even n  Even (suc (suc n))

  even? : Decidable¹ Even
  even? = λ where
    0  yes zero
    1  no λ ()
    (suc (suc n))  mapDec suc  where (suc p)  p) (even? n)

  open import Data.List.Relation.Unary.All using (All; []; _∷_; all?)

  unquoteDecl evens? = inlineDecl evens? (quoteTerm (all? even?))
  {-
  evens? : Decidable¹ (All Even)
  evens? []       = yes []
  evens? (x ∷ xs) = mapDec (uncurry _∷_) uncons (even? x ×-dec evens? xs)
  -}
  _ = evens? (0  2  [])  yes (zero  suc zero  [])
     refl

  -- (3) works under module parameters
  module _ (n m : ) where
    unquoteDecl ⟫evens? = inlineDecl ⟫evens? (quoteTerm (all? even?))
    _ = ⟫evens? (0  2  [])  yes (zero  suc zero  [])
       refl

  module _ {A B : Set} (f : A  B) where
    map' : List A  List B
    map' [] = []
    map' (x  xs) = f x  map' xs

  unquoteDecl sucs' = inlineDecl sucs' (quoteTerm (map' {B = } suc))
  _ = sucs' (0  1  2  3  [])  (1  2  3  4  [])
     refl

  -- (4) works under mutual blocks
  data Odd :   Set where
    one : Odd 1
    suc :  {n}  Odd n  Odd (suc (suc n))

  mutual
    odd? : Decidable¹ Odd
    odd? = λ where
      0  no λ ()
      1  yes one
      (suc (suc n))  mapDec suc  where (suc p)  p) (odd? n)

    unquoteDecl mevens? = inlineDecl mevens? (quoteTerm (all? even?))
    unquoteDecl modds?  = inlineDecl modds?  (quoteTerm (all? odd?))
    _ = mevens? (0  2  [])  yes (zero  suc zero  [])
       refl

    -- [AGDA BUG] cannot use _∋_ (c.f. Agda issue #7028)
    _ : modds? (1  3  [])  yes (one  suc one  [])
    _ = refl

  -- (5) works with `with`-statements (e.g. for specializing `mapMaybe`)
  toEvenOdd :    {- even part -} × Maybe  {- odd part -}
  toEvenOdd n with even? n
  ... | yes _ = n , nothing
  ... | no  _ = pred n , just 1

  toOdd :   Maybe 
  toOdd = proj₂  toEvenOdd

  unquoteDecl toOdds = inlineDecl toOdds (quoteTerm (L.mapMaybe toOdd))
  {-
  toOdds : List ℕ → List ℕ
  toOdds []       = []
  toOdds (x ∷ xs) with toOdd x
  -- ** [LIMITATION] does not recursively inline `with`-statements
  -- ... | just y  = y ∷ toOdds xs
  -- ... | nothing = toOdds xs
  ... | just y  = y ∷ mapMaybe toOdd xs
  ... | nothing = mapMaybe toOdd xs
  -}
  _ = toOdds (0  1  2  3  [])  (1  1  [])
     refl

  -- (6) test for `MOf?`
  -- open import Data.List.Relation.Unary.MOf using (MOf; mOf; MOf?)

  -- unquoteDecl MOf-even? = inlineDecl MOf-even? (quoteTerm (MOf? even?))
  -- {-
  -- MOf?-even? : ∀ m xs → Dec (MOf m Even xs)
  -- MOf?-even? zero    xs = yes done
  -- MOf?-even? (suc m) [] = no λ where (mOf (_ ∷ _) len≡ () _)
  -- MOf?-even? (suc m) (x ∷ xs) =
  --   if even? x then
  --     (λ {px} → mapDec (cons px) uncons (MOf?-even? m xs))
  --   else
  --     (λ {¬px} → mapDec skip (unskip ¬px) (MOf?-even? (suc m) xs))
  -- -}

  -- _ = MOf-even? 2 (0 ∷ 1 ∷ 2 ∷ 3 ∷ [])
  --   ≡ yes (mOf (0 ∷ 2 ∷ []) refl (refl ∷ 1 ∷ʳ refl ∷ 3 ∷ʳ []) (zero ∷ suc zero ∷ []))
  --   ∋ refl
  --   where open import Data.List.Relation.Binary.Sublist.Ext

  -- (7) works on pattern lambdas arising from `case_of_`
  refl? :   Set
  refl? n = case n  n of λ where
    (yes p)  
    (no ¬p)  

  unquoteDecl refl42 = inlineDecl refl42 (quoteTerm (refl? 42))
  {-
  refl42
    = case 42 ℕ.≟ 42 of
      (λ { (true because ofʸ p) → ⊤ ; (false because ofⁿ ¬p) → ⊥ })
  -}