{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}

{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- {-# OPTIONS_GHC -fno-warn-unused-imports #-} -- TEMP

-- | Automatic differentiation

module ConCat.AD where

import Prelude hiding (id,(.),curry,uncurry,const,unzip)

import GHC.Generics(Par1(..))

import ConCat.Misc ((:*))
import ConCat.Rep (repr)
import ConCat.Free.VectorSpace (HasV(..),IsScalar)
import ConCat.Free.LinearRow
import ConCat.AltCat
import ConCat.GAD

-- | Differentiable functions with composed-functor style linear maps as
-- derivatives.
type D s = GD (L s)

{--------------------------------------------------------------------
    Differentiation interface
--------------------------------------------------------------------}

andDer :: forall s a b . (a -> b) -> (a -> b :* L s a b)
andDer :: forall s a b. (a -> b) -> a -> b :* L s a b
andDer = (a -> b) -> a -> b :* L s a b
forall (k :: * -> * -> *) a b. (a -> b) -> a -> b :* k a b
andDeriv
{-# INLINE andDer #-}

der :: forall s a b . (a -> b) -> (a -> L s a b)
der :: forall s a b. (a -> b) -> a -> L s a b
der = (a -> b) -> a -> L s a b
forall (k :: * -> * -> *) a b. (a -> b) -> a -> k a b
deriv
{-# INLINE der #-}

gradient :: (HasV s a, IsScalar s) => (a -> s) -> a -> a
gradient :: forall s a. (HasV s a, IsScalar s) => (a -> s) -> a -> a
gradient a -> s
f = D s a s -> a -> a
forall s a. (HasV s a, IsScalar s) => D s a s -> a -> a
gradientD ((a -> s) -> D s a s
forall (k :: * -> * -> *) a b. (a -> b) -> k a b
toCcc a -> s
f)
{-# INLINE gradient #-}

gradientD :: (HasV s a, IsScalar s) => D s a s -> a -> a
gradientD :: forall s a. (HasV s a, IsScalar s) => D s a s -> a -> a
gradientD (D a -> (s, L s a s)
h) = V s a s -> a
forall s a. HasV s a => V s a s -> a
unV (V s a s -> a) -> (a -> V s a s) -> a -> a
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. Par1 (V s a s) -> V s a s
forall p. Par1 p -> p
unPar1 (Par1 (V s a s) -> V s a s)
-> (a -> Par1 (V s a s)) -> a -> V s a s
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. L s a s -> Par1 (V s a s)
L s a s -> Rep (L s a s)
forall a. HasRep a => a -> Rep a
repr (L s a s -> Par1 (V s a s))
-> (a -> L s a s) -> a -> Par1 (V s a s)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (s, L s a s) -> L s a s
forall a b. (a, b) -> b
snd ((s, L s a s) -> L s a s) -> (a -> (s, L s a s)) -> a -> L s a s
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. a -> (s, L s a s)
h
{-# INLINE gradientD #-}


--                             f :: a -> s
--                         der f :: a -> L s a s
--                unpack . der f :: a -> V s s (V s a s)
--                               :: a -> Par1 (V s a s)
--       unPar1 . unpack . der f :: a -> V s a s
-- unV . unPar1 . unpack . der f :: a -> a