{-# LANGUAGE CPP #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -Wno-unused-imports #-} -- TEMP

-- | Continuation-passing category

module ConCat.Continuation where

import Prelude hiding (id,(.),uncurry)

import Data.Constraint (Dict(..),(:-)(..))
import Data.Key (Zip)

import ConCat.Misc ((:*))
import ConCat.Rep
import qualified ConCat.Category
import ConCat.AltCat
import ConCat.Additive (Additive)

newtype Cont k r a b = Cont ((b `k` r) -> (a `k` r))

-- Could (->) here be another category? I think so.

instance HasRep (Cont k r a b) where
  type Rep (Cont k r a b) = (b `k` r) -> (a `k` r)
  abst :: Rep (Cont k r a b) -> Cont k r a b
abst Rep (Cont k r a b)
f = (k b r -> k a r) -> Cont k r a b
forall (k :: * -> * -> *) r a b. (k b r -> k a r) -> Cont k r a b
Cont Rep (Cont k r a b)
k b r -> k a r
f
  repr :: Cont k r a b -> Rep (Cont k r a b)
repr (Cont k b r -> k a r
f) = Rep (Cont k r a b)
k b r -> k a r
f

cont :: (Category k, Ok3 k r a b) => (a `k` b) -> Cont k r a b
cont :: forall (k :: * -> * -> *) r a b.
(Category k, Ok3 k r a b) =>
k a b -> Cont k r a b
cont k a b
f = Rep (Cont k r a b) -> Cont k r a b
forall a. HasRep a => Rep a -> a
abst (k b r -> k a b -> k a r
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. k a b
f)

instance Category (Cont k r) where
  type Ok (Cont k r) = Ok k
  id :: forall a. Ok (Cont k r) a => Cont k r a a
id = (k a r -> k a r) -> Cont k r a a
forall (k :: * -> * -> *) r a b. (k b r -> k a r) -> Cont k r a b
Cont k a r -> k a r
forall (k :: * -> * -> *) a. (Category k, Ok k a) => k a a
id
  . :: forall b c a.
Ok3 (Cont k r) a b c =>
Cont k r b c -> Cont k r a b -> Cont k r a c
(.) = (Rep (Cont k r b c) -> Rep (Cont k r a b) -> Rep (Cont k r a c))
-> Cont k r b c -> Cont k r a b -> Cont k r a c
forall p q r.
(HasRep p, HasRep q, HasRep r) =>
(Rep p -> Rep q -> Rep r) -> p -> q -> r
inAbst2 (((k b r -> k a r) -> (k c r -> k b r) -> k c r -> k a r)
-> (k c r -> k b r) -> (k b r -> k a r) -> k c r -> k a r
forall a b c. (a -> b -> c) -> b -> a -> c
flip (k b r -> k a r) -> (k c r -> k b r) -> k c r -> k a r
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
(.))

instance (MProductCat k, TerminalCat k, CoterminalCat k, CoproductPCat k, OkAdd k, Ok k r)
      => MonoidalPCat (Cont k r) where
  (***) :: forall a b c d. Ok4 k a b c d
        => Cont k r a c -> Cont k r b d -> Cont k r (Prod k a b) (Prod k c d)
  Cont k c r -> k a r
f *** :: forall a b c d.
Ok4 k a b c d =>
Cont k r a c -> Cont k r b d -> Cont k r (Prod k a b) (Prod k c d)
*** Cont k d r -> k b r
g = (k (Prod k c d) r -> k (Prod k a b) r)
-> Cont k r (Prod k a b) (Prod k c d)
forall (k :: * -> * -> *) r a b. (k b r -> k a r) -> Cont k r a b
Cont (Prod (->) (k a r) (k b r) -> k (Prod k a b) r
forall (k :: * -> * -> *) a c d.
(MCoproductPCat k, Ok3 k a c d, Additive a) =>
(k c a :* k d a) -> k (Prod k c d) a
joinP (Prod (->) (k a r) (k b r) -> k (Prod k a b) r)
-> (k (Prod k c d) r -> Prod (->) (k a r) (k b r))
-> k (Prod k c d) r
-> k (Prod k a b) r
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (k c r -> k a r
f (k c r -> k a r)
-> (k d r -> k b r)
-> Prod (->) (k c r) (k d r)
-> Prod (->) (k a r) (k b r)
forall (k :: * -> * -> *) a b c d.
(MonoidalPCat k, Ok4 k a b c d) =>
k a c -> k b d -> k (Prod k a b) (Prod k c d)
*** k d r -> k b r
g) (Prod (->) (k c r) (k d r) -> Prod (->) (k a r) (k b r))
-> (k (Prod k c d) r -> Prod (->) (k c r) (k d r))
-> k (Prod k c d) r
-> Prod (->) (k a r) (k b r)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. k (Prod k c d) r -> Prod (->) (k c r) (k d r)
forall (k :: * -> * -> *) a c d.
(MCoproductPCat k, C3 Additive a c d, Ok3 k a c d) =>
k (c :* d) a -> k c a :* k d a
unjoinP) 
    (Con (Sat Additive c) => Cont k r (Prod k a b) (Prod k c d))
-> (Sat (Ok k) c |- Sat Additive c)
-> Cont k r (Prod k a b) (Prod k c d)
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) a. OkAdd k => Ok' k a |- Sat Additive a
okAdd @k @c
    (Con (Sat Additive d) => Cont k r (Prod k a b) (Prod k c d))
