{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}

{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# OPTIONS_GHC -fno-warn-unused-imports #-} -- TEMP

-- | Automatic differentiation

module ConCat.ADFun where

import Prelude hiding (id,(.),curry,uncurry,const,zip,unzip)
import qualified Prelude as P
-- import Debug.Trace (trace)
import GHC.Generics (Par1)

import Control.Newtype.Generics (Newtype(..))
import Data.Pointed (Pointed)
import Data.Key (Zip(..))
import Data.Constraint hiding ((&&&),(***),(:=>))
import Data.Distributive (Distributive(..))
import Data.Functor.Rep (Representable(..))

import ConCat.Misc ((:*),R,Yes1,unzip,type (&+&),sqr,result)
import ConCat.Rep (repr)
import ConCat.Free.VectorSpace (HasV(..),inV,IsScalar)
import ConCat.Free.LinearRow -- hiding (linear)
import ConCat.AltCat
import ConCat.GAD -- hiding (linear)
import ConCat.AdditiveFun
-- The following imports allows the instances to type-check. Why?
import qualified ConCat.Category  as C

-- Differentiable functions
type D = GD (-+>)

#if 0
instance ClosedCat D where
  apply = D (\ (f,a) -> (f a, \ (df,da) -> df a ^+^ deriv f a da))
  curry (D h) = D (\ a -> (curry f a, \ da -> \ b -> f' (a,b) (da,zero)))
   where
     (f,f') = unfork h
  {-# INLINE apply #-}
  {-# INLINE curry #-}

-- TODO: generalize to ClosedCat k for an arbitrary CCC k. I guess I can simply
-- apply ccc to the lambda expressions.
#elif 0

instance ClosedCat D where
  apply = applyD ; {-# INLINE apply #-}
  curry = curryD ; {-# INLINE curry #-}

applyD :: forall a b. Ok2 D a b => D ((a -> b) :* a) b
-- applyD = D (\ (f,a) -> (f a, \ (df,da) -> df a ^+^ f da))
applyD = -- trace "calling applyD" $
 D (\ (f,a) -> let (b,f') = andDerF f a in (b, \ (df,da) -> df a ^+^ f' da))
-- applyD = oops "applyD called"   -- does it?

curryD :: forall a b c. Ok3 D a b c => D (a :* b) c -> D a (b -> c)
curryD (D (unfork -> (f,f'))) =
  D (\ a -> (curry f a, \ da -> \ b -> f' (a,b) (da,zero)))

{-# INLINE applyD #-}
{-# INLINE curryD #-}
#else
-- No ClosedCat D instance
#endif

{--------------------------------------------------------------------
    Differentiation interface
--------------------------------------------------------------------}

andDerF :: forall a b . (a -> b) -> (a -> b :* (a -> b))
andDerF :: forall a b. (a -> b) -> a -> b :* (a -> b)
andDerF a -> b
f = GD (-+>) a b -> a -> b :* Rep (a -+> b)
forall (k :: * -> * -> *) a b.
HasRep (k a b) =>
GD k a b -> a -> b :* Rep (k a b)
unMkD (forall (k :: * -> * -> *) a b. (a -> b) -> k a b
toCcc @D a -> b
f)
{-# INLINE andDerF #-}

andDerF' :: forall a b . (a -> b) -> (a -> b :* (a -> b))
andDerF' :: forall a b. (a -> b) -> a -> b :* (a -> b)
andDerF' a -> b
f = GD (-+>) a b -> a -> b :* Rep (a -+> b)
forall (k :: * -> * -> *) a b.
HasRep (k a b) =>
GD k a b -> a -> b :* Rep (k a b)
unMkD (forall (k :: * -> * -> *) a b. (a -> b) -> k a b
toCcc' @D a -> b
f)
{-# INLINE andDerF' #-}

-- Type specialization of deriv
derF :: forall a b . (a -> b) -> (a -> (a -> b))
derF :: forall a b. (a -> b) -> a -> a -> b
derF = (((a -> (b, a -> b)) -> a -> a -> b)
-> ((a -> b) -> a -> (b, a -> b)) -> (a -> b) -> a -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
result (((a -> (b, a -> b)) -> a -> a -> b)
 -> ((a -> b) -> a -> (b, a -> b)) -> (a -> b) -> a -> a -> b)
-> (((b, a -> b) -> a -> b) -> (a -> (b, a -> b)) -> a -> a -> b)
-> ((b, a -> b) -> a -> b)
-> ((a -> b) -> a -> (b, a -> b))
-> (a -> b)
-> a
-> a
-> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
P.. ((b, a -> b) -> a -> b) -> (a -> (b, a -> b)) -> a -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
result) (b, a -> b) -> a -> b
forall a b. (a, b) -> b
snd (a -> b) -> a -> (b, a -> b)
forall a b. (a -> b) -> a -> b :* (a -> b)
andDerF
-- derF f = \ a -> snd (andDerF f a)
{-# INLINE derF #-}

-- AD with derivative-as-function, then converted to linear map
andDerFL :: forall s a b. HasLin s a b => (a -> b) -> (a -> b :* L s a b)
andDerFL :: forall s a b. HasLin s a b => (a -> b) -> a -> b :* L s a b
andDerFL a -> b
f = ((a -> b) -> L s a b)
-> Prod (->) b (a -> b) -> Prod (->) b (L s a b)
forall (k :: * -> * -> *) a b d.
(MonoidalPCat k, Ok3 k a b d) =>
k b d -> k (Prod k a b) (Prod k a d)
second (a -> b) -> L s a b
forall s a b. HasLin s a b => (a -> b) -> L s a b
linear (Prod (->) b (a -> b) -> Prod (->) b (L s a b))
-> (a -> Prod (->) b (a -> b)) -> a -> Prod (->) b (L s a b)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> b) -> a -> Prod (->) b (a -> b)
forall a b. (a -> b) -> a -> b :* (a -> b)
andDerF a -> b
f
{-# INLINE andDerFL #-}

-- AD with derivative-as-function, then converted to linear map
derFL :: forall s a b. HasLin s a b => (a -> b) -> (a -> L s a b)
derFL :: forall s a b. HasLin s a b => (a -> b) -> a -> L s a b
derFL a -> b
f = (a -> b) -> L s a b
forall s a b. HasLin s a b => (a -> b) -> L s a b
linear ((a -> b) -> L s a b) -> (a -> a -> b) -> a -> L s a b
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> b) -> a -> a -> b
forall a b. (a -> b) -> a -> a -> b
derF a -> b
f
{-# INLINE derFL #-}

dualV :: forall s a. (HasLin s a s, IsScalar s) => (a -> s) -> a
dualV :: forall s a. (HasLin s a s, IsScalar s) => (a -> s) -> a
dualV a -> s
h = V s a s -> a
forall s a. HasV s a => V s a s -> a
unV ((:-*) (V s a) Par1 s -> O ((:-*) (V s a) Par1 s)
forall n. Newtype n => n -> O n
unpack (L s a s -> O (L s a s)
forall n. Newtype n => n -> O n
unpack (forall s a b. HasLin s a b => (a -> b) -> L s a b
linear @s a -> s
h)))
{-# INLINE dualV #-}

--                                h    :: a -> s
--                      linear @s h    :: L s a s
--              unpack (linear @s h)   :: V s s (V s a s)
--                                     :: Par1 (V s a s)
--      unpack (unpack (linear @s h))  :: V s a s
-- unV (unpack (unpack (linear @s h))) :: a

andGradFL :: (HasLin s a s, IsScalar s) => (a -> s) -> (a -> s :* a)
andGradFL :: forall s a. (HasLin s a s, IsScalar s) => (a -> s) -> a -> s :* a
andGradFL a -> s
f = ((a -> s) -> a) -> Prod (->) s (a -> s) -> Prod (->) s a
forall (k :: * -> * -> *) a b d.
(MonoidalPCat k, Ok3 k a b d) =>
k b d -> k (Prod k a b) (Prod k a d)
second (a -> s) -> a
forall s a. (HasLin s a s, IsScalar s) => (a -> s) -> a
dualV (Prod (->) s (a -> s) -> Prod (->) s a)
-> (a -> Prod (->) s (a -> s)) -> a -> Prod (->) s a
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> s) -> a -> Prod (->) s (a -> s)
forall a b. (a -> b) -> a -> b :* (a -> b)
andDerF a -> s
f
{-# INLINE andGradFL #-}

gradF :: (HasLin s a s, IsScalar s) => (a -> s) -> (a -> a)
gradF :: forall s a. (HasLin s a s, IsScalar s) => (a -> s) -> a -> a
gradF a -> s
f = (a -> s) -> a
forall s a. (HasLin s a s, IsScalar s) => (a -> s) -> a
dualV ((a -> s) -> a) -> (a -> a -> s) -> a -> a
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> s) -> a -> a -> s
forall a b. (a -> b) -> a -> a -> b
derF a -> s
f
{-# INLINE gradF #-}

-- NOTE: gradF is fairly expensive due to linear (via dualV). For efficiency,
-- use GD (Dual AdditiveFun) instead.

#if 1

{--------------------------------------------------------------------
    Conversion to linear map. Replace HasL in LinearRow and LinearCol
--------------------------------------------------------------------}

linear1 :: (Representable f, Eq (Rep f), Num s)
        => (f s -> s) -> f s
-- linear1 = (<$> diag 0 1)
linear1 :: forall (f :: * -> *) s.
(Representable f, Eq (Rep f), Num s) =>
(f s -> s) -> f s
linear1 = ((f s -> s) -> f (f s) -> f s
forall (k :: * -> * -> *) (h :: * -> *) a b.
(FunctorCat k h, Ok2 k a b) =>
k a b -> k (h a) (h b)
`fmapC` s -> s -> f (f s)
forall (h :: * -> *) a. Diagonal h => a -> a -> h (h a)
diag s
0 s
1)
{-# INLINE linear1 #-}

linearN :: (Representable f, Eq (Rep f), Distributive g, Num s)
        => (f s -> g s) -> (f :-* g) s
linearN :: forall (f :: * -> *) (g :: * -> *) s.
(Representable f, Eq (Rep f), Distributive g, Num s) =>
(f s -> g s) -> (:-*) f g s
linearN f s -> g s
h = (f s -> s) -> f s
forall (f :: * -> *) s.
(Representable f, Eq (Rep f), Num s) =>
(f s -> s) -> f s
linear1 ((f s -> s) -> f s) -> g (f s -> s) -> g (f s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (f s -> g s) -> g (f s -> s)
forall (g :: * -> *) (f :: * -> *) a.
(Distributive g, Functor f) =>
f (g a) -> g (f a)
forall (f :: * -> *) a. Functor f => f (g a) -> g (f a)
distribute f s -> g s
h
{-# INLINE linearN #-}

-- h :: f s -> g s
-- distribute h :: g (f s -> s)
-- linear1 <$> distribute h :: g (f s)

{--------------------------------------------------------------------
    Alternative definitions using Representable
--------------------------------------------------------------------}

type RepresentableV s a = (HasV s a, Representable (V s a))
type RepresentableVE s a = (RepresentableV s a, Eq (Rep (V s a)))
type HasLinR s a b = (RepresentableVE s a, RepresentableV s b, Num s)

linearNR :: ( HasV s a, RepresentableVE s a
            , HasV s b, Distributive (V s b), Num s )
         => (a -> b) -> L s a b
linearNR :: forall s a b.
(HasV s a, RepresentableVE s a, HasV s b, Distributive (V s b),
 Num s) =>
(a -> b) -> L s a b
linearNR a -> b
h = O (L s a b) -> L s a b
forall n. Newtype n => O n -> n
pack ((V s a s -> s) -> V s a s
forall (f :: * -> *) s.
(Representable f, Eq (Rep f), Num s) =>
(f s -> s) -> f s
linear1 ((V s a s -> s) -> V s a s)
-> V s b (V s a s -> s) -> V s b (V s a s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (V s a s -> V s b s) -> V s b (V s a s -> s)
forall (g :: * -> *) (f :: * -> *) a.
(Distributive g, Functor f) =>
f (g a) -> g (f a)
forall (f :: * -> *) a. Functor f => f (V s b a) -> V s b (f a)
distribute ((a -> b) -> V s a s -> V s b s
forall s a b.
(HasV s a, HasV s b) =>
(a -> b) -> V s a s -> V s b s
inV a -> b
h))
{-# INLINE linearNR #-}

-- AD with derivative-as-function, then converted to linear map
andDerFLR :: forall s a b. HasLinR s a b => (a -> b) -> (a -> b :* L s a b)
andDerFLR :: forall s a b. HasLinR s a b => (a -> b) -> a -> b :* L s a b
andDerFLR a -> b
f = ((a -> b) -> L s a b)
-> Prod (->) b (a -> b) -> Prod (->) b (L s a b)
forall (k :: * -> * -> *) a b d.
(MonoidalPCat k, Ok3 k a b d) =>
k b d -> k (Prod k a b) (Prod k a d)
second (a -> b) -> L s a b
forall s a b.
(HasV s a, RepresentableVE s a, HasV s b, Distributive (V s b),
 Num s) =>
(a -> b) -> L s a b
linearNR (Prod (->) b (a -> b) -> Prod (->) b (L s a b))
-> (a -> Prod (->) b (a -> b)) -> a -> Prod (->) b (L s a b)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> b) -> a -> Prod (->) b (a -> b)
forall a b. (a -> b) -> a -> b :* (a -> b)
andDerF a -> b
f
{-# INLINE andDerFLR #-}

-- AD with derivative-as-function, then converted to linear map
derFLR :: forall s a b. HasLinR s a b => (a -> b) -> (a -> L s a b)
derFLR :: forall s a b. HasLinR s a b => (a -> b) -> a -> L s a b
derFLR a -> b
f = (a -> b) -> L s a b
forall s a b.
(HasV s a, RepresentableVE s a, HasV s b, Distributive (V s b),
 Num s) =>
(a -> b) -> L s a b
linearNR ((a -> b) -> L s a b) -> (a -> a -> b) -> a -> L s a b
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> b) -> a -> a -> b
forall a b. (a -> b) -> a -> a -> b
derF a -> b
f
{-# INLINE derFLR #-}

dualVR :: forall s a. (HasV s a, RepresentableVE s a, IsScalar s, Num s)
      => (a -> s) -> a
dualVR :: forall s a.
(HasV s a, RepresentableVE s a, IsScalar s, Num s) =>
(a -> s) -> a
dualVR a -> s
h = V (O (Par1 s)) a (O (Par1 s)) -> a
forall s a. HasV s a => V s a s -> a
unV ((V (O (Par1 s)) a (O (Par1 s)) -> O (Par1 s))
-> V (O (Par1 s)) a (O (Par1 s))
forall (f :: * -> *) s.
(Representable f, Eq (Rep f), Num s) =>
(f s -> s) -> f s
linear1 (Par1 s -> O (Par1 s)
forall n. Newtype n => n -> O n
unpack (Par1 s -> O (Par1 s))
-> (V (O (Par1 s)) a (O (Par1 s)) -> Par1 s)
-> V (O (Par1 s)) a (O (Par1 s))
-> O (Par1 s)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. forall s a b.
(HasV s a, HasV s b) =>
(a -> b) -> V s a s -> V s b s
inV @s a -> s
h))
{-# INLINE dualVR #-}

--                            h   :: a -> s
--                        inV h   :: V s a s -> V s s s
--                                :: V s a s -> Par1 s
--              (unpack . inV h)  :: V s a s -> s
--      linear1 (unpack . inV h)  :: V s a s
-- unV (linear1 (unpack . inV h)) :: a

andGradFLR :: forall s a. (IsScalar s, RepresentableVE s a, Num s)
           => (a -> s) -> (a -> s :* a)
andGradFLR :: forall s a.
(IsScalar s, RepresentableVE s a, Num s) =>
(a -> s) -> a -> s :* a
andGradFLR a -> s
f = ((a -> s) -> a) -> Prod (->) s (a -> s) -> Prod (->) s a
forall (k :: * -> * -> *) a b d.
(MonoidalPCat k, Ok3 k a b d) =>
k b d -> k (Prod k a b) (Prod k a d)
second (a -> s) -> a
forall s a.
(HasV s a, RepresentableVE s a, IsScalar s, Num s) =>
(a -> s) -> a
dualVR (Prod (->) s (a -> s) -> Prod (->) s a)
-> (a -> Prod (->) s (a -> s)) -> a -> Prod (->) s a
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> s) -> a -> Prod (->) s (a -> s)
forall a b. (a -> b) -> a -> b :* (a -> b)
andDerF a -> s
f
{-# INLINE andGradFLR #-}

gradFR :: forall s a. (IsScalar s, RepresentableVE s a, Num s)
       => (a -> s) -> (a -> a)
gradFR :: forall s a.
(IsScalar s, RepresentableVE s a, Num s) =>
(a -> s) -> a -> a
gradFR a -> s
f = (a -> s) -> a
forall s a.
(HasV s a, RepresentableVE s a, IsScalar s, Num s) =>
(a -> s) -> a
dualVR ((a -> s) -> a) -> (a -> a -> s) -> a -> a
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> s) -> a -> a -> s
forall a b. (a -> b) -> a -> a -> b
derF a -> s
f
{-# INLINE gradFR #-}

#endif