{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# OPTIONS_GHC -fno-warn-unused-imports #-}
module ConCat.ADFun where
import Prelude hiding (id,(.),curry,uncurry,const,zip,unzip)
import qualified Prelude as P
import GHC.Generics (Par1)
import Control.Newtype.Generics (Newtype(..))
import Data.Pointed (Pointed)
import Data.Key (Zip(..))
import Data.Constraint hiding ((&&&),(***),(:=>))
import Data.Distributive (Distributive(..))
import Data.Functor.Rep (Representable(..))
import ConCat.Misc ((:*),R,Yes1,unzip,type (&+&),sqr,result)
import ConCat.Rep (repr)
import ConCat.Free.VectorSpace (HasV(..),inV,IsScalar)
import ConCat.Free.LinearRow
import ConCat.AltCat
import ConCat.GAD
import ConCat.AdditiveFun
import qualified ConCat.Category as C
type D = GD (-+>)
#if 0
instance ClosedCat D where
apply = D (\ (f,a) -> (f a, \ (df,da) -> df a ^+^ deriv f a da))
curry (D h) = D (\ a -> (curry f a, \ da -> \ b -> f' (a,b) (da,zero)))
where
(f,f') = unfork h
{-# INLINE apply #-}
{-# INLINE curry #-}
#elif 0
instance ClosedCat D where
apply = applyD ; {-# INLINE apply #-}
curry = curryD ; {-# INLINE curry #-}
applyD :: forall a b. Ok2 D a b => D ((a -> b) :* a) b
applyD =
D (\ (f,a) -> let (b,f') = andDerF f a in (b, \ (df,da) -> df a ^+^ f' da))
curryD :: forall a b c. Ok3 D a b c => D (a :* b) c -> D a (b -> c)
curryD (D (unfork -> (f,f'))) =
D (\ a -> (curry f a, \ da -> \ b -> f' (a,b) (da,zero)))
{-# INLINE applyD #-}
{-# INLINE curryD #-}
#else
#endif
andDerF :: forall a b . (a -> b) -> (a -> b :* (a -> b))
andDerF :: forall a b. (a -> b) -> a -> b :* (a -> b)
andDerF a -> b
f = GD (-+>) a b -> a -> b :* Rep (a -+> b)
forall (k :: * -> * -> *) a b.
HasRep (k a b) =>
GD k a b -> a -> b :* Rep (k a b)
unMkD (forall (k :: * -> * -> *) a b. (a -> b) -> k a b
toCcc @D a -> b
f)
{-# INLINE andDerF #-}
andDerF' :: forall a b . (a -> b) -> (a -> b :* (a -> b))
andDerF' :: forall a b. (a -> b) -> a -> b :* (a -> b)
andDerF' a -> b
f = GD (-+>) a b -> a -> b :* Rep (a -+> b)
forall (k :: * -> * -> *) a b.
HasRep (k a b) =>
GD k a b -> a -> b :* Rep (k a b)
unMkD (forall (k :: * -> * -> *) a b. (a -> b) -> k a b
toCcc' @D a -> b
f)
{-# INLINE andDerF' #-}
derF :: forall a b . (a -> b) -> (a -> (a -> b))
derF :: forall a b. (a -> b) -> a -> a -> b
derF = (((a -> (b, a -> b)) -> a -> a -> b)
-> ((a -> b) -> a -> (b, a -> b)) -> (a -> b) -> a -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
result (((a -> (b, a -> b)) -> a -> a -> b)
-> ((a -> b) -> a -> (b, a -> b)) -> (a -> b) -> a -> a -> b)
-> (((b, a -> b) -> a -> b) -> (a -> (b, a -> b)) -> a -> a -> b)
-> ((b, a -> b) -> a -> b)
-> ((a -> b) -> a -> (b, a -> b))
-> (a -> b)
-> a
-> a
-> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
P.. ((b, a -> b) -> a -> b) -> (a -> (b, a -> b)) -> a -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
result) (b, a -> b) -> a -> b
forall a b. (a, b) -> b
snd (a -> b) -> a -> (b, a -> b)
forall a b. (a -> b) -> a -> b :* (a -> b)
andDerF
{-# INLINE derF #-}
andDerFL :: forall s a b. HasLin s a b => (a -> b) -> (a -> b :* L s a b)
andDerFL :: forall s a b. HasLin s a b => (a -> b) -> a -> b :* L s a b
andDerFL a -> b
f = ((a -> b) -> L s a b)
-> Prod (->) b (a -> b) -> Prod (->) b (L s a b)
forall (k :: * -> * -> *) a b d.
(MonoidalPCat k, Ok3 k a b d) =>
k b d -> k (Prod k a b) (Prod k a d)
second (a -> b) -> L s a b
forall s a b. HasLin s a b => (a -> b) -> L s a b
linear (Prod (->) b (a -> b) -> Prod (->) b (L s a b))
-> (a -> Prod (->) b (a -> b)) -> a -> Prod (->) b (L s a b)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> b) -> a -> Prod (->) b (a -> b)
forall a b. (a -> b) -> a -> b :* (a -> b)
andDerF a -> b
f
{-# INLINE andDerFL #-}
derFL :: forall s a b. HasLin s a b => (a -> b) -> (a -> L s a b)
derFL :: forall s a b. HasLin s a b => (a -> b) -> a -> L s a b
derFL a -> b
f = (a -> b) -> L s a b
forall s a b. HasLin s a b => (a -> b) -> L s a b
linear ((a -> b) -> L s a b) -> (a -> a -> b) -> a -> L s a b
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> b) -> a -> a -> b
forall a b. (a -> b) -> a -> a -> b
derF a -> b
f
{-# INLINE derFL #-}
dualV :: forall s a. (HasLin s a s, IsScalar s) => (a -> s) -> a
dualV :: forall s a. (HasLin s a s, IsScalar s) => (a -> s) -> a
dualV a -> s
h = V s a s -> a
forall s a. HasV s a => V s a s -> a
unV ((:-*) (V s a) Par1 s -> O ((:-*) (V s a) Par1 s)
forall n. Newtype n => n -> O n
unpack (L s a s -> O (L s a s)
forall n. Newtype n => n -> O n
unpack (forall s a b. HasLin s a b => (a -> b) -> L s a b
linear @s a -> s
h)))
{-# INLINE dualV #-}
andGradFL :: (HasLin s a s, IsScalar s) => (a -> s) -> (a -> s :* a)
andGradFL :: forall s a. (HasLin s a s, IsScalar s) => (a -> s) -> a -> s :* a
andGradFL a -> s
f = ((a -> s) -> a) -> Prod (->) s (a -> s) -> Prod (->) s a
forall (k :: * -> * -> *) a b d.
(MonoidalPCat k, Ok3 k a b d) =>
k b d -> k (Prod k a b) (Prod k a d)
second (a -> s) -> a
forall s a. (HasLin s a s, IsScalar s) => (a -> s) -> a
dualV (Prod (->) s (a -> s) -> Prod (->) s a)
-> (a -> Prod (->) s (a -> s)) -> a -> Prod (->) s a
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> s) -> a -> Prod (->) s (a -> s)
forall a b. (a -> b) -> a -> b :* (a -> b)
andDerF a -> s
f
{-# INLINE andGradFL #-}
gradF :: (HasLin s a s, IsScalar s) => (a -> s) -> (a -> a)
gradF :: forall s a. (HasLin s a s, IsScalar s) => (a -> s) -> a -> a
gradF a -> s
f = (a -> s) -> a
forall s a. (HasLin s a s, IsScalar s) => (a -> s) -> a
dualV ((a -> s) -> a) -> (a -> a -> s) -> a -> a
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> s) -> a -> a -> s
forall a b. (a -> b) -> a -> a -> b
derF a -> s
f
{-# INLINE gradF #-}
#if 1
linear1 :: (Representable f, Eq (Rep f), Num s)
=> (f s -> s) -> f s
linear1 :: forall (f :: * -> *) s.
(Representable f, Eq (Rep f), Num s) =>
(f s -> s) -> f s
linear1 = ((f s -> s) -> f (f s) -> f s
forall (k :: * -> * -> *) (h :: * -> *) a b.
(FunctorCat k h, Ok2 k a b) =>
k a b -> k (h a) (h b)
`fmapC` s -> s -> f (f s)
forall (h :: * -> *) a. Diagonal h => a -> a -> h (h a)
diag s
0 s
1)
{-# INLINE linear1 #-}
linearN :: (Representable f, Eq (Rep f), Distributive g, Num s)
=> (f s -> g s) -> (f :-* g) s
linearN :: forall (f :: * -> *) (g :: * -> *) s.
(Representable f, Eq (Rep f), Distributive g, Num s) =>
(f s -> g s) -> (:-*) f g s
linearN f s -> g s
h = (f s -> s) -> f s
forall (f :: * -> *) s.
(Representable f, Eq (Rep f), Num s) =>
(f s -> s) -> f s
linear1 ((f s -> s) -> f s) -> g (f s -> s) -> g (f s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (f s -> g s) -> g (f s -> s)
forall (g :: * -> *) (f :: * -> *) a.
(Distributive g, Functor f) =>
f (g a) -> g (f a)
forall (f :: * -> *) a. Functor f => f (g a) -> g (f a)
distribute f s -> g s
h
{-# INLINE linearN #-}
type RepresentableV s a = (HasV s a, Representable (V s a))
type RepresentableVE s a = (RepresentableV s a, Eq (Rep (V s a)))
type HasLinR s a b = (RepresentableVE s a, RepresentableV s b, Num s)
linearNR :: ( HasV s a, RepresentableVE s a
, HasV s b, Distributive (V s b), Num s )
=> (a -> b) -> L s a b
linearNR :: forall s a b.
(HasV s a, RepresentableVE s a, HasV s b, Distributive (V s b),
Num s) =>
(a -> b) -> L s a b
linearNR a -> b
h = O (L s a b) -> L s a b
forall n. Newtype n => O n -> n
pack ((V s a s -> s) -> V s a s
forall (f :: * -> *) s.
(Representable f, Eq (Rep f), Num s) =>
(f s -> s) -> f s
linear1 ((V s a s -> s) -> V s a s)
-> V s b (V s a s -> s) -> V s b (V s a s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (V s a s -> V s b s) -> V s b (V s a s -> s)
forall (g :: * -> *) (f :: * -> *) a.
(Distributive g, Functor f) =>
f (g a) -> g (f a)
forall (f :: * -> *) a. Functor f => f (V s b a) -> V s b (f a)
distribute ((a -> b) -> V s a s -> V s b s
forall s a b.
(HasV s a, HasV s b) =>
(a -> b) -> V s a s -> V s b s
inV a -> b
h))
{-# INLINE linearNR #-}
andDerFLR :: forall s a b. HasLinR s a b => (a -> b) -> (a -> b :* L s a b)
andDerFLR :: forall s a b. HasLinR s a b => (a -> b) -> a -> b :* L s a b
andDerFLR a -> b
f = ((a -> b) -> L s a b)
-> Prod (->) b (a -> b) -> Prod (->) b (L s a b)
forall (k :: * -> * -> *) a b d.
(MonoidalPCat k, Ok3 k a b d) =>
k b d -> k (Prod k a b) (Prod k a d)
second (a -> b) -> L s a b
forall s a b.
(HasV s a, RepresentableVE s a, HasV s b, Distributive (V s b),
Num s) =>
(a -> b) -> L s a b
linearNR (Prod (->) b (a -> b) -> Prod (->) b (L s a b))
-> (a -> Prod (->) b (a -> b)) -> a -> Prod (->) b (L s a b)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> b) -> a -> Prod (->) b (a -> b)
forall a b. (a -> b) -> a -> b :* (a -> b)
andDerF a -> b
f
{-# INLINE andDerFLR #-}
derFLR :: forall s a b. HasLinR s a b => (a -> b) -> (a -> L s a b)
derFLR :: forall s a b. HasLinR s a b => (a -> b) -> a -> L s a b
derFLR a -> b
f = (a -> b) -> L s a b
forall s a b.
(HasV s a, RepresentableVE s a, HasV s b, Distributive (V s b),
Num s) =>
(a -> b) -> L s a b
linearNR ((a -> b) -> L s a b) -> (a -> a -> b) -> a -> L s a b
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> b) -> a -> a -> b
forall a b. (a -> b) -> a -> a -> b
derF a -> b
f
{-# INLINE derFLR #-}
dualVR :: forall s a. (HasV s a, RepresentableVE s a, IsScalar s, Num s)
=> (a -> s) -> a
dualVR :: forall s a.
(HasV s a, RepresentableVE s a, IsScalar s, Num s) =>
(a -> s) -> a
dualVR a -> s
h = V (O (Par1 s)) a (O (Par1 s)) -> a
forall s a. HasV s a => V s a s -> a
unV ((V (O (Par1 s)) a (O (Par1 s)) -> O (Par1 s))
-> V (O (Par1 s)) a (O (Par1 s))
forall (f :: * -> *) s.
(Representable f, Eq (Rep f), Num s) =>
(f s -> s) -> f s
linear1 (Par1 s -> O (Par1 s)
forall n. Newtype n => n -> O n
unpack (Par1 s -> O (Par1 s))
-> (V (O (Par1 s)) a (O (Par1 s)) -> Par1 s)
-> V (O (Par1 s)) a (O (Par1 s))
-> O (Par1 s)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. forall s a b.
(HasV s a, HasV s b) =>
(a -> b) -> V s a s -> V s b s
inV @s a -> s
h))
{-# INLINE dualVR #-}
andGradFLR :: forall s a. (IsScalar s, RepresentableVE s a, Num s)
=> (a -> s) -> (a -> s :* a)
andGradFLR :: forall s a.
(IsScalar s, RepresentableVE s a, Num s) =>
(a -> s) -> a -> s :* a
andGradFLR a -> s
f = ((a -> s) -> a) -> Prod (->) s (a -> s) -> Prod (->) s a
forall (k :: * -> * -> *) a b d.
(MonoidalPCat k, Ok3 k a b d) =>
k b d -> k (Prod k a b) (Prod k a d)
second (a -> s) -> a
forall s a.
(HasV s a, RepresentableVE s a, IsScalar s, Num s) =>
(a -> s) -> a
dualVR (Prod (->) s (a -> s) -> Prod (->) s a)
-> (a -> Prod (->) s (a -> s)) -> a -> Prod (->) s a
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> s) -> a -> Prod (->) s (a -> s)
forall a b. (a -> b) -> a -> b :* (a -> b)
andDerF a -> s
f
{-# INLINE andGradFLR #-}
gradFR :: forall s a. (IsScalar s, RepresentableVE s a, Num s)
=> (a -> s) -> (a -> a)
gradFR :: forall s a.
(IsScalar s, RepresentableVE s a, Num s) =>
(a -> s) -> a -> a
gradFR a -> s
f = (a -> s) -> a
forall s a.
(HasV s a, RepresentableVE s a, IsScalar s, Num s) =>
(a -> s) -> a
dualVR ((a -> s) -> a) -> (a -> a -> s) -> a -> a
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (a -> s) -> a -> a -> s
forall a b. (a -> b) -> a -> a -> b
derF a -> s
f
{-# INLINE gradFR #-}
#endif