-> (Sat (Ok k) d |- Sat Additive d)
-> Cont k r (Prod k a b) (Prod k c d)
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) a. OkAdd k => Ok' k a |- Sat Additive a
okAdd @k @d
    (Con (Sat Additive r) => Cont k r (Prod k a b) (Prod k c d))
-> (Sat (Ok k) r |- Sat Additive r)
-> Cont k r (Prod k a b) (Prod k c d)
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) a. OkAdd k => Ok' k a |- Sat Additive a
okAdd @k @r

-- TODO: Give non-default definitions for lassocP and rassocP, and relax
-- ProductCat k back to MonoidalPCat and drop TerminalCat k and CoterminalCat if
-- possible.

instance (MonoidalPCat k, CoproductPCat k, Ok k r, OkAdd k, Additive r)
      => BraidedPCat (Cont k r) where
  swapP :: forall a b. Ok2 k a b => Cont k r (a :* b) (b :* a)
  swapP :: forall a b. Ok2 k a b => Cont k r (a :* b) (b :* a)
swapP = (k (b :* a) r -> k (a :* b) r) -> Cont k r (a :* b) (b :* a)
forall (k :: * -> * -> *) r a b. (k b r -> k a r) -> Cont k r a b
Cont (Prod (->) (k a r) (k b r) -> k (a :* b) r
forall (k :: * -> * -> *) a c d.
(MCoproductPCat k, Ok3 k a c d, Additive a) =>
(k c a :* k d a) -> k (Prod k c d) a
joinP (Prod (->) (k a r) (k b r) -> k (a :* b) r)
-> (k (b :* a) r -> Prod (->) (k a r) (k b r))
-> k (b :* a) r
-> k (a :* b) r
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. Prod (->) (k b r) (k a r) -> Prod (->) (k a r) (k b r)
forall (k :: * -> * -> *) a b.
(BraidedPCat k, Ok2 k a b) =>
k (Prod k a b) (Prod k b a)
swapP (Prod (->) (k b r) (k a r) -> Prod (->) (k a r) (k b r))
-> (k (b :* a) r -> Prod (->) (k b r) (k a r))
-> k (b :* a) r
-> Prod (->) (k a r) (k b r)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. k (b :* a) r -> Prod (->) (k b r) (k a r)
forall (k :: * -> * -> *) a c d.
(MCoproductPCat k, C3 Additive a c d, Ok3 k a c d) =>
k (c :* d) a -> k c a :* k d a
unjoinP)
    (Con (Sat Additive a) => Cont k r (a :* b) (b :* a))
