{-# LANGUAGE CPP #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE AllowAmbiguousTypes #-}

{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-orphans #-} -- TEMP
{-# OPTIONS_GHC -fno-warn-unused-imports #-} -- TEMP

#include "ConCat/AbsTy.inc"

-- | Generalized automatic differentiation

module ConCat.GAD where

import Prelude hiding (id,(.),curry,uncurry,const,zip,unzip,zipWith)
import qualified Prelude as P
-- import GHC.Exts (Coercible,coerce)
import GHC.Exts (Constraint)

import Data.Constraint (Dict(..),(:-)(..))

-- import GHC.Generics (Par1(..),(:.:)(..),(:*:)())
-- import Control.Newtype.Generics
-- import Data.Key (Zip(..))

import Data.Pointed
import Data.Key
import Data.Distributive (Distributive(..))
import Data.Functor.Rep (Representable)
import qualified Data.Functor.Rep as R

import ConCat.Misc ((:*),type (&&),type (&+&),cond,result,unzip,sqr,bottom)
import ConCat.Additive
import ConCat.AltCat
import ConCat.Rep

AbsTyImports

-- TODO: try again with importing Category qualified and AltCat unqualified.

newtype GD k a b = D { forall (k :: * -> * -> *) a b. GD k a b -> a -> b :* k a b
unD :: a -> b :* (a `k` b) }
-- data GD k a b = D { unD :: a -> (b :* (a `k` b)) }

mkD :: HasRep (a `k` b) => (a -> b :* Rep (a `k` b)) -> GD k a b
mkD :: forall (k :: * -> * -> *) a b.
HasRep (k a b) =>
(a -> b :* Rep (k a b)) -> GD k a b
mkD = (a -> b :* k a b) -> GD k a b
forall (k :: * -> * -> *) a b. (a -> b :* k a b) -> GD k a b
D ((a -> b :* k a b) -> GD k a b)
-> ((a -> b :* Rep (k a b)) -> a -> b :* k a b)
-> (a -> b :* Rep (k a b))
-> GD k a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
P.. (((b :* Rep (k a b)) -> b :* k a b)
-> (a -> b :* Rep (k a b)) -> a -> b :* k a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
result (((b :* Rep (k a b)) -> b :* k a b)
 -> (a -> b :* Rep (k a b)) -> a -> b :* k a b)
-> ((Rep (k a b) -> k a b) -> (b :* Rep (k a b)) -> b :* k a b)
-> (Rep (k a b) -> k a b)
-> (a -> b :* Rep (k a b))
-> a
-> b :* k a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
P.. (Rep (k a b) -> k a b) -> (b :* Rep (k a b)) -> b :* k a b
forall (k :: * -> * -> *) a b d.
(MonoidalPCat k, Ok3 k a b d) =>
k b d -> k (Prod k a b) (Prod k a d)
second) Rep (k a b) -> k a b
forall a. HasRep a => Rep a -> a
abst
{-# INLINE mkD #-}

unMkD :: HasRep (a `k` b) => GD k a b -> (a -> b :* Rep (a `k` b))
unMkD :: forall (k :: * -> * -> *) a b.
HasRep (k a b) =>
GD k a b -> a -> b :* Rep (k a b)
unMkD = (((b :* k a b) -> b :* Rep (k a b))
-> (a -> b :* k a b) -> a -> b :* Rep (k a b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
result (((b :* k a b) -> b :* Rep (k a b))
 -> (a -> b :* k a b) -> a -> b :* Rep (k a b))
-> ((k a b -> Rep (k a b)) -> (b :* k a b) -> b :* Rep (k a b))
-> (k a b -> Rep (k a b))
-> (a -> b :* k a b)
-> a
-> b :* Rep (k a b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
P.. (k a b -> Rep (k a b)) -> (b :* k a b) -> b :* Rep (k a b)
forall (k :: * -> * -> *) a b d.
(MonoidalPCat k, Ok3 k a b d) =>
k b d -> k (Prod k a b) (Prod k a d)
second) k a b -> Rep (k a b)
forall a. HasRep a => a -> Rep a
repr ((a -> b :* k a b) -> a -> b :* Rep (k a b))
-> (GD k a b -> a -> b :* k a b)
-> GD k a b
-> a
-> b :* Rep (k a b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
P.. GD k a b -> a -> b :* k a b
forall (k :: * -> * -> *) a b. GD k a b -> a -> b :* k a b
unD
{-# INLINE unMkD #-}

-- Differentiable linear function, given the function and its constant derivative
linearD :: (a -> b) -> (a `k` b) -> GD k a b
-- linearD f f' = D (f &&& const f')
linearD :: forall a b (k :: * -> * -> *). (a -> b) -> k a b -> GD k a b
linearD a -> b
f k a b
f' = (a -> b :* k a b) -> GD k a b
forall (k :: * -> * -> *) a b. (a -> b :* k a b) -> GD k a b
D (\ a
a -> (a -> b
f a
a, k a b
f'))
{-# INLINE linearD #-}

-- -- Differentiable linear function
-- linear :: (a -> b) -> GD k a b
-- linear f = linearD f (toCcc' f)
-- {-# INLINE linear #-}

-- Use of linear leads to an plugin error. TODO: track down. linear also has the
-- unfortunate effect of hiding the requirements on k, resulting in run-time
-- errors instead of compile-time errors.

-- instance Newtype (D s a b) where
--   type O (D s a b) = (a -> b :* L s a b)
--   pack f = D f
--   unpack (D f) = f

instance HasRep (GD k a b) where
  type Rep (GD k a b) = (a -> b :* (a `k` b))
  abst :: Rep (GD k a b) -> GD k a b
abst Rep (GD k a b)
f = (a -> b :* k a b) -> GD k a b
forall (k :: * -> * -> *) a b. (a -> b :* k a b) -> GD k a b
D Rep (GD k a b)
a -> b :* k a b
f
  repr :: GD k a b -> Rep (GD k a b)
repr (D a -> b :* k a b
f) = Rep (GD k a b)
a -> b :* k a b
f

AbsTy(GD k a b)

-- Common pattern for linear functions
#define Linear(nm) nm = linearD nm nm ; {-# INLINE nm #-}

instance (TerminalCat k, CoterminalCat k, ConstCat k b, Additive b)
      => ConstCat (GD k) b where
  const :: forall a.
Ok (GD k) a =>
ConstObj (GD k) b -> GD k a (ConstObj (GD k) b)
const ConstObj (GD k) b
b = (a -> ConstObj (GD k) b)
-> k a (ConstObj (GD k) b) -> GD k a (ConstObj (GD k) b)
forall a b (k :: * -> * -> *). (a -> b) -> k a b -> GD k a b
linearD (ConstObj (GD k) b -> a -> ConstObj (GD k) b
forall (k :: * -> * -> *) b a. (ConstCat k b, Ok k a) => b -> k a b
const ConstObj (GD k) b
b) (ConstObj (GD k) b -> k a (ConstObj (GD k) b)
forall (k :: * -> * -> *) b a. (ConstCat k b, Ok k a) => b -> k a b
const ConstObj (GD k) b
forall a. Additive a => a
zero)
  {-# INLINE const #-}

-- What if we went further, and defined nonlinear arrows like mulC as if linear?
-- Probably wouldn't work, since the linear approximations depend on input. On
-- the other hand, maybe approximations of function shiftings at zero.

instance Category k => Category (GD k) where
  type Ok (GD k) = Ok k
  Linear(id)
  -- D g . D f = D (\ a ->
  --   let (b,f') = f a
  --       (c,g') = g b
  --   in
  --     (c, g' . f'))
  D b -> c :* k b c
g . :: forall b c a. Ok3 (GD k) a b c => GD k b c -> GD k a b -> GD k a c
. D a -> b :* k a b
f = (a -> c :* k a c) -> GD k a c
forall (k :: * -> * -> *) a b. (a -> b :* k a b) -> GD k a b
D (\ a
a -> let { (b
b,k a b
f') = a -> b :* k a b
f a
a ; (c
c,k b c
g') = b -> c :* k b c
g b
b } in (c
c, k b c
g' k b c -> k a b -> k a c
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. k a b
f'))

  {-# INLINE (.) #-}

instance AssociativePCat k => AssociativePCat (GD k) where
  Linear(lassocP)
  Linear(rassocP)

instance BraidedPCat k => BraidedPCat (GD k) where
  Linear(swapP)

instance MonoidalPCat k => MonoidalPCat (GD k) where
  -- D f *** D g = D (second (uncurry (***)) . transposeP . (f *** g))
  D a -> c :* k a c
f *** :: forall a b c d.
Ok4 (GD k) a b c d =>
GD k a c -> GD k b d -> GD k (Prod (GD k) a b) (Prod (GD k) c d)
*** D b -> d :* k b d
g =
    (Prod (GD k) a b
 -> Prod (GD k) c d :* k (Prod (GD k) a b) (Prod (GD k) c d))
-> GD k (Prod (GD k) a b) (Prod (GD k) c d)
forall (k :: * -> * -> *) a b. (a -> b :* k a b) -> GD k a b
D (\ (a
a,b
b) -> let { (c
c,k a c
f') = a -> c :* k a c
f a
a ; (d
d,k b d
g') = b -> d :* k b d
g b
b } in ((c
c,d
d), k a c
f' k a c -> k b d -> k (Prod (GD k) a b) (Prod (GD k) c d)
forall (k :: * -> * -> *) a b c d.
(MonoidalPCat k, Ok4 k a b c d) =>
k a c -> k b d -> k (Prod k a b) (Prod k c d)
*** k b d
g'))
  {-# INLINE (***) #-}

instance ProductCat k => ProductCat (GD k) where
  Linear(exl)
  Linear(exr)
  Linear(dup)

instance UnitCat k => UnitCat (GD k) where
  Linear(lunit)
  Linear(runit)
  Linear(lcounit)
  Linear(rcounit)

instance OkAdd k => OkAdd (GD k) where
  okAdd :: forall a. Ok' (GD k) a |- Sat Additive a
  okAdd :: forall a. Ok' (GD k) a |- Sat Additive a
okAdd = (Con (Sat (Ok k) a) :- Con (Sat Additive a))
-> Sat (Ok k) a |- Sat Additive a
forall a b. (Con a :- Con b) -> a |- b
Entail ((Ok k a => Dict (Additive a)) -> Ok k a :- Additive a
forall (a :: Constraint) (b :: Constraint). (a => Dict b) -> a :- b
Sub (Dict (Additive a)
Con (Sat Additive a) => Dict (Additive a)
forall (a :: Constraint). a => Dict a
Dict (Con (Sat Additive a) => Dict (Additive a))
-> (Sat (Ok k) a |- Sat Additive a) -> Dict (Additive a)
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) a. OkAdd k => Ok' k a |- Sat Additive a
okAdd @k @a))

#if 0
-- Unused, I think, and relies on CoproductPCat (->).
instance CoproductPCat k => CoproductPCat (GD k) where
  Linear(inlP)
  Linear(inrP)
  Linear(jamP)
  -- Linear(swapPS)
  -- D f ++++ D g = D (second (uncurry (++++)) . transposeP . (f ++++ g))
  -- D f ++++ D g = D (\ (a,b) ->
  --   let (c,f') = f a
  --       (d,g') = g b
  --   in
  --     ((c,d), f' ++++ g'))
  -- D f ++++ D g =
  --   D (\ (a,b) -> let { (c,f') = f a ; (d,g') = g b } in ((c,d), f' ++++ g'))
  -- {-# INLINE (++++) #-}
  -- D f |||| D g = D (\ (a,b) ->
  --   let (c ,f') = f a
  --       (c',g') = g b
  --   in
  --     (c ^+^ c', f' |||| g')) -- or default
  -- {-# INLINE (||||) #-}

#endif

{--------------------------------------------------------------------
    Indexed products and coproducts
--------------------------------------------------------------------}

#if 0
class (Category k, OkIxProd k h) => IxProductCat k h where
  exF    :: forall a  . Ok  k a   => h (h a `k` a)
  forkF  :: forall a b. Ok2 k a b => h (a `k` b) -> (a `k` h b)
  crossF :: forall a b. Ok2 k a b => h (a `k` b) -> (h a `k` h b)
  replF  :: forall a  . Ok  k a   => a `k` h a

class (Category k, OkIxProd k h) => IxCoproductPCat k h where
  inPF   :: forall a   . (Additive a, Ok  k a  ) => h (a `k` h a)
  joinPF :: forall a b . (Additive a, Ok2 k a b) => h (b `k` a) -> (h b `k` a)
  plusPF :: forall a b . (Additive a, Ok2 k a b) => h (b `k` a) -> (h b `k` h a)  -- same as crossPF
  jamPF  :: forall a   . (Additive a, Ok  k a  ) => h a `k` a

class OkIxProd k h where
  okIxProd :: Ok' k a |- Ok' k h a
#endif

instance OkIxProd k h => OkIxProd (GD k) h where
  okIxProd :: forall a. Ok' (GD k) a |- Ok' (GD k) (h a)
  okIxProd :: forall a. Ok' (GD k) a |- Ok' (GD k) (h a)
okIxProd = (Con (Sat (Ok k) a) :- Con (Sat (Ok k) (h a)))
-> Sat (Ok k) a |- Sat (Ok k) (h a)
forall a b. (Con a :- Con b) -> a |- b
Entail ((Ok k a => Dict (Ok k (h a))) -> Ok k a :- Ok k (h a)
forall (a :: Constraint) (b :: Constraint). (a => Dict b) -> a :- b
Sub (Dict (Ok k (h a))
Con (Sat (Ok k) (h a)) => Dict (Ok k (h a))
forall (a :: Constraint). a => Dict a
Dict (Con (Sat (Ok k) (h a)) => Dict (Ok k (h a)))
-> (Sat (Ok k) a |- Sat (Ok k) (h a)) -> Dict (Ok k (h a))
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) (h :: * -> *) a.
OkIxProd k h =>
Ok' k a |- Ok' k (h a)
okIxProd @k @h @a))

#define Linears(nm) nm = zipWith linearD nm nm

instance (IxMonoidalPCat (->) h, IxMonoidalPCat k h, Zip h) => IxMonoidalPCat (GD k) h where
  crossF :: forall a b. Ok2 (GD k) a b => h (GD k a b) -> GD k (h a) (h b)
crossF ((GD k a b -> a -> b :* k a b)
-> h (GD k a b) -> h (a -> b :* k a b)
forall a b. (a -> b) -> h a -> h b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap GD k a b -> a -> b :* k a b
forall (k :: * -> * -> *) a b. GD k a b -> a -> b :* k a b
unD -> h (a -> b :* k a b)
fs) = (h a -> Prod (->) (h b) (k (h a) (h b))) -> GD k (h a) (h b)
forall (k :: * -> * -> *) a b. (a -> b :* k a b) -> GD k a b
D ((h (k a b) -> k (h a) (h b))
-> Prod (->) (h b) (h (k a b)) -> Prod (->) (h b) (k (h a) (h b))
forall (k :: * -> * -> *) a b d.
(MonoidalPCat k, Ok3 k a b d) =>
k b d -> k (Prod k a b) (Prod k a d)
second h (k a b) -> k (h a) (h b)
forall (k :: * -> * -> *) (h :: * -> *) a b.
(IxMonoidalPCat k h, Ok2 k a b) =>
h (k a b) -> k (h a) (h b)
crossF (Prod (->) (h b) (h (k a b)) -> Prod (->) (h b) (k (h a) (h b)))
-> (h a -> Prod (->) (h b) (h (k a b)))
-> h a
-> Prod (->) (h b) (k (h a) (h b))
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. h (b :* k a b) -> Prod (->) (h b) (h (k a b))
forall (f :: * -> *) a b. Functor f => f (a :* b) -> f a :* f b
unzip (h (b :* k a b) -> Prod (->) (h b) (h (k a b)))
-> (h a -> h (b :* k a b)) -> h a -> Prod (->) (h b) (h (k a b))
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. h (a -> b :* k a b) -> h a -> h (b :* k a b)
forall (k :: * -> * -> *) (h :: * -> *) a b.
(IxMonoidalPCat k h, Ok2 k a b) =>
h (k a b) -> k (h a) (h b)
crossF h (a -> b :* k a b)
fs)
  {-# INLINE crossF #-}

instance (IxProductCat (->) h, IxProductCat k h, Zip h) => IxProductCat (GD k) h where
  Linears(exF)
  Linear(replF)

-- crossF types:
-- 
--   crossF fs     :: h a -> h (b :* (a `k` b))
--   unzip         :: .. -> h b :* h (a `k` b)
--   second crossF :: .. -> h b :* (h a `k` h b

-- instance (IxCoproductPCat (->) h, IxCoproductPCat k h, Zip h) => IxCoproductPCat (GD k) h where
--   Linears(inPF)
--   Linear(jamPF)
--   -- plusPF (fmap repr -> fs) = D (second plusPF . unzip . plusPF fs)
--   -- {-# INLINE plusPF #-}

{--------------------------------------------------------------------
    NumCat etc
--------------------------------------------------------------------}

instance {-# overlappable #-} (LinearCat k s, Additive s, Num s) => NumCat (GD k) s where
  addC :: GD k (Prod (GD k) s s) s
addC    = (Prod (GD k) s s -> s)
-> k (Prod (GD k) s s) s -> GD k (Prod (GD k) s s) s
forall a b (k :: * -> * -> *). (a -> b) -> k a b -> GD k a b
linearD Prod (GD k) s s -> s
forall (k :: * -> * -> *) a. NumCat k a => k (Prod k a a) a
addC k (Prod (GD k) s s) s
forall (k :: * -> * -> *) a.
(CoproductPCat k, Ok k a) =>
k (CoprodP k a a) a
jamP
  negateC :: GD k s s
negateC = (s -> s) -> k s s -> GD k s s
forall a b (k :: * -> * -> *). (a -> b) -> k a b -> GD k a b
linearD s -> s
forall (k :: * -> * -> *) a. NumCat k a => k a a
negateC (s -> k s s
forall (k :: * -> * -> *) a. ScalarCat k a => a -> k a a
scale (-s
1))
  mulC :: GD k (Prod (GD k) s s) s
mulC    = (Prod (GD k) s s -> s :* k (Prod (GD k) s s) s)
-> GD k (Prod (GD k) s s) s
forall (k :: * -> * -> *) a b. (a -> b :* k a b) -> GD k a b
D (Prod (GD k) s s -> s
forall (k :: * -> * -> *) a. NumCat k a => k (Prod k a a) a
mulC (Prod (GD k) s s -> s)
-> (Prod (GD k) s s -> k (Prod (GD k) s s) s)
-> Prod (GD k) s s
-> s :* k (Prod (GD k) s s) s
forall (k :: * -> * -> *) a c d.
(MProductCat k, Ok3 k a c d) =>
k a c -> k a d -> k a (Prod k c d)
&&& \ (s
u,s
v) -> s -> k s s
forall (k :: * -> * -> *) a. ScalarCat k a => a -> k a a
scale s
v k s s -> k s s -> k (Prod (GD k) s s) s
forall (k :: * -> * -> *) a c d.
(MCoproductPCat k, Ok3 k a c d) =>
k c a -> k d a -> k (CoprodP k c d) a
|||| s -> k s s
forall (k :: * -> * -> *) a. ScalarCat k a => a -> k a a
scale s
u) -- \ (du,dv) -> u*dv + v*du
  powIC :: Ok (GD k) Int => GD k (Prod (GD k) s Int) s
powIC   = String -> GD k (Prod (GD k) s Int) s
forall a. String -> a
notDef String
"powIC"       -- TODO
  {-# INLINE negateC #-}
  {-# INLINE addC    #-}
  {-# INLINE mulC    #-}
  {-# INLINE powIC   #-}

scalarD :: ScalarCat k s => (s -> s) -> (s -> s -> s) -> GD k s s
scalarD :: forall (k :: * -> * -> *) s.
ScalarCat k s =>
(s -> s) -> (s -> s -> s) -> GD k s s
scalarD s -> s
f s -> s -> s
d = (s -> s :* k s s) -> GD k s s
forall (k :: * -> * -> *) a b. (a -> b :* k a b) -> GD k a b
D (\ s
x -> let r :: s
r = s -> s
f s
x in (s
r, s -> k s s
forall (k :: * -> * -> *) a. ScalarCat k a => a -> k a a
scale (s -> s -> s
d s
x s
r)))
{-# INLINE scalarD #-}

-- Specializations

scalarR :: ScalarCat k s => (s -> s) -> (s -> s) -> GD k s s
scalarR :: forall (k :: * -> * -> *) s.
ScalarCat k s =>
(s -> s) -> (s -> s) -> GD k s s
scalarR s -> s
f s -> s
x = (s -> s) -> (s -> s -> s) -> GD k s s
forall (k :: * -> * -> *) s.
ScalarCat k s =>
(s -> s) -> (s -> s -> s) -> GD k s s
scalarD s -> s
f ((s -> s) -> s -> s -> s
forall (k :: * -> * -> *) b a. (ConstCat k b, Ok k a) => b -> k a b
const s -> s
x)
{-# INLINE scalarR #-}

scalarX :: ScalarCat k s => (s -> s) -> (s -> s) -> GD k s s
scalarX :: forall (k :: * -> * -> *) s.
ScalarCat k s =>
(s -> s) -> (s -> s) -> GD k s s
scalarX s -> s
f s -> s
r = (s -> s) -> (s -> s -> s) -> GD k s s
forall (k :: * -> * -> *) s.
ScalarCat k s =>
(s -> s) -> (s -> s -> s) -> GD k s s
scalarD s -> s
f (s -> s -> s
forall (k :: * -> * -> *) b a. (ConstCat k b, Ok k a) => b -> k a b
const (s -> s -> s) -> (s -> s) -> s -> s -> s
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. s -> s
r)
{-# INLINE scalarX #-}

instance (LinearCat k s, Additive s, Fractional s) => FractionalCat (GD k) s where
  recipC :: GD k s s
recipC = (s -> s) -> (s -> s) -> GD k s s
forall (k :: * -> * -> *) s.
ScalarCat k s =>
(s -> s) -> (s -> s) -> GD k s s
scalarR s -> s
forall a. Fractional a => a -> a
recip (s -> s
forall a. Num a => a -> a
negate (s -> s) -> (s -> s) -> s -> s
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. s -> s
forall a. Num a => a -> a
sqr)
  {-# INLINE recipC #-}

instance (ScalarCat k s, Ok k s, Floating s) => FloatingCat (GD k) s where
  expC :: GD k s s
expC = (s -> s) -> (s -> s) -> GD k s s
forall (k :: * -> * -> *) s.
ScalarCat k s =>
(s -> s) -> (s -> s) -> GD k s s
scalarR s -> s
forall a. Floating a => a -> a
exp s -> s
forall (k :: * -> * -> *) a. (Category k, Ok k a) => k a a
id
  logC :: GD k s s
logC = (s -> s) -> (s -> s) -> GD k s s
forall (k :: * -> * -> *) s.
ScalarCat k s =>
(s -> s) -> (s -> s) -> GD k s s
scalarX s -> s
forall a. Floating a => a -> a
log s -> s
forall a. Fractional a => a -> a
recip
  sinC :: GD k s s
sinC = (s -> s) -> (s -> s) -> GD k s s
forall (k :: * -> * -> *) s.
ScalarCat k s =>
(s -> s) -> (s -> s) -> GD k s s
scalarX s -> s
forall a. Floating a => a -> a
sin s -> s
forall a. Floating a => a -> a
cos
  cosC :: GD k s s
cosC = (s -> s) -> (s -> s) -> GD k s s
forall (k :: * -> * -> *) s.
ScalarCat k s =>
(s -> s) -> (s -> s) -> GD k s s
scalarX s -> s
forall a. Floating a => a -> a
cos (s -> s
forall a. Num a => a -> a
negate (s -> s) -> (s -> s) -> s -> s
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. s -> s
forall a. Floating a => a -> a
sin)
  sqrtC :: GD k s s
sqrtC = (s -> s) -> (s -> s) -> GD k s s
forall (k :: * -> * -> *) s.
ScalarCat k s =>
(s -> s) -> (s -> s) -> GD k s s
scalarR s -> s
forall a. Floating a => a -> a
sqrt (s -> s
forall a. Fractional a => a -> a
recip (s -> s) -> (s -> s) -> s -> s
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. s -> s -> s
forall (k :: * -> * -> *) a. ScalarCat k a => a -> k a a
scale s
2)
  tanhC :: GD k s s
tanhC = (s -> s) -> (s -> s) -> GD k s s
forall (k :: * -> * -> *) s.
ScalarCat k s =>
(s -> s) -> (s -> s) -> GD k s s
scalarR s -> s
forall a. Floating a => a -> a
tanh ((-) s
1 (s -> s) -> (s -> s) -> s -> s
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. s -> s
forall a. Num a => a -> a
sqr)
  {-# INLINE expC #-}
  {-# INLINE sinC #-}
  {-# INLINE cosC #-}
  {-# INLINE logC #-}
  {-# INLINE sqrtC #-}
  {-# INLINE tanhC #-}

-- TODO: experiment with moving some of these dual derivatives to DualAdditive,
-- in the style of addD, mulD, etc.

instance (ProductCat k, Ok k a, Ord a) => MinMaxCat (GD k) a where
  minC :: GD k (Prod (GD k) a a) a
minC = (Prod (GD k) a a -> a :* k (Prod (GD k) a a) a)
-> GD k (Prod (GD k) a a) a
forall (k :: * -> * -> *) a b. (a -> b :* k a b) -> GD k a b
D (Prod (GD k) a a -> a
forall (k :: * -> * -> *) a.
(MinMaxCat k a, Ok k a) =>
k (Prod k a a) a
minC (Prod (GD k) a a -> a)
-> (Prod (GD k) a a -> k (Prod (GD k) a a) a)
-> Prod (GD k) a a
-> a :* k (Prod (GD k) a a) a
forall (k :: * -> * -> *) a c d.
(MProductCat k, Ok3 k a c d) =>
k a c -> k a d -> k a (Prod k c d)
&&& k (Prod (GD k) a a) a
-> k (Prod (GD k) a a) a -> BoolOf (->) -> k (Prod (GD k) a a) a
forall a. a -> a -> BoolOf (->) -> a
cond k (Prod (GD k) a a) a
forall (k :: * -> * -> *) a b.
(ProductCat k, Ok2 k a b) =>
k (Prod k a b) a
exl k (Prod (GD k) a a) a
forall (k :: * -> * -> *) a b.
(ProductCat k, Ok2 k a b) =>
k (Prod k a b) b
exr (BoolOf (->) -> k (Prod (GD k) a a) a)
-> (Prod (GD k) a a -> BoolOf (->))
-> Prod (GD k) a a
-> k (Prod (GD k) a a) a
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. Prod (GD k) a a -> BoolOf (->)
forall (k :: * -> * -> *) a.
OrdCat k a =>
k (Prod k a a) (BoolOf (->))
lessThanOrEqual)
  maxC :: GD k (Prod (GD k) a a) a
maxC = (Prod (GD k) a a -> a :* k (Prod (GD k) a a) a)
-> GD k (Prod (GD k) a a) a
forall (k :: * -> * -> *) a b. (a -> b :* k a b) -> GD k a b
D (Prod (GD k) a a -> a
forall (k :: * -> * -> *) a.
(MinMaxCat k a, Ok k a) =>
k (Prod k a a) a
maxC (Prod (GD k) a a -> a)
-> (Prod (GD k) a a -> k (Prod (GD k) a a) a)
-> Prod (GD k) a a
-> a :* k (Prod (GD k) a a) a
forall (k :: * -> * -> *) a c d.
(MProductCat k, Ok3 k a c d) =>
k a c -> k a d -> k a (Prod k c d)
&&& k (Prod (GD k) a a) a
-> k (Prod (GD k) a a) a -> BoolOf (->) -> k (Prod (GD k) a a) a
forall a. a -> a -> BoolOf (->) -> a
cond k (Prod (GD k) a a) a
forall (k :: * -> * -> *) a b.
(ProductCat k, Ok2 k a b) =>
k (Prod k a b) b
exr k (Prod (GD k) a a) a
forall (k :: * -> * -> *) a b.
(ProductCat k, Ok2 k a b) =>
k (Prod k a b) a
exl (BoolOf (->) -> k (Prod (GD k) a a) a)
-> (Prod (GD k) a a -> BoolOf (->))
-> Prod (GD k) a a
-> k (Prod (GD k) a a) a
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. Prod (GD k) a a -> BoolOf (->)
forall (k :: * -> * -> *) a.
OrdCat k a =>
k (Prod k a a) (BoolOf (->))
lessThanOrEqual)
  {-# INLINE minC #-} 
  {-# INLINE maxC #-} 

-- Equivalently,
-- 
-- minC = D (\ (x,y) -> (minC (x,y), if x <= y then exl else exr))
-- maxC = D (\ (x,y) -> (maxC (x,y), if x <= y then exr else exl))

-- Functor-level operations:

-- TODO: IfCat. Maybe make ifC :: (a :* a) `k` (Bool -> a), which is linear.

{--------------------------------------------------------------------
    Discrete
--------------------------------------------------------------------}

-- Experiment

-- Differentiable discrete function, yielding 'bottom' derivative
discreteD :: (ConstCat k b, Ok k a, Additive b) => (a -> b) -> GD k a b
discreteD :: forall (k :: * -> * -> *) b a.
(ConstCat k b, Ok k a, Additive b) =>
(a -> b) -> GD k a b
discreteD a -> b
f = (a -> b :* k a b) -> GD k a b
forall (k :: * -> * -> *) a b. (a -> b :* k a b) -> GD k a b
D (\ a
a -> (a -> b
f a
a, b -> k a b
forall (k :: * -> * -> *) b a. (ConstCat k b, Ok k a) => b -> k a b
const b
forall a. Additive a => a
zero))
{-# INLINE discreteD #-}

#define DiscreteEntail(nm,ent) nm = discreteD nm <+ (ent) ; {-# INLINE nm #-}
#define Discrete(nm) DiscreteEntail(nm,id @(|-) @())
#define DiscreteBB(nm) DiscreteEntail(nm,okTT @k @Bool)
#define DiscreteAA(nm) DiscreteEntail(nm,okTT @k @a)

instance (ProductCat k, ConstCat k Bool, Ok k Bool) => BoolCat (GD k) where
  Discrete(notC)
  DiscreteBB(andC)
  DiscreteBB(orC)
  DiscreteBB(xorC)

instance (ProductCat k, ConstCat k Bool, Eq a, Ok2 k a Bool) => EqCat (GD k) a where
  DiscreteAA(equal)
  DiscreteAA(notEqual)

instance (ProductCat k, ConstCat k Bool, Ord a, Ok2 k a Bool) => OrdCat (GD k) a where
  DiscreteAA(greaterThan)
  DiscreteAA(lessThan)
  DiscreteAA(lessThanOrEqual)
  DiscreteAA(greaterThanOrEqual)

instance (ProductCat k, ConstCat k Bool, Ok2 k Bool a) => IfCat (GD k) a where
  -- Linear(ifC)
  -- ifC = D (ifC &&& \ (i,(t,e)) -> ifC (i,(der t, der e)))
  ifC :: GD k (Bool :* (a :* a)) a
  ifC :: IfT (GD k) a
ifC = -- D (ifC &&& \ (i,_) -> ifC (i,(exl,exr)) . exr)
        -- D (ifC &&& \ (i,_) -> ifC (i,(exl.exr,exr.exr)))
        -- D (ifC &&& \ (i,_) -> cond exl exr i . exr)
        ((BoolOf (->) :* (a :* a)) -> a :* k (BoolOf (->) :* (a :* a)) a)
-> IfT (GD k) a
forall (k :: * -> * -> *) a b. (a -> b :* k a b) -> GD k a b
D ((BoolOf (->) :* (a :* a)) -> a
forall (k :: * -> * -> *) a.
IfCat k a =>
k (Prod k (BoolOf (->)) (Prod k a a)) a
ifC ((BoolOf (->) :* (a :* a)) -> a)
-> ((BoolOf (->) :* (a :* a)) -> k (BoolOf (->) :* (a :* a)) a)
-> (BoolOf (->) :* (a :* a))
-> a :* k (BoolOf (->) :* (a :* a)) a
forall (k :: * -> * -> *) a c d.
(MProductCat k, Ok3 k a c d) =>
k a c -> k a d -> k a (Prod k c d)
&&& \ (BoolOf (->)
i,a :* a
_) -> (if BoolOf (->)
i then k (a :* a) a
forall (k :: * -> * -> *) a b.
(ProductCat k, Ok2 k a b) =>
k (Prod k a b) a
exl else k (a :* a) a
forall (k :: * -> * -> *) a b.
(ProductCat k, Ok2 k a b) =>
k (Prod k a b) b
exr) k (a :* a) a
-> k (BoolOf (->) :* (a :* a)) (a :* a)
-> k (BoolOf (->) :* (a :* a)) a
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. k (BoolOf (->) :* (a :* a)) (a :* a)
forall (k :: * -> * -> *) a b.
(ProductCat k, Ok2 k a b) =>
k (Prod k a b) b
exr)
        (Con (Sat (Ok k) (BoolOf (->) :* (a :* a))) => IfT (GD k) a)
-> ((Sat (Ok k) (BoolOf (->)) && Sat (Ok k) (a :* a))
    |- Sat (Ok k) (BoolOf (->) :* (a :* a)))
-> IfT (GD k) a
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) a b.
OkProd k =>
(Ok' k a && Ok' k b) |- Ok' k (Prod k a b)
okProd @k @Bool @(a :* a)
        (Con (Sat (Ok k) (a :* a)) => IfT (GD k) a)
-> ((Sat (Ok k) a && Sat (Ok k) a) |- Sat (Ok k) (a :* a))
-> IfT (GD k) a
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) a b.
OkProd k =>
(Ok' k a && Ok' k b) |- Ok' k (Prod k a b)
okProd @k @a @a

{--------------------------------------------------------------------
    Functor-level operations
--------------------------------------------------------------------}

instance (IxProductCat k h, Functor h, FunctorCat k h) => FunctorCat (GD k) h where
  fmapC :: forall a b. Ok2 (GD k) a b => GD k a b -> GD k (h a) (h b)
fmapC = (Rep (GD k a b) -> Rep (GD k (h a) (h b)))
-> GD k a b -> GD k (h a) (h b)
forall p q. (HasRep p, HasRep q) => (Rep p -> Rep q) -> p -> q
inAbst (\ Rep (GD k a b)
q -> (h (k a b) -> k (h a) (h b))
-> Prod (->) (h b) (h (k a b)) -> Prod (->) (h b) (k (h a) (h b))
forall (k :: * -> * -> *) a b d.
(MonoidalPCat k, Ok3 k a b d) =>
k b d -> k (Prod k a b) (Prod k a d)
second h (k a b) -> k (h a) (h b)
forall (k :: * -> * -> *) (h :: * -> *) a b.
(IxMonoidalPCat k h, Ok2 k a b) =>
h (k a b) -> k (h a) (h b)
crossF (Prod (->) (h b) (h (k a b)) -> Prod (->) (h b) (k (h a) (h b)))
-> (h a -> Prod (->) (h b) (h (k a b)))
-> h a
-> Prod (->) (h b) (k (h a) (h b))
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. h (b :* k a b) -> Prod (->) (h b) (h (k a b))
forall (k :: * -> * -> *) (h :: * -> *) a b.
(FunctorCat k h, Ok2 k a b) =>
k (h (a :* b)) (h a :* h b)
unzipC (h (b :* k a b) -> Prod (->) (h b) (h (k a b)))
-> (h a -> h (b :* k a b)) -> h a -> Prod (->) (h b) (h (k a b))
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> b :* k a b) -> h a -> h (b :* k a b)
forall (k :: * -> * -> *) (h :: * -> *) a b.
(FunctorCat k h, Ok2 k a b) =>
k a b -> k (h a) (h b)
fmapC Rep (GD k a b)
a -> b :* k a b
q)
  unzipC :: forall a b. Ok2 (GD k) a b => GD k (h (a :* b)) (h a :* h b)
Linear(unzipC)
  {-# INLINE fmapC #-}

-- See 2017-12-27 notes
-- 
--      q :: a -> b :* (a `k` b)
-- fmap q :: h a -> h (b :* (a `k` b))
-- unzip  :: h (b :* (a `k` b)) -> h b :* h (a `k` b)
-- crossF :: h (a `k` b) -> (h a `k` h b)

instance OkFunctor k h => OkFunctor (GD k) h where
  okFunctor :: forall a. Ok' (GD k) a |- Ok' (GD k) (h a)
  okFunctor :: forall a. Ok' (GD k) a |- Ok' (GD k) (h a)
okFunctor = (Con (Sat (Ok k) a) :- Con (Sat (Ok k) (h a)))
-> Sat (Ok k) a |- Sat (Ok k) (h a)
forall a b. (Con a :- Con b) -> a |- b
Entail ((Ok k a => Dict (Ok k (h a))) -> Ok k a :- Ok k (h a)
forall (a :: Constraint) (b :: Constraint). (a => Dict b) -> a :- b
Sub (Dict (Ok k (h a))
Con (Sat (Ok k) (h a)) => Dict (Ok k (h a))
forall (a :: Constraint). a => Dict a
Dict (Con (Sat (Ok k) (h a)) => Dict (Ok k (h a)))
-> (Sat (Ok k) a |- Sat (Ok k) (h a)) -> Dict (Ok k (h a))
forall a b r. Con a => (Con b => r) -> (a |- b) -> r
<+ forall (k :: * -> * -> *) (h :: * -> *) a.
OkFunctor k h =>
Ok' k a |- Ok' k (h a)
okFunctor @k @h @a))
  -- okFunctor = Entail (Sub Dict)
  -- okFunctor = inForkCon (yes1 *** okFunctor @k @h @a)
  {-# INLINE okFunctor #-}

-- TODO: FunctorCat. See RAD

instance (AddCat (->) h a, AddCat k h a, OkFunctor (GD k) h)
      => AddCat (GD k) h a where
  Linear(sumAC)

instance (ZipCat k h, OkFunctor (GD k) h) => ZipCat (GD k) h where
  Linear(zipC)
  -- zipWithC = ??
  -- {-# INLINE zipWithC #-}

instance (ZapCat k h, OkFunctor k h, Zip h) => ZapCat (GD k) h where
  zapC :: forall a b. Ok2 (GD k) a b => h (GD k a b) -> GD k (h a) (h b)
zapC = Rep (GD k (h a) (h b)) -> GD k (h a) (h b)
(h a -> h b :* k (h a) (h b)) -> GD k (h a) (h b)
forall a. HasRep a => Rep a -> a
abst ((h a -> h b :* k (h a) (h b)) -> GD k (h a) (h b))
-> (h (GD k a b) -> h a -> h b :* k (h a) (h b))
-> h (GD k a b)
-> GD k (h a) (h b)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (h (b :* k a b) -> h b :* k (h a) (h b))
-> (h a -> h (b :* k a b)) -> h a -> h b :* k (h a) (h b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
result ((h (k a b) -> k (h a) (h b))
-> Prod (->) (h b) (h (k a b)) -> h b :* k (h a) (h b)
forall (k :: * -> * -> *) a b d.
(MonoidalPCat k, Ok3 k a b d) =>
k b d -> k (Prod k a b) (Prod k a d)
second h (k a b) -> k (h a) (h b)
forall (k :: * -> * -> *) (h :: * -> *) a b.
(ZapCat k h, Ok2 k a b) =>
h (k a b) -> k (h a) (h b)
zapC (Prod (->) (h b) (h (k a b)) -> h b :* k (h a) (h b))
-> (h (b :* k a b) -> Prod (->) (h b) (h (k a b)))
-> h (b :* k a b)
-> h b :* k (h a) (h b)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. h (b :* k a b) -> Prod (->) (h b) (h (k a b))
forall (f :: * -> *) a b. Functor f => f (a :* b) -> f a :* f b
unzip) ((h a -> h (b :* k a b)) -> h a -> h b :* k (h a) (h b))
-> (h (GD k a b) -> h a -> h (b :* k a b))
-> h (GD k a b)
-> h a
-> h b :* k (h a) (h b)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. h (a -> b :* k a b) -> h a -> h (b :* k a b)
forall (k :: * -> * -> *) (h :: * -> *) a b.
(ZapCat k h, Ok2 k a b) =>
h (k a b) -> k (h a) (h b)
zapC (h (a -> b :* k a b) -> h a -> h (b :* k a b))
-> (h (GD k a b) -> h (a -> b :* k a b))
-> h (GD k a b)
-> h a
-> h (b :* k a b)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (GD k a b -> a -> b :* k a b)
-> h (GD k a b) -> h (a -> b :* k a b)
forall a b. (a -> b) -> h a -> h b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap GD k a b -> Rep (GD k a b)
GD k a b -> a -> b :* k a b
forall a. HasRep a => a -> Rep a
repr

-- fmap repr            :: h (GD k a b) -> h (a -> b :* k a b)
-- zapC                 :: h (a -> b :* k a b) -> (h a -> h (b :* k a b))
-- result unzip         :: (h a -> h (b :* k a b)) -> (h a -> h b :* h (k a b))
-- (result.second) zapC :: (h a -> h b :* h (k a b)) -> (h a -> h b :* k h a h b)
-- abst                 :: (h a -> h b :* k h a h b) -> GD k h a h b

-- TODO: What use can we make of the ZapCat instance? Maybe repeated differentiation.

instance (OkFunctor (GD k) h, Pointed h, PointedCat k h a) => PointedCat (GD k) h a where
  pointC :: GD k a (h a)
Linear(pointC)

-- instance (IxProductCat k h, FunctorCat k h, Strong k h)
--       => Strong (GD k) h where
--   Linear(strength)

instance (TraversableCat (->) t f, TraversableCat k t f)
      => TraversableCat (GD k) t f where
  Linear(sequenceAC)

instance (DistributiveCat (->) g f, DistributiveCat k g f)
      => DistributiveCat (GD k) g f where
  Linear(distributeC)

instance (RepresentableCat (->) g, RepresentableCat k g)
      => RepresentableCat (GD k) g where
  indexC :: forall a. Ok (GD k) a => GD k (g a) (Rep g -> a)
Linear(indexC)
  Linear(tabulateC)

instance (ProductCat k, MinMaxFFunctorCat k h a, Ord a) => MinMaxFunctorCat (GD k) h a where
  minimumC :: GD k (h a) a
minimumC = Rep (GD k (h a) a) -> GD k (h a) a
forall a. HasRep a => Rep a -> a
abst Rep (GD k (h a) a)
h a -> a :* k (h a) a
forall (k :: * -> * -> *) (h :: * -> *) a.
(MinMaxFFunctorCat k h a, OkFunctor k h, Ok k a) =>
h a -> a :* k (h a) a
minimumCF
  {-# INLINE minimumC #-}
  maximumC :: GD k (h a) a
maximumC = Rep (GD k (h a) a) -> GD k (h a) a
forall a. HasRep a => Rep a -> a
abst Rep (GD k (h a) a)
h a -> a :* k (h a) a
forall (k :: * -> * -> *) (h :: * -> *) a.
(MinMaxFFunctorCat k h a, OkFunctor k h, Ok k a) =>
h a -> a :* k (h a) a
maximumCF
  {-# INLINE maximumC #-}


{--------------------------------------------------------------------
    Other instances
--------------------------------------------------------------------}

notDef :: String -> a
notDef :: forall a. String -> a
notDef String
meth = String -> a
forall a. HasCallStack => String -> a
error (String
meth String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" on D not defined")

instance (RepCat (->) a r, RepCat k a r) => RepCat (GD k) a r where
  Linear(reprC)
  Linear(abstC)

#if 0
instance (Coercible a b, V s a ~ V s b, Ok2 k a b) => CoerceCat (GD k) a b where
  Linear(coerceC)
#else
instance ( CoerceCat (->) a b
         , CoerceCat k a b
         ) => CoerceCat (GD k) a b where
  Linear(coerceC)
#endif

{--------------------------------------------------------------------
    Differentiation interface
--------------------------------------------------------------------}

-- | A function combined with its derivative
andDeriv :: forall k a b . (a -> b) -> (a -> b :* (a `k` b))
andDeriv :: forall (k :: * -> * -> *) a b. (a -> b) -> a -> b :* k a b
andDeriv a -> b
h = GD k a b -> a -> b :* k a b
forall (k :: * -> * -> *) a b. GD k a b -> a -> b :* k a b
unD ((a -> b) -> GD k a b
forall (k :: * -> * -> *) a b. (a -> b) -> k a b
toCcc a -> b
h)
{-# INLINE andDeriv #-}

-- | The derivative of a given function
deriv :: forall k a b . (a -> b) -> (a -> (a `k` b))
deriv :: forall (k :: * -> * -> *) a b. (a -> b) -> a -> k a b
deriv a -> b
h = (b, k a b) -> k a b
forall a b. (a, b) -> b
snd ((b, k a b) -> k a b) -> (a -> (b, k a b)) -> a -> k a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
P.. (a -> b) -> a -> (b, k a b)
forall (k :: * -> * -> *) a b. (a -> b) -> a -> b :* k a b
andDeriv a -> b
h
{-# INLINE deriv #-}