{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE ConstraintKinds     #-}
{-# LANGUAGE EmptyCase           #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE Rank2Types          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DefaultSignatures   #-}

{-# OPTIONS_GHC -Wall #-}

-- {-# OPTIONS_GHC -fno-warn-unused-imports #-} -- TEMP
-- {-# OPTIONS_GHC -fno-warn-unused-binds   #-} -- TEMP

----------------------------------------------------------------------
-- |
-- Module      :  ConCat.Scan
-- Copyright   :  (c) 2016 Conal Elliott
-- 
-- Maintainer  :  conal@tabula.com
-- Stability   :  experimental
-- 
-- Parallel scan
----------------------------------------------------------------------

module ConCat.Scan
  ( LScan(..)
  , lscanT, lscanTraversable
  , lsums, lproducts, lAlls, lAnys, lParities
  , multiples, powers, iota
  ) where

import Prelude hiding (zip,unzip,zipWith)

import Data.Monoid ((<>),Sum(..),Product(..),All(..),Any(..))
import Control.Arrow ((***),first)
import Data.Traversable (mapAccumL)
import Data.Tuple (swap)
import GHC.Generics

import Control.Newtype.Generics (Newtype(..))

import Data.Key
import Data.Pointed

import ConCat.Misc ((:*),Parity(..),absurdF,unzip) -- , Unop
-- import ConCat.Misc (absurdF)

class Functor f => LScan f where
  lscan :: forall a. Monoid a => f a -> f a :* a
  default lscan :: (Generic1 f, LScan (Rep1 f), Monoid a) => f a -> f a :* a
  lscan = (Rep1 f a -> f a) -> (Rep1 f a, a) -> (f a, a)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Rep1 f a -> f a
forall a. Rep1 f a -> f a
forall k (f :: k -> *) (a :: k). Generic1 f => Rep1 f a -> f a
to1 ((Rep1 f a, a) -> (f a, a))
-> (f a -> (Rep1 f a, a)) -> f a -> (f a, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rep1 f a -> (Rep1 f a, a)
forall a. Monoid a => Rep1 f a -> Rep1 f a :* a
forall (f :: * -> *) a. (LScan f, Monoid a) => f a -> f a :* a
lscan (Rep1 f a -> (Rep1 f a, a))
-> (f a -> Rep1 f a) -> f a -> (Rep1 f a, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a -> Rep1 f a
forall a. f a -> Rep1 f a
forall k (f :: k -> *) (a :: k). Generic1 f => f a -> Rep1 f a
from1
  -- Temporary hack to avoid newtype-like representation. Still needed?
  lscanDummy :: f a
  lscanDummy = f a
forall a. HasCallStack => a
undefined
--   lscanWork, lscanDepth :: forall a. MappendStats a => Int

-- TODO: Try removing lscanDummy and the comment and recompiling with reification

-- | Traversable version (sequential)
lscanT :: Traversable t => (b -> a -> b) -> b -> t a -> (t b,b)
lscanT :: forall (t :: * -> *) b a.
Traversable t =>
(b -> a -> b) -> b -> t a -> (t b, b)
lscanT b -> a -> b
op b
e = (b, t b) -> (t b, b)
forall a b. (a, b) -> (b, a)
swap ((b, t b) -> (t b, b)) -> (t a -> (b, t b)) -> t a -> (t b, b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> a -> (b, b)) -> b -> t a -> (b, t b)
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (\ b
b a
a -> (b
b b -> a -> b
`op` a
a,b
b)) b
e
{-# INLINABLE lscanT #-}

lscanTraversable :: Traversable f => forall a. Monoid a => f a -> f a :* a
lscanTraversable :: forall (f :: * -> *) a.
(Traversable f, Monoid a) =>
f a -> f a :* a
lscanTraversable = (a -> a -> a) -> a -> f a -> (f a, a)
forall (t :: * -> *) b a.
Traversable t =>
(b -> a -> b) -> b -> t a -> (t b, b)
lscanT a -> a -> a
forall a. Monoid a => a -> a -> a
mappend a
forall a. Monoid a => a
mempty
{-# INLINABLE lscanTraversable #-}

{--------------------------------------------------------------------
    Monoid specializations
--------------------------------------------------------------------}

-- Left-scan via a 'Newtype'
lscanAla :: forall n o f. (Newtype n, o ~ O n, LScan f, Monoid n)
         => f o -> f o :* o
lscanAla :: forall n o (f :: * -> *).
(Newtype n, o ~ O n, LScan f, Monoid n) =>
f o -> f o :* o
lscanAla = ((n -> o) -> f n -> f o
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> o
n -> O n
forall n. Newtype n => n -> O n
unpack (f n -> f o) -> (n -> o) -> (f n, n) -> (f o, o)
forall b c b' c'. (b -> c) -> (b' -> c') -> (b, b') -> (c, c')
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** n -> o
n -> O n
forall n. Newtype n => n -> O n
unpack) ((f n, n) -> (f o, o)) -> (f o -> (f n, n)) -> f o -> (f o, o)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f n -> (f n, n)
forall a. Monoid a => f a -> f a :* a
forall (f :: * -> *) a. (LScan f, Monoid a) => f a -> f a :* a
lscan (f n -> (f n, n)) -> (f o -> f n) -> f o -> (f n, n)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (o -> n) -> f o -> f n
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall n. Newtype n => O n -> n
pack @n)

-- lscanAla k = underF k lscan
-- lscanAla _k = fmap unpack . lscan . fmap (pack :: o -> n)

lsums :: forall f a. (LScan f, Num a) => f a -> (f a, a)
lsums :: forall (f :: * -> *) a. (LScan f, Num a) => f a -> (f a, a)
lsums = forall n o (f :: * -> *).
(Newtype n, o ~ O n, LScan f, Monoid n) =>
f o -> f o :* o
lscanAla @(Sum a)
{-# INLINABLE lsums #-}

lproducts :: forall f a. (LScan f, Num a) => f a -> f a :* a
lproducts :: forall (f :: * -> *) a. (LScan f, Num a) => f a -> (f a, a)
lproducts = forall n o (f :: * -> *).
(Newtype n, o ~ O n, LScan f, Monoid n) =>
f o -> f o :* o
lscanAla @(Product a)
{-# INLINABLE lproducts #-}

lAlls :: LScan f => f Bool -> (f Bool, Bool)
lAlls :: forall (f :: * -> *). LScan f => f Bool -> (f Bool, Bool)
lAlls = forall n o (f :: * -> *).
(Newtype n, o ~ O n, LScan f, Monoid n) =>
f o -> f o :* o
lscanAla @All
{-# INLINABLE lAlls #-}

lAnys :: LScan f => f Bool -> (f Bool, Bool)
lAnys :: forall (f :: * -> *). LScan f => f Bool -> (f Bool, Bool)
lAnys = forall n o (f :: * -> *).
(Newtype n, o ~ O n, LScan f, Monoid n) =>
f o -> f o :* o
lscanAla @Any
{-# INLINABLE lAnys #-}

lParities :: LScan f => f Bool -> (f Bool, Bool)
lParities :: forall (f :: * -> *). LScan f => f Bool -> (f Bool, Bool)
lParities = forall n o (f :: * -> *).
(Newtype n, o ~ O n, LScan f, Monoid n) =>
f o -> f o :* o
lscanAla @Parity
{-# INLINABLE lParities #-}

multiples :: (LScan f, Pointed f, Num a) => a -> f a :* a
multiples :: forall (f :: * -> *) a.
(LScan f, Pointed f, Num a) =>
a -> f a :* a
multiples = f a -> (f a, a)
forall (f :: * -> *) a. (LScan f, Num a) => f a -> (f a, a)
lsums (f a -> (f a, a)) -> (a -> f a) -> a -> (f a, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> f a
forall a. a -> f a
forall (p :: * -> *) a. Pointed p => a -> p a
point

powers :: (LScan f, Pointed f, Num a) => a -> f a :* a
powers :: forall (f :: * -> *) a.
(LScan f, Pointed f, Num a) =>
a -> f a :* a
powers = f a -> f a :* a
forall (f :: * -> *) a. (LScan f, Num a) => f a -> (f a, a)
lproducts (f a -> f a :* a) -> (a -> f a) -> a -> f a :* a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> f a
forall a. a -> f a
forall (p :: * -> *) a. Pointed p => a -> p a
point

-- | Numbers from 0 to n (size of f). Named for APL iota operation (but 0 based).
iota :: (LScan f, Pointed f, Num a) => f a :* a
iota :: forall (f :: * -> *) a. (LScan f, Pointed f, Num a) => f a :* a
iota = a -> f a :* a
forall (f :: * -> *) a.
(LScan f, Pointed f, Num a) =>
a -> f a :* a
multiples a
1

{--------------------------------------------------------------------
    Work and depth
--------------------------------------------------------------------}

-- class Monoid o => MappendStats o where
--   mappendWork, mappendDepth :: Int
--   mappendWork = 1
--   mappendDepth = 1

-- instance Num a => MappendStats (Sum     a)
-- instance Num a => MappendStats (Product a)

{--------------------------------------------------------------------
    Generic support
--------------------------------------------------------------------}

instance LScan V1 where
  lscan :: forall a. Monoid a => V1 a -> V1 a :* a
lscan = V1 a -> V1 a :* a
forall {k} (a :: k) b. V1 a -> b
absurdF
--   lscanWork = 0
--   lscanDepth = 0

instance LScan U1 where
  lscan :: forall a. Monoid a => U1 a -> U1 a :* a
lscan U1 a
U1 = (U1 a
forall k (p :: k). U1 p
U1, a
forall a. Monoid a => a
mempty)
--   lscanWork = 0
--   lscanDepth = 0

instance LScan (K1 i c) where
  lscan :: forall a. Monoid a => K1 i c a -> K1 i c a :* a
lscan w :: K1 i c a
w@(K1 c
_) = (K1 i c a
w, a
forall a. Monoid a => a
mempty)
--   lscanWork = 0
--   lscanDepth = 0

instance LScan Par1 where
  lscan :: forall a. Monoid a => Par1 a -> Par1 a :* a
lscan (Par1 a
a) = (a -> Par1 a
forall p. p -> Par1 p
Par1 a
forall a. Monoid a => a
mempty, a
a)
--   lscanWork = 0
--   lscanDepth = 0

-- foo :: Int
-- foo = lscanWork @Par1 @(Sum Int)

instance (LScan f, LScan g) => LScan (f :+: g) where
  lscan :: forall a. Monoid a => (:+:) f g a -> (:+:) f g a :* a
lscan (L1 f a
fa) = (f a -> (:+:) f g a) -> (f a, a) -> ((:+:) f g a, a)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first f a -> (:+:) f g a
forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1 (f a -> (f a, a)
forall a. Monoid a => f a -> f a :* a
forall (f :: * -> *) a. (LScan f, Monoid a) => f a -> f a :* a
lscan f a
fa)
  lscan (R1 g a
ga) = (g a -> (:+:) f g a) -> (g a, a) -> ((:+:) f g a, a)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first g a -> (:+:) f g a
forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 (g a -> (g a, a)
forall a. Monoid a => g a -> g a :* a
forall (f :: * -> *) a. (LScan f, Monoid a) => f a -> f a :* a
lscan g a
ga)
--   lscanWork, lscanDepth :: forall a. MappendStats a => Int
--   lscanWork = max (lscanWork @f @a) (lscanWork @g @a)
--   lscanDepth = max (lscanDepth @f @a) (lscanDepth @g @a)

-- GHC objects:
-- 
--     • Could not deduce (MappendStats a0)
--       from the context: (LScan f, LScan g)
--         bound by the instance declaration
--         at /Users/conal/Haskell/shaped-types/src/ConCat/Scan.hs:157:10-46
--       or from: MappendStats a
--         bound by the type signature for:
--                    lscanWork :: MappendStats a => Int
--         at /Users/conal/Haskell/shaped-types/src/ConCat/Scan.hs:160:28-58
--       The type variable ‘a0’ is ambiguous
--
-- I wonder if ScopedTypeVariables is failing here

instance (LScan f, LScan g) => LScan (f :*: g) where
  lscan :: forall a. Monoid a => (:*:) f g a -> (:*:) f g a :* a
lscan (f a
fa :*: g a
ga) = (f a
fa' f a -> g a -> (:*:) f g a
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: ((a
fx a -> a -> a
forall a. Semigroup a => a -> a -> a
<>) (a -> a) -> g a -> g a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> g a
ga'), a
fx a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
gx)
   where
     (f a
fa', a
fx) = f a -> (f a, a)
forall a. Monoid a => f a -> f a :* a
forall (f :: * -> *) a. (LScan f, Monoid a) => f a -> f a :* a
lscan f a
fa
     (g a
ga', a
gx) = g a -> (g a, a)
forall a. Monoid a => g a -> g a :* a
forall (f :: * -> *) a. (LScan f, Monoid a) => f a -> f a :* a
lscan g a
ga
--   lscanWork :: 
--   lscanWork = lscanWork @f + lscanWork @g + mappendWork 

-- Alternatively,

--   lscan (fa :*: ga) = (fa' :*: ga', gx)
--    where
--      (fa', fx) =               lscan fa
--      (ga', gx) = mapl (fx <>) (lscan ga)

instance (LScan g, LScan f, Zip g) =>  LScan (g :.: f) where
  lscan :: forall a. Monoid a => (:.:) g f a -> (:.:) g f a :* a
lscan (Comp1 g (f a)
gfa) = (g (f a) -> (:.:) g f a
forall k2 k1 (f :: k2 -> *) (g :: k1 -> k2) (p :: k1).
f (g p) -> (:.:) f g p
Comp1 ((a -> f a -> f a) -> g a -> g (f a) -> g (f a)
forall a b c. (a -> b -> c) -> g a -> g b -> g c
forall (f :: * -> *) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith a -> f a -> f a
forall {f :: * -> *} {b}.
(Functor f, Semigroup b) =>
b -> f b -> f b
adjustl g a
tots' g (f a)
gfa'), a
tot)
   where
     (g (f a)
gfa', g a
tots)  = g (f a :* a) -> (g (f a), g a)
forall (f :: * -> *) a b. Functor f => f (a :* b) -> f a :* f b
unzip (f a -> f a :* a
forall a. Monoid a => f a -> f a :* a
forall (f :: * -> *) a. (LScan f, Monoid a) => f a -> f a :* a
lscan (f a -> f a :* a) -> g (f a) -> g (f a :* a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> g (f a)
gfa)
     (g a
tots',a
tot)   = g a -> (g a, a)
forall a. Monoid a => g a -> g a :* a
forall (f :: * -> *) a. (LScan f, Monoid a) => f a -> f a :* a
lscan g a
tots
     adjustl :: b -> f b -> f b
adjustl b
t     = (b -> b) -> f b -> f b
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (b
t b -> b -> b
forall a. Semigroup a => a -> a -> a
<>)

-- TODO: maybe zipWith (fmap . mappend) tots' gfa'

instance LScan f => LScan (M1 i c f) where
  lscan :: forall a. Monoid a => M1 i c f a -> M1 i c f a :* a
lscan (M1 f a
as) = (f a -> M1 i c f a) -> (f a, a) -> (M1 i c f a, a)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first f a -> M1 i c f a
forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 (f a -> (f a, a)
forall a. Monoid a => f a -> f a :* a
forall (f :: * -> *) a. (LScan f, Monoid a) => f a -> f a :* a
lscan f a
as)