{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-unused-imports #-}
module ConCat.GradientDescent where
import Data.Function (on)
import Data.List (iterate,unfoldr)
import Control.Arrow (first)
import GHC.Generics (Par1(..))
import Control.Newtype.Generics (unpack)
import Data.Key (Zip)
import ConCat.Misc (Unop,Binop,R)
import ConCat.AD
import ConCat.Free.VectorSpace (HasV(..),onV,onV2)
import qualified ConCat.Free.VectorSpace as V
import ConCat.Free.LinearRow
import ConCat.Orphans ()
import ConCat.Category (dup)
maximize, minimize :: (HasV R a, Zip (V R a), Eq a) => R -> D R a R -> a -> a
maximize :: forall a. (HasV R a, Zip (V R a), Eq a) => R -> D R a R -> a -> a
maximize = (((D R a R -> a -> (a, Int)) -> D R a R -> a -> a)
-> (R -> D R a R -> a -> (a, Int)) -> R -> D R a R -> a -> a
forall a b. (a -> b) -> (R -> a) -> R -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap(((D R a R -> a -> (a, Int)) -> D R a R -> a -> a)
-> (R -> D R a R -> a -> (a, Int)) -> R -> D R a R -> a -> a)
-> (((a, Int) -> a)
-> (D R a R -> a -> (a, Int)) -> D R a R -> a -> a)
-> ((a, Int) -> a)
-> (R -> D R a R -> a -> (a, Int))
-> R
-> D R a R
-> a
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((a -> (a, Int)) -> a -> a)
-> (D R a R -> a -> (a, Int)) -> D R a R -> a -> a
forall a b. (a -> b) -> (D R a R -> a) -> D R a R -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap(((a -> (a, Int)) -> a -> a)
-> (D R a R -> a -> (a, Int)) -> D R a R -> a -> a)
-> (((a, Int) -> a) -> (a -> (a, Int)) -> a -> a)
-> ((a, Int) -> a)
-> (D R a R -> a -> (a, Int))
-> D R a R
-> a
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((a, Int) -> a) -> (a -> (a, Int)) -> a -> a
forall a b. (a -> b) -> (a -> a) -> a -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) (a, Int) -> a
forall a b. (a, b) -> a
fst R -> D R a R -> a -> (a, Int)
forall a.
(HasV R a, Zip (V R a), Eq a) =>
R -> D R a R -> a -> (a, Int)
maximizeN
minimize :: forall a. (HasV R a, Zip (V R a), Eq a) => R -> D R a R -> a -> a
minimize = R -> D R a R -> a -> a
forall a. (HasV R a, Zip (V R a), Eq a) => R -> D R a R -> a -> a
maximize (R -> D R a R -> a -> a) -> (R -> R) -> R -> D R a R -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. R -> R
forall a. Num a => a -> a
negate
maximizeN, minimizeN :: (HasV R a, Zip (V R a), Eq a) => R -> D R a R -> a -> (a,Int)
maximizeN :: forall a.
(HasV R a, Zip (V R a), Eq a) =>
R -> D R a R -> a -> (a, Int)
maximizeN R
gamma = R -> (a -> a) -> a -> (a, Int)
forall a.
(HasV R a, Zip (V R a), Eq a) =>
R -> (a -> a) -> a -> (a, Int)
chaseN R
gamma ((a -> a) -> a -> (a, Int))
-> (D R a R -> a -> a) -> D R a R -> a -> (a, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. D R a R -> a -> a
forall s a. (HasV s a, IsScalar s) => D s a s -> a -> a
gradientD
minimizeN :: forall a.
(HasV R a, Zip (V R a), Eq a) =>
R -> D R a R -> a -> (a, Int)
minimizeN = R -> D R a R -> a -> (a, Int)
forall a.
(HasV R a, Zip (V R a), Eq a) =>
R -> D R a R -> a -> (a, Int)
maximizeN (R -> D R a R -> a -> (a, Int))
-> (R -> R) -> R -> D R a R -> a -> (a, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. R -> R
forall a. Num a => a -> a
negate
chaseN :: (HasV R a, Zip (V R a), Eq a) => R -> (a -> a) -> a -> (a,Int)
chaseN :: forall a.
(HasV R a, Zip (V R a), Eq a) =>
R -> (a -> a) -> a -> (a, Int)
chaseN R
gamma a -> a
next = (a -> a) -> a -> (a, Int)
forall a. Eq a => Unop a -> a -> (a, Int)
fixN (\ a
a -> a
a Binop a
forall a. (HasV R a, Zip (V R a)) => Binop a
^+^ R
gamma R -> a -> a
forall a. (HasV R a, Functor (V R a)) => R -> Unop a
*^ a -> a
next a
a)
chase :: (HasV R a, Zip (V R a), Eq a) => R -> Unop (a -> a)
chase :: forall a. (HasV R a, Zip (V R a), Eq a) => R -> Unop (a -> a)
chase = ((((a -> a) -> a -> (a, Int)) -> (a -> a) -> a -> a)
-> (R -> (a -> a) -> a -> (a, Int)) -> R -> (a -> a) -> a -> a
forall a b. (a -> b) -> (R -> a) -> R -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap((((a -> a) -> a -> (a, Int)) -> (a -> a) -> a -> a)
-> (R -> (a -> a) -> a -> (a, Int)) -> R -> (a -> a) -> a -> a)
-> (((a, Int) -> a)
-> ((a -> a) -> a -> (a, Int)) -> (a -> a) -> a -> a)
-> ((a, Int) -> a)
-> (R -> (a -> a) -> a -> (a, Int))
-> R
-> (a -> a)
-> a
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((a -> (a, Int)) -> a -> a)
-> ((a -> a) -> a -> (a, Int)) -> (a -> a) -> a -> a
forall a b. (a -> b) -> ((a -> a) -> a) -> (a -> a) -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap(((a -> (a, Int)) -> a -> a)
-> ((a -> a) -> a -> (a, Int)) -> (a -> a) -> a -> a)
-> (((a, Int) -> a) -> (a -> (a, Int)) -> a -> a)
-> ((a, Int) -> a)
-> ((a -> a) -> a -> (a, Int))
-> (a -> a)
-> a
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((a, Int) -> a) -> (a -> (a, Int)) -> a -> a
forall a b. (a -> b) -> (a -> a) -> a -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) (a, Int) -> a
forall a b. (a, b) -> a
fst R -> (a -> a) -> a -> (a, Int)
forall a.
(HasV R a, Zip (V R a), Eq a) =>
R -> (a -> a) -> a -> (a, Int)
chaseN
chaseL :: (HasV R a, Zip (V R a)) => R -> (a -> a) -> a -> [a]
chaseL :: forall a. (HasV R a, Zip (V R a)) => R -> (a -> a) -> a -> [a]
chaseL R
gamma a -> a
next = (a -> a) -> a -> [a]
forall a. (a -> a) -> a -> [a]
iterate (\ a
a -> a
a Binop a
forall a. (HasV R a, Zip (V R a)) => Binop a
^+^ R
gamma R -> a -> a
forall a. (HasV R a, Functor (V R a)) => R -> Unop a
*^ a -> a
next a
a)
maximizeL, minimizeL :: (HasV R a, Zip (V R a)) => R -> D R a R -> a -> [a]
maximizeL :: forall a. (HasV R a, Zip (V R a)) => R -> D R a R -> a -> [a]
maximizeL R
gamma = R -> (a -> a) -> a -> [a]
forall a. (HasV R a, Zip (V R a)) => R -> (a -> a) -> a -> [a]
chaseL R
gamma ((a -> a) -> a -> [a])
-> (D R a R -> a -> a) -> D R a R -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. D R a R -> a -> a
forall s a. (HasV s a, IsScalar s) => D s a s -> a -> a
gradientD
minimizeL :: forall a. (HasV R a, Zip (V R a)) => R -> D R a R -> a -> [a]
minimizeL = R -> D R a R -> a -> [a]
forall a. (HasV R a, Zip (V R a)) => R -> D R a R -> a -> [a]
maximizeL (R -> D R a R -> a -> [a]) -> (R -> R) -> R -> D R a R -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. R -> R
forall a. Num a => a -> a
negate
fixBy :: (a -> a -> Bool) -> Unop (Unop a)
fixBy :: forall a. (a -> a -> Bool) -> Unop (Unop a)
fixBy a -> a -> Bool
eq Unop a
next = Unop a
go
where
go :: Unop a
go a
a | a
a' a -> a -> Bool
`eq` a
a = a
a'
| Bool
otherwise = Unop a
go a
a'
where
a' :: a
a' = Unop a
next a
a
fixByN :: (a -> a -> Bool) -> Unop a -> a -> (a,Int)
fixByN :: forall a. (a -> a -> Bool) -> Unop a -> a -> (a, Int)
fixByN a -> a -> Bool
eq Unop a
next a
a0 = ((a, Int) -> (a, Int) -> Bool) -> Unop (Unop (a, Int))
forall a. (a -> a -> Bool) -> Unop (Unop a)
fixBy (a -> a -> Bool
eq (a -> a -> Bool) -> ((a, Int) -> a) -> (a, Int) -> (a, Int) -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (a, Int) -> a
forall a b. (a, b) -> a
fst) Unop (a, Int)
forall {b}. Num b => (a, b) -> (a, b)
next' (a
a0,Int
0)
where
next' :: (a, b) -> (a, b)
next' (a
a,!b
n) = (Unop a
next a
a, b
nb -> b -> b
forall a. Num a => a -> a -> a
+b
1)
fixN :: Eq a => Unop a -> a -> (a,Int)
fixN :: forall a. Eq a => Unop a -> a -> (a, Int)
fixN = (a -> a -> Bool) -> Unop a -> a -> (a, Int)
forall a. (a -> a -> Bool) -> Unop a -> a -> (a, Int)
fixByN a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==)
fixEq :: Eq a => Unop (Unop a)
fixEq :: forall a. Eq a => Unop (Unop a)
fixEq = (a -> a -> Bool) -> Unop (Unop a)
forall a. (a -> a -> Bool) -> Unop (Unop a)
fixBy a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==)
infixl 7 *^
infixl 6 ^-^, ^+^
(*^) :: (HasV R a, Functor (V R a)) => R -> Unop a
*^ :: forall a. (HasV R a, Functor (V R a)) => R -> Unop a
(*^) R
s = (V R a R -> V R a R) -> a -> a
forall s a b.
(HasV s a, HasV s b) =>
(V s a s -> V s b s) -> a -> b
onV (R -> V R a R -> V R a R
forall (f :: * -> *) s. (Functor f, Num s) => s -> f s -> f s
(V.*^) R
s)
negateV :: (HasV R a, Functor (V R a)) => Unop a
negateV :: forall a. (HasV R a, Functor (V R a)) => Unop a
negateV = (R -> Unop a
forall a. (HasV R a, Functor (V R a)) => R -> Unop a
(*^) (-R
1))
(^+^) :: forall a. (HasV R a, Zip (V R a)) => Binop a
^+^ :: forall a. (HasV R a, Zip (V R a)) => Binop a
(^+^) = (V R a R -> V R a R -> V R a R) -> a -> a -> a
forall s a b c.
(HasV s a, HasV s b, HasV s c) =>
(V s a s -> V s b s -> V s c s) -> a -> b -> c
onV2 (V R a R -> V R a R -> V R a R
forall (f :: * -> *) s. (Zip f, Num s) => f s -> f s -> f s
(V.^+^) :: Binop (V R a R))
(^-^) :: forall a. (HasV R a, Zip (V R a)) => Binop a
^-^ :: forall a. (HasV R a, Zip (V R a)) => Binop a
(^-^) = (V R a R -> V R a R -> V R a R) -> a -> a -> a
forall s a b c.
(HasV s a, HasV s b, HasV s c) =>
(V s a s -> V s b s -> V s c s) -> a -> b -> c
onV2 (V R a R -> V R a R -> V R a R
forall (f :: * -> *) s. (Zip f, Num s) => f s -> f s -> f s
(V.^-^) :: Binop (V R a R))