module Eval where

import Parser
import Research

import Data.List (partition, (\\))
import Data.Map  (Map)
import qualified Data.Map as Map
import qualified Data.Set as Set

evalSingle :: Env -> TricuAST -> Env
evalSingle env term
  | SDef name [] body <- term =
      if
        | Map.member name env ->
            errorWithoutStackTrace $
              "Error: Identifier '" ++ name ++ "' is already defined."
        | otherwise ->
            let res = evalAST env body
            in Map.insert "!result" res (Map.insert name res env)
  | SApp func arg <- term =
      let res = apply (evalAST env func) (evalAST env arg)
      in Map.insert "!result" res env
  | SVar name <- term =
      case Map.lookup name env of
        Just v  ->
          Map.insert "!result" v env
        Nothing ->
          errorWithoutStackTrace $ "Variable `" ++ name ++ "` not defined\n\
          \This error should never occur here. Please report this as an issue."
  | otherwise =
      Map.insert "!result" (evalAST env term) env

evalTricu :: Env -> [TricuAST] -> Env
evalTricu env x = go env (reorderDefs env x)
  where
    go env []     = env
    go env [x]    =
      let updatedEnv = evalSingle env x
      in Map.insert "!result" (result updatedEnv) updatedEnv
    go env (x:xs) =
      evalTricu (evalSingle env x) xs

evalAST :: Env -> TricuAST -> T
evalAST env term
  | SLambda _ _ <- term = evalAST env (elimLambda term)
  | SVar   name <- term = evalVar name
  | TLeaf       <- term = Leaf
  | TStem  t    <- term = Stem (evalAST env t)
  | TFork  t u  <- term = Fork (evalAST env t) (evalAST env u)
  | SApp   t u  <- term = apply (evalAST env t) (evalAST env u)
  | SStr   s    <- term = ofString s
  | SInt   n    <- term = ofNumber n
  | SList  xs   <- term = ofList (map (evalAST env) xs)
  | SEmpty      <- term = Leaf
  | otherwise           = errorWithoutStackTrace "Unexpected AST term"
    where
      evalVar name = Map.findWithDefault
        (errorWithoutStackTrace $ "Variable " ++ name ++ " not defined")
        name env

elimLambda :: TricuAST -> TricuAST
elimLambda = go
  where
    -- η-reduction
    go (SLambda [v] (SApp f (SVar x)))
      | v == x && not (isFree v f) = elimLambda f
    -- Triage optimization
    go (SLambda [a] (SLambda [b] (SLambda [c] body)))
      | body == triageBody         = _TRIAGE
      where
        triageBody =
          (SApp (SApp TLeaf (SApp (SApp TLeaf (SVar a)) (SVar b))) (SVar c))
    -- Composition optimization
    go (SLambda [f] (SLambda [g] (SLambda [x] body)))
      | body == composeBody        = _COMPOSE
      where
        composeBody = SApp (SVar f) (SApp (SVar g) (SVar x))
    -- General elimination
    go (SLambda (v:vs) body)
      | null vs                    = toSKI v (elimLambda body)
      | otherwise                  = elimLambda (SLambda [v] (SLambda vs body))
    go (SApp f g)                  = SApp (elimLambda f) (elimLambda g)
    go x                           = x

    toSKI x (SVar y)
      | x == y           = _I
      | otherwise        = SApp _K (SVar y)
    toSKI x t@(SApp n u)
      | not (isFree x t) = SApp _K t
      | otherwise        = SApp (SApp _S (toSKI x n)) (toSKI x u)
    toSKI x t
      | not (isFree x t) = SApp _K t
      | otherwise        = errorWithoutStackTrace "Unhandled toSKI conversion"

    _S       = parseSingle "t (t (t t t)) t"
    _K       = parseSingle "t t"
    _I       = parseSingle "t (t (t t)) t"
    _TRIAGE  = parseSingle "t (t (t t (t (t (t t t))))) t"
    _COMPOSE = parseSingle "t (t (t t (t (t (t t t)) t))) (t t)"