-> (Sat (Ok k) a |- Sat Additive a) -> Cont k r (a :* b) (b :* a)
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) a. OkAdd k => Ok' k a |- Sat Additive a
okAdd @k @a
    (Con (Sat Additive b) => Cont k r (a :* b) (b :* a))
-> (Sat (Ok k) b |- Sat Additive b) -> Cont k r (a :* b) (b :* a)
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) a. OkAdd k => Ok' k a |- Sat Additive a
okAdd @k @b

instance (ProductCat k, CoproductPCat k, AbelianCat k, OkAdd k, Ok k r)
      => ProductCat (Cont k r) where
  exl :: forall a b. Ok2 k a b => Cont k r (a :* b) a
  exl :: forall a b. Ok2 k a b => Cont k r (a :* b) a
exl = (k a r -> k (a :* b) r) -> Cont k r (a :* b) a
forall (k :: * -> * -> *) r a b. (k b r -> k a r) -> Cont k r a b
Cont (k a r -> k b r -> k (a :* b) r
forall (k :: * -> * -> *) a c d.
(MCoproductPCat k, Ok3 k a c d) =>
k c a -> k d a -> k (CoprodP k c d) a
|||| k b r
forall (k :: * -> * -> *) a b. (AbelianCat k, Ok2 k a b) => k a b
zeroC) (Con (Sat Additive r) => Cont k r (a :* b) a)
-> (Sat (Ok k) r |- Sat Additive r) -> Cont k r (a :* b) a
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) a. OkAdd k => Ok' k a |- Sat Additive a
okAdd @k @r
  exr :: forall a b. Ok2 k a b => Cont k r (a :* b) b
  exr :: forall a b. Ok2 k a b => Cont k r (a :* b) b
exr = (k b r -> k (a :* b) r) -> Cont k r (a :* b) b
forall (k :: * -> * -> *) r a b. (k b r -> k a r) -> Cont k r a b
Cont (k a r
forall (k :: * -> * -> *) a b. (AbelianCat k, Ok2 k a b) => k a b
zeroC k a r -> k b r -> k (a :* b) r
forall (k :: * -> * -> *) a c d.
(MCoproductPCat k, Ok3 k a c d) =>
k c a -> k d a -> k (CoprodP k c d) a
||||) (Con (Sat Additive r) => Cont k r (a :* b) b)
-> (Sat (Ok k) r |- Sat Additive r) -> Cont k r (a :* b) b
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) a. OkAdd k => Ok' k a |- Sat Additive a
okAdd @k @r
  dup :: forall a. Ok k a => Cont k r a (a :* a)
  dup :: forall a. Ok k a => Cont k r a (a :* a)
dup = (k (a :* a) r -> k a r) -> Cont k r a (a :* a)
forall (k :: * -> * -> *) r a b. (k b r -> k a r) -> Cont k r a b
Cont ((k a r -> Exp (->) (k a r) (k a r))
-> Prod (->) (k a r) (k a r) -> k a r
forall (k :: * -> * -> *) a b c.
(ClosedCat k, Ok3 k a b c) =>
k a (Exp k b c) -> k (Prod k a b) c
uncurry k a r -> Exp (->) (k a r) (k a r)
forall (k :: * -> * -> *) a b.
(AbelianCat k, Ok2 k a b) =>
Binop (k a b)
plusC (Prod (->) (k a r) (k a r) -> k a r)
-> (k (a :* a) r -> Prod (->) (k a r) (k a r))
-> k (a :* a) r
-> k a r
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. k (a :* a) r -> Prod (->) (k a r) (k a r)
forall (k :: * -> * -> *) a c d.
(MCoproductPCat k, C3 Additive a c d, Ok3 k a c d) =>
k (c :* d) a -> k c a :* k d a
unjoinP)
    (Con (Sat Additive a) => Cont k r a (a :* a))
