{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE CPP #-}

{-# OPTIONS_GHC -Wall #-}
-- {-# OPTIONS_GHC -Wno-unused-imports #-} -- TEMP

-- | A category of probabilistic functions using discrete distributions

module ConCat.Distribution where

import Prelude hiding (id,(.))

import Data.Map

import ConCat.Misc (R)
import ConCat.AltCat
import qualified ConCat.Category

-- | Distribution category
newtype Dist a b = Dist (a -> Map b R)

-- TODO: generalize Dist to a category transformer

-- | The one category-specific operation.
distrib :: (a -> Map b R) -> Dist a b
distrib :: forall a b. (a -> Map b R) -> Dist a b
distrib = (a -> Map b R) -> Dist a b
forall a b. (a -> Map b R) -> Dist a b
Dist

-- TODO: Perhaps replace 'distrib' with a simpler alternative.

-- | Embed a regular deterministic function
exactly :: (a -> b) -> Dist a b
exactly :: forall a b. (a -> b) -> Dist a b
exactly a -> b
f = (a -> Map b R) -> Dist a b
forall a b. (a -> Map b R) -> Dist a b
Dist ((b -> R -> Map b R) -> R -> b -> Map b R
forall a b c. (a -> b -> c) -> b -> a -> c
flip b -> R -> Map b R
forall k a. k -> a -> Map k a
singleton R
1 (b -> Map b R) -> (a -> b) -> a -> Map b R
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. a -> b
f)
-- exactly f = Dist (\ a -> singleton (f a) 1)

instance Category Dist where
  type Ok Dist = Ord  -- needed for Map keys
  id :: forall a. Ok Dist a => Dist a a
id = (a -> a) -> Dist a a
forall a b. (a -> b) -> Dist a b
exactly a -> a
forall (k :: * -> * -> *) a. (Category k, Ok k a) => k a a
id
  Dist b -> Map c R
g . :: forall b c a. Ok3 Dist a b c => Dist b c -> Dist a b -> Dist a c
. Dist a -> Map b R
f = (a -> Map c R) -> Dist a c
forall a b. (a -> Map b R) -> Dist a b
Dist a -> Map c R
h
   where
     h :: a -> Map c R
h a
a = (R -> R -> R) -> [Map c R] -> Map c R
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
(a -> a -> a) -> f (Map k a) -> Map k a
unionsWith R -> R -> R
forall a. Num a => a -> a -> a
(+) [ (R
p R -> R -> R
forall a. Num a => a -> a -> a
*) (R -> R) -> Map c R -> Map c R
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> b -> Map c R
g b
b | (b
b,R
p) <- Map b R -> [(b, R)]
forall k a. Map k a -> [(k, a)]
toList (a -> Map b R
f a
a) ]

-- Finite maps here denote total functions with missing entries being implicitly
-- zero. If so, equality should probably treat explicit zeros as missing entries
-- as the same. It might be worth removing the zero values as they arise.

instance AssociativePCat Dist where
  lassocP :: forall a b c.
Ok3 Dist a b c =>
Dist (Prod Dist a (Prod Dist b c)) (Prod Dist (Prod Dist a b) c)
lassocP = (Prod Dist a (Prod Dist b c) -> Prod Dist (Prod Dist a b) c)
-> Dist (Prod Dist a (Prod Dist b c)) (Prod Dist (Prod Dist a b) c)
forall a b. (a -> b) -> Dist a b
exactly Prod Dist a (Prod Dist b c) -> Prod Dist (Prod Dist a b) c
forall (k :: * -> * -> *) a b c.
(AssociativePCat k, Ok3 k a b c) =>
k (Prod k a (Prod k b c)) (Prod k (Prod k a b) c)
lassocP
  rassocP :: forall a b c.
Ok3 Dist a b c =>
Dist (Prod Dist (Prod Dist a b) c) (Prod Dist a (Prod Dist b c))
rassocP = (Prod Dist (Prod Dist a b) c -> Prod Dist a (Prod Dist b c))
-> Dist (Prod Dist (Prod Dist a b) c) (Prod Dist a (Prod Dist b c))
forall a b. (a -> b) -> Dist a b
exactly Prod Dist (Prod Dist a b) c -> Prod Dist a (Prod Dist b c)
forall (k :: * -> * -> *) a b c.
(AssociativePCat k, Ok3 k a b c) =>
k (Prod k (Prod k a b) c) (Prod k a (Prod k b c))
rassocP

instance BraidedPCat Dist where swapP :: forall a b. Ok2 Dist a b => Dist (Prod Dist a b) (Prod Dist b a)
swapP = (Prod Dist a b -> Prod Dist b a)
-> Dist (Prod Dist a b) (Prod Dist b a)
forall a b. (a -> b) -> Dist a b
exactly Prod Dist a b -> Prod Dist b a
forall (k :: * -> * -> *) a b.
(BraidedPCat k, Ok2 k a b) =>
k (Prod k a b) (Prod k b a)
swapP

instance MonoidalPCat Dist where
  Dist a -> Map c R
f *** :: forall a b c d.
Ok4 Dist a b c d =>
Dist a c -> Dist b d -> Dist (Prod Dist a b) (Prod Dist c d)
*** Dist b -> Map d R
g = (Prod Dist a b -> Map (c, d) R) -> Dist (Prod Dist a b) (c, d)
forall a b. (a -> Map b R) -> Dist a b
Dist Prod Dist a b -> Map (c, d) R
h
   where
     h :: Prod Dist a b -> Map (c, d) R
h (a
a,b
b) = [((c, d), R)] -> Map (c, d) R
forall k a. Ord k => [(k, a)] -> Map k a
fromList [ ((c
c,d
d),R
pR -> R -> R
forall a. Num a => a -> a -> a
*R
q) | (c
c,R
p) <- Map c R -> [(c, R)]
forall k a. Map k a -> [(k, a)]
toList (a -> Map c R
f a
a), (d
d,R
q) <- Map d R -> [(d, R)]
forall k a. Map k a -> [(k, a)]
toList (b -> Map d R
g b
b) ]
  -- We could default first and second, but the following may be more efficient:
  first :: forall a a' b.
Ok3 Dist a b a' =>
Dist a a' -> Dist (Prod Dist a b) (Prod Dist a' b)
first  (Dist a -> Map a' R
f) = (Prod Dist a b -> Map (a' :* b) R)
-> Dist (Prod Dist a b) (a' :* b)
forall a b. (a -> Map b R) -> Dist a b
Dist (\ (a
a,b
b) -> (a' -> a' :* b) -> Map a' R -> Map (a' :* b) R
forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
mapKeys (,b
b) (a -> Map a' R
f a
a))
  second :: forall a b b'.
Ok3 Dist a b b' =>
Dist b b' -> Dist (Prod Dist a b) (Prod Dist a b')
second (Dist b -> Map b' R
g) = (Prod Dist a b -> Map (a :* b') R)
-> Dist (Prod Dist a b) (a :* b')
forall a b. (a -> Map b R) -> Dist a b
Dist (\ (a
a,b
b) -> (b' -> a :* b') -> Map b' R -> Map (a :* b') R
forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
mapKeys (a
a,) (b -> Map b' R
g b
b))

-- TODO: define (***) less expensively by using toAscList and fromAscList and
-- relying on lexicographic ordering for correctness.
     
instance ProductCat Dist where
  exl :: forall a b. Ok2 Dist a b => Dist (Prod Dist a b) a
exl = (Prod Dist a b -> a) -> Dist (Prod Dist a b) a
forall a b. (a -> b) -> Dist a b
exactly Prod Dist a b -> a
forall (k :: * -> * -> *) a b.
(ProductCat k, Ok2 k a b) =>
k (Prod k a b) a
exl
  exr :: forall a b. Ok2 Dist a b => Dist (Prod Dist a b) b
exr = (Prod Dist a b -> b) -> Dist (Prod Dist a b) b
forall a b. (a -> b) -> Dist a b
exactly Prod Dist a b -> b
forall (k :: * -> * -> *) a b.
(ProductCat k, Ok2 k a b) =>
k (Prod k a b) b
exr
  dup :: forall a. Ok Dist a => Dist a (Prod Dist a a)
dup = (a -> Prod Dist a a) -> Dist a (Prod Dist a a)
forall a b. (a -> b) -> Dist a b
exactly a -> Prod Dist a a
forall (k :: * -> * -> *) a.
(ProductCat k, Ok k a) =>
k a (Prod k a a)
dup

instance AssociativeSCat Dist where
  lassocS :: forall a b c.
Oks Dist '[a, b, c] =>
Dist
  (Coprod Dist a (Coprod Dist b c)) (Coprod Dist (Coprod Dist a b) c)
lassocS = (Coprod Dist a (Coprod Dist b c)
 -> Coprod Dist (Coprod Dist a b) c)
-> Dist
     (Coprod Dist a (Coprod Dist b c)) (Coprod Dist (Coprod Dist a b) c)
forall a b. (a -> b) -> Dist a b
exactly Coprod Dist a (Coprod Dist b c) -> Coprod Dist (Coprod Dist a b) c
forall (k :: * -> * -> *) a b c.
(AssociativeSCat k, Ok3 k a b c) =>
k (Coprod k a (Coprod k b c)) (Coprod k (Coprod k a b) c)
lassocS
  rassocS :: forall a b c.
Oks Dist '[a, b, c] =>
Dist
  (Coprod Dist (Coprod Dist a b) c) (Coprod Dist a (Coprod Dist b c))
rassocS = (Coprod Dist (Coprod Dist a b) c
 -> Coprod Dist a (Coprod Dist b c))
-> Dist
     (Coprod Dist (Coprod Dist a b) c) (Coprod Dist a (Coprod Dist b c))
forall a b. (a -> b) -> Dist a b
exactly Coprod Dist (Coprod Dist a b) c -> Coprod Dist a (Coprod Dist b c)
forall (k :: * -> * -> *) a b c.
(AssociativeSCat k, Ok3 k a b c) =>
k (Coprod k (Coprod k a b) c) (Coprod k a (Coprod k b c))
rassocS

instance BraidedSCat Dist where swapS :: forall a b.
Ok2 Dist a b =>
Dist (Coprod Dist a b) (Coprod Dist b a)
swapS = (Coprod Dist a b -> Coprod Dist b a)
-> Dist (Coprod Dist a b) (Coprod Dist b a)
forall a b. (a -> b) -> Dist a b
exactly Coprod Dist a b -> Coprod Dist b a
forall (k :: * -> * -> *) a b.
(BraidedSCat k, Ok2 k a b) =>
k (Coprod k a b) (Coprod k b a)
swapS

instance MonoidalSCat Dist where
  Dist c -> Map a R
f +++ :: forall a b c d.
Ok4 Dist a b c d =>
Dist c a -> Dist d b -> Dist (Coprod Dist c d) (Coprod Dist a b)
+++ Dist d -> Map b R
g = (Coprod Dist c d -> Map (Either a b) R)
-> Dist (Coprod Dist c d) (Either a b)
forall a b. (a -> Map b R) -> Dist a b
Dist Coprod Dist c d -> Map (Either a b) R
h
   where
     h :: Coprod Dist c d -> Map (Either a b) R
h = (a -> Either a b) -> Map a R -> Map (Either a b) R
forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
mapKeys a -> Either a b
forall a b. a -> Either a b
Left (Map a R -> Map (Either a b) R)
-> (c -> Map a R) -> c -> Map (Either a b) R
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. c -> Map a R
f (c -> Map (Either a b) R)
-> (d -> Map (Either a b) R)
-> Coprod Dist c d
-> Map (Either a b) R
forall (k :: * -> * -> *) a c d.
(MCoproductCat k, Ok3 k a c d) =>
k c a -> k d a -> k (Coprod k c d) a
||| (b -> Either a b) -> Map b R -> Map (Either a b) R
forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
mapKeys b -> Either a b
forall a b. b -> Either a b
Right (Map b R -> Map (Either a b) R)
-> (d -> Map b R) -> d -> Map (Either a b) R
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. d -> Map b R
g
  -- We could default left and right, but the following may be more efficient:
  left :: forall a a' b.
Oks Dist '[a, b, a'] =>
Dist a a' -> Dist (Coprod Dist a b) (Coprod Dist a' b)
left  (Dist a -> Map a' R
f) = (Coprod Dist a b -> Map (a' :+ b) R)
-> Dist (Coprod Dist a b) (a' :+ b)
forall a b. (a -> Map b R) -> Dist a b
Dist ((a' -> a' :+ b) -> Map a' R -> Map (a' :+ b) R
forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
mapKeys a' -> a' :+ b
forall a b. a -> Either a b
Left (Map a' R -> Map (a' :+ b) R)
-> (a -> Map a' R) -> a -> Map (a' :+ b) R
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. a -> Map a' R
f (a -> Map (a' :+ b) R)
-> (b -> Map (a' :+ b) R) -> Coprod Dist a b -> Map (a' :+ b) R
forall (k :: * -> * -> *) a c d.
(MCoproductCat k, Ok3 k a c d) =>
k c a -> k d a -> k (Coprod k c d) a
||| ((a' :+ b) -> R -> Map (a' :+ b) R)
-> R -> (a' :+ b) -> Map (a' :+ b) R
forall a b c. (a -> b -> c) -> b -> a -> c
flip (a' :+ b) -> R -> Map (a' :+ b) R
forall k a. k -> a -> Map k a
singleton R
1 ((a' :+ b) -> Map (a' :+ b) R)
-> (b -> a' :+ b) -> b -> Map (a' :+ b) R
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. b -> a' :+ b
forall a b. b -> Either a b
Right)
  right :: forall a b b'.
Oks Dist '[a, b, b'] =>
Dist b b' -> Dist (Coprod Dist a b) (Coprod Dist a b')
right (Dist b -> Map b' R
g) = (Coprod Dist a b -> Map (a :+ b') R)
-> Dist (Coprod Dist a b) (a :+ b')
forall a b. (a -> Map b R) -> Dist a b
Dist (((a :+ b') -> R -> Map (a :+ b') R)
-> R -> (a :+ b') -> Map (a :+ b') R
forall a b c. (a -> b -> c) -> b -> a -> c
flip (a :+ b') -> R -> Map (a :+ b') R
forall k a. k -> a -> Map k a
singleton R
1 ((a :+ b') -> Map (a :+ b') R)
-> (a -> a :+ b') -> a -> Map (a :+ b') R
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. a -> a :+ b'
forall a b. a -> Either a b
Left (a -> Map (a :+ b') R)
-> (b -> Map (a :+ b') R) -> Coprod Dist a b -> Map (a :+ b') R
forall (k :: * -> * -> *) a c d.
(MCoproductCat k, Ok3 k a c d) =>
k c a -> k d a -> k (Coprod k c d) a
||| (b' -> a :+ b') -> Map b' R -> Map (a :+ b') R
forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
mapKeys b' -> a :+ b'
forall a b. b -> Either a b
Right (Map b' R -> Map (a :+ b') R)
-> (b -> Map b' R) -> b -> Map (a :+ b') R
forall (k :: * -> * -> *) b c a.
(Category k, Ok3 k a b c) =>
k b c -> k a b -> k a c
. b -> Map b' R
g)

-- TODO: test whether the first/second and left/right definitions produce more
-- efficient implementations than the defaults. Can GHC optimize away the
-- singleton dictionaries in the defaults?

instance CoproductCat Dist where
  inl :: forall a b. Ok2 Dist a b => Dist a (Coprod Dist a b)
inl = (a -> Coprod Dist a b) -> Dist a (Coprod Dist a b)
forall a b. (a -> b) -> Dist a b
exactly a -> Coprod Dist a b
forall (k :: * -> * -> *) a b.
(CoproductCat k, Ok2 k a b) =>
k a (Coprod k a b)
inl
  inr :: forall a b. Ok2 Dist a b => Dist b (Coprod Dist a b)
inr = (b -> Coprod Dist a b) -> Dist b (Coprod Dist a b)
forall a b. (a -> b) -> Dist a b
exactly b -> Coprod Dist a b
forall (k :: * -> * -> *) a b.
(CoproductCat k, Ok2 k a b) =>
k b (Coprod k a b)
inr
  jam :: forall a. Ok Dist a => Dist (Coprod Dist a a) a
jam = (Coprod Dist a a -> a) -> Dist (Coprod Dist a a) a
forall a b. (a -> b) -> Dist a b
exactly Coprod Dist a a -> a
forall (k :: * -> * -> *) a.
(CoproductCat k, Ok k a) =>
k (Coprod k a a) a
jam

instance DistribCat Dist where
  distl :: forall a u v.
Ok3 Dist a u v =>
Dist
  (Prod Dist a (Coprod Dist u v))
  (Coprod Dist (Prod Dist a u) (Prod Dist a v))
distl = (Prod Dist a (Coprod Dist u v)
 -> Coprod Dist (Prod Dist a u) (Prod Dist a v))
-> Dist
     (Prod Dist a (Coprod Dist u v))
     (Coprod Dist (Prod Dist a u) (Prod Dist a v))
forall a b. (a -> b) -> Dist a b
exactly Prod Dist a (Coprod Dist u v)
-> Coprod Dist (Prod Dist a u) (Prod Dist a v)
forall (k :: * -> * -> *) a u v.
(DistribCat k, Ok3 k a u v) =>
k (Prod k a (Coprod k u v)) (Coprod k (Prod k a u) (Prod k a v))
distl
  distr :: forall u v b.
Ok3 Dist u v b =>
Dist
  (Prod Dist (Coprod Dist u v) b)
  (Coprod Dist (Prod Dist u b) (Prod Dist v b))
distr = (Prod Dist (Coprod Dist u v) b
 -> Coprod Dist (Prod Dist u b) (Prod Dist v b))
-> Dist
     (Prod Dist (Coprod Dist u v) b)
     (Coprod Dist (Prod Dist u b) (Prod Dist v b))
forall a b. (a -> b) -> Dist a b
exactly Prod Dist (Coprod Dist u v) b
-> Coprod Dist (Prod Dist u b) (Prod Dist v b)
forall (k :: * -> * -> *) u v b.
(DistribCat k, Ok3 k u v b) =>
k (Prod k (Coprod k u v) b) (Coprod k (Prod k u b) (Prod k v b))
distr 

instance Num a => ScalarCat Dist a where
  scale :: a -> Dist a a
scale a
s = (a -> a) -> Dist a a
forall a b. (a -> b) -> Dist a b
exactly (a -> a -> a
forall (k :: * -> * -> *) a. ScalarCat k a => a -> k a a
scale a
s)

-- TODO: ClosedCat.