{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module ConCat.AD where
import Prelude hiding (id,(.),curry,uncurry,const,unzip)
import GHC.Generics(Par1(..))
import ConCat.Misc ((:*))
import ConCat.Rep (repr)
import ConCat.Free.VectorSpace (HasV(..),IsScalar)
import ConCat.Free.LinearRow
import ConCat.AltCat
import ConCat.GAD
type D s = GD (L s)
andDer :: forall s a b . (a -> b) -> (a -> b :* L s a b)
andDer :: forall s a b. (a -> b) -> a -> b :* L s a b
andDer = (a -> b) -> a -> b :* L s a b
forall (k :: * -> * -> *) a b. (a -> b) -> a -> b :* k a b
andDeriv
{-# INLINE andDer #-}
der :: forall s a b . (a -> b) -> (a -> L s a b)
der :: forall s a b. (a -> b) -> a -> L s a b
der = (a -> b) -> a -> L s a b
forall (k :: * -> * -> *) a b. (a -> b) -> a -> k a b
deriv
{-# INLINE der #-}
gradient :: (HasV s a, IsScalar s) => (a -> s) -> a -> a
gradient :: forall s a. (HasV s a, IsScalar s) => (a -> s) -> a -> a
gradient a -> s
f = D s a s -> a -> a
forall s a. (HasV s a, IsScalar s) => D s a s -> a -> a
gradientD ((a -> s) -> D s a s
forall (k :: * -> * -> *) a b. (a -> b) -> k a b
toCcc a -> s
f)
{-# INLINE gradient #-}
gradientD :: (HasV s a, IsScalar s) => D s a s -> a -> a
gradientD :: forall s a. (HasV s a, IsScalar s) => D s a s -> a -> a
gradientD (D a -> (s, L s a s)
h) = V s a s -> a
forall s a. HasV s a => V s a s -> a
unV (V s a s -> a) -> (a -> V s a s) -> a -> a
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. Par1 (V s a s) -> V s a s
forall p. Par1 p -> p
unPar1 (Par1 (V s a s) -> V s a s)
-> (a -> Par1 (V s a s)) -> a -> V s a s
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. L s a s -> Par1 (V s a s)
L s a s -> Rep (L s a s)
forall a. HasRep a => a -> Rep a
repr (L s a s -> Par1 (V s a s))
-> (a -> L s a s) -> a -> Par1 (V s a s)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (s, L s a s) -> L s a s
forall a b. (a, b) -> b
snd ((s, L s a s) -> L s a s) -> (a -> (s, L s a s)) -> a -> L s a s
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. a -> (s, L s a s)
h
{-# INLINE gradientD #-}