-> (Sat (Ok k) a |- Sat Additive a) -> Cont k r a (a :* a)
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) a. OkAdd k => Ok' k a |- Sat Additive a
okAdd @k @a
    (Con (Sat Additive r) => Cont k r a (a :* a))
-> (Sat (Ok k) r |- Sat Additive r) -> Cont k r a (a :* a)
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) a. OkAdd k => Ok' k a |- Sat Additive a
okAdd @k @r

-- instance (CoproductPCat k, Ok k r) => CoproductPCat (Cont k r) where
--   inlP :: forall a b. Ok2 k a b => Cont k r a (a :* b)
--   inlP = cont inlP <+ okProd @k @a @b
--   inrP :: forall a b. Ok2 k a b => Cont k r b (a :* b)
--   inrP = cont inrP <+ okProd @k @a @b
--   (||||) = inAbst2 (\ f g -> uncurry (||||) . (f &&& g))


--            f       :: (c `k` r) -> (a `k` r)
--                  g :: (c `k` r) -> (b `k` r)
--            f &&& g :: (c `k` r) -> (a `k` r) :* (b `k` r)
-- uncurry (||||) . (f &&& g) :: (c `k` r) -> ((a :* b) `k` r)

-- TODO: Fix the ProductCat and CoproductPCat instances to match the paper.

-- instance (ProductCat k, TerminalCat k, CoproductPCat k, CoterminalCat k, OkUnit k, OkAdd k, Ok k r) => UnitCat (Cont k r)

-- TODO: fix this instance either via necessary superclasses or by not using defaults.

-- class (Category k, OkIxProd k h) => IxMonoidalPCat k h where
--   crossF :: forall a b. Ok2 k a b => h (a `k` b) -> (h a `k` h b)

instance (OkIxProd k h, Additive1 h, OkAdd k) => OkIxProd (Cont k r) h where
  okIxProd :: forall a. Ok' (Cont k r) a |- Ok' (Cont k r) (h a)
  okIxProd :: forall a. Ok' (Cont k r) a |- Ok' (Cont k r) (h a)
okIxProd = (Con (Sat (Ok k) a) :- Con (Sat (Ok k) (h a)))
-> Sat (Ok k) a |- Sat (Ok k) (h a)
forall a b. (Con a :- Con b) -> a |- b
Entail ((Ok k a => Dict (Ok k (h a))) -> Ok k a :- Ok k (h a)
forall (a :: Constraint) (b :: Constraint). (a => Dict b) -> a :- b
Sub (Dict (Ok k (h a))
Con (Sat (Ok k) (h a)) => Dict (Ok k (h a))
forall (a :: Constraint). a => Dict a
Dict (Con (Sat (Ok k) (h a)) => Dict (Ok k (h a)))
-> (Sat (Ok k) a |- Sat (Ok k) (h a)) -> Dict (Ok k (h a))
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) (h :: * -> *) a.
OkIxProd k h =>
Ok' k a |- Ok' k (h a)
okIxProd @k @h @a (Con (Sat Additive (h a)) => Dict (Ok k (h a)))
-> (Sat Additive a |- Sat Additive (h a)) -> Dict (Ok k (h a))
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (h :: * -> *) a.
Additive1 h =>
Sat Additive a |- Sat Additive (h a)
additive1 @h @a (Con (Sat Additive a) => Dict (Ok k (h a)))
-> (Sat (Ok k) a |- Sat Additive a) -> Dict (Ok k (h a))
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) a. OkAdd k => Ok' k a |- Sat Additive a
okAdd @k @a))

instance (Zip h, IxCoproductPCat k h, Additive1 h, OkAdd k, Ok k r)
      => IxMonoidalPCat (Cont k r) h where
  crossF :: forall a b. Ok2 k a b => h (Cont k r a b) -> Cont k r (h a) (h b)
  crossF :: forall a b. Ok2 k a b => h (Cont k r a b) -> Cont k r (h a) (h b)
