{-# LANGUAGE CPP #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}

{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -Wno-unused-imports #-} -- TEMP

-- | A category of local approximations (and probably other uses)

module ConCat.Local where

import Prelude hiding (id,(.),curry,uncurry,const)
import Control.Applicative (pure,liftA2)

import Control.Newtype.Generics
import Data.Copointed
import Data.Constraint (Dict(..),(:-)(..))

import ConCat.Misc ((:*),inNew2,type (&&))
import qualified ConCat.Category as C
import ConCat.AltCat
import ConCat.Free.LinearRow
-- import ConCat.Rep

newtype Local k a b = Local (a -> (a `k` b))

instance Newtype (Local k a b) where
  type O (Local k a b) = a -> (a `k` b)
  pack :: O (Local k a b) -> Local k a b
pack O (Local k a b)
f = (a -> k a b) -> Local k a b
forall (k :: * -> * -> *) a b. (a -> k a b) -> Local k a b
Local O (Local k a b)
a -> k a b
f
  unpack :: Local k a b -> O (Local k a b)
unpack (Local a -> k a b
f) = O (Local k a b)
a -> k a b
f

simpleL :: (a `k` b) -> Local k a b
simpleL :: forall (k :: * -> * -> *) a b. k a b -> Local k a b
simpleL k a b
f = (a -> k a b) -> Local k a b
forall (k :: * -> * -> *) a b. (a -> k a b) -> Local k a b
Local (k a b -> a -> k a b
forall a. a -> a -> a
forall (f :: * -> *) a. Applicative f => a -> f a
pure k a b
f)

class    (Ok k a, Copointed (k a)) => OkLocal k a 
instance (Ok k a, Copointed (k a)) => OkLocal k a 

-- • Illegal constraint ‘Ok k a’ in a superclass context
--     (Use UndecidableInstances to permit this)
-- • Potential superclass cycle for ‘OkLocal’
--     one of whose superclass constraints is headed by a type family:
--       ‘Ok k a’
--   Use UndecidableSuperClasses to accept this

instance Category k => Category (Local k) where
  type Ok (Local k) = OkLocal k
  id :: forall a. Ok (Local k) a => Local k a a
id = k a a -> Local k a a
forall (k :: * -> * -> *) a b. k a b -> Local k a b
simpleL k a a
forall (k :: * -> * -> *) a. (Category k, Ok k a) => k a a
id
  . :: forall b c a.
Ok3 (Local k) a b c =>
Local k b c -> Local k a b -> Local k a c
(.) = (O (Local k b c) -> O (Local k a b) -> O (Local k a c))
-> Local k b c -> Local k a b -> Local k a c
forall p q r.
(Newtype p, Newtype q, Newtype r) =>
(O p -> O q -> O r) -> p -> q -> r
inNew2 ((O (Local k b c) -> O (Local k a b) -> O (Local k a c))
 -> Local k b c -> Local k a b -> Local k a c)
-> (O (Local k b c) -> O (Local k a b) -> O (Local k a c))
-> Local k b c
-> Local k a b
-> Local k a c
forall a b. (a -> b) -> a -> b
$ \ O (Local k b c)
g O (Local k a b)
f -> \ a
a -> let f' :: k a b
f' = O (Local k a b)
a -> k a b
f a
a
                                     g' :: k b c
g' = O (Local k b c)
b -> k b c
g (k a b -> b
forall a. k a a -> a
forall (p :: * -> *) a. Copointed p => p a -> a
copoint k a b
f')
                                 in
                                   k b c
g' k b c -> k a b -> k a c
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. k a b
f'

#if 0

type OkCart k = OpCon (Prod k) (Sat (Ok (Local k)))

instance OpCon (:*) (Sat (Ok k)) => OpCon (:*) (Sat (OkLocal k)) where
  inOp :: (Sat (OkLocal k) a, Sat (OkLocal k) b) |- Sat (OkLocal k) (a :* b)
  inOp = Entail (Sub Dict) -- <+ okProd @k @a @b

-- Could not deduce (Copointed (k (a, b)))

-- instance OpCon (:*) (Sat (Ok k)) => OpCon (:*) (Sat (OkLocal k)) where
--   inOp = Entail (Sub Dict)
--   {-# INLINE inOp #-}

instance (OkCart k, ProductCat k)
      => ProductCat (Local k) where
  exl = simpleL exl
  exr = simpleL exr
  (&&&) = (inNew2.liftA2) (&&&)

--    Local f &&& Local g
-- == Local (\ a -> f a &&& g a)
-- == Local (\ a da -> (f a da, g a da)
  
-- Affine approximation. Later make explicit via ConCat.Free.Affine.
instance (Num a, Copointed ((->) a)) => NumCat (Local (->)) a where
  negateC = Local (\ x dx -> - (x + dx))
  addC = Local (\ (x,y) (dx,dy) -> x + y + dx + dy)
  mulC = Local (\ (x,y) (dx,dy) -> x*y + dy*x + dx*y)

#endif