{-# LANGUAGE CPP #-}

-- | Utility functions for normalising, comparing types modulo type families.
module ConCat.NormaliseType (eqTypeM) where

import GHC.Plugins
#if MIN_VERSION_GLASGOW_HASKELL(9,4,0,0)
import GHC.HsToCore.Monad
import Data.Maybe (maybe)
import GHC.HsToCore.Monad
import GHC.Tc.Module (withTcPlugins, withHoleFitPlugins)
import GHC.Tc.Instance.Family (tcGetFamInstEnvs)
import GHC.Core.FamInstEnv (normaliseType)
import GHC.Core.Reduction (reductionReducedType)
import GHC.Tc.Types (TcM)
#endif

-- | Compare two types after first normalising out type families.
-- Returns 'Nothing' when they are equal, and 'Just' of the two normalised types if not.
eqTypeM :: HscEnv -> DynFlags -> ModGuts -> Type -> Type -> IO (Maybe (Type, Type))
#if MIN_VERSION_GLASGOW_HASKELL(9,4,0,0)
eqTypeM :: HscEnv
-> DynFlags -> ModGuts -> Type -> Type -> IO (Maybe (Type, Type))
eqTypeM HscEnv
env DynFlags
dflags ModGuts
guts Type
ty1 Type
ty2 =
  if Type
ty1 Type -> Type -> Bool
`eqType` Type
ty2
  then Maybe (Type, Type) -> IO (Maybe (Type, Type))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Type, Type)
forall a. Maybe a
Nothing
  else
  HscEnv
-> DynFlags
-> ModGuts
-> TcM (Maybe (Type, Type))
-> IO (Maybe (Type, Type))
forall a. HscEnv -> DynFlags -> ModGuts -> TcM a -> IO a
runTcForSolver HscEnv
env DynFlags
dflags ModGuts
guts (TcM (Maybe (Type, Type)) -> IO (Maybe (Type, Type)))
-> TcM (Maybe (Type, Type)) -> IO (Maybe (Type, Type))
forall a b. (a -> b) -> a -> b
$ do
    FamInstEnvs
famInstEnvs <- TcM FamInstEnvs
tcGetFamInstEnvs
    let normalisedTy1 :: Type
normalisedTy1 = Reduction -> Type
reductionReducedType (FamInstEnvs -> Role -> Type -> Reduction
normaliseType FamInstEnvs
famInstEnvs Role
Nominal Type
ty1)
    let normalisedTy2 :: Type
normalisedTy2 = Reduction -> Type
reductionReducedType (FamInstEnvs -> Role -> Type -> Reduction
normaliseType FamInstEnvs
famInstEnvs Role
Nominal Type
ty2)
    if Type
normalisedTy1 Type -> Type -> Bool
`eqType` Type
normalisedTy2
    then Maybe (Type, Type) -> TcM (Maybe (Type, Type))
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Type, Type)
forall a. Maybe a
Nothing
    else Maybe (Type, Type) -> TcM (Maybe (Type, Type))
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (m :: * -> *) a. Monad m => a -> m a
return ((Type, Type) -> Maybe (Type, Type)
forall a. a -> Maybe a
Just (Type
normalisedTy1, Type
normalisedTy2))

-- | run a DsM program inside IO
runDsM :: HscEnv -> DynFlags -> ModGuts -> DsM a -> IO a
runDsM :: forall a. HscEnv -> DynFlags -> ModGuts -> DsM a -> IO a
runDsM HscEnv
env DynFlags
dflags ModGuts
guts DsM a
m = do
  (Messages DsMessage
messages, Maybe a
result) <- HscEnv -> ModGuts -> DsM a -> IO (Messages DsMessage, Maybe a)
forall a.
HscEnv -> ModGuts -> DsM a -> IO (Messages DsMessage, Maybe a)
initDsWithModGuts HscEnv
env ModGuts
guts DsM a
m
  IO a -> (a -> IO a) -> Maybe a -> IO a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> IO a
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (DynFlags -> SDoc -> String
showSDoc DynFlags
dflags (Messages DsMessage -> SDoc
forall a. Outputable a => a -> SDoc
ppr Messages DsMessage
messages)))
        a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
result

-- | run a TcM program inside IO, with plugins initialised
runTcForSolver :: HscEnv -> DynFlags -> ModGuts -> TcM a -> IO a
runTcForSolver :: forall a. HscEnv -> DynFlags -> ModGuts -> TcM a -> IO a
runTcForSolver HscEnv
env DynFlags
dflags ModGuts
guts TcM a
m =
  HscEnv -> DynFlags -> ModGuts -> DsM a -> IO a
forall a. HscEnv -> DynFlags -> ModGuts -> DsM a -> IO a
runDsM HscEnv
env DynFlags
dflags ModGuts
guts (DsM a -> IO a) -> DsM a -> IO a
forall a b. (a -> b) -> a -> b
$ do
    TcM a -> DsM a
forall a. TcM a -> DsM a
initTcDsForSolver (TcM a -> DsM a) -> (TcM a -> TcM a) -> TcM a -> DsM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HscEnv -> TcM a -> TcM a
forall a. HscEnv -> TcM a -> TcM a
withTcPlugins HscEnv
env (TcM a -> TcM a) -> (TcM a -> TcM a) -> TcM a -> TcM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HscEnv -> TcM a -> TcM a
forall a. HscEnv -> TcM a -> TcM a
withHoleFitPlugins HscEnv
env (TcM a -> DsM a) -> TcM a -> DsM a
forall a b. (a -> b) -> a -> b
$ TcM a
m

-- | normalise a type wrt. type families
normaliseTypeM :: HscEnv -> DynFlags -> ModGuts -> Type -> IO Type
normaliseTypeM :: HscEnv -> DynFlags -> ModGuts -> Type -> IO Type
normaliseTypeM HscEnv
env DynFlags
dflags ModGuts
guts Type
ty =
  HscEnv -> DynFlags -> ModGuts -> TcM Type -> IO Type
forall a. HscEnv -> DynFlags -> ModGuts -> TcM a -> IO a
runTcForSolver HscEnv
env DynFlags
dflags ModGuts
guts (TcM Type -> IO Type) -> TcM Type -> IO Type
forall a b. (a -> b) -> a -> b
$ do
    FamInstEnvs
famInstEnvs <- TcM FamInstEnvs
tcGetFamInstEnvs
    let reduction :: Reduction
reduction = FamInstEnvs -> Role -> Type -> Reduction
normaliseType FamInstEnvs
famInstEnvs Role
Nominal Type
ty
    Type -> TcM Type
forall a. a -> IOEnv (Env TcGblEnv TcLclEnv) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Reduction -> Type
reductionReducedType Reduction
reduction)
#else
eqTypeM _ _ _ ty1 ty2 =
  if ty1 `eqType` ty2
  then return Nothing
  else return (Just (ty1, ty2))
#endif