{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications #-}
-- {-# LANGUAGE StandaloneDeriving #-}

{-# OPTIONS_GHC -Wall #-}
-- {-# OPTIONS_GHC -fno-warn-unused-imports #-}  -- TEMP

#include "ConCat/AbsTy.inc"
AbsTyPragmas

-- | Linear maps as "row-major" functor compositions

module ConCat.Free.LinearRow where

import Prelude hiding (id,(.),zipWith)
import GHC.Exts (Coercible,coerce)
import Data.Foldable (toList)
import GHC.Generics (U1(..),(:*:)(..),(:.:)(..)) -- ,Par1(..)
-- import GHC.TypeLits (KnownNat)

import Data.Constraint
import Data.Key (Zip(..))
import Data.Distributive (collect)
import Data.Functor.Rep (Representable)
-- import qualified Data.Functor.Rep as R
import Text.PrettyPrint.HughesPJClass hiding (render)
import Control.Newtype.Generics
-- import Data.Vector.Sized (Vector)

import ConCat.Misc ((:*),PseudoFun(..),oops,R,Binop,inNew2)
import ConCat.Orphans ()
import ConCat.Free.VectorSpace
-- The following import allows the instances to type-check. Why?
import qualified ConCat.Category as C
import ConCat.AltCat hiding (const)
import ConCat.Rep
-- import ConCat.Free.Diagonal
import qualified ConCat.AdditiveFun as Ad
import ConCat.Additive

AbsTyImports

-- TODO: generalize from Num to Semiring

{--------------------------------------------------------------------
    Linear maps
--------------------------------------------------------------------}

-- Linear map from a s to b s
infixr 1 :-*
type (a :-* b) s = b (a s)

-- TODO: consider instead
-- 
--   type Linear = (:.:)
-- 
-- so that Linear itself forms a vector space.

infixr 9 $*
-- Apply a linear map
($*), lapplyL :: forall s a b. (Zip a, Foldable a, Functor b, Num s)
              => (a :-* b) s -> a s -> b s
(:-*) a b s
as $* :: forall s (a :: * -> *) (b :: * -> *).
(Zip a, Foldable a, Functor b, Num s) =>
(:-*) a b s -> a s -> b s
$* a s
a = (a s -> a s -> s
forall s (f :: * -> *).
(Zip f, Foldable f, Num s) =>
f s -> f s -> s
<.> a s
a) (a s -> s) -> (:-*) a b s -> b s
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (:-*) a b s
as
lapplyL :: forall s (a :: * -> *) (b :: * -> *).
(Zip a, Foldable a, Functor b, Num s) =>
(:-*) a b s -> a s -> b s
lapplyL = (:-*) a b s -> a s -> b s
forall s (a :: * -> *) (b :: * -> *).
(Zip a, Foldable a, Functor b, Num s) =>
(:-*) a b s -> a s -> b s
($*)
{-# INLINE ($*) #-}
{-# INLINE lapplyL #-}

zeroL :: (Zeroable a, Zeroable b, Num s) => (a :-* b) s
zeroL :: forall (a :: * -> *) (b :: * -> *) s.
(Zeroable a, Zeroable b, Num s) =>
(:-*) a b s
zeroL = (:.:) b a s -> b (a s)
forall k2 k1 (f :: k2 -> *) (g :: k1 -> k2) (p :: k1).
(:.:) f g p -> f (g p)
unComp1 (:.:) b a s
forall (f :: * -> *) a. (Pointed f, Num a) => f a
zeroV
-- zeroL = point zeroV

scaleL :: (Diagonal a, Num s) => s -> (a :-* a) s
scaleL :: forall (a :: * -> *) s. (Diagonal a, Num s) => s -> (:-*) a a s
scaleL = s -> s -> a (a s)
forall (h :: * -> *) a. Diagonal h => a -> a -> h (h a)
diag s
0

{--------------------------------------------------------------------
    Other operations
--------------------------------------------------------------------}

---- Category

-- Identity linear map
idL :: (Diagonal a, Num s) => (a :-* a) s
idL :: forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a a s
idL = s -> (:-*) a a s
forall (a :: * -> *) s. (Diagonal a, Num s) => s -> (:-*) a a s
scaleL s
1
{-# INLINE idL #-}

-- Compose linear transformations
compL :: (Zip a, Zip b, Zeroable a, Foldable b, Functor c, Num s)
     => (b :-* c) s -> (a :-* b) s -> (a :-* c) s
(:-*) b c s
bc compL :: forall (a :: * -> *) (b :: * -> *) (c :: * -> *) s.
(Zip a, Zip b, Zeroable a, Foldable b, Functor c, Num s) =>
(:-*) b c s -> (:-*) a b s -> (:-*) a c s
`compL` (:-*) a b s
ab = (\ b s
b -> (:-*) a b s -> a s
forall (m :: * -> *) (n :: * -> *) a.
(Functor m, Foldable m, Zeroable n, Zip n, Num a) =>
m (n a) -> n a
sumV ((s -> a s -> a s) -> b s -> (:-*) a b s -> (:-*) a b s
forall a b c. (a -> b -> c) -> b a -> b b -> b c
forall (f :: * -> *) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith s -> a s -> a s
forall (f :: * -> *) s. (Functor f, Num s) => s -> f s -> f s
(*^) b s
b (:-*) a b s
ab)) (b s -> a s) -> (:-*) b c s -> c (a s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (:-*) b c s
bc
{-# INLINE compL #-}

#if 0
bc :: c (b s)
ab :: b (a s)
b  :: b s
zipWith (*^) b ab :: b (a s)
sumV (zipWith (*^) b ab) :: a s
\ b -> sumV (zipWith (*^) b ab) :: b -> a s
(\ b -> sumV (zipWith (*^) b ab)) <$> bc :: c (a s)
#endif

---- Product

exlL :: (Zeroable a, Diagonal a, Zeroable b, Num s)
     => (a :*: b :-* a) s
exlL :: forall (a :: * -> *) (b :: * -> *) s.
(Zeroable a, Diagonal a, Zeroable b, Num s) =>
(:-*) (a :*: b) a s
exlL = (a s -> b s -> (:*:) a b s
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: b s
forall (f :: * -> *) a. (Pointed f, Num a) => f a
zeroV) (a s -> (:*:) a b s) -> a (a s) -> a ((:*:) a b s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a (a s)
forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a a s
idL
{-# INLINE exlL #-}

exrL :: (Zeroable b, Diagonal b, Zeroable a, Num s)
     => (a :*: b :-* b) s
exrL :: forall (b :: * -> *) (a :: * -> *) s.
(Zeroable b, Diagonal b, Zeroable a, Num s) =>
(:-*) (a :*: b) b s
exrL = (a s
forall (f :: * -> *) a. (Pointed f, Num a) => f a
zeroV a s -> b s -> (:*:) a b s
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*:) (b s -> (:*:) a b s) -> b (b s) -> b ((:*:) a b s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> b (b s)
forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a a s
idL
{-# INLINE exrL #-}

-- crossL :: (a :-* c) s -> (b :-* d) s -> (a :*: b :-* c :*: d) s
-- f `crossL` g = (f `compL` exlL) `forkL` (g `compL` exrL)

crossL :: (Zeroable a, Zeroable b, Zeroable c, Zeroable d, Num s, Zip c, Zip d)
       => (a :-* c) s -> (b :-* d) s -> (a :*: b :-* c :*: d) s
(:-*) a c s
f crossL :: forall (a :: * -> *) (b :: * -> *) (c :: * -> *) (d :: * -> *) s.
(Zeroable a, Zeroable b, Zeroable c, Zeroable d, Num s, Zip c,
 Zip d) =>
(:-*) a c s -> (:-*) b d s -> (:-*) (a :*: b) (c :*: d) s
`crossL` (:-*) b d s
g = ((:-*) a c s
f (:-*) a c s -> (:-*) a d s -> (:-*) a (c :*: d) s
forall (a :: * -> *) (c :: * -> *) s (d :: * -> *).
(:-*) a c s -> (:-*) a d s -> (:-*) a (c :*: d) s
`forkL` (:-*) a d s
forall (a :: * -> *) (b :: * -> *) s.
(Zeroable a, Zeroable b, Num s) =>
(:-*) a b s
zeroL) (:-*) a (c :*: d) s
-> (:-*) b (c :*: d) s -> (:-*) (a :*: b) (c :*: d) s
forall (c :: * -> *) (a :: * -> *) s (b :: * -> *).
Zip c =>
(:-*) a c s -> (:-*) b c s -> (:-*) (a :*: b) c s
`joinL` ((:-*) b c s
forall (a :: * -> *) (b :: * -> *) s.
(Zeroable a, Zeroable b, Num s) =>
(:-*) a b s
zeroL (:-*) b c s -> (:-*) b d s -> (:-*) b (c :*: d) s
forall (a :: * -> *) (c :: * -> *) s (d :: * -> *).
(:-*) a c s -> (:-*) a d s -> (:-*) a (c :*: d) s
`forkL` (:-*) b d s
g)

forkL :: (a :-* c) s -> (a :-* d) s -> (a :-* c :*: d) s
forkL :: forall (a :: * -> *) (c :: * -> *) s (d :: * -> *).
(:-*) a c s -> (:-*) a d s -> (:-*) a (c :*: d) s
forkL = c (a s) -> d (a s) -> (:*:) c d (a s)
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
(:*:)

dupL :: (Diagonal a, Num s) => (a :-* (a :*: a)) s
dupL :: forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a (a :*: a) s
dupL = (:-*) a a s
forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a a s
idL (:-*) a a s -> (:-*) a a s -> (:-*) a (a :*: a) s
forall (a :: * -> *) (c :: * -> *) s (d :: * -> *).
(:-*) a c s -> (:-*) a d s -> (:-*) a (c :*: d) s
`forkL` (:-*) a a s
forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a a s
idL

itL :: (a :-* U1) s
itL :: forall (a :: * -> *) s. (:-*) a U1 s
itL = U1 (a s)
forall k (p :: k). U1 p
U1

---- Coproduct as direct sum (represented as Cartesian product)

inlL :: (Zeroable a, Diagonal a, Zeroable b, Num s)
     => (a :-* a :*: b) s
inlL :: forall (a :: * -> *) (b :: * -> *) s.
(Zeroable a, Diagonal a, Zeroable b, Num s) =>
(:-*) a (a :*: b) s
inlL = (:-*) a a s
forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a a s
idL (:-*) a a s -> b (a s) -> (:*:) a b (a s)
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: b (a s)
forall (a :: * -> *) (b :: * -> *) s.
(Zeroable a, Zeroable b, Num s) =>
(:-*) a b s
zeroL
{-# INLINE inlL #-}

inrL :: (Zeroable a, Zeroable b, Diagonal b, Num s)
     => (b :-* a :*: b) s
inrL :: forall (a :: * -> *) (b :: * -> *) s.
(Zeroable a, Zeroable b, Diagonal b, Num s) =>
(:-*) b (a :*: b) s
inrL = (:-*) b a s
forall (a :: * -> *) (b :: * -> *) s.
(Zeroable a, Zeroable b, Num s) =>
(:-*) a b s
zeroL (:-*) b a s -> b (b s) -> (:*:) a b (b s)
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: b (b s)
forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a a s
idL
{-# INLINE inrL #-}

joinL :: Zip c => (a :-* c) s -> (b :-* c) s -> (a :*: b :-* c) s
joinL :: forall (c :: * -> *) (a :: * -> *) s (b :: * -> *).
Zip c =>
(:-*) a c s -> (:-*) b c s -> (:-*) (a :*: b) c s
joinL = (a s -> b s -> (:*:) a b s)
-> c (a s) -> c (b s) -> c ((:*:) a b s)
forall a b c. (a -> b -> c) -> c a -> c b -> c c
forall (f :: * -> *) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith a s -> b s -> (:*:) a b s
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
(:*:)
{-# INLINE joinL #-}

jamL :: (Diagonal a, Zip a, Num s) => ((a :*: a) :-* a) s
jamL :: forall (a :: * -> *) s.
(Diagonal a, Zip a, Num s) =>
(:-*) (a :*: a) a s
jamL = (:-*) a a s
forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a a s
idL (:-*) a a s -> (:-*) a a s -> (:-*) (a :*: a) a s
forall (c :: * -> *) (a :: * -> *) s (b :: * -> *).
Zip c =>
(:-*) a c s -> (:-*) b c s -> (:-*) (a :*: b) c s
`joinL` (:-*) a a s
forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a a s
idL

{--------------------------------------------------------------------
    Category
--------------------------------------------------------------------}

newtype L s a b = L ((V s a :-* V s b) s)
-- data L s a b = L ((V s a :-* V s b) s)

type LR = L R

-- Using data is a workaround for
-- <https://ghc.haskell.org/trac/ghc/ticket/13083#ticket> when I need it. See
-- notes from 2016-01-07.

-- deriving instance Show ((V s a :-* V s b) s) => Show (L s a b)

flatten :: (Foldable (V s a), Foldable (V s b)) => L s a b -> [[s]]
flatten :: forall s a b.
(Foldable (V s a), Foldable (V s b)) =>
L s a b -> [[s]]
flatten = (V s a s -> [s]) -> [V s a s] -> [[s]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap V s a s -> [s]
forall a. V s a a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList ([V s a s] -> [[s]]) -> (L s a b -> [V s a s]) -> L s a b -> [[s]]
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. V s b (V s a s) -> [V s a s]
forall a. V s b a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (V s b (V s a s) -> [V s a s])
-> (L s a b -> V s b (V s a s)) -> L s a b -> [V s a s]
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. L s a b -> O (L s a b)
L s a b -> V s b (V s a s)
forall n. Newtype n => n -> O n
unpack

instance (Foldable (V s a), Foldable (V s b), Show s) => Show (L s a b) where
  show :: L s a b -> String
show = [[s]] -> String
forall a. Show a => a -> String
show ([[s]] -> String) -> (L s a b -> [[s]]) -> L s a b -> String
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. L s a b -> [[s]]
forall s a b.
(Foldable (V s a), Foldable (V s b)) =>
L s a b -> [[s]]
flatten

instance (Foldable (V s a), Foldable (V s b), Pretty s) => Pretty (L s a b) where
  pPrint :: L s a b -> Doc
pPrint = [[s]] -> Doc
forall a. Pretty a => a -> Doc
pPrint ([[s]] -> Doc) -> (L s a b -> [[s]]) -> L s a b -> Doc
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. L s a b -> [[s]]
forall s a b.
(Foldable (V s a), Foldable (V s b)) =>
L s a b -> [[s]]
flatten

-- TODO: maybe 2D matrix form.

-- Just for AbsTy in ConCat.Circuit
instance Newtype (L s a b) where
  type O (L s a b) = (V s a :-* V s b) s
  pack :: O (L s a b) -> L s a b
pack O (L s a b)
ab = (:-*) (V s a) (V s b) s -> L s a b
forall s a b. (:-*) (V s a) (V s b) s -> L s a b
L O (L s a b)
(:-*) (V s a) (V s b) s
ab
  unpack :: L s a b -> O (L s a b)
unpack (L (:-*) (V s a) (V s b) s
ab) = O (L s a b)
(:-*) (V s a) (V s b) s
ab

instance HasRep (L s a b) where
  type Rep (L s a b) = (V s a :-* V s b) s
  abst :: Rep (L s a b) -> L s a b
abst Rep (L s a b)
ab = (:-*) (V s a) (V s b) s -> L s a b
forall s a b. (:-*) (V s a) (V s b) s -> L s a b
L Rep (L s a b)
(:-*) (V s a) (V s b) s
ab
  repr :: L s a b -> Rep (L s a b)
repr (L (:-*) (V s a) (V s b) s
ab) = Rep (L s a b)
(:-*) (V s a) (V s b) s
ab
  {-# INLINE abst #-}
  {-# INLINE repr #-}

AbsTy(L s a b)

-- instance HasV s (L s a b) where
--   type V s (L s a b) = V s b :.: V s a
--   toV = abst . repr
--   unV = abst . repr

instance HasV s (Rep (L s a b)) => HasV s (L s a b)

type OkLF f = (Foldable f, Zeroable f, Zip f, Diagonal f)

type OkLM' s a = (Num s, HasV s a, OkLF (V s a))

class    (Num s, Additive a, HasV s a, OkLF (V s a)) => OkLM s a
instance (Num s, Additive a, HasV s a, OkLF (V s a)) => OkLM s a

zeroLM :: (Num s, Zeroable (V s a), Zeroable (V s b)) => L s a b
zeroLM :: forall s a b.
(Num s, Zeroable (V s a), Zeroable (V s b)) =>
L s a b
zeroLM = (:-*) (V s a) (V s b) s -> L s a b
forall s a b. (:-*) (V s a) (V s b) s -> L s a b
L (:-*) (V s a) (V s b) s
forall (a :: * -> *) (b :: * -> *) s.
(Zeroable a, Zeroable b, Num s) =>
(:-*) a b s
zeroL
{-# INLINE zeroLM #-}

addLM :: Ok2 (L s) a b => Binop (L s a b)
addLM :: forall s a b. Ok2 (L s) a b => Binop (L s a b)
addLM = ((O (L s a b) -> O (L s a b) -> O (L s a b)) -> Binop (L s a b)
((:-*) (V s a) (V s b) s
 -> (:-*) (V s a) (V s b) s -> (:-*) (V s a) (V s b) s)
-> Binop (L s a b)
forall p q r.
(Newtype p, Newtype q, Newtype r) =>
(O p -> O q -> O r) -> p -> q -> r
inNew2(((:-*) (V s a) (V s b) s
  -> (:-*) (V s a) (V s b) s -> (:-*) (V s a) (V s b) s)
 -> Binop (L s a b))
-> ((s -> s -> s)
    -> (:-*) (V s a) (V s b) s
    -> (:-*) (V s a) (V s b) s
    -> (:-*) (V s a) (V s b) s)
-> (s -> s -> s)
-> Binop (L s a b)
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
.(V s a s -> V s a s -> V s a s)
-> (:-*) (V s a) (V s b) s
-> (:-*) (V s a) (V s b) s
-> (:-*) (V s a) (V s b) s
forall a b c. (a -> b -> c) -> V s b a -> V s b b -> V s b c
forall (f :: * -> *) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith((V s a s -> V s a s -> V s a s)
 -> (:-*) (V s a) (V s b) s
 -> (:-*) (V s a) (V s b) s
 -> (:-*) (V s a) (V s b) s)
-> ((s -> s -> s) -> V s a s -> V s a s -> V s a s)
-> (s -> s -> s)
-> (:-*) (V s a) (V s b) s
-> (:-*) (V s a) (V s b) s
-> (:-*) (V s a) (V s b) s
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
.(s -> s -> s) -> V s a s -> V s a s -> V s a s
forall a b c. (a -> b -> c) -> V s a a -> V s a b -> V s a c
forall (f :: * -> *) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith) s -> s -> s
forall a. Num a => a -> a -> a
(+)

instance Ok2 (L s) a b => Ad.Additive (L s a b) where
  zero :: L s a b
zero  = L s a b
forall s a b.
(Num s, Zeroable (V s a), Zeroable (V s b)) =>
L s a b
zeroLM
  ^+^ :: L s a b -> L s a b -> L s a b
(^+^) = L s a b -> L s a b -> L s a b
forall s a b. Ok2 (L s) a b => Binop (L s a b)
addLM

-- lapply' :: forall s a b. (HasV s a, HasV s b, Zip (V s a), Zip (V s b), Foldable (V s a), Num s)
--         => L s a b -> (a -> b)
-- lapply' (L as) a = unV ((<.> toV a) <$> as)

-- as :: V s b (V s a s)
-- a :: a
-- toV a :: V s a s

instance Category (L s) where
  type Ok (L s) = OkLM s
  id :: forall a. Ok (L s) a => L s a a
id = Rep (L s a a) -> L s a a
forall a. HasRep a => Rep a -> a
abst Rep (L s a a)
(:-*) (V s a) (V s a) s
forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a a s
idL
  . :: forall b c a. Ok3 (L s) a b c => L s b c -> L s a b -> L s a c
(.) = (Rep (L s b c) -> Rep (L s a b) -> Rep (L s a c))
-> L s b c -> L s a b -> L s a c
forall p q r.
(HasRep p, HasRep q, HasRep r) =>
(Rep p -> Rep q -> Rep r) -> p -> q -> r
inAbst2 Rep (L s b c) -> Rep (L s a b) -> Rep (L s a c)
(:-*) (V s b) (V s c) s
-> (:-*) (V s a) (V s b) s -> (:-*) (V s a) (V s c) s
forall (a :: * -> *) (b :: * -> *) (c :: * -> *) s.
(Zip a, Zip b, Zeroable a, Foldable b, Functor c, Num s) =>
(:-*) b c s -> (:-*) a b s -> (:-*) a c s
compL
  {-# INLINE id #-}
  {-# INLINE (.) #-}

instance OpCon (:*) (Sat (OkLM s)) where inOp :: forall a b.
(Sat (OkLM s) a && Sat (OkLM s) b) |- Sat (OkLM s) (a :* b)
inOp = (Con (Sat (OkLM s) a && Sat (OkLM s) b)
 :- Con (Sat (OkLM s) (a :* b)))
-> (Sat (OkLM s) a && Sat (OkLM s) b) |- Sat (OkLM s) (a :* b)
forall a b. (Con a :- Con b) -> a |- b
Entail (((OkLM s a, OkLM s b) => Dict (OkLM s (a :* b)))
-> (OkLM s a, OkLM s b) :- OkLM s (a :* b)
forall (a :: Constraint) (b :: Constraint). (a => Dict b) -> a :- b
Sub Dict (OkLM s (a :* b))
(OkLM s a, OkLM s b) => Dict (OkLM s (a :* b))
forall (a :: Constraint). a => Dict a
Dict)
-- instance OpCon (->) (Sat (OkLM s)) where inOp = Entail (Sub Dict)

instance MonoidalPCat (L s) where
  *** :: forall a b c d.
Ok4 (L s) a b c d =>
L s a c -> L s b d -> L s (Prod (L s) a b) (Prod (L s) c d)
(***) = (Rep (L s a c) -> Rep (L s b d) -> Rep (L s (a :* b) (c :* d)))
-> L s a c -> L s b d -> L s (a :* b) (c :* d)
forall p q r.
(HasRep p, HasRep q, HasRep r) =>
(Rep p -> Rep q -> Rep r) -> p -> q -> r
inAbst2 Rep (L s a c) -> Rep (L s b d) -> Rep (L s (a :* b) (c :* d))
(:-*) (V s a) (V s c) s
-> (:-*) (V s b) (V s d) s
-> (:-*) (V s a :*: V s b) (V s c :*: V s d) s
forall (a :: * -> *) (b :: * -> *) (c :: * -> *) (d :: * -> *) s.
(Zeroable a, Zeroable b, Zeroable c, Zeroable d, Num s, Zip c,
 Zip d) =>
(:-*) a c s -> (:-*) b d s -> (:-*) (a :*: b) (c :*: d) s
crossL
  
instance BraidedPCat  (L s)

instance ProductCat (L s) where
  -- type Prod (L s) = (,)
  exl :: forall a b. Ok2 (L s) a b => L s (Prod (L s) a b) a
exl = Rep (L s (a :* b) a) -> L s (a :* b) a
forall a. HasRep a => Rep a -> a
abst Rep (L s (a :* b) a)
(:-*) (V s a :*: V s b) (V s a) s
forall (a :: * -> *) (b :: * -> *) s.
(Zeroable a, Diagonal a, Zeroable b, Num s) =>
(:-*) (a :*: b) a s
exlL
  exr :: forall a b. Ok2 (L s) a b => L s (Prod (L s) a b) b
exr = Rep (L s (a :* b) b) -> L s (a :* b) b
forall a. HasRep a => Rep a -> a
abst Rep (L s (a :* b) b)
(:-*) (V s a :*: V s b) (V s b) s
forall (b :: * -> *) (a :: * -> *) s.
(Zeroable b, Diagonal b, Zeroable a, Num s) =>
(:-*) (a :*: b) b s
exrL
  dup :: forall a. Ok (L s) a => L s a (Prod (L s) a a)
dup = Rep (L s a (a :* a)) -> L s a (a :* a)
forall a. HasRep a => Rep a -> a
abst (:-*) (V s a) (V s a :*: V s a) s
Rep (L s a (a :* a))
forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a (a :*: a) s
dupL
  {-# INLINE exl #-}
  {-# INLINE exr #-}
  {-# INLINE dup #-}

-- instance Num s => UnitCat (L s)

instance OkAdd (L s) where okAdd :: forall a. Ok' (L s) a |- Sat Additive a
okAdd = (Con (Sat (OkLM s) a) :- Con (Sat Additive a))
-> Sat (OkLM s) a |- Sat Additive a
forall a b. (Con a :- Con b) -> a |- b
Entail ((OkLM s a => Dict (Additive a)) -> OkLM s a :- Additive a
forall (a :: Constraint) (b :: Constraint). (a => Dict b) -> a :- b
Sub Dict (Additive a)
OkLM s a => Dict (Additive a)
forall (a :: Constraint). a => Dict a
Dict)

instance CoproductPCat (L s) where
  inlP :: forall a b. Ok2 (L s) a b => L s a (CoprodP (L s) a b)
inlP = Rep (L s a (CoprodP (L s) a b)) -> L s a (CoprodP (L s) a b)
forall a. HasRep a => Rep a -> a
abst (:-*) (V s a) (V s a :*: V s b) s
Rep (L s a (CoprodP (L s) a b))
forall (a :: * -> *) (b :: * -> *) s.
(Zeroable a, Diagonal a, Zeroable b, Num s) =>
(:-*) a (a :*: b) s
inlL
  inrP :: forall a b. Ok2 (L s) a b => L s b (CoprodP (L s) a b)
inrP = Rep (L s b (CoprodP (L s) a b)) -> L s b (CoprodP (L s) a b)
forall a. HasRep a => Rep a -> a
abst (:-*) (V s b) (V s a :*: V s b) s
Rep (L s b (CoprodP (L s) a b))
forall (a :: * -> *) (b :: * -> *) s.
(Zeroable a, Zeroable b, Diagonal b, Num s) =>
(:-*) b (a :*: b) s
inrL
  jamP :: forall a. Ok (L s) a => L s (CoprodP (L s) a a) a
jamP = Rep (L s (CoprodP (L s) a a) a) -> L s (CoprodP (L s) a a) a
forall a. HasRep a => Rep a -> a
abst Rep (L s (CoprodP (L s) a a) a)
(:-*) (V s a :*: V s a) (V s a) s
forall (a :: * -> *) s.
(Diagonal a, Zip a, Num s) =>
(:-*) (a :*: a) a s
jamL
  {-# INLINE inlP #-}
  {-# INLINE inrP #-}
  {-# INLINE jamP #-}

instance (r ~ Rep a, V s r ~ V s a, Ok (L s) a) => RepCat (L s) a r where
  reprC :: L s a r
reprC = (:-*) (V s a) (V s r) s -> L s a r
forall s a b. (:-*) (V s a) (V s b) s -> L s a b
L (:-*) (V s a) (V s r) s
(:-*) (V s a) (V s a) s
forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a a s
idL
  abstC :: L s r a
abstC = (:-*) (V s r) (V s a) s -> L s r a
forall s a b. (:-*) (V s a) (V s b) s -> L s a b
L (:-*) (V s r) (V s a) s
(:-*) (V s a) (V s a) s
forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a a s
idL

-- idL :: (a :-* a) s
--     ~  V s (V s a s)
-- L id  :: V s (V s a s)
--       ~  V s (V s r s)

#if 0
instance (Ok2 (L s) a b, Coercible (V s a) (V s b)) => CoerceCat (L s) a b where
  coerceC = coerce (id :: L s a a)
#else
instance ( -- Ok2 (L s) a b
           Num s, Diagonal (V s a)
         -- , Coercible (V s a) (V s b)
         , Coercible (Rep (L s a a)) (Rep (L s a b))
         -- , Coercible (V s a (V s a s)) (V s b (V s a s))
         ) => CoerceCat (L s) a b where
  -- coerceC = coerce (id :: L s a a)
  coerceC :: L s a b
coerceC = V s b (V s a s) -> L s a b
forall s a b. (:-*) (V s a) (V s b) s -> L s a b
L ((:-*) (V s a) (V s a) s -> V s b (V s a s)
forall a b. Coercible a b => a -> b
coerce (Rep (L s a a)
(:-*) (V s a) (V s a) s
forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a a s
idL :: Rep (L s a a)))
#endif

-- -- Okay:
-- foo :: L Float (L Float Float Float) (Par1 (V Float Float Float))
-- foo = coerceC
-- -- foo = L (coerce (idL :: Rep (L Float Float Float)))


-- -- -- Fail
-- foo :: L Float (L Float Float (Float, Float)) ((Par1 :*: Par1) (V Float Float Float))
-- foo = coerceC
-- -- foo = L (coerce (idL :: Rep (L Float Float Float)))

-- -- 
-- foo :: Rep (L Float (L Float Float (Float, Float)) ((Par1 :*: Par1) (V Float Float Float)))
-- foo = coerce (idL :: Rep (L Float Float Float))


-- We can't make a ClosedCat instance compatible with the ProductCat instance.
-- We'd have to change the latter to use the tensor product.

-- type instance Exp (L s) = (:.:)

-- Conversion to linear function
lapply :: (Num s, Ok2 (L s) a b) => L s a b -> (a -> b)
lapply :: forall s a b. (Num s, Ok2 (L s) a b) => L s a b -> a -> b
lapply (L (:-*) (V s a) (V s b) s
gfa) = V s b s -> b
forall s a. HasV s a => V s a s -> a
unV (V s b s -> b) -> (a -> V s b s) -> a -> b
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. (:-*) (V s a) (V s b) s -> V s a s -> V s b s
forall s (a :: * -> *) (b :: * -> *).
(Zip a, Foldable a, Functor b, Num s) =>
(:-*) a b s -> a s -> b s
lapplyL (:-*) (V s a) (V s b) s
gfa (V s a s -> V s b s) -> (a -> V s a s) -> a -> V s b s
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. a -> V s a s
forall s a. HasV s a => a -> V s a s
toV
{-# INLINE lapply #-}

type HasL s a = (HasV s a, Diagonal (V s a), Num s)  

type HasLin s a b =
  (HasV s a, HasV s b, Diagonal (V s a), Representable (V s b), Num s)  

linear :: forall s a b. HasLin s a b => (a -> b) -> L s a b
linear :: forall s a b. HasLin s a b => (a -> b) -> L s a b
linear a -> b
f = (:-*) (V s a) (V s b) s -> L s a b
forall s a b. (:-*) (V s a) (V s b) s -> L s a b
L ((V s a s -> V s b s) -> (:-*) (V s a) (V s b) s
forall s (f :: * -> *) (g :: * -> *).
(Diagonal f, Representable g, Num s) =>
(f s -> g s) -> (:-*) f g s
linearF ((a -> b) -> V s a s -> V s b s
forall s a b.
(HasV s a, HasV s b) =>
(a -> b) -> V s a s -> V s b s
inV a -> b
f))
{-# INLINE linear #-}

linearF :: forall s f g. (Diagonal f, Representable g, Num s)
        => (f s -> g s) -> (f :-* g) s
-- linearF q = undual <$> distribute q
-- linearF q = distribute (q <$> idL)
-- linearF q = distribute (fmap q idL)
-- linearF q = collect q idL
linearF :: forall s (f :: * -> *) (g :: * -> *).
(Diagonal f, Representable g, Num s) =>
(f s -> g s) -> (:-*) f g s
linearF = ((f s -> g s) -> f (f s) -> (:-*) f g s)
-> f (f s) -> (f s -> g s) -> (:-*) f g s
forall a b c. (a -> b -> c) -> b -> a -> c
flip (f s -> g s) -> f (f s) -> (:-*) f g s
forall (g :: * -> *) (f :: * -> *) a b.
(Distributive g, Functor f) =>
(a -> g b) -> f a -> g (f b)
forall (f :: * -> *) a b. Functor f => (a -> g b) -> f a -> g (f b)
collect f (f s)
forall (a :: * -> *) s. (Diagonal a, Num s) => (:-*) a a s
idL
{-# INLINE linearF #-}

-- q :: f s -> g s
--   :: (->) (f s) (g s)
-- distribute q :: g (f s -> s)
-- undual <$> distribute q :: g (f s)
--                         == (f :-* g) s

-- undual :: (Diagonal f, Num s) => (f s -> s) -> f s
-- undual p = p <$> idL

-- q :: f s -> g s
-- idL :: f (f s)
-- fmap q idL :: f (g s)
-- distribute (fmap q idL) :: g (f s)

-- collect :: (Distributive g, Functor f) => (a -> g b) -> f a -> g (f b)
-- collect f = distribute . fmap f

scalarMul :: OkLM s a => s -> L s a a
scalarMul :: forall s a. OkLM s a => s -> L s a a
scalarMul = (:-*) (V s a) (V s a) s -> L s a a
forall s a b. (:-*) (V s a) (V s b) s -> L s a b
L ((:-*) (V s a) (V s a) s -> L s a a)
-> (s -> (:-*) (V s a) (V s a) s) -> s -> L s a a
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. s -> (:-*) (V s a) (V s a) s
forall (a :: * -> *) s. (Diagonal a, Num s) => s -> (:-*) a a s
scaleL

negateLM :: OkLM s a => L s a a
negateLM :: forall s a. OkLM s a => L s a a
negateLM = s -> L s a a
forall s a. OkLM s a => s -> L s a a
scalarMul (-s
1)

#if 0

{--------------------------------------------------------------------
    Functors
--------------------------------------------------------------------}

data Lapply s

instance FunctorC (Lapply s) (L s) (->) where fmapC = lapply

data Linear s

instance FunctorC (Linear s) (->) (L s) where fmapC = linear

#endif

{--------------------------------------------------------------------
    CCC conversion
--------------------------------------------------------------------}

lmap :: forall s a b. (a -> b) -> L s a b
lmap :: forall s a b. (a -> b) -> L s a b
lmap a -> b
_ = String -> L s a b
forall b. String -> b
oops String
"lmap called"
{-# NOINLINE lmap #-}
{-# RULES "lmap" forall h. lmap h = toCcc h #-}
{-# ANN lmap (PseudoFun 1) #-}

{--------------------------------------------------------------------
   Some specializations 
--------------------------------------------------------------------}

#if 0

type One = Par1
type Two = One :*: One

-- Becomes (*) (and casts)
{-# SPECIALIZE compL :: Num s =>
  (One :-* One) s -> (One :-* One) s -> (One :-* One) s #-}

-- Becomes timesFloat
{-# SPECIALIZE compL ::
  (One :-* One) Float -> (One :-* One) Float -> (One :-* One) Float #-}

-- Becomes + (* ww1 ww4) (* ww2 ww5)
{-# SPECIALIZE compL :: Num s =>
  (Two :-* One) s -> (One :-* Two) s -> (One :-* One) s #-}

-- Becomes plusFloat# (timesFloat# x y) (timesFloat# x1 y1)
{-# SPECIALIZE compL ::
  (Two :-* One) Float -> (One :-* Two) Float -> (One :-* One) Float #-}

type LRRR = L Float Float Float

-- Becomes timesFloat (and casts)
{-# SPECIALIZE (.) :: LRRR -> LRRR -> LRRR #-}

#endif

-- Experiment

{-# RULES

-- "assoc L (.) right" forall (f :: L s a b) g h. (h . g) . f = h . (g . f)

-- Alternatively (but not both!),

-- "assoc L (.) left"  forall (f :: L s a b) g h. h . (g . f) = (h . g) . f

 #-}