------------------------------------------------------------------------
-- The Agda standard library
--
-- A solver that uses reflection to automatically obtain and solve
-- equations over rings.
------------------------------------------------------------------------

{-# OPTIONS --cubical-compatible --safe #-}

module Tactic.RingSolver where

open import Algebra
open import Data.Fin.Base   as Fin     using (Fin)
open import Data.Vec.Base   as Vec     using (Vec; _∷_; [])
open import Data.List.Base  as List    using (List; _∷_; [])
open import Data.Maybe.Base as Maybe   using (Maybe; just; nothing; fromMaybe)
open import Data.Nat.Base              using (; suc; zero; _<ᵇ_)
open import Data.Bool.Base             using (Bool; if_then_else_; true; false)
open import Data.Unit.Base             using ()
open import Data.String.Base as String using (String; _++_; parens)
open import Data.Product.Base          using (_,_; proj₁)
open import Function.Base

open import Reflection
open import Reflection.AST.Argument
open import Reflection.AST.Term as Term
open import Reflection.AST.AlphaEquality
open import Reflection.AST.Name as Name
open import Reflection.TCM.Syntax
open import Data.Nat.Reflection
open import Data.List.Reflection
import Data.Vec.Reflection as Vec

open import Tactic.RingSolver.NonReflective renaming (solve to solver)
open import Tactic.RingSolver.Core.AlmostCommutativeRing
open import Tactic.RingSolver.Core.NatSet as NatSet

open AlmostCommutativeRing

------------------------------------------------------------------------
-- Utilities

private
  VarMap : Set
  VarMap =   Maybe Term

  getVisible : Arg Term  Maybe Term
  getVisible (arg (arg-info visible _) x) = just x
  getVisible _                            = nothing

  getVisibleArgs :  n  Term  Maybe (Vec Term n)
  getVisibleArgs n (def _ xs) = Maybe.map Vec.reverse
    (List.foldl f c (List.mapMaybe getVisible xs) n)
    where
    f : (∀ n  Maybe (Vec Term n))  Term   n  Maybe (Vec Term n)
    f xs x zero    = just []
    f xs x (suc n) = Maybe.map (x ∷_) (xs n)

    c :  n  Maybe (Vec Term n)
    c zero     = just []
    c (suc _ ) = nothing
  getVisibleArgs _ _ = nothing

  curriedTerm : NatSet  Term
  curriedTerm = List.foldr go Vec.`[]  NatSet.toList
    where
    go :   Term  Term
    go x xs = var x [] Vec.`∷ xs

------------------------------------------------------------------------
-- Reflection utilities for rings

`AlmostCommutativeRing : Term
`AlmostCommutativeRing = def (quote AlmostCommutativeRing) (2 ⋯⟨∷⟩ [])

record RingOperatorTerms : Set where
  constructor add⇒_mul⇒_pow⇒_neg⇒_sub⇒_
  field
    add mul pow neg sub : Term

checkIsRing : Term  TC Term
checkIsRing ring = checkType ring `AlmostCommutativeRing

module RingReflection (`ring : Term) where

  -- Takes the name of a function that takes the ring as it's first
  -- explicit argument and the terms of it's arguments and inserts
  -- the required ring arguments
  --   e.g. "_+_" $ʳ xs = "_+_ {_} {_} ring xs"
  infixr 6 _$ʳ_
  _$ʳ_ : Name  Args Term  Term
  nm  args = def nm (2 ⋯⟅∷⟆ `ring ⟨∷⟩ args)

  `Carrier : Term
  `Carrier = quote Carrier  []

  `refl : Term
  `refl = quote refl  (1 ⋯⟅∷⟆ [])

  `sym : Term  Term
  `sym x≈y = quote sym  (2 ⋯⟅∷⟆ x≈y ⟨∷⟩ [])

  `trans : Term  Term  Term
  `trans x≈y y≈z = quote trans  (3 ⋯⟅∷⟆ x≈y ⟨∷⟩ y≈z ⟨∷⟩ [])

  -- Normalises each of the fields of the ring operator so we can
  -- compare the result against the normalised definitions we come
  -- across when converting the term passed to the macro.
  getRingOperatorTerms : TC RingOperatorTerms
  getRingOperatorTerms = 
    add⇒ normalise (quote _+_   [])
    mul⇒ normalise (quote _*_   [])
    pow⇒ normalise (quote _^_   [])
    neg⇒ normalise (quote (-_)  [])
    sub⇒ normalise (quote _-_   [])
    

------------------------------------------------------------------------
-- Reflection utilities for ring solver

module RingSolverReflection (ring : Term) (numberOfVariables : ) where
  open RingReflection ring

  `numberOfVariables : Term
  `numberOfVariables = toTerm numberOfVariables

  -- This function applies the hidden arguments that the constructors
  -- that Expr needs. The first is the universe level, the second is the
  -- type it contains, and the third is the number of variables it's
  -- indexed by. All three of these could likely be inferred, but to
  -- make things easier we supply the third because we know it.
  infix -1 _$ᵉ_
  _$ᵉ_ : Name  List (Arg Term)  Term
  e $ᵉ xs = con e (1 ⋯⟅∷⟆ `Carrier ⟅∷⟆ `numberOfVariables ⟅∷⟆ xs)

  -- A constant expression.
   : Term  Term
   x = quote Κ $ᵉ (x ⟨∷⟩ [])

  `I : Term  Term
  `I x = quote Ι $ᵉ (x ⟨∷⟩ [])

  infixl 6 _`⊜_

  _`⊜_ : Term  Term  Term
  x `⊜ y = quote _⊜_   (`numberOfVariables ⟅∷⟆ x ⟨∷⟩ y ⟨∷⟩ [])

  `correct : Term  Term  Term
  `correct x ρ = quote Ops.correct  (1 ⋯⟅∷⟆ x ⟨∷⟩ ρ ⟨∷⟩ [])

  `solver : Term  Term  Term
  `solver `f `eq = quote solver  (`numberOfVariables ⟨∷⟩ `f ⟨∷⟩ `eq ⟨∷⟩ [])

  -- Converts the raw terms provided by the macro into the `Expr`s
  -- used internally by the solver.
  --
  -- When trying to figure out the shape of an expression, one of
  -- the difficult tasks is recognizing where constants in the
  -- underlying ring are used. If we were only dealing with ℕ, we
  -- might look for its constructors: however, we want to deal with
  -- arbitrary types which implement AlmostCommutativeRing. If the
  -- Term type contained type information we might be able to
  -- recognize it there, but it doesn't.
  --
  -- We're in luck, though, because all other cases in the following
  -- function *are* recognizable. As a result, the "catch-all" case
  -- will just assume that it has a constant expression.
  convertTerm : RingOperatorTerms  VarMap  Term  TC Term
  convertTerm operatorTerms varMap = convert
    where
    open RingOperatorTerms operatorTerms

    mutual
      convert : Term  TC Term
      -- First try and match directly against the fields
      convert (def (quote _+_) xs) = convertOp₂ (quote _⊕_) xs
      convert (def (quote _*_) xs) = convertOp₂ (quote _⊗_) xs
      convert (def (quote  -_) xs) = convertOp₁ (quote  ⊝_) xs
      convert (def (quote _^_) xs) = convertExp xs
      convert (def (quote _-_) xs) = convertSub xs
      -- Other definitions the underlying implementation of the ring's fields
      convert (def nm          xs) = convertUnknownName nm xs
      -- Variables
      convert v@(var x _)          = pure $ fromMaybe ( v) (varMap x)
      -- Special case to recognise "suc" for naturals
      convert (`suc x)             = convertSuc x
      -- Otherwise we're forced to treat it as a constant
      convert t                    = pure $  t

      -- Application of a ring operator often doesn't have a type as
      -- simple as "Carrier → Carrier → Carrier": there may be hidden
      -- arguments, etc. Here, we do our best to handle those cases,
      -- by just taking the last two explicit arguments.
      convertOp₂ : Name  Args Term  TC Term
      convertOp₂ nm (x ⟨∷⟩ y ⟨∷⟩ []) = do
        x'  convert x
        y'  convert y
        pure (nm $ᵉ (x' ⟨∷⟩ y' ⟨∷⟩ []))
      convertOp₂ nm (x  xs)         = convertOp₂ nm xs
      convertOp₂ _  _                = pure unknown

      convertOp₁ : Name  Args Term  TC Term
      convertOp₁ nm (x ⟨∷⟩ []) = do
        x'  convert x
        pure (nm $ᵉ (x' ⟨∷⟩ []))
      convertOp₁ nm (x  xs)   = convertOp₁ nm xs
      convertOp₁ _  _          = pure unknown

      convertExp : Args Term  TC Term
      convertExp (x ⟨∷⟩ y ⟨∷⟩ []) = do
        x'  convert x
        pure (quote _⊛_ $ᵉ (x' ⟨∷⟩ y ⟨∷⟩ []))
      convertExp (x  xs)         = convertExp xs
      convertExp _                = pure unknown

      convertSub : Args Term  TC Term
      convertSub (x ⟨∷⟩ y ⟨∷⟩ []) = do
        x'   convert x
        -y'  convertOp₁ (quote (⊝_)) (y ⟨∷⟩ [])
        pure (quote _⊕_ $ᵉ x' ⟨∷⟩ -y' ⟨∷⟩ [])
      convertSub (x  xs)         = convertSub xs
      convertSub _                = pure unknown

      convertUnknownName : Name  Args Term  TC Term
      convertUnknownName nm xs = do
        nameTerm  normalise (def nm [])
        if (nameTerm =α= add) then convertOp₂ (quote _⊕_) xs else
          if (nameTerm =α= mul) then convertOp₂ (quote _⊗_) xs else
            if (nameTerm =α= neg) then convertOp₁ (quote ⊝_)  xs else
              if (nameTerm =α= pow) then convertExp             xs else
                if (nameTerm =α= sub) then convertSub            xs else
                  pure ( (def nm xs))

      convertSuc : Term  TC Term
      convertSuc x = do x'  convert x; pure (quote _⊕_ $ᵉ ( (toTerm 1) ⟨∷⟩ x' ⟨∷⟩ []))

------------------------------------------------------------------------
-- Macros
------------------------------------------------------------------------
-- Quantified macro

open RingReflection
open RingSolverReflection

malformedForallTypeError :  {a} {A : Set a}  Term  TC A
malformedForallTypeError found = typeError
  ( strErr "Malformed call to solve."
   strErr "Expected target type to be like: ∀ x y → x + y ≈ y + x."
   strErr "Instead: "
   termErr found
   [])

quantifiedVarMap :   VarMap
quantifiedVarMap numVars i =
  if i <ᵇ numVars
    then just (var i [])
    else nothing

constructCallToSolver : Term  RingOperatorTerms  List String  Term  Term  TC Term
constructCallToSolver `ring opNames variables `lhs `rhs = do
  `lhsExpr  conv `lhs
  `rhsExpr  conv `rhs

  pure $ `solver `ring numVars
                    (prependVLams variables (_`⊜_ `ring numVars `lhsExpr `rhsExpr))
                    (prependHLams variables (`refl `ring))
  where
  numVars : 
  numVars = List.length variables

  conv : Term  TC Term
  conv = convertTerm `ring numVars opNames (quantifiedVarMap numVars)

-- This is the main macro which solves for equations in which the
-- variables are universally quantified over:
--
--   lemma : ∀ x y → x + y ≈ y + x
--   lemma = solve-∀ ring
--
-- where ring is your implementation of AlmostCommutativeRing.
-- (Find some example implementations in
-- Polynomial.Solver.Ring.AlmostCommutativeRing.Instances).
solve-∀-macro : Name  Term  TC 
solve-∀-macro ring hole = do
  `ring  checkIsRing (def ring [])
  commitTC
  operatorTerms  getRingOperatorTerms `ring

  -- Obtain and sanitise the goal type
  `hole  inferType hole >>= reduce
  let variablesAndTypes , equation = stripPis `hole

  let variables = List.map proj₁ variablesAndTypes
  just (lhs  rhs  [])  pure (getVisibleArgs 2 equation)
    where nothing  malformedForallTypeError `hole

  solverCall  constructCallToSolver `ring operatorTerms variables lhs rhs
  unify hole solverCall

macro
  solve-∀ : Name  Term  TC 
  solve-∀ = solve-∀-macro

------------------------------------------------------------------------
-- Unquantified macro

malformedArgumentListError :  {a} {A : Set a}  Term  TC A
malformedArgumentListError found = typeError
  ( strErr "Malformed call to solve."
   strErr "First argument should be a list of free variables."
   strErr "Instead: "
   termErr found
   [])

malformedGoalError :  {a} {A : Set a}  Term  TC A
malformedGoalError found = typeError
  ( strErr "Malformed call to solve."
   strErr "Goal type should be of the form: LHS ≈ RHS"
   strErr "Instead: "
   termErr found
   [])

checkIsListOfVariables : Term  Term  TC Term
checkIsListOfVariables `ring `xs = checkType `xs (`List (`Carrier `ring)) >>= normalise

-- Extracts the deBruijn indices from a list of variables
getVariableIndices : Term  Maybe NatSet
getVariableIndices = go []
  where
  go : NatSet  Term  Maybe NatSet
  go t (var i [] `∷` xs) = go (insert i t) xs
  go t `[]`              = just t
  go _ _                 = nothing

constructSolution : Term  RingOperatorTerms  NatSet  Term  Term  TC Term
constructSolution `ring opTerms variables `lhs `rhs = do
  `lhsExpr  conv `lhs
  `rhsExpr  conv `rhs
  pure $ `trans `ring (`sym `ring `lhsExpr) `rhsExpr
  where
  numVars = List.length variables

  varMap : VarMap
  varMap i = Maybe.map  x  `I `ring numVars (toFinTerm x)) (lookup variables i)

  ρ : Term
  ρ = curriedTerm variables

  conv = λ t  do
    t'  convertTerm `ring numVars opTerms varMap t
    pure $ `correct `ring numVars t' ρ

-- Use this macro when you want to solve something *under* a lambda.
-- For example: say you have a long proof, and you just want the solver
-- to deal with an intermediate step. Call it like so:
--
--   lemma₃ : ∀ x y → x + y * 1 + 3 ≈ 2 + 1 + y + x
--   lemma₃ x y = begin
--     x + y * 1 + 3 ≈⟨ +-comm x (y * 1) ⟨ +-cong ⟩ refl ⟩
--     y * 1 + x + 3 ≈⟨ solve (x ∷ y ∷ []) Int.ring ⟩
--     3 + y + x     ≡⟨ refl ⟩
--     2 + 1 + y + x ∎
--
-- The first argument is the free variables, and the second is the
-- ring implementation (as before).
solve-macro : Term  Name  Term  TC 
solve-macro variables ring hole = do
  `ring  checkIsRing (def ring [])
  commitTC
  operatorTerms  getRingOperatorTerms `ring

  -- Obtain and sanitise the list of variables
  listOfVariables′  checkIsListOfVariables `ring variables
  commitTC
  just variableIndices  pure (getVariableIndices listOfVariables′)
    where nothing  malformedArgumentListError listOfVariables′

  -- Obtain and santise the goal type
  hole′  inferType hole >>= reduce
  just (lhs  rhs  [])  pure (getVisibleArgs 2 hole′)
    where nothing  malformedGoalError hole′

  solution  constructSolution `ring operatorTerms variableIndices lhs rhs
  unify hole solution

macro
  solve : Term  Name  Term  TC 
  solve = solve-macro