crossF h (Cont k r a b)
fs = (k (h b) r -> k (h a) r) -> Cont k r (h a) (h b)
forall (k :: * -> * -> *) r a b. (k b r -> k a r) -> Cont k r a b
Cont (h (k a r) -> k (h a) r
forall (k :: * -> * -> *) (h :: * -> *) a b.
(IxCoproductPCat k h, Ok2 k a b) =>
h (k b a) -> k (h b) a
joinPF (h (k a r) -> k (h a) r)
-> (k (h b) r -> h (k a r)) -> k (h b) r -> k (h a) r
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. h (k b r -> k a r) -> h (k b r) -> h (k a r)
forall (k :: * -> * -> *) (h :: * -> *) a b.
(IxMonoidalPCat k h, Ok2 k a b) =>
h (k a b) -> k (h a) (h b)
crossF (Cont k r a b -> Rep (Cont k r a b)
Cont k r a b -> k b r -> k a r
forall a. HasRep a => a -> Rep a
repr (Cont k r a b -> k b r -> k a r)
-> h (Cont k r a b) -> h (k b r -> k a r)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> h (Cont k r a b)
fs) (h (k b r) -> h (k a r))
-> (k (h b) r -> h (k b r)) -> k (h b) r -> h (k a r)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. k (h b) r -> h (k b r)
forall (k :: * -> * -> *) (h :: * -> *) a b.
(IxCoproductPCat k h, Functor h, Ok2 k a b) =>
k (h b) a -> h (k b a)
ConCat.AltCat.unjoinPF) 

-- instance ({- Zip h, IxCoproductPCat k h, Additive1 h, OkAdd k, Ok k r-})
--       => IxProductCat (Cont k r) h where

instance (IxCoproductPCat k h, Zip h, Additive1 h, OkAdd k, Ok k r)
      => IxProductCat (Cont k r) h where
  exF    :: forall a. Ok (Cont k r) a => h (Cont k r (h a) a)
  replF  :: forall a. Ok (Cont k r) a => Cont k r a (h a)
  exF :: forall a. Ok (Cont k r) a => h (Cont k r (h a) a)
exF = h (Cont k r (h a) a)
forall a. HasCallStack => a
undefined
        -- ((Cont . joinPF) .) <$> inPF
  replF :: forall a. Ok (Cont k r) a => Cont k r a (h a)
replF = Cont k r a (h a)
forall a. HasCallStack => a
undefined

#if 0

inPF :: h (a `k` h a)

fmap (joinPF .) inPF :: 


need :: h ((a `k` r) -> (h a `k` r))

f :: a `k` h a

f' ::  (a `k` r) -> (h a `k` r)

joinPF :: (IxCoproductPCat k h, Ok2 k a b) => h (b `k` a) -> (h b `k` a)


#endif

-- joinPF :: forall a b . (IxProductCat k, Ok2 k a b) => h (b `k` a) -> (h b `k` a)
-- unjoinPF :: forall a b . (IxProductCat k, Ok2 k a b) => (h b `k` a) -> h (b `k` a)



-- instance ProductCat k => ProductCat (ContC k r) where
--   exl  = Cont (join . inl)
--   exr  = Cont (join . inr)
--   dup  = Cont (jamP . unjoin)

#if 0

exl :: Cont k r (a :* b) a
Cont (join . inl) :: Cont k r (a :* b) a
join . inl :: (a `k` r) -> (a :* b) `k` r

inl :: (a `k` r) -> (a `k` r) :* (b `k` r)
join :: (a `k` r) :* (b `k` r) -> (a :* b) `k` r

inPF :: h ((a `k` r) -> h (a `k` r))
joinPF :: h (a `k` r) -> (h a `k` r)

need :: h ((a `k` r) -> (h a `k` r))
need = (joinPF .) <$> inPF

need' :: h ((a `k` r) -> (h a `k` r))
need' = (joinPF .) <$> inPF

#endif

-- need :: (IxCoproductPCat k h)
--      => h ((a `k` r) -> (h a `k` r))
-- need = (joinPFF .) <$> inPF