{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE BangPatterns #-}
-- {-# LANGUAGE AllowAmbiguousTypes #-}

{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-unused-imports #-} -- TEMP

-- | Experimenting with formulations of gradient descent minimization

module ConCat.GradientDescent where

import Data.Function (on)
import Data.List (iterate,unfoldr)
import Control.Arrow (first)

import GHC.Generics (Par1(..))

import Control.Newtype.Generics (unpack)
import Data.Key (Zip)

import ConCat.Misc (Unop,Binop,R)
import ConCat.AD
import ConCat.Free.VectorSpace (HasV(..),onV,onV2)
import qualified ConCat.Free.VectorSpace as V
import ConCat.Free.LinearRow
import ConCat.Orphans ()
import ConCat.Category (dup)

{--------------------------------------------------------------------
    Minimization via gradient descent
--------------------------------------------------------------------}

maximize, minimize :: (HasV R a, Zip (V R a), Eq a) => R -> D R a R -> a -> a
maximize :: forall a. (HasV R a, Zip (V R a), Eq a) => R -> D R a R -> a -> a
maximize = (((D R a R -> a -> (a, Int)) -> D R a R -> a -> a)
-> (R -> D R a R -> a -> (a, Int)) -> R -> D R a R -> a -> a
forall a b. (a -> b) -> (R -> a) -> R -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap(((D R a R -> a -> (a, Int)) -> D R a R -> a -> a)
 -> (R -> D R a R -> a -> (a, Int)) -> R -> D R a R -> a -> a)
-> (((a, Int) -> a)
    -> (D R a R -> a -> (a, Int)) -> D R a R -> a -> a)
-> ((a, Int) -> a)
-> (R -> D R a R -> a -> (a, Int))
-> R
-> D R a R
-> a
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((a -> (a, Int)) -> a -> a)
-> (D R a R -> a -> (a, Int)) -> D R a R -> a -> a
forall a b. (a -> b) -> (D R a R -> a) -> D R a R -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap(((a -> (a, Int)) -> a -> a)
 -> (D R a R -> a -> (a, Int)) -> D R a R -> a -> a)
-> (((a, Int) -> a) -> (a -> (a, Int)) -> a -> a)
-> ((a, Int) -> a)
-> (D R a R -> a -> (a, Int))
-> D R a R
-> a
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((a, Int) -> a) -> (a -> (a, Int)) -> a -> a
forall a b. (a -> b) -> (a -> a) -> a -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) (a, Int) -> a
forall a b. (a, b) -> a
fst R -> D R a R -> a -> (a, Int)
forall a.
(HasV R a, Zip (V R a), Eq a) =>
R -> D R a R -> a -> (a, Int)
maximizeN
minimize :: forall a. (HasV R a, Zip (V R a), Eq a) => R -> D R a R -> a -> a
minimize = R -> D R a R -> a -> a
forall a. (HasV R a, Zip (V R a), Eq a) => R -> D R a R -> a -> a
maximize (R -> D R a R -> a -> a) -> (R -> R) -> R -> D R a R -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. R -> R
forall a. Num a => a -> a
negate
-- {-# INLINE maximize #-}
-- {-# INLINE minimize #-}

-- | Optimize a function using gradient ascent, with step count.
maximizeN, minimizeN :: (HasV R a, Zip (V R a), Eq a) => R -> D R a R -> a -> (a,Int)
-- maximizeN gamma f = fixN (\ a -> a ^+^ gamma *^ gradient' f a)
-- maximizeN gamma f = chaseN gamma (gradientD f)
maximizeN :: forall a.
(HasV R a, Zip (V R a), Eq a) =>
R -> D R a R -> a -> (a, Int)
maximizeN R
gamma = R -> (a -> a) -> a -> (a, Int)
forall a.
(HasV R a, Zip (V R a), Eq a) =>
R -> (a -> a) -> a -> (a, Int)
chaseN R
gamma ((a -> a) -> a -> (a, Int))
-> (D R a R -> a -> a) -> D R a R -> a -> (a, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. D R a R -> a -> a
forall s a. (HasV s a, IsScalar s) => D s a s -> a -> a
gradientD
minimizeN :: forall a.
(HasV R a, Zip (V R a), Eq a) =>
R -> D R a R -> a -> (a, Int)
minimizeN = R -> D R a R -> a -> (a, Int)
forall a.
(HasV R a, Zip (V R a), Eq a) =>
R -> D R a R -> a -> (a, Int)
maximizeN (R -> D R a R -> a -> (a, Int))
-> (R -> R) -> R -> D R a R -> a -> (a, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. R -> R
forall a. Num a => a -> a
negate
-- {-# INLINE maximizeN #-}
-- {-# INLINE minimizeN #-}

-- TODO: adaptive step sizes

chaseN :: (HasV R a, Zip (V R a), Eq a) => R -> (a -> a) -> a -> (a,Int)
chaseN :: forall a.
(HasV R a, Zip (V R a), Eq a) =>
R -> (a -> a) -> a -> (a, Int)
chaseN R
gamma a -> a
next = (a -> a) -> a -> (a, Int)
forall a. Eq a => Unop a -> a -> (a, Int)
fixN (\ a
a -> a
a Binop a
forall a. (HasV R a, Zip (V R a)) => Binop a
^+^ R
gamma R -> a -> a
forall a. (HasV R a, Functor (V R a)) => R -> Unop a
*^ a -> a
next a
a)

chase :: (HasV R a, Zip (V R a), Eq a) => R -> Unop (a -> a)
chase :: forall a. (HasV R a, Zip (V R a), Eq a) => R -> Unop (a -> a)
chase = ((((a -> a) -> a -> (a, Int)) -> (a -> a) -> a -> a)
-> (R -> (a -> a) -> a -> (a, Int)) -> R -> (a -> a) -> a -> a
forall a b. (a -> b) -> (R -> a) -> R -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap((((a -> a) -> a -> (a, Int)) -> (a -> a) -> a -> a)
 -> (R -> (a -> a) -> a -> (a, Int)) -> R -> (a -> a) -> a -> a)
-> (((a, Int) -> a)
    -> ((a -> a) -> a -> (a, Int)) -> (a -> a) -> a -> a)
-> ((a, Int) -> a)
-> (R -> (a -> a) -> a -> (a, Int))
-> R
-> (a -> a)
-> a
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((a -> (a, Int)) -> a -> a)
-> ((a -> a) -> a -> (a, Int)) -> (a -> a) -> a -> a
forall a b. (a -> b) -> ((a -> a) -> a) -> (a -> a) -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap(((a -> (a, Int)) -> a -> a)
 -> ((a -> a) -> a -> (a, Int)) -> (a -> a) -> a -> a)
-> (((a, Int) -> a) -> (a -> (a, Int)) -> a -> a)
-> ((a, Int) -> a)
-> ((a -> a) -> a -> (a, Int))
-> (a -> a)
-> a
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((a, Int) -> a) -> (a -> (a, Int)) -> a -> a
forall a b. (a -> b) -> (a -> a) -> a -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) (a, Int) -> a
forall a b. (a, b) -> a
fst R -> (a -> a) -> a -> (a, Int)
forall a.
(HasV R a, Zip (V R a), Eq a) =>
R -> (a -> a) -> a -> (a, Int)
chaseN

-- Experiment: generate list of approximations

-- chaseL :: (HasV R a, Zip (V R a), Eq a) => R -> (a -> a) -> a -> [a]
chaseL :: (HasV R a, Zip (V R a)) => R -> (a -> a) -> a -> [a]
chaseL :: forall a. (HasV R a, Zip (V R a)) => R -> (a -> a) -> a -> [a]
chaseL R
gamma a -> a
next = (a -> a) -> a -> [a]
forall a. (a -> a) -> a -> [a]
iterate (\ a
a -> a
a Binop a
forall a. (HasV R a, Zip (V R a)) => Binop a
^+^ R
gamma R -> a -> a
forall a. (HasV R a, Functor (V R a)) => R -> Unop a
*^ a -> a
next a
a)

-- maximizeL, minimizeL :: (HasV R a, Zip (V R a), Eq a) => R -> D R a R -> a -> [a]
maximizeL, minimizeL :: (HasV R a, Zip (V R a)) => R -> D R a R -> a -> [a]
maximizeL :: forall a. (HasV R a, Zip (V R a)) => R -> D R a R -> a -> [a]
maximizeL R
gamma = R -> (a -> a) -> a -> [a]
forall a. (HasV R a, Zip (V R a)) => R -> (a -> a) -> a -> [a]
chaseL R
gamma ((a -> a) -> a -> [a])
-> (D R a R -> a -> a) -> D R a R -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. D R a R -> a -> a
forall s a. (HasV s a, IsScalar s) => D s a s -> a -> a
gradientD
minimizeL :: forall a. (HasV R a, Zip (V R a)) => R -> D R a R -> a -> [a]
minimizeL = R -> D R a R -> a -> [a]
forall a. (HasV R a, Zip (V R a)) => R -> D R a R -> a -> [a]
maximizeL (R -> D R a R -> a -> [a]) -> (R -> R) -> R -> D R a R -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. R -> R
forall a. Num a => a -> a
negate

{--------------------------------------------------------------------
    Fixed points
--------------------------------------------------------------------}

-- Fixed point with comparision
fixBy :: (a -> a -> Bool) -> Unop (Unop a)
fixBy :: forall a. (a -> a -> Bool) -> Unop (Unop a)
fixBy a -> a -> Bool
eq Unop a
next = Unop a
go
 where
   go :: Unop a
go a
a | a
a' a -> a -> Bool
`eq` a
a = a
a'
        | Bool
otherwise = Unop a
go a
a'
    where
      a' :: a
a' = Unop a
next a
a

-- Fixed point with comparison and number of steps
fixByN :: (a -> a -> Bool) -> Unop a -> a -> (a,Int)
fixByN :: forall a. (a -> a -> Bool) -> Unop a -> a -> (a, Int)
fixByN a -> a -> Bool
eq Unop a
next a
a0 = ((a, Int) -> (a, Int) -> Bool) -> Unop (Unop (a, Int))
forall a. (a -> a -> Bool) -> Unop (Unop a)
fixBy (a -> a -> Bool
eq (a -> a -> Bool) -> ((a, Int) -> a) -> (a, Int) -> (a, Int) -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (a, Int) -> a
forall a b. (a, b) -> a
fst) Unop (a, Int)
forall {b}. Num b => (a, b) -> (a, b)
next' (a
a0,Int
0)
 where
   next' :: (a, b) -> (a, b)
next' (a
a,!b
n) = (Unop a
next a
a, b
nb -> b -> b
forall a. Num a => a -> a -> a
+b
1)

-- Fixed point using (==) and number of steps
fixN :: Eq a => Unop a -> a -> (a,Int)
fixN :: forall a. Eq a => Unop a -> a -> (a, Int)
fixN = (a -> a -> Bool) -> Unop a -> a -> (a, Int)
forall a. (a -> a -> Bool) -> Unop a -> a -> (a, Int)
fixByN a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==)

-- Fixed point
fixEq :: Eq a => Unop (Unop a)
fixEq :: forall a. Eq a => Unop (Unop a)
fixEq = (a -> a -> Bool) -> Unop (Unop a)
forall a. (a -> a -> Bool) -> Unop (Unop a)
fixBy a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==)

{--------------------------------------------------------------------
    Vector operations
--------------------------------------------------------------------}

-- The vector operations in VectorSpace are on free vector spaces (f s for
-- functor f and scalar field s), so define counterparts on regular values.

infixl 7 *^
infixl 6 ^-^, ^+^

(*^) :: (HasV R a, Functor (V R a)) => R -> Unop a
*^ :: forall a. (HasV R a, Functor (V R a)) => R -> Unop a
(*^) R
s = (V R a R -> V R a R) -> a -> a
forall s a b.
(HasV s a, HasV s b) =>
(V s a s -> V s b s) -> a -> b
onV (R -> V R a R -> V R a R
forall (f :: * -> *) s. (Functor f, Num s) => s -> f s -> f s
(V.*^) R
s)

negateV :: (HasV R a, Functor (V R a)) => Unop a
negateV :: forall a. (HasV R a, Functor (V R a)) => Unop a
negateV = (R -> Unop a
forall a. (HasV R a, Functor (V R a)) => R -> Unop a
(*^) (-R
1))
-- negateV = onV V.negateV

(^+^) :: forall a. (HasV R a, Zip (V R a)) => Binop a
^+^ :: forall a. (HasV R a, Zip (V R a)) => Binop a
(^+^) = (V R a R -> V R a R -> V R a R) -> a -> a -> a
forall s a b c.
(HasV s a, HasV s b, HasV s c) =>
(V s a s -> V s b s -> V s c s) -> a -> b -> c
onV2 (V R a R -> V R a R -> V R a R
forall (f :: * -> *) s. (Zip f, Num s) => f s -> f s -> f s
(V.^+^) :: Binop (V R a R))

-- (^+^) :: forall s a. (HasV s a, Zip (V s a), Num s) => Binop a
-- (^+^) = onV2 @s (V.^+^)

(^-^) :: forall a. (HasV R a, Zip (V R a)) => Binop a
^-^ :: forall a. (HasV R a, Zip (V R a)) => Binop a
(^-^) = (V R a R -> V R a R -> V R a R) -> a -> a -> a
forall s a b c.
(HasV s a, HasV s b, HasV s c) =>
(V s a s -> V s b s -> V s c s) -> a -> b -> c
onV2 (V R a R -> V R a R -> V R a R
forall (f :: * -> *) s. (Zip f, Num s) => f s -> f s -> f s
(V.^-^) :: Binop (V R a R))

-- The specialization to R helps with type checking. Generalize if needed.