isFree :: String -> TricuAST -> Bool
isFree x = Set.member x . freeVars

freeVars :: TricuAST -> Set.Set String
freeVars (SVar    v    ) = Set.singleton v
freeVars (SInt    _    ) = Set.empty
freeVars (SStr    _    ) = Set.empty
freeVars (SList   s    ) = foldMap freeVars s
freeVars (SApp    f a  ) = freeVars f <> freeVars a
freeVars (TLeaf        ) = Set.empty
freeVars (SDef   _ _ b)  = freeVars b
freeVars (TStem   t    ) = freeVars t
freeVars (TFork   l r  ) = freeVars l <> freeVars r
freeVars (SLambda v b  ) = foldr Set.delete (freeVars b) v
freeVars _               = Set.empty

reorderDefs :: Env -> [TricuAST] -> [TricuAST]
reorderDefs env defs
  | not (null missingDeps) =
      errorWithoutStackTrace $
        "Missing dependencies detected: " ++ show missingDeps
  | otherwise = orderedDefs ++ others
  where
    (defsOnly, others) = partition isDef defs
    defNames = [ name | SDef name _ _ <- defsOnly ]

    defsWithFreeVars = [(def, freeVars body) | def@(SDef _ _ body) <- defsOnly]

    graph = buildDepGraph defsOnly
    sortedDefs = sortDeps graph
    defMap = Map.fromList [(name, def) | def@(SDef name _ _) <- defsOnly]
    orderedDefs = map (\name -> defMap Map.! name) sortedDefs

    freeVarsDefs = foldMap snd defsWithFreeVars
    freeVarsOthers = foldMap freeVars others
    allFreeVars = freeVarsDefs <> freeVarsOthers
    validNames = Set.fromList defNames `Set.union` Set.fromList (Map.keys env)
    missingDeps = Set.toList (allFreeVars `Set.difference` validNames)

    isDef (SDef _ _ _) = True
    isDef _            = False

buildDepGraph :: [TricuAST] -> Map.Map String (Set.Set String)
buildDepGraph topDefs
  | not (null duplicateNames) =
      errorWithoutStackTrace $
        "Duplicate definitions detected: " ++ show duplicateNames
  | otherwise =
      Map.fromList
        [ (name, depends topDefs (SDef name [] body))
        | SDef name _ body <- topDefs]
  where
    names = [name | SDef name _ _ <- topDefs]
    duplicateNames =
      [ name | (name, count) <- Map.toList (countOccurrences names) , count > 1]
    countOccurrences = foldr (\x -> Map.insertWith (+) x 1) Map.empty

sortDeps :: Map.Map String (Set.Set String) -> [String]
sortDeps graph = go [] Set.empty (Map.keys graph)
  where
    go sorted sortedSet [] = sorted
    go sorted sortedSet remaining =
      let ready = [ name | name <- remaining
                        , let deps = Map.findWithDefault Set.empty name graph
                        , Set.isSubsetOf deps sortedSet ]
          notReady = remaining \\ ready
      in if null ready
         then errorWithoutStackTrace
          "ERROR: Cyclic dependency detected and prohibited.\n\
          \RESOLVE: Use nested lambdas."
         else go (sorted ++ ready)
                 (Set.union sortedSet (Set.fromList ready))
                 notReady

depends :: [TricuAST] -> TricuAST -> Set.Set String
depends topDefs (SDef _ _ body) =
  Set.intersection
    (Set.fromList [n | SDef n _ _ <- topDefs])
    (freeVars body)
depends _ _ = Set.empty

result :: Env -> T
result r = case Map.lookup "!result" r of
  Just a -> a
  Nothing -> errorWithoutStackTrace "No !result field found in provided env"

mainResult :: Env -> T
mainResult r = case Map.lookup "main" r of
  Just  a -> a
  Nothing -> errorWithoutStackTrace "No valid definition for `main` found."